├── .gitignore ├── LICENSE ├── lcpfn ├── __init__.py ├── bar_distribution.py ├── curves.py ├── decoders.py ├── domhan_prior.py ├── encoders.py ├── initializers.py ├── layer.py ├── model.py ├── positional_encodings.py ├── priors │ ├── __init__.py │ ├── prior.py │ └── utils.py ├── train.py ├── train_lcpfn.py ├── transformer.py ├── utils.py └── version.py ├── notebooks ├── curve_normalization.ipynb ├── inference.ipynb └── training.ipynb ├── pyproject.toml ├── readme.md └── tests └── test_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | lcpfn/trained_models/ 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AutoML-Freiburg-Hannover 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /lcpfn/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.insert(0, os.path.dirname(__file__)) 4 | 5 | 6 | model_path = "trained_models" 7 | 8 | 9 | def prepare_models(): 10 | pfns4bo_dir = os.path.dirname(__file__) 11 | model_names = [ 12 | "pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt", 13 | "pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt", 14 | ] 15 | 16 | for name in model_names: 17 | weights_path = os.path.join(pfns4bo_dir, model_path, name) 18 | compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + ".gz") 19 | if not os.path.exists(weights_path): 20 | if not os.path.exists(compressed_weights_path): 21 | print("Downloading", os.path.abspath(compressed_weights_path)) 22 | import requests 23 | 24 | url = f'https://ml.informatik.uni-freiburg.de/research-artifacts/lcpfn/{name + ".gz"}' 25 | r = requests.get(url, allow_redirects=True) 26 | os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True) 27 | with open(compressed_weights_path, "wb") as f: 28 | f.write(r.content) 29 | if os.path.exists(compressed_weights_path): 30 | print("Unzipping", name) 31 | os.system(f"gzip -dk {compressed_weights_path}") 32 | else: 33 | print("Failed to find", compressed_weights_path) 34 | print( 35 | "Make sure you have an internet connection to download the model automatically.." 36 | ) 37 | if os.path.exists(weights_path): 38 | print("Successfully located model at", weights_path) 39 | 40 | 41 | model_dict = { 42 | "EMSIZE512_NLAYERS12_NBUCKETS1000": os.path.join( 43 | os.path.dirname(__file__), 44 | model_path, 45 | "pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt", 46 | ), 47 | "EMSIZE512_NLAYERS6_NBUCKETS1000": os.path.join( 48 | os.path.dirname(__file__), 49 | model_path, 50 | "pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt", 51 | ), 52 | } 53 | 54 | 55 | def __getattr__(name): 56 | if name in model_dict: 57 | if not os.path.exists(model_dict[name]): 58 | print( 59 | "Can't find", 60 | os.path.abspath(model_dict[name]), 61 | "thus unzipping/downloading models now.", 62 | ) 63 | print("This might take a while..") 64 | prepare_models() 65 | return model_dict[name] 66 | raise AttributeError(f"module '{__name__}' has no attribute '{name}'") 67 | 68 | 69 | from .version import __version__ 70 | from lcpfn.model import LCPFN 71 | from lcpfn.train_lcpfn import train_lcpfn 72 | from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func 73 | 74 | __all__ = [ 75 | "LCPFN", 76 | "train_lcpfn", 77 | "sample_from_prior", 78 | "create_get_batch_func", 79 | "__version__", 80 | ] 81 | -------------------------------------------------------------------------------- /lcpfn/bar_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class BarDistribution(nn.Module): 6 | def __init__( 7 | self, borders: torch.Tensor, smoothing=0.0 8 | ): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted 9 | # sorted list of borders 10 | super().__init__() 11 | assert len(borders.shape) == 1 12 | # self.borders = borders 13 | self.register_buffer("borders", borders) 14 | self.register_buffer("smoothing", torch.tensor(smoothing)) 15 | # self.bucket_widths = self.borders[1:] - self.borders[:-1] 16 | self.register_buffer("bucket_widths", self.borders[1:] - self.borders[:-1]) 17 | full_width = self.bucket_widths.sum() 18 | border_order = torch.argsort(borders) 19 | assert ( 20 | full_width - (self.borders[-1] - self.borders[0]) 21 | ).abs() < 1e-4, f"diff: {full_width - (self.borders[-1] - self.borders[0])}" 22 | assert ( 23 | border_order == torch.arange(len(borders)).to(border_order.device) 24 | ).all(), "Please provide sorted borders!" 25 | self.num_bars = len(borders) - 1 26 | 27 | def map_to_bucket_idx(self, y): 28 | target_sample = torch.searchsorted(self.borders, y) - 1 29 | target_sample[y == self.borders[0]] = 0 30 | target_sample[y == self.borders[-1]] = self.num_bars - 1 31 | return target_sample 32 | 33 | def forward( 34 | self, logits, y 35 | ): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars 36 | target_sample = self.map_to_bucket_idx(y) 37 | assert (target_sample >= 0).all() and ( 38 | target_sample < self.num_bars 39 | ).all(), f"y {y} not in support set for borders (min_y, max_y) {self.borders}" 40 | assert ( 41 | logits.shape[-1] == self.num_bars 42 | ), f"{logits.shape[-1]} vs {self.num_bars}" 43 | 44 | bucket_log_probs = torch.log_softmax(logits, -1) 45 | scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths) 46 | # print(bucket_log_probs, logits.shape) 47 | 48 | nll_loss = -scaled_bucket_log_probs.gather( 49 | -1, target_sample.unsqueeze(-1) 50 | ).squeeze(-1) 51 | 52 | smooth_loss = -scaled_bucket_log_probs.mean(dim=-1) 53 | smoothing = self.smoothing if self.training else 0.0 54 | loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss 55 | return loss 56 | 57 | def mean(self, logits): 58 | bucket_means = self.borders[:-1] + self.bucket_widths / 2 59 | p = torch.softmax(logits, -1) 60 | return p @ bucket_means 61 | 62 | def icdf(self, logits, left_prob): 63 | """ 64 | Implementation of the quantile function 65 | :param logits: Tensor of any shape, with the last dimension being logits 66 | :param left_prob: float: The probability mass to the left of the result. 67 | :return: Position with `left_prob` probability weight to the left. 68 | """ 69 | probs = logits.softmax(-1) 70 | cumprobs = torch.cumsum(probs, -1) 71 | idx = ( 72 | torch.searchsorted( 73 | cumprobs, 74 | left_prob * torch.ones(*cumprobs.shape[:-1], 1, device=probs.device), 75 | ) 76 | .squeeze(-1) 77 | .clamp(0, cumprobs.shape[-1] - 1) 78 | ) # this might not do the right for outliers 79 | cumprobs = torch.cat( 80 | [torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), cumprobs], -1 81 | ) 82 | 83 | rest_prob = left_prob - cumprobs.gather(-1, idx[..., None]).squeeze(-1) 84 | left_border = self.borders[idx] 85 | right_border = self.borders[idx + 1] 86 | return left_border + (right_border - left_border) * rest_prob / probs.gather( 87 | -1, idx[..., None] 88 | ).squeeze(-1) 89 | 90 | def quantile(self, logits, center_prob=0.682): 91 | side_probs = (1.0 - center_prob) / 2 92 | return torch.stack( 93 | (self.icdf(logits, side_probs), self.icdf(logits, 1.0 - side_probs)), -1 94 | ) 95 | 96 | def ucb(self, logits, best_f, rest_prob=(1 - 0.682) / 2, maximize=True): 97 | """ 98 | UCB utility. Rest Prob is the amount of utility above (below) the confidence interval that is ignored. 99 | Higher rest_prob is equivalent to lower beta in the standard GP-UCB formulation. 100 | :param logits: Logits, as returned by the Transformer. 101 | :param best_f: Only here, since the other utilities have it. 102 | :param rest_prob: The amount of utility above (below) the confidence interval that is ignored. 103 | The default is equivalent to using GP-UCB with `beta=1`. 104 | To get the corresponding `beta`, where `beta` is from 105 | the standard GP definition of UCB `ucb_utility = mean + beta * std`, 106 | you can use this computation: `beta = math.sqrt(2)*torch.erfinv(torch.tensor(2*rest_prob-1))`. 107 | :param maximize: 108 | :return: utility 109 | """ 110 | if maximize: 111 | rest_prob = 1 - rest_prob 112 | return self.icdf(logits, rest_prob) 113 | 114 | def mode(self, logits): 115 | mode_inds = logits.argmax(-1) 116 | bucket_means = self.borders[:-1] + self.bucket_widths / 2 117 | return bucket_means[mode_inds] 118 | 119 | def ei( 120 | self, logits, best_f, maximize=True 121 | ): # logits: evaluation_points x batch x feature_dim 122 | bucket_means = self.borders[:-1] + self.bucket_widths / 2 123 | if maximize: 124 | bucket_contributions = torch.tensor( 125 | [ 126 | max((bucket_max + max(bucket_min, best_f)) / 2 - best_f, 0) 127 | for bucket_min, bucket_max, bucket_mean in zip( 128 | self.borders[:-1], self.borders[1:], bucket_means 129 | ) 130 | ], 131 | dtype=logits.dtype, 132 | device=logits.device, 133 | ) 134 | else: 135 | bucket_contributions = torch.tensor( 136 | [ 137 | -min((min(bucket_max, best_f) + bucket_min) / 2 - best_f, 0) 138 | for bucket_min, bucket_max, bucket_mean in zip( # min on max instead of max on min, and compare min < instead of max > 139 | self.borders[:-1], self.borders[1:], bucket_means 140 | ) 141 | ], 142 | dtype=logits.dtype, 143 | device=logits.device, 144 | ) 145 | p = torch.softmax(logits, -1) 146 | return p @ bucket_contributions 147 | 148 | def pi( 149 | self, logits, best_f, maximize=True 150 | ): # logits: evaluation_points x batch x feature_dim 151 | """ 152 | Acquisition Function: Probability of Improvement 153 | :param logits: as returned by Transformer 154 | :param best_f: best evaluation so far (the incumbent) 155 | :param maximize: whether to maximize 156 | :return: utility 157 | """ 158 | assert maximize is True 159 | p = torch.softmax(logits, -1) 160 | border_widths = self.borders[1:] - self.borders[:-1] 161 | factor = 1.0 - ((best_f - self.borders[:-1]) / border_widths).clamp(0.0, 1.0) 162 | return (p * factor).sum(-1) 163 | 164 | def mean_of_square(self, logits): 165 | """ 166 | Computes E[x^2]. 167 | :param logits: Output of the model. 168 | """ 169 | left_borders = self.borders[:-1] 170 | right_borders = self.borders[1:] 171 | bucket_mean_of_square = ( 172 | left_borders.square() 173 | + right_borders.square() 174 | + left_borders * right_borders 175 | ) / 3.0 176 | p = torch.softmax(logits, -1) 177 | return p @ bucket_mean_of_square 178 | 179 | def variance(self, logits): 180 | return self.mean_of_square(logits) - self.mean(logits).square() 181 | 182 | 183 | class FullSupportBarDistribution(BarDistribution): 184 | @staticmethod 185 | def halfnormal_with_p_weight_before(range_max, p=0.5): 186 | s = range_max / torch.distributions.HalfNormal(torch.tensor(1.0)).icdf( 187 | torch.tensor(p) 188 | ) 189 | return torch.distributions.HalfNormal(s) 190 | 191 | def forward( 192 | self, logits, y 193 | ): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars 194 | assert self.num_bars > 1 195 | target_sample = self.map_to_bucket_idx(y) 196 | target_sample.clamp_(0, self.num_bars - 1) 197 | assert logits.shape[-1] == self.num_bars 198 | 199 | bucket_log_probs = torch.log_softmax(logits, -1) 200 | scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths) 201 | # print(bucket_log_probs, logits.shape) 202 | log_probs = scaled_bucket_log_probs.gather( 203 | -1, target_sample.unsqueeze(-1) 204 | ).squeeze(-1) 205 | 206 | side_normals = ( 207 | self.halfnormal_with_p_weight_before(self.bucket_widths[0]), 208 | self.halfnormal_with_p_weight_before(self.bucket_widths[-1]), 209 | ) 210 | 211 | # TODO look over it again 212 | log_probs[target_sample == 0] += side_normals[0].log_prob( 213 | (self.borders[1] - y[target_sample == 0]).clamp(min=0.00000001) 214 | ) + torch.log(self.bucket_widths[0]) 215 | log_probs[target_sample == self.num_bars - 1] += side_normals[1].log_prob( 216 | y[target_sample == self.num_bars - 1] - self.borders[-2] 217 | ) + torch.log(self.bucket_widths[-1]) 218 | 219 | nll_loss = -log_probs 220 | 221 | smooth_loss = -scaled_bucket_log_probs.mean(dim=-1) 222 | smoothing = self.smoothing if self.training else 0.0 223 | loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss 224 | 225 | return loss 226 | 227 | def mean(self, logits): 228 | bucket_means = self.borders[:-1] + self.bucket_widths / 2 229 | p = torch.softmax(logits, -1) 230 | side_normals = ( 231 | self.halfnormal_with_p_weight_before(self.bucket_widths[0]), 232 | self.halfnormal_with_p_weight_before(self.bucket_widths[-1]), 233 | ) 234 | bucket_means[0] = -side_normals[0].mean + self.borders[1] 235 | bucket_means[-1] = side_normals[1].mean + self.borders[-2] 236 | return p @ bucket_means 237 | 238 | 239 | def get_bucket_limits_( 240 | num_outputs: int, 241 | full_range: tuple = None, 242 | ys: torch.Tensor = None, 243 | verbose: bool = False, 244 | ): 245 | assert (ys is not None) or (full_range is not None) 246 | if ys is not None: 247 | ys = ys.flatten() 248 | if len(ys) % num_outputs: 249 | ys = ys[: -(len(ys) % num_outputs)] 250 | print( 251 | f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys." 252 | ) 253 | ys_per_bucket = len(ys) // num_outputs 254 | if full_range is None: 255 | full_range = (ys.min(), ys.max()) 256 | else: 257 | assert full_range[0] <= ys.min() and full_range[1] >= ys.max() 258 | full_range = torch.tensor(full_range) 259 | ys_sorted, ys_order = ys.sort(0) 260 | bucket_limits = ( 261 | ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1] 262 | + ys_sorted[ys_per_bucket::ys_per_bucket] 263 | ) / 2 264 | if verbose: 265 | print( 266 | f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys." 267 | ) 268 | print(full_range) 269 | bucket_limits = torch.cat( 270 | [full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0 271 | ) 272 | 273 | else: 274 | class_width = (full_range[1] - full_range[0]) / num_outputs 275 | bucket_limits = torch.cat( 276 | [ 277 | full_range[0] + torch.arange(num_outputs).float() * class_width, 278 | torch.tensor(full_range[1]).unsqueeze(0), 279 | ], 280 | 0, 281 | ) 282 | 283 | assert ( 284 | len(bucket_limits) - 1 == num_outputs 285 | and full_range[0] == bucket_limits[0] 286 | and full_range[-1] == bucket_limits[-1] 287 | ) 288 | return bucket_limits 289 | 290 | 291 | def get_bucket_limits( 292 | num_outputs: int, 293 | full_range: tuple = None, 294 | ys: torch.Tensor = None, 295 | verbose: bool = False, 296 | ): 297 | assert (ys is None) != ( 298 | full_range is None 299 | ), "Either full_range or ys must be passed." 300 | 301 | if ys is not None: 302 | ys = ys.flatten() 303 | ys = ys[~torch.isnan(ys)] 304 | if len(ys) % num_outputs: 305 | ys = ys[: -(len(ys) % num_outputs)] 306 | print( 307 | f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys." 308 | ) 309 | ys_per_bucket = len(ys) // num_outputs 310 | if full_range is None: 311 | full_range = (ys.min(), ys.max()) 312 | else: 313 | assert ( 314 | full_range[0] <= ys.min() and full_range[1] >= ys.max() 315 | ), f"full_range {full_range} not in range of ys {ys.min(), ys.max()}" 316 | full_range = torch.tensor(full_range) 317 | ys_sorted, ys_order = ys.sort(0) 318 | bucket_limits = ( 319 | ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1] 320 | + ys_sorted[ys_per_bucket::ys_per_bucket] 321 | ) / 2 322 | if verbose: 323 | print( 324 | f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys." 325 | ) 326 | print(full_range) 327 | bucket_limits = torch.cat( 328 | [full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0 329 | ) 330 | 331 | else: 332 | class_width = (full_range[1] - full_range[0]) / num_outputs 333 | bucket_limits = torch.cat( 334 | [ 335 | full_range[0] + torch.arange(num_outputs).float() * class_width, 336 | torch.tensor(full_range[1]).unsqueeze(0), 337 | ], 338 | 0, 339 | ) 340 | 341 | assert ( 342 | len(bucket_limits) - 1 == num_outputs 343 | ), f"len(bucket_limits) - 1 == {len(bucket_limits) - 1} != {num_outputs} == num_outputs" 344 | assert full_range[0] == bucket_limits[0], f"{full_range[0]} != {bucket_limits[0]}" 345 | assert ( 346 | full_range[-1] == bucket_limits[-1] 347 | ), f"{full_range[-1]} != {bucket_limits[-1]}" 348 | 349 | return bucket_limits 350 | -------------------------------------------------------------------------------- /lcpfn/curves.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | 4 | prior = { 5 | "pow3": { 6 | "uniform": OrderedDict( 7 | a={"type": "uniform", "param1": -1, "param2": 1}, 8 | c={"type": "uniform", "param1": 0, "param2": 1}, 9 | alpha={"type": "uniform", "param1": 0, "param2": 1}, 10 | ), 11 | "peaked": OrderedDict( 12 | a={"type": "uniform", "param1": -0.6, "param2": 0.6}, 13 | c={"type": "uniform", "param1": 0, "param2": 1.25}, 14 | alpha={"type": "log_normal", "param1": 0, "param2": 2}, 15 | ), 16 | }, 17 | "ilog2": { 18 | "uniform": OrderedDict( 19 | c={"type": "uniform", "param1": 0, "param2": 1}, 20 | a={"type": "uniform", "param1": -1, "param2": 1}, 21 | ), 22 | "peaked": OrderedDict( 23 | c={"type": "uniform", "param1": 0, "param2": 1}, 24 | a={"type": "uniform", "param1": -0.5, "param2": 0.5}, 25 | ), 26 | }, 27 | "janoschek": { 28 | "uniform": OrderedDict( 29 | a={"type": "uniform", "param1": 0, "param2": 1}, 30 | beta={"type": "uniform", "param1": 0, "param2": 2}, 31 | k={"type": "uniform", "param1": 0, "param2": 1}, 32 | delta={"type": "uniform", "param1": -5, "param2": 5}, 33 | ), 34 | "peaked": OrderedDict( 35 | a={"type": "uniform", "param1": 0, "param2": 1}, 36 | beta={"type": "uniform", "param1": 0, "param2": 2}, 37 | k={"type": "log_normal", "param1": -2, "param2": 1}, 38 | delta={"type": "log_normal", "param1": 0, "param2": 0.5}, 39 | ), 40 | }, 41 | } 42 | 43 | 44 | def prior_sampler(rng, type, param1, param2): 45 | if type == "uniform": 46 | return rng.uniform(param1, param2) 47 | elif type == "log_normal": 48 | return rng.lognormal(param1, param2) 49 | raise Exception("Unknown prior type: {}".format(type)) 50 | 51 | 52 | def pow3(x, c, a, alpha): 53 | return c - a * (x) ** (-alpha) 54 | 55 | 56 | def prior_pow3(rng): 57 | return { 58 | p: prior_sampler( 59 | rng, 60 | prior["pow3"]["peaked"][p]["type"], 61 | param1=prior["pow3"]["peaked"][p]["param1"], 62 | param2=prior["pow3"]["peaked"][p]["param2"], 63 | ) 64 | for p in ["a", "c", "alpha"] 65 | } 66 | 67 | 68 | def uniform_prior_pow3(rng): 69 | return { 70 | p: prior_sampler( 71 | rng, 72 | prior["pow3"]["uniform"][p]["type"], 73 | param1=prior["pow3"]["uniform"][p]["param1"], 74 | param2=prior["pow3"]["uniform"][p]["param2"], 75 | ) 76 | for p in ["a", "c", "alpha"] 77 | } 78 | 79 | 80 | def ilog2(x, c, a): 81 | return c - a / (np.log(x + 1)) 82 | 83 | 84 | def prior_ilog2(rng): 85 | return { 86 | p: prior_sampler( 87 | rng, 88 | prior["ilog2"]["peaked"][p]["type"], 89 | param1=prior["ilog2"]["peaked"][p]["param1"], 90 | param2=prior["ilog2"]["peaked"][p]["param2"], 91 | ) 92 | for p in ["a", "c"] 93 | } 94 | 95 | 96 | def uniform_prior_ilog2(rng): 97 | return { 98 | p: prior_sampler( 99 | rng, 100 | prior["ilog2"]["uniform"][p]["type"], 101 | param1=prior["ilog2"]["uniform"][p]["param1"], 102 | param2=prior["ilog2"]["uniform"][p]["param2"], 103 | ) 104 | for p in ["a", "c"] 105 | } 106 | 107 | 108 | def janoschek(x, a, beta, k, delta): 109 | """ 110 | http://www.pisces-conservation.com/growthhelp/janoschek.htm 111 | """ 112 | return a - (a - beta) * np.exp(-k * x**delta) 113 | 114 | 115 | def prior_janoschek(rng): 116 | return { 117 | p: prior_sampler( 118 | rng, 119 | prior["janoschek"]["peaked"][p]["type"], 120 | param1=prior["janoschek"]["peaked"][p]["param1"], 121 | param2=prior["janoschek"]["peaked"][p]["param2"], 122 | ) 123 | for p in ["a", "beta", "k", "delta"] 124 | } 125 | 126 | 127 | def uniform_prior_janoschek(rng): 128 | return { 129 | p: prior_sampler( 130 | rng, 131 | prior["janoschek"]["uniform"][p]["type"], 132 | param1=prior["janoschek"]["uniform"][p]["param1"], 133 | param2=prior["janoschek"]["uniform"][p]["param2"], 134 | ) 135 | for p in ["a", "beta", "k", "delta"] 136 | } 137 | 138 | 139 | def log_power(x, a, b, c): 140 | # a: upper bound 141 | # c: growth rate 142 | # initial = a/ (1 + (1/e^b)^c 143 | return a / (1.0 + (x / np.exp(b)) ** c) 144 | 145 | 146 | def prior_log_power(rng): 147 | # a ~ N(0.8,0.1) 148 | # b ~ N(1,1) 149 | # c ~ U(-3,0) 150 | a = rng.normal(0.8, 0.1) 151 | b = rng.normal(1.0, 1.0) 152 | c = rng.uniform(-3.0, 0.0) 153 | return {"a": a, "b": b, "c": c} 154 | 155 | 156 | def weibull(x, alpha, beta, kappa, delta): 157 | """ 158 | Weibull modell 159 | http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm 160 | alpha: upper asymptote 161 | beta: lower asymptote 162 | k: growth rate 163 | delta: controls the x-ordinate for the point of inflection 164 | """ 165 | return alpha - (alpha - beta) * np.exp(-((kappa * x) ** delta)) 166 | 167 | 168 | def prior_weibull(rng): 169 | alpha = rng.uniform(0.0, 1.5) 170 | beta = rng.uniform(0.0, 1) 171 | kappa = np.exp(rng.normal(-2.0, 1.0)) 172 | delta = np.exp(rng.normal(0, 0.5)) 173 | return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta} 174 | 175 | 176 | def mmf(x, alpha, beta, kappa, delta): 177 | """ 178 | Morgan-Mercer-Flodin 179 | description: 180 | Nonlinear Regression page 342 181 | http://bit.ly/1jodG17 182 | http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm 183 | alpha: upper asymptote 184 | kappa: growth rate 185 | beta: initial value 186 | delta: controls the point of inflection 187 | """ 188 | return alpha - (alpha - beta) / (1.0 + (kappa * x) ** delta) 189 | 190 | 191 | def prior_mmf(rng): 192 | # alpha ~ N(0.8,0.1) 193 | # beta ~ N(0.2,0.1) 194 | # ln(kappa) ~ N(0,2) 195 | # ln(delta) ~ N(0,1) 196 | alpha = rng.normal(0.8, 0.1) 197 | beta = rng.normal(0.2, 0.1) 198 | kappa = np.exp(rng.normal(0, 2)) 199 | delta = np.exp(rng.normal(0, 1)) 200 | return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta} 201 | 202 | 203 | def vap(x, a, b, c): 204 | """Vapor pressure model""" 205 | # no upper bound if c > 0 206 | # a = ln(upper bound) for c=0 207 | # a+b = ln(initial) 208 | return np.exp(a + b / x + c * np.log(x)) 209 | 210 | 211 | def prior_vap(rng): 212 | a = rng.uniform(-2.0, 0.0) # @heri: range check 213 | b = rng.uniform(-4.0, 0.0) # @heri: range check 214 | c = np.exp(rng.uniform(-8.0, 0.0)) # @heri: same as weights 215 | return {"a": a, "b": b, "c": c} 216 | 217 | 218 | def loglog_linear(x, a, b): 219 | x = np.log(x) 220 | return np.log(a * x + b) 221 | 222 | 223 | def prior_loglog_linear(rng): 224 | # ln(a) ~ N(-2, 1) 225 | # ln(b) ~ U(0, 1) 226 | a = np.exp(rng.normal(-2.0, 1.0)) 227 | b = np.exp(rng.uniform(0.0, 1.0)) 228 | return {"a": a, "b": b} 229 | 230 | 231 | def exp4(x, c, a, b, alpha): 232 | return c - np.exp(-a * (x**alpha) + b) 233 | 234 | 235 | def prior_exp4(rng): 236 | # c ~ N(0.8,0.1) 237 | c = rng.normal(0.8, 0.1) 238 | # ln(a) ~ N(-2,1) 239 | a = np.exp(rng.normal(-2, 1)) 240 | # ln(alpha) ~ N(0,1) 241 | alpha = np.exp(rng.normal(0, 1)) 242 | # ln(b) ~ N(0,0.5) 243 | b = np.exp(rng.normal(0, 0.5)) 244 | return {"a": a, "b": b, "c": c, "alpha": alpha} 245 | 246 | 247 | def pow4(x, c, a, b, alpha): 248 | return c - (a * x + b) ** -alpha 249 | 250 | 251 | def prior_pow4(rng): 252 | # ln(1 - c) ~ U(-5, 0) 253 | c = 1 - np.exp(rng.uniform(-5.0, 0)) 254 | # ln(a) ~ N(-3, 2) 255 | a = np.exp(rng.normal(-3.0, 2)) 256 | # ln(alpha) ~ N(0,1) 257 | alpha = np.exp(rng.normal(0, 1)) 258 | # ln(b) ~ U(0, 1) 259 | b = np.exp(rng.uniform(0, 1)) 260 | return {"a": a, "b": b, "c": c, "alpha": alpha} 261 | 262 | 263 | def dr_hill_zero_background(x, theta, eta, kappa): 264 | # theta: upper bound 265 | # eta: growth rate 266 | # initial = theta/(kappa^eta + 1) 267 | return (theta * x**eta) / (kappa**eta + x**eta) 268 | 269 | 270 | def prior_dr_hill_zero_background(rng): 271 | # theta ~ U(1,0) N(0.8,0.1) 272 | # ln(eta) ~ N(1,1) 273 | # ln(kappa) ~ N(1,2) 274 | theta = rng.normal(0.8, 0.1) 275 | eta = np.exp(rng.normal(1.0, 1.0)) 276 | kappa = np.exp(rng.normal(1.0, 2.0)) 277 | return {"theta": theta, "eta": eta, "kappa": kappa} 278 | -------------------------------------------------------------------------------- /lcpfn/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import random 4 | 5 | from torch import Tensor 6 | import torch.nn.functional as F 7 | 8 | 9 | class GELU(nn.Module): 10 | def forward(self, input: Tensor) -> Tensor: 11 | return F.gelu(input) 12 | 13 | 14 | class ScaledDecoder(nn.Module): 15 | def __init__(self, ninp, nhid, nout): 16 | super().__init__() 17 | self.linear = nn.Linear(ninp, nhid) 18 | self.linear1 = nn.Linear(nhid, nout) 19 | self.linear2 = nn.Linear(nhid, 10) 20 | 21 | def forward(self, x): 22 | # return torch.cat([self.linear1(x), self.linear2(x)], -1) 23 | x = self.linear(x) 24 | x = GELU()(x) 25 | temps = self.linear2(x).softmax(-1) @ torch.tensor( 26 | [1.0, 1.4, 1.7, 2.0, 5.0, 10.0, 20.0, 40.0, 80.0, 160.0], device=x.device 27 | ) 28 | if random.random() > 0.99: 29 | print(temps.shape, temps[:, :2]) 30 | return self.linear1(x) / temps.unsqueeze(-1) 31 | 32 | 33 | class FixedScaledDecoder(nn.Module): 34 | def __init__(self, ninp, nhid, nout): 35 | super().__init__() 36 | self.mapper = nn.Sequential( 37 | nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout) 38 | ) 39 | self.T = nn.Parameter(torch.ones(10000) / 10000) 40 | 41 | def forward(self, x): 42 | return self.mapper(x) / self.T.sum() 43 | -------------------------------------------------------------------------------- /lcpfn/domhan_prior.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import numpy as np 4 | from lcpfn.curves import ( 5 | pow3, 6 | ilog2, 7 | janoschek, 8 | log_power, 9 | prior_ilog2, 10 | uniform_prior_pow3, 11 | weibull, 12 | mmf, 13 | vap, 14 | loglog_linear, 15 | exp4, 16 | pow4, 17 | dr_hill_zero_background, 18 | ) 19 | from lcpfn.curves import ( 20 | prior_pow3, 21 | prior_janoschek, 22 | prior_log_power, 23 | prior_weibull, 24 | prior_mmf, 25 | prior_vap, 26 | prior_loglog_linear, 27 | prior_exp4, 28 | prior_pow4, 29 | prior_dr_hill_zero_background, 30 | ) 31 | from lcpfn.curves import ( 32 | uniform_prior_pow3, 33 | uniform_prior_ilog2, 34 | uniform_prior_janoschek, 35 | ) 36 | 37 | 38 | def prior_weights( 39 | rng, 40 | components=[ 41 | "pow3", 42 | "ilog2", 43 | "janoschek", 44 | "log_power", 45 | "weibull", 46 | "mmf", 47 | "vap", 48 | "loglog_linear", 49 | "exp4", 50 | "pow4", 51 | "dr_hill_zero_background", 52 | ], 53 | ): 54 | K = len(components) 55 | weights = rng.uniform(0.0, 1, size=(K,)) 56 | return {f: weights[i] for i, f in enumerate(components)} 57 | 58 | 59 | def sample_from_prior(rng, seq_len=100): 60 | return sample_prior_comb( 61 | rng=rng, 62 | seq_len=seq_len, 63 | components=["pow3", "ilog2", "janoschek"], 64 | distribution="peaked", 65 | ) 66 | 67 | 68 | def sample_prior_comb( 69 | rng, 70 | components, 71 | distribution, 72 | var_lnloc=-4, 73 | var_lnscale=1, 74 | range_constraint=True, 75 | seq_len=100, 76 | ): 77 | f_components = { 78 | "pow3": pow3, 79 | "ilog2": ilog2, 80 | "janoschek": janoschek, 81 | "log_power": log_power, 82 | "weibull": weibull, 83 | "mmf": mmf, 84 | "vap": vap, 85 | "loglog_linear": loglog_linear, 86 | "exp4": exp4, 87 | "pow4": pow4, 88 | "dr_hill_zero_background": dr_hill_zero_background, 89 | } 90 | 91 | if distribution == "peaked": 92 | f_priors = { 93 | "pow3": prior_pow3, 94 | "ilog2": prior_ilog2, 95 | "janoschek": prior_janoschek, 96 | "log_power": prior_log_power, 97 | "weibull": prior_weibull, 98 | "mmf": prior_mmf, 99 | "vap": prior_vap, 100 | "loglog_linear": prior_loglog_linear, 101 | "exp4": prior_exp4, 102 | "pow4": prior_pow4, 103 | "dr_hill_zero_background": prior_dr_hill_zero_background, 104 | } 105 | elif distribution == "uniform": 106 | f_priors = { 107 | "pow3": uniform_prior_pow3, 108 | "ilog2": uniform_prior_ilog2, 109 | "janoschek": uniform_prior_janoschek, 110 | } 111 | else: 112 | raise NotImplemented() 113 | 114 | x = np.arange(1, seq_len + 1) 115 | 116 | while True: 117 | # sample the noiseless curve 118 | weights = prior_weights(rng, components=components) 119 | y = np.zeros(x.shape, dtype="float") 120 | kwargs = 0 121 | for f, w in weights.items(): 122 | kwargs = f_priors[f](rng) 123 | # print(f_components[f](x, **kwargs)) 124 | y += w * f_components[f](x, **kwargs) 125 | # add noise (can exceed [0,1], but afaik no way to implement this prior in Tobias work) 126 | # Note: This is the correct definition, but it differs from the noise prior definition in the paper 127 | std = np.exp( 128 | rng.normal(var_lnloc, var_lnscale) 129 | ) 130 | 131 | # reject any curves that are non-increasing, exceed the [0,1] range 132 | if ( 133 | y[-1] <= y[0] 134 | or (range_constraint and (np.any(y < 0) or np.any(y > 1))) 135 | or np.isnan(y).any() 136 | ): 137 | continue 138 | else: 139 | break 140 | 141 | def curve(): # generates a sample from the same model, but with independent noise 142 | y_noisy = y + rng.normal(np.zeros_like(y), std) 143 | return y, y_noisy 144 | 145 | return curve 146 | 147 | 148 | def generate_prior_dataset(n, prior=sample_prior_comb, seed=42): 149 | """ 150 | Returns a fixed sample from the prior (with fixed seq_len) as an n x seq_len np.ndarray 151 | """ 152 | rng = np.random.RandomState(seed) 153 | prior_data = np.stack([prior(rng)()[1] for _ in range(n)]) 154 | return prior_data 155 | 156 | 157 | def create_get_batch_func(prior): 158 | return partial(get_batch_domhan, prior=prior) 159 | 160 | 161 | # function producing batches for PFN training 162 | def get_batch_domhan( 163 | batch_size, 164 | seq_len, 165 | num_features, 166 | prior, 167 | device="cpu", 168 | noisy_target=True, 169 | **_, 170 | ): 171 | assert num_features == 1 172 | 173 | x = np.arange(1, seq_len + 1) 174 | y_target = np.empty((batch_size, seq_len), dtype=float) 175 | y_noisy = np.empty((batch_size, seq_len), dtype=float) 176 | 177 | for i in range(batch_size): 178 | curve_func = prior(np.random, seq_len=seq_len) # uses numpy rng 179 | if noisy_target: 180 | _, y_noisy[i] = curve_func() 181 | y_target[i] = y_noisy[i] 182 | else: 183 | y_target[i], y_noisy[i] = curve_func() 184 | 185 | # turn numpy arrays into correctly shaped torch tensors & move them to device 186 | x = ( 187 | torch.arange(1, seq_len + 1) 188 | .repeat((num_features, batch_size, 1)) 189 | .transpose(2, 0) 190 | .to(device) 191 | ) 192 | y_target = torch.from_numpy(y_target).transpose(1, 0).to(device) 193 | y_noisy = torch.from_numpy(y_noisy).transpose(1, 0).to(device) 194 | 195 | # changes 196 | x = x.float() 197 | y_target = y_target.float() 198 | y_noisy = y_noisy.float() 199 | 200 | return x, y_noisy, y_target 201 | -------------------------------------------------------------------------------- /lcpfn/encoders.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from lcpfn.utils import normalize_data 6 | import torch.nn.functional as F 7 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 8 | 9 | 10 | class StyleEncoder(nn.Module): 11 | def __init__(self, em_size, hyperparameter_definitions): 12 | super().__init__() 13 | self.em_size = em_size 14 | self.embedding = nn.Linear(hyperparameter_definitions.shape[0], self.em_size) 15 | 16 | def forward(self, hyperparameters): # T x B x num_hps 17 | return self.embedding(hyperparameters) 18 | 19 | 20 | class _PositionalEncoding(nn.Module): 21 | def __init__(self, d_model, dropout=0.0): 22 | super().__init__() 23 | self.dropout = nn.Dropout(p=dropout) 24 | self.d_model = d_model 25 | self.device_test_tensor = nn.Parameter(torch.tensor(1.0)) 26 | 27 | def forward(self, x): # T x B x num_features 28 | assert self.d_model % x.shape[-1] * 2 == 0 29 | d_per_feature = self.d_model // x.shape[-1] 30 | pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device) 31 | # position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 32 | interval_size = 10 33 | div_term = ( 34 | (1.0 / interval_size) 35 | * 2 36 | * math.pi 37 | * torch.exp( 38 | torch.arange( 39 | 0, d_per_feature, 2, device=self.device_test_tensor.device 40 | ).float() 41 | * math.log(math.sqrt(2)) 42 | ) 43 | ) 44 | # print(div_term/2/math.pi) 45 | pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term) 46 | pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term) 47 | return self.dropout(pe).view(x.shape[0], x.shape[1], self.d_model) 48 | 49 | 50 | Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize) 51 | 52 | 53 | class EmbeddingEncoder(nn.Module): 54 | def __init__(self, num_features, em_size, num_embs=100): 55 | super().__init__() 56 | self.num_embs = num_embs 57 | self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True) 58 | self.init_weights(0.1) 59 | self.min_max = (-2, +2) 60 | 61 | @property 62 | def width(self): 63 | return self.min_max[1] - self.min_max[0] 64 | 65 | def init_weights(self, initrange): 66 | self.embeddings.weight.data.uniform_(-initrange, initrange) 67 | 68 | def discretize(self, x): 69 | split_size = self.width / self.num_embs 70 | return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1) 71 | 72 | def forward(self, x): # T x B x num_features 73 | x_idxs = self.discretize(x) 74 | x_idxs += ( 75 | torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs 76 | ) 77 | # print(x_idxs,self.embeddings.weight.shape) 78 | return self.embeddings(x_idxs).mean(-2) 79 | 80 | 81 | class Normalize(nn.Module): 82 | def __init__(self, mean, std): 83 | super().__init__() 84 | self.mean = mean 85 | self.std = std 86 | 87 | def forward(self, x): 88 | return (x - self.mean) / self.std 89 | 90 | 91 | def get_normalized_uniform_encoder(encoder_creator): 92 | """ 93 | This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std. 94 | For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can 95 | be initialized with `encoder_creator(feature_dim, in_dim)`. 96 | :param encoder: 97 | :return: 98 | """ 99 | return lambda in_dim, out_dim: nn.Sequential( 100 | Normalize(0.5, math.sqrt(1 / 12)), encoder_creator(in_dim, out_dim) 101 | ) 102 | 103 | 104 | Linear = nn.Linear 105 | MLP = lambda num_features, emsize: nn.Sequential( 106 | nn.Linear(num_features + 1, emsize * 2), nn.ReLU(), nn.Linear(emsize * 2, emsize) 107 | ) 108 | 109 | 110 | class NanHandlingEncoder(nn.Module): 111 | def __init__(self, num_features, emsize, keep_nans=True): 112 | super().__init__() 113 | self.num_features = 2 * num_features if keep_nans else num_features 114 | self.emsize = emsize 115 | self.keep_nans = keep_nans 116 | self.layer = nn.Linear(self.num_features, self.emsize) 117 | 118 | def forward(self, x): 119 | if self.keep_nans: 120 | x = torch.cat( 121 | [ 122 | torch.nan_to_num(x, nan=0.0), 123 | normalize_data( 124 | torch.isnan(x) * -1 125 | + torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1 126 | + torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2 127 | ), 128 | ], 129 | -1, 130 | ) 131 | else: 132 | x = torch.nan_to_num(x, nan=0.0) 133 | return self.layer(x) 134 | 135 | 136 | class Linear(nn.Linear): 137 | def __init__(self, num_features, emsize): 138 | super().__init__(num_features, emsize) 139 | self.num_features = num_features 140 | self.emsize = emsize 141 | 142 | def forward(self, x): 143 | x = torch.nan_to_num(x, nan=0.0) 144 | return super().forward(x) 145 | 146 | 147 | class Conv(nn.Module): 148 | def __init__(self, input_size, emsize): 149 | super().__init__() 150 | self.convs = torch.nn.ModuleList( 151 | [nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)] 152 | ) 153 | self.linear = nn.Linear(64, emsize) 154 | 155 | def forward(self, x): 156 | size = math.isqrt(x.shape[-1]) 157 | assert size * size == x.shape[-1] 158 | x = x.reshape(*x.shape[:-1], 1, size, size) 159 | for conv in self.convs: 160 | if x.shape[-1] < 4: 161 | break 162 | x = conv(x) 163 | x.relu_() 164 | x = nn.AdaptiveAvgPool2d((1, 1))(x).squeeze(-1).squeeze(-1) 165 | return self.linear(x) 166 | 167 | 168 | class CanEmb(nn.Embedding): 169 | def __init__( 170 | self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs 171 | ): 172 | assert embedding_dim % num_features == 0 173 | embedding_dim = embedding_dim // num_features 174 | super().__init__(num_embeddings, embedding_dim, *args, **kwargs) 175 | 176 | def forward(self, x): 177 | lx = x.long() 178 | assert (lx == x).all(), "CanEmb only works with tensors of whole numbers" 179 | x = super().forward(lx) 180 | return x.view(*x.shape[:-2], -1) 181 | 182 | 183 | def get_Canonical(num_classes): 184 | return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize) 185 | 186 | 187 | def get_Embedding(num_embs_per_feature=100): 188 | return lambda num_features, emsize: EmbeddingEncoder( 189 | num_features, emsize, num_embs=num_embs_per_feature 190 | ) 191 | -------------------------------------------------------------------------------- /lcpfn/initializers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def get_NormalInitializer(std): 6 | def initializer(m): 7 | if isinstance(m, nn.Linear): 8 | nn.init.normal_(m.weight, 0, std) 9 | nn.init.normal_(m.bias, 0, std) 10 | 11 | return initializer 12 | -------------------------------------------------------------------------------- /lcpfn/layer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional 3 | from torch import Tensor 4 | from torch import nn 5 | from torch.nn.modules.transformer import * 6 | from torch.nn.modules.transformer import _get_activation_fn 7 | 8 | from torch.utils.checkpoint import checkpoint 9 | 10 | 11 | class TransformerEncoderLayer(nn.Module): 12 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 13 | This standard encoder layer is based on the paper "Attention Is All You Need". 14 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 15 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 16 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 17 | in a different way during application. 18 | 19 | Args: 20 | d_model: the number of expected features in the input (required). 21 | nhead: the number of heads in the multiheadattention models (required). 22 | dim_feedforward: the dimension of the feedforward network model (default=2048). 23 | dropout: the dropout value (default=0.1). 24 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 25 | layer_norm_eps: the eps value in layer normalization components (default=1e-5). 26 | batch_first: If ``True``, then the input and output tensors are provided 27 | as (batch, seq, feature). Default: ``False``. 28 | 29 | Examples:: 30 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 31 | >>> src = torch.rand(10, 32, 512) 32 | >>> out = encoder_layer(src) 33 | 34 | Alternatively, when ``batch_first`` is ``True``: 35 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) 36 | >>> src = torch.rand(32, 10, 512) 37 | >>> out = encoder_layer(src) 38 | """ 39 | 40 | __constants__ = ["batch_first"] 41 | 42 | def __init__( 43 | self, 44 | d_model, 45 | nhead, 46 | dim_feedforward=2048, 47 | dropout=0.1, 48 | activation="relu", 49 | layer_norm_eps=1e-5, 50 | batch_first=False, 51 | pre_norm=False, 52 | device=None, 53 | dtype=None, 54 | recompute_attn=False, 55 | ) -> None: 56 | factory_kwargs = {"device": device, "dtype": dtype} 57 | super().__init__() 58 | self.self_attn = MultiheadAttention( 59 | d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs 60 | ) 61 | # Implementation of Feedforward model 62 | self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) 63 | self.dropout = Dropout(dropout) 64 | self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) 65 | 66 | self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 67 | self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 68 | self.dropout1 = Dropout(dropout) 69 | self.dropout2 = Dropout(dropout) 70 | self.pre_norm = pre_norm 71 | self.recompute_attn = recompute_attn 72 | 73 | self.activation = _get_activation_fn(activation) 74 | 75 | def __setstate__(self, state): 76 | if "activation" not in state: 77 | state["activation"] = F.relu 78 | super().__setstate__(state) 79 | 80 | def forward( 81 | self, 82 | src: Tensor, 83 | src_mask: Optional[Tensor] = None, 84 | src_key_padding_mask: Optional[Tensor] = None, 85 | ) -> Tensor: 86 | r"""Pass the input through the encoder layer. 87 | 88 | Args: 89 | src: the sequence to the encoder layer (required). 90 | src_mask: the mask for the src sequence (optional). 91 | src_key_padding_mask: the mask for the src keys per batch (optional). 92 | 93 | Shape: 94 | see the docs in Transformer class. 95 | """ 96 | if self.pre_norm: 97 | src_ = self.norm1(src) 98 | else: 99 | src_ = src 100 | if isinstance(src_mask, tuple): 101 | # global attention setup 102 | assert not self.self_attn.batch_first 103 | assert src_key_padding_mask is None 104 | 105 | global_src_mask, trainset_src_mask, valset_src_mask = src_mask 106 | 107 | num_global_tokens = global_src_mask.shape[0] 108 | num_train_tokens = trainset_src_mask.shape[0] 109 | 110 | global_tokens_src = src_[:num_global_tokens] 111 | train_tokens_src = src_[ 112 | num_global_tokens : num_global_tokens + num_train_tokens 113 | ] 114 | global_and_train_tokens_src = src_[: num_global_tokens + num_train_tokens] 115 | eval_tokens_src = src_[num_global_tokens + num_train_tokens :] 116 | 117 | attn = ( 118 | partial(checkpoint, self.self_attn) 119 | if self.recompute_attn 120 | else self.self_attn 121 | ) 122 | 123 | global_tokens_src2 = attn( 124 | global_tokens_src, 125 | global_and_train_tokens_src, 126 | global_and_train_tokens_src, 127 | None, 128 | True, 129 | global_src_mask, 130 | )[0] 131 | train_tokens_src2 = attn( 132 | train_tokens_src, 133 | global_tokens_src, 134 | global_tokens_src, 135 | None, 136 | True, 137 | trainset_src_mask, 138 | )[0] 139 | eval_tokens_src2 = attn( 140 | eval_tokens_src, src_, src_, None, True, valset_src_mask 141 | )[0] 142 | 143 | src2 = torch.cat( 144 | [global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0 145 | ) 146 | 147 | else: 148 | if self.recompute_attn: 149 | src2 = checkpoint( 150 | self.self_attn, 151 | src_, 152 | src_, 153 | src_, 154 | src_key_padding_mask, 155 | True, 156 | src_mask, 157 | )[0] 158 | else: 159 | src2 = self.self_attn( 160 | src_, 161 | src_, 162 | src_, 163 | attn_mask=src_mask, 164 | key_padding_mask=src_key_padding_mask, 165 | )[0] 166 | src = src + self.dropout1(src2) 167 | if not self.pre_norm: 168 | src = self.norm1(src) 169 | 170 | if self.pre_norm: 171 | src_ = self.norm2(src) 172 | else: 173 | src_ = src 174 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src_)))) 175 | src = src + self.dropout2(src2) 176 | 177 | if not self.pre_norm: 178 | src = self.norm2(src) 179 | return src 180 | -------------------------------------------------------------------------------- /lcpfn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lcpfn 3 | import warnings 4 | from lcpfn import utils 5 | 6 | 7 | class LCPFN(torch.nn.Module): 8 | def __init__(self, model_name="EMSIZE512_NLAYERS12_NBUCKETS1000"): 9 | super(LCPFN, self).__init__() 10 | self.model = torch.load( 11 | getattr(lcpfn, model_name) if model_name in lcpfn.model_dict else model_name 12 | ) 13 | self.model.eval() 14 | 15 | def check_input(self, x_train, x_test, y_train, y_test=None): 16 | if torch.any(x_train < 0) or torch.any(x_test < 0): 17 | # raise warning if input has negative values 18 | raise Exception("x values should be non-negative") 19 | if torch.any((0 > y_train) | (y_train > 1)) or ( 20 | y_test is not None and torch.any(0 < y_test < 1) 21 | ): 22 | # raise warning if input has values outside [0,1] 23 | raise Exception( 24 | "y values should be in the range [0,1]. Please set normalizer_kwargs accordingly." 25 | ) 26 | 27 | @torch.no_grad() 28 | def predict_mean( 29 | self, x_train, y_train, x_test, normalizer=utils.identity_normalizer() 30 | ): 31 | y_train_norm = normalizer[0](y_train) 32 | logits = self(x_train=x_train, y_train=y_train_norm, x_test=x_test) 33 | return normalizer[1](self.model.criterion.mean(logits)) 34 | 35 | @torch.no_grad() 36 | def predict_quantiles( 37 | self, x_train, y_train, x_test, qs, normalizer=utils.identity_normalizer() 38 | ): 39 | y_train_norm = normalizer[0](y_train) 40 | logits = self(x_train=x_train, y_train=y_train_norm, x_test=x_test) 41 | return normalizer[1]( 42 | torch.cat([self.model.criterion.icdf(logits, q) for q in qs], dim=1) 43 | ) 44 | 45 | @torch.no_grad() 46 | def nll_loss(self, x_train, y_train, x_test, y_test): 47 | # TODO add normalizer_kwargs 48 | logits = self(x_train=x_train, y_train=y_train, x_test=x_test) 49 | return self.model.criterion(logits, y_test) 50 | 51 | def forward(self, x_train, y_train, x_test): 52 | self.check_input(x_train, x_test, y_train) 53 | single_eval_pos = x_train.shape[0] 54 | x = torch.cat([x_train, x_test], dim=0).unsqueeze(1) 55 | y = y_train.unsqueeze(1) 56 | return self.model((x, y), single_eval_pos=single_eval_pos) 57 | -------------------------------------------------------------------------------- /lcpfn/positional_encodings.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | # Protocol for positonal encodings. 8 | # __init__(d_model, max_len=..[, more optionals]) 9 | # forward(x: (seq_len, bs, d_model)) -> Tensor of shape (*x.shape[:2],d_model) containing pos. embeddings 10 | 11 | 12 | class NoPositionalEncoding(nn.Module): 13 | def __init__(self, d_model, max_len=None): 14 | super(NoPositionalEncoding, self).__init__() 15 | pass 16 | 17 | def forward(self, x): 18 | return x # * math.sqrt(x.shape[-1]) 19 | 20 | 21 | class PositionalEncoding(nn.Module): 22 | def __init__(self, d_model, max_len=5000): 23 | super(PositionalEncoding, self).__init__() 24 | pe = torch.zeros(max_len, d_model) 25 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 26 | div_term = torch.exp( 27 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 28 | ) 29 | pe[:, 0::2] = torch.sin(position * div_term) 30 | pe[:, 1::2] = torch.cos(position * div_term) 31 | pe = pe.unsqueeze(0).transpose(0, 1) 32 | self.register_buffer("pe", pe) 33 | 34 | def forward(self, x): 35 | x = self.pe[: x.size(0), :] + x # * math.sqrt(x.shape[-1]) 36 | return x 37 | 38 | 39 | class LearnedPositionalEncoding(nn.Module): 40 | def __init__(self, d_model, max_len=5000): 41 | super(LearnedPositionalEncoding, self).__init__() 42 | self.max_seq_len = max_len 43 | # self.positional_embeddings = nn.Embedding(max_len, d_model) 44 | self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model)) 45 | nn.init.normal_(self.positional_embeddings, mean=0, std=d_model**-0.5) 46 | 47 | def forward(self, x): 48 | seq_len, bs, d_model = x.shape 49 | assert seq_len <= len( 50 | self.positional_embeddings 51 | ), "seq_len can be at most max_len." 52 | pos_emb = self.positional_embeddings[:seq_len] 53 | return ( 54 | pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x 55 | ) # * math.sqrt(x.shape[-1]) 56 | 57 | 58 | class PairedScrambledPositionalEncodings(LearnedPositionalEncoding): 59 | # TODO check whether it is a problem to use the same perm. for full batch 60 | def forward(self, x): 61 | seq_len, bs, d_model = x.shape 62 | assert seq_len <= len( 63 | self.positional_embeddings 64 | ), "seq_len can be at most max_len." 65 | assert ( 66 | len(self.positional_embeddings) % 2 == 0 67 | ), "Please specify an even max_len." 68 | 69 | paired_embs = self.positional_embeddings.view( 70 | len(self.positional_embeddings), -1, 2 71 | ) 72 | pos_emb = paired_embs[torch.randperm(len(paired_embs))].view( 73 | *self.positional_embeddings.shape 74 | )[:seq_len] 75 | 76 | return ( 77 | pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x 78 | ) # * math.sqrt(x.shape[-1]) 79 | -------------------------------------------------------------------------------- /lcpfn/priors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/lcpfn/ba892f6f451027f69c50edf00c765ded98c75d30/lcpfn/priors/__init__.py -------------------------------------------------------------------------------- /lcpfn/priors/prior.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | class PriorDataLoader(DataLoader, metaclass=ABCMeta): 6 | @abstractmethod 7 | def __init__( 8 | self, 9 | num_steps, 10 | batch_size, 11 | eval_pos_seq_len_sampler, 12 | seq_len_maximum, 13 | device, 14 | **kwargs, 15 | ): 16 | """ 17 | 18 | :param num_steps: int, first argument, the number of steps to take per epoch, i.e. iteration of the DataLoader 19 | :param batch_size: int, number of datasets per batch 20 | :param eval_pos_seq_len_sampler: callable, it takes no arguments and returns a tuple (single eval pos, bptt) 21 | :param kwargs: for future compatibility it is good to have a final all catch, as new kwargs might be introduced 22 | """ 23 | pass 24 | 25 | # A class or object variable `num_features`: int 26 | # Optional: `validate` function that accepts a transformer model 27 | 28 | # The DataLoader iter should return batches of the form ([style], x, y), target_y, single_eval_pos 29 | # We follow sequence len (s) first, batch size (b) second. So x: (s,b,num_features), y,target_y: (s,b) 30 | # and style: Optional[(b,num_style_params)], style can be omitted or set to None, if it is not intended to be used. 31 | 32 | # For more references, see `priors/utils.py` for a pretty general implementation of a DataLoader 33 | # and `train.py` for the only call of it. 34 | -------------------------------------------------------------------------------- /lcpfn/priors/utils.py: -------------------------------------------------------------------------------- 1 | from lcpfn.utils import set_locals_in_self 2 | from .prior import PriorDataLoader 3 | import math 4 | 5 | 6 | def get_batch_to_dataloader(get_batch_method_): 7 | class DL(PriorDataLoader): 8 | get_batch_method = get_batch_method_ 9 | 10 | # Caution, you might need to set self.num_features manually if it is not part of the args. 11 | def __init__(self, num_steps, **get_batch_kwargs): 12 | set_locals_in_self(locals()) 13 | 14 | # The stuff outside the or is set as class attribute before instantiation. 15 | self.num_features = ( 16 | get_batch_kwargs.get("num_features") or self.num_features 17 | ) 18 | print("DataLoader.__dict__", self.__dict__) 19 | 20 | @staticmethod 21 | def gbm(*args, eval_pos_seq_len_sampler, **kwargs): 22 | kwargs["single_eval_pos"], kwargs["seq_len"] = eval_pos_seq_len_sampler() 23 | # Scales the batch size dynamically with the power of 'dynamic_batch_size'. 24 | # A transformer with quadratic memory usage in the seq len would need a power of 2 to keep memory constant. 25 | if "dynamic_batch_size" in kwargs and kwargs["dynamic_batch_size"] > 0: 26 | kwargs["batch_size"] = kwargs["batch_size"] * math.floor( 27 | math.pow(kwargs["seq_len_maximum"], kwargs["dynamic_batch_size"]) 28 | / math.pow(kwargs["seq_len"], kwargs["dynamic_batch_size"]) 29 | ) 30 | batch = get_batch_method_(*args, **kwargs) 31 | x, y, target_y, style = ( 32 | batch if len(batch) == 4 else (batch[0], batch[1], batch[2], None) 33 | ) 34 | return (style, x, y), target_y, kwargs["single_eval_pos"] 35 | 36 | def __len__(self): 37 | return self.num_steps 38 | 39 | def __iter__(self): 40 | return iter( 41 | self.gbm(**self.get_batch_kwargs) for _ in range(self.num_steps) 42 | ) 43 | 44 | return DL 45 | -------------------------------------------------------------------------------- /lcpfn/train.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import time 3 | from contextlib import nullcontext 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from lcpfn import utils 9 | from lcpfn.transformer import TransformerModel 10 | from lcpfn.bar_distribution import ( 11 | BarDistribution, 12 | ) 13 | from lcpfn.utils import ( 14 | get_cosine_schedule_with_warmup, 15 | get_openai_lr, 16 | ) 17 | from lcpfn import positional_encodings 18 | from lcpfn.utils import init_dist 19 | from torch.cuda.amp import autocast, GradScaler 20 | 21 | 22 | class Losses: 23 | gaussian = nn.GaussianNLLLoss(full=True, reduction="none") 24 | mse = nn.MSELoss(reduction="none") 25 | ce = lambda num_classes: nn.CrossEntropyLoss( 26 | reduction="none", weight=torch.ones(num_classes) 27 | ) 28 | bce = nn.BCEWithLogitsLoss(reduction="none") 29 | get_BarDistribution = BarDistribution 30 | 31 | 32 | def train( 33 | priordataloader_class, 34 | criterion, 35 | encoder_generator, 36 | emsize=200, 37 | nhid=200, 38 | nlayers=6, 39 | nhead=2, 40 | dropout=0.2, 41 | epochs=10, 42 | steps_per_epoch=100, 43 | batch_size=200, 44 | bptt=10, 45 | lr=None, 46 | weight_decay=0.0, 47 | warmup_epochs=10, 48 | input_normalization=False, 49 | y_encoder_generator=None, 50 | pos_encoder_generator=None, 51 | decoder=None, 52 | extra_prior_kwargs_dict={}, 53 | scheduler=get_cosine_schedule_with_warmup, 54 | load_weights_from_this_state_dict=None, 55 | validation_period=10, 56 | single_eval_pos_gen=None, 57 | bptt_extra_samples=None, 58 | gpu_device="cuda:0", 59 | aggregate_k_gradients=1, 60 | verbose=True, 61 | style_encoder_generator=None, 62 | epoch_callback=None, 63 | initializer=None, 64 | initialize_with_model=None, 65 | train_mixed_precision=False, 66 | saving_period=10, 67 | checkpoint_file=None, 68 | load_optimizer_from_this_state_dict=None, 69 | output_path=None, 70 | **model_extra_args, 71 | ): 72 | device = gpu_device if torch.cuda.is_available() else "cpu:0" 73 | print(f"Using {device} device") 74 | using_dist, rank, device = init_dist(device) 75 | single_eval_pos_gen = ( 76 | single_eval_pos_gen 77 | if callable(single_eval_pos_gen) 78 | else lambda: single_eval_pos_gen 79 | ) 80 | 81 | def eval_pos_seq_len_sampler(): 82 | single_eval_pos = single_eval_pos_gen() 83 | if bptt_extra_samples: 84 | return single_eval_pos, single_eval_pos + bptt_extra_samples 85 | else: 86 | return single_eval_pos, bptt 87 | 88 | dl = priordataloader_class( 89 | num_steps=steps_per_epoch, 90 | batch_size=batch_size, 91 | eval_pos_seq_len_sampler=eval_pos_seq_len_sampler, 92 | seq_len_maximum=bptt + (bptt_extra_samples if bptt_extra_samples else 0), 93 | device=device, 94 | **extra_prior_kwargs_dict, 95 | ) 96 | 97 | encoder = encoder_generator(dl.num_features, emsize) 98 | style_def = next(iter(dl))[0][ 99 | 0 100 | ] # This is (style, x, y), target with x and y with batch size 101 | print(f"Style definition: {style_def}") 102 | style_encoder = ( 103 | style_encoder_generator(hyperparameter_definitions=style_def[0], em_size=emsize) 104 | if (style_def is not None) 105 | else None 106 | ) 107 | if isinstance(criterion, nn.GaussianNLLLoss): 108 | n_out = 2 109 | elif ( 110 | isinstance(criterion, BarDistribution) 111 | or "BarDistribution" in criterion.__class__.__name__ 112 | ): # TODO remove this fix (only for dev) 113 | n_out = criterion.num_bars 114 | elif isinstance(criterion, nn.CrossEntropyLoss): 115 | n_out = criterion.weight.shape[0] 116 | else: 117 | n_out = 1 118 | model = TransformerModel( 119 | encoder, 120 | n_out, 121 | emsize, 122 | nhead, 123 | nhid, 124 | nlayers, 125 | dropout, 126 | style_encoder=style_encoder, 127 | y_encoder=y_encoder_generator(1, emsize), 128 | input_normalization=input_normalization, 129 | pos_encoder=( 130 | pos_encoder_generator or positional_encodings.NoPositionalEncoding 131 | )(emsize, bptt * 2), 132 | decoder=decoder, 133 | init_method=initializer, 134 | **model_extra_args, 135 | ) 136 | model.criterion = criterion 137 | if load_weights_from_this_state_dict is not None: 138 | model.load_state_dict(load_weights_from_this_state_dict) 139 | if initialize_with_model is not None: 140 | model.init_from_small_model(initialize_with_model) 141 | 142 | print( 143 | f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters" 144 | ) 145 | 146 | try: 147 | for (k, v), (k2, v2) in zip( 148 | model.state_dict().items(), initialize_with_model.state_dict().items() 149 | ): 150 | print(k, ((v - v2) / v).abs().mean(), v.shape) 151 | except Exception: 152 | pass 153 | 154 | model.to(device) 155 | if using_dist: 156 | print("Distributed training") 157 | model = torch.nn.parallel.DistributedDataParallel( 158 | model, device_ids=[rank], output_device=rank, broadcast_buffers=False 159 | ) 160 | 161 | # learning rate 162 | if lr is None: 163 | lr = get_openai_lr(model) 164 | print(f"Using OpenAI max lr of {lr}.") 165 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) 166 | scheduler = scheduler( 167 | optimizer, warmup_epochs, epochs if epochs is not None else 100 168 | ) # when training for fixed time lr schedule takes 100 steps 169 | 170 | if load_optimizer_from_this_state_dict is not None: 171 | optimizer.load_state_dict(load_optimizer_from_this_state_dict) 172 | scaler = GradScaler() if train_mixed_precision else None 173 | 174 | # check that everything uses up-to-date APIs 175 | utils.check_compatibility(dl) 176 | 177 | def train_epoch(): 178 | model.train() # Turn on the train mode 179 | total_loss = 0.0 180 | total_positional_losses = 0.0 181 | total_positional_losses_recorded = 0 182 | before_get_batch = time.time() 183 | assert ( 184 | len(dl) % aggregate_k_gradients == 0 185 | ), "Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it." 186 | for batch, (data, targets, single_eval_pos) in enumerate(dl): 187 | if using_dist and not ( 188 | batch % aggregate_k_gradients == aggregate_k_gradients - 1 189 | ): 190 | cm = model.no_sync() 191 | else: 192 | cm = nullcontext() 193 | with cm: 194 | time_to_get_batch = time.time() - before_get_batch 195 | before_forward = time.time() 196 | 197 | with autocast(enabled=scaler is not None): 198 | # If style is set to None, it should not be transferred to device 199 | output = model( 200 | tuple(e.to(device) if torch.is_tensor(e) else e for e in data) 201 | if isinstance(data, tuple) 202 | else data.to(device), 203 | single_eval_pos=single_eval_pos, 204 | ) 205 | 206 | forward_time = time.time() - before_forward 207 | 208 | if single_eval_pos is not None: 209 | targets = targets[single_eval_pos:] 210 | if isinstance(criterion, nn.GaussianNLLLoss): 211 | assert ( 212 | output.shape[-1] == 2 213 | ), "need to write a little bit of code to handle multiple regression targets at once" 214 | 215 | mean_pred = output[..., 0] 216 | var_pred = output[..., 1].abs() 217 | losses = criterion( 218 | mean_pred.flatten(), 219 | targets.to(device).flatten(), 220 | var=var_pred.flatten(), 221 | ) 222 | elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)): 223 | losses = criterion( 224 | output.flatten(), targets.to(device).flatten() 225 | ) 226 | elif isinstance(criterion, nn.CrossEntropyLoss): 227 | losses = criterion( 228 | output.reshape(-1, n_out), 229 | targets.to(device).long().flatten(), 230 | ) 231 | else: 232 | losses = criterion(output, targets) 233 | losses = losses.view(*output.shape[0:2]) 234 | loss = losses.mean() / aggregate_k_gradients 235 | 236 | if scaler: 237 | loss = scaler.scale(loss) 238 | loss.backward() 239 | 240 | if batch % aggregate_k_gradients == aggregate_k_gradients - 1: 241 | if scaler: 242 | scaler.unscale_(optimizer) 243 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 244 | try: 245 | if scaler: 246 | scaler.step(optimizer) 247 | scaler.update() 248 | else: 249 | optimizer.step() 250 | except: 251 | print("Invalid optimization step encountered") 252 | optimizer.zero_grad() 253 | 254 | step_time = time.time() - before_forward 255 | 256 | if not torch.isnan(loss): 257 | total_loss += losses.mean().cpu().detach() 258 | total_positional_losses += ( 259 | losses.mean(1).cpu().detach() 260 | if single_eval_pos is None 261 | else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt) 262 | * losses[: bptt - single_eval_pos].mean().cpu().detach() 263 | ) 264 | 265 | total_positional_losses_recorded += ( 266 | torch.ones(bptt) 267 | if single_eval_pos is None 268 | else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt) 269 | ) 270 | 271 | before_get_batch = time.time() 272 | return ( 273 | total_loss / steps_per_epoch, 274 | (total_positional_losses / total_positional_losses_recorded).tolist(), 275 | time_to_get_batch, 276 | forward_time, 277 | step_time, 278 | ) 279 | 280 | total_loss = float("inf") 281 | total_positional_losses = float("inf") 282 | list_losses = [] 283 | try: 284 | for epoch in range(1, epochs + 1) if epochs is not None else itertools.count(1): 285 | epoch_start_time = time.time() 286 | ( 287 | total_loss, 288 | total_positional_losses, 289 | time_to_get_batch, 290 | forward_time, 291 | step_time, 292 | ) = train_epoch() 293 | list_losses.append(total_loss.item()) 294 | if hasattr(dl, "validate") and epoch % validation_period == 0: 295 | with torch.no_grad(): 296 | val_score = dl.validate(model) 297 | 298 | else: 299 | val_score = None 300 | 301 | if epoch % saving_period == 0 and checkpoint_file is not None: 302 | checkpoint = { 303 | "model_state_dict": model.state_dict(), 304 | "optimizer_state_dict": optimizer.state_dict(), 305 | "epoch": epoch, 306 | } 307 | torch.save(checkpoint, checkpoint_file) 308 | full_model_path = checkpoint_file.split(".")[0] + "_full_model.pt" 309 | torch.save(model, full_model_path) 310 | 311 | if verbose: 312 | print("-" * 89) 313 | print( 314 | f"| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | " 315 | f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}" 316 | f" data time {time_to_get_batch:5.2f} step time {step_time:5.2f}" 317 | f" forward time {forward_time:5.2f}" 318 | + (f"val score {val_score}" if val_score is not None else "") 319 | ) 320 | print("-" * 89) 321 | 322 | # stepping with wallclock time based scheduler 323 | if epoch_callback is not None and rank == 0: 324 | epoch_callback(model, epoch / epochs) 325 | scheduler.step() 326 | except KeyboardInterrupt: 327 | pass 328 | 329 | if rank == 0: # trivially true for non-parallel training 330 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 331 | model = model.module 332 | dl = None 333 | if output_path is not None: 334 | torch.save(model.to("cpu"), output_path) 335 | print("Checkpoint stored at ", output_path) 336 | return total_loss, total_positional_losses, model.to("cpu"), dl 337 | -------------------------------------------------------------------------------- /lcpfn/train_lcpfn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch import nn 4 | 5 | from lcpfn import bar_distribution, encoders, train 6 | from lcpfn import utils 7 | 8 | from lcpfn.priors import utils as putils 9 | 10 | 11 | def train_lcpfn( 12 | get_batch_func, 13 | seq_len: int = 100, 14 | emsize: int = 512, 15 | nlayers: int = 12, 16 | num_borders: int = 1000, 17 | lr: float = 0.0001, 18 | batch_size: int = 100, 19 | epochs: int = 1000, 20 | ): 21 | """ 22 | Train a LCPFN model using the specified hyperparameters. 23 | 24 | Args: 25 | get_batch_func (callable): A function that returns a batch of learning curves. 26 | seq_len (int, optional): The length of the input sequence. Defaults to 100. 27 | emsize (int, optional): The size of the embedding layer. Defaults to 512. 28 | nlayers (int, optional): The number of layers in the model. Defaults to 12. 29 | num_borders_choices (int, optional): The number of borders to use. Defaults to 1000. 30 | lr (float, optional): The learning rate for the optimizer. Defaults to 0.0001. 31 | batch_size (int, optional): The batch size for training. Defaults to 100. 32 | epochs (int, optional): The number of epochs to train for. Defaults to 1000. 33 | 34 | Returns: 35 | torch.module: The trained model. 36 | """ 37 | 38 | hps = {} 39 | 40 | # PFN training hyperparameters 41 | dataloader = putils.get_batch_to_dataloader(get_batch_func) # type: ignore 42 | 43 | num_features = 1 44 | 45 | ys = get_batch_func( 46 | 10_000, 47 | seq_len, 48 | num_features, 49 | hyperparameters=hps, 50 | single_eval_pos=seq_len, 51 | ) 52 | 53 | bucket_limits = bar_distribution.get_bucket_limits(num_borders, ys=ys[2]) 54 | 55 | # Discretization of the predictive distributions 56 | criterions = { 57 | num_features: { 58 | num_borders: bar_distribution.FullSupportBarDistribution(bucket_limits) 59 | } 60 | } 61 | 62 | config = dict( 63 | nlayers=nlayers, 64 | priordataloader_class=dataloader, 65 | criterion=criterions[num_features][num_borders], 66 | encoder_generator=lambda in_dim, out_dim: nn.Sequential( 67 | encoders.Normalize(0.0, 101.0), 68 | encoders.Normalize(0.5, math.sqrt(1 / 12)), 69 | encoders.Linear(in_dim, out_dim), 70 | ), 71 | emsize=emsize, 72 | nhead=(emsize // 128), 73 | warmup_epochs=(epochs // 4), 74 | y_encoder_generator=encoders.get_normalized_uniform_encoder(encoders.Linear), 75 | batch_size=batch_size, 76 | scheduler=utils.get_cosine_schedule_with_warmup, 77 | extra_prior_kwargs_dict={ 78 | # "num_workers": 10, 79 | "num_features": num_features, 80 | "hyperparameters": { 81 | **hps, 82 | }, 83 | }, 84 | epochs=epochs, 85 | lr=lr, 86 | bptt=seq_len, 87 | single_eval_pos_gen=utils.get_uniform_single_eval_pos_sampler( 88 | seq_len, min_len=1 89 | ), 90 | aggregate_k_gradients=1, 91 | nhid=(emsize * 2), 92 | steps_per_epoch=100, 93 | train_mixed_precision=False, 94 | ) 95 | 96 | return train.train(**config) 97 | -------------------------------------------------------------------------------- /lcpfn/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | import torch.nn.functional as F 8 | from torch.nn import Module, TransformerEncoder 9 | 10 | from lcpfn.layer import TransformerEncoderLayer, _get_activation_fn 11 | from lcpfn.utils import SeqBN, bool_mask_to_att_mask 12 | 13 | 14 | class GELU(nn.Module): 15 | def forward(self, input: Tensor) -> Tensor: 16 | return F.gelu(input) 17 | 18 | 19 | class TransformerModel(nn.Module): 20 | def __init__( 21 | self, 22 | encoder, 23 | n_out, 24 | ninp, 25 | nhead, 26 | nhid, 27 | nlayers, 28 | dropout=0.0, 29 | style_encoder=None, 30 | y_encoder=None, 31 | pos_encoder=None, 32 | decoder=None, 33 | input_normalization=False, 34 | init_method=None, 35 | pre_norm=False, 36 | activation="gelu", 37 | recompute_attn=False, 38 | num_global_att_tokens=0, 39 | full_attention=False, 40 | all_layers_same_init=True, 41 | ): 42 | super().__init__() 43 | self.model_type = "Transformer" 44 | encoder_layer_creator = lambda: TransformerEncoderLayer( 45 | ninp, 46 | nhead, 47 | nhid, 48 | dropout, 49 | activation=activation, 50 | pre_norm=pre_norm, 51 | recompute_attn=recompute_attn, 52 | ) 53 | self.transformer_encoder = ( 54 | TransformerEncoder(encoder_layer_creator(), nlayers) 55 | if all_layers_same_init 56 | else TransformerEncoderDiffInit(encoder_layer_creator, nlayers) 57 | ) 58 | self.ninp = ninp 59 | self.encoder = encoder 60 | self.y_encoder = y_encoder 61 | self.pos_encoder = pos_encoder 62 | self.decoder = ( 63 | decoder(ninp, nhid, n_out) 64 | if decoder is not None 65 | else nn.Sequential(nn.Linear(ninp, nhid), GELU(), nn.Linear(nhid, n_out)) 66 | ) 67 | self.input_ln = SeqBN(ninp) if input_normalization else None 68 | self.style_encoder = style_encoder 69 | self.init_method = init_method 70 | if num_global_att_tokens is not None: 71 | assert not full_attention 72 | self.global_att_embeddings = ( 73 | nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None 74 | ) 75 | self.full_attention = full_attention 76 | 77 | self.n_out = n_out 78 | self.nhid = nhid 79 | 80 | self.init_weights() 81 | 82 | @staticmethod 83 | def generate_square_subsequent_mask(sz): 84 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 85 | return bool_mask_to_att_mask(mask) 86 | 87 | @staticmethod 88 | def generate_D_q_matrix(sz, query_size): 89 | train_size = sz - query_size 90 | mask = torch.zeros(sz, sz) == 0 91 | mask[:, train_size:].zero_() 92 | mask |= torch.eye(sz) == 1 93 | return bool_mask_to_att_mask(mask) 94 | 95 | @staticmethod 96 | def generate_global_att_query_matrix( 97 | num_global_att_tokens, seq_len, num_query_tokens 98 | ): 99 | train_size = seq_len + num_global_att_tokens - num_query_tokens 100 | sz = seq_len + num_global_att_tokens 101 | mask = torch.zeros(num_query_tokens, sz) == 0 102 | mask[:, train_size:].zero_() 103 | mask[:, train_size:] |= torch.eye(num_query_tokens) == 1 104 | return bool_mask_to_att_mask(mask) 105 | 106 | @staticmethod 107 | def generate_global_att_trainset_matrix( 108 | num_global_att_tokens, seq_len, num_query_tokens 109 | ): 110 | train_size = seq_len + num_global_att_tokens - num_query_tokens 111 | trainset_size = seq_len - num_query_tokens 112 | mask = torch.zeros(trainset_size, num_global_att_tokens) == 0 113 | # mask[:,num_global_att_tokens:].zero_() 114 | # mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1 115 | return bool_mask_to_att_mask(mask) 116 | 117 | @staticmethod 118 | def generate_global_att_globaltokens_matrix( 119 | num_global_att_tokens, seq_len, num_query_tokens 120 | ): 121 | mask = ( 122 | torch.zeros( 123 | num_global_att_tokens, 124 | num_global_att_tokens + seq_len - num_query_tokens, 125 | ) 126 | == 0 127 | ) 128 | return bool_mask_to_att_mask(mask) 129 | 130 | def init_weights(self): 131 | initrange = 1.0 132 | # if isinstance(self.encoder,EmbeddingEncoder): 133 | # self.encoder.weight.data.uniform_(-initrange, initrange) 134 | # self.decoder.bias.data.zero_() 135 | # self.decoder.weight.data.uniform_(-initrange, initrange) 136 | if self.init_method is not None: 137 | self.apply(self.init_method) 138 | for layer in self.transformer_encoder.layers: 139 | nn.init.zeros_(layer.linear2.weight) 140 | nn.init.zeros_(layer.linear2.bias) 141 | attns = ( 142 | layer.self_attn 143 | if isinstance(layer.self_attn, nn.ModuleList) 144 | else [layer.self_attn] 145 | ) 146 | for attn in attns: 147 | nn.init.zeros_(attn.out_proj.weight) 148 | nn.init.zeros_(attn.out_proj.bias) 149 | 150 | def forward(self, src, src_mask=None, single_eval_pos=None): 151 | assert isinstance( 152 | src, tuple 153 | ), "inputs (src) have to be given as (x,y) or (style,x,y) tuple" 154 | 155 | if len(src) == 2: # (x,y) and no style 156 | src = (None,) + src 157 | 158 | style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1)) 159 | if src_mask is not None: 160 | assert self.global_att_embeddings is None or isinstance(src_mask, tuple) 161 | if src_mask is None: 162 | x_src = src[1] 163 | if self.global_att_embeddings is None: 164 | full_len = len(x_src) + style_src_size 165 | if self.full_attention: 166 | src_mask = bool_mask_to_att_mask( 167 | torch.ones((full_len, full_len), dtype=torch.bool) 168 | ).to(x_src.device) 169 | else: 170 | src_mask = self.generate_D_q_matrix( 171 | len(x_src) + style_src_size, 172 | len(x_src) + style_src_size - single_eval_pos, 173 | ).to(x_src.device) 174 | else: 175 | src_mask_args = ( 176 | self.global_att_embeddings.num_embeddings, 177 | len(x_src) + style_src_size, 178 | len(x_src) + style_src_size - single_eval_pos, 179 | ) 180 | src_mask = ( 181 | self.generate_global_att_globaltokens_matrix(*src_mask_args).to( 182 | x_src.device 183 | ), 184 | self.generate_global_att_trainset_matrix(*src_mask_args).to( 185 | x_src.device 186 | ), 187 | self.generate_global_att_query_matrix(*src_mask_args).to( 188 | x_src.device 189 | ), 190 | ) 191 | 192 | style_src, x_src, y_src = src 193 | x_src = self.encoder(x_src) 194 | y_src = self.y_encoder( 195 | y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src 196 | ) 197 | style_src = ( 198 | self.style_encoder(style_src).unsqueeze(0) 199 | if self.style_encoder 200 | else torch.tensor([], device=x_src.device) 201 | ) 202 | global_src = ( 203 | torch.tensor([], device=x_src.device) 204 | if self.global_att_embeddings is None 205 | else self.global_att_embeddings.weight.unsqueeze(1).repeat( 206 | 1, x_src.shape[1], 1 207 | ) 208 | ) 209 | train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos] 210 | src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0) 211 | 212 | if self.input_ln is not None: 213 | src = self.input_ln(src) 214 | 215 | if self.pos_encoder is not None: 216 | src = self.pos_encoder(src) 217 | 218 | # If we have style input, drop its output 219 | output = self.transformer_encoder(src, src_mask)[style_src_size:] 220 | output = self.decoder(output) 221 | return output[ 222 | single_eval_pos 223 | + ( 224 | self.global_att_embeddings.num_embeddings 225 | if self.global_att_embeddings 226 | else 0 227 | ) : 228 | ] 229 | 230 | @torch.no_grad() 231 | def init_from_small_model(self, small_model): 232 | assert ( 233 | isinstance(self.decoder, nn.Linear) 234 | and isinstance(self.encoder, (nn.Linear, nn.Sequential)) 235 | and isinstance(self.y_encoder, (nn.Linear, nn.Sequential)) 236 | ) 237 | 238 | def set_encoder_weights(my_encoder, small_model_encoder): 239 | my_encoder_linear, small_encoder_linear = ( 240 | (my_encoder, small_model_encoder) 241 | if isinstance(my_encoder, nn.Linear) 242 | else (my_encoder[-1], small_model_encoder[-1]) 243 | ) 244 | small_in_dim = small_encoder_linear.out_features 245 | my_encoder_linear.weight.zero_() 246 | my_encoder_linear.bias.zero_() 247 | my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight 248 | my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias 249 | 250 | set_encoder_weights(self.encoder, small_model.encoder) 251 | set_encoder_weights(self.y_encoder, small_model.y_encoder) 252 | 253 | small_in_dim = small_model.decoder.in_features 254 | 255 | self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight 256 | self.decoder.bias = small_model.decoder.bias 257 | 258 | for my_layer, small_layer in zip( 259 | self.transformer_encoder.layers, small_model.transformer_encoder.layers 260 | ): 261 | small_hid_dim = small_layer.linear1.out_features 262 | my_in_dim = my_layer.linear1.in_features 263 | 264 | # packed along q,k,v order in first dim 265 | my_in_proj_w = my_layer.self_attn.in_proj_weight 266 | small_in_proj_w = small_layer.self_attn.in_proj_weight 267 | 268 | my_in_proj_w.view(3, my_in_dim, my_in_dim)[ 269 | :, :small_in_dim, :small_in_dim 270 | ] = small_in_proj_w.view(3, small_in_dim, small_in_dim) 271 | my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:, :small_in_dim] = ( 272 | small_layer.self_attn.in_proj_bias.view(3, small_in_dim) 273 | ) 274 | 275 | my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = ( 276 | small_layer.self_attn.out_proj.weight 277 | ) 278 | my_layer.self_attn.out_proj.bias[:small_in_dim] = ( 279 | small_layer.self_attn.out_proj.bias 280 | ) 281 | 282 | my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = ( 283 | small_layer.linear1.weight 284 | ) 285 | my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias 286 | 287 | my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = ( 288 | small_layer.linear2.weight 289 | ) 290 | my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias 291 | 292 | my_layer.norm1.weight[:small_in_dim] = ( 293 | math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight 294 | ) 295 | my_layer.norm2.weight[:small_in_dim] = ( 296 | math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight 297 | ) 298 | 299 | my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias 300 | my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias 301 | 302 | 303 | class TransformerEncoderDiffInit(Module): 304 | r"""TransformerEncoder is a stack of N encoder layers 305 | 306 | Args: 307 | encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required). 308 | num_layers: the number of sub-encoder-layers in the encoder (required). 309 | norm: the layer normalization component (optional). 310 | """ 311 | 312 | __constants__ = ["norm"] 313 | 314 | def __init__(self, encoder_layer_creator, num_layers, norm=None): 315 | super().__init__() 316 | self.layers = nn.ModuleList( 317 | [encoder_layer_creator() for _ in range(num_layers)] 318 | ) 319 | self.num_layers = num_layers 320 | self.norm = norm 321 | 322 | def forward( 323 | self, 324 | src: Tensor, 325 | mask: Optional[Tensor] = None, 326 | src_key_padding_mask: Optional[Tensor] = None, 327 | ) -> Tensor: 328 | r"""Pass the input through the encoder layers in turn. 329 | 330 | Args: 331 | src: the sequence to the encoder (required). 332 | mask: the mask for the src sequence (optional). 333 | src_key_padding_mask: the mask for the src keys per batch (optional). 334 | 335 | Shape: 336 | see the docs in Transformer class. 337 | """ 338 | output = src 339 | 340 | for mod in self.layers: 341 | output = mod( 342 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask 343 | ) 344 | 345 | if self.norm is not None: 346 | output = self.norm(output) 347 | 348 | return output 349 | -------------------------------------------------------------------------------- /lcpfn/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import random 5 | import datetime 6 | 7 | import torch 8 | from torch import nn 9 | from torch.optim.lr_scheduler import LambdaLR 10 | import numpy as np 11 | 12 | 13 | # copied from huggingface 14 | def get_cosine_schedule_with_warmup( 15 | optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1 16 | ): 17 | """Create a schedule with a learning rate that decreases following the 18 | values of the cosine function between 0 and `pi * cycles` after a warmup 19 | period during which it increases linearly between 0 and 1. 20 | """ 21 | 22 | def lr_lambda(current_step): 23 | if current_step < num_warmup_steps: 24 | return float(current_step) / float(max(1, num_warmup_steps)) 25 | progress = float(current_step - num_warmup_steps) / float( 26 | max(1, num_training_steps - num_warmup_steps) 27 | ) 28 | return max( 29 | 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) 30 | ) 31 | 32 | return LambdaLR(optimizer, lr_lambda, last_epoch) 33 | 34 | 35 | # copied from huggingface 36 | def get_linear_schedule_with_warmup( 37 | optimizer, num_warmup_steps, num_training_steps, last_epoch=-1 38 | ): 39 | """ 40 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 41 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 42 | 43 | Args: 44 | optimizer (:class:`~torch.optim.Optimizer`): 45 | The optimizer for which to schedule the learning rate. 46 | num_warmup_steps (:obj:`int`): 47 | The number of steps for the warmup phase. 48 | num_training_steps (:obj:`int`): 49 | The total number of training steps. 50 | last_epoch (:obj:`int`, `optional`, defaults to -1): 51 | The index of the last epoch when resuming training. 52 | 53 | Return: 54 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 55 | """ 56 | 57 | def lr_lambda(current_step: int): 58 | if current_step < num_warmup_steps: 59 | return float(current_step) / float(max(1, num_warmup_steps)) 60 | return max( 61 | 0.0, 62 | float(num_training_steps - current_step) 63 | / float(max(1, num_training_steps - num_warmup_steps)), 64 | ) 65 | 66 | return LambdaLR(optimizer, lr_lambda, last_epoch) 67 | 68 | 69 | def get_openai_lr(transformer_model): 70 | num_params = sum(p.numel() for p in transformer_model.parameters()) 71 | return 0.003239 - 0.0001395 * math.log(num_params) 72 | 73 | 74 | def get_weighted_single_eval_pos_sampler(max_len): 75 | """ 76 | This gives a sampler that can be used for `single_eval_pos` which yields good performance for all positions p, 77 | where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer. 78 | :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`. 79 | """ 80 | return lambda: random.choices( 81 | range(max_len), [1 / (max_len - i) for i in range(max_len)] 82 | )[0] 83 | 84 | 85 | def get_uniform_single_eval_pos_sampler(max_len, min_len=0): 86 | """ 87 | Just sample any evaluation position with the same weight 88 | :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`. 89 | """ 90 | return lambda: random.choices(range(min_len, max_len))[0] 91 | 92 | 93 | class SeqBN(nn.Module): 94 | def __init__(self, d_model): 95 | super().__init__() 96 | self.bn = nn.BatchNorm1d(d_model) 97 | self.d_model = d_model 98 | 99 | def forward(self, x): 100 | assert self.d_model == x.shape[-1] 101 | flat_x = x.view(-1, self.d_model) 102 | flat_x = self.bn(flat_x) 103 | return flat_x.view(*x.shape) 104 | 105 | 106 | def set_locals_in_self(locals): 107 | """ 108 | Call this function like `set_locals_in_self(locals())` to set all local variables as object variables. 109 | Especially useful right at the beginning of `__init__`. 110 | :param locals: `locals()` 111 | """ 112 | self = locals["self"] 113 | for var_name, val in locals.items(): 114 | if var_name != "self": 115 | setattr(self, var_name, val) 116 | 117 | 118 | default_device = "cuda:0" if torch.cuda.is_available() else "cpu:0" 119 | 120 | 121 | # Copied from StackOverflow, but we do an eval on the values additionally 122 | class StoreDictKeyPair(argparse.Action): 123 | def __init__(self, option_strings, dest, nargs=None, **kwargs): 124 | self._nargs = nargs 125 | super(StoreDictKeyPair, self).__init__( 126 | option_strings, dest, nargs=nargs, **kwargs 127 | ) 128 | 129 | def __call__(self, parser, namespace, values, option_string=None): 130 | my_dict = {} 131 | for kv in values: 132 | k, v = kv.split("=") 133 | try: 134 | my_dict[k] = eval(v) 135 | except NameError: 136 | my_dict[k] = v 137 | setattr(namespace, self.dest, my_dict) 138 | print("dict values: {}".format(my_dict)) 139 | 140 | 141 | def get_nan_value(v, set_value_to_nan=0.0): 142 | if random.random() < set_value_to_nan: 143 | return v 144 | else: 145 | return random.choice([-999, 0, 1, 999]) 146 | 147 | 148 | def to_ranking(data): 149 | x = data >= data.unsqueeze(-3) 150 | x = x.sum(0) 151 | return x 152 | 153 | 154 | # TODO: Is there a better way to do this? 155 | # 1. Cmparing to unique elements: When all values are different we still get quadratic blowup 156 | # 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic 157 | # 3. Argsort(Argsort(Unique))->Scatter seems a bit complicated, doesn't have quadratic blowup, but how fast? 158 | def to_ranking_low_mem(data): 159 | x = torch.zeros_like(data) 160 | for col in range(data.shape[-1]): 161 | x_ = data[:, :, col] >= data[:, :, col].unsqueeze(-2) 162 | x_ = x_.sum(0) 163 | x[:, :, col] = x_ 164 | return x 165 | 166 | 167 | def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0): 168 | return get_nan_value(float("nan"), set_value_to_nan) 169 | 170 | 171 | def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0): 172 | return get_nan_value(float("-inf"), set_value_to_nan) 173 | 174 | 175 | def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0): 176 | return get_nan_value(float("inf"), set_value_to_nan) 177 | 178 | 179 | def torch_nanmean(x, axis=0): 180 | num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum( 181 | axis=axis 182 | ) 183 | value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis) 184 | return value / num 185 | 186 | 187 | def torch_nanstd(x, axis=0): 188 | num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum( 189 | axis=axis 190 | ) 191 | value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis) 192 | mean = value / num 193 | mean_broadcast = torch.repeat_interleave( 194 | mean.unsqueeze(axis), x.shape[axis], dim=axis 195 | ) 196 | return torch.sqrt( 197 | torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1) 198 | ) 199 | 200 | 201 | def normalize_data(data, normalize_positions=-1): 202 | if normalize_positions > 0: 203 | mean = torch_nanmean(data[:normalize_positions], axis=0) 204 | std = torch_nanstd(data[:normalize_positions], axis=0) + 0.000001 205 | else: 206 | mean = torch_nanmean(data, axis=0) 207 | std = torch_nanstd(data, axis=0) + 0.000001 208 | data = (data - mean) / std 209 | data = torch.clip(data, min=-100, max=100) 210 | 211 | return data 212 | 213 | 214 | def remove_outliers(X, n_sigma=4): 215 | # Expects T, B, H 216 | assert len(X.shape) == 3, "X must be T,B,H" 217 | # for b in range(X.shape[1]): 218 | # for col in range(X.shape[2]): 219 | data = X 220 | data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0) 221 | cut_off = data_std * n_sigma 222 | lower, upper = data_mean - cut_off, data_mean + cut_off 223 | 224 | data_clean = X[:].clone() 225 | data_clean[torch.logical_or(data > upper, data < lower)] = np.nan 226 | data_mean, data_std = ( 227 | torch_nanmean(data_clean, axis=0), 228 | torch_nanstd(data_clean, axis=0), 229 | ) 230 | cut_off = data_std * n_sigma 231 | lower, upper = data_mean - cut_off, data_mean + cut_off 232 | 233 | X = torch.maximum(-torch.log(1 + torch.abs(X)) + lower, X) 234 | X = torch.minimum(torch.log(1 + torch.abs(X)) + upper, X) 235 | # print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std) 236 | return X 237 | 238 | 239 | def bool_mask_to_att_mask(mask): 240 | return ( 241 | mask.float() 242 | .masked_fill(mask == 0, float("-inf")) 243 | .masked_fill(mask == 1, float(0.0)) 244 | ) 245 | 246 | 247 | def print_on_master_only(is_master): 248 | import builtins as __builtin__ 249 | 250 | builtin_print = __builtin__.print 251 | 252 | def print(*args, **kwargs): 253 | force = kwargs.pop("force", False) 254 | if is_master or force: 255 | builtin_print(*args, **kwargs) 256 | 257 | __builtin__.print = print 258 | 259 | 260 | def init_dist(device): 261 | print("init dist") 262 | if "LOCAL_RANK" in os.environ: 263 | # launched with torch.distributed.launch 264 | rank = int(os.environ["LOCAL_RANK"]) 265 | print("torch.distributed.launch and my rank is", rank) 266 | torch.cuda.set_device(rank) 267 | os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) 268 | torch.distributed.init_process_group( 269 | backend="nccl", 270 | init_method="env://", 271 | timeout=datetime.timedelta(seconds=20), 272 | world_size=torch.cuda.device_count(), 273 | rank=rank, 274 | ) 275 | torch.distributed.barrier() 276 | print_on_master_only(rank == 0) 277 | print( 278 | f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, " 279 | "only I can print, but when using print(..., force=True) it will print on all ranks." 280 | ) 281 | return True, rank, f"cuda:{rank}" 282 | elif "SLURM_PROCID" in os.environ and torch.cuda.device_count() > 1: 283 | # this is for multi gpu when starting with submitit 284 | assert device != "cpu:0" 285 | rank = int(os.environ["SLURM_PROCID"]) 286 | os.environ["MASTER_ADDR"] = "localhost" 287 | os.environ["MASTER_PORT"] = "12355" 288 | torch.cuda.set_device(rank) 289 | os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) 290 | print("distributed submitit launch and my rank is", rank) 291 | torch.distributed.init_process_group( 292 | backend="nccl", 293 | init_method="env://", 294 | timeout=datetime.timedelta(seconds=20), 295 | world_size=torch.cuda.device_count(), 296 | rank=rank, 297 | ) 298 | torch.distributed.barrier() 299 | print_on_master_only(rank == 0) 300 | print( 301 | f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, " 302 | "only I can print, but when using print(..., force=True) it will print on all ranks." 303 | ) 304 | 305 | return True, rank, f"cuda:{rank}" 306 | else: 307 | print("Not using distributed") 308 | # will not change any of the behavior of print, but allows putting the force=True in the print calls 309 | print_on_master_only(True) 310 | return False, 0, device 311 | 312 | 313 | def check_compatibility(dl): 314 | if hasattr(dl, "num_outputs"): 315 | print( 316 | "`num_outputs` for the DataLoader is deprecated. It is assumed to be 1 from now on." 317 | ) 318 | assert dl.num_outputs != 1, ( 319 | "We assume num_outputs to be 1. Instead of the num_ouputs change your loss." 320 | "We specify the number of classes in the CE loss." 321 | ) 322 | 323 | 324 | def pfn_normalize( 325 | lb=torch.tensor(float("-inf")), 326 | ub=torch.tensor(float("inf")), 327 | soft_lb=0.0, 328 | soft_ub=1.0, 329 | minimize=False, 330 | ): 331 | """ 332 | LC-PFN curve prior assumes curves to be normalized within the range [0,1] and to be maximized. 333 | This function allows to normalize and denormalize data to fit this assumption. 334 | 335 | Parameters: 336 | lb (torch.Tensor): Lower bound of the data. 337 | ub (torch.Tensor): Upper bound of the data. 338 | soft_lb (float): Soft lower bound for normalization. Default is 0.0. 339 | soft_ub (float): Soft upper bound for normalization. Default is 1.0. 340 | minimize (bool): If True, the original curve is a minization. Default is False. 341 | 342 | Returns: Two functions for normalizing and denormalizing the data. 343 | """ 344 | assert lb <= soft_lb and soft_lb < soft_ub and soft_ub <= ub 345 | # step 1: linearly transform [soft_lb,soft_ub] [-1,1] (where the sigmoid behaves approx linearly) 346 | # 2.0/(soft_ub - soft_lb)*(x - soft_lb) - 1.0 347 | # step 2: apply a vertically scaled/shifted the sigmoid such that [lb,ub] --> [0,1] 348 | 349 | def cinv(x): 350 | return 1 - x if minimize else x 351 | 352 | def lin_soft(x): 353 | return 2 / (soft_ub - soft_lb) * (x - soft_lb) - 1 354 | 355 | def lin_soft_inv(y): 356 | return (y + 1) / 2 * (soft_ub - soft_lb) + soft_lb 357 | 358 | try: 359 | if torch.exp(-lin_soft(lb)) > 1e300: 360 | raise RuntimeError 361 | # otherwise overflow causes issues, treat these cases as if the lower bound was -infinite 362 | # print(f"WARNING: {lb} --> NINF to avoid overflows ({np.exp(-lin_soft(lb))})") 363 | except RuntimeError: 364 | lb = torch.tensor(float("-inf")) 365 | if torch.isinf(lb) and torch.isinf(ub): 366 | return lambda x: cinv( 367 | 1 / (1 + torch.exp(-lin_soft(x))) 368 | ), lambda y: lin_soft_inv(torch.log(cinv(y) / (1 - cinv(y)))) 369 | elif torch.isinf(lb): 370 | a = 1 + torch.exp(-lin_soft(ub)) 371 | return lambda x: cinv( 372 | a / (1 + torch.exp(-lin_soft(x))) 373 | ), lambda y: lin_soft_inv(torch.log((cinv(y) / a) / (1 - (cinv(y) / a)))) 374 | elif torch.isinf(ub): 375 | a = 1 / (1 - 1 / (1 + torch.exp(-lin_soft(lb)))) 376 | b = 1 - a 377 | return lambda x: cinv( 378 | a / (1 + torch.exp(-lin_soft(x))) + b 379 | ), lambda y: lin_soft_inv( 380 | torch.log(((cinv(y) - b) / a) / (1 - ((cinv(y) - b) / a))) 381 | ) 382 | else: 383 | a = ( 384 | 1 385 | + torch.exp(-lin_soft(ub)) 386 | + torch.exp(-lin_soft(lb)) 387 | + torch.exp(-lin_soft(ub) - lin_soft(lb)) 388 | ) / (torch.exp(-lin_soft(lb)) - torch.exp(-lin_soft(ub))) 389 | b = -a / (1 + torch.exp(-lin_soft(lb))) 390 | return lambda x: cinv( 391 | a / (1 + torch.exp(-lin_soft(x))) + b 392 | ), lambda y: lin_soft_inv( 393 | torch.log(((cinv(y) - b) / a) / (1 - ((cinv(y) - b) / a))) 394 | ) 395 | 396 | 397 | def get_default_normalizer(): 398 | default_normalizer_kwargs = { 399 | "lb": torch.tensor(0.0), 400 | "ub": torch.tensor(1.0), 401 | "soft_lb": 0.0, 402 | "soft_ub": 1.0, 403 | "minimize": False, 404 | } 405 | return pfn_normalize(**default_normalizer_kwargs) 406 | 407 | 408 | def identity_normalizer(): 409 | return lambda x: x, lambda x: x 410 | -------------------------------------------------------------------------------- /lcpfn/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.3" 2 | -------------------------------------------------------------------------------- /notebooks/inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "718131ff-a3f1-4e41-918b-57ed59ce5af3", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "ae91ba35-9239-4dce-b767-0264c8e4809b", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "%cd -q .." 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 15, 27 | "id": "2d627399-e989-432e-b12f-f5aa1b62dae6", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import lcpfn\n", 32 | "import torch\n", 33 | "import numpy as np\n", 34 | "from matplotlib import pyplot as plt" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "id": "d6dc8df6-5267-4888-bc5a-e5e942331b79", 40 | "metadata": {}, 41 | "source": [ 42 | "## Load trained LC-PFN model" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "id": "e4188fdb-4a1e-4508-9530-8d2448e238cb", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "model = lcpfn.LCPFN()" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "03d80175-5f10-404a-85ce-59505f043691", 58 | "metadata": {}, 59 | "source": [ 60 | "## Generate a learning curve from the prior" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 14, 66 | "id": "fafc60aa-0681-494f-9be3-09b40e1495ad", 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "data": { 71 | "text/plain": [ 72 | "(0.0, 1.0)" 73 | ] 74 | }, 75 | "execution_count": 14, 76 | "metadata": {}, 77 | "output_type": "execute_result" 78 | }, 79 | { 80 | "data": { 81 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVSklEQVR4nO3df5BV5Z3n8feX/sFPFQZQWVBhI8pSJK7ZDmsqG9dMslFkC9aYilqxHKOllcqYmDW1iRaJumwqqWQSHabKNeBMVo3ZGEXKpWbNWhtGy5TBxNaMGnVw8AfyQ6URJYpAC3z3j76wl6abvsBtrvfp96vqFOc557nnfI9P8/H0ueccIjORJDW/YY0uQJJUHwa6JBXCQJekQhjoklQIA12SCmGgS1IhBgz0iPhpRGyMiD/2sz4i4m8iYnVEPB0RH61/mZKkgdRyhn47cM4B1s8BplemK4FbD78sSdLBGjDQM/MRYPMBuswH7swejwFjI2JSvQqUJNWmtQ7bmAysrWqvqyx7rXfHiLiSnrN4Ro8e/W9mzJhRh91L0tDxxBNPbMrMiX2tq0eg1ywzlwBLADo6OrKzs/NI7l6Sml5ErOlvXT3uclkPnFDVnlJZJkk6guoR6MuBSyp3u5wBbMnM/S63SJIG14CXXCLiF8BZwISIWAfcALQBZOZPgAeAc4HVwHvAlwarWElS/wYM9My8aID1Cfxl3SqSJB0SnxSVpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTiij/5LzSwz2b17N7t372bXrl19zvfVrmWq3vZA6/uaP5T11X/2Xt/XugN97kDz1f/tBvpM7+lQ1x3sZ/bUeDhTLdvY0+e6667j/PPPr/vPqIGug5aZdHd3s2PHDrZv386OHTv2tqv/fP/99+nu7t477Wnv3LmT999/f79pz/KdO3cecNq1a9d+872X1TJVh28ty3Rwhg0bRkQQEfu1+5uvnvpbPtC6A00H+hxw2NusZRsAI0eOHJT/5gZ6oTKT7du3s2XLFt5++23+9Kc/7TO9++67+0xbt27lvffe22/atm0b27ZtY/v27ftMg6WlpYXW1lba2tpobW2lpaVln/nW1tZ+51taWmhpaaG9vX3v/EDTsGHD+ly2Z+rdb89f3j3L+5qv/mx1//7m9/xF76/Pnu0daL6vbVUHTn+f76tPf4FbSyDvmVdjGOhNIjPZsmULGzZs4LXXXuONN95g48aNbNy4ka6uLjZt2sSmTZvYvHkzmzdv5q233mLHjh0DbnfYsGGMHj2a0aNHM2rUqL3zI0eO5Pjjj2fkyJH7TCNGjGDEiBEMHz6c4cOHM2LECNrb2/e2hw8fTltbG8OHD6e9vX3vfFtbG21tbbS3t+8N7OppTzgPG+bXOtKhMtA/QDZv3swLL7zASy+9tHd69dVXWbt2LWvXrmXbtm37faa1tZWJEycyYcIExo8fz4wZMxg/fjzjxo1j3LhxHHPMMYwdO5ajjz5673TUUUdx1FFHMWbMGEaMGOEZlVQIA70B3nnnHZ566imefvppnnnmGZ555hlWrVrFpk2b9uk3adIkTjzxRE477TTmzp3LlClTmDRpEpMmTeL444/nuOOOY+zYsQayJMBAH3SZyapVq3jkkUd49NFH6ezs5Pnnn9/7jffYsWOZNWsW5513HqeeeiqnnHIKJ598MlOnTh20L04klclAHwQbN27kwQcf5Fe/+hW//vWv6erqAuDYY49l9uzZfOELX6Cjo4PTTjuNyZMne4YtqS4M9DpZv349S5cu5d577+XRRx8FegL87LPP5qyzzuLMM8/k5JNPNrwlDRoD/TB0d3dz//33s2TJElasWAHARz7yERYuXMi5557L6aef7l0bko4YA/0QbNq0iUWLFrF48WK6uro48cQTufHGG7nggguYMWNGo8uTNEQZ6Adhw4YN/OhHP2Lx4sW89957zJs3jy9/+ct89rOfpaWlpdHlSRriDPQabNu2jZtuuonvfe977Nixgy9+8Yt861vfYubMmY0uTZL2MtAHsHz5cq6++mpeeeUVPve5z/HDH/6QD33oQ40uS5L24zd2/XjnnXe47LLLmD9/PmPGjGHFihXcd999hrmkDyzP0Pvw29/+losvvpg1a9awYMECrr/+etrb2xtdliQdkIHeyx133MEVV1zBlClTeOSRR/jEJz7R6JIkqSZecqnITK6//nouvfRSzjzzTJ588knDXFJT8Qwd2LlzJ1/60pe46667uPzyy7n11ltpa2trdFmSdFCG/Bn67t27ueKKK7jrrrv47ne/y2233WaYS2pKQ/oMPTO55ppruP3227nhhhtYsGBBo0uSpEM2pM/QFy5cyKJFi/j617/ODTfc0OhyJOmwDNlAX7ZsGTfeeCOXXnopP/7xj30LoqSmNyQD/eWXX+ayyy7jYx/7GIsXL/aNiJKKMOSSrLu7mwsuuACAX/7ylz4wJKkYQ+5L0W9+85s8/vjjLFu2jGnTpjW6HEmqmyF1hv6b3/yGRYsW8bWvfY3zzjuv0eVIUl0NmUDftWsXX/3qVznxxBP5/ve/3+hyJKnuhswll9tuu42nnnqKe++9l1GjRjW6HEmqu5rO0CPinIhYFRGrI+LaPtafGBEPRcQfIuLpiDi3/qUeus2bN7NgwQI+9alPcf755ze6HEkaFAMGekS0ALcAc4CZwEUR0fuf6vk2cE9mng5cCPz3ehd6OL7zne+wZcsWFi1a5P3mkopVyxn6bGB1Zr6Umd3A3cD8Xn0SOLoyfwywoX4lHp5Vq1bxk5/8hK985St8+MMfbnQ5kjRoagn0ycDaqva6yrJqNwIXR8Q64AHgq31tKCKujIjOiOjs6uo6hHIP3s0330xbWxvf/va3j8j+JKlR6nWXy0XA7Zk5BTgX+FlE7LftzFySmR2Z2TFx4sQ67bp/XV1d3HHHHVxyySUce+yxg74/SWqkWgJ9PXBCVXtKZVm1y4F7ADJzJTACmFCPAg/Hrbfeyvbt27nmmmsaXYokDbpaAv1xYHpETIuIdnq+9Fzeq8+rwKcBIuJf0RPoR+aaSj+2b9/OLbfcwty5c5kxY0YjS5GkI2LAQM/MncBVwIPA8/TczfJsRCyMiHmVbt8AroiIp4BfAJdmZg5W0bX4+c9/zsaNGz07lzRkRKNyt6OjIzs7Owdl25nJrFmzaG9v58knn/RWRUnFiIgnMrOjr3VFPin68MMP89xzz3HnnXca5pKGjCLf5bLn8X6fCpU0lBQX6Lt27WLZsmXMnTvXd7ZIGlKKC/RHH32UN954g89//vONLkWSjqjiAn3p0qWMHDmSc8/9QL0fTJIGXVGBvnv3bu677z7mzJnDmDFjGl2OJB1RRQX6ypUr2bBhg5dbJA1JRQX60qVLGT58OHPnzm10KZJ0xBUT6Lt372bp0qWcffbZHH300QN/QJIKU0ygd3Z2sm7dOi+3SBqyign0hx9+GIA5c+Y0thBJapBiAn3lypVMnz6dCRMa/tZeSWqIIgI9M1m5ciUf//jHG12KJDVMEYH+8ssv88Ybbxjokoa0IgJ95cqVAAa6pCGtmEAfM2YMs2bNanQpktQwxQT67NmzaWlpaXQpktQwTR/oW7du5amnnvJyi6Qhr+kDvbOzk127dhnokoa8pg/0PV+InnHGGQ2uRJIaq4hAP+WUUxg/fnyjS5GkhmrqQPeBIkn6/5o60F988UW6uroMdEmiyQP997//PeD1c0mCJg/01atXA3Dqqac2uBJJarymDvQ1a9Zw/PHHM2LEiEaXIkkN1/SBftJJJzW6DEn6QDDQJakQTRvou3fv5tVXXzXQJamiaQP99ddfp7u720CXpIqmDfQ1a9YAGOiSVGGgS1IhDHRJKkRTB/q4ceM4+uijG12KJH0g1BToEXFORKyKiNURcW0/fb4QEc9FxLMR8T/rW+b+vGVRkvbVOlCHiGgBbgH+A7AOeDwilmfmc1V9pgPXAZ/IzLci4tjBKniPNWvWcPLJJw/2biSpadRyhj4bWJ2ZL2VmN3A3ML9XnyuAWzLzLYDM3FjfMveVmZ6hS1IvtQT6ZGBtVXtdZVm1U4BTIuLRiHgsIs7pa0MRcWVEdEZEZ1dX16FVDLz11lu8++67BrokVanXl6KtwHTgLOAi4LaIGNu7U2YuycyOzOyYOHHiIe/slVdeAbzDRZKq1RLo64ETqtpTKsuqrQOWZ+b7mfky8AI9AT8o9tyyOHXq1MHahSQ1nVoC/XFgekRMi4h24EJgea8+99Nzdk5ETKDnEsxL9StzX96DLkn7GzDQM3MncBXwIPA8cE9mPhsRCyNiXqXbg8CbEfEc8BDwXzLzzcEqes2aNYwaNcp/GFqSqgx42yJAZj4APNBr2fVV8wlcU5kG3Z47XCLiSOxOkppCUz4p6i2LkrQ/A12SCtF0gb5161befPNNA12Semm6QPeWRUnqW9MGumfokrQvA12SCtF0gT527Fg++clPMmnSpEaXIkkfKNFzC/mR19HRkZ2dnQ3ZtyQ1q4h4IjM7+lrXdGfokqS+GeiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKUVOgR8Q5EbEqIlZHxLUH6Hd+RGREdNSvRElSLQYM9IhoAW4B5gAzgYsiYmYf/Y4CrgZ+V+8iJUkDq+UMfTawOjNfysxu4G5gfh/9/hvwA2B7HeuTJNWolkCfDKytaq+rLNsrIj4KnJCZ//tAG4qIKyOiMyI6u7q6DrpYSVL/DvtL0YgYBtwEfGOgvpm5JDM7MrNj4sSJh7trSVKVWgJ9PXBCVXtKZdkeRwGzgIcj4hXgDGC5X4xK0pFVS6A/DkyPiGkR0Q5cCCzfszIzt2TmhMycmplTgceAeZnZOSgVS5L6NGCgZ+ZO4CrgQeB54J7MfDYiFkbEvMEuUJJUm9ZaOmXmA8ADvZZd30/fsw6/LEnSwfJJUUkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFqCnQI+KciFgVEasj4to+1l8TEc9FxNMRsSIiTqp/qZKkAxkw0COiBbgFmAPMBC6KiJm9uv0B6MjMjwBLgR/Wu1BJ0oHVcoY+G1idmS9lZjdwNzC/ukNmPpSZ71WajwFT6lumJGkgtQT6ZGBtVXtdZVl/Lgd+1deKiLgyIjojorOrq6v2KiVJA6rrl6IRcTHQAfxVX+szc0lmdmRmx8SJE+u5a0ka8lpr6LMeOKGqPaWybB8R8RlgAfDvM3NHfcqTJNWqljP0x4HpETEtItqBC4Hl1R0i4nRgMTAvMzfWv0xJ0kAGDPTM3AlcBTwIPA/ck5nPRsTCiJhX6fZXwBjg3oj4x4hY3s/mJEmDpJZLLmTmA8ADvZZdXzX/mTrXJUk6SD4pKkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFaKmQI+IcyJiVUSsjohr+1g/PCJ+WVn/u4iYWvdKJUkHNGCgR0QLcAswB5gJXBQRM3t1uxx4KzNPBm4GflDvQiVJB1bLGfpsYHVmvpSZ3cDdwPxefeYDd1TmlwKfjoioX5mSpIG01tBnMrC2qr0O+Lf99cnMnRGxBRgPbKruFBFXAldWmu9GxKpDKRqY0HvbQ8RQPO6heMwwNI97KB4zHPxxn9TfiloCvW4ycwmw5HC3ExGdmdlRh5KaylA87qF4zDA0j3soHjPU97hrueSyHjihqj2lsqzPPhHRChwDvFmPAiVJtakl0B8HpkfEtIhoBy4Elvfqsxz4i8r854F/yMysX5mSpIEMeMmlck38KuBBoAX4aWY+GxELgc7MXA78HfCziFgNbKYn9AfTYV+2aVJD8biH4jHD0DzuoXjMUMfjDk+kJakMPikqSYUw0CWpEE0X6AO9hqAEEXFCRDwUEc9FxLMRcXVl+Z9FxP+NiH+u/Dmu0bXWW0S0RMQfIuLvK+1plddJrK68XqK90TXWW0SMjYilEfFPEfF8RHx8iIz1f678fP8xIn4RESNKG++I+GlEbIyIP1Yt63Nso8ffVI796Yj46MHur6kCvcbXEJRgJ/CNzJwJnAH8ZeU4rwVWZOZ0YEWlXZqrgeer2j8Abq68VuItel4zUZpFwP/JzBnAafQcf9FjHRGTga8BHZk5i54bLi6kvPG+HTin17L+xnYOML0yXQncerA7a6pAp7bXEDS9zHwtM5+szL9Dz1/wyez7ioU7gP/UkAIHSURMAeYCf1tpB/Dn9LxOAso85mOAM+m5U4zM7M7Mtyl8rCtagZGVZ1dGAa9R2Hhn5iP03PlXrb+xnQ/cmT0eA8ZGxKSD2V+zBXpfryGY3KBajojKmytPB34HHJeZr1VWvQ4c16i6BslfA98Edlfa44G3M3NnpV3ieE8DuoD/UbnU9LcRMZrCxzoz1wM/Al6lJ8i3AE9Q/nhD/2N72PnWbIE+pETEGOA+4OuZ+afqdZUHt4q55zQi/iOwMTOfaHQtR1gr8FHg1sw8HdhKr8srpY01QOW68Xx6/of2L4DR7H9ponj1HttmC/RaXkNQhIhooyfMf56ZyyqL39jzK1jlz42Nqm8QfAKYFxGv0HMp7c/pubY8tvIrOZQ53uuAdZn5u0p7KT0BX/JYA3wGeDkzuzLzfWAZPT8DpY839D+2h51vzRbotbyGoOlVrh3/HfB8Zt5Utar6FQt/AfyvI13bYMnM6zJzSmZOpWdc/yEzvwg8RM/rJKCwYwbIzNeBtRFxamXRp4HnKHisK14FzoiIUZWf9z3HXfR4V/Q3tsuBSyp3u5wBbKm6NFObzGyqCTgXeAF4EVjQ6HoG6Rj/HT2/hj0N/GNlOpeea8orgH8Gfg38WaNrHaTjPwv4+8r8vwR+D6wG7gWGN7q+QTjefw10Vsb7fmDcUBhr4L8C/wT8EfgZMLy08QZ+Qc93BO/T89vY5f2NLRD03MX3IvAMPXcAHdT+fPRfkgrRbJdcJEn9MNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIf4fg/wIRPMvMt8AAAAASUVORK5CYII=\n", 82 | "text/plain": [ 83 | "
" 84 | ] 85 | }, 86 | "metadata": { 87 | "needs_background": "light" 88 | }, 89 | "output_type": "display_data" 90 | } 91 | ], 92 | "source": [ 93 | "prior = lcpfn.sample_from_prior(np.random)\n", 94 | "curve, _ = prior()\n", 95 | "plt.plot(curve, \"black\")\n", 96 | "plt.ylim(0, 1)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "id": "18d1ceef-30a8-47c7-9754-ed7bf194cd37", 102 | "metadata": {}, 103 | "source": [ 104 | "## Extrapolate the learning curve with a cutoff of 10% " 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 40, 110 | "id": "3610482a-5e21-4109-a3b3-16d937e26cd9", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# construct\n", 115 | "\n", 116 | "x = torch.arange(1, 101).unsqueeze(1)\n", 117 | "y = torch.from_numpy(curve).float().unsqueeze(1)\n", 118 | "cutoff = 10" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 37, 124 | "id": "31d375db-ec12-41d0-86b7-0ceffebdd908", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "predictions = model.predict_quantiles(\n", 129 | " x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[cutoff:], qs=[0.05, 0.5, 0.95]\n", 130 | ")" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 60, 136 | "id": "e96900b1-c844-4540-9d9c-183429c525f1", 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "data": { 141 | "text/plain": [ 142 | "" 143 | ] 144 | }, 145 | "execution_count": 60, 146 | "metadata": {}, 147 | "output_type": "execute_result" 148 | }, 149 | { 150 | "data": { 151 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAp00lEQVR4nO3de3hU1b3/8fc3M7lxF4hICQIWqsgl3EQsBZV65YdQFHugqCBVTitWqbU+nuKDqOhPqz3SHlFLvVA8WhRRRH9ADypobbkkHCEgIIKCJCoGAkgg91m/P2YmTMIkmcAkQyaf1/PsZ/Zl7b3Xzg4f9qy994o55xARkcYvIdYVEBGR6FCgi4jECQW6iEicUKCLiMQJBbqISJxQoIuIxIlaA93MXjCzb81sSzXLzcz+ZGY7zSzbzAZEv5oiIlKbSK7Q5wNX1bD8aqBHYJgKPHPq1RIRkbqqNdCdcx8C+TUUGQMscH5rgTZm1jFaFRQRkch4o7CNTsDekOmcwLyvqxY0s6n4r+Jp3rz5wPPOOy8Ku29YR44coWXLlrGuhog0URs2bNjvnEsLtywagR4x59w8YB7AoEGDXFZWVkPuPipmzZrFrFmzYl0NETkNBHtOqdqDSuh0uHGvF8xObp9mtqe6ZdEI9Fygc8h0emCeiDRCzp3cUNO64ZaFzvP5/NM+n38InR86HZwXbr3gNquWj6SOwfWCy2uqb+jPyez4vEjGg+t17w49epz8OapONAJ9KXC7mS0ELgQOO+dOaG4RkeNhU/WzpvHqpoNDebl/CJ0Xuqxq+dBlofUI3QdUvoKMZh9+VQOu6r7MKg9V51cdD12/uuWhyxISql+vuvrUVLauDh+GsrJT20Z1ag10M/sbcAnQ3sxygPuBRADn3LPAMmAksBM4BtxcP1UVib7QMKsuFKsOZWXHh+B6ZWXHg7XqELrt2tR0ZRduurYhWC4hofK0x3P8a391ASqNT62B7pybUMtyB0yLWo1EahAajuFCMziUlvqHYPBWHQ+GcLiQDQZnMNxCv1qHBl8wJIOfVceTkk6cL1KfGvSmqDRtwbCt7mq2tBRKSo5/lpUd/wwGNFT/9b9q6Ho84T+TkytPi8QLBbrUiXOVmxxCQ7qsDIqLjw8lJceH0tLw7aahzQvBkE1IqDwkJvqvdhXAIjVToAs+3/Er4OBQUgLHjvmDuajIPxQXH7+ZE3o1HBrKwUAOttF6PNC8ucJYpCEo0OOcc/6ADr1iPnrUPxQW+kO7pOTE9cz8gRwMZa/X31Th8TT8MUjjF+5pnKo3jGua59yJN61Dn+AJLg/3RE/Vp4pqWhbuccmanj6qbqiuHPgvjq67Dnr2jP7PWYEeB5zzh3LwKvrYMf+jUQUF/qFqU4fX62/G8Hr9V8+tW8em3k1daFNVbUPovYfqllXXFFbdZ9V7GdWVDTde9SZ0aFgHy4aWiXfBG99wvKkwOD/cU0QDB9ZPPRTojUx5uT+wCwvh0CH/cPjw8auTYLNHUpI/tNu2rfzcbVMSvNEa/HZSXHy8OSm0bT84HrwJGzo/tAkquDz0CZrQZqrQp2mqfoYbb8i/zx78xhXaFBYc93r9vyPhxj0e/+9RSsrxdULXDW1eCwZZcL3gsnDjwenQz9rmhU6Hu9kdWiZ0ftXx0G2Ehm/oeNXyNY3X1eHD0LGeertSoJ/mCgv9V9kHD8L+/XDkyPEgSEz0N4O0aXP6h7bP5w/UwsLjbfKhQ7CtPvQzNIiLio6PV73pGm68uDh6V4ZmjqSk499s/N9uHImJLhB8rmJITHSkpDg8Hh8ej3+ef9yH1+tISPCRkODD6/UvD057PD4SEsorpkPHPZ5yoHIZs/LAeDlmwSE4vywwXlYxH8oBh8/nwzmHc/7PcNPgKi0LrhdcdnwdV2k63DLnHKWl1W833PTxwVdluur+qGa96utStUxN05HMC9ahprLBMv/+7w/yox+NjM4vZTUU6KeZkhL47jvIy4NvvvEHGfivuFNSoF27hrm5WFZ2vMkmdDh61HHkSDmHD5dRUFBOQYGPY8dcYICiIgsMCRQXJ1BS4qGkxENZ2ck1vicklAaGEjyeEhISSjALfhZXGqAIj6eY1NQiUlOLAf885wpxrggoBIpx7vg8/2cxzhXi8/nn+XyF+HzHKj6dK6v4T0RqZmYVA/g/ExISKqaD49UtD13fLKHKdLhtcMJ6tW/31OdVrUNt6wMkJ6fU+89fgX4aKCqC/HzIzYUDB/zzEhOhRQto1erktllWVkpBwWEOHjzE118f49tvi8jLKyU/v5xDhxzffZfA0aMejh5NpLAwmeLiZEpKUiktTaWsrDk+X7Nqtmz4f228+K/6jgAFIZ9HA5/HAuNHA+PHQsYLq8z3z0tIKMXrLcXrLcHjKSMhoZTERA8ejxePJ/jpJSHBg9d7fNzj8eL1+sf9054TxoPbMEsIzEsiISGFhIR2lconJCQEBk/Ip6fiH3Bw2swq1gvO93g8gFXsJ1jGrPI2g2FVeZsJJ6zrX348OIL7inw8Iey2gvsKX+bEwAqta7hllQNMYkmBHiOlpf7w3rvX35RiBs2aQfv24a/AS0qK2b//a/LyvmL//q/Jz/+W3NxjfPWVj/37PRw8mExBQQsKC1tRUtIWn+9MoAPQvYZaHMEsn4SEw3i9B/B6C2jWrJCkpGJSUopJTi4hNbWU1NRyUlN9NGvmo0ULfz2bN4fUVA/JyckkJiaTlOQfvN4kkpJSSUxsjdebSGJiEl5vEl5vYo2DP8gUCCKnQoHegHw+/03M3Fz46it/W3izZpCWBuXlpeTkfE5Ozk5ycj4nN/dzcnN3k5NTzr59LTh6tCPw/cDQFzgbSK60/YSEYlJSDtO27VFatSqiVasczjjjS9q2Ndq185CWlkRaWjJnnpnKWWc1p0WL5iQkqG93kXihQG8AhYXw9dewZw8UFjry8z8nN3cju3Zl89ln2Xz++VZycr7G5+sDDAAyMLsBOB/nUiu2k5JSTFpaId/7nuPss0vp0sVLp04eOnSADh2gVatkzM6M1WGKSIwp0OtJebn/yZTt24/y4Ydr+OSTD9m5cy3bt2fx3XcHga6YDad5838DLsS5swH/jcPWrX2ce67RvbtxzjnQrRt06QJt2iRT9apcRCRIgR5l333nY9WqjSxZspz165fz6afrKC8vw+x7dOgwidatZ+FcX44caVHx9lifPnD++dCrl//tsbS0hAZ5kkVE4osCPQoKC338z/+s5eWXX+P99xdx4MBXgJezz55Cz54Pk5/fn6++asU33/jfyhw82P+mWP/+cM45ep1eRKJDgX6SnINNm77mz39+gTfffI59+3bj9Z5F9+6/pWPHseze3YUvv0zA64V+/eDaa2HIEPjBD6L3ElB1/UwEl1XtQyLcdOi2IplXXT2qOpVvGJG8QRmtbzC1baem5Se7LNLy1W2jrvPrss/63G+kx1iXn0Us9nE6U6CfhI8//piHHvq/vPXWm/h85/O9791H587XkJubxvbtRtu2cOmlMHy4/2q8efOatxfaJW3oH18IsjA9GwbnV30du+qr0qGvN4d7Zbnq/OB2Q4fQeaH7rvoPoCH+QdQU9jX9BxTJf061lalueW37rcv8+t5+TfOrq0uk9amubCTnI9J1I61/XbZX27aq/rurbl5t5UPr4K2n5DVXXa3q2aBBg1xWVlZM9n2y1q5dy0033cRnn31FYuKf8HqvpbCwDeBv+/7Rj2DoUH97eLir8NDX00P/Uo6Z/y1Q/7Pd/iE5+XgHWtWFdjxcUYicrur6H19d5p/Kv18z2+CcGxRuma7QI7B3717uueceFi5cSGpqKrNn38eLL07irLM8jBjhD/H27SuvE+y3JPhXdsAf2G3a+NvRmzf3h3Zy8vE/VSYip4/G2CyjQK9BaWkpv//973nkkUfw+XzMnDmT0tJSZsz4HXfcAf/8p/+lIJ/veA+IwSvvli2hUyd/b4fBq+76+polIgIK9Gp99tlnTJw4kczMTK677jr+8Ic/0KVLF2bNmgX4/6cuLfV3opWQ4A/uLl38fa80b+5vLhERaUgK9Cqcc7zwwgvccccdJCcns2jRIsaNG3dCuZQUGDTIH97BP7EmIhJLp3kv2g2rvLycO+64g1tuuYUhQ4aQnZ0dNszB33zSsaP/ilxhLiKnAwV6wNGjRxk7dixPPfUUv/nNb1i5ciXp6emxrpaISMTU5ALk5+dzxRVX8PHHHzN37lxuu+22WFdJRKTOmnygFxQUMHLkSDZv3sySJUu45pprYl0lEZGT0qQDvbi4mLFjx5KVlcXixYsV5iLSqDXZQC8vL2fixIm8++67zJ8/nzFjxsS6SiIip6TJ3hSdPXs2ixcv5sknn2TSpEmxro6IyClrkoG+evVqHnzwQW688UamT58e6+qIiERFkwv0vLw8fvazn9G9e3eefvrpWFdHRCRqmlQbus/n46abbiI/P5/ly5fTokWLWFdJRCRqmlSgv/jii6xYsYKnn36ajIyMWFdHRCSqmkyTy+HDh/nd737H0KFD+cUvfhHr6oiIRF2TuUJ/6KGHyMvLY9myZdjp3qmxiMhJiOgK3cyuMrNPzWynmd0bZvnZZrbKzD42s2wzGxn9qp68Tz/9lD/+8Y9MmTKFgQMHxro6IiL1otZANzMPMBe4GjgfmGBm51cpdh/wmnOuPzAeOK0eH7nrrrto1qwZjzzySKyrIiJSbyK5Qh8M7HTOfe6cKwEWAlVfq3RAq8B4a+Cr6FXx1Lz33nssW7aM+++/nzPPPDPW1RERqTeRBHonYG/IdE5gXqhZwA1mlgMsA34VbkNmNtXMsswsKy8v7ySqW3dPPPEEHTp0YNq0aQ2yPxGRWInWUy4TgPnOuXRgJPCSmZ2wbefcPOfcIOfcoLS0tCjtunpbt25lxYoV3H777SQnJ9f7/kREYimSQM8FOodMpwfmhfo58BqAc24NkAK0j0YFT8WcOXNISUnRY4oi0iREEuiZQA8z62ZmSfhvei6tUuZL4McAZtYTf6A3TJtKNfLy8liwYAE33XQT7dvH/P8WEZF6V2ugO+fKgNuBvwPb8D/N8omZPWhmowPFfgPcamabgL8Bk51zrr4qHYlnn32W4uJidb4lIk1GRC8WOeeW4b/ZGTpvZsj4VmBodKt28oqLi5k7dy5XX301PXv2jHV1REQaRFy++v/mm2+yb98+XZ2LSJMSl4G+aNEiOnbsyGWXXRbrqoiINJi4C/SjR4+yfPlyrr32WhIS4u7wRESqFXeJt2LFCgoLC7nuuutiXRURkQYVd4H++uuv0759e4YNGxbrqoiINKi4CvSioiLeeecdxo4di9fbZHoGFhEB4izQV65cSUFBgZpbRKRJiqtAf/3112nTpg2XXnpprKsiItLg4ibQS0pKWLp0KaNHjyYpKSnW1RERaXBxE+irV6/m0KFDjBs3LtZVERGJibgJ9A8++ACv18uPf/zjWFdFRCQm4ibQ16xZQ0ZGBs2aNYt1VUREYiIuAr28vJz169dz0UUXxboqIiIxExeBvmXLFo4ePcqQIUNiXRURkZiJi0Bfs2YNgK7QRaRJi4tAX7t2LWlpaXTr1i3WVRERiZm4CPQ1a9Zw0UUXYWaxroqISMw0+kA/cOAAO3bsUPu5iDR5jT7Q161bB6j9XESk0Qf62rVrSUhIYNCgQbGuiohITDX6QF+zZg19+/alRYsWsa6KiEhMNepALy8vZ926dWo/FxGhkQf6tm3bOHLkiNrPRURo5IG+fv16AF2hi4jQyAN9165deDwezjnnnFhXRUQk5hp1oO/Zs4f09HT9/VARERp5oH/55Zd06dIl1tUQETktNOpA37NnjwJdRCSg0QZ6WVkZubm5nH322bGuiojIaaHRBnpubi7l5eW6QhcRCWi0gb5nzx4ABbqISIACXUQkTjTaQP/yyy8B1IYuIhLQaAN9z549pKWlkZqaGuuqiIicFiIKdDO7ysw+NbOdZnZvNWV+amZbzewTM3slutU8kR5ZFBGprNZXLM3MA8wFLgdygEwzW+qc2xpSpgfwH8BQ59xBMzuzvioctGfPHnr37l3fuxERaTQiuUIfDOx0zn3unCsBFgJjqpS5FZjrnDsI4Jz7NrrVrMw5p7dERUSqiCTQOwF7Q6ZzAvNC/QD4gZn908zWmtlV4TZkZlPNLMvMsvLy8k6uxkBeXh6FhYUKdBGRENG6KeoFegCXABOAv5hZm6qFnHPznHODnHOD0tLSTnpnesJFROREkQR6LtA5ZDo9MC9UDrDUOVfqnPsC2IE/4OuFnkEXETlRJIGeCfQws25mlgSMB5ZWKbME/9U5ZtYefxPM59GrZmUKdBGRE9Ua6M65MuB24O/ANuA159wnZvagmY0OFPs7cMDMtgKrgN865w7UV6X37NlDixYtOOOMM+prFyIijU5EfxnCObcMWFZl3syQcQfcFRjqXfAZdDNriN2JiDQKjfJN0T179uiGqIhIFY0y0PUMuojIiRpdoBcUFJCfn69AFxGpotEFup5wEREJT4EuIhInGm2g66aoiEhljS7QzzjjDIYNG0bHjh1jXRURkdNKowv08ePH8+GHH+LxeGJdFRGR00qjC3QREQlPgS4iEicU6CIicUKBLiISJxToIiJxQoEuIhInFOgiInFCgS4iEicU6CIicUKBLiISJxToIiJxQoEuIhInFOgiInFCgS4iEicU6CIicUKBLiISJxToIiJxQoEuIhInFOgiInFCgS4iEicU6CIicUKBLiISJxToIiJxQoEuIhInFOgiInFCgS4iEicU6CIicSKiQDezq8zsUzPbaWb31lDuOjNzZjYoelUUEZFI1BroZuYB5gJXA+cDE8zs/DDlWgJ3AuuiXUkREaldJFfog4GdzrnPnXMlwEJgTJhyDwGPAUVRrJ+IiEQokkDvBOwNmc4JzKtgZgOAzs65/1fThsxsqpllmVlWXl5enSsrIiLVO+WbomaWAPwn8Jvayjrn5jnnBjnnBqWlpZ3qrkVEJEQkgZ4LdA6ZTg/MC2oJ9AZWm9luYAiwVDdGRUQaViSBngn0MLNuZpYEjAeWBhc65w4759o757o657oCa4HRzrmseqmxiIiEVWugO+fKgNuBvwPbgNecc5+Y2YNmNrq+KygiIpHxRlLIObcMWFZl3sxqyl5y6tUSEZG60puiIiJxQoEuIhInFOgiInFCgS4iEicU6CIicUKBLiISJxToIiJxQoEuIhInFOgiInFCgS4iEicU6CIicUKBLiISJxToIiJxIqLeFkUkOkpLS8nJyaGoSH96V2qWkpJCeno6iYmJEa+jQBdpQDk5ObRs2ZKuXbtiZrGujpymnHMcOHCAnJwcunXrFvF6anIRaUBFRUW0a9dOYS41MjPatWtX529yCnSRBqYwl0iczO+JAl1EJE4o0EWakEOHDvH000/X+36WLFnC1q1b630/UpkCXaQJqWugO+fw+Xx13o8CPTb0lItIjEyfPp2NGzdGdZv9+vVjzpw51S6/99572bVrF/369ePSSy8lOzubgwcPUlpayuzZsxkzZgy7d+/myiuv5MILL2TDhg0sW7aMBQsW8N///d+kpaXRuXNnBg4cyN13382uXbuYNm0aeXl5NGvWjL/85S/k5+ezdOlSPvjgA2bPns3ixYv5/ve/H9XjlPAU6CJNyKOPPsqWLVvYuHEjZWVlHDt2jFatWrF//36GDBnC6NGjAfjss8/461//ypAhQ8jMzGTx4sVs2rSJ0tJSBgwYwMCBAwGYOnUqzz77LD169GDdunXcdtttvP/++4wePZpRo0Yxbty4WB5uk6NAF4mRmq6kG4Jzjt/97nd8+OGHJCQkkJuby759+wDo0qULQ4YMAeCf//wnY8aMISUlhZSUFK655hoACgoK+Ne//sX1119fsc3i4uKGPxCpoEAXaaJefvll8vLy2LBhA4mJiXTt2rXiuefmzZvXur7P56NNmzZRbzaSk6eboiJNSMuWLTly5AgAhw8f5swzzyQxMZFVq1axZ8+esOsMHTqUt99+m6KiIgoKCnjnnXcAaNWqFd26dWPRokWA/4p/06ZNJ+xHGo4CXaQJadeuHUOHDqV3795s3LiRrKws+vTpw4IFCzjvvPPCrnPBBRcwevRo+vbty9VXX02fPn1o3bo14L/Kf/7558nIyKBXr1689dZbAIwfP57HH3+c/v37s2vXrgY7vqZOTS4iTcwrr7xSa5ktW7ZUmr777ruZNWsWx44dY/jw4RU3Rbt168aKFStOWH/o0KF6bDEGFOgiUqupU6eydetWioqKmDRpEgMGDIh1lSQMBbqI1CqSq3qJPbWhi4jECQW6iEicUKCLiMQJBbqISJxQoIs0MR6Ph379+lUMjz76aI3lH3nkkQap16xZs3jiiSdqLFO1F8eZM2fy7rvvnvK+58+fz+23337S6wd/pr179+b666/n2LFjleYHh927d7N69WrMjLfffrti/VGjRrF69epTPYzIAt3MrjKzT81sp5ndG2b5XWa21cyyzew9M+tyyjUTkXqRmprKxo0bK4Z77z3hn3Ql1QX6yXateyqqBvqDDz7IZZdd1qB1CCf4M92yZQtJSUk8++yzleYHh65duwKQnp7Oww8/HPV61PrYopl5gLnA5UAOkGlmS51zoW8NfAwMcs4dM7NfAr8H/i3qtRWJI9OnQ7S7QenXD06mz6/Dhw8zePBgli5dyrnnnsuECRMYMWIEu3btorCwkH79+tGrVy8efvjhE7rWffTRR8nMzKSwsJBx48bxwAMPANC1a1d++tOfsnz5clJTU3nllVfo3r07u3fvZsqUKezfv5+0tDRefPFFzj777Er1+ctf/sK8efMoKSmhe/fuvPTSS2zcuPGEbnkfeuihil4d33vvPe6++27Kysq44IILeOaZZ0hOTqZr165MmjSJt99+m9LSUhYtWhT2rdi9e/dyySWXkJubyw033MD999/PzJkzadu2LdOnTwdgxowZnHnmmdx5553V/iyHDRtGdnZ2jT/vjIwMSktLWblyJZdffnkdz1b1IrlCHwzsdM597pwrARYCY0ILOOdWOeeOBSbXAulRq6GIRFUwoIPDq6++SuvWrXnqqaeYPHkyCxcu5ODBg9x66608+uijFVeZL7/8MuDvWve2227jk08+oUuXLjz88MNkZWWRnZ3NBx98UCnMWrduzebNm7n99tsrQvFXv/oVkyZNIjs7m4kTJ3LHHXecUMdrr72WzMxMNm3aRM+ePXn++ef54Q9/yOjRo3n88cfZuHFjpT7Wi4qKmDx5Mq+++iqbN2+mrKyMZ555pmJ5+/bt+d///V9++ctfVtuss379ehYvXkx2djaLFi0iKyuLKVOmsGDBAsDfGdnChQu54YYbqv3ZlpWVsXz5cvr06XPCz3rs2LGVys6YMYPZs2fXdKrqLJIXizoBe0Omc4ALayj/c2B5uAVmNhWYCpzwP7JIUxOr3nODAV3V5ZdfzqJFi5g2bVpFJ1vhhHatC/Daa68xb948ysrK+Prrr9m6dSt9+/YFYMKECRWfv/71rwFYs2YNb7zxBgA33ngj99xzzwn72LJlC/fddx+HDh2ioKCAK6+8ssZj+vTTT+nWrRs/+MEPAJg0aRJz586t+E/k2muvBWDgwIEV+w53/O3ataso/9FHHzF9+nTatWvHxx9/zL59++jfv39FmVDB4Ab/FfrPf/5zoPqfNcDw4cMB+Oijj2o8trqI6puiZnYDMAi4ONxy59w8YB7AoEGDXDT3LSKnxufzsW3bNpo1a8bBgwdJTw//RTu0a90vvviCJ554gszMTM444wwmT55c0QUvVP7L9XX5K/aTJ09myZIlZGRkMH/+/FO+YZicnAz4b1KWlZWFLVO1fsHpW265hfnz5/PNN98wZcqUsOvWFNw1CV6le73RieJImlxygc4h0+mBeZWY2WXADGC0c0693Is0Mk8++SQ9e/bklVde4eabb6a0tBSAxMTEivGqvvvuO5o3b07r1q3Zt28fy5dX/nL+6quvVnxedNFFAPzwhz9k4cKFgL+3xmHDhp2w3SNHjtCxY0dKS0srmnqg+m55zz33XHbv3s3OnTsBeOmll7j44rDXldVauXIl+fn5FBYWsmTJEoYOHQrA2LFjWbFiBZmZmbV+U6irK664goMHD9ba5h6pSP5byAR6mFk3/EE+HvhZaAEz6w/8GbjKOfdtVGomIvUitHkA4KqrruLmm2/mueeeY/369bRs2ZLhw4cze/ZsHnjgAaZOnUrfvn0ZMGDACU9mZGRk0L9/f8477zw6d+5cEYJBBw8epG/fviQnJ/O3v/0NgP/6r//i5ptv5vHHH6+4KVrVQw89xIUXXkhaWhoXXnhhRYiPHz+eW2+9lT/96U+8/vrrFeVTUlJ48cUXuf766ytuiv7iF7+o089l8ODBXHfddeTk5HDDDTcwaNAgAJKSkrj00ktp06YNHo+nTtuMxIwZMxgzZkztBSPhnKt1AEYCO4BdwIzAvAfxX40DvAvsAzYGhqW1bXPgwIGuMbr//vtjXQVpxLZu3RrrKjSYLl26uLy8vFhX45SVl5e7jIwMt2PHjgbfd7jfFyDLVZOrETXcOOeWAcuqzJsZMh77B0FFRKJs69atjBo1irFjx9KjR49YV6dW6j5XROrF7t27Y12FU3b++efz+eefx7oaEdOr/yIicUKBLiISJxToIiJxQoEuIhIndFNUJIbWroVDh6K3vTZtIOSt/LC++eYbpk+fTmZmJm3atKFDhw7MmTOHpKQkRo0axZYtWyLe329/+1uWLVvGyJEjefzxxyvmHzx4kClTprBr1y5SUlJ44YUX6N27NwArVqzgzjvvpLy8nFtuuaWit8eJEyeyefNmRo0aVdHD4+zZs+nduzc/+clP6vRzaKoU6CIxdOgQpKVFb3t5eTUvd84xduxYJk2aVPG25qZNm9i3bx+dO3eueeUw5s2bR35+/gkv3DzyyCP069ePN998k+3btzNt2jTee+89ysvLmTZtGitXriQ9PZ0LLriA0aNHU1ZWRmpqKtnZ2Vx++eUcPnyYY8eOsW7dOu67774616upUpOLSBOyatUqEhMTK71FmZGREfb1+yDnHL/97W/p3bs3ffr0qXidf/To0RQUFDBw4MCKeUFbt25lxIgRAJx33nns3r2bffv2sX79erp3784555xDUlIS48eP56233iIxMZHCwkJ8Ph+lpaV4PB5mzpxZ0RWvREZX6CJNyJYtWxg4cGCd1nnjjTfYuHEjmzZtYv/+/VxwwQUMHz6cpUuX0qJFi7CdUmVkZPDGG28wbNgw1q9fz549e8jJySE3N7fSN4H09HTWrVtHz549SUtLY8CAAdx4443s3LkTn8/HgAEDTvWQmxQFuojU6KOPPmLChAl4PB46dOjAxRdfTGZmJqNHj652nXvvvZc777yTfv360adPH/r3719rPyhzQvoTvuaaa/jzn//Mww8/zKZNm7j88su59dZbo3VIcUuBLtKE9OrVq1KnVvWlVatWFZ1uOefo1q0b55xzDoWFhezde/zPK+Tk5NCpU6dK67711lsMHDiQgoICdu3axWuvvcaVV17JxIkTadasWb3XvTFTG7pIEzJixAiKi4uZN29exbzs7Gz+8Y9/VLvOsGHDePXVVykvLycvL48PP/yQwYMH17ifQ4cOUVJSAsBzzz3H8OHDadWqFRdccAGfffYZX3zxBSUlJSxcuLDSlX5paSlz5szhnnvuobCwsKJP8vLy8ortSfV0hS4SQ23a1P5kSl23VxMz480332T69Ok89thjpKSk0LVr10rNHVWNHTuWNWvWkJGRgZnx+9//nrPOOqvG/Wzbto1JkyZhZvTq1Yvnn38eAK/Xy1NPPcWVV15JeXk5U6ZMoVevXhXrzZ07l0mTJtGsWTP69u3LsWPH6NOnDyNHjqRNbQcnmL83xoY3aNAgl5WVFZN9n4pZs2Yxa9asWFdDGqlt27bRs2fPWFdDGolwvy9mtsE5NyhceTW5iIjECQW6iEicUKCLiMQJBbqISJxQoIuIxAkFuohInFCgi0hYu3fv5pVXXomo7IQJE+jbty9PPvkk27dvp1+/fvTv359du3bVcy0llAJdRMKKNNC/+eYbMjMzyc7O5te//jVLlixh3LhxfPzxx3z/+99vgJpKkAJdpIlZsGABffv2JSMjgxtvvJHJkydX6t+lRYsWgL+DrX/84x/069ePJ598kqKiIm6++eaKzrZWrVoFwBVXXEFubi79+vXjgQceYM6cOTzzzDNceumlMTm+pkyv/ovE0OrVq1m9enXUtnfJJZdwySWXVLv8k08+Yfbs2fzrX/+iffv25Ofnc9ddd4Ut++ijj/LEE0/wzjvvAPCHP/wBM2Pz5s1s376dK664gh07drB06VJGjRpV0Y2uc44WLVpw9913R+24JDIKdJEYqi2Ao+3999/n+uuvp3379gC0bds24nU/+ugjfvWrXwH+P1rRpUsXduzYQatWreqlrlJ3anIRaeK8Xi8+nw8An8+nXg0bMQW6SBMyYsQIFi1axIEDBwDIz8+na9eubNiwAYClS5dSWloKQMuWLTly5EjFusOGDePll18GYMeOHXz55Zece+65DXwEUhM1uYg0Ib169WLGjBlcfPHFeDwe+vfvz2OPPcaYMWPIyMjgqquuonnz5gD07dsXj8dDRkYGkydP5rbbbuOXv/wlffr0wev1Mn/+fJKTk2N8RBJK3efWkbrPlVOh7nOlLtR9rohIE6VAFxGJEwp0kQYWq2ZOaVxO5vdEgS7SgFJSUjhw4IBCXWrknOPAgQOkpKTUaT095SLSgNLT08nJySEvmn8ZWuJSSkoK6enpdVpHgS7SgBITE+nWrVusqyFxKqImFzO7ysw+NbOdZnZvmOXJZvZqYPk6M+sa9ZqKiEiNag10M/MAc4GrgfOBCWZ2fpViPwcOOue6A08Cj0W7oiIiUrNIrtAHAzudc58750qAhcCYKmXGAH8NjL8O/NjMLHrVFBGR2kTSht4J2BsynQNcWF0Z51yZmR0G2gH7QwuZ2VRgamCywMw+PZlKA+2rbrshPfDAA7HYbUyPOUZ0zE2DjrluulS3oEFvijrn5gHzTnU7ZpZV3auv8UrH3DTomJuG+jrmSJpccoHOIdPpgXlhy5iZF2gNHIhGBUVEJDKRBHom0MPMuplZEjAeWFqlzFJgUmB8HPC+05sTIiINqtYml0Cb+O3A3wEP8IJz7hMzexDIcs4tBZ4HXjKznUA+/tCvT6fcbNMI6ZibBh1z01Avxxyz7nNFRCS61JeLiEicUKCLiMSJRhfotXVDEA/MrLOZrTKzrWb2iZndGZjf1sxWmtlngc8zYl3XaDIzj5l9bGbvBKa7BbqS2BnoWiIp1nWMJjNrY2avm9l2M9tmZhc1gXP868Dv9BYz+5uZpcTbeTazF8zsWzPbEjIv7Hk1vz8Fjj3bzAacyr4bVaBH2A1BPCgDfuOcOx8YAkwLHOe9wHvOuR7Ae4HpeHInsC1k+jHgyUCXEgfxdzERT/4IrHDOnQdk4D/2uD3HZtYJuAMY5Jzrjf8hi/HE33meD1xVZV515/VqoEdgmAo8cyo7blSBTmTdEDR6zrmvnXP/Gxg/gv8feicqd7HwV+AnMalgPTCzdOD/AM8Fpg0Ygb8rCYi/420NDMf/hBjOuRLn3CHi+BwHeIHUwPsqzYCvibPz7Jz7EP/TfqGqO69jgAXOby3Qxsw6nuy+G1ugh+uGoFOM6tIgAj1X9gfWAR2cc18HFn0DdIhVverBHOAewBeYbgcccs6VBabj7Vx3A/KAFwPNTM+ZWXPi+Bw753KBJ4Av8Qf5YWAD8X2eg6o7r1HNtMYW6E2KmbUAFgPTnXPfhS4LvLgVF8+cmtko4Fvn3IZY16UBeYEBwDPOuf7AUao0r8TTOQYItBuPwf+f2feA5pzYNBH36vO8NrZAj6QbgrhgZon4w/xl59wbgdn7gl/HAp/fxqp+UTYUGG1mu/E3o43A377cJvDVHOLvXOcAOc65dYHp1/EHfLyeY4DLgC+cc3nOuVLgDfznPp7Pc1B15zWqmdbYAj2SbggavUD78fPANufcf4YsCu1iYRLwVkPXrT445/7DOZfunOuK/5y+75ybCKzC35UExNHxAjjnvgH2mtm5gVk/BrYSp+c44EtgiJk1C/yOB485bs9ziOrO61LgpsDTLkOAwyFNM3XnnGtUAzAS2AHsAmbEuj71dIw/wv+VLBvYGBhG4m9Xfg/4DHgXaBvrutbDsV8CvBMYPwdYD+wEFgHJsa5flI+1H5AVOM9LgDPi/RwDDwDbgS3AS0ByvJ1n4G/47xGU4v8m9vPqzitg+J/c2wVsxv8E0EnvW6/+i4jEicbW5CIiItVQoIuIxAkFuohInFCgi4jECQW6iEicUKCLiMQJBbqISJz4/3OT6k3SfLRMAAAAAElFTkSuQmCC\n", 152 | "text/plain": [ 153 | "
" 154 | ] 155 | }, 156 | "metadata": { 157 | "needs_background": "light" 158 | }, 159 | "output_type": "display_data" 160 | } 161 | ], 162 | "source": [ 163 | "# plot data\n", 164 | "plt.plot(curve, \"black\", label=\"target\")\n", 165 | "\n", 166 | "# plot extrapolation\n", 167 | "plt.plot(x[cutoff:], predictions[:, 1], \"blue\", label=\"Extrapolation by PFN\")\n", 168 | "plt.fill_between(\n", 169 | " x[cutoff:].flatten(),\n", 170 | " predictions[:, 0],\n", 171 | " predictions[:, 2],\n", 172 | " color=\"blue\",\n", 173 | " alpha=0.2,\n", 174 | " label=\"CI of 90%\",\n", 175 | ")\n", 176 | "\n", 177 | "# plot cutoff\n", 178 | "plt.vlines(cutoff, 0, 1, linewidth=0.5, color=\"k\", label=\"cutoff\")\n", 179 | "plt.ylim(0, 1)\n", 180 | "plt.legend(loc=\"lower right\")" 181 | ] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "lcpfn", 187 | "language": "python", 188 | "name": "lcpfn" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.10.11" 201 | } 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 5 205 | } 206 | -------------------------------------------------------------------------------- /notebooks/training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "97c21741-fc4a-4116-99ce-e400a17d3727", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "028b0665-850b-4add-bd8a-fffea70cc9bc", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "%cd -q .." 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "id": "b816ca61-893f-47fc-8b2c-5730ef41199d", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import lcpfn\n", 32 | "import numpy as np\n", 33 | "from matplotlib import pyplot as plt" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "d8cabdf7-ab64-492e-a0c8-7d1ef8f49f95", 39 | "metadata": {}, 40 | "source": [ 41 | "## Generate samples from LC prior" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "id": "c0218450-7968-44db-8a4a-a25b6a02dfb7", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "text/plain": [ 53 | "(0.0, 1.0)" 54 | ] 55 | }, 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "output_type": "execute_result" 59 | }, 60 | { 61 | "data": { 62 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAABY5UlEQVR4nO29e6zsSHof9qvuJtlN9uv0ec29d+7M7OyOIgkrJdpMtDIUxIIlBSsl0AZIkKwUI3aiZIDAmyi240CGA0VR/pHysCMDCzmT9VqyEGgTK4YzcDbeJLIMAUGk7CpKFO3Ku55dr2bOzNx7Xv1uNtnNrvzR56tbXafIJrvZ5xyeWz+AIFkskkU2+1dfffU9GOccBgYGBgbFR+m2G2BgYGBgkA8MoRsYGBjcExhCNzAwMLgnMIRuYGBgcE9gCN3AwMDgnsAQuoGBgcE9wVpCZ4x9jjF2yhj7g5jjjDH2VxljbzPGfp8x9rH8m2lgYGBgsA5pJPRfBvCJhOM/AuC1q+UNAL+0fbMMDAwMDLJiLaFzzn8LwGVClU8C+Jt8id8G0GaMPcirgQYGBgYG6VDJ4RqPALwr7Z9clX2gVmSMvYGlFA/P8/7pb//2b8/h9gYGBgbPD373d3/3nHN+qDuWB6GnBuf8TQBvAsDrr7/Ov/zlL9/k7Q0MDO4I5JAjabbzOGfT6+9i3/M8OI6DTcAY+6O4Y3kQ+nsAHkv7L16V3SgWiwWiKLpWzjlHGIbihTLGUC6XUSqVUC6XUancaJ9mYCBA3+Sm603rpK2/yfXSbhcRjLHc9nf1LvJgs7cAfJox9nkAHwfQ55xfU7fsArPZDN1uF1EUXftwxuMxgiDAbDbTnksvt1QqoVKpwLZt2LYNx3FgWZaooy5qedy+XK7b1q0N8oFMOOp22mO6OnHnqcfTHLstJH17SWXydqlUSjye5Vppr5PH+VmO6faLgLWEzhj7NQA/AOCAMXYC4D8BYAEA5/yvAfgCgB8F8DaACYB/c1eNVTGZTBBFETzPE1J3FEW4vLxEpVKB67qCqMvlMq7ajCiKMJ/PxTKbzRCGIUajEYbDIRhjsCwLlmXBcRyUy+Vrf/pdYB3hb/MnTLudx75aFkeecSS7DfnKv0/aP2RSvbh3XyqVtJ01kZ3uvKRjadbb1jW4/1hL6JzzH19znAP4M7m1KAOCIIDjOGg2m2K/3+/Dtm0cHx9n1lHN53OEYYgwDBEEwYoKp1qtwnEcOI6DUqmUinR020lladZx11ksFrH1krZ1xzjnWCwWK2vdknRs2w4w7ahIPRZXP+6cddsq6Hl06r2bxm0StekktkOz2YTrurlft7AKZJKyPc8DAIRhiIuLC1QqFXQ6nY1045VKRUj2wDOCD4IA0+kUk8kEAITk7jgObNu+cx83ka26yOW6bSJgmmega6nXBp4Rnyx1rlvi6unKdffWtWPde9jm+Lb1tz3vpq53V+95l7Ht+9jV3F1hCX06nQKAkMKHwyHK5TIODw9zI1iV4GezmSD30WiE0WgExpggd8dxdvJDyeQbRZF2WyXuOMhErC5ULq/jtg0MDO4eCkvoQRAIKxUi2mazuVOyIb16vV4H5xxBEIiFOphyuSwmV0n/ngTS6asLETZt60CSNFnsWJalJWqVpA0MDO4nCknoRKYkOQ+HQ5RKJaF+uQkwxlCtVlGtVgEsVUAywfu+D2Ap5VuWJYgXgJiMJcJWQQRNJE3bVC6TtIGBgQGhkIQ+m83AOYfjOJjP55hOp2g0GrdKcOVyWVjVVKtVTCYTTCYTDAYDTKdTIWWXy2VUq1XUajXUajW4risImxZD1AYGBpugkIRO6g3btjEYDMAYu1HpnBBFkbCKmc1moqMhOI6Der2OSqUiVCuLxQKz2UysoyiCbduwLAu2bRvJ28DAYGMUktCDIIBt2+Ccw/d9uK674uywK8hqlTAMhbqE7NZJQqfJ1CRilk0kwzAUnRTwTE1DRG9ZliF5AwODtSgcoZNk22g0MJ1OwTnfqXROZDudTjGfzwEsddxkskikmxWqBQ09F0n7sh6e6hO5y3p5AwMDA0LhCD0IAgAQemrGWO6mgvP5HJPJBL7vCynccRy4rotqtboT00TqJGRnqCiKhCqHyF4meRoZUOdA24boDQyeTxSO0AGIeCtRFOVGrpxzTKdTjMdjhGEo7tNoNFCtVm9EpaOCJknJkgZYSvIUroDW8qQr8MyckYieSN6QvYHB/UbhCJ2sQ4ClJL0toZMefjgcIooilMtl4ZZ7GyS+DqVSSah6ZBDRq0sQBNfim6hWNcbKxsDgfqBwhC4jiqIV6TUriMjn8zls20ar1drqereJOKIHnoVJkNdRFF2T7OVrybbwqv27vG1gYHB3UFhCp5C5m6gQoihCr9dDEASwLAudTqewRJ4GRMw6qJ6qqoeqGqRMRZxnquqhqgspYGBgkC8KTehA9iA3k8kE/X4fANBqtW7Ffv0ugSaV173HpPgxVDafz1cCfa27b1LMGHnRlcvXiDtmcDNYFxFULctyfNvtvM/Jcixpv9VqmWiLMsiEMAuh9/t9jMdjOI6DdrttJgiR7YOU9e/r6qeJ+EidgRr5UQ7Pm7bdumO6kLjysySR/rqOQXc9tV1yHTXUMWMsdYhltTxuHXdO3DXk96uGX9a937i26erEYZNzk9qSFRRNlEC/g1pOx3Tb8nm6Y/J7Ur8B2n/11VcNocsgQk9DypxzdLtdTKdT1Ot11Ot1oU9WJcqkDzqpPOn4pr14lmOb1rsriNPJy0SlW6gOraljAHCtE5HrJF1HrUtYR4A6skozUkmDNPXiOjC1c0o6Fne9pGvpzlWTeajbcVmPkrIhxbVHHZ3pzk9bnvQesiCuLbSmHA55o7CETiaLup5VxmKxwMnJCSaTCRzHAWMMo9Fo6/unke6SyrJKAwSV9OLO2/T629RNKksi47j9tOXy8aQ2pDkmH9epd3TbujL6LtOSYtxad81119sGeXYw9wFx31rSsTSjqU0TRK9DYQmdbLCfPHmCw8PDa6qXxWKB8XiM9957D5PJBHt7e2g2m9eiFxodrB6yFKsmwdAdi9tPGmInYZ0uPa5O3HFdfblM3Ta4Gazr2OO2s5Bqmu24sl2hVCpt5GG+DoUldJoU5XyZELrVaoljQRDg8vISo9EIURThlVdeQafTua2m3hpkFUTWDEZpoCNasmNPIuM0i8HukTT6iTuWpixLnTyRZQSlZtrKen5SWZpju0IhCZ2Ih3J7TiYTNBoNlEolzGYzXF5eIooiOI6D/f39e0PmMunKJoZJhB0H+qiJaCn2uky+69YG+SOOXDddkq65LdKMhHQpCuX9PI7JbXneUUhCJ+m8VCoJK4nJZIJarYaLiwsAyx+4Wq1ib2/vNpuaGpzfTOYiYwe+HXTEmEci7W1Idp26aZ1qMcsxufy+Qff+1bI0ddKcl+Qbsg0KSeiqhUulUsFwOMR4PAawjJMeBAE6nc6d+vBk93zVezPOY1PNXCR7asrzAAZ66Mg1bg5g3XxAkhllHOKIdZ0KSj533TY9p/zMWbbV97Su3i6P5bG/6Tk3iV35wBSa0OmP4Xkeut0uGGN48OAB+v0+XNfdWWbtdaAoiWpcFZUQKGBWtVo1MVU0UEk47SLbtifpbNX9dYQKXDfHSyPB0n3k319HcLdNMptAfk71e03az1p3Xf111896z23PW1e2K24qJKFTEC3OuVCtEGFS1qB6vb7zdnDOhbWNvMh/TCLtWq12LerhXSfsdXpYtVy3r+r5Vd2/rGrSTdjSNeW1DF1IAZ2+X11UaVmFjhx0+to0x9Kem3U763nb3Ctu3+BuoZCEPp/PBaHT2rIsBEGAbreLdru9kx4wKeUcY8+yFskxyncZwGqdrjar/nadPlcnKa8ro3dDa51kS6ojWa1E5fLxuPkA9R55rA0MiohCEjpZsFBwrSAIUK1WRfTEF198Mbf7ULo5OUgVkbfneStZhLJCJt40JoU6kk6LJF2t2ha1XfL95PNkE0VZr6/T8SeRsYGBQT4oHKHTEL1cLouYGJSPs1arCQl6U4KllHNBEKyknLNtG/V6XeQMTSIjnXWKTvWwjpBVtQGpa3TqBFW1QFB1y2ksZ1SCVolYV2ZgYHD7KByhy1EWOecolUqYTCbgnKNWq8F1XQyHQ9RqtVQSIOdc5O+kHKWMMdi2Ddd1RXYk9Rx50pPIkfT4cbpeWbXgOM7akLNJkEmZVEGqyaMKuq6sx48jbiM9GxgUD4UjdFlqpv0oioRUube3h/Pzc4zH48SJ0TAMMR6PBYmXSiXUajVUq1UR84Wu7/u+0JnT/WQQMdq2fc1Shchy02fVJaegoGIy6PnL5TIcx7l2f9o2MDC4vyg8oVP+T8YYHMeBbduoVqtCSldDvY7HY0wmE8znczDGREo7CpYzm81EXtEwDFcm9iqVCmzbXpn03NZaRbZNlxdVJSOrXKjjUKVsAwOD5xuFI3THcVbitoRhKEibSLnZbOLs7AxnZ2fodDqwLAvj8Rij0QiLxQK2baPdbovcpGQdEwSBIHCyD5cJfBviltU08lrWYVOnQfeWOw0Tu93AwGAdCkfoZFVC5Dufz2FZFhaLhSD0SqWCw8NDXFxc4OTkREjvjuOg0WjAtm3M53MMBgP4vi/iwjiOI1Qu20i8ZA8vLzSyAJ5ZyRBpU4dhSNvAwGAbFI7QCTSZSdYsFL+EUCqVUKlUMJlMAEBI2mEYot/vIwxDlEolVKtVoXJZJ4HHWaXM53Nh2hiG4Qp5k9u+bOIYl/EnKXfnOty2l+Ft33/XuO/PZ6DHrn73XQlwhSV0MsVjjCGKopV0TvP5HBcXF4iiCC+99JKwhDk5OcF0OkWpVILrunBdF9PpVJg9ykj6IWUCl9UmZMEiL2QKSbp+AwMDAxPLRQE5usxmM1QqFaFuWSwWuLy8BOcch4eHsCxLqFWazSYODw/huu5a5xzV6SYIAkyn05WJ0lqthlarBdu2hX26em5a5GEmeNumhjd1/3XerUllum3del1Z3H5Sm7c5fhvY9ve8zfOznnsTbb2J/0cqQmeMfQLALwIoA/gs5/znleMvAfgVAO2rOj/NOf9Cvk1dhepabts2OOciFvr+/j4sy8JgMMBoNIJt2zg4OEjtcLRYLDCdTuH7PsIwFOEF6vW60McbnXd2qJ6x6zxi88qAJEPn7q+GJZDXSdtp9teVb1pvF8ijY9n0Gtvcu0j3BNLlQt4Ea9mNMVYG8BkAPwzgBMCXGGNvcc6/KlX7jwH8D5zzX2KMfSeALwB4ZQftFSAyn81mIjFDt9tFGIbY29uDbdvo9/sYj8fwPG/FMiYOnHNMp1NMJhMEQQBg+eI9zxMWL/cJSZKtvE/b8nnqdlwc9yxJN4Bn3rFxgbSSynXnA9cJm5CFOO9C3W3O2fbcuzSC3EWHdxOd6F2R0L8XwNuc828CAGPs8wA+CUAmdA6A0li3ALyfZyN1kCU7x3Hg+z5830ej0UCtVkO324Xv+6jX62szbM/nc4zHY6GaKZfLqNfrqNVqO8n7lzcoHEJSBqM4aTfL9ZOSb+igxnJR87hm9Y6V20PPY2BQRNymDv0RgHel/RMAH1fq/CyA/5Ux9u8B8AD8kO5CjLE3ALwBAC+99FLWtq6ATBbJ3HA4HKJcLqPRaGA0GglybzQasdcIwxCj0QjT6RSMLcPwkrv/XQN5iKrhBpIyGalkKceB0S1E3GoSDrq+zvtUF0LgNuzmswx/sw6Vd3HtTYbrt6FWyOP8u9KGXV4v6z12NdrPa1L0xwH8Muf8v2KM/TEAv8oY+yjnfIVpOOdvAngTAF5//fWt3ijZelerVZRKJQRBgEajgfl8juFwiGq1Gkvms9kMg8EAQRCgVCqh0WjAdd07oxOn2CyyHbvOAUkNN6DGZEmCLo677h5kK6+S913DbU8IGxjcBaQh9PcAPJb2X7wqk/GTAD4BAJzz/5MxVgVwAOA0j0bqQMPuUqkkzA5d10Wv1wMArc48iiLhTERu/9VqFZwvc5Kum4DT9bhJelndtm69WCyEaWNcuAHHcYQTkhxuQNUTE8ikk7aT4rjLxG2cnAyKjDgLpnUWTmnLNq2v1mm1Wium1nkhDaF/CcBrjLEPYUnknwLwE0qddwD8IIBfZox9B4AqgLM8G6qCdMPkPER5RIMgQKvVukZIg8EAZ2dnCIJAeISS3p0QN9EWR5o6gpfL4jLucL7MdEQ28GSjTmF61XADNFmrXj9uez6fCxt56iCo8yPSJisdy7JEfflZ5XeSZr1J/U3qJN37tnAb7bjtSbysE+pJE+xpiDcNme7qnaThA12ZLl0hrW/NyoVzPmeMfRrAF7E0Sfwc5/wrjLGfA/BlzvlbAP48gP+WMfZnsZwg/dN8x0oqeSJwPp+j2WxiMBjAtm0x2cA5x2g0wsnJCXzfh23b6HQ6Ih2cnCEna8hYeWJON+kYl0FoMpmsmEKWy2W0221hxy5P9skdQlKvTyoa1U6eCJyuLZtskp5c7SjUny3pj5h03k0ibQcQdzypXpZrpt3Ock9dG9IQpExwaUk27jpxdfJAHDnGvfO09VXhKysZq+frsM37oFF33kilQ+dLm/IvKGU/I21/FcD359u0ZJBlRRiG4iNfLBZot9vgfBlVsdfr4eLiAuVyGY8fP0an00nVM1InkWSKl/aHLJVKWCwWIlQvsHT7lfX2cR9X3JrzZ96nFP63Wq3C87yViJPqs64jMR2ydHLUNiB52CmvNz22bqSS5lqbHpe31axOdEzu1KmefExHrEnl8jW2hSo5yqNSKqdFrqvuJx3T3UeVdNV5nrhOLM0+vaNtrhFXppP+0/6X4sru+qTojYHik8tZeGq1mkhHVy6XcXl5icFggMlkgr29PTx48EBrfsg5F278csIK9Y8jZ+qxbTsxKYX80UZRhOFwCN/34XkeDg4O4HneRj8m588ScVBgMsYYGo2GUCHtKpN4kRE3WopbsuZhjYvLoxvpyNsqgcplm0qqac9Z17Y0z5DlOmn3k86PKysqWq3WTkyiC8cA0+kUvV5PSKhEaJPJBJ7n4fz8HMPhEJxzHB0dXZPKoygSKeaCIBAfCcVhIXWMbNGRVTe3WCwwGAwwHo/BGEO9XofneZn1ZjKJy4k4qtXqtUQc9wFJKqt1+3FlWaFKm3Gd9bpFvlbc8fuGtCScV5m6naWT0I3G0l4n7ny1jm5EFTeSyAuFI3TXdXF5eYnhcCgmRUklMhqNEEURSqWSkIjpxU2nU4zH4xUPUNd1xSRkXpMU4/EYg8EAnHN4nodGo5E5FG8YhmLClkL7UiIO27bvHCGsc9tPmmNQ1RVpIJOuTJI0ibzO21S+Bm3TelsVDD1T2vPyqJd1f13dTesYpMeu3l3hCJ0sQcbjMRaLhYiNPp1O4bqucDTqdDpgbJlAut/vi8TS5Emat3piPp+LkQMl4chyD5owlbMpZQntuw2IhNTYKnGxVnTbdB2ZnOVtnV5Vp0tVj8vPLddT7xFnURRXtul72rQsy/1l1YhuYk93jI7HSYCquiXpXN1xHVRBRdcm3bXi2q7WiWuHTnee1GadqmmT/1PcOXHPLX//2wgwWVA4QgeWUjqFr61UKhiPx8LSo1qtotPpAAAuLy8xnU5hWRZarRaq1epO2kNSOWMM7XY7k32pmtuUsimRw1RWcP7M21Oe0NV5f8rHZVJUVRcEHbmoEjCAa+oKtX1J26o+WYZ8bJt9tf26Z4wj0ziCTbMt/5676qDXXTeJlJL205yT9rwsSPuesrzPtHV1ar6kMt3keByq1WpizuNNUUhCp6H1ZDIB51y4/VuWhb29PSwWC5yfnwNYpqPbxYsDlj9Yv9/HZDKB4zhot9ta1Y2qZoiiCL7vYzgcCht013VFDtTZbIbpdLpilhlFkZi4JdNE8vZUSTpOd0jEq+qFZdNNeQJYjb0Spx+W9wmqBB53DZ3ETvu67ZvYz1KWx/lZr72r6+VxPM86eSButJlm9Cm3Vf5W5XL1v5R2fWt26HcRJMnO53OhN282m0LlcnFxgUqlktpMMct96ccOwxDn5+ci/EAURXjy5InwxFSlYxpB+L6P8Xgs1CoUX0UnMeskZJlwyaqH3PIpqYY8qSvXUQNk0YRvnJpjE1WI7nieBHpTxJf2uMHNQBdwLo6o00b2VImW/keq0BNHynfx2ygkoZPuvFQq4fT0FJ1OR1h9kN35/v5+JpWFSqiy7bma1DkIAjx9+hTT6VTEYVeHYypJ+74v0uGVy2VUq1WRC5XKVCKWLW1oO2tkQgKNDOTIiEkSsU7tkVZCv6sf+66Qt0Sa9d1t+q5v6j7qNVTilX075P+RqgqMu16SCbGOnJMIme6bl81/HHYVBLCwhE7kSoG4xuOx6GXjyJxzvhKxUF4oFyipMuQenqTZMAzR6/Xw9OlTLBbLDEiytCwTId2PnH/q9boI5UvWNboohXEfWdJ20npdmW5bhzRWEvIEEL23uI6B1ro6urX6bnX1defG1aXrJWHb42nrZK27qZ56V+fJpBy3pHXIk0mXhLZ1JL3J82xL2Osm3eP+ewQKv5E3Ckno9IPQD7u3t4d2uw3LstBsNlEul4VELUcVlCVuNbogESqFDpDNIcmC5fLyEv1+H57n4ejoCJ7nCbWJfH65XBa27owtnX/Ivb8I0BF92g5h02O6bZ35X95I05HEHUuzjvOWlO+/yTquLA/IkrJM1HHbMuRIn+tUFjoVRhohRbfOu+66sjRI+n12NYItJKHLpPzCCy/g4cOHaDQa8H0fvV5vhaxlNQOVUXwTWedcLpfFhxqGoTAf7Ha7mEwmgtgfP36MF154QUQ/lCMUlkrLZNSj0UiMFBqNRuE8OO+iymTTUUfa42mOAViRMtOQwi6QNLKiUQyVyyQpqwCJtGlbLo+i6NoIhnOeqFtW1Rny3NU6Qtt2W9d5pu04dZZHWUZ+m5btihOKxTRXWCwWwgW+XC4Lb0oAYhKQ6skTgXIkQxrOBUGAyWSCbreL6XSK2WwGYCmx+76PcrmM4+NjcM5Rq9Xw4MEDraTt+z4GgwGiKBKx2IuQ7ago2JU0ugl0I4+kRWfWppYD1z1l5bIkO3uZqGWVoS7+kHo+EZ88GSirAYmcVSl6HVkuFotrv1USsW8zQtH9LnH3TWpTmu28jpFFW94oJKGT+/5kMhEfGdl+B0EgJAxK5uw4zgq5cs4xGAxwcXEB3/cxm81EvXa7DcYYRqMR6vU6Wq2WMI/UJZkOwxD9fl/kNm2323cy49HzBJUwk0h2k2VbyBKkrHIgc1xVz79OJ10qla6dS2W0lrflifY8O8g41cQ61UVWNUeW62W9b5rz4zqOtNcAloLnLlSwhST0+XwuyJwkbrIgIelY55gzn89xenqKy8tLhGEI27bRaDSE05Ft25jNZri8vITneWi32+j3+1gsFtfIXE6WUS6Xsbe3h1qtdqPv4b4gSZLVbcdJupsSruoUpRItcJ2EdWW6Yb6uXH5unXmr7BSmIwk5SJw8sZ40ua6+FzK9Taqja+86pH3/N1lPJzWro4y82pH2HGOHLmE2m2E0GiEMQ6HzTgqAFQQBzs/PcX5+jiiK0Gg08OjRo2v67SAIcHl5iUqlgv39fQwGA8xmM3Q6nRUJfzQaYTgcAgDq9ToajcZOVQFpdbZ5SEhZjsn7suQoS5WyTb3OdlhWJ6SBKtlSmeq8RJJrGhKWOwK1LdtI5ETYqmeurBZRoao8dNvU5qQE3QbpkPV/u8n/XHfOrkbxhSR0inVeqVRweHiI4+Nj7UujOC5nZ2ciCcaDBw+0krRK5qTSIWkfeKZeCcNQWNSUSiWRrCLNMB64rnuVy2hbXt804hw3ZLJWJWQZ6m+hSr20yGoC1T44zmZ43QTbuvIsdZPKCap0Tdu69yInG5GTacvLOjO8pPakIZt1dbY9nrZO1rq7uOZ9ROEIfbFY4OnTpxgMBnjttdfw4MGDaz/gbDYTLvmDwQC1Wg3Hx8exIQDm87kIFeB5nkhXRw5AZ2dn6Pf7GI1GYIwJifzy8jKxreuG6Tq3dx15bTJRpJapUnOSzTBjbCUZtHxN1aIhzqljnQNHUcA5X1GFqGoRGTRvQ4lLdMRtYLBLFI7QB4MBLi8vUSqVcHR0dG1igdQhZIO+v7+P/f19UY+Gu3K2+6dPn4q63W4XFxcX4Jzj8PAQ4/EYo9EInHN0Oh2Rr1RnQ6vu3yR0Xq6qx6sO9CyyE4c87FfJ+T6CSFsl7rhkJ6qkLRO3gcFtonCETn8eSrlGOvDFYiEmO+nPRZEXaQKTEicTSqUSxuMxyuUyHjx4ANd1MR6Psbe3h4ODA0ynU/i+j4ODA7RarVt1DFIlRXWtQta76ibQ0gzv7xtUspZJXIb8/ciknbdViIFB3igcoQdBgDAMUSqV0Gq1hDrh4uIC8/kcrVYL4/EYwDIR68XFhTBjtCwLjUZDWMZMp1MsFgscHR2h0WiILEiWZaHf74NzjkajgXq9HmulkHZbnXTTWWXQRBeNHIhwaFuGjqDVeC90TR3hJ018JmEb072k8/KyIpDfodrpyfVVsz7ZlI/eGSVD2bR9eSPv95d0jbT3SjtpHvfdpD0/qV1p6ujOSXOtpPPTnqsrf/HFF3F0dJTqXllQOEIns0LOOfb39wEA3W5XWKP4vo/Ly0sR38W2bdRqNViWJX5EikF+dnaGSqWCarWK8/NznJ6eYjAYCB1os9kULv9xE4VJk53rJjflkLhEPPIQX9Zly4Qju1c/r5AtSFT1kgyy71aDnAFYMdtL+jPqTP7U89Ksk0gg7XmbTJhnIZqs18gDcdfeZjSU97nblOlwcHCQvWEpUEhCPz09xXA4RBAE+MY3viEsWJ48eYJ33nkHtm2L0Lkk7crgnOPs7Ex0At1uF6enp3j69Cnq9bpITzccDrWTlTpb5Tj7ZdqmcAUyiZMnKxE3ebDSAjyzOKF26zwHVWsTzrm4p66zoX2CjqDUcvUY50uzwHXS2jrpSR5J6MoBrMx5yDHg5fehsxahTo/aQefK70A2W9SR6DZWJ2nOiTOnlMt090m6hlyeNHGubmd1g09zn6R2x7U57f3TtC1L2U2i1Wrt5LqFI/ROp4NarYbFYoF+v4/Ly0u88MILGA6H+Na3vgXP8/DKK6+gVquhVFoNYUvkOJlMsFgssLe3B8dx0Ov1MB6P8ejRI7z88svCKSnOiiPuY5Al+Pl8LkLmTiYTEc2Rcp6S8weRCtVXpf5NEPfB6wJFyeW0TeWqnj2LtY36XuROR+2E5P3ZbIYwDK+FLJbbalmW0G8njVh0JKm2Uz0u23nHkavuPHlN5+raomsDkDyEz1I3zTXykLZ3RYhZrrvJaCXp+mmIf5sOQ/2f7QKFI/TFYhkL/eDgAIvFAoeHh2i32xgMBjg+Psarr74qvOBkyZyG3pZlYTwei/PIFPHll1/Gq6++GmupIOtlZXtjSoVHhEwTqUEQiD82WUU4jiMmKHXqFJLuddYm8rG4EUESyeT5/nVmj2qZTNZx15H13KTnpxAMNJlLLtK06J41bjFIj211wpvouvMuy/O8Ta+f9lq7+j4LR+jz+Ry2bQsrFM45Tk5O8MEHH+AjH/kIhsOhkOIoQJYcsKvX64nEy8PhEFEUoV6vY39/X0RcVCVE3Zogp4UjAt7f34fneSKtHN3/Lttmk05aNXVUt+OgOgzJ+wC09tu2bceqm4wJ4M3iLqolDLKjcIROwbQ8z0OtVsPl5SVarZYIirW/vx/rVhuGIbrdrrB6qdfrmEwmIsri6empsIagvJ2qjtZxHJRKJUHkruuiUqmgXq+LjuKuTljGmTzG2anLIwS1U1Lt1AmkNqFlOp2udIBkDqhGvTQwMNgehSN0IlrbthFFESqVCg4ODvDgwQNYloVer4e9vb1rNuOLxQLvvPMOxuMxjo+PUa1W0ev1cH5+jlarJfTXROAkPdJQv1KpCElcDuxVq9VQrVbvjDRDkrZsPUOLPOyTVT5E1rI1SFpHIkqvR6MaedKxVCoJKyMicEPeBga7Q+EInSRCeeg+nU5xeHgI13VxeXmJi4sLtFotEVJ3Mpng9PQUvV4PzWYTi8UC4/EYg8EA1WoVtVpNWE2QBElSJOcck8lEmC5WKhW4ritMG4Hr1iBZJ16yHJdBemh5UW3OaeKQ9NJy3tKsIEsR6tQoOBq127IseJ4n9N5GbWJgcLMoHKE3Gg3s7++LJBcUHIsmPA8PD9HtdoULP01Wvv/++2KijUj64uIC9XodYRgKYmeMIQgCEZxrPB6LEQF5ppLlyk2CVBlEqKqHqEzW8kI6cHKSyWpFQJI3rUn6Vict6V6+74tkI3nhrox+ZNzFNm2D+/Y8dwVx79XzPBH0L08UjtBLpRKazSbq9TouLi5wfHwM27YRBAEcxxG5PN9//31cXl5iPp8L6fH4+BjtdluQTrvdxssvvwzXdVdUAZPJREyuHh4exmYf2tYyIOk4kSktRKakDlLT6K1TZaQx7aJ7UogEmcAdx0G9Xt+Z9L1Lx5VNcRfblBX34RnuCjZ5lzf9/gtH6OTWv7e3h6dPn2I0GqHZbOLrX/86ms0mZrMZGGNoNptot9sYj8fodrs4ODgQcV1KpRKOj4/x6NGjlV6SkkFTfPVms3mj8VtoEpFInD4Gz/NWTPfyIlOS+oMgECROaifSe8vmggYGBncbhSN02XOSnIL29vYwmUxgWRYeP34Mz/MwmUwwm81Qq9XgeZ6Q2huNBsrl8kqcc+BZlEbGGNrtttC/7xKcc6HeobgyAIQumuzW8xwOywQudxq7vKeBgcHNoHCEfnR0hKOjI+GU8vDhQ3zsYx/DaDSC7/vY399Hr9cTgbpI7bK/v49qtYrhcIjz83OUSiURK304HGI6naJarYrwuLsC5/ya81GpVBIONbrUedtgsVgIAqd8qwDE5K7s7GRgYFBsFI7QafJNjlFO2YOm0yn+6I/+CNVqFe12W2QhOjg4wOHhISqVCmazGY6OjuB5Hnq9Hr71rW8JFcze3t7OJFMicXKGongxZE2TtxROcwkUgEruNMjixcDAYPfQxQnaVYC9VITOGPsEgF8EUAbwWc75z2vq/KsAfhYAB/D/cs5/Isd2CpCd+Gg0wmKxEEFuyOrigw8+wNHRESzLwje+8Q3Yto2joyNUKhXhor+/vy+sVUg/PZvN8OTJEziOk5ttOSWzptgxpVJJeI/mqZsn1Q2pb0gKJ29ZksINDJ5H6AhVXeKO6crT1qV9HVqtFjzPy/1Z1xI6Y6wM4DMAfhjACYAvMcbe4px/VarzGoC/COD7Oeddxlj+gX6vQBLubDZDpVIR+UHJtI5ipFMUxg9/+MPCQmU8Hgsvx4uLC1QqFbzyyisol8sIw1BI0NPpVDgVkU16Fol2Op1iPB4LU8FqtSrUG3lJ4qRKUePGOI4j5ge2kQDuQuyMdeW7OpZ3nV3U2/ScTa0utrXWkM+PI9WkhbAub696/TyeRxcbSV7ov6eLpRQX1C1LYvQsSCOhfy+Atznn37xq0OcBfBLAV6U6/w6Az3DOuwDAOT/Nu6EEWWXBORdu/jSh+corr+Ab3/gG3n33XXzHd3yHmNyMogjT6VSEC2CM4eDgQBA1WXO0Wi2EYSiIvd/vo9/vC+ccnc6Zcy6iOA6HQ8zn8xVpnMLMTqdTUV/twePK5DU9Ay3AUpVSrVbhOM6KOmo0Gq2cK7dVt63bN8gfaTv0TTr+bc/ZhGBpXxfOOe6cbSCTpByoTj0WF6Quy6K+n6ztTMKuRsxpCP0RgHel/RMAH1fqfBsAMMb+DyzVMj/LOf976oUYY28AeAMAXnrppU3ai1qthkajgfF4jCiKYFkW5vM5ptOpSN7s+75IaNHtdtFoNDCZTBBFESaTCSqVigjGRR+juhBRzmYz+L6P4XCI09NT8XHKqd1IZ00OSKQbl4k3LdQPiToC0ofTB6yzSFElgXXbWY7p9jetE1eWZ3lex/OskwdkgoyTVtdJsXKdpHDNaYLIZSXJTUnVWF2lQ16TohUArwH4AQAvAvgtxth3cc57ciXO+ZsA3gSA119/faMum9QmZJ1SqVRECFzP83B+fo4oivBd3/VdIlQueYVOJhMRWZESQScNfehDotC3rVZrJWTuYDAQHqOe56HT6cQSbdwHqiNjInFSATmOIzzLKHqjwd1HHMmuiwmftJ8FSURJ+VGzhCOOI9ws72NdmU763+Q6N1WWtZzKyBgib6Qh9PcAPJb2X7wqk3EC4Hc45zMA/5gx9nUsCf5LubRSgm3baDab6Pf7Qhfl+z5c18ViscCTJ09Qr9fR6XTA2DKi4vn5uXDvb7VaQpcdl8QiKcs95xzj8Rij0UhEWnQcZyXsLjk36ULCJv0JgiAQ0R9pFEBRHA2J3xxU6ZXi/MjbcsYktZ6aZYqQpO5KUgsAq4lH0pTL27r7ys8jI4tKLuu+wRLUod4WoX8JwGuMsQ9hSeSfAqBasPwdAD8O4G8wxg6wVMF8M8d2CtTrdTiOg5OTE9i2jW63C8456vU63n//fURRhMePH6988LPZDLZt4/Hjx9jb29v43r7vYzAYIIoiEapXJVo5XVoYhphMJuLDph+SCJ6kJJLGyYu1VqvBdd17Z5mizhPo9mk7rr56XF3LCTeSCDhtjtg4yNKtrMNdl3gk7tgm7xJYDQy3iWpNradOpO9aRbfJddKq3za9f5Z7bFq+K6wldM75nDH2aQBfxFI//jnO+VcYYz8H4Muc87eujv3zjLGvAogA/AXO+cVOGnwVKZD01b1eD0dHR/B9H+fn5yJFHSEMQ5yfn6PRaKDdbm90z/l8jn6/jyAIYFmWiL0e1z7Z+obOlyMi+r6Pi4sLEXaWdOKNRkPES9m1nfg6HausjlpXT1dHLcsCmWBlwlXL4u6bpM5ijImREnWocrYodV8to7WO2NKWbVPXwCAJqXTonPMvAPiCUvYz0jYH8Oeulp2DvEQtyxIThmdnZyiXy3j06NFK3bOzMywWCzx8+HCjPwaFBAA2tx0lkied/mw2E3p2CrTF+TI07WAwEOcR6cghb9Xkx6p0mVZXuwl0qgCd1Enlahvj2qWWy9eoVCor15Pzsaop+ZJS9d3VTFEGBnmicJ6iAETiYNd1BTk5joN2uy0IAFhKehcXF2g2m5lDVcqBumq1GprN5sZSs2yXzhhDtVoVSTjk9G4U31yOsKhLHAE8IzZ1oY5DR27r1ADrFiJdXU5RXTLudSoLNeuRjpR15GxgYKBHIQmdTAFJTXFwcIDT09MVNQcAEcfl6CibnxMlv2CMYW9v79p104CSaPT7feF+T1L5bDZDt9uNtbChTD/kHERElqT7laVP2pfTxamLKuXH5RBVddI6qAQsjyZ0xG0kZQOD3aCQhO77PjjnsCwLe3t7KxnjCZxznJ2doVqtotFopLou5xy9Xg++7wuJP61UTmnfptMper0eBoOBSGhNdunkYKSmelPXm4CImDI5UXvG4/FKflQ6JkMmeiJgmrSlbYr/IhNzllR1BgYGu0chCZ1isjiOI/TojLEVi5PhcAjf91M7MM3ncyHRUwKNJFAWIHL4IbXKdDpFpVJBo9FAq9VCrVZbkVjzAlluyAQuLyS9k4OUTOSynlqewFR19HJaPp3lh6zXNqRuYHD7KCShU4ozMv/r9/srcVI45zg/P4dt26ksW4IgQLfbBQDs7+/HWrDIKdZms5koI52+53l49OgR6vV6LlYqRNYyadO2Tl1DUjYlwVBVLEltWqd6IYeqJDWRLLUn7RvyNzDYDQpJ6OS4Q5JjFEUrEjVJy3t7eyuTpDr4vo9er4dyuSzCAaig/KKku7csC7VaTZA6JYfwPC+zFE5kKS9qvlDg2SSiqg6Rl20gXz8JOuKX1Tk0sSubFurutc6pK42T100gjUVQWquhvOvt6pq7asN9aEte59ymp+idg+/7wppDjmhIoDgv65yIaNLStm10Op1rZDydToUuvFQqodFowLZtjMdjkaCabMfXEQ5J8mSTTuStqjDouarV6oq5InVeOhtvddIy7TpuW/344urr9uWOgTpblfCTFhWybbnOBFFnxSOrjeI62E1NNw0MtgHxxG16it45TKdTkayYdNayZEkknWSdMhqNMBgMhAmhTMhhGGIwGCAMQ1QqFezt7cGyLIxGI1xcXIAxFkvkZE9OpodypiAiYFUFIaslOOfCXDGt12JeSOvYksV7kEwpk+4lQ3UeSnIskjuzpGdKMuGMM43cxFQyz3pZRiR3oe5duvZN1N/0nF2jcIROsVsajQYqlQqGw+FK/k+y7Njb24uVzsgssVqtotPpiHLOOYbDIUajEcrlMtrtNmq1GkajEXq9HoClqWStVgPnXMQiJ+em6XQqpG+aeCT1iLzI0qPqqJPGThyIj9uRVKYeV8uKijiiT7Mtl8UhieiTHJjUbQODXaNwhE7Z6R3HEVKbPIk5Go1EPlEdfN9Hv98XkjlhPp+j2+1iNpvBdV00m00Mh0M8efIEYRgKvXmv18PZ2ZloB4XaLZfLIhqiGjc9TkI0yAdp9f86qF6q64hfnUPI4nkbR/RxTl/r9s03ZKCicIROUrDneWKoLeui+v0+LMvSmh2SNYvjOCtqljAMcXp6ijAM4bouJpMJTk5ORJwViqY4GAwEeTcaDZHAgpJLGBQPeUjQcWEN1oVjkOc+Ng3LkDSiWzfaW3fcoHgoHAtRVMJGoyGiE8pej4PBAM1m89ofNIoidLtdVCoVEVo3DEP0+308efIEnHO0222MRiNxD8/z4LqukP5kydskWTYgyASYx3eRJhZPmpg46rLpc+1yofsY5IPCETqwzFpEErr8ByJ1S7PZXKnPOcfl5aUgbUrcPBwO0e/34bou2u02+v2+mGTd29sT+UTJ4sTA4CawK717GtJPu+TVaRDSzhlte0wtu28oHEtVKhUcHBzAdV2MRqMVoqWkF6r+nOKp2LYtMhXR0ul0sFgs8N577wlbdIrfYiayDO4T8h5J6LBJ55D2PF29bRFH8mk6AbVDoH0yiNAdI9BIP28UjtCjKBJ6a/IQJfR6PTQajZWPdTqd4uJiGZqdMSYmLj/44AORWBoADg8P8eDBAyOJGxhsgTwkXx2Bx23r5iWS5jGSwksnzYHQPeUoott0Ki+++GLmoIFpUDj2IumcsWU4VyJgShZxeHiIMAyF6/rbb7+NIAjwwgsvwPM8zGYzfO1rX8NwOES9Xke73cbx8fFGERUNDO4TskjFWfbX1VG384YsXcvx9VUdfpYy+RhJ5Lr2xx2LCy+yLQpH6JZlodVqCbd78kh88uQJLi4u0Ol0hEnZyckJ+v0+Hj16BM453nnnHVxcXGAwGKDdbovMQP1+H6PRSBv5UF6nlTySJAD5uLyvW8vXi7tPGqxrd9xxXfm6oeQ2+zdxrCi4SbVFXkSaRmVBasw0dddtx52r1n2eUDhCJ5DJYhAEwlKlVqthf38fpVIJZ2dnqFQq+OhHPyr07bVaDdVqFfv7+zg6OoLneeCcXwtCpUvOIAfEknVktC1/SHnEHllHhGmOpfmjbttZFA15dgwyGapmh5tIp7oOfZMOUl7kb1Hep4XUk1nOSZqETHpfWb5pg81QWEKfz+cYj8dYLBaoVCpwXVe4419cXGA0GuHg4ACVSgW+76NWq6Hf78PzPHzkIx+JVbEsFosVl30Kj6vGIqE/oPwnUB2IdIkl6E+k7qsxSO4i0owckuqsOz+PunH1qOOWR09xTkTA9bym6nlxiLv/OqkzTiDQ1Ys7P+sIEkBi2ITbQJqOYBdlarn8u+m+L125LOTFgY41m82dqHkLS+iUbLnT6Qi7cM/zhClipVIRfz7P83B2doYoivDhD3945UUuFgvhtk+6dwJds9VqiZgkcXk9aZHD26qdAEn/cVD/1Eleplkkp7yGn7clYaWZ2FpXZx3o96TQDEnvOO69J/0eeSBLJ6crS9sBp1ENph156OqkPQ+A9reTR8q6MnnUpJap14xrg+54Upvi3m3cb/Tyyy8bQpdBKeKazSZ6vZ74Q45GI5GHs9PpwHVddLtd+L6Phw8fol6vg3MuQuJStEbyCKWgX9QhyEkjptNpbMAoYL2agqRxtRNYR06AfkhP15TXcdtyWRIhUV15W13rjsVJnFQm/4HjnlsnBSepH+TydQG20sZcURH3m6rSnPp7rSO/NOoX3bXU95LmmK486btNS05xdTclvCRperFYXPsWZal43fnydpwkrX7H6nlx/zmd1K5K8CrkZDx5opCETm74FH88DEOUy2UMh0MRgbHdbqNer6PX62E6naLT6WBvbw/D4RCTyUQ4JdXrdZEcgzqCyWRyLR45gBWpWScxr/sgkpD0p9Jtr5NUk6RUXYchx2BX5xBUSUcmXTmmSVIclKSPm96PLhriuu0kCTiJQNXjSeXr6srHdb9lEtb9zrr6Sc9METvl71A+Rz0mrwnrBAJdHZVkdaQrj2rjriu3eV2ZSuhZoXvHcffW1dumTMcveaCQhC4H4CI3fQqSRdYq+/v7IkKi67rCqSiKIjiOg2azKST18XgsfkRKIFGr1a6pWW5KxXBbSEpaocYyB/RWOarEr1MTURn9IdNI7Zzza52HuugIPA0Ry1CfJY4caV+uKyPNiEmdL1EFgrhRUhw2+T6znpP0nPJ+FskZ0Met3/TaVJ6mXtycVZyErV5Hldp1UrzuG/M8T3vfbVE4Qo+iCMPhEI7joFqtwvf9Fek6iiIcHh5iNBoJT6yzszM0m02RVHo2m6Hf72OxWKBcLgvrF4qOeN9AZKjmHFW3dZBHIzRM1P2x6KOleQJ5ziBuSVI5yNsqiepGQ3GSu66e7jq6kMa6Z10nvaapn+b4XYAqASeNILLsb3Nu2n3dCCvLNXYNiguVNwpH6KQOaTQaACAsUMbjMVzXheu64JzDsiw0m0187WtfW5HIKXdorVaD67o7M/C/Scgml3EJo2Woahv5OrrjRMyylYh6rgxdHHBVOrdte6VMtfLRJZjQSWxx6gT1eNIx3TtQ3we9E9kKRn6fOolfraPWTaqfFmmIKC1ZZSW1JNVEkiSvU5Ws21ex7fnrrpf2GmnO2/Tam6BwhE6TmjT5CSylwul0KuzKyfno3XffxXw+x+PHj+H7PubzOWq1GprNZuGiJcqEreYeJQmc0tvpzCtVyThObaJTj6hmmDJhVyqVa3XV69A+3UfVr8v6e510riuT92XEkes6VUueEvWm19hGhbPpOXkc25Scbvq8vK+xzfUMoV9BlvooFd3Tp0/BOUcYhmi1Wmi1Wri8vMRoNEKn08FkMhGBt25aIl+n11UlXkpfR8RM5plkG0/lsvRN5CwTK4FImMzxKGuSbduwLAuMMXGc6ujmC+Q/b9Ik6Gw200rxMtGmVWeoEr5ONSJv69Qscn35OmqZ+py6/axSYRrCT1O+7lia42nrZKlncPdQOEIHltJqEARwHAflchmDwQCLxdLB6Pj4WEx0UuJox3FEDHQ6X5ZyVWmW6sikJQ+5abiuQkdIcX8OImt6Ft/3MZlMEIbhiqQNQEjHlmUJUq7VaoKUaSFzS9u2VwhblZTjtmUpXpfQWVbfUEegYp2JoG5/3WJgYJAOhSP02WyGbrcrVCyUz7NcLqPVaoExBt/3MZ1OwTnH/v4+Wq0WxuOxIE15si5JN6pCHfKrxBhHWnTdMAwxnU6FExMRN/AsCzhFknRdF9VqVRA1TaJQjtJNiS5ObUMdmAo5lg11KKpnq0zcBgYGt4fCEXoYhjg/PxcZhJ48eYIgCLC/v4/9/X10u130+310u114noeLiwu89957wu6cSJEWWc0QZ2eukzoJqnphsVhgOp1iOBxiMBig3++LzoR0/mQxQiMH27ZRrVZFjGS5TXJQMJLak9pKoBEAJa2mRTXLI1ULpdGTwxIYgjYwKBYKR+iU89OyLHC+zEREw/8PPvgAk8kEg8EAFxcXeOmll4THKDkQyWSVF2H5vo/hcIjhcIjLy0uR95TzZQLrVquFR48eiTykrusKaTfJrE9e4oKGAc+kf9nKBXim/iApX1bFyB2HgcFtI40Z4aZlac/Ls65qaSb/NxeLBQ4ODmIT2W+DwhH6bDYTOuSTkxO88847aDQaQo1RKpVwcnIiVC2u66JUKgnVTFpdLqC31OB86Yw0Go0wGAyEJ+psNluJ/bK3t4dGo3HNvl1VlWTNVk+Tv6S6IesdInvLslCtVlesUuSPzPd9+L6/ck3daCTt/vOINFY068qyHE/azvucrHWznh93jSKBLLNUfw55USGrKsvl8kYmqmlQOEInhyHOOb7+9a8LKdhxHNRqNfR6PbRaLXz3d3+36AF11iS6bVVHTueSpclgMMBwOMR4PBbqE9d14XkeDg4OUK/XheQLQJBuFuhIkiRwWmQJnFQ3pEqKI1mdO77O61ItXwfV5lw3EUpqI3VegdzUdWaISWaLMnTmiHEmimnXSdffFrsgsyyWOOvOU+eH0l43qQ2bXC+uLM7cNK8ygs4JTyZv9XxZVSlrAuSoqnJ9k+DiCkQCrusKHbHjOKjX6wiCAKVSCd/2bd+Ghw8fbixBkh7c9310u10hjc/nc9i2jU6ng3a7jWazKYiUsI4k0ko7FAyMIkHSszcaDTFJWqlUUhFEUh1dZ6Z2fjpHJV253DnIk6w6kiUkjZh0I6i4UVWcm7xaX75vmrXa1rg6SWVZjsvvJG39XaJIo7C0baXvVDUKkFWb8jXL5fKKilImajk6ZxbsKtVl4QidnIMePnyI3/u930MYhiLRcxiGaDabODw8zPyCicSn0ykGg4HQi0dRhEqlgqOjI7RaLaFG2cUPQioRUh8ByxFAp9MRqhu5vTKBxknecSOROH28CnniVC5LUk/RopK4bJOuOjbpPDRlE0l1HVeW1Cms6zDiyN+geNARtrwtg0a6slGAaiBQFKRiJcbYJwD8IoAygM9yzn8+pt6/DODXAfwznPMv59ZKCe12G67rwvf9FTt0shRptVqZ4gwHQYDJZCJ0y77vYzabYbFYoNlsCp2467o7IXFZJz+ZTLBYLIOMkcVLqVTCfD4XtvayI1ESdCQlD/2SpN11y01jXQeVtK36E2RBlne1ybs02Bykx1ZJO84El4iZJG2VuO8L1jIUY6wM4DMAfhjACYAvMcbe4px/VanXAPBTAH5nFw0l0A9wdnYmrDVIb763t5cq6A3nHJPJRERtpElNcj9vNptotVrwPA/VanXrPx/n/FqcFXJ+Go1GYkKVgoRRLHZSIcnWKknmlPdVwiR1Sh5/vLi5lKT5lLh6qtNZVmzSmcq/rS7MQlzdIn4Paowi2XdClbJJYFFNcGldxOffBGlEzu8F8Dbn/JsAwBj7PIBPAviqUu8/A/ALAP5Cri3UgCw7AAhdehRFYIwlRkxcLBZCEiY7bVkXTNJ4o9HYKAD9YrHQ2n7LEjWNCMguntLmUbhenf25QX6QSS5PyWzdXETWZd1k/SZIoybTla87N8u+ClU1Im+rI1HykiYpWyZtY367RBpCfwTgXWn/BMDH5QqMsY8BeMw5/58ZY7GEzhh7A8AbAPDSSy9lb+0VfN8XH8J8Pke73Ua5XBb6dRWccyENy0PwKIoQhiFqtRoODg7QbDZTq1VIZz+bzcRa/gDp4yOJOwxD+L4Py7Kwv78Pz/NEgg6D4uOmRkVJ1j9JS1w9uVxWS+XZkcikrU6uc74aGlmOKyQvahgLedQbhqG241hXtu7YriA/c97YWinMGCsB+MsA/vS6upzzNwG8CQCvv/76Rl8JxT5hjGE2m2E6nWJvbw/A0kxQfVFhGKLX6wnCpeNBEMCyLEHkadQ0ctJoctkHIGy/aWKFVCMAVlQ7lmUJadxI3wab4DbVaes6CBqdyoHlZB8JWnS+EvLko3x92ftanSyPixgaZ1G1rl7cedSBZHnv8jnqGgCOjo4Eb+WJNIT+HoDH0v6LV2WEBoCPAvgHV419AcBbjLEf28XE6GQyAbB8YRSvpdlsih++2+0iCAI0m02Mx2NcXl4KS5VKpSKkg3q9LtQdcVgsFitWJ/SD2LaNRqORaPtN1jJE5GSpYmBwW9CRMW2nLYsLKSGHP+acr5C0bEFC8zzqfdTYSrpRgaoqk53m4tqbhrR1xLvJu113rnx8VyPzNIT+JQCvMcY+hCWRfwrAT9BBznkfwAHtM8b+AYD/cBdkDgCe56FSqWA4HIpcoo7jYD6fo9PpwLIsjEYjvP/++5hOp7BtG+12G7VaTfTo1WoVe3t7WvUK51zYoJP9d6VSged5wvIkqacOwxCDwQBhGKJSqRgifw4RJ8Um7ac9tul21varZC37GhBIrUikLZO3bJudZp2l7jbn6Nb3CWsJnXM+Z4x9GsAXsTRb/Bzn/CuMsZ8D8GXO+Vu7bqSMUqmEWq0mvCZd1xU9NTkYkdXIgwcP0Ol0RALpIAjguq6IyihjsVhmPVITSNdqtVQTpIvFMnH1ZDJBqVQS5pUGt4c0uuUk/XKSPjmpLC/EkZe6r3OqSrOtizVCKhLGmFAfyqpEeX2fzP3ygCr1x5UBz8Jt5I1UOnTO+RcAfEEp+5mYuj+wfbOSQcM+4NmHSdLzZDJBEAR48OCBUMVcXl4KpyNVxUKWL5Qo2nEctNvtTK65FBBssVgIVc597P3zhI484yw71i1xk3mbIo0liM5ENE9LkDy/H1nSltUlKsHYti38LWQrkrRtSSKxdSqPtOflfY287pEVZBadNwrnKQos9dMkSZAKhD66fr+/kkP08vJSRFxUVR+j0Qij0QiLxQK1Wg2NRiOT81AURej1eiLZRqvV2plL712ASqA6h540Dj+bEK6OWGmh+PCbLLpr3wR0ZEPvSd6X668jRyJqirypms1y/ky3rWawImlRDbOclVhvC7qOcJOyrOdsenwTs+g0KCT7BEEgPjp5gqTb7aJUKonZ416vhzAMsbe3t0Lm8/kc3W4Xs9kM1Wp1I7vz8XiMwWAAAIVTr6iheYlI4spUKXgddO72sl5VPaYuuvI0iFN/6LZ1Uj1tZy3b5Ng20Om3VYmbSFuN/U//F5Vk6Fx5BELIoovepm7W89Rtg4ISOgWt4pyjXq+jVCphNpuhVCrh4OAApVJJxGNpNpsrtumj0QjD4RCMMezt7WUKEwAsyZBC5pJ65rZ1iZzza0GzdNtpwgaonqdqnlFdNEVg9Y+nk8R1ZaqnZVz9uGPq9i4Q94xxZevqpLkOrWVnG5m8OefXbLSJsE2c++cbhST0MAyFswINIefzudCj+76P0WgEz/OEzpxzjl6vB9/3Ua1W0W63M3/0QRCg2+2Cc74zHZiKuHjLcnkcaeokXzmei851XL0GvWfd9fNAHLmp+2pb4+qlKddtx5XdBOgdk5MabasSt+rrsE0qQoP7iUISuuz6b1mWcPuv1WrgnGMwGMCyLBEPXbVPl0leDfcap/elELqVSgWtVkuYNqqIIzr1j6cSpjxhJUtmOt2zLqcn3UPOyJRmmCoPsdOoQmQJnc5PE9I2rs7zBM65dnJSjktCoyKanJTNAA0M1qFwhD4ej/HBBx8IKxfLshAEATzPQ61Ww2AwQBRFQo++WCxwfn4O3/fhui7m8znOz89jM4vIIMLp9/siRECr1VpJHadOCMZNFMpxXihJhSxty/fUedCpVgdJ0qZK2EQGccGc1kGdsMsTaXWjWTqnPM7Pcg3dsbi4PvLIiQhblrxVqxK6lhpKOG2b48oM7icKR+ikLyczQfL+dBxH2JJ7ngfLsjCZTPD+++9jNBqh1WoJ6Zcisqn5RenDlwn47OwMnPMV5yR1OAysSsek05alarJtJ/d/WedJOlDbtrUqkW2hs0zY9XbW8/K4r7yvuoVnvb9uPw4kcdNaFRZUj0mZuOXomzeFvDqFXV0n7bXyvscuz1fLdzXqKhyhk9v+kydPxMQoqVv6/b4gwydPnqDf72M6neL4+BjtdnslvjgtQRBoo7vNZjN0u10w9mzyVPaMk1UP8lCaFvmPK+s9yaPuJqWmJMnSIB2ow5B13SSBM8ZWEnCrk5XyO8/aidz0fp7n6Oqpo7xN7p/l3LsKY4d+hel0KpIxA0uCp0w+o9EIwDPduOM4OD4+huM4CIJAxB6XQeqMarUqJCkaARwfH+Pw8PCabTnnzwJ1kQklAPHHpuxCRN4GxQPnq8GmaCFQ506jLTNJeXeRhfi37TzSnm9S0F2BJoxKpRKazSbm8zkajQa63S76/T5eeOEFVKtVnJ6egnMushBRgghyHoqTlCeTCcbjMWq1GjqdjhgWUcQ32WSSCNx1XRFU3/yhiwdZ3y3rvQnytyOPsgyKgedpXqFwhE45NmezGTzPw9HRESzLwsnJCZrNJlzXxbvvvoswDPHgwQO4rismndZhPB4LT9NOpwNgaVFD4QSApURP11wXqMvg7mEdeZOKjGL4GPI2KBIKR+i2baNWqwndZaVSEe79s9kM3/rWtzAcDvHgwQNYliX+sGQ5Qms1IhyZJZLnKDkmLRaLzIG6DO4GZLUJ6bx15O26riBvYx5YbKTRoWfRs++q7q0G57pLoGBanHOUSiX4vo9+v49erycInJI7y5Ymca7r5XIZQRBgPB4LS5nhcChUO6ROeR6x6YRVnudluR7FMVHJm74V2ZqIyJuuo+rIs+pY8z5n3bFtz931+Wmvkde9igYzKXqF8XiM9957D4PBAIwxvP/+++Cco9Pp4IUXXkAURdjf39fGICfzQTn+xWAwwOnpKXzfh+M4KJVKcF1XmCj6vi8ke5LqVccd1VRvl6Z+m9bbZP8ug1KPyU461H419ySZgwKrcyGbIknNltWUbZtr5nHdNCrDbe6R5niWennVyVJvF9c0wbmuQBIWxR93HAcHBweo1+t4+vSpNisKQd2nzmGxWGBvb09YLJD3JlmwxDkhyXbsaoJnXaLnXW1v4kSj28+rzi5sg2XJm0I/yNK2vI6zINjGbtjAoAgoHKFzvoxZ7rougiBArVbD48ePcXBwgH6/j/39/ZXeT0dojDF0u12cnp7C8zw8evRIJL2Q66hr8uokm3U1rooMnROBzmVfXQyJYEV1QlI4oVKpCGsTIm/zzgwMligcodOffTKZYDabod1u49VXX0UURWi32zg4OEg8n0IBnJ+f4+DgAB/+8IdzswmV48LIRE9rSlSdJuJhUnTDNHFX5OUug3O+Qt70joBn5oK1Wm1F721gYKBH4Qi90Wig2WzC931hA+55Hs7Pz9dOMpAr//n5OVqtFl5++eVcTdLkkLProIa0TYpFrk7ubgIdyceNSOJGKXFlcXV0+6T7li1PCKTvpqiZ8kiLiP8msemcQp6TgbusW5Rr7rLuNudtch86p9FoZA7dnQaFI3QAGAwGmM/nwgqFJsSSkjEHQYCLiwt0u1202208ePDgVu2LtzFbInKXrXfWLep58r5uO2mtbidBNRsk1RRjbEX3LVudkAeuwc3gNicHd9mGTdqR5ZxtQmrsaqRZOELnfJlWDoCInxEEgciJqMNkMkGv18NkMkGr1cLBwUGh7cllKfouOL3IhB+GIYIgECqUUqkk5jzo90qauFSvedvYVGW1K0LbNZkZFBuFI/TpdIrxeLwSrXA6ncYOX3zfR6/XW8kbmiTJG6SHSuCy6SDFOSECvwsdj4HBfUfhCJ3UDBQCl8p0hB6GIXq9niAT27bRbDZvtL33CYvFQkjeQRCs6L9t24bneYLAzeSlgcHNo3CETu7aAITzDwXJkhFFES4vL4VeVk4ebZAOcQRO+u9GoyEI3AzvDQxuH4UjdCIXIvb5fI5qtXot5vTFxQWApSPSdDrF/v6+GfavAU1I6gjctm00Gg04jnMtxreBgcHdQOEIfTabYTqdolKpgHMukufKGA6HIqzucDiE53nPbTyWJOgmMYFVCdwQuIFBcVA4Qi+VSqjVahiPx9d06cDS8YjimVMcFqM3f4bZbCbMAsMwFJOYtm2jXq8LG3BD4AYGxUPhCJ08CclumdzpCf1+HwBEqrn9/f3nmpwoJg0RODkmWZYlJjEdx3mu35GBwX1B4QidJEmSvmV7csom5LqukNKfN1ULTWQGQYDpdCoceUg15TiOiCppYGBwv1A4Qvc8DwcHBzg9PRWJKoClPrjf76NSqSCKIjDG0Gq1brm1NwMicJLCgaUe3HEcoUbZVQ5DAwODu4PC/cspyBVlKyIJ3fd9RFGEZrOJwWCARqNxb6VQiulNJC6rUcxEpoHB84tCEnoYhmCMrahcJpMJKpWKCAOwi2wgtwlZjULmhKVSyahRDAwMBApN6JVKRahYwjBEtVrFdDpFs9ksPLktFgtB4LIUTt6uJIUbGBgYEApH6I1GA/V6XejQLcvCZDIBsLSAKbJ0Pp/PRXo00oUbKdzAwCAtCkfowDIULknotm3j8vJSZBSizENFADn2EImTRYqsC4+LIGlgYGCgIhWhM8Y+AeAXAZQBfJZz/vPK8T8H4N8GMAdwBuDf4pz/Uc5tFZB16JTs+aodcF13V7fNBZxzQeCkSpEtUqrVqglRYGBgsBHWEjpjrAzgMwB+GMAJgC8xxt7inH9VqvZ7AF7nnE8YY/8ugP8cwL+2iwZTlneKL0LZ2xljqNVqd1I6J6sU3/eFdyapUkidchfbbWBgUCykkdC/F8DbnPNvAgBj7PMAPglAEDrn/Del+r8N4E/m2UgZ8/kcURQJL1Hf98EYiw2he1uIogi+76/ow8vlMjzPQ7VaNaoUAwOD3JGG0B8BeFfaPwHw8YT6Pwngf9EdYIy9AeANAHjppZdSNnEVMqFT3k1KlHzbXqE0qen7vjAtJH14tVo1VikGBgY7Ra6TooyxPwngdQB/XHecc/4mgDcB4PXXX98ox1gQBIiiSJAjJby4LcuW+XwuJHEicTItrNVqRh9uYGBwY0hD6O8BeCztv3hVtgLG2A8B+EsA/jjnfGcZfikmCZnvzWYz2LZ9o+qWOBJvtVpmUtPAwODWkIbQvwTgNcbYh7Ak8k8B+Am5AmPsewD8NwA+wTk/zb2VCizLEmFfoyhCuVzeuU6adOKyOsWQuIGBwV3CWkLnnM8ZY58G8EUszRY/xzn/CmPs5wB8mXP+FoD/AkAdwN+6stZ4h3P+Y7to8Gw2w3w+h23b4JwjiqKdSedknTKZTMTEpmVZRp1iYGBwJ5FKh845/wKALyhlPyNt/1DO7YqFbduoVquIokgkuMiT0CkN22QyQRAE4JyjUqmg0WigVquZqIUGBgZ3FoVjJzkgFyW6yMN6JAxDoVIhyxnP81Cr1Yx1ioGBQSFQOEJ3HAe1Wk1kLlLziWbBYrHAZDLBZDLBfD4HYwzVahWu6966CaSBgYFBVhSO0Cn1HAARzyUrgiDAeDwWXqa2baPdbt9ZT1MDAwODNCgcoUdRhCiKwDlfyVi0DiSNj8dj4ZhUr9fhuq7RixsYGNwLFI7J5vO5CMZlWdZaMo6iCKPRCJPJBJxz4fRTrVaNNG5gYHCvUEhCn81mgpzjCN33fXS7XYxGIywWC6F7B4DxeIzxeCzqMsbEou6ri6zuSVoMDJ5nkJ/Ipsdvsk6aa2Stu66eZVk7MXsuHKGTHTpFW5QJfT6fYzQa4ezsDMPhUITTbTabqFQqIuaLDM65CB9A2/KyKdTOQd6W99Os48rS7K8r36buNu8ny/nb/iGTjm967ibX3LQduzi27fFtf/vnHa1WayfhSgpH6BSQi5JbEPkMh0Ocnp5iOBzCsiw8fPgQnU5n62TJ9OHGEb98TNcZxJXJa7qmXKarp24bpEfSN7CLY0nHs14zrjPf1f3u0vGbusZN19uVU2LhCL1Wq6HRaGA6nQrpfDAY4MmTJ5jP53jw4IEg8jyg/pnukneoSu6bSIbb1AWySf7bXGPbOvetI7xvz3OXcBPvdldq2cIROuccs9kMlUoFlmWh3+/j/Pwci8UCDx8+RLvdzuUe5IlK0nOcSiZO6k7aTirLcnzTuru8xk1c08Cg6DAqlyvQpKjneQjDEMPhELPZDIeHh2i1WqmvQx0DWc3QQiSehCRduE7vrTumq5N0vyzYtvff5aTuXZwwvottyhP3/fluC9u8110FEywkoXPOYVkWptMpRqORULOse8GUkDkMQ2EpAzxzUKpUKiiXyyiXyyiVSisLWbiYP4eBgcFdReEInaIdOo6D8XgMy7Kwv78fS7Sz2QyTyQTT6RRRFAFY9o6e5wkrGeNYZGBgcB9QOCbzPA/1el14jNbrdWEbLmM6nWI8HiMIAhGjhRIy6+obGBgYFB2FI3QiY5q0rNfrK8fDMMRgMEAYhiiXy2g2m3Bd15C4gYHBvUfhCB145v5PUjewtOXu9/vwfR/lchntdhuu695ySw0MDAxuDoUj9CiKEIahyFpkWRbCMES328VisUCj0UC9XjeTlwYGBs8dCkfo8/kcYRiiVCrBcRxh6VKpVHBwcGCSURgYGDy3KByhh2GIIAhQqVQwGo1gWRZqtRra7baRyg0MDJ5rFI7QB4MBTk9Psbe3B8dx0Ol0tspaZGBgYHBfUDjTj1KpBNd1cXx8jAcPHhgyNzAwMLhC4Qi9XC7j4OAA7Xbb6MsNDAwMJBSO0B3Hged5KJfLhtANDAwMJBSO0Gu1Glqt1kosdAMDAwODAhK6bdtwHGdn0coMDAwMiorCETq5/Bt1i4GBgcEqCkfos9kMAAyhGxgYGCgwhG5gYGBwT1A4x6JGowHXdc2EqIGBgYGCwknowN1K1GxgYGBwV1BIQjcwMDAwuA5D6AYGBgb3BIbQDQwMDO4JDKEbGBgY3BMYQjcwMDC4J0hF6IyxTzDGvsYYe5sx9tOa4w5j7L+/Ov47jLFXcm+pgYGBgUEi1hI6Y6wM4DMAfgTAdwL4ccbYdyrVfhJAl3P+EQB/BcAv5N1QAwMDA4NkpJHQvxfA25zzb3LOQwCfB/BJpc4nAfzK1favA/hBZjx/DAwMDG4UaTxFHwF4V9o/AfDxuDqc8zljrA9gH8C5XIkx9gaAN652R4yxr23SaAAH6rWfEzyPz/08PjPwfD738/jMQPbnfjnuwI26/nPO3wTw5rbXYYx9mXP+eg5NKhSex+d+Hp8ZeD6f+3l8ZiDf506jcnkPwGNp/8WrMm0dxlgFQAvARR4NNDAwMDBIhzSE/iUArzHGPsQYswF8CsBbSp23APypq+1/BcDf55zz/JppYGBgYLAOa1UuVzrxTwP4IoAygM9xzr/CGPs5AF/mnL8F4K8D+FXG2NsALrEk/V1ia7VNQfE8Pvfz+MzA8/ncz+MzAzk+NzOCtIGBgcH9gPEUNTAwMLgnMIRuYGBgcE9QOEJfF4bgPoAx9pgx9puMsa8yxr7CGPupq/IOY+x/Y4z9o6v13m23NW8wxsqMsd9jjP3dq/0PXYWTePsqvIR9223MG4yxNmPs1xlj/5Ax9oeMsT/2nPzWf/bq+/4DxtivMcaq9+33Zox9jjF2yhj7A6lM+9uyJf7q1bP/PmPsY1nvVyhCTxmG4D5gDuDPc86/E8D3AfgzV8/50wB+g3P+GoDfuNq/b/gpAH8o7f8CgL9yFVaii2WYifuGXwTw9zjn3w7gn8Ty+e/1b80YewTg3wfwOuf8o1gaXHwK9+/3/mUAn1DK4n7bHwHw2tXyBoBfynqzQhE60oUhKDw45x9wzv/vq+0hln/wR1gNsfArAP6lW2ngjsAYexHAvwDgs1f7DMCfwDKcBHA/n7kF4J/D0lIMnPOQc97DPf+tr1ABULvyXXEBfIB79ntzzn8LS8s/GXG/7ScB/E2+xG8DaDPGHmS5X9EIXReG4NEtteVGcBW58nsA/A6AY875B1eHngA4vq127Qj/NYD/CMDian8fQI9zPr/av4+/94cAnAH4G1eqps8yxjzc89+ac/4egP8SwDtYEnkfwO/i/v/eQPxvuzW/FY3QnyswxuoA/kcA/wHnfCAfu3Lcujc2p4yxfxHAKef8d2+7LTeMCoCPAfglzvn3ABhDUa/ct98aAK70xp/EskN7CMDDddXEvUfev23RCD1NGIJ7AcaYhSWZ/3ec8799VfyUhmBX69Pbat8O8P0Afowx9i0sVWl/AkvdcvtqSA7cz9/7BMAJ5/x3rvZ/HUuCv8+/NQD8EIB/zDk/45zPAPxtLL+B+/57A/G/7db8VjRCTxOGoPC40h3/dQB/yDn/y9IhOcTCnwLwP91023YFzvlf5Jy/yDl/Bcvf9e9zzv91AL+JZTgJ4J49MwBwzp8AeJcx9k9cFf0ggK/iHv/WV3gHwPcxxtyr752e+17/3leI+23fAvBvXFm7fB+AvqSaSQfOeaEWAD8K4OsAvgHgL912e3b0jP8slsOw3wfw/1wtP4qlTvk3APwjAP87gM5tt3VHz/8DAP7u1farAP4vAG8D+FsAnNtu3w6e958C8OWr3/vvANh7Hn5rAP8pgH8I4A8A/CoA57793gB+Dcs5ghmWo7GfjPttATAsrfi+AeD/w9ICKNP9jOu/gYGBwT1B0VQuBgYGBgYxMIRuYGBgcE9gCN3AwMDgnsAQuoGBgcE9gSF0AwMDg3sCQ+gGBgYG9wSG0A0MDAzuCf5/K5YWZNAD/YgAAAAASUVORK5CYII=\n", 63 | "text/plain": [ 64 | "
" 65 | ] 66 | }, 67 | "metadata": { 68 | "needs_background": "light" 69 | }, 70 | "output_type": "display_data" 71 | } 72 | ], 73 | "source": [ 74 | "for _ in range(50):\n", 75 | " prior = lcpfn.sample_from_prior(np.random)\n", 76 | " curve, _ = prior()\n", 77 | " plt.plot(curve, \"black\", alpha=0.1)\n", 78 | "plt.ylim(0, 1)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "id": "685552b1-94b2-428a-bca2-66690f3090d2", 84 | "metadata": {}, 85 | "source": [ 86 | "## Train a PFN model with the previous learning curve prior" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 5, 92 | "id": "48b1c249-d9cc-4543-bbd3-c0cf410263c3", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "get_batch_func = lcpfn.create_get_batch_func(prior=lcpfn.sample_from_prior)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 6, 102 | "id": "386dcfc8-d92a-4b78-8881-b68da4b1e9c5", 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "torch.Size([100, 100, 1]) torch.Size([100, 100]) torch.Size([100, 100])\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "# example of a batch\n", 115 | "\n", 116 | "X, Y, Y_noisy = get_batch_func(batch_size=100, seq_len=100, num_features=1)\n", 117 | "print(X.shape, Y.shape, Y_noisy.shape)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 7, 123 | "id": "091def9b-dd4c-4e7d-8550-5dc016bc3893", 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "text/plain": [ 129 | "" 130 | ] 131 | }, 132 | "execution_count": 7, 133 | "metadata": {}, 134 | "output_type": "execute_result" 135 | } 136 | ], 137 | "source": [ 138 | "# Main function to train a PFN model\n", 139 | "\n", 140 | "lcpfn.train_lcpfn" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 9, 146 | "id": "c18b2d21-1819-425a-aeba-9b3871f5900a", 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "name": "stdout", 151 | "output_type": "stream", 152 | "text": [ 153 | "Using 1000000 y evals to estimate 1000 buckets. Cut off the last 0 ys.\n", 154 | "Using cpu:0 device\n", 155 | "init dist\n", 156 | "Not using distributed\n", 157 | "DataLoader.__dict__ {'num_steps': 100, 'get_batch_kwargs': {'batch_size': 10, 'eval_pos_seq_len_sampler': .eval_pos_seq_len_sampler at 0x174f16e60>, 'seq_len_maximum': 100, 'device': 'cpu:0', 'num_features': 1, 'hyperparameters': {}}, 'num_features': 1}\n", 158 | "Style definition: None\n", 159 | "Using a Transformer with 2.23 M parameters\n", 160 | "-----------------------------------------------------------------------------------------\n", 161 | "| end of epoch 1 | time: 9.84s | mean loss -0.58 | pos losses nan, 0.03,-0.40,-0.33, nan,-0.72, nan,-0.27,-0.37, nan,-0.12,-0.61,-0.57,-0.73,-0.40,-0.58,-0.65,-1.00,-0.10, nan,-0.72,-0.27, nan,-0.30,-0.56, nan,-0.58, nan,-0.41, nan, nan,-0.42,-0.59,-0.54,-0.40, nan,-0.41,-0.93,-0.85, nan, nan,-0.82,-0.56, nan,-0.81,-0.79, nan,-0.91, nan,-1.08, nan,-0.66, nan,-0.65, nan, nan, nan,-0.73,-0.89,-0.52,-0.56,-0.68,-0.74,-0.72, nan, nan,-0.58,-0.25,-0.61,-0.74,-0.67, nan, nan, nan,-0.55, nan, nan,-0.73,-0.81,-0.52,-1.05,-0.49,-0.80,-0.71, nan, nan, nan,-0.86,-0.64, 0.15,-0.71, nan, nan, nan,-0.74,-0.44, nan,-0.46, nan, nan, lr 0.001 data time 0.00 step time 0.10 forward time 0.05\n", 162 | "-----------------------------------------------------------------------------------------\n", 163 | "-----------------------------------------------------------------------------------------\n", 164 | "| end of epoch 2 | time: 9.95s | mean loss -1.09 | pos losses nan, nan, nan,-0.11,-0.31,-0.72, nan, nan, nan,-0.89,-0.74,-0.75, nan, nan, nan,-0.87,-0.87,-1.29, nan,-1.06, nan,-1.25,-1.00, nan,-1.20,-0.85,-1.18,-0.99,-1.31,-1.03,-1.26,-0.81, nan,-1.03,-1.10, nan,-1.32, nan,-1.14, nan,-1.09,-1.21, nan, nan, nan,-1.40,-1.22, nan,-1.19,-1.16, nan, nan,-1.19,-1.09, nan,-1.30,-1.38,-1.40,-1.11,-1.19, nan, nan,-1.28, nan,-1.23, nan, nan, nan,-1.25, nan, nan,-0.92, nan,-0.77,-1.15, nan,-1.04,-1.12,-1.06,-1.15, nan,-1.00, nan,-0.57,-1.20, nan,-0.98,-0.81, nan,-1.04,-1.03,-1.43,-1.08, nan, nan,-1.40,-1.29,-1.47,-1.47,-1.42, lr 0.00075 data time 0.00 step time 0.09 forward time 0.05\n", 165 | "-----------------------------------------------------------------------------------------\n", 166 | "-----------------------------------------------------------------------------------------\n", 167 | "| end of epoch 3 | time: 10.00s | mean loss -1.30 | pos losses nan,-0.08, nan,-0.84,-0.85,-0.96, nan,-1.02, nan,-0.82, 0.05, nan, nan,-1.07,-1.41, nan,-1.35,-1.41,-1.28,-1.44, nan, nan, nan, nan,-1.02, nan,-1.28, nan,-1.49,-1.30,-1.42,-1.21,-1.44,-1.41, nan,-1.36, nan, nan, nan, nan, nan,-1.55,-1.72,-1.27,-1.70,-1.03, nan,-1.57,-1.45,-1.50,-1.49, nan, nan, nan,-1.40,-0.89,-1.54,-1.47,-1.35,-1.28,-1.42, nan,-1.48, nan, nan,-1.60,-1.42,-1.44, nan,-1.08,-1.59, nan,-1.46,-1.03,-0.96,-1.24,-1.36,-0.81,-1.07, nan, nan, nan, nan, nan,-1.65, nan,-1.58,-1.42, nan,-1.53,-1.61,-1.71,-1.47, nan, nan,-1.27,-1.50,-1.56, nan,-1.08, lr 0.0002500000000000001 data time 0.00 step time 0.10 forward time 0.05\n", 168 | "-----------------------------------------------------------------------------------------\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "# train a small model for 3 epochs\n", 174 | "\n", 175 | "result = lcpfn.train_lcpfn(\n", 176 | " get_batch_func=get_batch_func,\n", 177 | " seq_len=100,\n", 178 | " emsize=256,\n", 179 | " nlayers=3,\n", 180 | " num_borders=1000,\n", 181 | " lr=0.001,\n", 182 | " batch_size=10,\n", 183 | " epochs=3,\n", 184 | ")" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 10, 190 | "id": "4584f7a0-b5d9-4623-a7bd-987da27558aa", 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "TransformerModel(\n", 198 | " (transformer_encoder): TransformerEncoder(\n", 199 | " (layers): ModuleList(\n", 200 | " (0): TransformerEncoderLayer(\n", 201 | " (self_attn): MultiheadAttention(\n", 202 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 203 | " )\n", 204 | " (linear1): Linear(in_features=256, out_features=512, bias=True)\n", 205 | " (dropout): Dropout(p=0.2, inplace=False)\n", 206 | " (linear2): Linear(in_features=512, out_features=256, bias=True)\n", 207 | " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 208 | " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 209 | " (dropout1): Dropout(p=0.2, inplace=False)\n", 210 | " (dropout2): Dropout(p=0.2, inplace=False)\n", 211 | " )\n", 212 | " (1): TransformerEncoderLayer(\n", 213 | " (self_attn): MultiheadAttention(\n", 214 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 215 | " )\n", 216 | " (linear1): Linear(in_features=256, out_features=512, bias=True)\n", 217 | " (dropout): Dropout(p=0.2, inplace=False)\n", 218 | " (linear2): Linear(in_features=512, out_features=256, bias=True)\n", 219 | " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 220 | " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 221 | " (dropout1): Dropout(p=0.2, inplace=False)\n", 222 | " (dropout2): Dropout(p=0.2, inplace=False)\n", 223 | " )\n", 224 | " (2): TransformerEncoderLayer(\n", 225 | " (self_attn): MultiheadAttention(\n", 226 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 227 | " )\n", 228 | " (linear1): Linear(in_features=256, out_features=512, bias=True)\n", 229 | " (dropout): Dropout(p=0.2, inplace=False)\n", 230 | " (linear2): Linear(in_features=512, out_features=256, bias=True)\n", 231 | " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 232 | " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 233 | " (dropout1): Dropout(p=0.2, inplace=False)\n", 234 | " (dropout2): Dropout(p=0.2, inplace=False)\n", 235 | " )\n", 236 | " )\n", 237 | " )\n", 238 | " (encoder): Sequential(\n", 239 | " (0): Normalize()\n", 240 | " (1): Normalize()\n", 241 | " (2): Linear(in_features=1, out_features=256, bias=True)\n", 242 | " )\n", 243 | " (y_encoder): Sequential(\n", 244 | " (0): Normalize()\n", 245 | " (1): Linear(in_features=1, out_features=256, bias=True)\n", 246 | " )\n", 247 | " (pos_encoder): NoPositionalEncoding()\n", 248 | " (decoder): Sequential(\n", 249 | " (0): Linear(in_features=256, out_features=512, bias=True)\n", 250 | " (1): GELU()\n", 251 | " (2): Linear(in_features=512, out_features=1000, bias=True)\n", 252 | " )\n", 253 | " (criterion): FullSupportBarDistribution()\n", 254 | ")\n" 255 | ] 256 | } 257 | ], 258 | "source": [ 259 | "# Get the trained model\n", 260 | "\n", 261 | "model = result[2]\n", 262 | "print(model)" 263 | ] 264 | } 265 | ], 266 | "metadata": { 267 | "kernelspec": { 268 | "display_name": "lcpfn", 269 | "language": "python", 270 | "name": "lcpfn" 271 | }, 272 | "language_info": { 273 | "codemirror_mode": { 274 | "name": "ipython", 275 | "version": 3 276 | }, 277 | "file_extension": ".py", 278 | "mimetype": "text/x-python", 279 | "name": "python", 280 | "nbconvert_exporter": "python", 281 | "pygments_lexer": "ipython3", 282 | "version": "3.10.11" 283 | } 284 | }, 285 | "nbformat": 4, 286 | "nbformat_minor": 5 287 | } 288 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "lcpfn" 3 | description = "In-context Bayesian Learning Curve Extrapolation" 4 | readme = {file = "readme.md", content-type = 'text/markdown'} 5 | license = {file = "LICENSE"} 6 | authors = [ 7 | {name = "Steven Adriaensen", email= "adriaens@cs.uni-freiburg.de"}, 8 | {name = "Herilalaina Rakotoarison", email = "rakotoah@cs.uni-freiburg.de"}, 9 | {name = "Samuel Müller", email = "muellesa@cs.uni-freiburg.de"}, 10 | {name = "Frank Hutter", email = "fh@cs.uni-freiburg.de"}, 11 | ] 12 | requires-python = ">=3.9,<3.12" 13 | dependencies = [ 14 | "torch<=1.11.0", 15 | "numpy>=1.21.2,<2", 16 | "requests>=2.23.0" 17 | ] 18 | dynamic = ["version"] 19 | classifiers = [ 20 | 'Intended Audience :: Science/Research', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python', 23 | 'Topic :: Software Development', 24 | 'Topic :: Scientific/Engineering', 25 | 'Operating System :: Unix', 26 | 'Operating System :: MacOS', 27 | 'Programming Language :: Python :: 3', 28 | 'Programming Language :: Python :: 3.9', 29 | 'Programming Language :: Python :: 3.10', 30 | 'Programming Language :: Python :: 3.11', 31 | ] 32 | 33 | [project.urls] 34 | homepage = "https://github.com/automl/lcpfn" 35 | repository = "https://github.com/automl/lcpfn" 36 | bugtracker = "https://github.com/automl/lcpfn/issues" 37 | 38 | [tool.setuptools.packages.find] 39 | include = ["lcpfn*"] 40 | 41 | [tool.setuptools.dynamic] 42 | version = {attr = "lcpfn.version.__version__"} -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | # Efficient Bayesian Learning Curve Extrapolation using Prior-Data Fitted Networks 3 | 4 | This repository offers an implementation of [LC-PFN](https://openreview.net/pdf?id=xgTV6rmH6n), a method designed for efficient Bayesian learning curve extrapolation. 5 | 6 | **LC-PFN in action on [Google colab](https://colab.research.google.com/drive/1JA2t91xgqZVfjZya41oW5vVQktv_YXhE?usp=sharing) and [HuggingFace](https://huggingface.co/spaces/herilalaina/lcpfn)** 7 | 8 | Installation using pip: 9 | 10 | ```bash 11 | pip install -U lcpfn 12 | ``` 13 | 14 | > **Update**: there is an inconsistency between the code and the paper regarding the definition of the noise prior. The correct definition is the one used in the code, where $\log(\sigma)$ is defined as $\mathcal{N}(-4, 1)$. 15 | 16 | ### Usage 17 | 18 | Try out the `notebooks` (require ``matplotlib``) for training and inference examples. 19 | 20 | **NOTE:** Our model supports only increasing curves with values in $[0,1]$. If needed, please consider normalizing your curves to meet these constraints. See an example in ``notebooks/curve_normalization.ipynb``. 21 | 22 | 23 | ### Reference 24 | 25 | ``` 26 | @inproceedings{ 27 | adriaensens2023lcpfn, 28 | title={Efficient Bayesian Learning Curve Extrapolation using Prior-Data Fitted Networks}, 29 | author={Adriaensen, Steven and Rakotoarison, Herilalaina and Müller, Samuel and Hutter, Frank}, 30 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 31 | year={2023}, 32 | url={https://openreview.net/forum?id=xgTV6rmH6n} 33 | } 34 | ``` 35 | 36 | 37 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from lcpfn.model import LCPFN 4 | 5 | 6 | class TestLCPFN(unittest.TestCase): 7 | def setUp(self): 8 | self.model = LCPFN() 9 | 10 | def test_init(self): 11 | self.assertIsInstance(self.model, LCPFN) 12 | 13 | def test_predict_mean(self): 14 | x_train = torch.arange(1, 11).unsqueeze(-1) 15 | y_train = torch.rand(10).unsqueeze(-1) 16 | x_test = torch.arange(11, 16).unsqueeze(-1) 17 | mean = self.model.predict_mean(x_train, y_train, x_test) 18 | self.assertIsInstance(mean, torch.Tensor) 19 | 20 | def test_predict_quantiles(self): 21 | x_train = torch.arange(1, 11).unsqueeze(-1) 22 | y_train = torch.rand(10).unsqueeze(-1) 23 | x_test = torch.arange(11, 16).unsqueeze(-1) 24 | qs = [0.1, 0.5, 0.9] 25 | quantiles = self.model.predict_quantiles(x_train, y_train, x_test, qs) 26 | self.assertTrue(torch.all(quantiles[0] < quantiles[1])) 27 | self.assertTrue(torch.all(quantiles[1] < quantiles[2])) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | --------------------------------------------------------------------------------