├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── examples └── Getting started with ifBO.ipynb ├── ifbo ├── __init__.py ├── bar_distribution.py ├── decoders.py ├── download.py ├── encoders.py ├── initializers.py ├── layer.py ├── positional_encodings.py ├── priors │ ├── __init__.py │ ├── ftpfn_prior.py │ ├── output_sorted.npy │ ├── prior.py │ ├── prior_bag.py │ └── utils.py ├── surrogate.py ├── train.py ├── transformer.py ├── utils.py └── version.py ├── pyproject.toml ├── requirements.txt └── tests └── test_surrogate.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.10", "3.11", "3.12", "3.13"] 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 pytest 31 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 32 | pip install -e . 33 | - name: Lint with flake8 34 | run: | 35 | # stop the build if there are Python syntax errors or undefined names 36 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 37 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 38 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 39 | - name: Test with pytest 40 | run: | 41 | pytest 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .aider* 2 | __pycache__/ 3 | */__pycache__/ 4 | */*/__pycache__/ 5 | ifBO.egg-info/ 6 | 7 | 8 | *data*/ 9 | *build*/ 10 | *dist*/ 11 | *output*/ 12 | *result*/ 13 | .model/ 14 | 15 | .DS_Store 16 | */.DS_Store 17 | 18 | *.ipynb 19 | *.json 20 | *.csv 21 | *.pkl 22 | *.pt 23 | *.pth 24 | *.sh 25 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/astral-sh/ruff-pre-commit 6 | rev: v0.3.5 7 | hooks: 8 | - id: ruff 9 | args: [ --fix, --exit-non-zero-on-fix ] 10 | types_or: [python, pyi] 11 | - id: ruff-format 12 | types_or: [python, pyi] 13 | - repo: https://github.com/pre-commit/mirrors-mypy 14 | rev: v1.9.0 15 | hooks: 16 | - id: mypy 17 | additional_dependencies: [ 18 | "types-requests", 19 | ] 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AutoML-Freiburg-Hannover 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![arXiv](https://img.shields.io/badge/arXiv-2105.09821-b31b1b.svg)](https://arxiv.org/abs/2404.16795) 2 | 3 | # `ifBO`: In-context Freeze-Thaw Bayesian Optimization for Hyperparameter Optimization 4 | 5 | [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-md.svg)](https://huggingface.co/spaces/herilalaina/ifbo) 6 | 7 | This repository contains the official code for our [ICML 2024 paper](https://openreview.net/forum?id=VyoY3Wh9Wd). `ifBO` is an efficient Bayesian Optimization algorithm that dynamically selects and incrementally evaluates candidates during the optimization process. It uses a model called the `Freeze-Thaw surrogate (FT-PFN)` to predict the performance of candidate configurations as more resources are allocated. The `main` branch includes the necessary API to use `FT-PFN`. Refer to the following sections: 8 | - [Surrogate API](#surrogate-api): to learn how to initialize and use the surrogate model. 9 | - [Bayesian Optimization with ifBO](#bayesian-optimization-with-ifbo): to understand how to use `ifBO` for Hyperparameter Optimization. 10 | 11 | 12 | > To reproduce experiments from the above paper version, please refer to the branch [`icml-2024`](https://github.com/automl/ifBO/tree/icml-2024). 13 | 14 | # Installation 15 | 16 | Requires Python 3.11. 17 | 18 | ```bash 19 | pip install -U ifBO 20 | ``` 21 | 22 | # Usage 23 | 24 | ## Surrogate API 25 | 26 | Checkout out this [notebook](https://github.com/automl/ifBO/blob/main/examples/Getting%20started%20with%20ifBO.ipynb). 27 | 28 | **Initializing the model** 29 | 30 | ```python 31 | from ifbo.surrogate import FTPFN 32 | from ifbo import Curve, PredictionResult 33 | 34 | model = FTPFN(version="0.0.1") 35 | ``` 36 | 37 | This creates a ``.model/`` directory in the current working directory for the surrogate model. To have control over this, specify a ``target_path: Path`` when initializing. 38 | 39 | Supported versions: 40 | 41 | | Version | Identifier | Notes | 42 | | ------- | ---------------- | --------------------------------------------------------------------- | 43 | | 0.0.1 | ICML '24 submission | Supports up to ``1000`` unique configurations in the context, with each configuration having a maximum of ``10`` dimensions. | 44 | 45 | **Creating context and query points** 46 | 47 | The code snippet below demonstrates how to create instances of learning curves using `ifbo.Curve` class. Each curve represents the performance over time of a configuration (vector of hyperparameter values). These instances are used to form the context and query points for the model: 48 | 49 | - `context`: known data points with both time (`t`) and observed values (`y`). 50 | - `query`: points where predictions are needed, with only time (`t`) provided. 51 | 52 | > __Note__: All values (hyperparameters, performances, and times) must be normalized to the range $[0, 1]$. 53 | 54 | ```python 55 | import torch 56 | 57 | context = [ 58 | Curve( 59 | hyperparameters=torch.tensor([0.2, 0.1, 0.5]), 60 | t=torch.tensor([0.1, 0.2, 0.3]), 61 | y=torch.tensor([0.1, 0.15, 0.3]) 62 | ), 63 | Curve( 64 | hyperparameters=torch.tensor([0.2, 0.3, 0.25]), 65 | t=torch.tensor([0.1, 0.2, 0.3, 0.4]), 66 | y=torch.tensor([0.2, 0.5, 0.6, 0.75]) 67 | ), 68 | ] 69 | query = [ 70 | Curve( 71 | hyperparameters=torch.tensor([0.2, 0.1, 0.5]), 72 | t=torch.tensor([0.3, 0.4, 0.5, 0.6, 0.7, 0.9]) 73 | ), 74 | Curve( 75 | hyperparameters=torch.tensor([0.2, 0.3, 0.25]), 76 | t=torch.tensor([0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) 77 | ), 78 | ] 79 | ``` 80 | 81 | **Making predictions** 82 | 83 | Use the model to predict performances at the ``query`` points. 84 | 85 | ```python 86 | predictions: list[PredictionResult] = model.predict(context=context, query=query) 87 | 88 | # Get predictions for the first curve 89 | prediction: PredictionResult = predictions[0] 90 | 91 | # Print the 5% and 95% percentiles of the predictive posterior distribution 92 | print(prediction.quantile(0.05), prediction.quantile(0.95)) 93 | ``` 94 | 95 | Following the PFN approach, the FT-PFN model outputs the Predictive Posterior Distribution (PPD) of the performances for each query point. Each PPD is encapsulated in an `ifbo.PredictionResult` object, which provides an interface to compute various quantities from the distribution, including: 96 | 97 | * ``likelihood(y_test: torch.Tensor)``: Computes the negative log-likelihood of the test targets (``y_test``). 98 | * ``ucb()``: Computes the upper confidence bound. 99 | * ``ei(y_best: torch.Tensor)``: Computes the expected improvement over ``y_best``. 100 | * ``pi(y_best: torch.Tensor)``: Computes the probability of improvement over ``y_best``. 101 | * `quantile(q: float)`: Computes the value at the specified quantile level ``q``. 102 | 103 | 104 | ## Bayesian Optimization with ifBO 105 | 106 | To use the `ifBO` algorithm in practice, refer to [NePS](https://automl.github.io/neps/latest/), a package for hyperparameter optimization that includes the latest and improved version of `ifBO`. Below is a template example of how to use `ifBO` with NePS. For a complete Python script, see the [full example](https://github.com/automl/neps/blob/master/neps_examples/efficiency/freeze_thaw.py). 107 | 108 | ```python 109 | import neps 110 | 111 | def training_pipeline( 112 | num_layers, 113 | num_neurons, 114 | epochs, 115 | learning_rate, 116 | weight_decay 117 | ): 118 | # Training logic and checkpoint loading here 119 | pass 120 | 121 | pipeline_space = { 122 | "learning_rate": neps.Float(1e-5, 1e-1, log=True), 123 | "num_layers": neps.Integer(1, 5), 124 | "num_neurons": neps.Integer(64, 128), 125 | "weight_decay": neps.Float(1e-5, 0.1, log=True), 126 | "epochs": neps.Integer(1, 10, is_fidelity=True), 127 | } 128 | 129 | neps.run( 130 | pipeline_space=pipeline_space, 131 | run_pipeline=training_pipeline, 132 | searcher="ifbo", 133 | max_evaluations_total=50, 134 | step_size=1, 135 | surrogate_model_args=dict( 136 | version="0.0.1", 137 | target_path=None, 138 | ), 139 | ) 140 | ``` 141 | 142 | 143 | 144 | # Citation 145 | 146 | If using our surrogate, code, experiment setup, kindly cite using: 147 | ```bibtex 148 | @inproceedings{ 149 | rakotoarison-icml24, 150 | title={In-Context Freeze-Thaw Bayesian Optimization for Hyperparameter Optimization}, 151 | author={H. Rakotoarison and S. Adriaensen and N. Mallik and S. Garibov and E. Bergman and F. Hutter}, 152 | booktitle={Forty-first International Conference on Machine Learning}, 153 | year={2024}, 154 | url={https://openreview.net/forum?id=VyoY3Wh9Wd} 155 | } 156 | ``` 157 | -------------------------------------------------------------------------------- /ifbo/__init__.py: -------------------------------------------------------------------------------- 1 | from .bar_distribution import BarDistribution 2 | from .download import VERSION_MAP 3 | from .priors import ftpfn_prior 4 | from .priors.prior import Batch 5 | from .priors.utils import get_batch_sequence as get_batch_sequence 6 | from .priors.utils import get_batch_to_dataloader as get_batch_to_dataloader 7 | from .surrogate import FTPFN 8 | from .utils import Curve 9 | from .utils import PredictionResult 10 | from .version import __version__ 11 | 12 | 13 | __all__ = [ 14 | "FTPFN", 15 | "Curve", 16 | "PredictionResult", 17 | "VERSION_MAP", 18 | "BarDistribution", 19 | "Batch", 20 | "get_batch_sequence", 21 | "get_batch_to_dataloader", 22 | "ftpfn_prior", 23 | "__version__", 24 | ] 25 | -------------------------------------------------------------------------------- /ifbo/bar_distribution.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import logging 7 | 8 | 9 | class BarDistribution(nn.Module): 10 | def __init__( 11 | self, borders: torch.Tensor, smoothing: float = 0.0, ignore_nan_targets: bool = True 12 | ) -> None: # here borders should start with min and end with max, where all values lie in (min,max) and are sorted 13 | """ 14 | :param borders: 15 | :param smoothing: 16 | :param append_mean_pred: Whether to predict the mean of the other positions as a last output in forward, 17 | is enabled when additionally y has a sequence length 1 shorter than logits, i.e. len(logits) == 1 + len(y) 18 | """ 19 | super().__init__() 20 | assert len(borders.shape) == 1 21 | self.register_buffer("borders", borders) 22 | self.register_buffer("smoothing", torch.tensor(smoothing)) 23 | self.register_buffer("bucket_widths", self.borders[1:] - self.borders[:-1]) 24 | full_width = self.bucket_widths.sum() 25 | assert ( 26 | (1 - (full_width / (self.borders[-1] - self.borders[0]))).abs() < 1e-2 27 | ), f"diff: {full_width - (self.borders[-1] - self.borders[0])} with {full_width} {self.borders[-1]} {self.borders[0]}" 28 | assert ( 29 | self.bucket_widths >= 0.0 30 | ).all(), "Please provide sorted borders!" # This also allows size zero buckets 31 | self.num_bars = len(borders) - 1 32 | self.ignore_nan_targets = ignore_nan_targets 33 | self.to(borders.device) 34 | 35 | def __setstate__(self, state: Any) -> None: 36 | super().__setstate__(state) 37 | self.__dict__.setdefault("append_mean_pred", False) 38 | 39 | def map_to_bucket_idx(self, y: torch.Tensor) -> torch.Tensor: 40 | target_sample = torch.searchsorted(self.borders, y) - 1 41 | target_sample[y == self.borders[0]] = 0 42 | target_sample[y == self.borders[-1]] = self.num_bars - 1 43 | return target_sample 44 | 45 | def ignore_init(self, y: torch.Tensor) -> torch.Tensor: 46 | ignore_loss_mask = torch.isnan(y) 47 | if ignore_loss_mask.any(): 48 | if not self.ignore_nan_targets: 49 | raise ValueError(f"Found NaN in target {y}") 50 | y[ignore_loss_mask] = self.borders[ 51 | 0 52 | ] # this is just a default value, it will be ignored anyway 53 | return ignore_loss_mask 54 | 55 | def compute_scaled_log_probs(self, logits: torch.Tensor) -> torch.Tensor: 56 | # this is equivalent to log(p(y)) of the density p 57 | bucket_log_probs = torch.log_softmax(logits, -1) 58 | scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths) 59 | return scaled_bucket_log_probs 60 | 61 | def forward( 62 | self, 63 | logits: torch.Tensor, 64 | y: torch.Tensor, 65 | mean_prediction_logits: torch.Tensor | None = None, 66 | ) -> ( 67 | torch.Tensor 68 | ): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars 69 | y = y.clone().view(*logits.shape[:-1]) # no trailing one dimension 70 | ignore_loss_mask = self.ignore_init(y) 71 | target_sample = self.map_to_bucket_idx(y) 72 | assert (target_sample >= 0).all() and ( 73 | target_sample < self.num_bars 74 | ).all(), f"y {y} not in support set for borders (min_y, max_y) {self.borders}" 75 | assert logits.shape[-1] == self.num_bars, f"{logits.shape[-1]} vs {self.num_bars}" 76 | scaled_bucket_log_probs = self.compute_scaled_log_probs(logits) 77 | nll_loss = -scaled_bucket_log_probs.gather(-1, target_sample[..., None]).squeeze( 78 | -1 79 | ) # T x B 80 | if mean_prediction_logits is not None: 81 | if not self.training: 82 | logging.warning("Calculating loss incl mean prediction loss for nonmyopic BO.") 83 | scaled_mean_log_probs = self.compute_scaled_log_probs(mean_prediction_logits) 84 | nll_loss = torch.cat((nll_loss, self.mean_loss(logits, scaled_mean_log_probs)), 0) 85 | smooth_loss = -scaled_bucket_log_probs.mean(dim=-1) 86 | smoothing = self.smoothing if self.training else 0.0 87 | loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss 88 | loss[ignore_loss_mask] = 0.0 89 | return loss 90 | 91 | def mean_loss(self, logits: torch.Tensor, scaled_mean_logits: torch.Tensor) -> torch.Tensor: 92 | assert (len(logits.shape) == 3) and (len(scaled_mean_logits.shape) == 2), ( 93 | len(logits.shape), 94 | len(scaled_mean_logits.shape), 95 | ) 96 | means = self.mean(logits).detach() # T x B 97 | target_mean = self.map_to_bucket_idx(means).clamp_(0, self.num_bars - 1) # T x B 98 | return -scaled_mean_logits.gather(1, target_mean.T).mean(1).unsqueeze(0) # 1 x B 99 | 100 | def mean(self, logits: torch.Tensor) -> torch.Tensor: 101 | bucket_means = self.borders[:-1] + self.bucket_widths / 2 102 | p = torch.softmax(logits, -1) 103 | return p @ bucket_means 104 | 105 | def median(self, logits: torch.Tensor) -> torch.Tensor: 106 | return self.icdf(logits, 0.5) 107 | 108 | def icdf(self, logits: torch.Tensor, left_prob: float) -> torch.Tensor: 109 | """ 110 | Implementation of the quantile function 111 | :param logits: Tensor of any shape, with the last dimension being logits 112 | :param left_prob: float: The probability mass to the left of the result. 113 | :return: Position with `left_prob` probability weight to the left. 114 | """ 115 | probs = logits.softmax(-1) 116 | cumprobs = torch.cumsum(probs, -1) 117 | idx = ( 118 | torch.searchsorted( 119 | cumprobs, 120 | left_prob * torch.ones(*cumprobs.shape[:-1], 1, device=logits.device), 121 | ) 122 | .squeeze(-1) 123 | .clamp(0, cumprobs.shape[-1] - 1) 124 | ) # this might not do the right for outliers 125 | cumprobs = torch.cat( 126 | [torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), cumprobs], -1 127 | ) 128 | rest_prob = left_prob - cumprobs.gather(-1, idx[..., None]).squeeze(-1) 129 | left_border = self.borders[idx] 130 | right_border = self.borders[idx + 1] 131 | return left_border + (right_border - left_border) * rest_prob / probs.gather( 132 | -1, idx[..., None] 133 | ).squeeze(-1) 134 | 135 | def quantile(self, logits: torch.Tensor, center_prob: float = 0.682) -> torch.Tensor: 136 | side_probs = (1.0 - center_prob) / 2 137 | return torch.stack( 138 | (self.icdf(logits, side_probs), self.icdf(logits, 1.0 - side_probs)), -1 139 | ) 140 | 141 | def ucb( 142 | self, 143 | logits: torch.Tensor, 144 | best_f: torch.Tensor | float, 145 | rest_prob: float = (1 - 0.682) / 2, 146 | maximize: bool = True, 147 | ) -> torch.Tensor: 148 | """ 149 | UCB utility. Rest Prob is the amount of utility above (below) the confidence interval that is ignored. 150 | Higher rest_prob is equivalent to lower beta in the standard GP-UCB formulation. 151 | :param logits: Logits, as returned by the Transformer. 152 | :param rest_prob: The amount of utility above (below) the confidence interval that is ignored. 153 | The default is equivalent to using GP-UCB with `beta=1`. 154 | To get the corresponding `beta`, where `beta` is from 155 | the standard GP definition of UCB `ucb_utility = mean + beta * std`, 156 | you can use this computation: `beta = math.sqrt(2)*torch.erfinv(torch.tensor(2*(1-rest_prob)-1))`. 157 | :param maximize: 158 | :return: utility 159 | """ 160 | if maximize: 161 | rest_prob = 1 - rest_prob 162 | return self.icdf(logits, rest_prob) 163 | 164 | def mode(self, logits: torch.Tensor) -> torch.Tensor: 165 | mode_inds = logits.argmax(-1) 166 | bucket_means = self.borders[:-1] + self.bucket_widths / 2 167 | return bucket_means[mode_inds] 168 | 169 | def ei( 170 | self, logits: torch.Tensor, best_f: torch.Tensor | float, maximize: bool = True 171 | ) -> torch.Tensor: # logits: evaluation_points x batch x feature_dim 172 | bucket_diffs = self.borders[1:] - self.borders[:-1] 173 | assert maximize 174 | if not torch.is_tensor(best_f): 175 | best_f = torch.full(logits[..., 0].shape, best_f, device=logits.device) 176 | assert isinstance(best_f, torch.Tensor) 177 | if not len(best_f.shape): 178 | best_f = torch.full(logits[..., 0].shape, best_f, device=logits.device) 179 | best_f = best_f[..., None].repeat(*[1] * len(best_f.shape), logits.shape[-1]) 180 | clamped_best_f = best_f.clamp(self.borders[:-1], self.borders[1:]) 181 | # bucket_contributions = (best_f[...,None] < self.borders[:-1]).float() * bucket_means 182 | # true bucket contributions 183 | bucket_contributions = ( 184 | (self.borders[1:] ** 2 - clamped_best_f**2) / 2 185 | - best_f * (self.borders[1:] - clamped_best_f) 186 | ) / bucket_diffs 187 | p = torch.softmax(logits, -1) 188 | return torch.einsum("...b,...b->...", p, bucket_contributions) 189 | 190 | def pi( 191 | self, logits: torch.Tensor, best_f: torch.Tensor | float, maximize: bool = True 192 | ) -> torch.Tensor: # logits: evaluation_points x batch x feature_dim 193 | """ 194 | Acquisition Function: Probability of Improvement 195 | :param logits: as returned by Transformer 196 | :param best_f: best evaluation so far (the incumbent) 197 | :param maximize: whether to maximize 198 | :return: utility 199 | """ 200 | assert maximize is True 201 | p = torch.softmax(logits, -1) 202 | border_widths = self.borders[1:] - self.borders[:-1] 203 | factor = 1.0 - ((best_f - self.borders[:-1]) / border_widths).clamp(0.0, 1.0) 204 | return (p * factor).sum(-1) 205 | 206 | def mean_of_square(self, logits: torch.Tensor) -> torch.Tensor: 207 | """ 208 | Computes E[x^2]. 209 | :param logits: Output of the model. 210 | """ 211 | left_borders = self.borders[:-1] 212 | right_borders = self.borders[1:] 213 | bucket_mean_of_square = ( 214 | left_borders.square() + right_borders.square() + left_borders * right_borders 215 | ) / 3.0 216 | p = torch.softmax(logits, -1) 217 | return p @ bucket_mean_of_square 218 | 219 | def variance(self, logits: torch.Tensor) -> torch.Tensor: 220 | return self.mean_of_square(logits) - self.mean(logits).square() 221 | 222 | 223 | class FullSupportBarDistribution(BarDistribution): 224 | @staticmethod 225 | def halfnormal_with_p_weight_before( 226 | range_max: float, p: float = 0.5 227 | ) -> torch.distributions.HalfNormal: 228 | s = range_max / torch.distributions.HalfNormal(torch.tensor(1.0)).icdf(torch.tensor(p)) 229 | return torch.distributions.HalfNormal(max(s, 1e-9)) 230 | 231 | def forward( 232 | self, 233 | logits: torch.Tensor, 234 | y: torch.Tensor, 235 | mean_prediction_logits: torch.Tensor | None = None, 236 | ) -> ( 237 | torch.Tensor 238 | ): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars 239 | assert self.num_bars > 1 240 | y = y.clone().view(len(y), -1) # no trailing one dimension 241 | ignore_loss_mask = self.ignore_init(y) # alters y 242 | target_sample = self.map_to_bucket_idx(y) # shape: T x B (same as y) 243 | target_sample.clamp_(0, self.num_bars - 1) 244 | assert logits.shape[-1] == self.num_bars, f"{logits.shape[-1]} vs {self.num_bars}" 245 | assert (target_sample >= 0).all() and ( 246 | target_sample < self.num_bars 247 | ).all(), f"y {y} not in support set for borders (min_y, max_y) {self.borders}" 248 | assert logits.shape[-1] == self.num_bars, f"{logits.shape[-1]} vs {self.num_bars}" 249 | # ignore all position with nan values 250 | 251 | scaled_bucket_log_probs = self.compute_scaled_log_probs(logits) 252 | 253 | assert len(scaled_bucket_log_probs) == len(target_sample), ( 254 | len(scaled_bucket_log_probs), 255 | len(target_sample), 256 | ) 257 | log_probs = scaled_bucket_log_probs.gather(-1, target_sample.unsqueeze(-1)).squeeze(-1) 258 | 259 | side_normals = ( 260 | self.halfnormal_with_p_weight_before(self.bucket_widths[0]), 261 | self.halfnormal_with_p_weight_before(self.bucket_widths[-1]), 262 | ) 263 | 264 | log_probs[target_sample == 0] += side_normals[0].log_prob( 265 | (self.borders[1] - y[target_sample == 0]).clamp(min=0.00000001) 266 | ) + torch.log(self.bucket_widths[0]) 267 | log_probs[target_sample == self.num_bars - 1] += side_normals[1].log_prob( 268 | (y[target_sample == self.num_bars - 1] - self.borders[-2]).clamp(min=0.00000001) 269 | ) + torch.log(self.bucket_widths[-1]) 270 | 271 | nll_loss = -log_probs 272 | 273 | if mean_prediction_logits is not None: 274 | assert ( 275 | not ignore_loss_mask.any() 276 | ), "Ignoring examples is not implemented with mean pred." 277 | if not torch.is_grad_enabled(): 278 | logging.warning( 279 | "loss is not correct in absolute terms, only the gradient is right, when using `append_mean_pred`." 280 | ) 281 | scaled_mean_log_probs = self.compute_scaled_log_probs(mean_prediction_logits) 282 | nll_loss = torch.cat((nll_loss, self.mean_loss(logits, scaled_mean_log_probs)), 0) 283 | # ignore_loss_mask = torch.zeros_like(nll_loss, dtype=torch.bool) 284 | 285 | if self.smoothing: 286 | smooth_loss = -scaled_bucket_log_probs.mean(dim=-1) 287 | smoothing = self.smoothing if self.training else 0.0 288 | nll_loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss 289 | 290 | if ignore_loss_mask.any(): 291 | nll_loss[ignore_loss_mask] = 0.0 292 | 293 | return nll_loss 294 | 295 | def mean(self, logits: torch.Tensor) -> torch.Tensor: 296 | bucket_means = self.borders[:-1] + self.bucket_widths / 2 297 | p = torch.softmax(logits, -1) 298 | side_normals = ( 299 | self.halfnormal_with_p_weight_before(self.bucket_widths[0]), 300 | self.halfnormal_with_p_weight_before(self.bucket_widths[-1]), 301 | ) 302 | bucket_means[0] = -side_normals[0].mean + self.borders[1] 303 | bucket_means[-1] = side_normals[1].mean + self.borders[-2] 304 | return p @ bucket_means.to(logits.device) 305 | 306 | def mean_of_square(self, logits: torch.Tensor) -> torch.Tensor: 307 | """ 308 | Computes E[x^2]. 309 | :param logits: Output of the model. 310 | """ 311 | left_borders = self.borders[:-1] 312 | right_borders = self.borders[1:] 313 | bucket_mean_of_square = ( 314 | left_borders.square() + right_borders.square() + left_borders * right_borders 315 | ) / 3.0 316 | side_normals = ( 317 | self.halfnormal_with_p_weight_before(self.bucket_widths[0]), 318 | self.halfnormal_with_p_weight_before(self.bucket_widths[-1]), 319 | ) 320 | bucket_mean_of_square[0] = ( 321 | side_normals[0].variance + (-side_normals[0].mean + self.borders[1]).square() 322 | ) 323 | bucket_mean_of_square[-1] = ( 324 | side_normals[1].variance + (side_normals[1].variance + self.borders[-2]).square() 325 | ) 326 | p = torch.softmax(logits, -1) 327 | return p @ bucket_mean_of_square 328 | 329 | def pi( 330 | self, logits: torch.Tensor, best_f: torch.Tensor | float, maximize: bool = True 331 | ) -> torch.Tensor: # logits: evaluation_points x batch x feature_dim 332 | """ 333 | Acquisition Function: Probability of Improvement 334 | :param logits: as returned by Transformer (evaluation_points x batch x feature_dim) 335 | :param best_f: best evaluation so far (the incumbent) 336 | :param maximize: whether to maximize 337 | :return: utility 338 | """ 339 | assert maximize is True 340 | if not torch.is_tensor(best_f): 341 | best_f = torch.full( 342 | logits[..., 0].shape, best_f, device=logits.device 343 | ) # evaluation_points x batch 344 | assert isinstance(best_f, torch.Tensor) 345 | if not len(best_f.shape): 346 | best_f = torch.full( 347 | logits[..., 0].shape, best_f, device=logits.device 348 | ) # evaluation_points x batch 349 | assert ( 350 | best_f.shape == logits[..., 0].shape 351 | ), f"best_f.shape: {best_f.shape}, logits.shape: {logits.shape}" 352 | p = torch.softmax(logits, -1) # evaluation_points x batch 353 | border_widths = self.borders[1:] - self.borders[:-1] 354 | factor = 1.0 - ((best_f[..., None] - self.borders[:-1]) / border_widths).clamp( 355 | 0.0, 1.0 356 | ) # evaluation_points x batch x num_bars 357 | 358 | side_normals = ( 359 | self.halfnormal_with_p_weight_before(self.bucket_widths[0]), 360 | self.halfnormal_with_p_weight_before(self.bucket_widths[-1]), 361 | ) 362 | position_in_side_normals = ( 363 | -(best_f - self.borders[1]).clamp(max=0.0), 364 | (best_f - self.borders[-2]).clamp(min=0.0), 365 | ) # evaluation_points x batch 366 | factor[..., 0] = 0.0 367 | factor[..., 0][position_in_side_normals[0] > 0.0] = side_normals[0].cdf( 368 | position_in_side_normals[0][position_in_side_normals[0] > 0.0] 369 | ) 370 | factor[..., -1] = 1.0 371 | factor[..., -1][position_in_side_normals[1] > 0.0] = 1.0 - side_normals[1].cdf( 372 | position_in_side_normals[1][position_in_side_normals[1] > 0.0] 373 | ) 374 | return (p * factor).sum(-1) 375 | 376 | def ei_for_halfnormal( 377 | self, scale: float, best_f: torch.Tensor | float, maximize: bool = True 378 | ) -> torch.Tensor: 379 | """ 380 | This is the EI for a standard normal distribution with mean 0 and variance `scale` times 2. 381 | Which is the same as the half normal EI. 382 | I tested this with MC approximation: 383 | ei_for_halfnormal = lambda scale, best_f: (torch.distributions.HalfNormal(torch.tensor(scale)).sample((10_000_000,))- best_f ).clamp(min=0.).mean() 384 | print([(ei_for_halfnormal(scale,best_f), FullSupportBarDistribution().ei_for_halfnormal(scale,best_f)) for scale in [0.1,1.,10.] for best_f in [.1,10.,4.]]) 385 | :param scale: 386 | :param best_f: 387 | :param maximize: 388 | :return: 389 | """ 390 | assert maximize 391 | mean = torch.tensor(0.0) 392 | u = (mean - best_f) / scale 393 | normal = torch.distributions.Normal(torch.zeros_like(u), torch.ones_like(u)) 394 | try: 395 | ucdf = normal.cdf(u) 396 | except ValueError: 397 | raise 398 | updf = torch.exp(normal.log_prob(u)) 399 | normal_ei = scale * (updf + u * ucdf) 400 | return 2 * normal_ei 401 | 402 | def ei( 403 | self, logits: torch.Tensor, best_f: torch.Tensor | float, maximize: bool = True 404 | ) -> torch.Tensor: # logits: evaluation_points x batch x feature_dim 405 | if torch.isnan(logits).any(): 406 | raise ValueError(f"logits contains NaNs: {logits}") 407 | bucket_diffs = self.borders[1:] - self.borders[:-1] 408 | assert maximize 409 | if not torch.is_tensor(best_f): 410 | best_f = torch.full( 411 | logits[..., 0].shape, best_f, device=logits.device 412 | ) # evaluation_points x batch 413 | assert isinstance(best_f, torch.Tensor) 414 | if not len(best_f.shape): 415 | best_f = torch.full(logits[..., 0].shape, best_f, device=logits.device) 416 | assert ( 417 | best_f.shape == logits[..., 0].shape 418 | ), f"best_f.shape: {best_f.shape}, logits.shape: {logits.shape}" 419 | 420 | best_f_per_logit = best_f[..., None].repeat(*[1] * len(best_f.shape), logits.shape[-1]) 421 | clamped_best_f = best_f_per_logit.clamp(self.borders[:-1], self.borders[1:]) 422 | 423 | # true bucket contributions 424 | bucket_contributions = ( 425 | (self.borders[1:] ** 2 - clamped_best_f**2) / 2 426 | - best_f_per_logit * (self.borders[1:] - clamped_best_f) 427 | ) / bucket_diffs 428 | 429 | # extra stuff for continuous 430 | side_normals = ( 431 | self.halfnormal_with_p_weight_before(self.bucket_widths[0]), 432 | self.halfnormal_with_p_weight_before(self.bucket_widths[-1]), 433 | ) 434 | position_in_side_normals = ( 435 | -(best_f - self.borders[1]).clamp(max=0.0), 436 | (best_f - self.borders[-2]).clamp(min=0.0), 437 | ) # evaluation_points x batch 438 | 439 | bucket_contributions[..., -1] = self.ei_for_halfnormal( 440 | side_normals[1].scale, position_in_side_normals[1] 441 | ) 442 | 443 | bucket_contributions[..., 0] = self.ei_for_halfnormal( 444 | side_normals[0].scale, torch.zeros_like(position_in_side_normals[0]) 445 | ) - self.ei_for_halfnormal(side_normals[0].scale, position_in_side_normals[0]) 446 | 447 | p = torch.softmax(logits, -1) 448 | return torch.einsum("...b,...b->...", p, bucket_contributions) 449 | 450 | 451 | def get_bucket_limits( 452 | num_outputs: int, 453 | full_range: tuple | None = None, 454 | ys: torch.Tensor = None, 455 | verbose: bool = False, 456 | ) -> torch.Tensor: 457 | assert (ys is None) != (full_range is None), "Either full_range or ys must be passed." 458 | 459 | if ys is not None: 460 | ys = ys.flatten() 461 | ys = ys[~torch.isnan(ys)] 462 | if len(ys) % num_outputs: 463 | ys = ys[: -(len(ys) % num_outputs)] 464 | logging.info( 465 | f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys." 466 | ) 467 | ys_per_bucket = len(ys) // num_outputs 468 | if full_range is None: 469 | full_range = (ys.min(), ys.max()) 470 | else: 471 | assert ( 472 | full_range[0] <= ys.min() and full_range[1] >= ys.max() 473 | ), f"full_range {full_range} not in range of ys {ys.min(), ys.max()}" 474 | full_range = torch.tensor(full_range) 475 | ys_sorted, ys_order = ys.sort(0) 476 | bucket_limits = ( 477 | ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1] 478 | + ys_sorted[ys_per_bucket::ys_per_bucket] 479 | ) / 2 480 | if verbose: 481 | logging.info( 482 | f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys." 483 | ) 484 | assert isinstance(full_range, torch.Tensor) 485 | bucket_limits = torch.cat( 486 | [full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0 487 | ) 488 | 489 | else: 490 | assert isinstance(full_range, torch.Tensor) 491 | class_width = (full_range[1] - full_range[0]) / num_outputs 492 | bucket_limits = torch.cat( 493 | [ 494 | full_range[0] + torch.arange(num_outputs).float() * class_width, 495 | torch.tensor(full_range[1]).unsqueeze(0), 496 | ], 497 | 0, 498 | ) 499 | 500 | assert ( 501 | len(bucket_limits) - 1 == num_outputs 502 | ), f"len(bucket_limits) - 1 == {len(bucket_limits) - 1} != {num_outputs} == num_outputs" 503 | assert full_range[0] == bucket_limits[0], f"{full_range[0]} != {bucket_limits[0]}" 504 | assert full_range[-1] == bucket_limits[-1], f"{full_range[-1]} != {bucket_limits[-1]}" 505 | 506 | return bucket_limits 507 | 508 | 509 | def get_custom_bar_dist(borders: torch.Tensor, criterion: BarDistribution) -> BarDistribution: 510 | # Tested that a bar_dist with borders 0.54 (-> softplus 1.0) yields the same bar distribution as the passed one. 511 | borders_ = torch.nn.functional.softplus(borders) + 0.001 512 | borders_ = torch.cumsum( 513 | torch.cat([criterion.borders[0:1], criterion.bucket_widths]) * borders_, 0 514 | ) 515 | criterion_ = criterion.__class__( 516 | borders=borders_, ignore_nan_targets=criterion.ignore_nan_targets 517 | ) 518 | return criterion_ 519 | -------------------------------------------------------------------------------- /ifbo/decoders.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class ScaledDecoder(nn.Module): 8 | def __init__(self, ninp: int, nhid: int, nout: int) -> None: 9 | super().__init__() 10 | self.linear = nn.Linear(ninp, nhid) 11 | self.linear1 = nn.Linear(nhid, nout) 12 | self.linear2 = nn.Linear(nhid, 10) 13 | 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | # return torch.cat([self.linear1(x), self.linear2(x)], -1) 16 | x = self.linear(x) 17 | x = nn.GELU()(x) 18 | temps = self.linear2(x).softmax(-1) @ torch.tensor( 19 | [1.0, 1.4, 1.7, 2.0, 5.0, 10.0, 20.0, 40.0, 80.0, 160.0], device=x.device 20 | ) 21 | return self.linear1(x) / temps.unsqueeze(-1) 22 | 23 | 24 | class FixedScaledDecoder(nn.Module): 25 | def __init__(self, ninp: int, nhid: int, nout: int) -> None: 26 | super().__init__() 27 | self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout)) 28 | self.T = nn.Parameter(torch.ones(10000) / 10000) 29 | 30 | def forward(self, x: torch.Tensor) -> torch.Tensor: 31 | return self.mapper(x) / self.T.sum() 32 | -------------------------------------------------------------------------------- /ifbo/download.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | import tarfile 5 | 6 | import requests 7 | 8 | 9 | VERSION_MAP = { 10 | "0.0.1": dict( 11 | url="https://ml.informatik.uni-freiburg.de/research-artifacts/ifbo/ftpfnv0.0.1.tar.gz", 12 | name="ftpfnv0.0.1", 13 | final_name="bopfn_broken_unisep_1000curves_10params_2M", 14 | extension="pt", 15 | ) 16 | } 17 | 18 | 19 | # Helper functions to generate the file names 20 | def FILENAME(version: str) -> str: 21 | return f"{VERSION_MAP[version].get('name')}.tar.gz" 22 | 23 | 24 | def FILE_URL(version: str) -> str: 25 | return f"{VERSION_MAP[version].get('url')}" 26 | 27 | 28 | def WEIGHTS_FILE_NAME(version: str) -> str: 29 | return f"{VERSION_MAP[version].get('name')}.{VERSION_MAP[version].get('extension')}" 30 | 31 | 32 | def WEIGHTS_FINAL_NAME(version: str) -> str: 33 | return f"{VERSION_MAP[version].get('final_name')}.{VERSION_MAP[version].get('extension')}" 34 | 35 | 36 | def download_and_decompress(url: str, path: Path) -> None: 37 | """Helper function to download a file from a URL and decompress it and store by given name. 38 | 39 | Args: 40 | url (str): URL of the file to download 41 | path (Path): Path along with filename to save the downloaded file 42 | 43 | Returns: 44 | bool: Flag to indicate if the download and decompression was successful 45 | """ 46 | # Check if the file already exists 47 | if path.exists(): 48 | return 49 | 50 | # Send a HTTP request to the URL of the file 51 | response = requests.get(url, allow_redirects=True) 52 | 53 | # Check if the request is successful 54 | if response.status_code != 200: 55 | raise ValueError( 56 | f"Failed to download the surrogate from {url}." 57 | f" Recieved HTTP status code: {response.status_code}." 58 | "Please either try again later, use an alternative link or contact the authors through github." 59 | ) 60 | 61 | # Save the .tar.gz file 62 | with open(path, "wb") as f: 63 | f.write(response.content) 64 | 65 | # Decompress the .tar.gz file 66 | with tarfile.open(path, "r:gz") as tar: 67 | tar.extractall(path.parent.absolute()) 68 | 69 | 70 | def parse_args() -> argparse.Namespace: 71 | """Helper function to parse the command line arguments.""" 72 | args = argparse.ArgumentParser() 73 | 74 | args.add_argument( 75 | "--version", type=str, default="0.0.1", help="The version of the PFN model to download" 76 | ) 77 | args.add_argument( 78 | "--path", type=str, default=None, help="The path to save the downloaded file" 79 | ) 80 | 81 | parser = args.parse_args() 82 | return parser 83 | 84 | 85 | if __name__ == "__main__": 86 | args = parse_args() 87 | 88 | assert args.version in VERSION_MAP, "The version provided is not available" 89 | 90 | if args.path is None: 91 | args.path = Path(__file__).parent.absolute() / ".." / ".." / "PFNS4HPO" / "final_models" 92 | else: 93 | args.path = Path(args.path) 94 | 95 | if not args.path.exists(): 96 | os.makedirs(args.path) 97 | 98 | # Use the function 99 | download_and_decompress( 100 | url=VERSION_MAP[args.version]["url"], path=args.path / FILENAME(args.version) 101 | ) 102 | print(f"Successfully downloaded FT-PFN v{args.version} in to {args.path}!") 103 | -------------------------------------------------------------------------------- /ifbo/encoders.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | import math 5 | from typing import Any 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from ifbo.utils import normalize_data 11 | 12 | 13 | class StyleEncoder(nn.Module): 14 | def __init__(self, num_hyperparameters: int, em_size: int) -> None: 15 | super().__init__() 16 | self.em_size = em_size 17 | self.embedding = nn.Linear(num_hyperparameters, self.em_size) 18 | 19 | def forward(self, hyperparameters: dict[str, Any]) -> torch.Tensor: # B x num_hps 20 | return self.embedding(hyperparameters) 21 | 22 | 23 | class StyleEmbEncoder(nn.Module): 24 | def __init__(self, num_hyperparameters: int, em_size: int, num_embeddings: int = 100) -> None: 25 | super().__init__() 26 | assert num_hyperparameters == 1 27 | self.em_size = em_size 28 | self.embedding = nn.Embedding(num_embeddings, self.em_size) 29 | 30 | def forward(self, hyperparameters: torch.Tensor) -> torch.Tensor: # B x num_hps 31 | return self.embedding(hyperparameters.squeeze(1)) 32 | 33 | 34 | class _PositionalEncoding(nn.Module): 35 | def __init__(self, d_model: int, dropout: float = 0.0) -> None: 36 | super().__init__() 37 | self.dropout = nn.Dropout(p=dropout) 38 | self.d_model = d_model 39 | self.device_test_tensor = nn.Parameter(torch.tensor(1.0)) 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: # T x B x num_features 42 | assert self.d_model % x.shape[-1] * 2 == 0 43 | d_per_feature = self.d_model // x.shape[-1] 44 | pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device) 45 | # position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 46 | interval_size = 10 47 | div_term = ( 48 | (1.0 / interval_size) 49 | * 2 50 | * math.pi 51 | * torch.exp( 52 | torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float() 53 | * math.log(math.sqrt(2)) 54 | ) 55 | ) 56 | pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term) 57 | pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term) 58 | return self.dropout(pe).view(x.shape[0], x.shape[1], self.d_model) 59 | 60 | 61 | class Positional(_PositionalEncoding): 62 | def __init__(self, num_features: int, emsize: int) -> None: 63 | super().__init__(d_model=emsize) 64 | self.num_features = num_features 65 | 66 | 67 | class EmbeddingEncoder(nn.Module): 68 | def __init__(self, num_features: int, em_size: int, num_embs: int = 100) -> None: 69 | super().__init__() 70 | self.num_embs = num_embs 71 | self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True) 72 | self.init_weights(0.1) 73 | self.min_max = (-2, +2) 74 | 75 | @property 76 | def width(self) -> float: 77 | return self.min_max[1] - self.min_max[0] 78 | 79 | def init_weights(self, initrange: float) -> None: 80 | self.embeddings.weight.data.uniform_(-initrange, initrange) 81 | 82 | def discretize(self, x: torch.Tensor) -> torch.Tensor: 83 | split_size = self.width / self.num_embs 84 | return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1) 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: # T x B x num_features 87 | x_idxs = self.discretize(x) 88 | x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs 89 | return self.embeddings(x_idxs).mean(-2) 90 | 91 | 92 | class Normalize(nn.Module): 93 | def __init__(self, mean: float, std: float) -> None: 94 | super().__init__() 95 | self.mean = mean 96 | self.std = std 97 | 98 | def forward(self, x: torch.Tensor) -> torch.Tensor: 99 | return (x - self.mean) / self.std 100 | 101 | 102 | class SqueezeBetween0and1(nn.Module): # take care of test set here 103 | def forward(self, x: torch.Tensor) -> torch.Tensor: 104 | width = x.max(0).values - x.min(0).values 105 | result = (x - x.min(0).values) / width 106 | result[(width == 0)[None].repeat(len(x), *[1] * (len(x.shape) - 1))] = 0.5 107 | return result 108 | 109 | 110 | def get_normalized_uniform_encoder( 111 | encoder_creator: Callable[[int, int], nn.Module], 112 | ) -> Callable[[int, int], nn.Module]: 113 | """ 114 | This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std. 115 | For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can 116 | be initialized with `encoder_creator(feature_dim, in_dim)`. 117 | :param encoder: 118 | :return: 119 | """ 120 | return lambda in_dim, out_dim: nn.Sequential( 121 | Normalize(0.5, math.sqrt(1 / 12)), encoder_creator(in_dim, out_dim) 122 | ) 123 | 124 | 125 | def get_normalized_encoder( 126 | encoder_creator: Callable[[int, int], nn.Module], data_std: float 127 | ) -> Callable[[int, int], nn.Module]: 128 | return lambda in_dim, out_dim: nn.Sequential( 129 | Normalize(0.0, data_std), encoder_creator(in_dim, out_dim) 130 | ) 131 | 132 | 133 | def get_log_dims(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: 134 | logged_x = ((x + eps).log() - math.log(eps)) / (math.log(1.0 + eps) - math.log(eps)) 135 | return logged_x 136 | 137 | 138 | def add_log_neglog_dims(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: 139 | logged_x = get_log_dims(x, eps) / 2.0 140 | neglogged_x = 1 - get_log_dims(1 - x, eps) / 2.0 141 | logged_x[x > 0.5] = neglogged_x[x > 0.5] 142 | return torch.stack([x, logged_x], -1).view(*x.shape[:-1], -1) 143 | 144 | 145 | class AddLogNegLogDims(nn.Module): 146 | def __init__(self, eps: float = 1e-10) -> None: 147 | super().__init__() 148 | self.eps = eps 149 | 150 | def forward(self, x: torch.Tensor) -> torch.Tensor: 151 | return add_log_neglog_dims(x, self.eps) 152 | 153 | 154 | def get_logdim_encoder( 155 | encoder_creator: Callable[[int, int], nn.Module], eps: float = 1e-10 156 | ) -> Callable[[int, int], nn.Module]: 157 | return lambda in_dim, out_dim: nn.Sequential( 158 | AddLogNegLogDims(eps), encoder_creator(in_dim * 2, out_dim) 159 | ) 160 | 161 | 162 | class ZNormalize(nn.Module): 163 | def forward(self, x: torch.Tensor) -> torch.Tensor: 164 | std = x.std(-1, keepdim=True) 165 | std[std == 0.0] = 1.0 166 | return (x - x.mean(-1, keepdim=True)) / std 167 | 168 | 169 | class ZNormalizePerDataset(nn.Module): 170 | def forward(self, x: torch.Tensor) -> torch.Tensor: 171 | std = x.std(0, keepdim=True) 172 | std[std == 0.0] = 1.0 173 | return (x - x.mean(0, keepdim=True)) / std 174 | 175 | 176 | class AppendEmbeddingEncoder(nn.Module): 177 | def __init__( 178 | self, base_encoder: Callable[[torch.Tensor], torch.Tensor], num_features: int, emsize: int 179 | ) -> None: 180 | super().__init__() 181 | self.num_features = num_features 182 | self.base_encoder = base_encoder 183 | self.emb = nn.Parameter(torch.zeros(emsize)) 184 | 185 | def forward(self, x: torch.Tensor) -> torch.Tensor: 186 | if (x[-1] == 1.0).all(): 187 | append_embedding = True 188 | else: 189 | assert (x[-1] == 0.0).all(), ( 190 | "You need to specify as last position whether to append embedding. " 191 | "If you don't want this behavior, please use the wrapped encoder instead." 192 | ) 193 | append_embedding = False 194 | x = x[:-1] 195 | encoded_x = self.base_encoder(x) 196 | if append_embedding: 197 | encoded_x = torch.cat( 198 | [encoded_x, self.emb[None, None, :].repeat(1, encoded_x.shape[1], 1)], 0 199 | ) 200 | return encoded_x 201 | 202 | 203 | def get_append_embedding_encoder( 204 | encoder_creator: Callable[[int, int], nn.Module], 205 | ) -> Callable[[int, int], nn.Module]: 206 | return lambda num_features, emsize: AppendEmbeddingEncoder( 207 | encoder_creator(num_features, emsize), num_features, emsize 208 | ) 209 | 210 | 211 | class VariableNumFeaturesEncoder(nn.Module): 212 | def __init__( 213 | self, base_encoder: Callable[[torch.Tensor], torch.Tensor], num_features: int 214 | ) -> None: 215 | super().__init__() 216 | self.base_encoder = base_encoder 217 | self.num_features = num_features 218 | 219 | def forward(self, x: torch.Tensor) -> torch.Tensor: 220 | x = x * (self.num_features / x.shape[-1]) 221 | x = torch.cat( 222 | (x, torch.zeros(*x.shape[:-1], self.num_features - x.shape[-1], device=x.device)), -1 223 | ) 224 | return self.base_encoder(x) 225 | 226 | 227 | def get_variable_num_features_encoder( 228 | encoder_creator: Callable[[int, int], nn.Module], 229 | ) -> Callable[[int, int], nn.Module]: 230 | return lambda num_features, emsize: VariableNumFeaturesEncoder( 231 | encoder_creator(num_features, emsize), num_features 232 | ) 233 | 234 | 235 | class NoMeanEncoder(nn.Module): 236 | """ 237 | This can be useful for any prior that is translation invariant in x or y. 238 | A standard GP for example is translation invariant in x. 239 | That is, GP(x_test+const,x_train+const,y_train) = GP(x_test,x_train,y_train). 240 | """ 241 | 242 | def __init__(self, base_encoder: Callable[[torch.Tensor], torch.Tensor]) -> None: 243 | super().__init__() 244 | self.base_encoder = base_encoder 245 | 246 | def forward(self, x: torch.Tensor) -> torch.Tensor: 247 | return self.base_encoder(x - x.mean(0, keepdim=True)) 248 | 249 | 250 | def get_no_mean_encoder( 251 | encoder_creator: Callable[[int, int], nn.Module], 252 | ) -> Callable[[int, int], nn.Module]: 253 | return lambda num_features, emsize: NoMeanEncoder(encoder_creator(num_features, emsize)) 254 | 255 | 256 | def MLP(num_features: int, emsize: int) -> nn.Sequential: 257 | return nn.Sequential( 258 | nn.Linear(num_features, emsize * 2), 259 | nn.ReLU(), 260 | nn.Linear(emsize * 2, emsize), 261 | ) 262 | 263 | 264 | class NanHandlingEncoder(nn.Module): 265 | def __init__(self, num_features: int, emsize: int, keep_nans: bool = True) -> None: 266 | super().__init__() 267 | self.num_features = 2 * num_features if keep_nans else num_features 268 | self.emsize = emsize 269 | self.keep_nans = keep_nans 270 | self.layer = nn.Linear(self.num_features, self.emsize) 271 | 272 | def forward(self, x: torch.Tensor) -> torch.Tensor: 273 | if self.keep_nans: 274 | x = torch.cat( 275 | [ 276 | torch.nan_to_num(x, nan=0.0), 277 | normalize_data( 278 | torch.isnan(x) * -1 279 | + torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1 280 | + torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2 281 | ), 282 | ], 283 | -1, 284 | ) 285 | else: 286 | x = torch.nan_to_num(x, nan=0.0) 287 | return self.layer(x) 288 | 289 | 290 | class Linear(nn.Linear): 291 | def __init__(self, num_features: int, emsize: int, replace_nan_by_zero: bool = False) -> None: 292 | super().__init__(num_features, emsize) 293 | self.num_features = num_features 294 | self.emsize = emsize 295 | self.replace_nan_by_zero = replace_nan_by_zero 296 | 297 | def forward(self, x: torch.Tensor) -> torch.Tensor: 298 | if self.replace_nan_by_zero: 299 | x = torch.nan_to_num(x, nan=0.0) 300 | return super().forward(x) 301 | 302 | def __setstate__(self, state: dict[str, Any]) -> None: 303 | super().__setstate__(state) 304 | self.__dict__.setdefault("replace_nan_by_zero", True) 305 | 306 | 307 | class Conv(nn.Module): 308 | def __init__(self, input_size: int, emsize: int) -> None: 309 | super().__init__() 310 | self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)]) 311 | self.linear = nn.Linear(64, emsize) 312 | 313 | def forward(self, x: torch.Tensor) -> torch.Tensor: 314 | size = math.isqrt(x.shape[-1]) 315 | assert size * size == x.shape[-1] 316 | x = x.reshape(*x.shape[:-1], 1, size, size) 317 | for conv in self.convs: 318 | if x.shape[-1] < 4: 319 | break 320 | x = conv(x) 321 | x.relu_() 322 | x = nn.AdaptiveAvgPool2d((1, 1))(x).squeeze(-1).squeeze(-1) 323 | return self.linear(x) 324 | 325 | 326 | class CanEmb(nn.Embedding): 327 | def __init__( 328 | self, num_features: int, num_embeddings: int, embedding_dim: int, *args: Any, **kwargs: Any 329 | ) -> None: 330 | assert embedding_dim % num_features == 0 331 | embedding_dim = embedding_dim // num_features 332 | super().__init__(num_embeddings, embedding_dim, *args, **kwargs) 333 | 334 | def forward(self, x: torch.Tensor) -> torch.Tensor: 335 | lx = x.long() 336 | assert (lx == x).all(), "CanEmb only works with tensors of whole numbers" 337 | x = super().forward(lx) 338 | return x.view(*x.shape[:-2], -1) 339 | 340 | 341 | def get_Canonical(num_classes: int) -> Callable[[int, int], nn.Module]: 342 | return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize) 343 | 344 | 345 | def get_Embedding(num_embs_per_feature: int = 100) -> Callable[[int, int], nn.Module]: 346 | return lambda num_features, emsize: EmbeddingEncoder( 347 | num_features, emsize, num_embs=num_embs_per_feature 348 | ) 349 | -------------------------------------------------------------------------------- /ifbo/initializers.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | from torch import nn 4 | 5 | 6 | def get_NormalInitializer(std: float) -> Callable[[nn.Module], nn.Module]: 7 | def initializer(m: nn.Module) -> nn.Module: 8 | if isinstance(m, nn.Linear): 9 | nn.init.normal_(m.weight, 0, std) 10 | nn.init.normal_(m.bias, 0, std) 11 | 12 | return initializer 13 | -------------------------------------------------------------------------------- /ifbo/layer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import partial 4 | from typing import Any 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.nn.modules.transformer import _get_activation_fn 9 | from torch.nn.modules.transformer import Dropout 10 | from torch.nn.modules.transformer import LayerNorm 11 | from torch.nn.modules.transformer import Linear 12 | from torch.nn.modules.transformer import Module 13 | from torch.nn.modules.transformer import MultiheadAttention 14 | from torch.nn.modules.transformer import Tensor 15 | from torch.utils.checkpoint import checkpoint 16 | 17 | 18 | class TransformerEncoderLayer(Module): 19 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 20 | This standard encoder layer is based on the paper "Attention Is All You Need". 21 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 22 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 23 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 24 | in a different way during application. 25 | 26 | Args: 27 | d_model: the number of expected features in the input (required). 28 | nhead: the number of heads in the multiheadattention models (required). 29 | dim_feedforward: the dimension of the feedforward network model (default=2048). 30 | dropout: the dropout value (default=0.1). 31 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 32 | layer_norm_eps: the eps value in layer normalization components (default=1e-5). 33 | batch_first: If ``True``, then the input and output tensors are provided 34 | as (batch, seq, feature). Default: ``False``. 35 | 36 | Examples:: 37 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 38 | >>> src = torch.rand(10, 32, 512) 39 | >>> out = encoder_layer(src) 40 | 41 | Alternatively, when ``batch_first`` is ``True``: 42 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) 43 | >>> src = torch.rand(32, 10, 512) 44 | >>> out = encoder_layer(src) 45 | """ 46 | 47 | __constants__ = ["batch_first"] 48 | 49 | def __init__( 50 | self, 51 | d_model: int, 52 | nhead: int, 53 | dim_feedforward: int = 2048, 54 | dropout: float = 0.1, 55 | activation: str = "relu", 56 | layer_norm_eps: float = 1e-5, 57 | batch_first: bool = False, 58 | pre_norm: bool = False, 59 | device: torch.device | None = None, 60 | dtype: torch.dtype | None = None, 61 | recompute_attn: bool = False, 62 | save_trainingset_representations: bool = False, 63 | ) -> None: 64 | factory_kwargs = {"device": device, "dtype": dtype} 65 | super().__init__() 66 | self.self_attn = MultiheadAttention( 67 | d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs 68 | ) 69 | # Implementation of Feedforward model 70 | self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) 71 | self.dropout = Dropout(dropout) 72 | self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) 73 | 74 | self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 75 | self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 76 | self.dropout1 = Dropout(dropout) 77 | self.dropout2 = Dropout(dropout) 78 | self.pre_norm = pre_norm 79 | self.recompute_attn = recompute_attn 80 | self.save_trainingset_representations = save_trainingset_representations 81 | self.saved_src_to_attend_to = None 82 | 83 | self.activation = _get_activation_fn(activation) 84 | 85 | def __setstate__(self, state: dict[str, Any]) -> None: 86 | if "activation" not in state: 87 | state["activation"] = F.relu 88 | super().__setstate__(state) 89 | self.__dict__.setdefault("save_trainingset_representations", False) 90 | 91 | def forward( 92 | self, 93 | src: Tensor, 94 | src_mask: Tensor | None = None, 95 | src_key_padding_mask: Tensor | None = None, 96 | ) -> Tensor: 97 | r"""Pass the input through the encoder layer. 98 | 99 | Args: 100 | src: the sequence to the encoder layer (required). 101 | src_mask: the mask for the src sequence (optional). 102 | src_key_padding_mask: the mask for the src keys per batch (optional). 103 | 104 | Shape: 105 | see the docs in Transformer class. 106 | """ 107 | if self.save_trainingset_representations: 108 | assert ( 109 | isinstance(src_mask, int) and not self.training 110 | ), "save_trainingset_representations is only supported in eval mode and requires src_mask to be an int" 111 | 112 | if self.pre_norm: 113 | src_ = self.norm1(src) 114 | else: 115 | src_ = src 116 | if isinstance(src_mask, tuple): 117 | # global attention setup 118 | assert not self.self_attn.batch_first 119 | assert src_key_padding_mask is None 120 | 121 | global_src_mask, trainset_src_mask, valset_src_mask = src_mask 122 | 123 | num_global_tokens = global_src_mask.shape[0] 124 | num_train_tokens = trainset_src_mask.shape[0] 125 | 126 | global_tokens_src = src_[:num_global_tokens] 127 | train_tokens_src = src_[num_global_tokens : num_global_tokens + num_train_tokens] 128 | global_and_train_tokens_src = src_[: num_global_tokens + num_train_tokens] 129 | eval_tokens_src = src_[num_global_tokens + num_train_tokens :] 130 | 131 | attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn 132 | 133 | global_tokens_src2 = attn( 134 | global_tokens_src, 135 | global_and_train_tokens_src, 136 | global_and_train_tokens_src, 137 | None, 138 | True, 139 | global_src_mask, 140 | )[0] 141 | train_tokens_src2 = attn( 142 | train_tokens_src, 143 | global_tokens_src, 144 | global_tokens_src, 145 | None, 146 | True, 147 | trainset_src_mask, 148 | )[0] 149 | eval_tokens_src2 = attn(eval_tokens_src, src_, src_, None, True, valset_src_mask)[0] 150 | 151 | src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0) 152 | 153 | elif isinstance(src_mask, int): 154 | assert src_key_padding_mask is None 155 | single_eval_position = src_mask 156 | src_to_attend_to = src_[:single_eval_position] 157 | if self.save_trainingset_representations: 158 | if single_eval_position == src_.shape[0] or single_eval_position is None: 159 | self.saved_src_to_attend_to = src_to_attend_to 160 | elif single_eval_position == 0: 161 | if self.saved_src_to_attend_to is None: 162 | raise ValueError( 163 | "First save the trainingset representations by passing in a src_mask of None or the length of the src" 164 | ) 165 | src_to_attend_to = self.saved_src_to_attend_to 166 | else: 167 | raise ValueError( 168 | "save_trainingset_representations only supports single_eval_position == 0 or single_eval_position == src.shape[0]" 169 | ) 170 | src_left = self.self_attn( 171 | src_[:single_eval_position], 172 | src_[:single_eval_position], 173 | src_[:single_eval_position], 174 | )[0] 175 | src_right = self.self_attn( 176 | src_[single_eval_position:], src_to_attend_to, src_to_attend_to 177 | )[0] 178 | src2 = torch.cat([src_left, src_right], dim=0) 179 | else: 180 | if self.recompute_attn: 181 | src2 = checkpoint( 182 | self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask 183 | )[0] 184 | else: 185 | src2 = self.self_attn( 186 | src_, src_, src_, attn_mask=src_mask, key_padding_mask=src_key_padding_mask 187 | )[0] 188 | src = src + self.dropout1(src2) 189 | if not self.pre_norm: 190 | src = self.norm1(src) 191 | 192 | if self.pre_norm: 193 | src_ = self.norm2(src) 194 | else: 195 | src_ = src 196 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src_)))) 197 | src = src + self.dropout2(src2) 198 | 199 | if not self.pre_norm: 200 | src = self.norm2(src) 201 | return src 202 | -------------------------------------------------------------------------------- /ifbo/positional_encodings.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | # Protocol for positonal encodings. 8 | # __init__(d_model, max_len=..[, more optionals]) 9 | # forward(x: (seq_len, bs, d_model)) -> Tensor of shape (*x.shape[:2],d_model) containing pos. embeddings 10 | 11 | 12 | class NoPositionalEncoding(nn.Module): 13 | def __init__(self, d_model: int, max_len: int | None = None) -> None: 14 | super(NoPositionalEncoding, self).__init__() 15 | pass 16 | 17 | def forward(self, x: torch.Tensor) -> torch.Tensor: 18 | return x # * math.sqrt(x.shape[-1]) 19 | 20 | 21 | class PositionalEncoding(nn.Module): 22 | def __init__(self, d_model: int, max_len: int = 5000) -> None: 23 | super(PositionalEncoding, self).__init__() 24 | pe = torch.zeros(max_len, d_model) 25 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 26 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 27 | pe[:, 0::2] = torch.sin(position * div_term) 28 | pe[:, 1::2] = torch.cos(position * div_term) 29 | pe = pe.unsqueeze(0).transpose(0, 1) 30 | self.register_buffer("pe", pe) 31 | 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | x = self.pe[: x.size(0), :] + x # * math.sqrt(x.shape[-1]) 34 | return x 35 | 36 | 37 | class LearnedPositionalEncoding(nn.Module): 38 | def __init__(self, d_model: int, max_len: int = 5000) -> None: 39 | super(LearnedPositionalEncoding, self).__init__() 40 | self.max_seq_len = max_len 41 | # self.positional_embeddings = nn.Embedding(max_len, d_model) 42 | self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model)) 43 | nn.init.normal_(self.positional_embeddings, mean=0, std=d_model**-0.5) 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | seq_len, bs, d_model = x.shape 47 | assert seq_len <= len(self.positional_embeddings), "seq_len can be at most max_len." 48 | pos_emb = self.positional_embeddings[:seq_len] 49 | return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x # * math.sqrt(x.shape[-1]) 50 | 51 | 52 | class PairedScrambledPositionalEncodings(LearnedPositionalEncoding): 53 | # TODO check whether it is a problem to use the same perm. for full batch 54 | def forward(self, x: torch.Tensor) -> torch.Tensor: 55 | seq_len, bs, d_model = x.shape 56 | assert seq_len <= len(self.positional_embeddings), "seq_len can be at most max_len." 57 | assert len(self.positional_embeddings) % 2 == 0, "Please specify an even max_len." 58 | 59 | paired_embs = self.positional_embeddings.view(len(self.positional_embeddings), -1, 2) 60 | pos_emb = paired_embs[torch.randperm(len(paired_embs))].view( 61 | *self.positional_embeddings.shape 62 | )[:seq_len] 63 | 64 | return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x # * math.sqrt(x.shape[-1]) 65 | -------------------------------------------------------------------------------- /ifbo/priors/__init__.py: -------------------------------------------------------------------------------- 1 | class AbstractDatasetPrior: 2 | def new_dataset(self) -> None: 3 | raise NotImplementedError 4 | -------------------------------------------------------------------------------- /ifbo/priors/ftpfn_prior.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | import math 3 | import os 4 | from typing import Any 5 | import warnings 6 | 7 | import numpy as np 8 | from scipy.stats import beta 9 | from scipy.stats import expon 10 | from scipy.stats import gamma 11 | from scipy.stats import norm 12 | import torch 13 | 14 | from ifbo import encoders 15 | from ifbo.encoders import Normalize 16 | from ifbo.priors.prior import Batch 17 | from ifbo.utils import default_device 18 | 19 | 20 | OUTPUT_SORTED = np.load( 21 | os.path.join(os.path.dirname(os.path.abspath(__file__)), "output_sorted.npy") 22 | ) 23 | 24 | 25 | def progress_noise(X: np.ndarray, sigma: float, L: float) -> np.ndarray: 26 | EPS = 10**-9 27 | N = len(X) 28 | 29 | Z = np.random.normal(0, sigma, size=(N,)) 30 | 31 | SIGMA = np.exp(-(np.subtract.outer(X, X) ** 2) / L) 32 | 33 | SIGMA += EPS * np.eye(N) # to guarantee SPD 34 | 35 | C = np.linalg.cholesky(SIGMA) 36 | 37 | return C @ Z 38 | 39 | 40 | def add_noise_and_break( 41 | x: np.ndarray, x_noise: None, Xsat: np.ndarray, Rpsat: np.ndarray 42 | ) -> np.ndarray: 43 | x = np.where( 44 | x < Xsat, x, Rpsat * (x - Xsat) + Xsat 45 | ) # add a breaking point when saturation is reached 46 | return x 47 | # noisy_x = x + x_noise 48 | # add the exponential tails to avoid negative x 49 | # TODO: actually make curve go to 0 in the negative range (would allow divergence beyond Y0) 50 | # noisy_x = np.where(noisy_x > 1/1000, noisy_x, np.exp(noisy_x-1/1000+np.log(1/1000))) 51 | # return noisy_x 52 | 53 | 54 | def comb( 55 | x: np.ndarray, 56 | Y0: float = 0.2, 57 | Yinf: float = 0.8, 58 | sigma: float | None = 0.01, 59 | L: float | None = 0.0001, 60 | PREC: list[int] = [100] * 4, 61 | Xsat: list[float] = [1.0] * 4, 62 | alpha: list[float] = [np.exp(1), np.exp(-1), 1 + np.exp(-4), np.exp(0)], 63 | Rpsat: list[float] = [1.0] * 4, 64 | w: list[float] = [1 / 4] * 4, 65 | ) -> float: 66 | # x_noise = progress_noise(x,sigma,L) 67 | x_noise = None 68 | EPS = 10**-9 69 | x_eps = np.array([EPS, 2 * EPS]) 70 | 71 | # POW4 with exponential tail 72 | x_pow = add_noise_and_break(x, x_noise, Xsat[0], Rpsat[0]) 73 | pow_eps = ( 74 | Yinf - (Yinf - Y0) * (((PREC[0]) ** (1 / alpha[0]) - 1) / Xsat[0] * x_eps + 1) ** -alpha[0] 75 | ) 76 | pow_grad = (pow_eps[1] - pow_eps[0]) / EPS 77 | pow_y = np.where( 78 | x_pow > 0, 79 | Yinf 80 | - (Yinf - Y0) * (((PREC[0]) ** (1 / alpha[0]) - 1) / Xsat[0] * x_pow + 1) ** -alpha[0], 81 | Y0 * np.exp(x_pow * (pow_grad + EPS) / Y0), 82 | ) 83 | 84 | x_exp = add_noise_and_break(x, x_noise, Xsat[1], Rpsat[1]) 85 | exp_eps = Yinf - (Yinf - Y0) * PREC[1] ** (-((x_eps / Xsat[1]) ** alpha[1])) 86 | exp_grad = (exp_eps[1] - exp_eps[0]) / EPS 87 | exp_y = np.where( 88 | x_exp > 0, 89 | Yinf - (Yinf - Y0) * PREC[1] ** (-((x_exp / Xsat[1]) ** alpha[1])), 90 | Y0 * np.exp(x_exp * (exp_grad + EPS) / Y0), 91 | ) 92 | 93 | x_log = add_noise_and_break(x, x_noise, Xsat[2], Rpsat[2]) 94 | log_eps = Yinf - (Yinf - Y0) * np.log(alpha[2]) / ( 95 | np.log((alpha[2] ** PREC[2] - alpha[2]) * x_eps / Xsat[2] + alpha[2]) 96 | ) 97 | log_grad = (log_eps[1] - log_eps[0]) / EPS 98 | log_y = np.where( 99 | x_log > 0, 100 | Yinf 101 | - (Yinf - Y0) 102 | * np.log(alpha[2]) 103 | / (np.log((alpha[2] ** PREC[2] - alpha[2]) * x_log / Xsat[2] + alpha[2])), 104 | Y0 * np.exp(x_log * (log_grad + EPS) / Y0), 105 | ) 106 | 107 | x_hill = add_noise_and_break(x, x_noise, Xsat[3], Rpsat[3]) 108 | hill_eps = Yinf - (Yinf - Y0) / ((x_eps / Xsat[3]) ** alpha[3] * (PREC[3] - 1) + 1) 109 | hill_grad = (hill_eps[1] - hill_eps[0]) / EPS 110 | hill_y = np.where( 111 | x_hill > 0, 112 | Yinf - (Yinf - Y0) / ((x_hill / Xsat[3]) ** alpha[3] * (PREC[3] - 1) + 1), 113 | Y0 * np.exp(x_hill * (hill_grad + EPS) / Y0), 114 | ) 115 | 116 | return w[0] * pow_y + w[1] * exp_y + w[2] * log_y + w[3] * hill_y 117 | 118 | 119 | class MLP(torch.nn.Module): 120 | def __init__(self, num_inputs: int, num_outputs: int) -> None: 121 | super(MLP, self).__init__() 122 | 123 | num_layers = np.random.randint(8, 16) 124 | num_hidden = np.random.randint(36, 150) 125 | self.init_std = np.random.uniform(0.089, 0.193) 126 | self.sparseness = 0.145 127 | self.preactivation_noise_std = np.random.uniform( 128 | 0.0003, 0.0014 129 | ) # TODO: check value for this! 130 | self.output_noise = np.random.uniform(0.0004, 0.0013) 131 | activation = "tanh" 132 | 133 | self.linears = torch.nn.ModuleList( 134 | [torch.nn.Linear(num_inputs, num_hidden)] 135 | + [torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers - 2)] 136 | + [torch.nn.Linear(num_hidden, num_outputs)] 137 | ) 138 | 139 | self.reset_parameters() 140 | 141 | self.activation = { 142 | "tanh": torch.nn.Tanh(), 143 | "relu": torch.nn.ReLU(), 144 | "elu": torch.nn.ELU(), 145 | "identity": torch.nn.Identity(), 146 | }[activation] 147 | 148 | def reset_parameters( 149 | self, init_std: float | None = None, sparseness: float | None = None 150 | ) -> None: 151 | init_std = init_std if init_std is not None else self.init_std 152 | sparseness = sparseness if sparseness is not None else self.sparseness 153 | for linear in self.linears: 154 | linear.reset_parameters() 155 | 156 | with torch.no_grad(): 157 | if init_std is not None: 158 | for linear in self.linears: 159 | linear.weight.normal_(0, init_std) 160 | linear.bias.normal_(0, init_std) 161 | 162 | if sparseness > 0.0: 163 | for linear in self.linears[1:-1]: 164 | linear.weight /= (1.0 - sparseness) ** (1 / 2) 165 | linear.weight *= torch.bernoulli( 166 | torch.ones_like(linear.weight) * (1.0 - sparseness) 167 | ) 168 | 169 | def forward(self, x: torch.Tensor) -> torch.Tensor: 170 | for linear in self.linears[:-1]: 171 | x = linear(x) 172 | x = x + torch.randn_like(x) * self.preactivation_noise_std 173 | x = torch.tanh(x) 174 | x = self.linears[-1](x) 175 | return x + torch.randn_like(x) * self.output_noise 176 | 177 | 178 | class DatasetPrior: 179 | def _get_model(self) -> torch.nn.Module: 180 | return MLP(self.num_inputs, self.num_outputs).to("cpu") 181 | 182 | def _output_for(self, input: torch.Tensor) -> torch.Tensor: 183 | with torch.no_grad(): 184 | # normalize the inputs 185 | input = self.normalizer(input) 186 | # reweight the inputs for parameter importance 187 | # input = input*self.input_weights # TODO: consider adding this again 188 | # apply the model produce the output 189 | output = self.model(input.float()) 190 | # rescale and shift outputs to account for parameter sensitivity 191 | # This output scaling causes issues with 192 | # output = output * self.output_sensitivity + self.output_offset 193 | return output 194 | 195 | def __init__(self, num_params: int, num_outputs: int) -> None: 196 | self.num_features = num_params 197 | self.num_outputs = num_outputs 198 | self.num_inputs = num_params 199 | 200 | self.normalizer = Normalize(0.5, math.sqrt(1 / 12)) 201 | 202 | self.new_dataset() 203 | 204 | def new_dataset(self) -> None: 205 | # reinitialize all dataset specific random variables 206 | # reinit the parameters of the BNN 207 | self.model = self._get_model() 208 | # initial performance (after init) & max performance 209 | u1 = np.random.uniform() 210 | u2 = np.random.uniform() 211 | self.y0 = min(u1, u2) 212 | self.ymax = max(u1, u2) if np.random.uniform() < 0.25 else 1.0 213 | # TODO: this is not standard BOPFN BNN, but consider adding this 214 | # the input weights (parameter importance & magnitude of aleatoric uncertainty on the curve) 215 | # param_importance = np.random.dirichlet([1]*(self.num_inputs-1) + [0.1]) # relative parameter importance 216 | # lscale = np.exp(np.random.normal(2, 0.5)) # length scale ~ complexity of the landscape 217 | # self.input_weights = np.concatenate((param_importance*lscale*self.num_inputs, np.full((1,),lscale)), axis=0) 218 | # the output weights (curve property sensitivity) 219 | # self.output_sensitivity = np.random.uniform(size=(self.num_outputs,)) 220 | # self.output_offset = np.random.uniform((self.output_sensitivity-1)/2,(1-self.output_sensitivity)/2) 221 | 222 | def curves_for_configs( 223 | self, configs: np.ndarray, noise: bool = True 224 | ) -> Callable[[np.ndarray, int], np.ndarray]: 225 | # more efficient batch-wise 226 | ncurves = 4 227 | bnn_outputs = self.output_for_config(configs, noise=noise) 228 | 229 | indices = np.searchsorted(OUTPUT_SORTED, bnn_outputs, side="left") 230 | 231 | rng4config = MyRNG(indices) 232 | 233 | Y0 = self.y0 234 | 235 | # sample Yinf (shared by all components) 236 | Yinf = rng4config.uniform(a=Y0, b=self.ymax) # 0 237 | assert isinstance(Yinf, np.ndarray) 238 | 239 | # sample weights for basis curves (dirichlet) 240 | w = np.stack([rng4config.gamma(a=1) for i in range(ncurves)]).T # 1, 2, 3, 4 241 | w = w / w.sum(axis=1, keepdims=1) 242 | 243 | # sample shape/skew parameter for each basis curve 244 | alpha = np.stack( 245 | [ 246 | np.exp(rng4config.normal(1, 1)), # 5 247 | np.exp(rng4config.normal(0, 1)), # 6 248 | 1.0 + np.exp(rng4config.normal(-4, 1)), # 7 249 | np.exp(rng4config.normal(0.5, 0.5)), 250 | ] 251 | ).T # 8 252 | 253 | # sample saturation x for each basis curve 254 | Xsat_max = 10 ** rng4config.normal(0, 1) # max saturation # 9 255 | assert isinstance(Xsat_max, np.ndarray) 256 | 257 | Xsat_rel = np.stack( 258 | [rng4config.gamma(a=1) for i in range(ncurves)] 259 | ).T # relative saturation points # 10, 11, 12, 13 260 | 261 | Xsat = ((Xsat_max.T * Xsat_rel.T) / np.max(Xsat_rel, axis=1)).T 262 | 263 | # sample relative saturation y (PREC) for each basis curve 264 | PREC = np.stack( 265 | [1.0 / 10 ** rng4config.uniform(-3, 0) for i in range(ncurves)] 266 | ).T # 14, 15, 16, 17 267 | 268 | # post saturation convergence/divergence rate for each basis curve 269 | Rpsat = np.stack( 270 | [1.0 - rng4config.exponential(scale=1) for i in range(ncurves)] 271 | ).T # 18, 19, 20, 21 272 | 273 | # sample noise parameters 274 | sigma = np.exp(rng4config.normal(loc=-5, scale=1)) 275 | # sigma_x = np.exp(rng4config.normal(-4,0.5)) # STD of the xGP 22 276 | # sigma_y_scaler = np.exp(rng4config.uniform(-5,0.0)) # STD of the yGP 23 277 | # L = 10**rng4config.normal(-5,1) # Length-scale of the xyGP 24 278 | 279 | def foo(x_: np.ndarray, cid: int = 0) -> np.ndarray: 280 | warnings.filterwarnings("ignore") 281 | y_ = comb( 282 | x_, 283 | Y0=Y0, 284 | Yinf=Yinf[cid], 285 | sigma=None, 286 | L=None, 287 | Xsat=Xsat[cid], 288 | alpha=alpha[cid], 289 | Rpsat=Rpsat[cid], 290 | w=w[cid], 291 | PREC=PREC[cid], 292 | ) 293 | # y_ = comb(x_, Y0=Y0, Yinf=Yinf[cid], sigma=sigma_x[cid], L=L[cid], Xsat=Xsat[cid], alpha=alpha[cid], Rpsat=Rpsat[cid], w=w[cid], PREC=PREC[cid]) 294 | y_noise = np.random.normal(size=x_.shape, scale=sigma[cid]) 295 | # y_noise = progress_noise(x_,1,L) 296 | # y_noise *= np.minimum(y_,1.0-y_)/4*sigma_y_scaler[cid] 297 | return np.clip(y_ + y_noise, 0.0, 1.0) 298 | 299 | return foo 300 | 301 | def output_for_config(self, config: np.ndarray, noise: bool = True) -> np.ndarray: 302 | # add aleatoric noise & bias 303 | output = self._output_for(torch.from_numpy(config)) 304 | return output.numpy() 305 | 306 | def uniform(self, bnn_output: np.ndarray, a: float = 0.0, b: float = 1.0) -> np.ndarray: 307 | indices = np.searchsorted(OUTPUT_SORTED, bnn_output, side="left") 308 | return (b - a) * indices / len(OUTPUT_SORTED) + a 309 | 310 | def normal(self, bnn_output: np.ndarray, loc: float = 0, scale: float = 1) -> np.ndarray: 311 | eps = 0.5 / len(OUTPUT_SORTED) # to avoid infinite samples 312 | u = self.uniform(bnn_output, a=eps, b=1 - eps) 313 | return norm.ppf(u, loc=loc, scale=scale) 314 | 315 | def beta( 316 | self, bnn_output: np.ndarray, a: float = 1, b: float = 1, loc: float = 0, scale: float = 1 317 | ) -> np.ndarray: 318 | eps = 0.5 / len(OUTPUT_SORTED) # to avoid infinite samples 319 | u = self.uniform(bnn_output, a=eps, b=1 - eps) 320 | return beta.ppf(u, a=a, b=b, loc=loc, scale=scale) 321 | 322 | def gamma( 323 | self, bnn_output: np.ndarray, a: float = 1, loc: float = 0, scale: float = 1 324 | ) -> np.ndarray: 325 | eps = 0.5 / len(OUTPUT_SORTED) # to avoid infinite samples 326 | u = self.uniform(bnn_output, a=eps, b=1 - eps) 327 | return gamma.ppf(u, a=a, loc=loc, scale=scale) 328 | 329 | def exponential(self, bnn_output: np.ndarray, scale: float = 1) -> np.ndarray: 330 | eps = 0.5 / len(OUTPUT_SORTED) # to avoid infinite samples 331 | u = self.uniform(bnn_output, a=eps, b=1 - eps) 332 | return expon.ppf(u, scale=scale) 333 | 334 | 335 | class MyRNG: 336 | def __init__(self, indices: np.ndarray) -> None: 337 | self.indices = indices.T 338 | self.reset() 339 | 340 | def reset(self) -> None: 341 | self.counter = 0 342 | 343 | def uniform(self, a: float = 0.0, b: float = 1.0) -> float | np.ndarray: 344 | u = (b - a) * self.indices[self.counter] / len(OUTPUT_SORTED) + a 345 | self.counter += 1 346 | return u 347 | 348 | def normal(self, loc: float = 0, scale: float = 1) -> float | np.ndarray: 349 | eps = 0.5 / len(OUTPUT_SORTED) # to avoid infinite samples 350 | u = self.uniform(a=eps, b=1 - eps) 351 | return norm.ppf(u, loc=loc, scale=scale) 352 | 353 | def beta( 354 | self, a: float = 1, b: float = 1, loc: float = 0, scale: float = 1 355 | ) -> float | np.ndarray: 356 | eps = 0.5 / len(OUTPUT_SORTED) # to avoid infinite samples 357 | u = self.uniform(a=eps, b=1 - eps) 358 | return beta.ppf(u, a=a, b=b, loc=loc, scale=scale) 359 | 360 | def gamma(self, a: float = 1, loc: float = 0, scale: float = 1) -> float | np.ndarray: 361 | eps = 0.5 / len(OUTPUT_SORTED) # to avoid infinite samples 362 | u = self.uniform(a=eps, b=1 - eps) 363 | return gamma.ppf(u, a=a, loc=loc, scale=scale) 364 | 365 | def exponential(self, scale: float = 1) -> float | np.ndarray: 366 | eps = 0.5 / len(OUTPUT_SORTED) # to avoid infinite samples 367 | u = self.uniform(a=eps, b=1 - eps) 368 | return expon.ppf(u, scale=scale) 369 | 370 | 371 | def curve_prior( 372 | dataset: DatasetPrior, config: np.ndarray 373 | ) -> Callable[[np.ndarray, int], np.ndarray]: 374 | # calls the more efficient batch-wise method 375 | return dataset.curves_for_configs(np.array([config])) 376 | 377 | 378 | # function producing batches for PFN training 379 | @torch.no_grad() 380 | def get_batch( 381 | batch_size: int, 382 | seq_len: int, 383 | num_features: int, 384 | single_eval_pos: int, 385 | device: torch.device = default_device, 386 | hyperparameters: dict[str, Any] | None = None, 387 | **kwargs: Any, 388 | ) -> Batch: 389 | # assert num_features == 2 390 | assert num_features >= 2 391 | EPS = 10**-9 392 | 393 | if hyperparameters is not None and "hp_dim" in hyperparameters: 394 | num_params = hyperparameters["hp_dim"] 395 | else: 396 | num_params = np.random.randint(1, num_features - 1) # beware upper bound is exclusive! 397 | 398 | dataset_prior = DatasetPrior(num_params, 23) 399 | 400 | x = [] 401 | y = [] 402 | 403 | for i in range(batch_size): 404 | epoch = torch.zeros(seq_len) 405 | id_curve = torch.zeros(seq_len) 406 | curve_val = torch.zeros(seq_len) 407 | config = torch.zeros(seq_len, num_params) 408 | 409 | # determine the number of fidelity levels (ranging from 1: BB, up to seq_len) 410 | n_levels = int(np.round(10 ** np.random.uniform(0, 3))) 411 | 412 | # determine # observations/queries per curve 413 | # TODO: also make this a dirichlet thing 414 | alpha = 10 ** np.random.uniform(-4, -1) 415 | weights = np.random.gamma(alpha, alpha, seq_len) + EPS 416 | p = weights / np.sum(weights) 417 | ids = np.arange(seq_len) 418 | all_levels = np.repeat(ids, n_levels) 419 | all_p = np.repeat(p, n_levels) / n_levels 420 | ordering = np.random.choice(all_levels, p=all_p, size=seq_len, replace=False) 421 | 422 | # calculate the cutoff/samples for each curve 423 | cutoff_per_curve = np.zeros((seq_len,), dtype=int) 424 | epochs_per_curve = np.zeros((seq_len,), dtype=int) 425 | for i in range(seq_len): # loop over every pos 426 | cid = ordering[i] 427 | epochs_per_curve[cid] += 1 428 | if i < single_eval_pos: 429 | cutoff_per_curve[cid] += 1 430 | 431 | # fix dataset specific random variables 432 | dataset_prior.new_dataset() 433 | 434 | # determine config, x, y for every curve 435 | curve_configs = np.random.uniform(size=(seq_len, num_params)) 436 | curves = dataset_prior.curves_for_configs(curve_configs) 437 | curve_xs = [] 438 | curve_ys = [] 439 | for cid in range(seq_len): # loop over every curve 440 | if epochs_per_curve[cid] > 0: 441 | # determine x (observations + query) 442 | x_ = np.zeros((epochs_per_curve[cid],)) 443 | if cutoff_per_curve[cid] > 0: # observations (if any) 444 | x_[: cutoff_per_curve[cid]] = ( 445 | np.arange(1, cutoff_per_curve[cid] + 1) / n_levels 446 | ) 447 | if cutoff_per_curve[cid] < epochs_per_curve[cid]: # queries (if any) 448 | x_[cutoff_per_curve[cid] :] = ( 449 | np.random.choice( 450 | np.arange(cutoff_per_curve[cid] + 1, n_levels + 1), 451 | size=epochs_per_curve[cid] - cutoff_per_curve[cid], 452 | replace=False, 453 | ) 454 | / n_levels 455 | ) 456 | curve_xs.append(x_) 457 | # determine y's 458 | y_ = curves(x_, cid) 459 | curve_ys.append(y_) 460 | else: 461 | curve_xs.append(None) 462 | curve_ys.append(None) 463 | 464 | # construct the batch data element 465 | curve_counters = torch.zeros(seq_len).type(torch.int64) 466 | for i in range(seq_len): 467 | cid = ordering[i] 468 | if i < single_eval_pos or curve_counters[cid] > 0: 469 | id_curve[i] = cid + 1 # reserve ID 0 for queries 470 | else: 471 | id_curve[i] = 0 # queries for unseen curves always have ID 0 472 | epoch[i] = curve_xs[cid][curve_counters[cid]] 473 | config[i] = torch.from_numpy(curve_configs[cid]) 474 | curve_val[i] = curve_ys[cid][curve_counters[cid]] 475 | curve_counters[cid] += 1 476 | 477 | x.append(torch.cat([torch.stack([id_curve, epoch], dim=1), config], dim=1)) 478 | y.append(curve_val) 479 | 480 | x = torch.stack(x, dim=1).to(device).float() 481 | y = torch.stack(y, dim=1).to(device).float() 482 | 483 | return Batch(x=x, y=y, target_y=y) 484 | 485 | 486 | class MultiCurvesEncoder(torch.nn.Module): 487 | def __init__(self, in_dim: int, out_dim: int) -> None: 488 | super().__init__() 489 | seq_len = 1000 490 | self.normalizer = torch.nn.Sequential( 491 | encoders.Normalize(0.5, math.sqrt(1 / 12)), 492 | ) 493 | self.epoch_enc = torch.nn.Linear(1, out_dim, bias=False) 494 | self.idcurve_enc = torch.nn.Embedding(seq_len + 1, out_dim) 495 | self.configuration_enc = encoders.get_variable_num_features_encoder(encoders.Linear)( 496 | in_dim - 2, out_dim 497 | ) 498 | 499 | def forward(self, *x, **kwargs) -> torch.Tensor: 500 | x = torch.cat(x, dim=-1) 501 | out = ( 502 | self.epoch_enc(self.normalizer(x[..., 1:2])) 503 | + self.idcurve_enc(x[..., :1].int()).squeeze(2) 504 | + self.configuration_enc(x[..., 2:]) 505 | ) 506 | return out 507 | 508 | 509 | def get_encoder() -> Callable[[int, int], torch.nn.Module]: 510 | return lambda num_features, emsize: MultiCurvesEncoder(num_features, emsize) 511 | 512 | 513 | def sample_curves( 514 | num_hyperparameters: int = 1000, curve_length: int = 100, hyperparameter_dimensions: int = 2 515 | ) -> tuple[np.ndarray, np.ndarray]: 516 | dataset_prior = DatasetPrior(hyperparameter_dimensions, 23) 517 | hyperparameters = np.random.uniform(size=(num_hyperparameters, hyperparameter_dimensions)) 518 | dataset_prior.new_dataset() 519 | curve_sampler = dataset_prior.curves_for_configs(hyperparameters) 520 | curves = np.array( 521 | [curve_sampler(np.linspace(0, 1, curve_length), cid) for cid in range(num_hyperparameters)] 522 | ) 523 | return hyperparameters, curves 524 | -------------------------------------------------------------------------------- /ifbo/priors/output_sorted.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ifBO/53f1207e27e00059fce308c51b91f067f96f50c4/ifbo/priors/output_sorted.npy -------------------------------------------------------------------------------- /ifbo/priors/prior.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABCMeta 4 | from abc import abstractmethod 5 | from collections.abc import Callable 6 | from dataclasses import dataclass 7 | from dataclasses import fields 8 | from typing import Any 9 | 10 | import torch 11 | from torch.utils.data import DataLoader 12 | 13 | 14 | @dataclass 15 | class Batch: 16 | """ 17 | A batch of data, with non-optional x, y, and target_y attributes. 18 | All other attributes are optional. 19 | 20 | If you want to add an attribute for testing only, you can just assign it after creation like: 21 | ``` 22 | batch = Batch(x=x, y=y, target_y=target_y) 23 | batch.test_attribute = test_attribute 24 | ``` 25 | """ 26 | 27 | # Required entries 28 | x: torch.Tensor 29 | y: torch.Tensor 30 | target_y: torch.Tensor 31 | 32 | # Optional Batch Entries 33 | style: torch.Tensor | None = None 34 | style_hyperparameter_values: torch.Tensor | None = None 35 | single_eval_pos: torch.Tensor | None = None 36 | causal_model_dag: object | None = None 37 | mean_prediction: bool | None = ( 38 | None # this controls whether to do mean prediction in bar_distribution for nonmyopic BO 39 | ) 40 | 41 | def other_filled_attributes( 42 | self, set_of_attributes: set[str] = set(("x", "y", "target_y")) 43 | ) -> list[str]: 44 | return [ 45 | f.name 46 | for f in fields(self) 47 | if f.name not in set_of_attributes and getattr(self, f.name) is not None 48 | ] 49 | 50 | 51 | def safe_merge_batches_in_batch_dim(*batches: Any, ignore_attributes: list[str] = []) -> Batch: 52 | """ 53 | Merge all supported non-None fields in a pre-specified (general) way, 54 | e.g. mutliple batch.x are concatenated in the batch dimension. 55 | :param ignore_attributes: attributes to remove from the merged batch, treated as if they were None. 56 | :return: 57 | """ 58 | not_none_fields = [ 59 | f.name 60 | for f in fields(batches[0]) 61 | if f.name not in ignore_attributes and getattr(batches[0], f.name) is not None 62 | ] 63 | assert all( 64 | [ 65 | set(not_none_fields) 66 | == set( 67 | [ 68 | f.name 69 | for f in fields(b) 70 | if f.name not in ignore_attributes and getattr(b, f.name) is not None 71 | ] 72 | ) 73 | for b in batches 74 | ] 75 | ), "All batches must have the same fields!" 76 | merge_funcs = { 77 | "x": lambda xs: torch.cat(xs, 1), 78 | "y": lambda ys: torch.cat(ys, 1), 79 | "target_y": lambda target_ys: torch.cat(target_ys, 1), 80 | "style": lambda styles: torch.cat(styles, 0), 81 | } 82 | assert all( 83 | f in merge_funcs for f in not_none_fields 84 | ), "Unknown fields encountered in `safe_merge_batches_in_batch_dim`." 85 | return Batch( 86 | **{f: merge_funcs[f]([getattr(batch, f) for batch in batches]) for f in not_none_fields} # type: ignore 87 | ) 88 | 89 | 90 | def merge_batches(*batches: Any, ignore_attributes: list[str] = []) -> Batch: 91 | assert False, "TODO: isn't this broken!? because catting in dim 0 seems wrong!?" 92 | 93 | def merge_attribute(attr_name: str, batch_sizes: list[int]) -> Any: 94 | attr = [getattr(batch, attr_name) for batch in batches] 95 | if isinstance(attr[0], list): 96 | 97 | def make_list(sublist, i): 98 | if sublist is None: 99 | return [None for _ in range(batch_sizes[i])] 100 | return sublist 101 | 102 | return sum([make_list(sublist, i) for i, sublist in enumerate(attr)], []) 103 | elif type(attr[0]) is torch.Tensor: 104 | return torch.cat(attr, 0) 105 | else: 106 | assert all(a is None for a in attr), ( 107 | f"Unknown type encountered in `merge_batches`." 108 | f"To ignore this, please add `{attr}` to the `ignore_attributes`." 109 | f"The following values are the problem: {attr_name}." 110 | ) 111 | return None 112 | 113 | batch_sizes = [batch.x.shape[0] for batch in batches] 114 | return Batch( 115 | **{ 116 | f.name: merge_attribute(f.name, batch_sizes) 117 | for f in fields(batches[0]) 118 | if f.name not in ignore_attributes 119 | } 120 | ) 121 | 122 | 123 | class PriorDataLoader(DataLoader, metaclass=ABCMeta): 124 | @abstractmethod 125 | def __init__( 126 | self, 127 | num_steps: int, 128 | batch_size: int, 129 | eval_pos_seq_len_sampler: Callable, 130 | seq_len_maximum: int, 131 | device: torch.device | None, 132 | **kwargs: Any, 133 | ) -> None: 134 | """ 135 | 136 | :param num_steps: int, first argument, the number of steps to take per epoch, i.e. iteration of the DataLoader 137 | :param batch_size: int, number of datasets per batch 138 | :param eval_pos_seq_len_sampler: callable, it takes no arguments and returns a tuple (single eval pos, bptt) 139 | :param kwargs: for future compatibility it is good to have a final all catch, as new kwargs might be introduced 140 | """ 141 | pass 142 | 143 | # A class or object variable `num_features`: int 144 | # Optional: `validate` function that accepts a transformer model 145 | 146 | # The DataLoader iter should return batches of the form ([style], x, y), target_y, single_eval_pos 147 | # We follow sequence len (s) first, batch size (b) second. So x: (s,b,num_features), y,target_y: (s,b) 148 | # and style: Optional[(b,num_style_params)], style can be omitted or set to None, if it is not intended to be used. 149 | 150 | # For more references, see `priors/utils.py` for a pretty general implementation of a DataLoader 151 | # and `train.py` for the only call of it. 152 | -------------------------------------------------------------------------------- /ifbo/priors/prior_bag.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import torch 6 | 7 | from ifbo.priors.prior import Batch 8 | from ifbo.priors.utils import get_batch_to_dataloader 9 | from ifbo.utils import default_device 10 | 11 | 12 | def get_batch( 13 | batch_size: int, 14 | seq_len: int, 15 | num_features: int, 16 | hyperparameters: dict[str, Any], 17 | device: torch.device = default_device, 18 | batch_size_per_gp_sample: int | None = None, 19 | **kwargs: Any, 20 | ) -> Batch: 21 | batch_size_per_gp_sample = batch_size_per_gp_sample or (min(64, batch_size)) 22 | num_models = batch_size // batch_size_per_gp_sample 23 | assert ( 24 | num_models * batch_size_per_gp_sample == batch_size 25 | ), f"Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})" 26 | 27 | args = { 28 | "device": device, 29 | "seq_len": seq_len, 30 | "num_features": num_features, 31 | "batch_size": batch_size_per_gp_sample, 32 | } 33 | 34 | prior_bag_priors_get_batch = hyperparameters["prior_bag_get_batch"] 35 | prior_bag_priors_p = [1.0] + [ 36 | hyperparameters[f"prior_bag_exp_weights_{i}"] 37 | for i in range(1, len(prior_bag_priors_get_batch)) 38 | ] 39 | 40 | weights = torch.tensor(prior_bag_priors_p, dtype=torch.float) # create a tensor of weights 41 | batch_assignments = torch.multinomial( 42 | torch.softmax(weights, 0), num_models, replacement=True 43 | ).numpy() 44 | 45 | if "verbose" in hyperparameters and hyperparameters["verbose"]: 46 | print( 47 | "PRIOR_BAG:", 48 | weights, 49 | batch_assignments, 50 | num_models, 51 | batch_size_per_gp_sample, 52 | batch_size, 53 | ) 54 | sample: list[Batch] = [ 55 | prior_bag_priors_get_batch[int(prior_idx)]( 56 | hyperparameters=hyperparameters, **args, **kwargs 57 | ) 58 | for prior_idx in batch_assignments 59 | ] 60 | 61 | def merge(sample: list[Batch], k: str) -> Any: 62 | x = [getattr(x_, k) for x_ in sample] 63 | if torch.is_tensor(x[0]): 64 | return torch.cat(x, 1).detach() 65 | else: 66 | return [*x] 67 | 68 | merged_sample = {k: merge(sample, k) for k in sample[0].other_filled_attributes(set())} 69 | if hyperparameters.get("verbose"): 70 | print({k: v.shape for k, v in merged_sample.items()}) 71 | 72 | return Batch(**merged_sample) 73 | 74 | 75 | if __name__ == "__main__": 76 | DataLoader = get_batch_to_dataloader(get_batch) 77 | -------------------------------------------------------------------------------- /ifbo/priors/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from collections.abc import Generator 5 | from functools import partial 6 | import inspect 7 | import math 8 | import os 9 | import random 10 | import sys 11 | import time 12 | import types 13 | from typing import Any 14 | 15 | import cloudpickle 16 | import numpy as np 17 | from scipy import stats 18 | import submitit 19 | import torch 20 | from torch import nn 21 | 22 | from ifbo.priors.prior import Batch 23 | from ifbo.utils import normalize_data 24 | from ifbo.utils import set_locals_in_self 25 | 26 | 27 | def get_uniform_sampler(min_eval_pos: int, seq_len: int) -> Callable[[], int]: 28 | print(f"Using this sampler single_eval_pos = {min_eval_pos} is equally likely as {seq_len-1}") 29 | 30 | def foo() -> int: 31 | return np.random.randint(min_eval_pos, seq_len) 32 | 33 | return foo 34 | 35 | 36 | def get_expon_sep_sampler(base: int, min_eval_pos: int, seq_len: int) -> Callable[[], int]: 37 | p_levels = np.array([np.power(base, i) for i in range(seq_len - min_eval_pos)]) 38 | p_levels /= p_levels.sum() 39 | print( 40 | f"Using this sampler single_eval_pos = {min_eval_pos} is {p_levels[0]/p_levels[-1]} times more likely than {seq_len-1}" 41 | ) 42 | 43 | def foo() -> int: 44 | return np.random.choice(seq_len - min_eval_pos, p=p_levels) + min_eval_pos 45 | 46 | return foo 47 | 48 | 49 | class PriorDataLoader: 50 | def _load_chunk(self, chunk_id: int) -> None: 51 | if self.partition: 52 | partition_id = chunk_id // 1000 53 | chunk_file = os.path.join( 54 | self.path, f"partition_{partition_id}", f"chunk_{chunk_id}.pkl" 55 | ) 56 | else: 57 | chunk_file = os.path.join(self.path, f"chunk_{chunk_id}.pkl") 58 | with open(chunk_file, "rb") as f: 59 | self.loaded_chunk = cloudpickle.load(f) 60 | self.loaded_chunk_id = chunk_id 61 | self.batch_counter = 0 62 | self.subsample_counter = 0 63 | 64 | def __init__( 65 | self, 66 | load_path: str, 67 | n_chunks: int = 2_000, 68 | store: bool = False, 69 | subsample: int = 1, 70 | partition: bool | None = None, 71 | ) -> None: 72 | self.path = load_path 73 | if store: 74 | self.partition = True # new collections are partitioned 75 | elif partition is None: 76 | # check whether the current directory contains a directory partition_0 77 | self.partition = os.path.isdir(os.path.join(self.path, "partition_0")) 78 | else: 79 | self.partition = partition 80 | if not store: 81 | self._load_chunk(0) 82 | 83 | self.n_chunks = n_chunks 84 | self.subsample = subsample 85 | 86 | def get_batch(self, device: torch.device | None) -> Batch: 87 | if self.subsample == 1: 88 | _, batch_data = self.loaded_chunk[self.batch_counter] 89 | batch_data.x = batch_data.x.to(device) 90 | batch_data.y = batch_data.y.to(device) 91 | batch_data.target_y = batch_data.target_y.to(device) 92 | self.batch_counter += 1 93 | if self.batch_counter >= len(self.loaded_chunk): 94 | self._load_chunk((self.loaded_chunk_id + 1) % self.n_chunks) 95 | else: 96 | _, full_batch_data = self.loaded_chunk[self.batch_counter] 97 | seq_len, batch_size = full_batch_data.y.shape 98 | subsample_size = batch_size // self.subsample 99 | if self.subsample_counter < self.subsample - 1: 100 | low = subsample_size * self.subsample_counter 101 | high = subsample_size * (self.subsample_counter + 1) 102 | batch_data = Batch( 103 | full_batch_data.x[:, low:high, :].to(device), 104 | full_batch_data.y[:, low:high].to(device), 105 | full_batch_data.target_y[:, low:high].to(device), 106 | ) 107 | self.subsample_counter += 1 108 | else: 109 | low = subsample_size * self.subsample_counter 110 | batch_data = Batch( 111 | full_batch_data.x[:, low:, :].to(device), 112 | full_batch_data.y[:, low:].to(device), 113 | full_batch_data.target_y[:, low:].to(device), 114 | ) 115 | self.subsample_counter = 0 116 | self.batch_counter += 1 117 | if self.batch_counter >= len(self.loaded_chunk): 118 | self._load_chunk((self.loaded_chunk_id + 1) % self.n_chunks) 119 | 120 | return batch_data 121 | 122 | def get_single_eval_pos(self) -> int: 123 | single_eval_pos, _ = self.loaded_chunk[self.batch_counter] 124 | if single_eval_pos == 1000: 125 | print( 126 | "WARNING: as a TEMP hack single eval pos = 1000 is manually corrected to 999", 127 | file=sys.stderr, 128 | ) 129 | single_eval_pos = 999 130 | return single_eval_pos 131 | 132 | def store_prior( 133 | self, 134 | prior: Any, 135 | local: bool = False, 136 | chunk_size: int = 1_000, 137 | batch_size: int = 25, 138 | seq_len: int = 1_000, 139 | n_features: int = 12, 140 | prior_hyperparameters: dict[str, Any] = {}, 141 | partition: str = "gki_cpu-cascadelake", 142 | eval_pos_sampler: Callable[[], int] | None = None, 143 | ) -> None: 144 | # generate batches in parallel and store them for efficient training 145 | assert chunk_size % batch_size == 0 146 | 147 | def store_batch( 148 | path: str, 149 | chunk_id: int, 150 | chunk_size: int, 151 | batch_size: int, 152 | seq_len: int, 153 | n_features: int, 154 | partition: bool, 155 | prior_hyperparameters: dict[str, Any], 156 | ) -> None: 157 | if partition: 158 | partition_id = chunk_id // 1000 159 | chunk_dir = os.path.join(self.path, f"partition_{partition_id}") 160 | chunk_file = os.path.join(chunk_dir, f"chunk_{chunk_id}.pkl") 161 | else: 162 | chunk_file = os.path.join(self.path, f"chunk_{chunk_id}.pkl") 163 | if not os.path.exists(chunk_file): 164 | np.random.seed((os.getpid() * int(time.time())) % 123456789) 165 | chunk_data = [] 166 | for bid in range(chunk_size // batch_size): 167 | if eval_pos_sampler is None: 168 | # sample single eval pos log-uniformly ({1, ..., seq_len} log-uniformly - 1) 169 | single_eval_pos = int( 170 | np.floor(np.exp(np.random.uniform(0, np.log(seq_len + 1)))) - 1 171 | ) 172 | else: 173 | single_eval_pos = eval_pos_sampler() 174 | assert single_eval_pos < seq_len 175 | b = prior.get_batch( # type: ignore 176 | batch_size=batch_size, 177 | single_eval_pos=single_eval_pos, 178 | seq_len=seq_len, 179 | num_features=n_features, 180 | hyperparameters=prior_hyperparameters, 181 | ) 182 | chunk_data.append((single_eval_pos, b)) 183 | with open(chunk_file, "wb") as file: 184 | cloudpickle.dump(chunk_data, file) 185 | else: 186 | print("Already done.") 187 | 188 | if partition: 189 | for partition_id in range(self.n_chunks // 1000): 190 | chunk_dir = os.path.join(self.path, f"partition_{partition_id}") 191 | if not os.path.exists(chunk_dir): 192 | print(f"Creating directory: {chunk_dir}") 193 | os.makedirs(chunk_dir) 194 | 195 | kwargss = [ 196 | { 197 | "path": self.path, 198 | "chunk_id": i, 199 | "chunk_size": chunk_size, 200 | "batch_size": batch_size, 201 | "seq_len": seq_len, 202 | "n_features": n_features, 203 | "partition": self.partition, 204 | "prior_hyperparameters": prior_hyperparameters, 205 | } 206 | for i in range(0, self.n_chunks) 207 | ] 208 | 209 | # check how long a task takes & if too long run them on submitit 210 | runonsubmitit = False 211 | done = 0 212 | for kwargs in kwargss: 213 | print(f"Calculating the {kwargs['chunk_id']}th chunk...") 214 | before = time.time() 215 | store_batch(**kwargs) # type: ignore 216 | duration = time.time() - before 217 | print(f"Done, took {duration}s") 218 | done += 1 219 | if not local and duration > 5: 220 | # run stuff on submitit 221 | runonsubmitit = True 222 | # allocate 3x more (+ 3min) time to avoid timeouts 223 | tlimit = int(3 * duration / 60.0 + 1) 224 | print(f"Generating remaining chunks using submitit with time limit {tlimit}min") 225 | break 226 | 227 | if runonsubmitit: 228 | executor = submitit.get_executor(folder="/tmp/") 229 | executor.update_parameters( 230 | time=tlimit, 231 | # partition="alldlc_gpu-rtx2080", 232 | partition=partition, 233 | cpus_per_task=1, 234 | slurm_gres="gpu:0", 235 | ) 236 | 237 | kwargss = kwargss[done:] 238 | job_name = os.path.basename(self.path) 239 | print(job_name, len(kwargss)) 240 | job_group = executor.submit_group(job_name, store_batch, kwargss) 241 | print(job_group) 242 | 243 | 244 | def get_rank() -> int: 245 | if "LOCAL_RANK" in os.environ: 246 | # launched with torch.distributed.launch 247 | rank = int(os.environ["LOCAL_RANK"]) 248 | return rank 249 | elif "SLURM_PROCID" in os.environ and torch.cuda.device_count() > 1: 250 | # this is for multi gpu when starting with submitit 251 | rank = int(os.environ["SLURM_PROCID"]) 252 | return rank 253 | else: 254 | # not using a distributed setting, assume rank 0 255 | return 0 256 | 257 | 258 | class DistributedPriorDataLoader(PriorDataLoader): 259 | def __init__( 260 | self, 261 | load_path: str, 262 | n_gpus: int = 1, 263 | n_chunks: int = 2_000, 264 | store: bool = False, 265 | subsample: int = 1, 266 | partition: bool | None = None, 267 | ): 268 | self.path = load_path 269 | if store: 270 | self.partition = True # new collections are partitioned 271 | elif partition is None: 272 | # check whether the current directory contains a directory partition_0 273 | self.partition = os.path.isdir(os.path.join(self.path, "partition_0")) 274 | else: 275 | self.partition = partition 276 | if not store: 277 | self.n_gpus = n_gpus 278 | self.loaded_chunk = ( 279 | None # lazy load, as the rank of the process may be unknown on initialization 280 | ) 281 | 282 | self.n_chunks = n_chunks 283 | self.subsample = subsample 284 | self.rank: int 285 | 286 | def data_sync(self) -> None: 287 | # lazy loading or reloading in case this object is used by multiple processes (should be avoided) 288 | if self.loaded_chunk is None or self.rank != get_rank(): 289 | self.rank = get_rank() 290 | offset = self.rank * self.n_chunks // self.n_gpus 291 | self._load_chunk(offset) 292 | 293 | def get_batch(self, device: torch.device | None) -> Batch: 294 | self.data_sync() 295 | 296 | if self.subsample == 1: 297 | _, batch_data = self.loaded_chunk[self.batch_counter] 298 | batch_data.x = batch_data.x.to(device) 299 | batch_data.y = batch_data.y.to(device) 300 | batch_data.target_y = batch_data.target_y.to(device) 301 | self.batch_counter += 1 302 | if self.batch_counter >= len(self.loaded_chunk): 303 | self._load_chunk((self.loaded_chunk_id + 1) % self.n_chunks) 304 | else: 305 | _, full_batch_data = self.loaded_chunk[self.batch_counter] 306 | seq_len, batch_size = full_batch_data.y.shape 307 | subsample_size = batch_size // self.subsample 308 | if self.subsample_counter < self.subsample - 1: 309 | low = subsample_size * self.subsample_counter 310 | high = subsample_size * (self.subsample_counter + 1) 311 | batch_data = Batch( 312 | full_batch_data.x[:, low:high, :].to(device), 313 | full_batch_data.y[:, low:high].to(device), 314 | full_batch_data.target_y[:, low:high].to(device), 315 | ) 316 | self.subsample_counter += 1 317 | else: 318 | low = subsample_size * self.subsample_counter 319 | batch_data = Batch( 320 | full_batch_data.x[:, low:, :].to(device), 321 | full_batch_data.y[:, low:].to(device), 322 | full_batch_data.target_y[:, low:].to(device), 323 | ) 324 | self.subsample_counter = 0 325 | self.batch_counter += 1 326 | if self.batch_counter >= len(self.loaded_chunk): 327 | self._load_chunk((self.loaded_chunk_id + 1) % self.n_chunks) 328 | 329 | return batch_data 330 | 331 | def get_single_eval_pos(self) -> int: 332 | self.data_sync() 333 | 334 | single_eval_pos, _ = self.loaded_chunk[self.batch_counter] 335 | if single_eval_pos == 1000: 336 | print( 337 | "WARNING: as a TEMP hack single eval pos = 1000 is manually corrected to 999", 338 | file=sys.stderr, 339 | ) 340 | single_eval_pos = 999 341 | return single_eval_pos 342 | 343 | 344 | def get_batch_to_dataloader(get_batch_method_: Callable) -> PriorDataLoader: 345 | # DL = partial(DL, get_batch_method=get_batch_method_) 346 | class DL(PriorDataLoader): 347 | get_batch_method = get_batch_method_ 348 | 349 | # Caution, you might need to set self.num_features manually if it is not part of the args. 350 | def __init__(self, num_steps: int, **get_batch_kwargs: Any) -> None: 351 | set_locals_in_self(locals()) 352 | 353 | # The stuff outside the or is set as class attribute before instantiation. 354 | self.num_features: int = get_batch_kwargs.get("num_features") or self.num_features 355 | self.epoch_count = 0 356 | print("DataLoader.__dict__", self.__dict__) 357 | self.num_steps: int 358 | self.get_batch_kwargs: dict[str, Any] 359 | 360 | @staticmethod 361 | def gbm( 362 | *args: Any, eval_pos_seq_len_sampler: Callable[[], tuple[int, int]], **kwargs: Any 363 | ) -> Batch: 364 | kwargs["single_eval_pos"], kwargs["seq_len"] = eval_pos_seq_len_sampler() 365 | # Scales the batch size dynamically with the power of 'dynamic_batch_size'. 366 | # A transformer with quadratic memory usage in the seq len would need a power of 2 to keep memory constant. 367 | if ( 368 | "dynamic_batch_size" in kwargs 369 | and kwargs["dynamic_batch_size"] > 0 370 | and kwargs["dynamic_batch_size"] is not None 371 | ): 372 | kwargs["batch_size"] = kwargs["batch_size"] * math.floor( 373 | math.pow(kwargs["seq_len_maximum"], kwargs["dynamic_batch_size"]) 374 | / math.pow(kwargs["seq_len"], kwargs["dynamic_batch_size"]) 375 | ) 376 | batch: Batch = get_batch_method_(*args, **kwargs) 377 | if batch.single_eval_pos is None: 378 | batch.single_eval_pos = kwargs["single_eval_pos"] 379 | return batch 380 | 381 | def __len__(self) -> int: 382 | return self.num_steps 383 | 384 | def get_test_batch(self, **kwargs: Any) -> Batch: # does not increase epoch_count 385 | return self.gbm( 386 | **self.get_batch_kwargs, 387 | epoch=self.epoch_count, 388 | model=self.model if hasattr(self, "model") else None, 389 | **kwargs, 390 | ) 391 | 392 | def __iter__(self) -> Generator: 393 | assert hasattr( 394 | self, "model" 395 | ), "Please assign model with `dl.model = ...` before training." 396 | self.epoch_count += 1 397 | return iter( 398 | self.gbm(**self.get_batch_kwargs, epoch=self.epoch_count - 1, model=self.model) 399 | for _ in range(self.num_steps) 400 | ) 401 | 402 | return DL # type: ignore 403 | 404 | 405 | def trunc_norm_sampler_f(mu: float, sigma: float) -> Callable[[], float]: 406 | return lambda: stats.truncnorm( 407 | (0 - mu) / sigma, (1000000 - mu) / sigma, loc=mu, scale=sigma 408 | ).rvs(1)[0] 409 | 410 | 411 | def beta_sampler_f(a: float, b: float) -> Callable[[], float]: 412 | return lambda: np.random.beta(a, b) 413 | 414 | 415 | def gamma_sampler_f(a: float, b: float) -> Callable[[], float]: 416 | return lambda: np.random.gamma(a, b) 417 | 418 | 419 | def uniform_sampler_f(a: float, b: float) -> Callable[[], float]: 420 | return lambda: np.random.uniform(a, b) 421 | 422 | 423 | def uniform_int_sampler_f(a: float, b: float) -> Callable[[], float]: 424 | return lambda: round(np.random.uniform(a, b)) 425 | 426 | 427 | def zipf_sampler_f(a: float, b: float, c: float) -> Callable[[], int]: 428 | x = np.arange(b, c) 429 | weights = x ** (-a) 430 | weights /= weights.sum() 431 | return lambda: stats.rv_discrete(name="bounded_zipf", values=(x, weights)).rvs(1) 432 | 433 | 434 | def scaled_beta_sampler_f(a: float, b: float, scale: float, minimum: float) -> Callable[[], float]: 435 | return lambda: minimum + round(beta_sampler_f(a, b)() * (scale - minimum)) 436 | 437 | 438 | def normalize_by_used_features_f( 439 | x: torch.Tensor, num_features_used: int, num_features: int, normalize_with_sqrt: bool = False 440 | ) -> torch.Tensor: 441 | if normalize_with_sqrt: 442 | return x / (num_features_used / num_features) ** (1 / 2) 443 | return x / (num_features_used / num_features) 444 | 445 | 446 | def order_by_y(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 447 | order = torch.argsort(y if random.randint(0, 1) else -y, dim=0)[:, 0, 0] 448 | order = order.reshape(2, -1).transpose(0, 1).reshape(-1) # .reshape(seq_len) 449 | x = x[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).reshape(seq_len, 1, -1) 450 | y = y[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).reshape(seq_len, 1, -1) 451 | 452 | return x, y 453 | 454 | 455 | def randomize_classes(x: torch.Tensor, num_classes: int) -> torch.Tensor: 456 | classes = torch.arange(0, num_classes, device=x.device) 457 | random_classes = torch.randperm(num_classes, device=x.device).type(x.type()) 458 | x = ((x.unsqueeze(-1) == classes) * random_classes).sum(-1) 459 | return x 460 | 461 | 462 | @torch.no_grad() 463 | def sample_num_feaetures_get_batch( 464 | batch_size: int, 465 | seq_len: int, 466 | num_features: int, 467 | hyperparameters: dict[str, Any], 468 | get_batch: Callable, 469 | **kwargs: Any, 470 | ) -> Batch: 471 | if ( 472 | hyperparameters.get("sample_num_features", True) and kwargs["epoch"] > 0 473 | ): # don't sample on test batch 474 | num_features = random.randint(3, num_features) 475 | return get_batch(batch_size, seq_len, num_features, hyperparameters=hyperparameters, **kwargs) 476 | 477 | 478 | class CategoricalActivation(nn.Module): 479 | def __init__( 480 | self, 481 | categorical_p: float = 0.1, 482 | ordered_p: float = 0.7, 483 | keep_activation_size: bool = False, 484 | num_classes_sampler: Callable[[], int] = zipf_sampler_f(0.8, 1, 10), 485 | ) -> None: 486 | self.categorical_p = categorical_p 487 | self.ordered_p = ordered_p 488 | self.keep_activation_size = keep_activation_size 489 | self.num_classes_sampler = num_classes_sampler 490 | 491 | super().__init__() 492 | 493 | def forward(self, x: torch.Tensor) -> torch.Tensor: 494 | # x shape: T, B, H 495 | 496 | x = nn.Softsign()(x) 497 | 498 | num_classes = self.num_classes_sampler() 499 | hid_strength = torch.abs(x).mean(0).unsqueeze(0) if self.keep_activation_size else None 500 | 501 | categorical_classes = torch.rand((x.shape[1], x.shape[2])) < self.categorical_p 502 | class_boundaries = torch.zeros( 503 | (num_classes - 1, x.shape[1], x.shape[2]), device=x.device, dtype=x.dtype 504 | ) 505 | # Sample a different index for each hidden dimension, but shared for all batches 506 | for b in range(x.shape[1]): 507 | for h in range(x.shape[2]): 508 | ind = torch.randint(0, x.shape[0], (num_classes - 1,)) 509 | class_boundaries[:, b, h] = x[ind, b, h] 510 | 511 | for b in range(x.shape[1]): 512 | x_rel = x[:, b, categorical_classes[b]] 513 | boundaries_rel = class_boundaries[:, b, categorical_classes[b]].unsqueeze(1) 514 | x[:, b, categorical_classes[b]] = (x_rel > boundaries_rel).sum( 515 | dim=0 516 | ).float() - num_classes / 2 517 | 518 | ordered_classes = torch.rand((x.shape[1], x.shape[2])) < self.ordered_p 519 | ordered_classes = torch.logical_and(ordered_classes, categorical_classes) 520 | x[:, ordered_classes] = randomize_classes(x[:, ordered_classes], num_classes) 521 | 522 | x = x * hid_strength if self.keep_activation_size else x 523 | 524 | return x 525 | 526 | 527 | class QuantizationActivation(torch.nn.Module): 528 | def __init__(self, n_thresholds: int, reorder_p: float = 0.5) -> None: 529 | super().__init__() 530 | self.n_thresholds = n_thresholds 531 | self.reorder_p = reorder_p 532 | self.thresholds = torch.nn.Parameter(torch.randn(self.n_thresholds)) 533 | 534 | def forward(self, x: torch.Tensor) -> torch.Tensor: 535 | x = normalize_data(x) 536 | assert isinstance(x, torch.Tensor) 537 | x = x.unsqueeze(-1) 538 | x = (x > self.thresholds).sum(-1) 539 | 540 | if random.random() < self.reorder_p: 541 | x = randomize_classes(x.unsqueeze(-1), self.n_thresholds).squeeze(-1) 542 | # x = ((x.float() - self.n_thresholds/2) / self.n_thresholds)# * data_std + data_mean 543 | x = normalize_data(x) 544 | return x 545 | 546 | 547 | class NormalizationActivation(torch.nn.Module): 548 | def __init__(self) -> None: 549 | super().__init__() 550 | 551 | def forward(self, x: torch.Tensor) -> torch.Tensor: 552 | x = normalize_data(x) 553 | return x 554 | 555 | 556 | class PowerActivation(torch.nn.Module): 557 | def __init__(self) -> None: 558 | super().__init__() 559 | # self.exp = torch.nn.Parameter(0.5 * torch.ones(1)) 560 | self.shared_exp_strength = 0.5 561 | # TODO: Somehow this is only initialized once, so it's the same for all runs 562 | 563 | def forward(self, x: torch.Tensor) -> torch.Tensor: 564 | # print(torch.nn.functional.softplus(x), self.exp) 565 | shared_exp = torch.randn(1) 566 | exp = torch.nn.Parameter( 567 | ( 568 | shared_exp * self.shared_exp_strength 569 | + shared_exp * torch.randn(x.shape[-1]) * (1 - self.shared_exp_strength) 570 | ) 571 | * 2 572 | + 0.5 573 | ).to(x.device) 574 | x_ = torch.pow(torch.nn.functional.softplus(x) + 0.001, exp) 575 | if False: 576 | print( 577 | x[0:3, 0, 0].cpu().numpy(), 578 | torch.nn.functional.softplus(x[0:3, 0, 0]).cpu().numpy(), 579 | x_[0:3, 0, 0].cpu().numpy(), 580 | normalize_data(x_)[0:3, 0, 0].cpu().numpy(), 581 | self.exp.cpu().numpy(), 582 | ) 583 | return x_ 584 | 585 | 586 | def lambda_time(f: Callable, name: str = "", enabled: bool = True) -> Any: 587 | if not enabled: 588 | return f() 589 | start = time.time() 590 | r = f() 591 | print("Timing", name, time.time() - start) 592 | return r 593 | 594 | 595 | def pretty_get_batch(get_batch: Callable) -> str: 596 | """ 597 | Genereate string representation of get_batch function 598 | :param get_batch: 599 | :return: 600 | """ 601 | if isinstance(get_batch, types.FunctionType): 602 | return f"<{get_batch.__module__}.{get_batch.__name__} {inspect.signature(get_batch)}" 603 | else: 604 | return repr(get_batch) 605 | 606 | 607 | class get_batch_sequence(list): 608 | """ 609 | This will call the get_batch_methods in order from the back and pass the previous as `get_batch` kwarg. 610 | For example for `get_batch_methods=[get_batch_1, get_batch_2, get_batch_3]` this will produce a call 611 | equivalent to `get_batch_3(*args,get_batch=partial(partial(get_batch_2),get_batch=get_batch_1,**kwargs))`. 612 | get_batch_methods: all priors, but the first, muste have a `get_batch` argument 613 | """ 614 | 615 | def __init__(self, *get_batch_methods: Any) -> None: 616 | if len(get_batch_methods) == 0: 617 | raise ValueError("Must have at least one get_batch method") 618 | super().__init__(get_batch_methods) 619 | 620 | def __repr__(self) -> str: 621 | s = ",\n\t".join([f"{pretty_get_batch(get_batch)}" for get_batch in self]) 622 | return f"get_batch_sequence(\n\t{s}\n)" 623 | 624 | def __call__(self, *args: Any, **kwargs: Any) -> Batch: 625 | """ 626 | 627 | Standard kwargs are: batch_size, seq_len, num_features 628 | This returns a priors.Batch object. 629 | """ 630 | final_get_batch = self[0] 631 | for get_batch in self[1:]: 632 | final_get_batch = partial(get_batch, get_batch=final_get_batch) 633 | return final_get_batch(*args, **kwargs) 634 | -------------------------------------------------------------------------------- /ifbo/surrogate.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from pathlib import Path 5 | import warnings 6 | 7 | import torch 8 | 9 | from ifbo.download import download_and_decompress 10 | from ifbo.download import FILE_URL 11 | from ifbo.download import FILENAME 12 | from ifbo.download import VERSION_MAP 13 | from ifbo.download import WEIGHTS_FINAL_NAME 14 | from ifbo.utils import Curve 15 | from ifbo.utils import PredictionResult 16 | from ifbo.utils import tokenize 17 | 18 | 19 | def _resolve_model_path(target_path: Path | None = None) -> Path: 20 | """Resolve the model path. 21 | 22 | Args: 23 | target_path: Path to the trained model. 24 | 25 | Returns: 26 | path: Path to the trained model. 27 | """ 28 | # Resolve target path 29 | if target_path is None: 30 | target_path = Path.cwd().absolute() / ".model" 31 | warnings.warn( 32 | "No target path provided. " f"Defaulting to current working directory: {target_path}" 33 | ) 34 | if target_path.name == ".model" and target_path.is_dir(): 35 | target_path = target_path.absolute() 36 | elif (target_path / ".model").is_dir() or ( 37 | target_path.is_dir() and not (target_path / ".model").is_dir() 38 | ): 39 | # if target_path is a directory, and if `.model` subdirectory exists or not 40 | target_path = (target_path / ".model").absolute() 41 | else: 42 | raise ValueError("Invalid target path. Please provide a valid directory path.") 43 | target_path.mkdir(parents=True, exist_ok=True) 44 | 45 | return target_path 46 | 47 | 48 | class FTPFN(torch.nn.Module): 49 | """FTPFN surrogate model.""" 50 | 51 | def __init__( 52 | self, 53 | target_path: Path | str | None = None, 54 | version: str = "0.0.1", 55 | device: torch.device | None = None, 56 | ): 57 | """Initialize the FTPFN surrogate model. 58 | 59 | Args: 60 | target_path (Path, optional): Path to the trained model. Defaults to None. 61 | If None, creates a `.model/` directory in the current working directory. 62 | version (str, optional): Version of the model. Defaults to "0.0.1". 63 | """ 64 | super(FTPFN, self).__init__() 65 | 66 | self.version = version 67 | if target_path is None: 68 | # choose the current working directory if no path is provided 69 | target_path = Path.cwd().absolute() 70 | warnings.warn( 71 | "No target path provided. Defaulting to current" 72 | f" working directory: {target_path}." 73 | "\nPlease provide the above path or any other valid path to avoid this warning." 74 | ) 75 | self.target_path = _resolve_model_path( 76 | Path(target_path).absolute() 77 | if isinstance(target_path, str) 78 | else target_path.absolute() 79 | ) 80 | self.device = device 81 | 82 | if self.version not in VERSION_MAP: 83 | raise ValueError(f"Version {version} is not available for the surrogate model!") 84 | 85 | _target_file_zip = self.target_path / FILENAME(self.version) 86 | download_and_decompress(url=FILE_URL(self.version), path=_target_file_zip) 87 | 88 | # Loading and initializing the model with the pre-trained weights 89 | self.model = torch.load( 90 | os.path.join(self.target_path, WEIGHTS_FINAL_NAME(version)), 91 | map_location=self.device if self.device is not None else torch.device("cpu"), 92 | # TODO: See issue #12 93 | weights_only=False, 94 | ) 95 | self.model.eval() 96 | 97 | @torch.no_grad() 98 | def predict(self, context: list[Curve], query: list[Curve]) -> list[PredictionResult]: 99 | """Obtain the logits for the given context and query curves. 100 | 101 | Function to perform Bayesian inference using FT-PFN that uses the logits obtained to 102 | compute various measures like likelihood, UCB, EI, PI, and quantile. 103 | 104 | Args: 105 | context (list[Curve]): List of context curves. 106 | query (list[Curve]): List of query curves. 107 | 108 | Returns: 109 | list[PredictionResult]: List of prediction results for each query curve 110 | """ 111 | x_train, y_train, x_test = tokenize(context, query, device=self.device) 112 | logits = self(x_train=x_train, y_train=y_train, x_test=x_test) 113 | results = torch.split(logits, [len(curve.t) for curve in query], dim=0) 114 | return [ 115 | PredictionResult( 116 | logits=logit, 117 | criterion=self.model.criterion, 118 | ) 119 | for curve, logit in zip(query, results) 120 | ] 121 | 122 | def _check_input( 123 | self, x_train: torch.Tensor, y_train: torch.Tensor, x_test: torch.Tensor 124 | ) -> None: 125 | """Check the input values.""" 126 | if y_train.min() < 0 or y_train.max() > 1: 127 | raise Exception("y values should be in the range [0,1]") 128 | if ( 129 | x_train[:, 1].min() < 0 130 | or x_train[:, 1].max() > 1 131 | or x_test[:, 1].min() < 0 132 | or x_test[:, 1].max() > 1 133 | ): 134 | raise Exception("step values should be in the range [0,1]") 135 | if ( 136 | x_train[:, 0].min() < 0 137 | or x_train[:, 0].max() > 1000 138 | or x_test[:, 0].min() < 0 139 | or x_test[:, 0].max() > 1000 140 | ): 141 | raise Exception("id values should be in the range [0,1000]") 142 | if ( 143 | x_train[:, 2:].min() < 0 144 | or x_train[:, 2:].max() > 1 145 | or x_test[:, 2:].min() < 0 146 | or x_test[:, 2:].max() > 1 147 | ): 148 | raise Exception("hyperparameter values should be in the range [0,1]") 149 | 150 | def forward( 151 | self, x_train: torch.Tensor, y_train: torch.Tensor, x_test: torch.Tensor 152 | ) -> torch.Tensor: 153 | """Forward pass through the model. 154 | 155 | Args: 156 | x_train (torch.Tensor): context points, shape (n_context_observations x features). 157 | y_train (torch.Tensor): context values, shape (n_context_observations). 158 | x_test (torch.Tensor): query points, shape (n_query_observations x features). 159 | 160 | Returns: 161 | torch.Tensor: logits for the query points. 162 | """ 163 | 164 | self._check_input(x_train, y_train, x_test) 165 | if x_train.shape[0] == 0: 166 | x_test[:, 0] = 0 167 | elif x_train[:, 0].min() == 0: 168 | x_train[:, 0] += 1 169 | x_test[:, 0] += 1 170 | 171 | # reserve id=0 to curves that are not in x_train 172 | # set to 0 for all id in x_test[:, 0] that is not x_train[:, 0] 173 | x_test[:, 0] = torch.where( 174 | torch.isin(x_test[:, 0], x_train[:, 0]), 175 | x_test[:, 0], 176 | torch.zeros_like(x_test[:, 0]), 177 | ) 178 | 179 | single_eval_pos = x_train.shape[0] 180 | batch_size = 2000 # maximum batch size 181 | n_batches = (x_test.shape[0] + batch_size - 1) // batch_size 182 | 183 | results = [] 184 | for i in range(n_batches): 185 | start = i * batch_size 186 | end = min((i + 1) * batch_size, x_test.shape[0]) 187 | x_batch = torch.cat([x_train, x_test[start:end]], dim=0).unsqueeze(1) 188 | y_batch = y_train.unsqueeze(1) 189 | result = self.model((x_batch, y_batch), single_eval_pos=single_eval_pos) 190 | results.append(result) 191 | 192 | return torch.cat(results, dim=0) 193 | -------------------------------------------------------------------------------- /ifbo/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from contextlib import nullcontext 5 | import itertools 6 | import time 7 | from typing import Any 8 | 9 | import torch 10 | from torch import nn 11 | from torch.cuda.amp import autocast 12 | from torch.cuda.amp import GradScaler 13 | from tqdm import tqdm 14 | 15 | from ifbo import positional_encodings 16 | from ifbo import utils 17 | from ifbo.bar_distribution import BarDistribution 18 | from ifbo.bar_distribution import get_custom_bar_dist 19 | from ifbo.priors import prior 20 | from ifbo.transformer import TransformerModel 21 | from ifbo.utils import get_cosine_schedule_with_warmup 22 | from ifbo.utils import get_openai_lr 23 | from ifbo.utils import init_dist 24 | 25 | 26 | class Losses: 27 | def get_cross_entropy_loss(self, num_classes: int) -> nn.CrossEntropyLoss: 28 | return nn.CrossEntropyLoss(reduction="none", weight=torch.ones(num_classes)) 29 | 30 | gaussian = nn.GaussianNLLLoss(full=True, reduction="none") 31 | mse = nn.MSELoss(reduction="none") 32 | ce = get_cross_entropy_loss 33 | bce = nn.BCEWithLogitsLoss(reduction="none") 34 | get_BarDistribution = BarDistribution 35 | 36 | 37 | def train( 38 | priordataloader_class: prior.PriorDataLoader, 39 | criterion: nn.Module | BarDistribution, 40 | encoder_generator: Callable[[int, int], nn.Module], 41 | style_encoder_generator: Callable[[int, int], nn.Module], 42 | y_encoder_generator: Callable[[int, int], nn.Module], 43 | emsize: int = 200, 44 | nhid: int = 200, 45 | nlayers: int = 6, 46 | nhead: int = 2, 47 | dropout: float = 0.0, 48 | epochs: int = 10, 49 | steps_per_epoch: int = 100, 50 | batch_size: int = 200, 51 | bptt: int = 10, 52 | lr: float | None = None, 53 | weight_decay: float = 0.0, 54 | warmup_epochs: int = 10, 55 | input_normalization: bool = False, 56 | pos_encoder_generator: Callable[[int, int], nn.Module] | None = None, 57 | decoder_dict: dict[str, Any] = {}, 58 | extra_prior_kwargs_dict: dict[str, Any] = {}, 59 | scheduler_generator: Callable = get_cosine_schedule_with_warmup, 60 | load_weights_from_this_state_dict: dict[str, Any] | None = None, 61 | validation_period: int = 10, 62 | single_eval_pos_gen: Callable[[], int] | int | None = None, 63 | bptt_extra_samples: int | None = None, 64 | gpu_device: torch.device = "cuda:0", 65 | aggregate_k_gradients: int = 1, 66 | verbose: bool = True, 67 | epoch_callback: Callable | None = None, 68 | step_callback: Callable | None = None, 69 | continue_model: nn.Module = None, 70 | initializer: Callable | None = None, 71 | initialize_with_model: TransformerModel | None = None, 72 | train_mixed_precision: bool = False, 73 | efficient_eval_masking: bool = True, 74 | border_decoder: nn.Module | None = None, 75 | num_global_att_tokens: int = 0, 76 | progress_bar: bool = False, 77 | **model_extra_args: Any, 78 | ) -> tuple[float, float, TransformerModel, prior.PriorDataLoader] | None: 79 | device = gpu_device if torch.cuda.is_available() else "cpu:0" 80 | print(f"Using {device} device") 81 | using_dist, rank, device = init_dist(device) 82 | 83 | def eval_pos_seq_len_sampler() -> tuple[int | None, int]: 84 | if isinstance(single_eval_pos_gen, int): 85 | single_eval_pos = single_eval_pos_gen 86 | elif callable(single_eval_pos_gen): 87 | single_eval_pos = single_eval_pos_gen() 88 | else: 89 | single_eval_pos = None 90 | if bptt_extra_samples and False: # TODO: Currently disabled 91 | return single_eval_pos, single_eval_pos + bptt_extra_samples 92 | else: 93 | return single_eval_pos, bptt 94 | 95 | dl = priordataloader_class( 96 | num_steps=steps_per_epoch, 97 | batch_size=batch_size, 98 | eval_pos_seq_len_sampler=eval_pos_seq_len_sampler, 99 | seq_len_maximum=bptt, # +(bptt_extra_samples if bptt_extra_samples else 0) # TODO: Currently disabled 100 | device=device, 101 | **extra_prior_kwargs_dict, 102 | ) 103 | 104 | test_batch: prior.Batch = dl.get_test_batch() 105 | style_def = test_batch.style 106 | print( 107 | f"Style definition of first 3 examples: {style_def[:3] if style_def is not None else None}" 108 | ) 109 | style_encoder = ( 110 | style_encoder_generator(style_def.shape[1], emsize) if (style_def is not None) else None 111 | ) 112 | pos_encoder = (pos_encoder_generator or positional_encodings.NoPositionalEncoding)( 113 | emsize, bptt * 2 114 | ) 115 | if isinstance(criterion, nn.GaussianNLLLoss): 116 | n_out = 2 117 | elif ( 118 | isinstance(criterion, BarDistribution) or "BarDistribution" in criterion.__class__.__name__ 119 | ): # TODO remove this fix (only for dev) 120 | n_out = criterion.num_bars 121 | elif isinstance(criterion, nn.CrossEntropyLoss): 122 | n_out = criterion.weight.shape[0] 123 | else: 124 | n_out = 1 125 | 126 | # border_decoder = None if border_decoder is None else border_decoder(emsize, criterion.num_bars + 1).to(device) 127 | 128 | if continue_model: 129 | model = continue_model 130 | else: 131 | decoder_dict = decoder_dict if decoder_dict else {"standard": (None, n_out)} 132 | 133 | decoder_once_dict = {} 134 | if test_batch.mean_prediction is not None: 135 | decoder_once_dict["mean_prediction"] = decoder_dict["standard"] 136 | 137 | encoder = encoder_generator(dl.num_features, emsize) 138 | model = TransformerModel( 139 | encoder=encoder, 140 | nhead=nhead, 141 | ninp=emsize, 142 | nhid=nhid, 143 | nlayers=nlayers, 144 | dropout=dropout, 145 | style_encoder=style_encoder, 146 | y_encoder=y_encoder_generator(1, emsize), 147 | input_normalization=input_normalization, 148 | pos_encoder=pos_encoder, 149 | decoder_dict=decoder_dict, 150 | init_method=initializer, 151 | efficient_eval_masking=efficient_eval_masking, 152 | decoder_once_dict=decoder_once_dict, 153 | num_global_att_tokens=num_global_att_tokens, 154 | **model_extra_args, 155 | ) 156 | model.criterion = criterion 157 | if load_weights_from_this_state_dict is not None: 158 | model.load_state_dict(load_weights_from_this_state_dict) 159 | if initialize_with_model is not None: 160 | model.init_from_small_model(initialize_with_model) 161 | 162 | print( 163 | f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters" 164 | ) 165 | 166 | try: 167 | for (k, v), (k2, v2) in zip( 168 | model.state_dict().items(), 169 | initialize_with_model.state_dict().items(), # type: ignore 170 | ): 171 | print(k, ((v - v2) / v).abs().mean(), v.shape) 172 | except Exception: 173 | pass 174 | 175 | model.to(device) 176 | if using_dist: 177 | print("Distributed training") 178 | model = torch.nn.parallel.DistributedDataParallel( 179 | model, 180 | device_ids=[rank], 181 | output_device=rank, 182 | broadcast_buffers=False, 183 | find_unused_parameters=test_batch.mean_prediction is not None, 184 | ) 185 | dl.model = model.module # use local model, should not use multi-gpu functionality.. 186 | else: 187 | dl.model = model 188 | 189 | # learning rate 190 | if lr is None: 191 | lr = get_openai_lr(model) 192 | print(f"Using OpenAI max lr of {lr}.") 193 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) 194 | scheduler = scheduler_generator( 195 | optimizer, warmup_epochs, epochs if epochs is not None else 100 196 | ) # when training for fixed time lr schedule takes 100 steps 197 | 198 | scaler = GradScaler() if train_mixed_precision else None 199 | 200 | # check that everything uses up-to-date APIs 201 | utils.check_compatibility(dl) 202 | 203 | def train_epoch() -> tuple[float, list[float], float, float, float, float, float]: 204 | model.train() # Turn on the train mode 205 | total_loss = 0.0 206 | total_positional_losses = torch.zeros(bptt) 207 | total_positional_losses_recorded = torch.zeros(bptt) 208 | nan_steps = torch.zeros(1) 209 | ignore_steps = torch.zeros(1) 210 | before_get_batch = time.time() 211 | assert ( 212 | len(dl) % aggregate_k_gradients == 0 213 | ), "Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it." 214 | tqdm_iter = ( 215 | tqdm(range(len(dl)), desc="Training Epoch") if rank == 0 and progress_bar else None 216 | ) # , disable=not verbose 217 | 218 | for batch, full_data in enumerate(dl): 219 | data, targets, single_eval_pos = ( 220 | (full_data.style, full_data.x, full_data.y), 221 | full_data.target_y, 222 | full_data.single_eval_pos, 223 | ) 224 | 225 | def get_metrics() -> tuple[float, list[float], float, float, float, float, float]: 226 | return ( 227 | total_loss / steps_per_epoch, 228 | (total_positional_losses / total_positional_losses_recorded).tolist(), 229 | time_to_get_batch, 230 | forward_time, 231 | step_time, 232 | nan_steps.cpu().item() / (batch + 1), 233 | ignore_steps.cpu().item() / (batch + 1), 234 | ) 235 | 236 | tqdm_iter.update() if tqdm_iter is not None else None 237 | if using_dist and not (batch % aggregate_k_gradients == aggregate_k_gradients - 1): 238 | cm = model.no_sync() 239 | else: 240 | cm = nullcontext() 241 | with cm: 242 | time_to_get_batch = time.time() - before_get_batch 243 | before_forward = time.time() 244 | # TODO: This disables the bptt_extra_samples functionality but otherwise single eval pos is overwritten 245 | # if bptt_extra_samples is None: 246 | # single_eval_pos = single_eval_pos_gen() if callable(single_eval_pos_gen) else single_eval_pos_gen 247 | # else: 248 | # single_eval_pos = targets.shape[0] - bptt_extra_samples 249 | try: 250 | metrics_to_log: dict[str, Any] = {} 251 | with autocast(enabled=scaler is not None): 252 | # If style is set to None, it should not be transferred to device 253 | out = model( 254 | tuple(e.to(device) if torch.is_tensor(e) else e for e in data), 255 | single_eval_pos=single_eval_pos, 256 | only_return_standard_out=False, 257 | ) 258 | 259 | # this handling is for training old models only, this can be deleted soon(ish) 260 | # to only support models that return a tuple of dicts 261 | out, output_once = out if isinstance(out, tuple) else (out, None) 262 | output = out["standard"] if isinstance(out, dict) else out 263 | 264 | forward_time = time.time() - before_forward 265 | 266 | if single_eval_pos is not None: 267 | targets = targets[single_eval_pos:] 268 | 269 | if len(targets.shape) == len(output.shape): 270 | # this implies the prior uses a trailing 1 dimesnion 271 | # below we assume this not to be the case 272 | targets = targets.squeeze(-1) 273 | assert targets.shape == output.shape[:-1], ( 274 | f"Target shape {targets.shape} " 275 | "does not match output shape {output.shape}" 276 | ) 277 | if isinstance(criterion, nn.GaussianNLLLoss): 278 | assert ( 279 | output.shape[-1] == 2 280 | ), "need to write a little bit of code to handle multiple regression targets at once" 281 | 282 | mean_pred = output[..., 0] 283 | var_pred = output[..., 1].abs() 284 | losses = criterion( 285 | mean_pred.flatten(), targets.flatten(), var=var_pred.flatten() 286 | ) 287 | elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)): 288 | targets[torch.isnan(targets)] = -100 289 | losses = criterion(output.flatten(), targets.flatten()) 290 | elif isinstance(criterion, nn.CrossEntropyLoss): 291 | targets[torch.isnan(targets)] = -100 292 | print(f"{targets.min()=}, {targets.max()=}") 293 | losses = criterion(output.reshape(-1, n_out), targets.long().flatten()) 294 | elif border_decoder is not None: 295 | 296 | def apply_batch_wise_criterion(i: int) -> torch.Tensor: 297 | output_, targets_, borders_ = ( 298 | output_adaptive[:, i], 299 | targets[:, i], 300 | borders[i], 301 | ) 302 | criterion_ = get_custom_bar_dist(borders_, criterion).to(device) 303 | return criterion_(output_, targets_) 304 | 305 | output_adaptive, borders = out["adaptive_bar"], output_once["borders"] 306 | losses_adaptive_bar = torch.stack( 307 | [ 308 | apply_batch_wise_criterion(i) 309 | for i in range(output_adaptive.shape[1]) 310 | ], 311 | 1, 312 | ) 313 | losses_fixed_bar = criterion(output, targets) 314 | losses = (losses_adaptive_bar + losses_fixed_bar) / 2 315 | 316 | metrics_to_log = { 317 | **metrics_to_log, 318 | **{ 319 | "loss_fixed_bar": losses_fixed_bar.mean() 320 | .cpu() 321 | .detach() 322 | .item(), 323 | "loss_adaptive_bar": losses_adaptive_bar.mean() 324 | .cpu() 325 | .detach() 326 | .item(), 327 | }, 328 | } 329 | elif isinstance(criterion, BarDistribution) and full_data.mean_prediction: 330 | assert "mean_prediction" in output_once 331 | utils.print_once("Using mean prediction for loss") 332 | losses = criterion( 333 | output, 334 | targets, 335 | mean_prediction_logits=output_once["mean_prediction"], 336 | ) 337 | # the mean pred loss appears as the last per sequence 338 | else: 339 | losses = criterion(output, targets) 340 | losses = losses.view( 341 | -1, output.shape[1] 342 | ) # sometimes the seq length can be one off 343 | # that is because bar dist appends the mean 344 | loss, nan_share = utils.torch_nanmean(losses.mean(0), return_nanshare=True) 345 | loss_scaled = loss / aggregate_k_gradients 346 | 347 | if scaler: 348 | loss_scaled = scaler.scale(loss_scaled) 349 | loss_scaled.backward() 350 | 351 | if batch % aggregate_k_gradients == aggregate_k_gradients - 1: 352 | if scaler: 353 | scaler.unscale_(optimizer) 354 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 355 | if scaler: 356 | scaler.step(optimizer) 357 | scaler.update() 358 | else: 359 | optimizer.step() 360 | optimizer.zero_grad() 361 | 362 | step_time = time.time() - before_forward 363 | 364 | if not torch.isnan(loss): 365 | total_loss += loss.cpu().detach().item() 366 | total_positional_losses += ( 367 | losses.mean(1).cpu().detach() 368 | if single_eval_pos is None 369 | else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt) 370 | * utils.torch_nanmean(losses[: bptt - single_eval_pos].mean(0)) 371 | .cpu() 372 | .detach() 373 | ) 374 | 375 | total_positional_losses_recorded += ( 376 | torch.ones(bptt) 377 | if single_eval_pos is None 378 | else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt) 379 | ) 380 | 381 | metrics_to_log = { 382 | **metrics_to_log, 383 | **{"loss": loss, "single_eval_pos": single_eval_pos}, 384 | } 385 | if step_callback is not None and rank == 0: 386 | step_callback(metrics_to_log) 387 | nan_steps += nan_share 388 | ignore_steps += (targets == -100).float().mean() 389 | except Exception as e: 390 | print("Invalid step encountered, skipping...") 391 | print(e) 392 | raise (e) 393 | 394 | # total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time, nan_share, ignore_share = get_metrics() 395 | if tqdm_iter: 396 | tqdm_iter.set_postfix( 397 | { 398 | "data_time": time_to_get_batch, 399 | "step_time": step_time, 400 | "mean_loss": total_loss / (batch + 1), 401 | } 402 | ) 403 | 404 | before_get_batch = time.time() 405 | return get_metrics() 406 | 407 | total_loss = float("inf") * torch.ones(1) 408 | total_positional_losses = float("inf") * torch.ones(bptt) 409 | try: 410 | # Initially test the epoch callback function 411 | if epoch_callback is not None and rank == 0: 412 | epoch_callback(model, 1, data_loader=dl, scheduler=scheduler) 413 | for epoch in range(1, epochs + 1) if epochs is not None else itertools.count(1): 414 | epoch_start_time = time.time() 415 | try: 416 | ( 417 | total_loss, 418 | total_positional_losses, 419 | time_to_get_batch, 420 | forward_time, 421 | step_time, 422 | nan_share, 423 | ignore_share, 424 | ) = train_epoch() 425 | except Exception as e: 426 | print("Invalid epoch encountered, skipping...") 427 | print(e) 428 | raise (e) 429 | if hasattr(dl, "validate") and epoch % validation_period == 0: 430 | with torch.no_grad(): 431 | val_score = dl.validate(model) 432 | 433 | else: 434 | val_score = None 435 | 436 | if verbose: 437 | print("-" * 89) 438 | print( 439 | f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | ' 440 | f"pos losses {','.join([f'{loss:5.2f}' for loss in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}" 441 | f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}' 442 | f' forward time {forward_time:5.2f}' 443 | f' nan share {nan_share:5.2f} ignore share (for classification tasks) {ignore_share:5.4f}' 444 | + (f"val score {val_score}" if val_score is not None else "") 445 | ) 446 | print("-" * 89) 447 | 448 | # stepping with wallclock time based scheduler 449 | if epoch_callback is not None and rank == 0: 450 | epoch_callback(model, epoch, data_loader=dl, scheduler=scheduler) 451 | scheduler.step() 452 | except KeyboardInterrupt: 453 | pass 454 | 455 | if rank == 0: # trivially true for non-parallel training 456 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 457 | model = model.module 458 | dl = None 459 | return total_loss, total_positional_losses, model.to("cpu"), dl 460 | 461 | return None 462 | -------------------------------------------------------------------------------- /ifbo/transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | import math 5 | from typing import Any 6 | 7 | import torch 8 | from torch import Tensor 9 | import torch.nn as nn 10 | from torch.nn import Module 11 | from torch.nn import TransformerEncoder 12 | 13 | from ifbo.layer import TransformerEncoderLayer 14 | from ifbo.utils import bool_mask_to_att_mask 15 | from ifbo.utils import SeqBN 16 | 17 | 18 | class TransformerModel(nn.Module): 19 | def __init__( 20 | self, 21 | encoder: nn.Module, 22 | ninp: int, 23 | nhead: int, 24 | nhid: int, 25 | nlayers: int, 26 | dropout: float = 0.0, 27 | style_encoder: nn.Module | None = None, 28 | y_encoder: nn.Module | None = None, 29 | pos_encoder: nn.Module | None = None, 30 | decoder_dict: dict[str, tuple[nn.Module | None, int]] | None = None, 31 | input_normalization: bool = False, 32 | init_method: Callable | None = None, 33 | pre_norm: bool = False, 34 | activation: str = "gelu", 35 | recompute_attn: bool = False, 36 | num_global_att_tokens: int = 0, 37 | full_attention: bool = False, 38 | all_layers_same_init: bool = False, 39 | efficient_eval_masking: bool = True, 40 | decoder_once_dict: dict[str, tuple[nn.Module | None, int]] | None = None, 41 | return_all_outputs: bool = False, 42 | save_trainingset_representations: bool = False, 43 | ) -> None: 44 | super().__init__() 45 | self.model_type = "Transformer" 46 | 47 | def encoder_layer_creator() -> TransformerEncoderLayer: 48 | return TransformerEncoderLayer( 49 | ninp, 50 | nhead, 51 | nhid, 52 | dropout, 53 | activation=activation, 54 | pre_norm=pre_norm, 55 | recompute_attn=recompute_attn, 56 | save_trainingset_representations=save_trainingset_representations, 57 | ) 58 | 59 | self.transformer_encoder = ( 60 | TransformerEncoder(encoder_layer_creator(), nlayers) 61 | if all_layers_same_init 62 | else TransformerEncoderDiffInit(encoder_layer_creator, nlayers) # type: ignore 63 | ) 64 | self.ninp = ninp 65 | self.encoder = encoder 66 | self.y_encoder = y_encoder 67 | self.pos_encoder = pos_encoder 68 | self.return_all_outputs = return_all_outputs 69 | 70 | def make_decoder_dict( 71 | decoder_description_dict: dict[str, tuple[nn.Module | None, int]] | None, 72 | ) -> dict[str, nn.Module] | None: 73 | if decoder_description_dict is None or len(decoder_description_dict) == 0: 74 | return None 75 | initialized_decoder_dict = {} 76 | for decoder_key in decoder_description_dict: 77 | decoder_model, decoder_n_out = decoder_description_dict[decoder_key] 78 | if decoder_model is None: 79 | initialized_decoder_dict[decoder_key] = nn.Sequential( 80 | nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, decoder_n_out) 81 | ) 82 | else: 83 | initialized_decoder_dict[decoder_key] = decoder_model( 84 | ninp, nhid, decoder_n_out 85 | ) 86 | print( 87 | "Initialized decoder for", 88 | decoder_key, 89 | "with", 90 | decoder_description_dict[decoder_key], 91 | " and nout", 92 | decoder_n_out, 93 | ) 94 | return torch.nn.ModuleDict(initialized_decoder_dict) 95 | 96 | self.decoder_dict = make_decoder_dict(decoder_dict) 97 | self.decoder_dict_once = make_decoder_dict(decoder_once_dict) 98 | 99 | # N(0,1) is the initialization as the default of nn.Embedding 100 | self.decoder_dict_once_embeddings = ( 101 | torch.nn.Parameter(torch.randn((len(self.decoder_dict_once), 1, ninp))) 102 | if self.decoder_dict_once is not None 103 | else None 104 | ) 105 | # nn.Embedding(len(self.decoder_dict.keys()), nhid) 106 | self.input_ln = SeqBN(ninp) if input_normalization else None 107 | self.style_encoder = style_encoder 108 | self.init_method = init_method 109 | if num_global_att_tokens is not None: 110 | assert not full_attention 111 | self.global_att_embeddings = ( 112 | nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None 113 | ) 114 | self.full_attention = full_attention 115 | self.efficient_eval_masking = efficient_eval_masking 116 | 117 | self.nhid = nhid 118 | 119 | self.init_weights() 120 | 121 | def __setstate__(self, state: dict[str, Any]) -> None: 122 | super().__setstate__(state) 123 | self.__dict__.setdefault("efficient_eval_masking", False) 124 | if not hasattr(self, "decoder_dict_once"): 125 | self.__dict__.setdefault("decoder_dict_once", None) 126 | if hasattr(self, "decoder") and not hasattr(self, "decoder_dict"): 127 | self.add_module("decoder_dict", nn.ModuleDict({"standard": self.decoder})) 128 | self.__dict__.setdefault("return_all_outputs", False) 129 | 130 | def add_approximate_false(module: nn.Module) -> None: 131 | if isinstance(module, nn.GELU): 132 | module.__dict__.setdefault("approximate", "none") 133 | 134 | self.apply(add_approximate_false) 135 | 136 | @staticmethod 137 | def generate_square_subsequent_mask(sz: int) -> Tensor: 138 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 139 | return bool_mask_to_att_mask(mask) 140 | 141 | @staticmethod 142 | def generate_D_q_matrix(sz: int, query_size: int) -> Tensor: 143 | train_size = sz - query_size 144 | mask = torch.zeros(sz, sz) == 0 145 | mask[:, train_size:].zero_() 146 | mask |= torch.eye(sz) == 1 147 | return bool_mask_to_att_mask(mask) 148 | 149 | @staticmethod 150 | def generate_global_att_query_matrix( 151 | num_global_att_tokens: int, seq_len: int, num_query_tokens: int 152 | ) -> Tensor: 153 | train_size = seq_len + num_global_att_tokens - num_query_tokens 154 | sz = seq_len + num_global_att_tokens 155 | mask = torch.zeros(num_query_tokens, sz) == 0 156 | mask[:, train_size:].zero_() 157 | mask[:, train_size:] |= torch.eye(num_query_tokens) == 1 158 | return bool_mask_to_att_mask(mask) 159 | 160 | @staticmethod 161 | def generate_global_att_trainset_matrix( 162 | num_global_att_tokens: int, seq_len: int, num_query_tokens: int 163 | ) -> Tensor: 164 | trainset_size = seq_len - num_query_tokens 165 | mask = torch.zeros(trainset_size, num_global_att_tokens) == 0 166 | # mask[:,num_global_att_tokens:].zero_() 167 | # mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1 168 | return bool_mask_to_att_mask(mask) 169 | 170 | @staticmethod 171 | def generate_global_att_globaltokens_matrix( 172 | num_global_att_tokens: int, seq_len: int, num_query_tokens: int 173 | ) -> Tensor: 174 | mask = ( 175 | torch.zeros(num_global_att_tokens, num_global_att_tokens + seq_len - num_query_tokens) 176 | == 0 177 | ) 178 | return bool_mask_to_att_mask(mask) 179 | 180 | def init_weights(self) -> None: 181 | # initrange = 1.0 182 | # if isinstance(self.encoder,EmbeddingEncoder): 183 | # self.encoder.weight.data.uniform_(-initrange, initrange) 184 | # self.decoder.bias.data.zero_() 185 | # self.decoder.weight.data.uniform_(-initrange, initrange) 186 | if self.init_method is not None: 187 | self.apply(self.init_method) 188 | for layer in self.transformer_encoder.layers: 189 | nn.init.zeros_(layer.linear2.weight) 190 | nn.init.zeros_(layer.linear2.bias) 191 | attns = ( 192 | layer.self_attn 193 | if isinstance(layer.self_attn, nn.ModuleList) 194 | else [layer.self_attn] 195 | ) 196 | for attn in attns: 197 | nn.init.zeros_(attn.out_proj.weight) 198 | nn.init.zeros_(attn.out_proj.bias) 199 | 200 | def forward(self, *args: Any, **kwargs: Any) -> Tensor: 201 | """ 202 | This will perform a forward-pass (possibly recording gradients) of the model. 203 | We have multiple interfaces we support with this model: 204 | 205 | model(train_x, train_y, test_x, src_mask=None, style=None, only_return_standard_out=True) 206 | model((x,y), src_mask=None, single_eval_pos=None, only_return_standard_out=True) 207 | model((style,x,y), src_mask=None, single_eval_pos=None, only_return_standard_out=True) 208 | """ 209 | if len(args) == 3: 210 | # case model(train_x, train_y, test_x, src_mask=None, style=None, only_return_standard_out=True) 211 | assert all( 212 | kwarg in {"src_mask", "style", "only_return_standard_out"} 213 | for kwarg in kwargs.keys() 214 | ), f"Unrecognized keyword argument in kwargs: {set(kwargs.keys()) - {'src_mask', 'style', 'only_return_standard_out'}}" 215 | x = args[0] 216 | if args[2] is not None: 217 | x = torch.cat((x, args[2]), dim=0) 218 | style = kwargs.pop("style", None) 219 | return self._forward((style, x, args[1]), single_eval_pos=len(args[0]), **kwargs) 220 | elif len(args) == 1 and isinstance(args, tuple): 221 | # case model((x,y), src_mask=None, single_eval_pos=None, only_return_standard_out=True) 222 | # case model((style,x,y), src_mask=None, single_eval_pos=None, only_return_standard_out=True) 223 | assert all( 224 | kwarg in {"src_mask", "single_eval_pos", "only_return_standard_out"} 225 | for kwarg in kwargs.keys() 226 | ), f"Unrecognized keyword argument in kwargs: {set(kwargs.keys()) - {'src_mask', 'single_eval_pos', 'only_return_standard_out'}}" 227 | return self._forward(*args, **kwargs) 228 | 229 | def _forward( 230 | self, 231 | src: tuple, 232 | src_mask: tuple | int | None = None, 233 | single_eval_pos: int | None = None, 234 | only_return_standard_out: bool = True, 235 | ) -> dict[str, Tensor] | tuple[dict[str, Tensor], dict[str, Tensor]]: 236 | assert isinstance( 237 | src, tuple 238 | ), "inputs (src) have to be given as (x,y) or (style,x,y) tuple" 239 | 240 | if len(src) == 2: # (x,y) and no style 241 | src = (None,) + src 242 | 243 | style_src, x_src, y_src = src 244 | 245 | if single_eval_pos is None: 246 | single_eval_pos = x_src.shape[0] 247 | assert single_eval_pos is not None 248 | 249 | x_src = self.encoder(x_src) 250 | 251 | if self.decoder_dict_once is not None: 252 | x_src = torch.cat( 253 | [x_src, self.decoder_dict_once_embeddings.repeat(1, x_src.shape[1], 1)], dim=0 254 | ) 255 | 256 | assert self.y_encoder is not None 257 | y_src = ( 258 | self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src) 259 | if y_src is not None 260 | else None 261 | ) 262 | if self.style_encoder: 263 | assert style_src is not None, "style_src must be given if style_encoder is used" 264 | style_src = self.style_encoder(style_src).unsqueeze(0) 265 | else: 266 | style_src = torch.tensor([], device=x_src.device) 267 | global_src = ( 268 | torch.tensor([], device=x_src.device) 269 | if self.global_att_embeddings is None 270 | else self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1) 271 | ) 272 | 273 | if src_mask is not None: 274 | assert self.global_att_embeddings is None or isinstance(src_mask, tuple) 275 | 276 | if src_mask is None: 277 | if self.global_att_embeddings is None: 278 | full_len = len(x_src) + len(style_src) 279 | if self.full_attention: 280 | src_mask = bool_mask_to_att_mask( 281 | torch.ones((full_len, full_len), dtype=torch.bool) 282 | ).to(x_src.device) 283 | elif self.efficient_eval_masking: 284 | src_mask = single_eval_pos + len(style_src) 285 | else: 286 | src_mask = self.generate_D_q_matrix(full_len, len(x_src) - single_eval_pos).to( 287 | x_src.device 288 | ) 289 | else: 290 | src_mask_args = ( 291 | self.global_att_embeddings.num_embeddings, 292 | len(x_src) + len(style_src), 293 | len(x_src) + len(style_src) - single_eval_pos, 294 | ) 295 | src_mask = ( 296 | self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device), 297 | self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device), 298 | self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device), 299 | ) 300 | 301 | train_x = x_src[:single_eval_pos] 302 | if y_src is not None: 303 | train_x = train_x + y_src[:single_eval_pos] 304 | src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0) 305 | 306 | if self.input_ln is not None: 307 | src = self.input_ln(src) 308 | 309 | if self.pos_encoder is not None: 310 | src = self.pos_encoder(src) 311 | 312 | output = self.transformer_encoder(src, src_mask) 313 | 314 | num_prefix_positions = len(style_src) + ( 315 | self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0 316 | ) 317 | if self.return_all_outputs: 318 | out_range_start = num_prefix_positions 319 | else: 320 | out_range_start = single_eval_pos + num_prefix_positions 321 | 322 | # In the line below, we use the indexing feature, that we have `x[i:None] == x[i:]` 323 | out_range_end = ( 324 | -len(self.decoder_dict_once_embeddings) if self.decoder_dict_once is not None else None 325 | ) 326 | 327 | # take care the output once are counted from the end 328 | output_once = ( 329 | {k: v(output[-(i + 1)]) for i, (k, v) in enumerate(self.decoder_dict_once.items())} 330 | if self.decoder_dict_once is not None 331 | else {} 332 | ) 333 | 334 | output = ( 335 | {k: v(output[out_range_start:out_range_end]) for k, v in self.decoder_dict.items()} 336 | if self.decoder_dict is not None 337 | else {} 338 | ) 339 | 340 | if only_return_standard_out: 341 | return output["standard"] 342 | 343 | if output_once: 344 | return output, output_once 345 | return output 346 | 347 | @torch.no_grad() 348 | def init_from_small_model(self, small_model: TransformerModel) -> None: 349 | assert ( 350 | isinstance(self.decoder, nn.Linear) 351 | and isinstance(self.encoder, (nn.Linear, nn.Sequential)) 352 | and isinstance(self.y_encoder, (nn.Linear, nn.Sequential)) 353 | ) 354 | 355 | def set_encoder_weights( 356 | my_encoder: nn.Linear | nn.Sequential, 357 | small_model_encoder: nn.Linear | nn.Sequential, 358 | ) -> None: 359 | my_encoder_linear, small_encoder_linear = ( 360 | (my_encoder, small_model_encoder) 361 | if isinstance(my_encoder, nn.Linear) 362 | else (my_encoder[-1], small_model_encoder[-1]) 363 | ) 364 | small_in_dim = small_encoder_linear.out_features 365 | my_encoder_linear.weight.zero_() 366 | my_encoder_linear.bias.zero_() 367 | my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight 368 | my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias 369 | 370 | set_encoder_weights(self.encoder, small_model.encoder) 371 | set_encoder_weights(self.y_encoder, small_model.y_encoder) 372 | 373 | small_in_dim = small_model.decoder.in_features 374 | 375 | self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight 376 | self.decoder.bias = small_model.decoder.bias 377 | 378 | for my_layer, small_layer in zip( 379 | self.transformer_encoder.layers, small_model.transformer_encoder.layers 380 | ): 381 | small_hid_dim = small_layer.linear1.out_features 382 | my_in_dim = my_layer.linear1.in_features 383 | 384 | # packed along q,k,v order in first dim 385 | my_in_proj_w = my_layer.self_attn.in_proj_weight 386 | small_in_proj_w = small_layer.self_attn.in_proj_weight 387 | 388 | my_in_proj_w.view(3, my_in_dim, my_in_dim)[:, :small_in_dim, :small_in_dim] = ( 389 | small_in_proj_w.view(3, small_in_dim, small_in_dim) 390 | ) 391 | my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:, :small_in_dim] = ( 392 | small_layer.self_attn.in_proj_bias.view(3, small_in_dim) 393 | ) 394 | 395 | my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = ( 396 | small_layer.self_attn.out_proj.weight 397 | ) 398 | my_layer.self_attn.out_proj.bias[:small_in_dim] = small_layer.self_attn.out_proj.bias 399 | 400 | my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = small_layer.linear1.weight 401 | my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias 402 | 403 | my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = small_layer.linear2.weight 404 | my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias 405 | 406 | my_layer.norm1.weight[:small_in_dim] = ( 407 | math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight 408 | ) 409 | my_layer.norm2.weight[:small_in_dim] = ( 410 | math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight 411 | ) 412 | 413 | my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias 414 | my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias 415 | 416 | 417 | class TransformerEncoderDiffInit(Module): 418 | r"""TransformerEncoder is a stack of N encoder layers 419 | 420 | Args: 421 | encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required). 422 | num_layers: the number of sub-encoder-layers in the encoder (required). 423 | norm: the layer normalization component (optional). 424 | """ 425 | 426 | __constants__ = ["norm"] 427 | 428 | def __init__( 429 | self, 430 | encoder_layer_creator: Callable[[], nn.Module], 431 | num_layers: int, 432 | norm: nn.Module | None = None, 433 | ): 434 | super().__init__() 435 | self.layers = nn.ModuleList([encoder_layer_creator() for _ in range(num_layers)]) 436 | self.num_layers = num_layers 437 | self.norm = norm 438 | 439 | def forward( 440 | self, 441 | src: Tensor, 442 | mask: Tensor | None = None, 443 | src_key_padding_mask: Tensor | None = None, 444 | ) -> Tensor: 445 | r"""Pass the input through the encoder layers in turn. 446 | 447 | Args: 448 | src: the sequence to the encoder (required). 449 | mask: the mask for the src sequence (optional). 450 | src_key_padding_mask: the mask for the src keys per batch (optional). 451 | 452 | Shape: 453 | see the docs in Transformer class. 454 | """ 455 | output = src 456 | 457 | for mod in self.layers: 458 | output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) 459 | 460 | if self.norm is not None: 461 | output = self.norm(output) 462 | 463 | return output 464 | -------------------------------------------------------------------------------- /ifbo/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | from collections.abc import Callable 5 | from collections.abc import Generator 6 | from collections.abc import Sequence 7 | from dataclasses import dataclass 8 | import datetime 9 | import itertools 10 | import math 11 | import os 12 | import random 13 | from typing import Any 14 | 15 | import numpy as np 16 | import torch 17 | from torch import nn 18 | from torch.optim import Optimizer 19 | from torch.optim.lr_scheduler import LambdaLR 20 | 21 | from ifbo.bar_distribution import BarDistribution 22 | from ifbo.priors.prior import Batch 23 | 24 | 25 | @dataclass 26 | class Curve: 27 | """ 28 | A class to represent a performance curve. 29 | 30 | Attributes: 31 | ----------- 32 | hyperparameters : torch.Tensor 33 | A tensor containing the hyperparameters. Should not have more than 10 dimensions and values should be in the range [0, 1]. 34 | t : torch.Tensor 35 | A tensor containing the time steps. Values should be in the range [0, 1]. 36 | y : torch.Tensor | None, optional 37 | A tensor containing the performance values (higher is better). Values should be in the range [0, 1]. Default is None. 38 | """ 39 | 40 | hyperparameters: torch.Tensor 41 | t: torch.Tensor 42 | y: torch.Tensor | None = None 43 | 44 | 45 | @dataclass(unsafe_hash=True) 46 | class PredictionResult: 47 | """ 48 | A dataclass for storing prediction results and computing various metrics. 49 | 50 | Attributes: 51 | logits (torch.Tensor): The logits output from the model. 52 | criterion (BarDistribution): The criterion used for computing various metrics. 53 | 54 | Methods: 55 | likelihood(y_test: torch.Tensor) -> torch.Tensor: 56 | Computes the negative log-likelihood of the test targets. 57 | 58 | ucb() -> torch.Tensor: 59 | Computes the upper confidence bound (UCB) of the logits. 60 | 61 | ei(y_best: torch.Tensor) -> torch.Tensor: 62 | Computes the expected improvement (EI) given the best observed value. 63 | 64 | pi(y_best: torch.Tensor) -> torch.Tensor: 65 | Computes the probability of improvement (PI) given the best observed value. 66 | 67 | quantile(q: float) -> torch.Tensor: 68 | Computes the quantile of the logits at the given quantile level. 69 | """ 70 | 71 | logits: torch.Tensor 72 | criterion: BarDistribution 73 | 74 | def likelihood(self, y_test: torch.Tensor) -> torch.Tensor: 75 | """ 76 | Computes the log-likelihood of the test targets. 77 | 78 | Args: 79 | y_test (torch.Tensor): The test targets. 80 | 81 | Returns: 82 | torch.Tensor: The log-likelihood. 83 | """ 84 | return -self.criterion(self.logits, y_test).squeeze(1) 85 | 86 | def ucb(self) -> torch.Tensor: 87 | """ 88 | Computes the upper confidence bound (UCB) of the logits. 89 | 90 | Returns: 91 | torch.Tensor: The upper confidence bound. 92 | """ 93 | return self.criterion.ucb(self.logits, best_f=None).squeeze(1) 94 | 95 | def ei(self, y_best: torch.Tensor) -> torch.Tensor: 96 | """ 97 | Computes the expected improvement (EI) given the best observed value. 98 | 99 | Args: 100 | y_best (torch.Tensor): The best observed value. 101 | 102 | Returns: 103 | torch.Tensor: The expected improvement. 104 | """ 105 | return self.criterion.ei(self.logits, best_f=y_best).squeeze(1) 106 | 107 | def pi(self, y_best: torch.Tensor) -> torch.Tensor: 108 | """ 109 | Computes the probability of improvement (PI) given the best observed value. 110 | 111 | Args: 112 | y_best (torch.Tensor): The best observed value. 113 | 114 | Returns: 115 | torch.Tensor: The probability of improvement. 116 | """ 117 | return self.criterion.pi(self.logits, best_f=y_best).squeeze(1) 118 | 119 | def quantile(self, q: float) -> torch.Tensor: 120 | """ 121 | Computes the quantile of the logits at the given quantile level. 122 | 123 | Args: 124 | q (float): The quantile level. 125 | 126 | Returns: 127 | torch.Tensor: The quantile at the given level. 128 | """ 129 | return self.criterion.icdf(self.logits, q).squeeze(1) 130 | 131 | 132 | # copied from huggingface 133 | def get_cosine_schedule_with_warmup( 134 | optimizer: Optimizer, 135 | num_warmup_steps: int, 136 | num_training_steps: int, 137 | num_cycles: float = 0.5, 138 | last_epoch: int = -1, 139 | ) -> LambdaLR: 140 | """Create a schedule with a learning rate that decreases following the 141 | values of the cosine function between 0 and `pi * cycles` after a warmup 142 | period during which it increases linearly between 0 and 1. 143 | """ 144 | 145 | def lr_lambda(current_step: int) -> float: 146 | if current_step < num_warmup_steps: 147 | return float(current_step) / float(max(1, num_warmup_steps)) 148 | progress = float(current_step - num_warmup_steps) / float( 149 | max(1, num_training_steps - num_warmup_steps) 150 | ) 151 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 152 | 153 | return LambdaLR(optimizer, lr_lambda, last_epoch) 154 | 155 | 156 | # copied from huggingface 157 | def get_restarting_cosine_schedule_with_warmup( 158 | optimizer: Optimizer, 159 | num_warmup_steps: int, 160 | num_training_steps: int, 161 | steps_per_restart: int, 162 | num_cycles: float = 0.5, 163 | last_epoch: int = -1, 164 | ) -> LambdaLR: 165 | assert num_training_steps % steps_per_restart == 0 166 | 167 | def inner_lr_lambda( 168 | current_step: int, num_warmup_steps: int, num_training_steps: int 169 | ) -> float: 170 | if current_step < num_warmup_steps: 171 | return float(current_step) / float(max(1, num_warmup_steps)) 172 | progress = float(current_step - num_warmup_steps) / float( 173 | max(1, num_training_steps - num_warmup_steps) 174 | ) 175 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 176 | 177 | def lr_lambda(current_step: int) -> float: 178 | inner_step = current_step % steps_per_restart 179 | return inner_lr_lambda( 180 | inner_step, 181 | num_warmup_steps if current_step < steps_per_restart else 0, 182 | steps_per_restart, 183 | ) 184 | 185 | return LambdaLR(optimizer, lr_lambda, last_epoch) 186 | 187 | 188 | # copied from huggingface 189 | def get_linear_schedule_with_warmup( 190 | optimizer: Optimizer, 191 | num_warmup_steps: int, 192 | num_training_steps: int, 193 | last_epoch: int = -1, 194 | ) -> LambdaLR: 195 | """ 196 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 197 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 198 | 199 | Args: 200 | optimizer (:class:`~torch.optim.Optimizer`): 201 | The optimizer for which to schedule the learning rate. 202 | num_warmup_steps (:obj:`int`): 203 | The number of steps for the warmup phase. 204 | num_training_steps (:obj:`int`): 205 | The total number of training steps. 206 | last_epoch (:obj:`int`, `optional`, defaults to -1): 207 | The index of the last epoch when resuming training. 208 | 209 | Return: 210 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 211 | """ 212 | 213 | def lr_lambda(current_step: int) -> float: 214 | if current_step < num_warmup_steps: 215 | return float(current_step) / float(max(1, num_warmup_steps)) 216 | return max( 217 | 0.0, 218 | float(num_training_steps - current_step) 219 | / float(max(1, num_training_steps - num_warmup_steps)), 220 | ) 221 | 222 | return LambdaLR(optimizer, lr_lambda, last_epoch) 223 | 224 | 225 | def get_openai_lr(transformer_model: nn.Module) -> float: 226 | num_params = sum(p.numel() for p in transformer_model.parameters()) 227 | return 0.003239 - 0.0001395 * math.log(num_params) 228 | 229 | 230 | def get_weighted_single_eval_pos_sampler( 231 | max_len: int, min_len: int = 0, p: float = 1.0 232 | ) -> Callable[[], int]: 233 | """ 234 | This gives a sampler that can be used for `single_eval_pos` which yields good performance for all positions p, 235 | where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer. 236 | :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`. 237 | """ 238 | return lambda: random.choices( 239 | range(min_len, max_len), 240 | [1 / math.pow(((max_len - min_len) - i), p) for i in range(max_len - min_len)], 241 | )[0] 242 | 243 | 244 | def get_uniform_single_eval_pos_sampler(max_len: int, min_len: int = 0) -> Callable[[], int]: 245 | """ 246 | Just sample any evaluation position with the same weight 247 | :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`. 248 | """ 249 | return lambda: random.choices(range(min_len, max_len))[0] 250 | 251 | 252 | class SeqBN(nn.Module): 253 | def __init__(self, d_model: int) -> None: 254 | super().__init__() 255 | self.bn = nn.BatchNorm1d(d_model) 256 | self.d_model = d_model 257 | 258 | def forward(self, x: torch.Tensor) -> torch.Tensor: 259 | assert self.d_model == x.shape[-1] 260 | flat_x = x.view(-1, self.d_model) 261 | flat_x = self.bn(flat_x) 262 | return flat_x.view(*x.shape) 263 | 264 | 265 | def set_locals_in_self(locals: dict[str, Any]) -> None: 266 | """ 267 | Call this function like `set_locals_in_self(locals())` to set all local variables as object variables. 268 | Especially useful right at the beginning of `__init__`. 269 | :param locals: `locals()` 270 | """ 271 | self = locals["self"] 272 | for var_name, val in locals.items(): 273 | if var_name != "self": 274 | setattr(self, var_name, val) 275 | 276 | 277 | default_device = "cuda:0" if torch.cuda.is_available() else "cpu:0" 278 | 279 | 280 | # Copied from StackOverflow, but we do an eval on the values additionally 281 | class StoredictKeyPair(argparse.Action): 282 | def __init__( 283 | self, option_strings: list[str], dest: str, nargs: int | None = None, **kwargs: Any 284 | ) -> None: 285 | self._nargs = nargs 286 | super(StoredictKeyPair, self).__init__(option_strings, dest, nargs=nargs, **kwargs) 287 | 288 | def __call__( 289 | self, 290 | parser: argparse.ArgumentParser, 291 | namespace: argparse.Namespace, 292 | values: str | Sequence[Any] | None, 293 | option_string: str | None = None, 294 | ) -> None: 295 | my_dict = {} 296 | if values is not None: 297 | for kv in values: 298 | k, v = kv.split("=") 299 | try: 300 | my_dict[k] = eval(v) 301 | except NameError: 302 | my_dict[k] = v 303 | setattr(namespace, self.dest, my_dict) 304 | 305 | 306 | def get_nan_value(v: float, set_value_to_nan: float = 1.0) -> float: 307 | if random.random() < set_value_to_nan: 308 | return v 309 | else: 310 | return random.choice([-999, 0, 1, 999]) 311 | 312 | 313 | def to_ranking(data: torch.Tensor) -> torch.Tensor: 314 | x = data >= data.unsqueeze(-3) 315 | x = x.sum(0) 316 | return x 317 | 318 | 319 | # TODO: Is there a better way to do this? 320 | # 1. Cmparing to unique elements: When all values are different we still get quadratic blowup 321 | # 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic 322 | # 3. Argsort(Argsort(Unique))->Scatter seems a bit complicated, doesn't have quadratic blowup, but how fast? 323 | def to_ranking_low_mem(data: torch.Tensor) -> torch.Tensor: 324 | x = torch.zeros_like(data) 325 | for col in range(data.shape[-1]): 326 | x_ = data[:, :, col] >= data[:, :, col].unsqueeze(-2) 327 | x_ = x_.sum(0) 328 | x[:, :, col] = x_ 329 | return x 330 | 331 | 332 | def nan_handling_missing_for_unknown_reason_value(nan_prob: float = 1.0) -> float: 333 | return get_nan_value(float("nan"), nan_prob) 334 | 335 | 336 | def nan_handling_missing_for_no_reason_value(nan_prob: float = 1.0) -> float: 337 | return get_nan_value(float("-inf"), nan_prob) 338 | 339 | 340 | def nan_handling_missing_for_a_reason_value(nan_prob: float = 1.0) -> float: 341 | return get_nan_value(float("inf"), nan_prob) 342 | 343 | 344 | def torch_nanmean(x: torch.Tensor, axis: int = 0, return_nanshare: bool = False) -> torch.Tensor: 345 | num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis) 346 | value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis) 347 | if return_nanshare: 348 | return value / num, 1.0 - num / x.shape[axis] 349 | return value / num 350 | 351 | 352 | def torch_nanstd(x: torch.Tensor, axis: int = 0) -> torch.Tensor: 353 | num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis) 354 | value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis) 355 | mean = value / num 356 | mean_broadcast = torch.repeat_interleave(mean.unsqueeze(axis), x.shape[axis], dim=axis) 357 | return torch.sqrt(torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1)) 358 | 359 | 360 | def normalize_data( 361 | data: torch.Tensor, normalize_positions: int = -1, return_scaling: bool = False 362 | ) -> torch.Tensor | tuple[torch.Tensor, tuple[float, float]]: 363 | if normalize_positions > 0: 364 | mean = torch_nanmean(data[:normalize_positions], axis=0) 365 | std = torch_nanstd(data[:normalize_positions], axis=0) + 0.000001 366 | else: 367 | mean = torch_nanmean(data, axis=0) 368 | std = torch_nanstd(data, axis=0) + 0.000001 369 | data = (data - mean) / std 370 | data = torch.clip(data, min=-100, max=100) 371 | 372 | if return_scaling: 373 | return data, (mean, std) 374 | return data 375 | 376 | 377 | def remove_outliers( 378 | X: torch.Tensor, n_sigma: int = 4, normalize_positions: int = -1 379 | ) -> torch.Tensor: 380 | # Expects T, B, H 381 | assert len(X.shape) == 3, "X must be T,B,H" 382 | # for b in range(X.shape[1]): 383 | # for col in range(X.shape[2]): 384 | data = X if normalize_positions == -1 else X[:normalize_positions] 385 | data_clean = data[:].clone() 386 | data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0) 387 | cut_off = data_std * n_sigma 388 | lower, upper = data_mean - cut_off, data_mean + cut_off 389 | 390 | data_clean[torch.logical_or(data_clean > upper, data_clean < lower)] = np.nan 391 | data_mean, data_std = ( 392 | torch_nanmean(data_clean, axis=0), 393 | torch_nanstd(data_clean, axis=0), 394 | ) 395 | cut_off = data_std * n_sigma 396 | lower, upper = data_mean - cut_off, data_mean + cut_off 397 | 398 | X = torch.maximum(-torch.log(1 + torch.abs(X)) + lower, X) 399 | X = torch.minimum(torch.log(1 + torch.abs(X)) + upper, X) 400 | return X 401 | 402 | 403 | def bool_mask_to_att_mask(mask: torch.Tensor) -> torch.Tensor: 404 | return mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) 405 | 406 | 407 | def print_on_master_only(is_master: bool) -> None: 408 | import builtins as __builtin__ 409 | 410 | builtin_print = __builtin__.print 411 | 412 | def print(*args: Any, **kwargs: Any) -> None: 413 | force = kwargs.pop("force", False) 414 | if is_master or force: 415 | builtin_print(*args, **kwargs) 416 | 417 | __builtin__.print = print 418 | 419 | 420 | def init_dist(device: torch.device) -> tuple[bool, int, torch.device]: 421 | print("init dist") 422 | if "LOCAL_RANK" in os.environ: 423 | # launched with torch.distributed.launch 424 | rank = int(os.environ["LOCAL_RANK"]) 425 | print("torch.distributed.launch and my rank is", rank) 426 | torch.cuda.set_device(rank) 427 | os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) 428 | torch.distributed.init_process_group( 429 | backend="nccl", 430 | init_method="env://", 431 | timeout=datetime.timedelta(seconds=20), 432 | world_size=torch.cuda.device_count(), 433 | rank=rank, 434 | ) 435 | torch.distributed.barrier() 436 | print_on_master_only(rank == 0) 437 | print( 438 | f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, " 439 | "only I can print, but when using print(..., force=True) it will print on all ranks." 440 | ) 441 | return True, rank, f"cuda:{rank}" 442 | elif "SLURM_PROCID" in os.environ and torch.cuda.device_count() > 1: 443 | # this is for multi gpu when starting with submitit 444 | assert device != "cpu:0" 445 | rank = int(os.environ["SLURM_PROCID"]) 446 | os.environ["MASTER_ADDR"] = "localhost" 447 | os.environ["MASTER_PORT"] = "12355" 448 | torch.cuda.set_device(rank) 449 | os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) 450 | print("distributed submitit launch and my rank is", rank) 451 | torch.distributed.init_process_group( 452 | backend="nccl", 453 | init_method="env://", 454 | timeout=datetime.timedelta(seconds=20), 455 | world_size=torch.cuda.device_count(), 456 | rank=rank, 457 | ) 458 | torch.distributed.barrier() 459 | print_on_master_only(rank == 0) 460 | print( 461 | f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, " 462 | "only I can print, but when using print(..., force=True) it will print on all ranks." 463 | ) 464 | 465 | return True, rank, f"cuda:{rank}" 466 | else: 467 | print("Not using distributed") 468 | # will not change any of the behavior of print, but allows putting the force=True in the print calls 469 | print_on_master_only(True) 470 | return False, 0, device 471 | 472 | 473 | # NOP decorator for python with statements (x = NOP(); with x:) 474 | class NOP: 475 | def __enter__(self) -> None: 476 | pass 477 | 478 | def __exit__(self, type: Any, value: Any, traceback: Any) -> None: 479 | pass 480 | 481 | 482 | def check_compatibility(dl: torch.utils.data.DataLoader) -> None: 483 | if hasattr(dl, "num_outputs"): 484 | print("`num_outputs` for the DataLoader is deprecated. It is assumed to be 1 from now on.") 485 | assert dl.num_outputs != 1, ( 486 | "We assume num_outputs to be 1. Instead of the num_ouputs change your loss." 487 | "We specify the number of classes in the CE loss." 488 | ) 489 | 490 | 491 | def product_dict(dic: dict[str, Any]) -> Generator[dict[str, Any], None, None]: 492 | keys = dic.keys() 493 | vals = dic.values() 494 | for instance in itertools.product(*vals): 495 | yield dict(zip(keys, instance)) 496 | 497 | 498 | def to_tensor(x: torch.Tensor, device: torch.device | None = None) -> torch.Tensor: 499 | if isinstance(x, torch.Tensor): 500 | return x.to(device) 501 | else: 502 | return torch.tensor(x, device=device) 503 | 504 | 505 | printed_already = set() 506 | 507 | 508 | def print_once(*msgs: str) -> None: 509 | msg = " ".join([repr(m) for m in msgs]) 510 | if msg not in printed_already: 511 | print(msg) 512 | printed_already.add(msg) 513 | 514 | 515 | def tokenize( 516 | context: list[Curve], query: list[Curve], device: torch.device | None = None 517 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 518 | # takes as input a list of curves and query points (does not have y values) 519 | # returns the tokenized representation of 520 | # - context curves: ([id curve, t value, hyperparameters]) and the corresponding y values. 521 | # - query points: ([id curve, t value, hyperparameters]) 522 | # The id curve is a unique identifier for each curve in the context. 523 | 524 | config_to_id: dict[torch.Tensor, int] = {} 525 | context_tokens = [] 526 | context_y_values = [] 527 | query_tokens = [] 528 | current_id = 1 529 | 530 | def get_curve_id(hyperparameters: torch.Tensor) -> int: 531 | nonlocal current_id 532 | for config, cid in config_to_id.items(): 533 | if torch.equal(config, hyperparameters): 534 | return cid 535 | config_to_id[hyperparameters] = current_id 536 | current_id += 1 537 | return config_to_id[hyperparameters] 538 | 539 | for curve in context: 540 | curve_id = get_curve_id(curve.hyperparameters) 541 | num_points = curve.t.size(0) 542 | for i in range(num_points): 543 | context_tokens.append( 544 | torch.cat( 545 | (torch.tensor([curve_id, curve.t[i].item()]), curve.hyperparameters.cpu()) 546 | ) 547 | ) 548 | assert curve.y is not None 549 | context_y_values.append(curve.y[i]) 550 | 551 | for curve in query: 552 | curve_id = get_curve_id(curve.hyperparameters) 553 | num_points = curve.t.size(0) 554 | for i in range(num_points): 555 | query_tokens.append( 556 | torch.cat( 557 | (torch.tensor([curve_id, curve.t[i].item()]), curve.hyperparameters.cpu()) 558 | ) 559 | ) 560 | 561 | # Convert lists to tensors 562 | context_tokens_tensor = torch.stack(context_tokens, dim=0).to(device) 563 | context_y_values_tensor = torch.stack(context_y_values, dim=0).to(device) 564 | query_tokens_tensor = torch.stack(query_tokens, dim=0).to(device) 565 | 566 | return context_tokens_tensor, context_y_values_tensor, query_tokens_tensor 567 | 568 | 569 | def detokenize( 570 | batch: Batch, context_size: int, device: torch.device | None = None 571 | ) -> tuple[list[Curve], list[Curve]]: 572 | ( 573 | context_tokens_tensor, 574 | context_y_values_tensor, 575 | query_tokens_tensor, 576 | query_y_values_tensor, 577 | ) = ( 578 | batch.x.squeeze(1)[:context_size, ...], 579 | batch.y.squeeze(1)[:context_size, ...], 580 | batch.x.squeeze(1)[context_size:, ...], 581 | batch.y.squeeze(1)[context_size:, ...], 582 | ) 583 | id_to_config: dict[int, torch.Tensor] = {} 584 | context_curves: dict[int, list[tuple[float, float]]] = {} 585 | query_curves: dict[int, list[tuple[float, float]]] = {} 586 | used_ids: set[int] = set() 587 | all_possible_ids: set[int] = set(range(1, 1001)) 588 | 589 | def get_curve_config(curve_id: int) -> torch.Tensor: 590 | if curve_id in id_to_config: 591 | return id_to_config[curve_id] 592 | else: 593 | raise KeyError(f"Curve ID {curve_id} not found in id_to_config") 594 | 595 | # Process context tokens and y values 596 | for i in range(context_tokens_tensor.size(0)): 597 | token = context_tokens_tensor[i] 598 | y_value = context_y_values_tensor[i] 599 | 600 | curve_id = int(token[0].item()) 601 | x_value = token[1].item() 602 | configuration = token[2:] 603 | 604 | if curve_id not in id_to_config: 605 | id_to_config[curve_id] = configuration 606 | used_ids.add(curve_id) 607 | 608 | if curve_id not in context_curves: 609 | context_curves[curve_id] = [] 610 | 611 | context_curves[curve_id].append((x_value, y_value.item())) 612 | 613 | unused_ids = all_possible_ids - used_ids 614 | 615 | # Process query tokens 616 | for i in range(query_tokens_tensor.size(0)): 617 | token = query_tokens_tensor[i] 618 | y_value = ( 619 | query_y_values_tensor[i] 620 | if query_y_values_tensor is not None 621 | else torch.tensor([0.0]) * float("nan") 622 | ) 623 | 624 | curve_id = int(token[0].item()) 625 | x_value = token[1].item() 626 | configuration = token[2:] 627 | 628 | # Assign a new unique ID for configurations with curve_id 0 not in context tokens 629 | if curve_id not in used_ids: 630 | found = False 631 | for existing_id, config in id_to_config.items(): 632 | if torch.equal(config, configuration): 633 | curve_id = existing_id 634 | found = True 635 | break 636 | if not found: 637 | if not unused_ids: 638 | raise ValueError("No unused IDs available") 639 | curve_id = unused_ids.pop() 640 | id_to_config[curve_id] = configuration 641 | 642 | if curve_id not in query_curves: 643 | query_curves[curve_id] = [] 644 | 645 | query_curves[curve_id].append((x_value, y_value.item())) 646 | 647 | # Convert the context curves dictionary to list of Curve objects 648 | context_list = [] 649 | for curve_id, points in context_curves.items(): 650 | x_values = torch.tensor([p[0] for p in points]).to(device) 651 | y_values = torch.tensor([p[1] for p in points]).to(device) 652 | configuration = get_curve_config(curve_id) 653 | context_list.append(Curve(t=x_values, y=y_values, hyperparameters=configuration)) 654 | 655 | # Convert the query curves dictionary to list of Curve objects 656 | query_list = [] 657 | for curve_id, points in query_curves.items(): 658 | x_values = torch.tensor([p[0] for p in points]).to(device) 659 | if query_y_values_tensor is not None: 660 | y_values = torch.tensor([p[1] for p in points]).to(device) 661 | configuration = get_curve_config(curve_id) 662 | query_list.append( 663 | Curve( 664 | t=x_values, 665 | hyperparameters=configuration, 666 | y=y_values if query_y_values_tensor is not None else None, 667 | ) 668 | ) 669 | 670 | return context_list, query_list 671 | -------------------------------------------------------------------------------- /ifbo/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.12" 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ifBO" 3 | description = "In-context Freeze-Thaw Bayesian Optimization for Hyperparameter Optimization" 4 | readme = {file = "README.md", content-type = 'text/markdown'} 5 | license = {file = "LICENSE"} 6 | authors = [ 7 | {name = "Herilalaina Rakotoarison", email = "rakotoah@cs.uni-freiburg.de"}, 8 | {name = "Steven Adriaensen", email= "adriaens@cs.uni-freiburg.de"}, 9 | {name = "Neeratyoy Mallik", email = "mallik@cs.uni-freiburg.de"}, 10 | {name = "Samir Garibov"}, 11 | {name = "Edward Bergman"}, 12 | {name = "Frank Hutter"}, 13 | ] 14 | requires-python = ">=3.10,<3.14" 15 | dependencies = [ 16 | "cloudpickle>=3.0.0", 17 | "torch>=1.9.0", 18 | "numpy>=1.21.2", 19 | "scipy>=1.13.1", 20 | "requests>=2.23.0", 21 | "submitit>=1.5.1", 22 | ] 23 | dynamic = ["version"] 24 | classifiers = [ 25 | 'Intended Audience :: Science/Research', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Programming Language :: Python', 28 | 'Topic :: Software Development', 29 | 'Topic :: Scientific/Engineering', 30 | 'Operating System :: Unix', 31 | 'Operating System :: MacOS', 32 | 'Programming Language :: Python :: 3', 33 | 'Programming Language :: Python :: 3.10', 34 | 'Programming Language :: Python :: 3.11', 35 | 'Programming Language :: Python :: 3.12', 36 | 'Programming Language :: Python :: 3.13', 37 | ] 38 | 39 | [project.optional-dependencies] 40 | checking = [ 41 | "pre-commit", 42 | "mypy", 43 | "ruff", 44 | ] 45 | 46 | [project.urls] 47 | homepage = "https://github.com/automl/ifBO" 48 | repository = "https://github.com/automl/ifBO" 49 | bugtracker = "https://github.com/automl/ifBO/issues" 50 | 51 | [tool.setuptools.packages.find] 52 | include = ["ifbo*"] 53 | 54 | [tool.setuptools.package-data] 55 | ifbo = ["priors/output_sorted.npy"] 56 | 57 | [tool.setuptools.dynamic] 58 | version = {attr = "ifbo.version.__version__"} 59 | 60 | [tool.ruff] 61 | line-length = 99 62 | 63 | [tool.ruff.lint] 64 | extend-select = [ 65 | "I", 66 | ] 67 | 68 | [tool.ruff.lint.isort] 69 | known-third-party = [] 70 | lines-after-imports = 2 71 | force-single-line = true 72 | force-sort-within-sections = true 73 | order-by-type = false 74 | 75 | [tool.mypy] 76 | # Options configure mypy's strict mode. 77 | warn_unused_configs = true 78 | disallow_untyped_calls = true 79 | disallow_untyped_defs = true 80 | disallow_incomplete_defs = true 81 | check_untyped_defs = true 82 | no_implicit_optional = true 83 | warn_redundant_casts = true 84 | strict_equality = true 85 | extra_checks = true 86 | no_implicit_reexport = true 87 | ignore_missing_imports = true 88 | explicit_package_bases = true 89 | exclude = [ 90 | ".venv", 91 | "venv", 92 | "build", 93 | "work", 94 | ".*/.ipynb_checkpoints/.*", 95 | ] 96 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | numpy>=1.21.2,<2 3 | scipy>=1.13.1 4 | requests>=2.23.0 -------------------------------------------------------------------------------- /tests/test_surrogate.py: -------------------------------------------------------------------------------- 1 | import ifbo 2 | import unittest 3 | 4 | 5 | class TestSurrogateModel(unittest.TestCase): 6 | 7 | def setUp(self): 8 | # Initialize the surrogate model and any required data 9 | self.model = ifbo.surrogate.FTPFN(version="0.0.1") 10 | single_eval_pos = 700 11 | batch = ifbo.priors.ftpfn_prior.get_batch( 12 | batch_size=1, 13 | seq_len=1000, 14 | num_features=12, 15 | single_eval_pos=single_eval_pos 16 | ) 17 | self.context, self.query = ifbo.utils.detokenize(batch, context_size=single_eval_pos, device="cpu") 18 | 19 | def test_prediction_shape(self): 20 | # Test if the prediction output has the correct shape 21 | predictions = self.model.predict(self.context, self.query) 22 | self.assertEqual(len(predictions), len(self.query)) 23 | 24 | def assertBetween(self, value, min, max): 25 | """Fail if value is not between min and max (inclusive).""" 26 | self.assertTrue((value >= min).all()) 27 | self.assertTrue((value <= max).all()) 28 | 29 | def test_prediction_values(self): 30 | predictions = self.model.predict(self.context, self.query) 31 | for prediction in predictions: 32 | for q in [0.01, 0.5, 0.99]: 33 | self.assertBetween(prediction.quantile(q), 0, 1) 34 | self.assertBetween(prediction.ucb(), 0, 1) 35 | self.assertBetween(prediction.ei(0.5), 0, 1) 36 | self.assertBetween(prediction.pi(0.5), 0, 1) 37 | 38 | def test_exception_hyperparameters(self): 39 | """Test if the model raises an exception for invalid input.""" 40 | invalid_context = self.context.copy() 41 | invalid_context[0].hyperparameters[-1] = -1 42 | self.assertRaises(Exception, self.model.predict, invalid_context, self.query) 43 | 44 | invalid_context = self.context.copy() 45 | invalid_context[0].hyperparameters[-1] = 2 46 | self.assertRaises(Exception, self.model.predict, invalid_context, self.query) 47 | 48 | invalid_query = self.query.copy() 49 | invalid_query[0].hyperparameters[-1] = -1 50 | self.assertRaises(Exception, self.model.predict, self.context, invalid_query) 51 | 52 | invalid_query = self.query.copy() 53 | invalid_query[0].hyperparameters[-1] = 2 54 | self.assertRaises(Exception, self.model.predict, self.context, invalid_query) 55 | 56 | def test_exception_step(self): 57 | """Test if the model raises an exception for invalid input.""" 58 | invalid_context = self.context.copy() 59 | invalid_context[0].t = -1 60 | self.assertRaises(Exception, self.model.predict, invalid_context, self.query) 61 | 62 | invalid_context = self.context.copy() 63 | invalid_context[0].t = 2 64 | self.assertRaises(Exception, self.model.predict, invalid_context, self.query) 65 | 66 | invalid_query = self.query.copy() 67 | invalid_query[0].t = -1 68 | self.assertRaises(Exception, self.model.predict, self.context, invalid_query) 69 | 70 | invalid_query = self.query.copy() 71 | invalid_query[0].t = 2 72 | self.assertRaises(Exception, self.model.predict, self.context, invalid_query) 73 | 74 | def test_exception_performance(self): 75 | """Test if the model raises an exception for invalid input.""" 76 | invalid_context = self.context.copy() 77 | invalid_context[0].y = -1 78 | self.assertRaises(Exception, self.model.predict, invalid_context, self.query) 79 | 80 | invalid_context = self.context.copy() 81 | invalid_context[0].y = 2 82 | self.assertRaises(Exception, self.model.predict, invalid_context, self.query) 83 | 84 | invalid_query = self.query.copy() 85 | invalid_query[0].y = -1 86 | self.assertRaises(Exception, self.model.predict, self.context, invalid_query) 87 | 88 | invalid_query = self.query.copy() 89 | invalid_query[0].y = 2 90 | self.assertRaises(Exception, self.model.predict, self.context, invalid_query) 91 | 92 | 93 | 94 | if __name__ == '__main__': 95 | unittest.main() --------------------------------------------------------------------------------