├── requirements.txt ├── CODEOWNERS ├── CONTRIBUTING-ARCHIVED.md ├── .copyright.tmpl ├── online_conformal ├── __init__.py ├── split_conformal.py ├── nex_conformal.py ├── dataset.py ├── visualize.py ├── ogd.py ├── utils.py ├── model_sigma.py ├── saocp.py ├── enbpi.py ├── base.py └── faci.py ├── .gitignore ├── MANIFEST.in ├── SECURITY.md ├── .pre-commit-config.yaml ├── Dockerfile ├── .github └── workflows │ └── publish.yml ├── setup.py ├── AI_ETHICS.md ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── README.md ├── make_table.py ├── LICENSE ├── cv_utils.py ├── time_series.py └── vision.py /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | jinja2>=3.0 3 | matplotlib 4 | numpy 5 | pandas 6 | salesforce-merlion 7 | scipy 8 | tqdm -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related other information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /.copyright.tmpl: -------------------------------------------------------------------------------- 1 | Copyright (c) ${years} ${owner} 2 | All rights reserved. 3 | SPDX-License-Identifier: Apache-2.0 4 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 5 | -------------------------------------------------------------------------------- /online_conformal/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # package 2 | __pycache__ 3 | *.egg-info 4 | dist 5 | # experiments 6 | data 7 | results 8 | cv_models 9 | cv_logits 10 | # IDE/system 11 | .idea 12 | *.swp 13 | .DS_Store 14 | sandbox 15 | .vscode 16 | Icon? 17 | # build files 18 | .ipynb_checkpoints 19 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CODE_OF_CONDUCT.md LICENSE SECURITY.md requirements.txt 2 | global-exclude *.py[cod] 3 | exclude *.ipynb 4 | exclude Dockerfile .pre-commit-config.yaml .copyright.tmpl .gitignore 5 | recursive-exclude cv_logits * 6 | recursive-exclude cv_models * 7 | recursive-exclude figures * 8 | recursive-exclude results * 9 | recursive-exclude .github * 10 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: '22.10.0' 4 | hooks: 5 | - id: black 6 | args: ["--line-length", "120"] 7 | - repo: https://github.com/johann-petrak/licenseheaders.git 8 | rev: 'v0.8.8' 9 | hooks: 10 | - id: licenseheaders 11 | args: ["-t", ".copyright.tmpl", "-cy", "-o", "salesforce.com, inc.", 12 | "-E", ".py", "-x", "docs/source/conf.py", "-f"] 13 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | # 7 | FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime 8 | 9 | COPY requirements.txt requirements.txt 10 | COPY README.md README.md 11 | COPY conformal_ts conformal_ts 12 | COPY setup.py setup.py 13 | RUN python3 -m pip install -e . 14 | -------------------------------------------------------------------------------- /online_conformal/split_conformal.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | import pandas as pd 8 | 9 | from online_conformal.base import BasePredictor 10 | 11 | 12 | class SplitConformal(BasePredictor): 13 | """ 14 | Split Conformal Prediction, adapted to time series. 15 | """ 16 | 17 | def update(self, ground_truth: pd.Series, forecast: pd.Series, horizon): 18 | self.residuals.extend(horizon, (ground_truth - forecast).values.tolist()) 19 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to pip 2 | 3 | on: 4 | release: 5 | types: [ published ] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: '3.10' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip setuptools build wheel 19 | - name: Build package 20 | run: | 21 | python -m build 22 | - name: Publish package 23 | uses: pypa/gh-action-pypi-publish@release/v1 24 | with: 25 | user: __token__ 26 | password: ${{ secrets.PYPI_API_TOKEN }} 27 | verbose: true 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="online_conformal", 11 | version="1.0.2", 12 | author="Aadyot Bhatnagar", 13 | author_email="abhatnagar@salesforce.com", 14 | description="A library for time series conformal prediction", 15 | long_description=open("README.md", "r", encoding="utf-8").read(), 16 | long_description_content_type="text/markdown", 17 | license="Apache 2.0", 18 | url="https://github.com/salesforce/online_conformal", 19 | packages=find_packages(include=["online_conformal*"]), 20 | package_dir={"online_conformal": "online_conformal"}, 21 | install_requires=open("requirements.txt").read().split("\n"), 22 | ) 23 | -------------------------------------------------------------------------------- /AI_ETHICS.md: -------------------------------------------------------------------------------- 1 | ## Ethics disclaimer for Salesforce AI models, data, code 2 | 3 | This release is for research purposes only in support of an academic 4 | paper. Our models, datasets, and code are not specifically designed or 5 | evaluated for all downstream purposes. We strongly recommend users 6 | evaluate and address potential concerns related to accuracy, safety, and 7 | fairness before deploying this model. We encourage users to consider the 8 | common limitations of AI, comply with applicable laws, and leverage best 9 | practices when selecting use cases, particularly for high-risk scenarios 10 | where errors or misuse could significantly impact people’s lives, rights, 11 | or safety. For further guidance on use cases, refer to our standard 12 | [AUP](https://www.salesforce.com/content/dam/web/en_us/www/documents/legal/Agreements/policies/ExternalFacing_Services_Policy.pdf) 13 | and [AI AUP](https://www.salesforce.com/content/dam/web/en_us/www/documents/legal/Agreements/policies/ai-acceptable-use-policy.pdf). 14 | -------------------------------------------------------------------------------- /online_conformal/nex_conformal.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | import numpy as np 8 | 9 | from online_conformal.enbpi import EnbMixIn 10 | from online_conformal.split_conformal import SplitConformal 11 | from online_conformal.utils import quantile 12 | 13 | 14 | class NExConformal(SplitConformal): 15 | """ 16 | Non-Exchangeable Split Conformal Prediction, one of the algorithms described in Barber et al., 2022, 17 | "Conformal Prediction Beyond Exchangeability." https://arxiv.org/abs/2202.13415. 18 | """ 19 | 20 | @property 21 | def gamma(self): 22 | return self.coverage + 3 * (1 - self.coverage) / 4 23 | 24 | def quantile(self, arr, q): 25 | weights = np.exp(np.log(self.gamma) * np.arange(len(arr) - 1, -1, -1)) 26 | return quantile(arr, q, weights=weights) 27 | 28 | 29 | class EnbNEx(EnbMixIn, NExConformal): 30 | pass 31 | -------------------------------------------------------------------------------- /online_conformal/dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | from datasets import load_dataset 8 | import numpy as np 9 | import pandas as pd 10 | from ts_datasets.forecast import M4 as _M4 11 | 12 | 13 | class MonashTSF: 14 | def __init__(self, dataset, freq, horizon): 15 | self.freq = freq 16 | self.horizon = horizon 17 | self.data = load_dataset("monash_tsf", dataset)["test"] 18 | 19 | def __getitem__(self, i): 20 | data = self.data[i] 21 | arr = np.asarray(data["target"]).T 22 | t0 = data["start"][0] if isinstance(data["start"], list) else data["start"] 23 | ts = pd.DataFrame(arr.reshape(len(arr), -1), index=pd.date_range(start=t0, periods=len(arr), freq=self.freq)) 24 | ts = (ts - ts.min(axis=0)) / (ts.max(axis=0) - ts.min(axis=0)) 25 | n_test = min(120, int(len(ts) / 5)) 26 | return dict(train_data=ts.iloc[:-n_test], test_data=ts.iloc[-n_test:], horizon=self.horizon, calib_frac=0.2) 27 | 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def __iter__(self): 32 | for i in range(len(self)): 33 | yield self[i] 34 | 35 | 36 | class M4: 37 | def __init__(self, subset): 38 | self.dataset = _M4(subset) 39 | 40 | def __len__(self): 41 | return len(self.dataset) 42 | 43 | def __iter__(self): 44 | for i in range(len(self)): 45 | yield self[i] 46 | 47 | def __getitem__(self, i): 48 | ts, md = self.dataset[i] 49 | ts = (ts - ts.min(axis=0)) / (ts.max(axis=0) - ts.min(axis=0)) 50 | calib_frac = 0.2 51 | if self.dataset.subset == "Weekly": 52 | horizon = 26 if len(ts) > 200 else 13 53 | n_test = 120 if len(ts) > 400 else 60 if len(ts) > 200 else 26 54 | elif self.dataset.subset == "Daily": 55 | horizon = 28 if len(ts) > 200 else 14 56 | n_test = 120 if len(ts) > 400 else 60 if len(ts) > 200 else 28 57 | calib_frac = 0.05 if len(ts) > 4000 else 0.1 if len(ts) > 2000 else 0.2 58 | elif self.dataset.subset == "Hourly": 59 | horizon = 24 60 | n_test = 120 61 | else: 62 | n_test = (~md.trainval).sum() 63 | horizon = n_test // 2 64 | return dict(train_data=ts.iloc[:-n_test], test_data=ts.iloc[-n_test:], calib_frac=calib_frac, horizon=horizon) 65 | -------------------------------------------------------------------------------- /online_conformal/visualize.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | import math 8 | import matplotlib.pyplot as plt 9 | import pandas as pd 10 | from online_conformal.utils import coverage, interval_miscoverage 11 | 12 | 13 | def plot(ground_truth: pd.DataFrame, pred: pd.DataFrame, lb: pd.DataFrame, ub: pd.DataFrame, title=None, ax=None): 14 | pred_color = "#0072B2" 15 | if ax is None: 16 | fig = plt.figure(facecolor="w", figsize=(10, 6)) 17 | ax = fig.add_subplot(111) 18 | else: 19 | fig = ax.get_figure() 20 | ax.plot(pred.index, pred.values.flatten(), c=pred_color, ls="-", zorder=0) 21 | ax.plot(ground_truth.index, ground_truth.values.flatten(), c="k", alpha=0.8, lw=1, zorder=1) 22 | ax.fill_between(pred.index, lb.values.flatten(), ub.values.flatten(), color=pred_color, alpha=0.2, zorder=2) 23 | ax.set_title(title) 24 | return fig, ax 25 | 26 | 27 | def plot_simulated_forecast(results, horizon=None): 28 | ncols = 3 29 | nrows = math.ceil(len(results) / ncols) 30 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 6, nrows * 5 + 1), facecolor="w") 31 | axs = axs.reshape(nrows, ncols) 32 | delta = 0 33 | target_cov = None 34 | for i, (method, result) in enumerate(results.items()): 35 | if method in ["ModelSigma"]: 36 | continue 37 | i = i + delta 38 | ground_truth = result["ground_truth"] 39 | forecast, lb, ub = result["forecast"] 40 | constructed = pd.DataFrame(0, index=ground_truth.index, columns=["yhat", "lb", "ub"]) 41 | for k in range(len(ground_truth)): 42 | h = horizon if horizon else "full" 43 | constructed.iloc[k] = [forecast[h].iloc[k], lb[h].iloc[k], ub[h].iloc[k]] 44 | target_cov = result["target_cov"] if target_cov is None else target_cov 45 | forecast, lb, ub = constructed["yhat"], constructed["lb"], constructed["ub"] 46 | result = (ground_truth, forecast, lb, ub) 47 | title = ( 48 | f"{method}: coverage={coverage(*result):.3f}, width={(ub - lb).mean() / 2:.2f}, " 49 | f"int miscov={interval_miscoverage(*result, window=20, cov=target_cov):.2f}" 50 | ) 51 | plot(*result, ax=axs[i // ncols, i % ncols], title=title) 52 | name = f"Horizon = {horizon}" if horizon else "Simulated Forecast" 53 | fig.suptitle(f"{name}, Target Coverage = {target_cov}", fontsize=16) 54 | fig.tight_layout() 55 | plt.show() 56 | -------------------------------------------------------------------------------- /online_conformal/ogd.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | from collections import defaultdict 8 | import numpy as np 9 | import pandas as pd 10 | from typing import Tuple 11 | 12 | from online_conformal.base import BasePredictor 13 | from online_conformal.enbpi import EnbMixIn 14 | from online_conformal.utils import pinball_loss_grad, Residuals 15 | 16 | 17 | class ScaleFreeOGD(BasePredictor): 18 | """ 19 | Scale-free online gradient descent to learn conformal confidence intervals via online quantile regression. We 20 | perform online gradient descent on the pinball loss to learn the relevant quantiles of the residuals. 21 | From Orabona & Pal, 2016, "Scale-Free Online Learning." https://arxiv.org/abs/1601.01974. 22 | """ 23 | 24 | def __init__(self, *args, horizon=1, max_scale=None, **kwargs): 25 | self.scale = {} 26 | self.delta = defaultdict(float) 27 | self.grad_norm = defaultdict(float) 28 | if max_scale is None: 29 | self.scale = {} 30 | else: 31 | self.scale = {h + 1: float(max_scale) for h in range(horizon)} 32 | super().__init__(*args, horizon=horizon, **kwargs) 33 | 34 | # Use calibration to initialize learning rate & estimates for deltas 35 | residuals = self.residuals 36 | self.residuals = Residuals(self.horizon) 37 | for h in range(1, self.horizon + 1): 38 | r = residuals.horizon2residuals[h] 39 | if h not in self.scale: 40 | self.scale[h] = 1 if len(r) == 0 else np.max(np.abs(r)) * np.sqrt(3) 41 | self.update(pd.Series(r, dtype=float), pd.Series(np.zeros(len(r))), h) 42 | 43 | def predict(self, horizon) -> Tuple[float, float]: 44 | return -self.delta[horizon], self.delta[horizon] 45 | 46 | def update(self, ground_truth: pd.Series, forecast: pd.Series, horizon): 47 | residuals = np.abs(ground_truth - forecast).values 48 | self.residuals.extend(horizon, residuals.tolist()) 49 | if horizon not in self.scale: 50 | return 51 | for s in residuals: 52 | delta = self.delta[horizon] 53 | grad = pinball_loss_grad(np.abs(s), delta, self.coverage) 54 | self.grad_norm[horizon] += grad**2 55 | if self.grad_norm[horizon] != 0: 56 | self.delta[horizon] = max(0, delta - self.scale[horizon] / np.sqrt(3 * self.grad_norm[horizon]) * grad) 57 | 58 | 59 | class EnbOGD(EnbMixIn, ScaleFreeOGD): 60 | pass 61 | -------------------------------------------------------------------------------- /online_conformal/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | import numpy as np 8 | import pandas as pd 9 | 10 | 11 | def quantile(arr, q, weights=None): 12 | q = np.clip(q, 0, 1) 13 | if len(arr) == 0: 14 | return np.zeros(len(q)) if hasattr(q, "__len__") else 0 15 | if weights is None: 16 | return np.quantile(arr, q, method="inverted_cdf") 17 | assert len(weights) == len(arr) 18 | idx = np.argsort(arr) 19 | weights = np.cumsum(weights[idx]) 20 | q_idx = np.searchsorted(weights / weights[-1], q) 21 | return np.asarray(arr)[idx[q_idx]] 22 | 23 | 24 | def pinball_loss(y, yhat, q: float): 25 | return np.maximum(q * (y - yhat), (1 - q) * (yhat - y)) 26 | 27 | 28 | def pinball_loss_grad(y, yhat, q: float): 29 | return -q * (y > yhat) + (1 - q) * (y < yhat) 30 | 31 | 32 | def interval_miscoverage(y: pd.Series, yhat: pd.Series, lb: pd.Series, ub: pd.Series, window: int, cov: float): 33 | interval_cov = ((lb <= y) & (y <= ub)).rolling(window).mean().dropna() 34 | return np.abs(interval_cov - cov).values.max() 35 | 36 | 37 | def interval_regret(y: pd.Series, yhat: pd.Series, lb: pd.Series, ub: pd.Series, window: int, cov: float): 38 | resid = np.abs(y - yhat) 39 | interval_losses = pinball_loss(resid, (ub - lb) / 2, cov).rolling(window).mean().dropna() 40 | opts = resid.rolling(window).quantile(cov).dropna() 41 | opt_losses = [pinball_loss(resid.values[i : i + window], opt, cov).mean() for i, opt in enumerate(opts)] 42 | return max(interval_losses.values - np.asarray(opt_losses)) if len(opt_losses) == len(interval_losses) else np.nan 43 | 44 | 45 | def coverage(y: pd.Series, yhat: pd.Series, lb: pd.Series, ub: pd.Series): 46 | return ((lb <= y) & (y <= ub)).mean() 47 | 48 | 49 | def mae(y: pd.Series, yhat: pd.Series, lb: pd.Series, ub: pd.Series): 50 | return np.abs(y - yhat).mean() 51 | 52 | 53 | def err_std(y: pd.Series, yhat: pd.Series, lb: pd.Series, ub: pd.Series): 54 | return np.abs(y - yhat).std() 55 | 56 | 57 | def width(y: pd.Series, yhat: pd.Series, lb: pd.Series, ub: pd.Series): 58 | return (ub - lb).median() / 2 59 | 60 | 61 | class Residuals: 62 | def __init__(self, horizon): 63 | assert isinstance(horizon, int) and horizon > 0 64 | self.horizon = horizon 65 | self.horizon2residuals = {h + 1: [] for h in range(self.horizon)} 66 | 67 | def __len__(self): 68 | return max(len(r) for r in self.horizon2residuals.values()) 69 | 70 | def get(self, horizon): 71 | assert isinstance(horizon, int) and 1 <= horizon <= self.horizon, f"Got {horizon}, self.horizon={self.horizon}" 72 | max_h = max(h for h, v in self.horizon2residuals.items() if 3 * len(v) >= len(self)) 73 | return self.horizon2residuals[min(horizon, max_h)] 74 | 75 | def extend(self, horizon, vals): 76 | assert isinstance(horizon, int) and 1 <= horizon <= self.horizon, f"Got {horizon}, self.horizon={self.horizon}" 77 | self.horizon2residuals[horizon] += [v for v in (vals if isinstance(vals, list) else [vals]) if not np.isnan(v)] 78 | 79 | def remove_outliers(self): 80 | for h, resid in self.horizon2residuals.items(): 81 | resid = np.asarray(resid) 82 | resid = resid[~np.isnan(resid)] 83 | z_score = np.abs(resid - np.mean(resid)) / np.std(resid) 84 | self.horizon2residuals[h] = resid[z_score < 5].tolist() 85 | -------------------------------------------------------------------------------- /online_conformal/model_sigma.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | from typing import Dict, Tuple, Union 8 | 9 | from merlion.models.forecast.base import ForecasterBase 10 | from merlion.utils import TimeSeries 11 | import pandas as pd 12 | from scipy.stats import norm 13 | import tqdm 14 | 15 | from online_conformal.base import BasePredictor 16 | 17 | 18 | class ModelSigma(BasePredictor): 19 | """ 20 | Use the model itself to estimate uncertainty. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model: ForecasterBase, 26 | train_data: pd.DataFrame, 27 | calib_frac=None, 28 | coverage=0.9, 29 | horizon=1, 30 | pretrained_model=False, 31 | verbose=False, 32 | calib_residuals=None, 33 | **kwargs, 34 | ): 35 | super().__init__( 36 | model=model, 37 | train_data=train_data, 38 | calib_frac=calib_frac, 39 | coverage=coverage, 40 | horizon=horizon, 41 | pretrained_model=pretrained_model, 42 | verbose=verbose, 43 | calib_residuals=None, 44 | **kwargs, 45 | ) 46 | 47 | def update(self, ground_truth: pd.Series, forecast: pd.Series, horizon: int): 48 | pass 49 | 50 | @classmethod 51 | def from_other(cls, other, **kwargs): 52 | all_kwargs = dict( 53 | model=other.model, 54 | pretrained_model=True, 55 | train_data=other.train_data, 56 | calib_frac=0, 57 | coverage=other.coverage, 58 | horizon=other.horizon, 59 | verbose=other.verbose, 60 | ) 61 | all_kwargs.update(**kwargs) 62 | return cls(**all_kwargs) 63 | 64 | def forecast( 65 | self, time_series: Union[TimeSeries, pd.DataFrame], time_series_prev: Union[TimeSeries, pd.DataFrame] = None 66 | ) -> Tuple[Dict[int, pd.Series], Dict[int, pd.Series], Dict[int, pd.Series]]: 67 | # Process arguments 68 | t0 = time_series.index[0] 69 | if time_series_prev is None: 70 | time_series_prev = self.train_data.iloc[: -self.horizon + 1] 71 | time_series = pd.concat((self.train_data.iloc[-self.horizon + 1 :], time_series)) 72 | 73 | # Forecast in increments of self.horizon & get the error bars along the way 74 | alpha = 1 - self.coverage 75 | yhat, lb, ub = [], [], [] 76 | for i in tqdm.trange(len(time_series), desc="Forecasting", disable=not self.verbose): 77 | y_t = time_series.iloc[i : i + self.horizon] 78 | yhat_t, err_t = self.model.forecast(y_t.index, TimeSeries.from_pd(time_series_prev)) 79 | yhat_t = yhat_t.to_pd().iloc[:, 0] 80 | if err_t is None: 81 | raise RuntimeError(f"Model {type(self.model).__name__} does not support uncertainty estimation") 82 | err_t = err_t.to_pd().iloc[:, 0] 83 | yhat.append(yhat_t) 84 | lb.append(yhat_t + err_t * norm.ppf(alpha / 2)) 85 | ub.append(yhat_t + err_t * norm.ppf(1 - alpha / 2)) 86 | 87 | # Aggregate & return forecasts for each horizon (along with confidence intervals) 88 | yhat, lb, ub = [ 89 | { 90 | h + 1: pd.concat([x.iloc[h : h + 1] for x in ts if len(x) > h and x.index[h] >= t0]) 91 | for h in range(self.horizon) 92 | } 93 | for ts in [yhat, lb, ub] 94 | ] 95 | return yhat, lb, ub 96 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guide 2 | 3 | This page lists the operational governance model of this project, as well as the recommendations and requirements for how to best contribute to it. We strive to obey these as best as possible. As always, thanks for contributing – we hope these guidelines make it easier and shed some light on our approach and processes. 4 | 5 | # Governance Model 6 | 7 | ## Published but not supported 8 | 9 | The intent and goal of open sourcing this project is because it may contain useful or interesting code/concepts that we wish to share with the larger open source community. Although occasional work may be done on it, we will not be looking for or soliciting contributions. 10 | 11 | # Getting started 12 | 13 | Please join the community. Also please make sure to take a look at the project [roadmap](ROADMAP.md), if it exists, to see where are headed. 14 | 15 | # Issues, requests & ideas 16 | 17 | Use GitHub Issues page to submit issues, enhancement requests and discuss ideas. 18 | 19 | ### Bug Reports and Fixes 20 | - If you find a bug, please search for it in the Issues, and if it isn't already tracked, 21 | create a new issue. Fill out the "Bug Report" section of the issue template. Even if an Issue is closed, feel free to comment and add details, it will still 22 | be reviewed. 23 | - Issues that have already been identified as a bug (note: able to reproduce) will be labelled `bug`. 24 | - If you'd like to submit a fix for a bug, [send a Pull Request](#creating_a_pull_request) and mention the Issue number. 25 | - Include tests that isolate the bug and verifies that it was fixed. 26 | 27 | ### New Features 28 | - If you'd like to add new functionality to this project, describe the problem you want to solve in a new Issue. 29 | - Issues that have been identified as a feature request will be labelled `enhancement`. 30 | - If you'd like to implement the new feature, please wait for feedback from the project 31 | maintainers before spending too much time writing the code. In some cases, `enhancement`s may 32 | not align well with the project objectives at the time. 33 | 34 | ### Tests, Documentation, Miscellaneous 35 | - If you'd like to improve the tests, you want to make the documentation clearer, you have an 36 | alternative implementation of something that may have advantages over the way its currently 37 | done, or you have any other change, we would be happy to hear about it! 38 | - If its a trivial change, go ahead and [send a Pull Request](#creating_a_pull_request) with the changes you have in mind. 39 | - If not, open an Issue to discuss the idea first. 40 | 41 | If you're new to our project and looking for some way to make your first contribution, look for 42 | Issues labelled `good first contribution`. 43 | 44 | # Contribution Checklist 45 | 46 | - [x] Clean, simple, well styled code 47 | - [x] Commits should be atomic and messages must be descriptive. Related issues should be mentioned by Issue number. 48 | - [x] Comments 49 | - Module-level & function-level comments. 50 | - Comments on complex blocks of code or algorithms (include references to sources). 51 | - [x] Tests 52 | - The test suite, if provided, must be complete and pass 53 | - Increase code coverage, not versa. 54 | - Use any of our testkits that contains a bunch of testing facilities you would need. For example: `import com.salesforce.op.test._` and borrow inspiration from existing tests. 55 | - [x] Dependencies 56 | - Minimize number of dependencies. 57 | - Prefer Apache 2.0, BSD3, MIT, ISC and MPL licenses. 58 | - [x] Reviews 59 | - Changes must be approved via peer code review 60 | 61 | # Creating a Pull Request 62 | 63 | 1. **Ensure the bug/feature was not already reported** by searching on GitHub under Issues. If none exists, create a new issue so that other contributors can keep track of what you are trying to add/fix and offer suggestions (or let you know if there is already an effort in progress). 64 | 3. **Clone** the forked repo to your machine. 65 | 4. **Create** a new branch to contain your work (e.g. `git br fix-issue-11`) 66 | 4. **Commit** changes to your own branch. 67 | 5. **Push** your work back up to your fork. (e.g. `git push fix-issue-11`) 68 | 6. **Submit** a Pull Request against the `main` branch and refer to the issue(s) you are fixing. Try not to pollute your pull request with unintended changes. Keep it simple and small. 69 | 7. **Sign** the Salesforce CLA (you will be prompted to do so when submitting the Pull Request) 70 | 71 | > **NOTE**: Be sure to [sync your fork](https://help.github.com/articles/syncing-a-fork/) before making a pull request. 72 | 73 | 74 | # Code of Conduct 75 | Please follow our [Code of Conduct](CODE_OF_CONDUCT.md). 76 | 77 | # License 78 | By contributing your code, you agree to license your contribution under the terms of our project [LICENSE](LICENSE.txt) and to sign the [Salesforce CLA](https://cla.salesforce.com/sign-cla) 79 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improved Online Conformal Prediction via Strongly Adaptive Online Learning 2 | This library implements numerous algorithms which perform conformal prediction on data with arbitrary distribution 3 | shifts over time. This is the official implementation for the [paper](https://arxiv.org/abs/2302.07869) Bhatnagar et al., 4 | "Improved Online Conformal Prediction via Strongly Adaptive Online Learning," 2023. We include reference implementations 5 | for the proposed methods Strongly Adaptive Online Conformal Prediction (SAOCP) and Scale-Free Online Gradient Descent 6 | (SF-OGD), as well as Split Conformal Prediction ([Vovk et al., 1999](https://dl.acm.org/doi/10.5555/645528.657641)), 7 | Non-Exchangeable Conformal Prediction ([Barber et al., 2022](https://arxiv.org/pdf/2202.13415.pdf)), and 8 | Fully Adaptive Conformal Inference (FACI, [Gibbs & Candes, 2022](https://arxiv.org/pdf/2208.08401.pdf)). 9 | 10 | ## Replicating Our Experiments 11 | First install the `online_conformal` package by cloning this repo and calling ``pip install .``. 12 | To run our time series forecasting experiments, first clone the [Merlion](https://github.com/salesforce/Merlion) repo 13 | and install their `ts_datasets` package. Then, you can call 14 | ```shell 15 | python time_series.py --model --dataset --njobs 16 | ``` 17 | where `` can be one of `LGBM`, `ARIMA`, or `Prophet`; `` can be one of 18 | `M4_Hourly`, `M4_Daily`, `M4_Weekly`, or `NN5_Daily`; and `` indicates the number of 19 | parallel cores you wish to parallelize the file with. The results will be written to a sub-directory 20 | `results`. 21 | 22 | To run our experiments on image classification under distribution shift, first install [PyTorch](https://pytorch.org/). 23 | Then, you can call 24 | ```shell 25 | python vision.py --dataset 26 | ``` 27 | where dataset is one of `ImageNet` or `TinyImageNet`. Various intermediate results will be written to 28 | sub-folders, and checkpointing (e.g. for model training) is automatic. 29 | 30 | ## Using Our Code 31 | To use our code, first install the `online_conformal` package by calling ``pip install online_conformal``. 32 | You can alternatively install the package from source by cloning this repo and calling ``pip install .``. 33 | 34 | Each online conformal prediction method is implemented as its own class in the package. All methods share a common API. 35 | For time series forecasting, we leverage models implemented in [Merlion](https://github.com/salesforce/Merlion). 36 | Below, we demonstrate how to use `SAOCP` to create prediction intervals for multi-horizon time series forecasting. 37 | The update loop is a simplified version of calling 38 | `saocp.forecast(time_series=test_data.iloc[:horizon], time_series_prev=train_data)`, whose implementation you can find 39 | [here](https://github.com/salesforce/online_conformal/blob/main/online_conformal/base.py#L116). 40 | 41 | ```python 42 | import pandas as pd 43 | from merlion.models.factory import ModelFactory 44 | from merlion.utils import TimeSeries 45 | from online_conformal.dataset import M4 46 | from online_conformal.saocp import SAOCP 47 | 48 | # Get some time series data as pandas.DataFrames 49 | data = M4("Hourly")[0] 50 | train_data, test_data = data["train_data"], data["test_data"] 51 | # Initialize a Merlion model for time series forecasting 52 | model = ModelFactory.create(name="LGBMForecaster") 53 | # Initialize the SAOCP wrapper on top of the model. This splits the data 54 | # into train/calibration splits, trains the model on the train split, 55 | # and initializes SAOCP's internal state on the calibration split. 56 | # The target coverage is 90% here, but you can adjust this freely. 57 | # We also do 24-step-ahead forecasting by setting horizon=24. 58 | horizon = 24 59 | saocp = SAOCP(model=model, train_data=train_data, coverage=0.9, 60 | calib_frac=0.2, horizon=horizon) 61 | 62 | # Get the model's 24-step-ahead prediction, and convert it to prediction intervals 63 | yhat, _ = saocp.model.forecast(horizon, time_series_prev=TimeSeries.from_pd(train_data)) 64 | delta_lb, delta_ub = zip(*[saocp.predict(horizon=h + 1) for h in range(horizon)]) 65 | yhat = yhat.to_pd().iloc[:, 0] 66 | lb, ub = yhat + delta_lb, yhat + delta_ub 67 | 68 | # Update SAOCP's internal state based on the next 24 observations 69 | prev = train_data.iloc[:-horizon + 1] 70 | time_series = pd.concat((train_data.iloc[-horizon + 1:], test_data.iloc[:horizon])) 71 | for i in range(len(time_series)): 72 | # Predict yhat_{t-H+i+1}, ..., yhat_{t-H+i+H} = f(y_1, ..., y_{t-H+i}) 73 | y = time_series.iloc[i:i + horizon, 0] 74 | yhat, _ = saocp.model.forecast(y.index, time_series_prev=TimeSeries.from_pd(prev)) 75 | yhat = yhat.to_pd().iloc[:, 0] 76 | # Use h-step prediction of yhat_{t-k+h} to update SAOCP's h-step prediction interval 77 | for h in range(len(y)): 78 | if i >= h: 79 | saocp.update(ground_truth=y[h:h + 1], forecast=yhat[h:h + 1], horizon=h + 1) 80 | prev = pd.concat((prev, time_series.iloc[i:i+1])) 81 | ``` 82 | 83 | For other use cases, you can initialize `saocp = SAOCP(model=None, train_data=None, max_scale=max_scale, coverage=0.9)`. 84 | Here, `max_scale` indicates the largest value you expect the conformal score to take. Then, you can obtain the conformal 85 | score corresponding to 90% (or your desired level of coverage) by calling `score = saocp.predict(horizon=1)[1]`, and 86 | you can use this value to compute the prediction set `{y: S(X_t, y) < score}` using your own custom code. Finally, after 87 | you observe the true conformal score `new_score = S(X_t, Y_t)`, you can update the conformal predictor by calling 88 | `saocp.update(ground_truth=pd.Series([new_score]), forecast=pd.Series([0]), horizon=1)`. 89 | -------------------------------------------------------------------------------- /online_conformal/saocp.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | from typing import Dict, Tuple 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from online_conformal.base import BasePredictor 13 | from online_conformal.enbpi import EnbMixIn 14 | from online_conformal.utils import pinball_loss, pinball_loss_grad, Residuals 15 | 16 | 17 | class _OGD: 18 | def __init__(self, t, scale, alpha, yhat_0, g=8): 19 | """ 20 | Instantiates an online gradient descent learner which starts at time t and has a finite lifetime. The lifetime 21 | is given by the data streaming intervals described in Hazan & Seshadhri, 2007, "Adaptive Algorithms for Online 22 | Decision Problems." (Appendix B) https://www.cs.princeton.edu/techreports/2007/798.pdf. The underlying algorithm 23 | is Scale-Free Online Mirror Descent with regularizer ||y - y_0||^2 / 2 (https://arxiv.org/abs/1601.01974). 24 | """ 25 | # Scale-free online gradient descent parameters 26 | self.scale = scale 27 | self.base_lr = scale / np.sqrt(3) 28 | self.alpha = alpha 29 | self.yhat = yhat_0 30 | self.grad_norm = 0 31 | 32 | # Meta-algorithm parameters 33 | u = 0 34 | while t % 2 == 0: 35 | t /= 2 36 | u += 1 37 | self.lifetime = g * 2**u 38 | self.z = 0 # sum of differences between losses & meta-losses 39 | self.wz = 0 # weighted sum of differences between losses & meta-losses 40 | self.s_t = 0 # how long the predictor has been alive 41 | 42 | @property 43 | def expired(self): 44 | return self.s_t > self.lifetime 45 | 46 | def loss(self, y): 47 | return pinball_loss(y, self.yhat, 1 - self.alpha) 48 | 49 | @property 50 | def w(self): 51 | return 0 if self.s_t == 0 else self.z / self.s_t * (1 + self.wz) 52 | 53 | def update(self, y, meta_loss): 54 | # Update meta-algorithm weights 55 | w = self.w 56 | g = np.clip((meta_loss - self.loss(y)) / self.scale / max(self.alpha, 1 - self.alpha), -1 * (w > 0), 1) 57 | self.z += g 58 | self.wz += g * w 59 | self.s_t += 1 60 | 61 | # Update estimator 62 | grad = pinball_loss_grad(y, self.yhat, 1 - self.alpha) 63 | self.grad_norm += grad**2 64 | if self.grad_norm != 0: 65 | self.yhat = max(0, self.yhat - self.base_lr / np.sqrt(self.grad_norm) * grad) 66 | 67 | 68 | class SAOCP(BasePredictor): 69 | """ 70 | Strongly Adaptive Online Conformal Prediction (SAOCP). The main algorithm of the paper. This algorithm adapts 71 | Coin Betting for Changing Environment (CBCE) to learn conformal confidence intervals. 72 | From Jun et al., 2017, "Improved Strongly Adaptive Online Learning using Coin Betting". 73 | https://proceedings.mlr.press/v54/jun17a/jun17a-supp.pdf. 74 | """ 75 | 76 | def __init__(self, *args, horizon=1, max_scale=None, lifetime=8, **kwargs): 77 | self.t = 1 78 | if max_scale is None: 79 | self.scale = {} 80 | else: 81 | self.scale = {h + 1: max_scale for h in range(horizon)} 82 | self.experts = {h + 1: {} for h in range(horizon)} 83 | self.lifetime = lifetime 84 | super().__init__(*args, horizon=horizon, **kwargs) 85 | 86 | residuals = self.residuals 87 | self.residuals = Residuals(self.horizon) 88 | for h in range(1, self.horizon + 1): 89 | r = residuals.horizon2residuals[h] 90 | if h not in self.scale: 91 | self.scale[h] = 1 if len(r) == 0 else np.max(np.abs(r)) * np.sqrt(3) 92 | self.update(pd.Series(r, dtype=float), pd.Series(np.zeros(len(r))), h) 93 | 94 | def get_p(self, horizon) -> Dict[int, float]: 95 | experts = self.experts[horizon] 96 | prior = {t: 1 / (t**2 * (1 + np.floor(np.log2(t)))) for t in experts.keys()} 97 | z = sum(prior.values()) 98 | prior = {t: v / z for t, v in prior.items()} 99 | p = {t: prior[t] * max(0, expert.w) for t, expert in experts.items()} 100 | sum_p = sum(p.values()) 101 | return {t: v / sum_p for t, v in p.items()} if sum_p > 0 else prior 102 | 103 | def predict(self, horizon) -> Tuple[float, float]: 104 | p = self.get_p(horizon) 105 | delta = sum(p[t] * expert.yhat for t, expert in self.experts[horizon].items()) 106 | return -delta, delta 107 | 108 | def create_expert(self, horizon, s_hat): 109 | return _OGD(self.t, self.scale[horizon], 1 - self.coverage, g=self.lifetime, yhat_0=s_hat) 110 | 111 | def update(self, ground_truth: pd.Series, forecast: pd.Series, horizon: int): 112 | residuals = np.abs(ground_truth - forecast).values 113 | self.residuals.extend(horizon, residuals.tolist()) 114 | if horizon not in self.scale: 115 | return 116 | 117 | experts = self.experts[horizon] 118 | for s in residuals: 119 | # Remove expired experts & add new expert 120 | _, s_hat = self.predict(horizon) 121 | [experts.pop(t) for t in [k for k, v in experts.items() if v.expired]] 122 | experts[self.t] = self.create_expert(horizon, s_hat) 123 | 124 | # Update experts 125 | meta_loss = pinball_loss(s, self.predict(horizon)[1], self.coverage) 126 | [expert.update(s, meta_loss) for expert in experts.values()] 127 | self.t += 1 128 | 129 | 130 | class EnbSAOCP(EnbMixIn, SAOCP): 131 | pass 132 | -------------------------------------------------------------------------------- /online_conformal/enbpi.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | from abc import ABC 8 | import copy 9 | 10 | from merlion.models.ensemble.forecast import ForecasterEnsemble 11 | from merlion.models.forecast.sklearn_base import SKLearnForecaster 12 | from merlion.models.utils.rolling_window_dataset import RollingWindowDataset 13 | from merlion.transform.base import Identity 14 | from merlion.utils import TimeSeries 15 | import numpy as np 16 | import pandas as pd 17 | 18 | from online_conformal.base import BasePredictor 19 | from online_conformal.split_conformal import SplitConformal 20 | from online_conformal.utils import Residuals 21 | 22 | 23 | class EnbMixIn(BasePredictor, ABC): 24 | """ 25 | Mix-in class for Ensemble Prediction Intervals (EnbPI), which is the algorithm proposed by Xu & Xie, 2020, 26 | "Conformal prediction for time series." https://arxiv.org/abs/2010.09107. 27 | 28 | Inheriting from this class will transform any base predictor into one uses leave-one-out ensembles. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | model: SKLearnForecaster, 34 | train_data: pd.DataFrame, 35 | calib_frac=None, 36 | coverage=0.9, 37 | horizon=1, 38 | pretrained_model=False, 39 | verbose=False, 40 | calib_residuals=None, 41 | retrain=False, 42 | **kwargs 43 | ): 44 | if isinstance(model, SKLearnForecaster): 45 | model = copy.deepcopy(model) 46 | model.config.transform = Identity() 47 | model.config.prediction_stride = horizon 48 | if not pretrained_model: 49 | model.train(TimeSeries.from_pd(train_data)) 50 | 51 | # Get the rolling window dataset 52 | target_idx = ( 53 | None if train_data.shape[1] > 1 and model.prediction_stride == 1 else (model.target_seq_index or 0) 54 | ) 55 | dataset = RollingWindowDataset( 56 | train_data, 57 | target_seq_index=target_idx, 58 | n_past=model.maxlags, 59 | n_future=model.prediction_stride, 60 | batch_size=None, 61 | ) 62 | inputs, inputs_ts, labels, labels_ts = next(iter(dataset)) 63 | non_nan = ~np.isnan(inputs).any(axis=1) & ~np.isnan(labels).any(axis=1) 64 | inputs, labels = inputs[non_nan], labels[non_nan] 65 | n = len(inputs) 66 | 67 | # Create copies of the models & re-train each copy on a subset of the data. 68 | # Obtain the forecasted values as if doing cross-validation. 69 | b = min(5, n) 70 | state = np.random.RandomState(0) 71 | models = [copy.deepcopy(model) for _ in range(b)] 72 | excluded_idx_sets = np.array_split(state.permutation(n), b) 73 | residuals = np.zeros((n, 1 if target_idx is None else labels.shape[1])) 74 | for k, (model_copy, excluded) in enumerate(zip(models, excluded_idx_sets)): 75 | # Train on (n - 1) / n of the data 76 | included = np.asarray(sorted(set(range(n)).difference(excluded)), dtype=int) 77 | model_copy.model.fit(inputs[included], labels[included]) 78 | # Get residuals on the remaining 1 / n of the data 79 | predict = model_copy.model.predict(inputs[excluded]) 80 | predict = predict[:, model_copy.target_seq_index] if target_idx is None else predict 81 | ground_truth = labels[excluded, model_copy.target_seq_index] if target_idx is None else labels[excluded] 82 | residuals[excluded] = (ground_truth - predict).reshape((-1, residuals.shape[1])) 83 | 84 | # Create an ensemble model & call the superclass initializer with it 85 | model = ForecasterEnsemble(models=models) 86 | model.train_pre_process(TimeSeries.from_pd(train_data)) 87 | residuals = residuals[:, :horizon] 88 | calib_residuals = Residuals(horizon) 89 | for h in range(residuals.shape[1]): 90 | calib_residuals.extend(h + 1, residuals[:, h].tolist()) 91 | 92 | elif not isinstance(model, ForecasterEnsemble): 93 | b = 5 94 | calib_residuals = Residuals(horizon) 95 | models = [copy.deepcopy(model) for _ in range(b)] 96 | excluded_idx_sets = np.array_split(np.arange(len(train_data), dtype=int), b + 1)[1:] 97 | for k, (model, excluded) in enumerate(zip(models, excluded_idx_sets)): 98 | # Train model on the non-excluded data points 99 | model.reset() 100 | included = np.asarray(sorted(set(range(len(train_data))).difference(excluded))) 101 | model.train(TimeSeries.from_pd(train_data.iloc[included])) 102 | 103 | # Update residuals for all horizons 104 | t0 = train_data.index[excluded[0]] 105 | ts = train_data.iloc[excluded[0] - horizon : excluded[-1] + 1] 106 | ts_prev = train_data.iloc[: excluded[0] - horizon] 107 | for i in range(len(ts)): 108 | y_t = ts.iloc[i : i + horizon] 109 | yhat_t = model.forecast(y_t.index, TimeSeries.from_pd(ts_prev))[0].to_pd().iloc[:, 0] 110 | resid = y_t.iloc[:, model.target_seq_index] - yhat_t 111 | ts_prev = pd.concat((ts_prev, y_t.iloc[:1])) 112 | for h, r in enumerate(resid): 113 | if y_t.index[h] >= t0: 114 | calib_residuals.extend(h + 1, r) 115 | 116 | # Create an ensemble model to call the superclass initializer with 117 | model = ForecasterEnsemble(models=models) 118 | model.train_pre_process(TimeSeries.from_pd(train_data)) 119 | 120 | super().__init__( 121 | model=model, 122 | train_data=train_data, 123 | calib_frac=0, 124 | coverage=coverage, 125 | horizon=horizon, 126 | pretrained_model=True, 127 | verbose=verbose, 128 | calib_residuals=calib_residuals, 129 | retrain=False, 130 | **kwargs 131 | ) 132 | 133 | 134 | class EnbPI(EnbMixIn, SplitConformal): 135 | pass 136 | -------------------------------------------------------------------------------- /online_conformal/base.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | from abc import ABC, abstractmethod 8 | import copy 9 | import math 10 | from typing import Dict, Tuple 11 | 12 | from merlion.models.forecast.base import ForecasterBase 13 | from merlion.utils import TimeSeries 14 | import numpy as np 15 | import pandas as pd 16 | import tqdm 17 | 18 | from online_conformal.utils import quantile, Residuals 19 | 20 | 21 | class BasePredictor(ABC): 22 | def __init__( 23 | self, 24 | model: ForecasterBase, 25 | train_data: pd.DataFrame, 26 | calib_frac=None, 27 | coverage=0.9, 28 | horizon=1, 29 | pretrained_model=False, 30 | verbose=False, 31 | calib_residuals=None, 32 | retrain=False, 33 | **kwargs, 34 | ): 35 | self.coverage = coverage 36 | self.horizon = horizon 37 | self.verbose = verbose 38 | 39 | # Split data into train & calibration splits after saving the full training dataset 40 | self.train_data = train_data 41 | if calib_frac is None and train_data is not None: 42 | delta, epsilon = 0.1, 1 - coverage 43 | # DKW inequality: sup_x |F_n(x) - F(x)| > epsilon w.p. at most delta if n is above the below threshold 44 | n_calib = -np.log(delta / 2) / 2 / epsilon**2 45 | # But we don't want to use more than 20% of the data for calibration unless otherwise specified 46 | n_calib = math.ceil(min(0.2 * len(train_data), n_calib)) 47 | elif train_data is not None: 48 | calib_frac = max(0.0, min(1.0, calib_frac)) 49 | n_calib = math.ceil(len(train_data) * calib_frac) 50 | else: 51 | n_calib = 0 52 | if n_calib == 0: 53 | calib_data = None 54 | else: 55 | calib_data = train_data.iloc[-n_calib:] 56 | train_data = train_data.iloc[:-n_calib] 57 | 58 | # Train model if needed 59 | if model is not None: 60 | assert isinstance(model, ForecasterBase) 61 | assert isinstance(model.target_seq_index, int) or train_data.shape[1] == 1 62 | else: 63 | pretrained_model = True 64 | self.model = model 65 | if not pretrained_model: 66 | self.model.reset() 67 | self.model.train(TimeSeries.from_pd(train_data)) 68 | 69 | # Make predictions on the calibration data, updating the conformal wrapper in the process 70 | self.residuals = Residuals(self.horizon) 71 | if calib_residuals is not None: 72 | self.calib_residuals = copy.deepcopy(calib_residuals) 73 | for h, r in calib_residuals.horizon2residuals.items(): 74 | if len(r) > 0: 75 | self.update(pd.Series(r, dtype=float), pd.Series(np.zeros(len(r))), horizon=h) 76 | else: 77 | self.calib_residuals = None 78 | if calib_data is not None: 79 | self.forecast(calib_data, train_data) 80 | self.residuals.remove_outliers() 81 | self.calib_residuals = copy.deepcopy(self.residuals) 82 | 83 | if retrain and not pretrained_model: 84 | self.model.reset() 85 | self.model.train(TimeSeries.from_pd(self.train_data)) 86 | 87 | @abstractmethod 88 | def update(self, ground_truth: pd.Series, forecast: pd.Series, horizon: int): 89 | raise NotImplementedError 90 | 91 | @classmethod 92 | def from_other(cls, other, **kwargs): 93 | assert isinstance(other, BasePredictor) 94 | all_kwargs = dict( 95 | model=other.model, 96 | pretrained_model=True, 97 | train_data=other.train_data, 98 | calib_frac=0, 99 | coverage=other.coverage, 100 | horizon=other.horizon, 101 | calib_residuals=other.calib_residuals, 102 | verbose=other.verbose, 103 | ) 104 | all_kwargs.update(**kwargs) 105 | return cls(**all_kwargs) 106 | 107 | @staticmethod 108 | def quantile(arr, q): 109 | return quantile(arr, q) 110 | 111 | def predict(self, horizon) -> Tuple[float, float]: 112 | delta_ub = self.quantile(np.abs(self.residuals.get(horizon)), self.coverage) 113 | delta_lb = -delta_ub 114 | return delta_lb, delta_ub 115 | 116 | def forecast( 117 | self, time_series: pd.DataFrame, time_series_prev: pd.DataFrame = None 118 | ) -> Tuple[Dict[int, pd.Series], Dict[int, pd.Series], Dict[int, pd.Series]]: 119 | # Process arguments 120 | t0 = time_series.index[0] 121 | if time_series_prev is None: 122 | time_series_prev = self.train_data.iloc[: -self.horizon + 1] 123 | time_series = pd.concat((self.train_data.iloc[-self.horizon + 1 :], time_series)) 124 | 125 | # Forecast in increments of self.horizon & update the conformal wrapper's internal state along the way. 126 | yhat, lb, ub = [], [], [] 127 | for i in tqdm.trange(len(time_series), desc="Forecasting", disable=not self.verbose): 128 | y_t = time_series.iloc[i : i + self.horizon] 129 | yhat_t, _ = self.model.forecast(y_t.index, TimeSeries.from_pd(time_series_prev)) 130 | yhat_t = yhat_t.to_pd().iloc[:, 0] 131 | yhat.append(yhat_t) 132 | lb_t, ub_t = zip(*[self.predict(h + 1) for h in range(len(yhat_t))]) 133 | lb.append(yhat_t + np.asarray(lb_t)) 134 | ub.append(yhat_t + np.asarray(ub_t)) 135 | time_series_prev = pd.concat((time_series_prev, y_t.iloc[:1])) 136 | for h in range(len(y_t)): 137 | idx = self.model.target_seq_index 138 | if y_t.index[h] >= t0 and not np.isnan(y_t.iloc[h, idx]) and not np.isnan(yhat_t.iloc[h]): 139 | self.update(y_t.iloc[h : h + 1, idx], yhat_t.iloc[h : h + 1], horizon=h + 1) 140 | 141 | # Aggregate & return forecasts for each horizon (along with confidence intervals) 142 | yhat, lb, ub = [ 143 | { 144 | h + 1: pd.concat([x.iloc[h : h + 1] for x in ts if len(x) > h and x.index[h] >= t0]) 145 | for h in range(self.horizon) 146 | } 147 | for ts in [yhat, lb, ub] 148 | ] 149 | return yhat, lb, ub 150 | -------------------------------------------------------------------------------- /online_conformal/faci.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | import math 8 | from typing import Tuple 9 | 10 | import numpy as np 11 | import pandas as pd 12 | from scipy.special import logsumexp 13 | 14 | from online_conformal.base import BasePredictor 15 | from online_conformal.enbpi import EnbMixIn 16 | from online_conformal.utils import pinball_loss, pinball_loss_grad, Residuals 17 | 18 | 19 | class FACI(BasePredictor): 20 | """ 21 | Fully Adaptive Conformal Inference, which is the algorithm proposed by Gibbs & Candes, 2022, 22 | "Conformal Inference for Online Prediction with Arbitrary Distribution Shifts." https://arxiv.org/abs/2208.08401 23 | """ 24 | 25 | def __init__(self, *args, horizon=1, coverage=0.9, **kwargs): 26 | self.gammas = np.asarray([0.001 * 2**k for k in range(8)]) 27 | self.alphas = np.full((horizon, self.k), 1 - coverage) 28 | self.log_w = np.zeros((horizon, self.k)) 29 | super().__init__(*args, horizon=horizon, coverage=coverage, **kwargs) 30 | 31 | @property 32 | def I(self): 33 | return 100 34 | 35 | @property 36 | def k(self): 37 | return len(self.gammas) 38 | 39 | @property 40 | def sigma(self): 41 | return 1 / (2 * self.I) 42 | 43 | @property 44 | def eta(self): 45 | alpha = 1 - self.coverage 46 | denom = ((1 - alpha) ** 2 * alpha**3 + alpha**2 * (1 - alpha) ** 3) / 3 47 | return np.sqrt(3 / self.I) * np.sqrt((np.log(self.I * self.k) + 2) / denom) 48 | 49 | def predict(self, horizon) -> Tuple[float, float]: 50 | log_w = self.log_w[horizon - 1] 51 | alpha = np.dot(np.exp(log_w - logsumexp(log_w)), self.alphas[horizon - 1]) 52 | delta = self.quantile(np.abs(self.residuals.get(horizon)), 1 - alpha) 53 | return -delta, delta 54 | 55 | def update(self, ground_truth: pd.Series, forecast: pd.Series, horizon: int): 56 | h = horizon - 1 57 | residuals = self.residuals.horizon2residuals[horizon] 58 | for s in np.abs(forecast - ground_truth).values: 59 | if len(residuals) > math.floor(1 / (1 - self.coverage)): 60 | # Compute pinball losses incurred by the current residual 61 | beta = np.mean(residuals >= s) 62 | losses = pinball_loss(beta, self.alphas[h], 1 - self.coverage) 63 | 64 | # Update weights 65 | wbar = self.log_w[h] - self.eta * losses 66 | self.log_w[h] = logsumexp( 67 | [wbar, np.full(self.k, logsumexp(wbar))], b=[[1 - self.sigma], [self.sigma / self.k]], axis=0 68 | ) 69 | self.log_w[h] = self.log_w[h] - logsumexp(self.log_w[h]) 70 | 71 | # Compute coverage errors & update alphas 72 | err = self.alphas[h] > beta 73 | self.alphas[h] = np.clip(self.alphas[h] + self.gammas * ((1 - self.coverage) - err), 0, 1) 74 | residuals.append(s) 75 | 76 | 77 | class FACI_S(BasePredictor): 78 | """FACI on radius, instead of quantiles.""" 79 | 80 | def __init__(self, *args, horizon=1, coverage=0.9, max_scale=None, **kwargs): 81 | self.gammas = np.asarray([0.001 * 2**k for k in range(8)]) 82 | self.s_hats = np.zeros((horizon, self.k)) 83 | self.log_w = np.zeros((horizon, self.k)) 84 | if max_scale is None: 85 | self.scale = {} 86 | else: 87 | self.scale = {h + 1: float(max_scale) for h in range(horizon)} 88 | self.prev_loss_sq = {h + 1: [] for h in range(horizon)} 89 | super().__init__(*args, horizon=horizon, coverage=coverage, **kwargs) 90 | 91 | # Use calibration to initialize learning rate & estimates for deltas 92 | residuals = self.residuals 93 | self.residuals = Residuals(self.horizon) 94 | for h in range(1, self.horizon + 1): 95 | r = residuals.horizon2residuals[h] 96 | self.scale[h] = 1 if len(r) == 0 else np.max(np.abs(r)) * np.sqrt(3) 97 | self.update(pd.Series(r, dtype=float), pd.Series(np.zeros(len(r))), h) 98 | 99 | @property 100 | def I(self): 101 | return 100 102 | 103 | @property 104 | def k(self): 105 | return len(self.gammas) 106 | 107 | @property 108 | def sigma(self): 109 | return 1 / (2 * self.I) 110 | 111 | def eta(self, horizon): 112 | loss_sq = self.prev_loss_sq[horizon][-self.I :] 113 | if len(loss_sq) == 0: 114 | loss_sq_sum = self.I * (self.scale[horizon] * self.coverage) ** 2 115 | else: 116 | loss_sq_sum = np.sum(loss_sq) * (self.I / len(loss_sq)) 117 | return np.sqrt((np.log(self.k * self.I) + 2) / loss_sq_sum) 118 | 119 | def predict(self, horizon) -> Tuple[float, float]: 120 | log_w = self.log_w[horizon - 1] 121 | s_hat = np.dot(np.exp(log_w - logsumexp(log_w)), self.s_hats[horizon - 1]) 122 | return -s_hat, s_hat 123 | 124 | def update(self, ground_truth: pd.Series, forecast: pd.Series, horizon: int): 125 | residuals = np.abs(ground_truth - forecast).values 126 | self.residuals.extend(horizon, residuals.tolist()) 127 | if horizon not in self.scale: 128 | return 129 | h = horizon - 1 130 | for s in np.abs(forecast - ground_truth).values: 131 | if horizon in self.scale: 132 | # Compute loss 133 | w = np.exp(self.log_w[h] - logsumexp(self.log_w[h])) 134 | losses = pinball_loss(s, self.s_hats[h], self.coverage) 135 | 136 | # Update weights 137 | wbar = self.log_w[h] - self.eta(horizon) * losses 138 | self.log_w[h] = logsumexp( 139 | [wbar, np.full(self.k, logsumexp(wbar))], b=[[1 - self.sigma], [self.sigma / self.k]], axis=0 140 | ) 141 | self.log_w[h] = self.log_w[h] - logsumexp(self.log_w[h]) 142 | 143 | # Add previous expected loss squared to the list 144 | self.prev_loss_sq[horizon].append(np.dot(w, losses**2)) 145 | 146 | # Update s_hat's 147 | grad = pinball_loss_grad(s, self.s_hats[h], self.coverage) 148 | self.s_hats[h] = self.s_hats[h] - self.gammas * self.scale[horizon] * grad 149 | 150 | 151 | class EnbFACI(EnbMixIn, FACI): 152 | pass 153 | -------------------------------------------------------------------------------- /make_table.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | import argparse 8 | import functools 9 | import itertools 10 | import os 11 | import re 12 | import numpy as np 13 | import pandas as pd 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--coverage", default=90, type=int) 19 | parser.add_argument("--window", default=20, type=int) 20 | parser.add_argument("--ensemble", action="store_true", default=False) 21 | parser.add_argument("--print_sd", action="store_true", default=False) 22 | parser.add_argument("--markdown", action="store_true", default=False) 23 | parser.add_argument("--interval_table", action="store_true", default=False) 24 | parser.add_argument("--dataset", type=str, default="M4_Hourly", help="Only used for interval analysis") 25 | parser.add_argument("--model", type=str, default="LGBM", help="Only used for interval analysis") 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def df_to_str_df(df, interval, print_sd): 31 | if interval or not print_sd: 32 | return df.applymap(lambda x: f"{(np.round(x, 3)):.3f}") 33 | return df.apply( 34 | lambda c: c.apply( 35 | lambda x: (f"{(np.round(x, 3)):.3f}" if "Reg" in c.name else f"{(np.round(x, 3)):.3f}").lstrip("0") 36 | ) 37 | ) 38 | 39 | 40 | def combine_sd_df(df, interval, print_sd): 41 | non_sd_cols = [c for c in df.columns if "SD" not in c] 42 | sd_cols = [c + " SD" for c in non_sd_cols] 43 | str_df = df_to_str_df(df.loc[:, non_sd_cols], interval=interval, print_sd=print_sd) 44 | if print_sd and all(c in df.columns for c in sd_cols): 45 | str_df = str_df + "\\textsubscript{" + df_to_str_df(df.loc[:, sd_cols], interval, True).values + "}" 46 | return str_df 47 | 48 | 49 | def rename_stats(stat): 50 | stat = re.sub("Interval Miscoverage", r"$\\mathrm{LCE}_k$", stat) 51 | stat = re.sub("Interval Regret", r"$\\mathrm{SAReg}_k$", stat) 52 | return stat 53 | 54 | 55 | def bold_best(v, dataset, full_df, target_cov): 56 | stat = v.name[1] 57 | if "Coverage" in full_df.columns.get_level_values(2): 58 | cov = full_df.loc[:, (dataset, v.name[0], "Coverage")] 59 | valid = np.abs(cov - target_cov) < 0.05 60 | else: 61 | valid = [True] * len(v) 62 | v = full_df.loc[:, (dataset, *v.name)] 63 | v_sort = sorted(np.round(v.loc[valid].dropna(), 3)) 64 | if stat == "Coverage": 65 | return ["color: ForestGreen" if v else "color: red" for v in valid] 66 | else: 67 | best = [False] * len(v) if len(v_sort) < 1 else (np.round(v, 3) == v_sort[0]) & valid 68 | second_best = [False] * len(v) if len(v_sort) < 2 else (np.round(v, 3) == v_sort[1]) & valid 69 | return ["font-weight: bold" if b else "font-style: italic" if b2 else "" for b, b2 in zip(best, second_best)] 70 | 71 | 72 | def md_rename(method): 73 | return re.sub("ScaleFree", "SF-", re.sub("Split", "S", re.sub("Conformal", "CP", re.sub("ACI_", "ACI-", method)))) 74 | 75 | 76 | def tex_formatting(tex_str): 77 | # Rename methods to match paper 78 | tex_str = re.sub("SAOCP", r"\\method{}", tex_str) 79 | tex_str = re.sub("OGD", r"\\methodBasic{}", re.sub("ScaleFree", "", tex_str)) 80 | tex_str = re.sub("Split", "S", re.sub("Conformal", "CP", re.sub("ACI_", "ACI-", tex_str))) 81 | # Update formatting. Underline second-best instead of italicize, and make index more compact 82 | tex_str = re.sub(r"\\itshape\s*([\d.]*)(\\textsubscript\{[\d.]*\})?", r"\\underline{\1\2}", tex_str) 83 | tex_str = re.sub(f"(?m)^(.*?Coverage)", r"Method\1", re.sub("Method.*?\n", "", tex_str)) 84 | # Put methods in the right order 85 | lines = tex_str.split("\n") 86 | order = ["ModelSigma", "SCP", "NExCP", "FACI", r"\\methodBasic{}", "FACI-S", r"\\method{}"] 87 | order += ["EnbPI", "EnbNEx", "EnbFACI", r"Enb\\methodBasic{}", r"Enb\\method{}"] 88 | model_lines = sum([[i for i, line in enumerate(lines) if re.match(f"^\\s*{m}\\s*&", line)] for m in order], []) 89 | line_order = list(range(min(model_lines))) + model_lines + list(range(max(model_lines) + 1, len(lines))) 90 | tex_str = "\n".join([lines[i] for i in line_order]) 91 | return tex_str 92 | 93 | 94 | def primary_table(args): 95 | full_df, full_str_df = None, None 96 | models = ["LGBM", "ARIMA", "Prophet"] 97 | datasets = ["M4_Hourly", "M4_Daily", "M4_Weekly", "NN5_Daily"] 98 | mae_idx = "Enb" if args.ensemble else "Base" 99 | dirname = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results") 100 | for dataset, model in itertools.product(datasets, models): 101 | fname = "results_enb.csv" if args.ensemble else "results_base.csv" 102 | fname = os.path.join(dirname, dataset, model, f"k={args.window}", fname) 103 | if os.path.exists(fname): 104 | df = pd.read_csv(fname, index_col=0) 105 | df = df[df["Target Coverage"] == args.coverage / 100].drop(columns="Target Coverage") 106 | df = df.rename(columns=rename_stats, index=lambda s: re.sub("CBCE", "SAOCP", s)) 107 | if df.isna().all().all() or len(df) == 0: 108 | continue 109 | for m in ["ModelSigma", "SimpleSAOCP"]: 110 | if m in df.index: 111 | df = df.drop(labels=[m]) 112 | str_df = combine_sd_df(df, interval=False, print_sd=args.print_sd) 113 | mae = pd.read_csv(os.path.join(dirname, dataset, model, "mae.csv"), index_col=0).loc[mae_idx, "MAE"] 114 | model = f"{model} (MAE = {mae:.2f})" 115 | df.columns = pd.MultiIndex.from_tuples([(re.sub("_", " ", dataset), model, c) for c in df.columns]) 116 | str_df.columns = pd.MultiIndex.from_tuples([(re.sub("_", " ", dataset), model, c) for c in str_df.columns]) 117 | full_df = df if full_df is None else pd.concat((full_df, df), axis=1) 118 | full_str_df = str_df if full_str_df is None else pd.concat((full_str_df, str_df), axis=1) 119 | 120 | models = full_df.columns.get_level_values(1) 121 | datasets = full_df.columns.get_level_values(0) 122 | models = models[sorted(np.unique(models, return_index=True)[1])] 123 | datasets = datasets[sorted(np.unique(datasets, return_index=True)[1])] 124 | for dataset in datasets: 125 | df = full_str_df.loc[:, dataset] 126 | if args.markdown: 127 | print(df.rename(index=md_rename).to_markdown()) 128 | return 129 | highlight = functools.partial(bold_best, dataset=dataset, full_df=full_df, target_cov=args.coverage / 100) 130 | styler = df.style.format(na_rep="--").apply(highlight).hide(axis=1, level=2) 131 | _models = [m for m in models if m in df.columns.get_level_values(0)] 132 | tex_str = styler.to_latex( 133 | hrules=True, 134 | convert_css=True, 135 | multicol_align="c|", 136 | column_format="l" + "".join(("|" + "c" * df.loc[:, m].shape[1]) for m in _models), 137 | ) 138 | print(dataset) 139 | # No vrule after last multicol 140 | tex_str = re.sub(r"(multicolumn{\d+}{c)\|(}{" + re.sub(r"\(.*?\)", ".*?", _models[-1]) + "})", r"\1\2", tex_str) 141 | print(tex_formatting(tex_str)) 142 | 143 | 144 | def interval_table(args): 145 | dirname = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", args.dataset, args.model) 146 | fname = "results_enb.csv" if args.ensemble else "results_base.csv" 147 | ks = sorted([int(k[2:]) for k in os.listdir(dirname) if k.startswith("k=")]) 148 | full_df, full_str_df = None, None 149 | for k in ks: 150 | df = pd.read_csv(os.path.join(dirname, f"k={k}", fname), index_col=0) 151 | df = df[df["Target Coverage"] == args.coverage / 100].drop(columns="Target Coverage") 152 | df = df.rename(columns=rename_stats, index=lambda m: re.sub("CBCE", "SAOCP", m)) 153 | if df.isna().all().all() or len(df) == 0: 154 | continue 155 | for m in ["ModelSigma", "SimpleSAOCP"]: 156 | if m in df.index: 157 | df = df.drop(labels=[m]) 158 | df = df.loc[:, [c for c in df.columns if "LCE" in c]] 159 | str_df = combine_sd_df(df, interval=True, print_sd=args.print_sd) 160 | df.columns = pd.MultiIndex.from_tuples([(args.dataset, k, c) for c in df.columns]) 161 | str_df.columns = pd.MultiIndex.from_tuples([(k, c) for c in str_df.columns]) 162 | full_df = df if full_df is None else pd.concat((full_df, df), axis=1) 163 | full_str_df = str_df if full_str_df is None else pd.concat((full_str_df, str_df), axis=1) 164 | 165 | if not args.print_sd: 166 | print(full_str_df) 167 | if args.markdown: 168 | full_str_df.columns = full_str_df.columns.droplevel(1) 169 | print(full_str_df.rename(index=md_rename).to_markdown()) 170 | return 171 | highlight = functools.partial(bold_best, dataset=args.dataset, full_df=full_df, target_cov=args.coverage / 100) 172 | styler = full_str_df.style.format(na_rep="--").apply(highlight).hide(axis=1, level=2) 173 | tex_str = styler.to_latex(hrules=True, convert_css=True) 174 | tex_str = "\n".join([line for line in tex_str.split("\n") if "LCE" not in line]) 175 | tex_str = tex_formatting(tex_str) 176 | print(tex_str) 177 | 178 | 179 | def main(): 180 | args = parse_args() 181 | if args.interval_table: 182 | interval_table(args) 183 | else: 184 | primary_table(args) 185 | 186 | 187 | if __name__ == "__main__": 188 | main() 189 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /cv_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | """ 8 | Utilities for computer vision experiments. 9 | """ 10 | import hashlib 11 | import logging 12 | import os 13 | import tarfile 14 | 15 | from datasets import load_dataset 16 | import numpy as np 17 | import requests 18 | import torch 19 | import torch.distributed as dist 20 | import torch.nn as nn 21 | from torch.utils.data import Dataset 22 | import torchvision 23 | from torchvision.datasets import ImageFolder 24 | import torchvision.transforms as transforms 25 | import tqdm 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def create_model(dataset, model_name="resnet18", device=torch.device("cpu"), **kwargs): 32 | """ 33 | Returns a model with a representation pre-trained on ImageNet, but an untrained final linear classification layer. 34 | """ 35 | model_name = model_name.lower() 36 | if model_name == "resnet18": 37 | model = torchvision.models.resnet18(weights="DEFAULT", **kwargs) 38 | if not isinstance(dataset, ImageNet): 39 | model.fc = nn.Linear(512, dataset.n_class) 40 | 41 | elif model_name == "resnet50": 42 | model = torchvision.models.resnet50(weights="DEFAULT", **kwargs) 43 | if not isinstance(dataset, ImageNet): 44 | model.fc = nn.Linear(2048, dataset.n_class) 45 | 46 | elif model_name == "densenet121": 47 | model = torchvision.models.densenet121(weights="DEFAULT", **kwargs) 48 | if not isinstance(dataset, ImageNet): 49 | model.classifier = nn.Linear(1024, dataset.n_class) 50 | 51 | elif model_name == "inception_v3": 52 | model = torchvision.models.inception_v3(weights="DEFAULT", transform_input=False, **kwargs) 53 | if not isinstance(dataset, ImageNet): 54 | model.fc = nn.Linear(2048, dataset.n_class) 55 | 56 | elif model_name == "wide_resnet50": 57 | model = torchvision.models.wide_resnet50_2(weights="DEFAULT", **kwargs) 58 | if not isinstance(dataset, ImageNet): 59 | model.fc = nn.Linear(2048, dataset.n_class) 60 | 61 | else: 62 | raise NotImplementedError(f"Model {model_name} is not a supported pre-trained ImageNet model.") 63 | 64 | return model.to(device=device) 65 | 66 | 67 | def data_loader(dataset, batch_size=256, epoch=0, pin_memory=True): 68 | shuffle = dataset.split == "train" 69 | if dist.is_available() and dist.is_initialized(): 70 | sampler = torch.utils.data.DistributedSampler(dataset, shuffle=shuffle) 71 | sampler.set_epoch(epoch) 72 | elif shuffle: 73 | sampler = torch.utils.data.RandomSampler(dataset) 74 | else: 75 | sampler = torch.utils.data.SequentialSampler(dataset) 76 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler, pin_memory=pin_memory) 77 | 78 | 79 | class DatasetMixIn(Dataset): 80 | def __init__(self, split): 81 | assert split in ["train", "valid", "test"] 82 | self.split = split 83 | if split == "train": 84 | resize = [transforms.Resize(256), transforms.RandomResizedCrop(224)] 85 | else: 86 | resize = [transforms.Resize(256), transforms.CenterCrop(224)] 87 | self.transform = transforms.Compose([*resize, transforms.ToTensor(), self.norm()]) 88 | self.instance_key = "image" 89 | self.label_key = "label" 90 | 91 | def __len__(self): 92 | test_name = "test" if "test" in self.data else "valid" 93 | if self.split == "train": 94 | return len(self.train_idx) 95 | if self.split == "valid": 96 | return len(self.valid_idx) 97 | return len(self.data[test_name]) 98 | 99 | def __getitem__(self, i): 100 | test_name = "test" if "test" in self.data else "valid" 101 | if self.split == "train": 102 | i = self.train_idx[i] 103 | elif self.split == "valid": 104 | i = self.valid_idx[i] 105 | 106 | # Get the data point 107 | instance = self.data[test_name if self.split == "test" else "train"][i] 108 | return self.transform(instance[self.instance_key].convert("RGB")), instance[self.label_key] 109 | 110 | 111 | class ImageNet: 112 | def __init__(self, split, rootdir="/export/share/datasets/vision/imagenet"): 113 | if split == "train": 114 | self.data = ImageFolder(os.path.join(rootdir, "train")) 115 | else: 116 | self.data = ImageFolder(os.path.join(rootdir, "val")) 117 | resize = [transforms.Resize(256), transforms.CenterCrop(224)] 118 | self.split = split 119 | self.transform = transforms.Compose([*resize, transforms.ToTensor(), TinyImageNet.norm()]) 120 | 121 | def __len__(self): 122 | return len(self.data) if self.split == "train" else len(self.data) // 2 123 | 124 | def __getitem__(self, i): 125 | img, label = self.data[i if self.split == "train" else 2 * i + (self.split == "test")] 126 | return self.transform(img), label 127 | 128 | @property 129 | def n_class(self): 130 | return 1000 131 | 132 | 133 | class TinyImageNet(DatasetMixIn): 134 | def __init__(self, split): 135 | super().__init__(split) 136 | self.data = load_dataset("Maysee/tiny-imagenet") 137 | self.valid_idx = list(range(0, 100000, 10)) 138 | self.train_idx = [i for i in range(100000) if i not in self.valid_idx] 139 | 140 | @classmethod 141 | def norm(cls): 142 | return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 143 | 144 | @property 145 | def n_class(self): 146 | return 200 147 | 148 | 149 | class CIFAR10(DatasetMixIn): 150 | def __init__(self, split): 151 | super().__init__(split) 152 | self.data = load_dataset("cifar10") 153 | self.train_idx = range(45000) 154 | self.valid_idx = range(45000, 50000) 155 | self.instance_key = "img" 156 | 157 | @classmethod 158 | def norm(cls): 159 | return transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 160 | 161 | @property 162 | def n_class(self): 163 | return 10 164 | 165 | 166 | class CIFAR100(DatasetMixIn): 167 | def __init__(self, split): 168 | super().__init__(split) 169 | self.data = load_dataset("cifar100") 170 | self.train_idx = range(45000) 171 | self.valid_idx = range(45000, 50000) 172 | self.instance_key = "img" 173 | self.label_key = "fine_label" 174 | 175 | @classmethod 176 | def norm(cls): 177 | return transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 178 | 179 | @property 180 | def n_class(self): 181 | return 100 182 | 183 | 184 | class Downloader(Dataset): 185 | def __len__(self): 186 | return len(self.data) 187 | 188 | def __getitem__(self, i): 189 | image, label = self.data[i] 190 | if not isinstance(self.data, DatasetMixIn): 191 | image = self.transform(image) 192 | return image, label 193 | 194 | @staticmethod 195 | def extract_tar(tar_path, extract_dir=None, success_file=None): 196 | if extract_dir is None: 197 | extract_dir = os.path.join(os.path.dirname(tar_path), os.path.basename(tar_path)[: -len(".tar")]) 198 | success_file = os.path.join(extract_dir, "_SUCCESS") if success_file is None else success_file 199 | if not os.path.isfile(success_file): 200 | logger.info(f"Extracting tarfile {os.path.basename(tar_path)}...") 201 | tarfile.open(tar_path).extractall(path=os.path.dirname(tar_path)) 202 | with open(success_file, "w"): 203 | pass 204 | 205 | @staticmethod 206 | def download(url, file_name, expected_md5): 207 | if os.path.exists(file_name): 208 | logger.info("Checking MD5 checksum...") 209 | with open(file_name, "rb") as file_to_check: 210 | data = file_to_check.read() 211 | md5_returned = hashlib.md5(data).hexdigest() 212 | if md5_returned != expected_md5: 213 | logger.info("Invalid MD5 checksum. Restarting download.") 214 | os.remove(file_name) 215 | 216 | if os.path.exists(file_name): 217 | return 218 | 219 | logger.info(f"Downloading file {os.path.basename(file_name)}...") 220 | os.makedirs(os.path.dirname(file_name), exist_ok=True) 221 | with open(file_name, "wb") as f: 222 | r = requests.get(url, stream=True) 223 | bar_format = "{l_bar}{bar}| {n:.1f}/{total:.1f}MB [{elapsed}<{remaining}, {rate_fmt}]" 224 | total_mb = int(r.headers["Content-Length"]) / (1024**2) 225 | with tqdm.tqdm(unit="MB", total=total_mb, bar_format=bar_format) as pbar: 226 | for chunk in r.iter_content(chunk_size=1 * (1024**2)): 227 | if chunk: # filter out keep-alive new chunks 228 | f.write(chunk) 229 | f.flush() 230 | pbar.update(len(chunk) / (1024**2)) 231 | 232 | 233 | class ImageNetC(Downloader): 234 | def __init__(self, corruption=None, severity=0): 235 | # Download dataset 236 | base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "ImageNetC") 237 | if not os.path.exists(os.path.join(base_dir, "_SUCCESS")): 238 | base = "https://zenodo.org/record/2235448/files" 239 | suf = "?download=1" 240 | self.download(f"{base}/blur.tar{suf}", f"{base_dir}/blur.tar", "2d8e81fdd8e07fef67b9334fa635e45c") 241 | self.download(f"{base}/digital.tar{suf}", f"{base_dir}/digital.tar", "89157860d7b10d5797849337ca2e5c03") 242 | self.download(f"{base}/noise.tar{suf}", f"{base_dir}/noise.tar", "e80562d7f6c3f8834afb1ecf27252745") 243 | self.download(f"{base}/weather.tar{suf}", f"{base_dir}/weather.tar", "33ffea4db4d93fe4a428c40a6ce0c25d") 244 | with open(os.path.join(base_dir, "_SUCCESS"), "w"): 245 | pass 246 | 247 | # Extract tar files 248 | self.extract_tar(f"{base_dir}/blur.tar", success_file=f"{base_dir}/_SUCCESS_blur") 249 | self.extract_tar(f"{base_dir}/digital.tar", success_file=f"{base_dir}/_SUCCESS_digital") 250 | self.extract_tar(f"{base_dir}/noise.tar", success_file=f"{base_dir}/_SUCCESS_noise") 251 | self.extract_tar(f"{base_dir}/weather.tar", success_file=f"{base_dir}/_SUCCESS_weather") 252 | 253 | # Get the actual dataset 254 | if severity == 0 or corruption is None: 255 | self.data = ImageNet(split="test") 256 | else: 257 | valid_corruptions = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))] 258 | assert severity in range(1, 6), f"Got severity={severity}. Expected an int between 1 and 5." 259 | assert corruption in valid_corruptions, f"Got corruption={corruption}. Expected one of {valid_corruptions}" 260 | self.data = ImageFolder(os.path.join(base_dir, corruption, str(severity))) 261 | resize = [transforms.Resize(256), transforms.CenterCrop(224)] 262 | self.transform = transforms.Compose([*resize, transforms.ToTensor(), TinyImageNet.norm()]) 263 | self.split = "test" 264 | 265 | def __getitem__(self, i): 266 | return self.data[i] if isinstance(self.data, ImageNet) else super().__getitem__(i) 267 | 268 | @property 269 | def n_class(self): 270 | return 1000 271 | 272 | 273 | class TinyImageNetC(Downloader): 274 | def __init__(self, corruption=None, severity=0): 275 | # Download data & extract tar if needed 276 | url = "https://zenodo.org/record/2469796/files/TinyImageNet-C.tar?download=1" 277 | base_dir = os.path.dirname(os.path.abspath(__file__)) 278 | file_name = os.path.join(base_dir, "data", "TinyImageNet-C.tar") 279 | data_dir = os.path.join(os.path.dirname(file_name), "TinyImageNet-C", "Tiny-ImageNet-C") 280 | if not os.path.exists(os.path.join(data_dir, "_SUCCESS")): 281 | self.download(url=url, file_name=file_name, expected_md5="3d9c6e89c2609aeb4198f84c8edd1ff0") 282 | self.extract_tar(file_name) 283 | self.extract_tar(os.path.join(os.path.dirname(file_name), "TinyImageNet-C", "Tiny-ImageNet-C.tar")) 284 | 285 | # Get the actual dataset 286 | if severity == 0 or corruption is None: 287 | self.data = TinyImageNet(split="test") 288 | else: 289 | valid_corruptions = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))] 290 | assert severity in range(1, 6), f"Got severity={severity}. Expected an int between 1 and 5." 291 | assert corruption in valid_corruptions, f"Got corruption={corruption}. Expected one of {valid_corruptions}" 292 | self.data = ImageFolder(os.path.join(data_dir, corruption, str(severity))) 293 | resize = [transforms.Resize(256), transforms.CenterCrop(224)] 294 | self.transform = transforms.Compose([*resize, transforms.ToTensor(), TinyImageNet.norm()]) 295 | self.split = "test" 296 | 297 | @property 298 | def n_class(self): 299 | return 200 300 | 301 | 302 | class CIFAR10C(Downloader): 303 | def __init__(self, corruption=None, severity=0): 304 | # Download data & extract tar if needed 305 | url = "https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1" 306 | base_dir = os.path.dirname(os.path.abspath(__file__)) 307 | file_name = os.path.join(base_dir, "data", "CIFAR10-C.tar") 308 | data_dir = os.path.join(base_dir, "data", "CIFAR-10-C") 309 | if not os.path.exists(os.path.join(data_dir, "_SUCCESS")): 310 | self.download(url=url, file_name=file_name, expected_md5="56bf5dcef84df0e2308c6dcbcbbd8499") 311 | self.extract_tar(file_name, extract_dir=data_dir) 312 | 313 | # Get the actual dataset 314 | if severity == 0 or corruption is None: 315 | self.data = CIFAR10(split="test") 316 | else: 317 | valid_corruptions = [d[:-4] for d in os.listdir(data_dir) if d != "labels.npy" and d.endswith(".npy")] 318 | assert severity in range(1, 6), f"Got severity={severity}. Expected an int between 1 and 5." 319 | assert corruption in valid_corruptions, f"Got corruption={corruption}. Expected one of {valid_corruptions}" 320 | data = np.load(os.path.join(data_dir, corruption + ".npy")) 321 | labels = np.load(os.path.join(data_dir, "labels.npy")) 322 | self.data = [(data[i], labels[i]) for i in range((severity - 1) * 10000, severity * 10000)] 323 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Resize(224), CIFAR10.norm()]) 324 | self.split = "test" 325 | 326 | @property 327 | def n_class(self): 328 | return 10 329 | 330 | 331 | class CIFAR100C(Downloader): 332 | def __init__(self, corruption=None, severity=0): 333 | # Download data & extract tar if needed 334 | url = "https://zenodo.org/record/3555552/files/CIFAR-100-C.tar?download=1" 335 | base_dir = os.path.dirname(os.path.abspath(__file__)) 336 | file_name = os.path.join(base_dir, "data", "CIFAR100-C.tar") 337 | data_dir = os.path.join(base_dir, "data", "CIFAR-100-C") 338 | if not os.path.exists(os.path.join(data_dir, "_SUCCESS")): 339 | self.download(url=url, file_name=file_name, expected_md5="11f0ed0f1191edbf9fa23466ae6021d3") 340 | self.extract_tar(file_name, extract_dir=data_dir) 341 | 342 | # Get the actual dataset 343 | if severity == 0 or corruption is None: 344 | self.data = CIFAR100(split="test") 345 | else: 346 | valid_corruptions = [d[:-4] for d in os.listdir(data_dir) if d != "labels.npy" and d.endswith(".npy")] 347 | assert severity in range(1, 6), f"Got severity={severity}. Expected an int between 1 and 5." 348 | assert corruption in valid_corruptions, f"Got corruption={corruption}. Expected one of {valid_corruptions}" 349 | data = np.load(os.path.join(data_dir, corruption + ".npy")) 350 | labels = np.load(os.path.join(data_dir, "labels.npy")) 351 | self.data = [(data[i], labels[i]) for i in range((severity - 1) * 10000, severity * 10000)] 352 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Resize(224), CIFAR10.norm()]) 353 | self.split = "test" 354 | 355 | @property 356 | def n_class(self): 357 | return 100 358 | -------------------------------------------------------------------------------- /time_series.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | """ 8 | File for running all time series experiments. 9 | """ 10 | import argparse 11 | from collections import OrderedDict 12 | from functools import partial 13 | import glob 14 | import itertools 15 | import logging 16 | import math 17 | import multiprocessing as mp 18 | import os 19 | import pickle 20 | import re 21 | import traceback 22 | 23 | import matplotlib.pyplot as plt 24 | from merlion.models.factory import ModelFactory 25 | from merlion.models.utils.autosarima_utils import ndiffs 26 | from merlion.utils import TimeSeries 27 | import numpy as np 28 | import pandas as pd 29 | from scipy.stats import norm 30 | import tqdm 31 | 32 | from online_conformal.dataset import M4, MonashTSF 33 | from online_conformal.saocp import SAOCP, EnbSAOCP 34 | from online_conformal.enbpi import EnbPI, EnbMixIn 35 | from online_conformal.faci import FACI, FACI_S, EnbFACI 36 | from online_conformal.model_sigma import ModelSigma 37 | from online_conformal.nex_conformal import NExConformal, EnbNEx 38 | from online_conformal.ogd import ScaleFreeOGD, EnbOGD 39 | from online_conformal.split_conformal import SplitConformal 40 | from online_conformal.utils import coverage, interval_miscoverage, interval_regret, mae, width 41 | 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | name2dataset = dict( 46 | M4_Hourly=lambda: M4("Hourly"), 47 | M4_Daily=lambda: M4("Daily"), 48 | M4_Weekly=lambda: M4("Weekly"), 49 | NN5_Daily=lambda: MonashTSF("nn5_daily", freq="1D", horizon=30), 50 | ) 51 | 52 | name2model = dict( 53 | Prophet=dict(name="Prophet", target_seq_index=0), 54 | LGBM=dict(name="LGBMForecaster", n_jobs=2, target_seq_index=0), 55 | ARIMA=dict(name="Arima", order=(10, None, 10), target_seq_index=0, transform=dict(name="Identity")), 56 | ) 57 | 58 | 59 | def parse_args(): 60 | results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results") 61 | parser = argparse.ArgumentParser(description="Runs conformal prediction experiments on time series datasets.") 62 | parser.add_argument("--dirname", type=str, default=results_dir, help="Directory where results are stored.") 63 | parser.add_argument("--dataset", type=str, default="M4_Hourly", choices=list(name2dataset.keys())) 64 | parser.add_argument("--model", type=str, default="LGBM", choices=list(name2model.keys())) 65 | parser.add_argument("--window", type=int, default=20, help="Interval length for strongly adaptive evaluation.") 66 | parser.add_argument("--target_cov", type=int, nargs="*", help="The target coverages (as a percent).") 67 | parser.add_argument("--njobs", type=int, default=None, help="The number of parallel processes to use") 68 | parser.add_argument("--skip_train", action="store_true", help="Skip running models & only use saved results.") 69 | parser.add_argument("--start", type=int, default=0, help="The index to start at. For parallelization.") 70 | parser.add_argument("--end", type=int, default=None, help="The index to end at. For parallelization.") 71 | parser.add_argument("--ignore_checkpoint", action="store_true", help="Ignore saved results & start over.") 72 | parser.add_argument("--skip_model_sigma", action="store_true", help="Skip visualizing model's own uncertainty.") 73 | parser.add_argument("--skip_ensemble", action="store_true", help="Skip using ensemble methods.") 74 | args = parser.parse_args() 75 | 76 | # Set full dirname & convert various arguments to the forms expected downstream 77 | args.dirname = os.path.join(args.dirname, args.dataset, args.model) 78 | args.target_cov = np.asarray(args.target_cov or [80, 90, 95]) / 100 79 | if args.njobs is None: 80 | args.njobs = math.ceil(mp.cpu_count() / 2) 81 | if args.model == "LGBM": 82 | args.njobs = math.ceil(args.njobs / 2) 83 | args.model = name2model[args.model] 84 | return args 85 | 86 | 87 | def evaluate(model, train_data, test_data, horizon, target_covs, calib_frac, ensemble=True, verbose=False, cache=None): 88 | cache = cache or {} 89 | target_idx = None 90 | base_predictor, base_ensemble = None, None 91 | if "ARIMA" in model["name"].upper(): 92 | model["order"] = (model["order"][0], ndiffs(train_data.iloc[:, 0].dropna()), model["order"][2]) 93 | model = ModelFactory.create(**model) if isinstance(model, dict) else model 94 | if not isinstance(target_covs, list): 95 | target_covs = [target_covs] 96 | predictors = OrderedDict() 97 | methods = [SplitConformal, NExConformal, FACI, ScaleFreeOGD, FACI_S, SAOCP, ModelSigma] 98 | methods += [EnbPI, EnbNEx, EnbFACI, EnbOGD, EnbSAOCP] 99 | method_covs = list(itertools.product(methods, target_covs)) 100 | kwargs = dict(train_data=train_data, calib_frac=calib_frac, horizon=horizon) 101 | for method, cov in tqdm.tqdm(method_covs, desc="Model Training", disable=not verbose): 102 | # Collect all the predictors after training the base models 103 | method_name = method.__name__ 104 | if cov in cache and (method_name in cache[cov] or re.sub("SAOCP", "CBCE", method_name) in cache[cov]): 105 | predictors[(method_name, cov)] = None 106 | continue 107 | if issubclass(method, EnbMixIn) and not ensemble: 108 | continue 109 | try: 110 | if issubclass(method, EnbMixIn) and base_ensemble is None: 111 | predictor = method(model, coverage=cov, **kwargs) 112 | base_ensemble = predictor 113 | elif not issubclass(method, EnbMixIn) and base_predictor is None: 114 | predictor = method(model, coverage=cov, **kwargs) 115 | target_idx = predictor.model.target_seq_index 116 | base_predictor = predictor 117 | else: 118 | other = base_ensemble if issubclass(method, EnbMixIn) else base_predictor 119 | predictor = method.from_other(other, coverage=cov) 120 | predictors[(method_name, cov)] = predictor 121 | except Exception as e: 122 | if method is ModelSigma: # model doesn't support uncertainty estimation 123 | continue 124 | elif issubclass(method, EnbMixIn): # model is incompatible with ensembles 125 | ensemble = False 126 | continue 127 | else: 128 | raise e 129 | 130 | # Do the forecasting 131 | t0 = test_data.index[0] 132 | if all(p is None for p in predictors.values()): 133 | target = None 134 | else: 135 | target = test_data.iloc[:, target_idx] 136 | if horizon > 1: 137 | test_data = pd.concat((train_data.iloc[-horizon + 1 :], test_data)) 138 | train_data = train_data.iloc[: -horizon + 1] 139 | 140 | yhat, lb, ub = [OrderedDict((k, []) for k in predictors.keys()) for _ in range(3)] 141 | for i in tqdm.trange(len(test_data), desc="Forecasting", disable=not verbose): 142 | # Don't do anything if we've cached all the results already 143 | if all(p is None for p in predictors.values()): 144 | break 145 | # Get the base model's forecast for this timestamp, and then move the train data forward one step 146 | y_t = test_data.iloc[i : i + horizon] 147 | if base_predictor is not None: 148 | base_yhat_t, err_t = base_predictor.model.forecast(y_t.index, TimeSeries.from_pd(train_data)) 149 | base_yhat_t = base_yhat_t.to_pd().iloc[:, 0] 150 | err_t = None if err_t is None else err_t.to_pd().iloc[:, 0] 151 | else: 152 | base_yhat_t = err_t = None 153 | if base_ensemble is not None: 154 | ens_yhat_t = base_ensemble.model.forecast(y_t.index, TimeSeries.from_pd(train_data))[0].to_pd().iloc[:, 0] 155 | else: 156 | ens_yhat_t = None 157 | train_data = pd.concat((train_data, y_t.iloc[:1])) 158 | 159 | # Obtain error bars from each predictor, and then update the predictor 160 | for (method, cov), predictor in predictors.items(): 161 | k = (method, cov) 162 | if predictor is None: # cached results 163 | continue 164 | yhat_t = ens_yhat_t if isinstance(predictor, EnbMixIn) else base_yhat_t 165 | if isinstance(predictor, ModelSigma): 166 | if err_t is None: 167 | if k in yhat: 168 | del yhat[k], lb[k], ub[k] 169 | continue 170 | lb_t = yhat_t + err_t * norm.ppf((1 - cov) / 2) 171 | ub_t = yhat_t + err_t * norm.ppf(1 - (1 - cov) / 2) 172 | else: 173 | lb_t, ub_t = zip(*[predictor.predict(h + 1) for h in range(len(yhat_t))]) 174 | lb_t = yhat_t + np.asarray(lb_t) 175 | ub_t = yhat_t + np.asarray(ub_t) 176 | for h in range(len(y_t)): 177 | if y_t.index[h] >= t0 and not np.isnan(y_t.iloc[h, target_idx]) and not np.isnan(yhat_t.iloc[h]): 178 | predictor.update(y_t.iloc[h : h + 1, target_idx], yhat_t.iloc[h : h + 1], horizon=h + 1) 179 | yhat[k].append(yhat_t) 180 | lb[k].append(lb_t) 181 | ub[k].append(ub_t) 182 | 183 | # Aggregate forecasts & error bars for each horizon 184 | results = OrderedDict() 185 | for method, cov in yhat.keys(): 186 | if cov not in results: 187 | results[cov] = OrderedDict() 188 | if cov in cache and (method in cache[cov] or re.sub("SAOCP", "CBCE", method) in cache[cov]): 189 | results[cov][method] = cache[cov].get(method, cache[cov][re.sub("SAOCP", "CBCE", method)]) 190 | else: 191 | yhat_k, lb_k, ub_k = [ 192 | { 193 | h + 1: pd.concat([x.iloc[h : h + 1] for x in ts if len(x) > h and x.index[h] >= t0]) 194 | for h in range(horizon) 195 | } 196 | for ts in [yhat[(method, cov)], lb[(method, cov)], ub[(method, cov)]] 197 | ] 198 | results[cov][method] = {"ground_truth": target, "forecast": [yhat_k, lb_k, ub_k], "target_cov": cov} 199 | 200 | return results 201 | 202 | 203 | def summarize_results(all_results, window): 204 | def construct(true, pred): 205 | return pd.concat([pred[t % len(pred) + 1].iloc[t : t + 1] for t in range(len(true))]) 206 | 207 | summaries = OrderedDict() 208 | for cov, cov_results in all_results.items(): 209 | summary = [] 210 | methods = [re.sub("CBCE", "SAOCP", method) for method in cov_results.keys()] 211 | for method, result in zip(methods, cov_results.values()): 212 | y = result["ground_truth"] 213 | yhat, lb, ub = result["forecast"] 214 | horizons = ["full"] + sorted(yhat.keys()) 215 | yhat["full"], lb["full"], ub["full"] = construct(y, yhat), construct(y, lb), construct(y, ub) 216 | kwargs = dict(cov=result["target_cov"], window=min(window, len(y))) 217 | int_miscov = partial(interval_miscoverage, **kwargs) 218 | int_regret = partial(interval_regret, **kwargs) 219 | for i, fn in enumerate([coverage, width, int_miscov, int_regret, mae]): 220 | if len(summary) < i + 1: 221 | summary.append(pd.DataFrame(0, index=pd.Index(horizons, name="Horizon"), columns=methods)) 222 | summary[i].loc[horizons, method] = [fn(y, yhat[h], lb[h], ub[h]) for h in horizons] 223 | summaries[cov] = summary 224 | return summaries 225 | 226 | 227 | def summarize_file(fname, window=20): 228 | with open(fname, "rb") as f: 229 | results = pickle.load(f) 230 | ts_target_cov = list(results.values())[0]["target_cov"] 231 | return ts_target_cov, summarize_results({ts_target_cov: results}, window=window)[ts_target_cov] 232 | 233 | 234 | def synthesize_results_dir(dirname: str, window=20, njobs=1): 235 | target_cov = None 236 | full_summary = [] 237 | files = sorted(glob.glob(os.path.join(dirname, "*.pkl")), key=lambda k: int(re.search(r"(\d+)\.pkl", k).group(1))) 238 | if len(files) == 0: 239 | raise RuntimeError(f"Directory {dirname} has no .pkl files of results in it.") 240 | with mp.Pool(njobs) as pool: 241 | with tqdm.tqdm(total=len(files), desc="Analyzing Results", leave=False) as pbar: 242 | for ts_target_cov, summ in pool.imap_unordered(partial(summarize_file, window=window), files): 243 | if target_cov is None: 244 | target_cov = ts_target_cov 245 | assert ts_target_cov == target_cov 246 | if any((df > 1000).any().any() for df in summ): # Outlier removal 247 | continue 248 | for i, df in enumerate(summ): 249 | if len(full_summary) < i + 1: 250 | full_summary.append([df]) 251 | else: 252 | full_summary[i].append(df) 253 | pbar.update(1) 254 | 255 | gbs = tuple(pd.concat(summ).groupby("Horizon", dropna=False) for summ in full_summary) 256 | mu = {target_cov: tuple(gb.mean() for gb in gbs)} 257 | sd = {target_cov: tuple(gb.apply(lambda s: pd.Series(s.std() / np.sqrt(len(s)))) for gb in gbs)} 258 | return mu, sd 259 | 260 | 261 | def visualize(summaries, ensemble=False, skip_model_sigma=True, plot_regret=True): 262 | def skip(name): 263 | extra_check = name == "ModelSigma" and skip_model_sigma 264 | return extra_check or ("Enb" in name and not ensemble) or ("Enb" not in name and ensemble) 265 | 266 | figs = OrderedDict() 267 | for target_cov, stats in summaries.items(): 268 | cov, subopt, miscov, regret = stats[:4] 269 | results = [("Coverage", target_cov, cov), ("Width", None, subopt), ("Interval Miscoverage", 0, miscov)] 270 | if plot_regret: 271 | results.append(("Interval Regret", 0, regret)) 272 | nrows = math.ceil(len(results) / 3) 273 | ncols = math.ceil(len(results) / nrows) 274 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6 * ncols, 4 * nrows + 1), facecolor="w") 275 | axs = axs.reshape(nrows, ncols) 276 | for i, (title, baseline_target, df) in enumerate(results): 277 | i, j = i // ncols, i % ncols 278 | df = df.loc[[h for h in df.index if isinstance(h, int)], [c for c in df.columns if not skip(c)]] 279 | if baseline_target is not None: 280 | axs[i, j].axhline(baseline_target, ls="--", c="k", label="target") 281 | for k, method in enumerate(df.columns): 282 | c = 1 if method == "SAOCP" else k + int(k > 0) 283 | axs[i, j].plot(df.loc[:, method], label=method, color=f"C{c}") 284 | axs[i, j].set_xlabel(df.index.name, fontsize=14) 285 | axs[i, j].set_title(title, fontsize=16) 286 | if i == j == 0: 287 | fig.legend() 288 | fig.suptitle(f"Target Coverage = {target_cov:.3f}", fontsize=20) 289 | fig.tight_layout() 290 | figs[target_cov] = fig 291 | return figs 292 | 293 | 294 | def main_loop(i_data_args): 295 | cache, fnames = {}, {} 296 | i, data, args = i_data_args 297 | covs = list(args.target_cov) 298 | if not args.start <= i < args.end: 299 | return None, None 300 | for cov in covs: 301 | fname = os.path.join(args.dirname, str(int(cov * 100)), f"{i}.pkl") 302 | fnames[cov] = fname 303 | if os.path.exists(fname) and not args.ignore_checkpoint: 304 | try: 305 | with open(fname, "rb") as f: 306 | cache[cov] = pickle.load(f) 307 | except: 308 | continue 309 | logging.basicConfig(format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", level=logging.ERROR) 310 | try: 311 | return fnames, evaluate(args.model, **data, target_covs=covs, ensemble=not args.skip_ensemble, cache=cache) 312 | except Exception: 313 | return fnames, f"Exception on time series {i}\n{traceback.format_exc()}" 314 | 315 | 316 | def main(): 317 | args = parse_args() 318 | dataset = name2dataset[args.dataset]() 319 | dirnames = {cov: os.path.join(args.dirname, str(int(cov * 100))) for cov in args.target_cov} 320 | for dirname in dirnames.values(): 321 | os.makedirs(dirname, exist_ok=True) 322 | os.makedirs(os.path.join(args.dirname, f"k={args.window}"), exist_ok=True) 323 | os.makedirs(os.path.join(args.dirname, f"k={args.window}", "figures"), exist_ok=True) 324 | 325 | if not args.skip_train: 326 | n = len(dataset) 327 | args.end = n if args.end is None else args.end 328 | with tqdm.trange(n, desc="Dataset") as pbar: 329 | with mp.Pool(args.njobs) as pool: 330 | for fnames, results in pool.imap_unordered(main_loop, map(lambda i: (i, dataset[i], args), range(n))): 331 | if isinstance(results, str): 332 | logger.error(results) 333 | elif isinstance(results, dict): 334 | for cov, cov_results in results.items(): 335 | with open(fnames[cov], "wb") as f: 336 | pickle.dump(cov_results, f) 337 | pbar.update(1) 338 | if args.start != 0 or args.end != n: 339 | return 340 | 341 | idx_cols = ["Method", "Target Coverage"] 342 | cols = ["Coverage", "Width", "Interval Miscoverage", "Interval Regret"] 343 | err_cols = ["MAE"] 344 | table = pd.DataFrame(columns=idx_cols + cols).set_index(idx_cols) 345 | enb_table = table.copy() 346 | mae_table = pd.DataFrame(columns=err_cols) 347 | for target_cov, dirname in dirnames.items(): 348 | # Create a table & save it 349 | njobs = args.njobs * 2 if "LGBM" in args.model["name"] else args.njobs 350 | summ, sd = synthesize_results_dir(dirname, njobs=njobs, window=args.window) 351 | for col_name, data, data_std in zip(cols + err_cols, *summ.values(), *sd.values()): 352 | if col_name in err_cols: 353 | enb = [m for m in data.columns if "Enb" in m] 354 | base = [m for m in data.columns if "Enb" not in m] 355 | if len(base) > 0: 356 | mae_table.loc["Base", col_name] = data.loc["full", base[0]] 357 | mae_table.loc["Base SD", col_name] = data_std.loc["full", base[0]] 358 | if len(enb) > 0: 359 | mae_table.loc["Enb", col_name] = data.loc["full", enb[0]] 360 | mae_table.loc["Enb SD", col_name] = data_std.loc["full", enb[0]] 361 | continue 362 | for method in data.columns: 363 | t = enb_table if "Enb" in method else table 364 | t.loc[(method, target_cov), col_name] = data.loc["full", method] 365 | t.loc[(method, target_cov), col_name + " SD"] = data_std.loc["full", method] 366 | table.to_csv(os.path.join(args.dirname, f"k={args.window}", "results_base.csv")) 367 | enb_table.to_csv(os.path.join(args.dirname, f"k={args.window}", "results_enb.csv")) 368 | mae_table.to_csv(os.path.join(args.dirname, f"k={args.window}", "mae.csv")) 369 | 370 | # Make & save figures 371 | figdir = os.path.join(args.dirname, f"k={args.window}", "figures") 372 | fig = visualize(summ, ensemble=False, skip_model_sigma=args.skip_model_sigma)[target_cov] 373 | fig_enb = visualize(summ, ensemble=True, skip_model_sigma=args.skip_model_sigma)[target_cov] 374 | fig.savefig(os.path.join(figdir, f"{int(target_cov * 100)}_results_base.png")) 375 | fig_enb.savefig(os.path.join(figdir, f"{int(target_cov * 100)}_results_enb.png")) 376 | 377 | 378 | if __name__ == "__main__": 379 | logging.basicConfig(format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", level=logging.ERROR) 380 | main() 381 | -------------------------------------------------------------------------------- /vision.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023 salesforce.com, inc. 3 | # All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0 6 | # 7 | """ 8 | File for running all computer vision experiments. 9 | """ 10 | import argparse 11 | from collections import defaultdict 12 | import math 13 | import os 14 | from re import sub 15 | 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import pandas as pd 19 | from scipy.ndimage import gaussian_filter1d 20 | import torch 21 | import torch.distributed as dist 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from torch.optim import Adam, LBFGS, SGD 25 | import tqdm 26 | 27 | from online_conformal.saocp import SAOCP 28 | from online_conformal.faci import FACI, FACI_S 29 | from online_conformal.nex_conformal import NExConformal 30 | from online_conformal.ogd import ScaleFreeOGD 31 | from online_conformal.split_conformal import SplitConformal 32 | from online_conformal.utils import pinball_loss 33 | from cv_utils import create_model, data_loader 34 | from cv_utils import ImageNet, TinyImageNet, CIFAR10, CIFAR100, ImageNetC, TinyImageNetC, CIFAR10C, CIFAR100C 35 | 36 | 37 | corruptions = [ 38 | None, 39 | "brightness", 40 | "contrast", 41 | "defocus_blur", 42 | "elastic_transform", 43 | "fog", 44 | "frost", 45 | "gaussian_noise", 46 | "glass_blur", 47 | "impulse_noise", 48 | "jpeg_compression", 49 | "motion_blur", 50 | "pixelate", 51 | "shot_noise", 52 | "snow", 53 | "zoom_blur", 54 | ] 55 | 56 | 57 | def parse_args(): 58 | parser = argparse.ArgumentParser( 59 | description=f"Runs conformal prediction experiments on computer vision datasets. If you want to do multi-GPU " 60 | f"training, call this file with `torchrun --nproc_per_node {os.path.basename(__file__)} ...`." 61 | f"But if training is finished, we recommend not doing this." 62 | ) 63 | parser.add_argument( 64 | "--dataset", 65 | required=True, 66 | choices=["ImageNet", "TinyImageNet", "CIFAR10", "CIFAR100"], 67 | help="Dataset to run on.", 68 | ) 69 | parser.add_argument("--model", default="resnet50", help="Model architecture to use.") 70 | parser.add_argument("--lr", default=1e-3, help="Learning rate for training.") 71 | parser.add_argument("--batch_size", default=256, help="Batch size for data loader.") 72 | parser.add_argument("--n_epochs", default=150, help="Number of epochs to train for.") 73 | parser.add_argument("--patience", default=10, help="Number of epochs before early stopping.") 74 | parser.add_argument("--ignore_checkpoint", action="store_true", help="Whether to restart from scratch.") 75 | parser.add_argument("--target_cov", default=90, type=int, help="The target coverage (as a percent).") 76 | args = parser.parse_args() 77 | assert 50 < args.target_cov < 100 78 | args.target_cov = args.target_cov / 100 79 | 80 | # Set up distributed training if desired, and set the device 81 | args.local_rank = int(os.environ.get("LOCAL_RANK", -1)) 82 | if args.local_rank == -1: 83 | if torch.cuda.is_available(): 84 | args.device = torch.device("cuda") 85 | else: 86 | args.device = torch.device("cpu") 87 | args.world_size = 1 88 | else: 89 | dist.init_process_group(backend="nccl") 90 | args.device = torch.device(args.local_rank) 91 | args.world_size = dist.get_world_size() 92 | 93 | return args 94 | 95 | 96 | def get_base_dataset(dataset, split): 97 | if dataset == "ImageNet": 98 | return ImageNet(split) 99 | elif dataset == "TinyImageNet": 100 | return TinyImageNet(split) 101 | elif dataset == "CIFAR10": 102 | return CIFAR10(split) 103 | elif dataset == "CIFAR100": 104 | return CIFAR100(split) 105 | raise ValueError(f"Dataset {dataset} is not supported.") 106 | 107 | 108 | def get_model_file(args): 109 | rootdir = os.path.dirname(os.path.abspath(__file__)) 110 | return os.path.join(rootdir, "cv_models", args.dataset, args.model, "model.pt") 111 | 112 | 113 | def get_model(args): 114 | if args.dataset != "ImageNet": 115 | return torch.load(get_model_file(args), map_location=args.device) 116 | return create_model(dataset=ImageNet("valid"), model_name=args.model, device=args.device) 117 | 118 | 119 | def get_results_file(args, corruption, severity): 120 | rootdir = os.path.dirname(os.path.abspath(__file__)) 121 | return os.path.join(rootdir, "cv_logits", args.dataset, args.model, f"{corruption}_{severity}.pt") 122 | 123 | 124 | def get_temp_file(args): 125 | return os.path.join(os.path.dirname(get_results_file(args, None, 0)), "temp.txt") 126 | 127 | 128 | def finished(args): 129 | for corruption in corruptions: 130 | for severity in [0] if corruption is None else [1, 2, 3, 4, 5]: 131 | fname = get_results_file(args, corruption, severity) 132 | if not os.path.isfile(fname): 133 | return False 134 | return os.path.isfile(get_temp_file(args)) 135 | 136 | 137 | def raps_params(dataset): 138 | if dataset == "CIFAR10": 139 | lmbda, k_reg, n_class = 0.1, 1, 10 140 | elif dataset == "CIFAR100": 141 | lmbda, k_reg, n_class = 0.02, 5, 100 142 | elif dataset == "TinyImageNet": 143 | lmbda, k_reg, n_class = 0.01, 20, 200 144 | elif dataset == "ImageNet": 145 | lmbda, k_reg, n_class = 0.01, 10, 1000 146 | else: 147 | raise ValueError(f"Unsupported dataset {dataset}") 148 | return lmbda, k_reg, n_class 149 | 150 | 151 | def train(args): 152 | # Get train/valid data 153 | train_data = get_base_dataset(args.dataset, "train") 154 | valid_data = get_base_dataset(args.dataset, "valid") 155 | 156 | # Load model checkpoint one has been saved. Otherwise, initialize everything from scratch. 157 | model_file = get_model_file(args) 158 | ckpt_name = os.path.join(os.path.dirname(model_file), "checkpoint.pt") 159 | if os.path.isfile(ckpt_name) and not args.ignore_checkpoint: 160 | model, opt, epoch, best_epoch, best_valid_acc = torch.load(ckpt_name, map_location=args.device) 161 | else: 162 | # create save directory if needed 163 | if args.local_rank in [-1, 0]: 164 | os.makedirs(os.path.dirname(ckpt_name), exist_ok=True) 165 | model = create_model(dataset=train_data, model_name=args.model, device=args.device) 166 | if "ImageNet" in args.dataset: 167 | opt = SGD(model.parameters(), lr=0.1, momentum=0.9) 168 | else: 169 | opt = Adam(model.parameters(), lr=args.lr) 170 | epoch, best_epoch, best_valid_acc = 0, 0, 0.0 171 | 172 | # Set up distributed data parallel if applicable 173 | writer = args.local_rank in [-1, 0] 174 | if args.local_rank != -1: 175 | model = nn.parallel.DistributedDataParallel(model, device_ids=[args.device]) 176 | 177 | for epoch in range(epoch, args.n_epochs): 178 | # Check early stopping condition 179 | if args.patience and epoch - best_epoch > args.patience: 180 | break 181 | 182 | # Main training loop 183 | train_loader = data_loader(dataset=train_data, batch_size=args.batch_size // args.world_size, epoch=epoch) 184 | for x, y in tqdm.tqdm(train_loader, desc=f"Train epoch {epoch+1:2}/{args.n_epochs}", disable=not writer): 185 | opt.zero_grad() 186 | pred = model(x.to(device=args.device)) 187 | loss = F.cross_entropy(pred, y.to(device=args.device)) 188 | loss.backward() 189 | opt.step() 190 | 191 | # Anneal learning rate by a factor of 10 every 7 epochs 192 | if (epoch + 1) % 7 == 0: 193 | for g in opt.param_groups: 194 | g["lr"] *= 0.1 195 | 196 | # Obtain accuracy on the validation dataset 197 | valid_acc = torch.zeros(2, device=args.device) 198 | valid_loader = data_loader(valid_data, batch_size=args.batch_size, epoch=epoch) 199 | with torch.no_grad(): 200 | for x, y in tqdm.tqdm(valid_loader, desc=f"Valid epoch {epoch + 1:2}/{args.n_epochs}", disable=True): 201 | pred = model(x.to(device=args.device)) 202 | valid_acc[0] += x.shape[0] 203 | valid_acc[1] += (pred.argmax(dim=-1) == y.to(device=args.device)).sum().item() 204 | 205 | # Reduce results from all parallel processes 206 | if args.local_rank != -1: 207 | dist.all_reduce(valid_acc) 208 | valid_acc = (valid_acc[1] / valid_acc[0]).item() 209 | 210 | # Save checkpoint & update best saved model 211 | if writer: 212 | print(f"Epoch {epoch + 1:2} valid acc: {valid_acc:.5f}") 213 | model_to_save = model.module if args.local_rank != -1 else model 214 | if valid_acc > best_valid_acc: 215 | best_epoch = epoch 216 | best_valid_acc = valid_acc 217 | torch.save(model_to_save, model_file) 218 | torch.save([model_to_save, opt, epoch + 1, best_epoch, best_valid_acc], ckpt_name) 219 | 220 | # Synchronize before starting next epoch 221 | if args.local_rank != -1: 222 | dist.barrier() 223 | 224 | 225 | def temperature_scaling(args): 226 | temp = nn.Parameter(torch.tensor(1.0, device=args.device)) 227 | opt = LBFGS([temp], lr=0.01, max_iter=500) 228 | loss_fn = nn.CrossEntropyLoss() 229 | 230 | n_epochs = 10 231 | valid_data = get_base_dataset(args.dataset, "valid") 232 | model = get_model(args) 233 | for epoch in range(n_epochs): 234 | valid_loader = data_loader(valid_data, batch_size=args.batch_size, epoch=epoch) 235 | for x, y in tqdm.tqdm(valid_loader, desc=f"Calibration epoch {epoch + 1:2}/{n_epochs}", disable=False): 236 | with torch.no_grad(): 237 | logits = model(x.to(device=args.device)) 238 | 239 | def eval(): 240 | opt.zero_grad() 241 | loss = loss_fn(logits / temp, y.to(device=args.device)) 242 | loss.backward() 243 | return loss 244 | 245 | opt.step(eval) 246 | 247 | return temp.item() 248 | 249 | 250 | def get_logits(args): 251 | if args.dataset == "CIFAR10": 252 | dataset_cls = CIFAR10C 253 | elif args.dataset == "CIFAR100": 254 | dataset_cls = CIFAR100C 255 | elif args.dataset == "TinyImageNet": 256 | dataset_cls = TinyImageNetC 257 | elif args.dataset == "ImageNet": 258 | dataset_cls = ImageNetC 259 | else: 260 | raise ValueError(f"Dataset {args.dataset} is not supported.") 261 | model = None 262 | for corruption in tqdm.tqdm(corruptions, desc="Corruptions", position=1): 263 | severities = [0] if corruption is None else [1, 2, 3, 4, 5] 264 | for severity in tqdm.tqdm(severities, desc="Severity Levels", position=2, leave=False): 265 | fname = get_results_file(args, corruption, severity) 266 | if os.path.isfile(fname) and not args.ignore_checkpoint: 267 | continue 268 | os.makedirs(os.path.dirname(fname), exist_ok=True) 269 | if model is None: 270 | model = get_model(args) 271 | 272 | # Save the model's logits & labels for the whole dataset 273 | logits, labels = [], [] 274 | dataset = dataset_cls(corruption=corruption, severity=severity) 275 | loader = data_loader(dataset, batch_size=args.batch_size) 276 | with torch.no_grad(): 277 | for x, y in loader: 278 | logits.append(model(x.to(device=args.device)).cpu()) 279 | labels.append(y.cpu()) 280 | torch.save([torch.cat(logits), torch.cat(labels)], fname) 281 | 282 | 283 | def t_to_sev(t, window, run_length=100, schedule=None): 284 | if t < window or schedule in [None, "None", "none"]: 285 | return 0 286 | t_base = t - window // 2 287 | if schedule == "gradual": 288 | k = (t_base // run_length) % 10 289 | return k if k <= 5 else 10 - k 290 | return 5 * ((t_base // run_length) % 2) 291 | 292 | 293 | def main(): 294 | # Train the model, save its logits on all the corrupted test datasets, and do temperature scaling 295 | args = parse_args() 296 | if not finished(args) and args.dataset != "ImageNet": 297 | train(args) 298 | if args.local_rank in [-1, 0]: 299 | temp_file = get_temp_file(args) 300 | if not finished(args): 301 | get_logits(args) 302 | temp = temperature_scaling(args) 303 | with open(temp_file, "w") as f: 304 | f.write(str(temp)) 305 | 306 | # Load the saved logits 307 | with open(temp_file) as f: 308 | temp = float(f.readline()) 309 | n_data = None 310 | sev2results = defaultdict(list) 311 | for corruption in corruptions: 312 | severities = [0] if corruption is None else [1, 2, 3, 4, 5] 313 | for severity in severities: 314 | try: 315 | logits, labels = torch.load(get_results_file(args, corruption, severity)) 316 | except: 317 | continue 318 | sev2results[severity].append(list(zip(F.softmax(logits / temp, dim=-1).numpy(), labels.numpy()))) 319 | n_data = len(labels) if n_data is None else min(n_data, len(labels)) 320 | 321 | # Initialize conformal prediction methods, along with accumulators for results 322 | lmbda, k_reg, n_class = raps_params(args.dataset) 323 | D = 1 + lmbda * np.sqrt(n_class - k_reg) 324 | methods = [SplitConformal, NExConformal, FACI, ScaleFreeOGD, FACI_S, SAOCP] 325 | 326 | label2err = defaultdict(list) 327 | plt.rcParams["text.usetex"] = True 328 | h = 5 + 0.5 * (len(methods) > 5) 329 | fig, axs = plt.subplots(nrows=3, ncols=2, sharex="col", sharey="row", figsize=(12, h), height_ratios=[4, 4, 2]) 330 | for i_shift, shift in enumerate(["sudden", "gradual"]): 331 | sevs, s_opts, w_opts = [], [], [] 332 | warmup, window, run_length = 1000, 100, 500 333 | state = np.random.RandomState(0) 334 | order = state.permutation(n_data)[: 6 * run_length + window // 2 + warmup] 335 | coverages, s_hats, widths = [{m.__name__: [] for m in methods} for _ in range(3)] 336 | predictors = [m(None, None, max_scale=D, lifetime=32, coverage=args.target_cov) for m in methods] 337 | for t, i in tqdm.tqdm(enumerate(order, start=-warmup), total=len(order)): 338 | # Get saved results for the desired severity 339 | sev = t_to_sev(t, window=window, run_length=run_length, schedule=shift) 340 | probs, label = sev2results[sev][state.randint(0, len(sev2results[sev]))][i] 341 | 342 | # Convert probability to APS score 343 | i_sort = np.flip(np.argsort(probs)) 344 | p_sort_cumsum = np.cumsum(probs[i_sort]) - state.rand() * probs[i_sort] 345 | s_sort_cumsum = p_sort_cumsum + lmbda * np.sqrt(np.cumsum([i > k_reg for i in range(n_class)])) 346 | w_opt = np.argsort(i_sort)[label] + 1 347 | s_opt = s_sort_cumsum[w_opt - 1] 348 | if t >= 0: 349 | sevs.append(sev) 350 | s_opts.append(s_opt) 351 | w_opts.append(w_opt) 352 | 353 | # Update all the conformal predictors 354 | for predictor in predictors: 355 | name = type(predictor).__name__ 356 | if t >= 0: 357 | _, s_hat = predictor.predict(horizon=1) 358 | w = np.sum(s_sort_cumsum <= s_hat) 359 | s_hats[name].append(s_hat) 360 | widths[name].append(w) 361 | coverages[name].append(w >= w_opt) 362 | predictor.update(ground_truth=pd.Series([s_opt]), forecast=pd.Series([0]), horizon=1) 363 | 364 | # Perform evaluation & produce a pretty graph 365 | plot_loss = False 366 | for ax in axs[:, i_shift]: 367 | ax.xaxis.grid(True) 368 | ax.tick_params(axis="both", which="both", labelsize=10) 369 | 370 | ax1, ax2, ax3 = axs[:, i_shift] 371 | sevs = pd.Series(sevs).rolling(window).mean().dropna() 372 | w_opts = pd.Series(s_opts if plot_loss else w_opts).rolling(window).quantile(args.target_cov).dropna() 373 | ax1.set_ylabel("Local Coverage", fontsize=10) 374 | ax2.set_ylabel("Prediction Set Size", fontsize=10) 375 | ax3.set_xlabel("Time", fontsize=10) 376 | ax3.set_ylabel("Corruption Level", fontsize=10) 377 | ax1.axhline(args.target_cov, c="k", ls="--", lw=2, zorder=len(methods), label="Best Fixed") 378 | ax2.plot(range(len(w_opts)), gaussian_filter1d(w_opts, sigma=2), c="k", ls="--", lw=2, zorder=len(methods)) 379 | ax3.plot(range(len(sevs)), sevs, c="k") 380 | 381 | s_opts = np.asarray(s_opts) 382 | int_q = pd.Series(s_opts).rolling(window).quantile(args.target_cov).dropna() 383 | print(f"Distribution shift: {shift}") 384 | for i, m in enumerate(methods): 385 | # Compute various summary statistics 386 | name = m.__name__ 387 | label = sub("Split", "S", sub("Conformal", "CP", sub("ScaleFree", "SF-", sub("_", "-", name)))) 388 | s_hat = np.asarray(s_hats[name]) 389 | int_cov = gaussian_filter1d(pd.Series(coverages[name]).rolling(window).mean().dropna(), sigma=3) 390 | int_w = pd.Series(s_hats[name] if plot_loss else widths[name]).rolling(window).mean().dropna() 391 | int_losses = pd.Series(pinball_loss(s_opts, s_hat, args.target_cov)).rolling(window).mean().dropna() 392 | opts = [pinball_loss(s_opts[i : i + window], q, args.target_cov).mean() for i, q in enumerate(int_q)] 393 | int_regret = int_losses.values - np.asarray(opts) 394 | int_miscov = np.abs(args.target_cov - int_cov) 395 | 396 | # Do the plotting 397 | color = "C" + str(i + (i > 0) if m is not SAOCP else 1) 398 | label2err[label].append(f"{np.max(int_miscov):.2f}") 399 | ax1.plot(range(len(int_cov)), int_cov, zorder=i, label=label, color=color) 400 | ax2.plot(range(len(int_w)), gaussian_filter1d(int_w, sigma=2), zorder=i, label=label, color=color) 401 | if min(int_cov) < args.target_cov - 0.2: 402 | ax1.set_ylim(args.target_cov - 0.2, 1.02) 403 | 404 | print( 405 | f"{name:15}: " 406 | f"Cov: {np.mean(coverages[name]):.3f}, " 407 | f"Avg Width: {np.mean(widths[name]):.1f}, " 408 | f"SA Miscov: {np.max(int_miscov):.3f}, " 409 | f"Avg Miscov: {np.mean(int_miscov):.3f}, " 410 | f"SA Regret: {np.max(int_regret):.4f}, " 411 | f"Avg Regret: {np.mean(int_regret):.4f}" 412 | ) 413 | 414 | fig.tight_layout() 415 | labels = [] 416 | lines = axs[0, 0].get_lines() 417 | for line in lines: 418 | label = line.get_label() 419 | if label in label2err: 420 | label = f"{label}: $\\mathrm{{LCE}}_k = ({','.join(label2err[label])})$" 421 | labels.append(label) 422 | ncols = math.ceil(len(lines) / 2) if FACI_S in methods else len(lines) 423 | fig.subplots_adjust(top=0.92 if ncols == len(lines) else 0.88) 424 | fig.legend(lines, labels, loc="upper center", ncols=ncols, fontsize=10, columnspacing=1.5) 425 | figdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "figures") 426 | os.makedirs(figdir, exist_ok=True) 427 | fig.savefig(os.path.join(figdir, f"{args.dataset}.pdf")) 428 | 429 | 430 | if __name__ == "__main__": 431 | main() 432 | --------------------------------------------------------------------------------