├── examples
├── vcae
│ ├── vcae
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── metrics.py
│ │ ├── experiment.py
│ │ └── model.py
│ ├── run_seeds.sh
│ ├── run_seeds.py
│ ├── vcae_analyze.ipynb
│ └── vcae_debug.ipynb
├── benchmark
│ ├── readme.md
│ └── __init__.py
└── pred_intvl
│ ├── readme.md
│ ├── vine_wrapper.py
│ ├── experiments.py
│ ├── data_utils.py
│ └── models.py
├── docs
├── Makefile
├── conf.py
└── index.rst
├── torchvinecopulib
├── __init__.py
├── util
│ └── __init__.py
└── bicop
│ └── __init__.py
├── tests
├── test_package_metadata.py
├── __init__.py
├── test_util.py
├── test_vinecop.py
└── test_bicop.py
├── .github
└── workflows
│ ├── static.yml
│ ├── python-package.yml
│ └── publish-package.yml
├── pyproject.toml
├── .gitignore
├── README.md
└── LICENSE
/examples/vcae/vcae/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/benchmark/readme.md:
--------------------------------------------------------------------------------
1 | ## Specify project root
2 |
3 | Put a `.env` file in project root folder to specify the project root folder location as an env var:
4 |
5 | ```bash
6 | DIR_WORK='~/path/to/project/root/folder'
7 | ```
8 |
9 | ## Environment setup
10 |
11 | In project root folder run in bash:
12 |
13 | ```bash
14 | # cpu only
15 | uv sync --extra cpu -U
16 | # or with cuda
17 | uv sync --extra cu126 -U
18 | ```
19 |
20 | ## Run
21 |
22 | `cd` into project root folder, trigger by:
23 |
24 | ```bash
25 | uv run ./examples/benchmark/__init__.py
26 | ```
27 |
--------------------------------------------------------------------------------
/examples/pred_intvl/readme.md:
--------------------------------------------------------------------------------
1 | ## Specify project root
2 |
3 | Put a `.env` file in project root folder to specify the project root folder location as an env var:
4 |
5 | ```bash
6 | DIR_WORK='~/path/to/project/root/folder'
7 | ```
8 |
9 | ## Environment setup
10 |
11 | In project root folder run in bash:
12 |
13 | ```bash
14 | # cpu only
15 | uv sync --extra cpu -U
16 | # or with cuda
17 | uv sync --extra cu126 -U
18 | ```
19 |
20 | ## Run
21 |
22 | `cd` into project root folder, trigger by:
23 |
24 | ```bash
25 | uv run ./examples/pred_intvl/experiments.py
26 | ```
27 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | clean:
18 | @$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O);
19 |
20 | # Catch-all target: route all unknown targets to Sphinx using the new
21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
22 | %: Makefile
23 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
24 |
--------------------------------------------------------------------------------
/torchvinecopulib/__init__.py:
--------------------------------------------------------------------------------
1 | from . import util
2 | from .bicop import BiCop
3 | from .vinecop import VineCop
4 |
5 | __all__ = [
6 | "BiCop",
7 | "VineCop",
8 | "util",
9 | ]
10 | # dynamically grab the version you just built & installed
11 | try:
12 | from importlib.metadata import version, PackageNotFoundError
13 | except ImportError:
14 | # Python <3.8 fallback
15 | from pkg_resources import (
16 | get_distribution as version,
17 | DistributionNotFound as PackageNotFoundError,
18 | )
19 |
20 | try:
21 | __version__ = version(__name__)
22 | except PackageNotFoundError:
23 | # this can happen if you run from a source checkout
24 | __version__ = "0+unknown"
25 |
26 | __title__ = "torchvinecopulib" # the canonical project name
27 | __author__ = "Tuoyuan Cheng"
28 | __url__ = "https://github.com/TY-Cheng/torchvinecopulib" # the project homepage
29 | __description__ = "Fitting and sampling vine copulas using PyTorch."
30 |
--------------------------------------------------------------------------------
/tests/test_package_metadata.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 | import torchvinecopulib as tvc
3 |
4 |
5 | def test___all___imports_everything():
6 | # make sure every name in __all__ actually exists on the package
7 | for name in tvc.__all__:
8 | assert hasattr(tvc, name), f"{name!r} is missing from torchvinecopulib"
9 |
10 |
11 | def test_version_matches_distribution():
12 | # this will throw if the package isn’t actually installed under that name
13 | dist_version = importlib.metadata.version("torchvinecopulib")
14 | assert tvc.__version__ == dist_version
15 |
16 |
17 | def test_metadata_fields():
18 | # simple sanity‐checks of your metadata dunders
19 | assert isinstance(tvc.__version__, str) and len(tvc.__version__) > 0
20 | assert tvc.__title__ == "torchvinecopulib"
21 | assert isinstance(tvc.__author__, str) and len(tvc.__author__) > 10
22 | assert tvc.__url__.startswith("https://github.com/TY-Cheng/torchvinecopulib")
23 | assert isinstance(tvc.__description__, str) and len(tvc.__description__) > 10
24 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | from sphinx_pyproject import SphinxConfig
5 |
6 | sys.path.append(".")
7 | sys.path.insert(0, str(Path(__file__).parents[1]))
8 | # * load the pyproject.toml file using SphinxConfig
9 | # * using Path for better cross-platform compatibility
10 | try:
11 | config = SphinxConfig()
12 | except FileNotFoundError as err:
13 | raise FileNotFoundError("pyproject.toml not found") from err
14 |
15 | # * project metadata
16 | project = config.name
17 | author = config.author
18 | maintainer = config.get("maintainer", author)
19 | copyright = config.get("copyright", f"2024-, {author}")
20 | version = release = config.version
21 | documentation_summary = config.description
22 | extensions = config.get("extensions", [])
23 | html_theme = config.get("html_theme", "furo")
24 | html_title = f"{project} v{version}"
25 | html_theme_options = {
26 | "sidebar_hide_name": False,
27 | # "light_logo": "../torchvinecopulib.png",
28 | # "dark_logo": "../torchvinecopulib.png",
29 | # "sticky_navigation": True,
30 | # "navigation_with_keys": True,
31 | # "footer_text": f"© {copyright}",
32 | # "navigation_depth": 4,
33 | # "titles_only": False,
34 | }
35 | autosummary_generate = True
36 |
--------------------------------------------------------------------------------
/examples/vcae/vcae/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 |
4 | import torch
5 |
6 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
7 | torch.set_float32_matmul_precision("medium")
8 |
9 |
10 | @dataclass
11 | class Config:
12 | # Reproducibility
13 | seed: int = 42
14 |
15 | # Training-related
16 | data_dir: str = os.environ.get("PATH_DATASETS", ".")
17 | save_dir: str = "logs/"
18 | batch_size: int = 512 if torch.cuda.is_available() else 64
19 | max_epochs: int = 10
20 | accelerator: str = DEVICE
21 | devices: int = 1
22 | num_workers: int = 1 # or min(15, os.cpu_count())
23 |
24 | # Data-related
25 | dims: tuple[int, ...] = (1, 28, 28)
26 | val_train_split: float = 0.1
27 |
28 | # Model-related
29 | hidden_size: int = 64
30 | latent_size: int = 10
31 | learning_rate: float = 2e-4
32 | vine_lambda: float = 0.0
33 | # use_mmd: bool = False
34 | # mmd_sigmas: list[float] = [1e-1, 1, 10]
35 | # mmd_lambda: float = 10.0
36 |
37 | config_mnist = Config(
38 | max_epochs=10,
39 | dims=(1, 28, 28),
40 | hidden_size=64,
41 | latent_size=10,
42 | )
43 |
44 | config_svhn = Config(
45 | max_epochs=50,
46 | dims=(3, 32, 32),
47 | hidden_size=128,
48 | latent_size=32,
49 | )
50 |
--------------------------------------------------------------------------------
/examples/vcae/run_seeds.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | PYTHON_BIN=$(which python)
4 | USE_NOHUP=false
5 |
6 | # Defaults
7 | START=0
8 | END=30
9 | STEP=10
10 |
11 | # Parse arguments
12 | POSITIONAL=()
13 | while [[ $# -gt 0 ]]; do
14 | case "$1" in
15 | --nohup)
16 | USE_NOHUP=true
17 | shift
18 | ;;
19 | *)
20 | POSITIONAL+=("$1")
21 | shift
22 | ;;
23 | esac
24 | done
25 |
26 | # Restore positional args
27 | set -- "${POSITIONAL[@]}"
28 |
29 | # Assign range values if provided
30 | if [[ $# -ge 1 ]]; then START=$1; fi
31 | if [[ $# -ge 2 ]]; then END=$2; fi
32 | if [[ $# -ge 3 ]]; then STEP=$3; fi
33 |
34 | # Validate input
35 | if (( STEP <= 0 )); then
36 | echo "Error: STEP must be a positive integer." >&2
37 | exit 1
38 | fi
39 |
40 | if (( (END - START) % STEP != 0 )); then
41 | echo "Error: (END - START) must be divisible by STEP." >&2
42 | exit 1
43 | fi
44 |
45 | echo "Using Python binary: $PYTHON_BIN"
46 | echo "Using nohup: $USE_NOHUP"
47 | echo "Range: $START to $END with step $STEP"
48 |
49 | # Launch loop
50 | for ((i = START; i < END; i += STEP)); do
51 | j=$((i + STEP))
52 | name="seeds_${i}_${j}"
53 | echo "Launching $name"
54 | if $USE_NOHUP; then
55 | nohup "$PYTHON_BIN" run_seeds.py $i $j > logs/$name.log 2>&1 &
56 | else
57 | "$PYTHON_BIN" run_seeds.py $i $j > logs/$name.log 2>&1 &
58 | fi
59 | done
60 |
61 | wait
62 |
--------------------------------------------------------------------------------
/examples/vcae/run_seeds.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import logging
3 | import os
4 | import sys
5 | from typing import Union
6 |
7 | import pandas as pd
8 | from tqdm import tqdm
9 | from vcae.config import config_mnist, config_svhn
10 | from vcae.experiment import run_experiment
11 |
12 | dataset = "MNIST" # or "SVHN"
13 | start = int(sys.argv[1])
14 | end = int(sys.argv[2])
15 |
16 | if dataset == "MNIST":
17 | config = config_mnist
18 | elif dataset == "SVHN":
19 | config = config_svhn
20 | else:
21 | raise ValueError(f"Unsupported dataset: {dataset}")
22 |
23 |
24 | # Redirect tqdm and errors to log file
25 | log_path = f"progress_{dataset}_{start}_{end}.log"
26 | log_file = open(log_path, "w")
27 |
28 | logging.basicConfig(
29 | level=logging.INFO,
30 | format="%(asctime)s [%(levelname)s] %(message)s",
31 | handlers=[logging.StreamHandler(log_file)],
32 | )
33 |
34 |
35 | @contextlib.contextmanager
36 | def suppress_output():
37 | with contextlib.redirect_stdout(log_file), contextlib.redirect_stderr(log_file):
38 | logging_level = logging.getLogger().level
39 | logging.getLogger().setLevel(logging.ERROR)
40 | try:
41 | yield
42 | finally:
43 | logging.getLogger().setLevel(logging_level)
44 |
45 |
46 | results: list[dict[str, Union[float, int, str]]] = []
47 | output_path = f"results_{dataset}_{start}_{end}.csv"
48 | for seed in tqdm(range(start, end), desc=f"Seeds {start}-{end}", file=log_file):
49 | try:
50 | with suppress_output():
51 | result = run_experiment(seed, config, dataset=dataset)
52 | df = pd.DataFrame([result])
53 |
54 | # Write headers only once
55 | if not os.path.exists(output_path):
56 | df.to_csv(output_path, index=False, mode="w")
57 | else:
58 | df.to_csv(output_path, index=False, mode="a", header=False)
59 |
60 | except Exception:
61 | logging.exception(f"Exception while running seed {seed}")
62 |
63 | logging.info(f"All done. Results saved to {output_path}")
64 | log_file.close()
65 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # conftest.py
2 | import numpy as np
3 | import pytest
4 | import pyvinecopulib as pvc
5 | import torch
6 |
7 | import torchvinecopulib as tvc
8 |
9 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 | N_SIM = 2000
11 | SEEDS = list(range(5))
12 | EPS = tvc.util._EPS
13 | # List the (family, true-parameter) pairs you want to test
14 | FAMILIES = [
15 | (pvc.gaussian, np.array([[0.7]]), 0), # rotation 0
16 | (pvc.clayton, np.array([[0.9]]), 0),
17 | (pvc.clayton, np.array([[0.9]]), 90),
18 | (pvc.clayton, np.array([[0.9]]), 180),
19 | (pvc.clayton, np.array([[0.9]]), 270),
20 | (pvc.frank, np.array([[3.0]]), 0),
21 | (pvc.frank, np.array([[-3.0]]), 0),
22 | # …add more if you like…
23 | ]
24 |
25 |
26 | @pytest.fixture(scope="module", params=FAMILIES, ids=lambda f: f[0].name)
27 | def bicop_pair(request):
28 | """
29 | Returns a tuple:
30 | ( family, true_params, U_tensor, bicop_fastkde, bicop_tll )
31 |
32 | notice the scope="module" so that the fixture is created only once and reused in all tests that use it.
33 | """
34 | family, true_params, rotation = request.param
35 |
36 | # 1) build the 'true' copula and simulate U
37 | true_bc = pvc.Bicop(family=family, parameters=true_params, rotation=rotation)
38 | U = true_bc.simulate(n=N_SIM, seeds=SEEDS) # shape (N_SIM, 2)
39 | U_tensor = torch.tensor(U, device=DEVICE, dtype=torch.float64)
40 |
41 | # 2) fit two torchvinecopulib instances (fast KDE and TLL)
42 | bc_fast = tvc.BiCop(num_step_grid=512).to(DEVICE)
43 | bc_fast.fit(U_tensor, mtd_kde="fastKDE")
44 |
45 | bc_tll = tvc.BiCop(num_step_grid=512).to(DEVICE)
46 | bc_tll.fit(U_tensor, mtd_kde="tll")
47 |
48 | return family, true_params, rotation, U_tensor, bc_fast, bc_tll
49 |
50 |
51 | @pytest.fixture(scope="module")
52 | def U_tensor():
53 | # a moderately‐sized random [0,1]² sample
54 | return torch.rand(500, 2, dtype=torch.float64)
55 |
56 |
57 | @pytest.fixture(scope="module")
58 | def sample_1d():
59 | torch.manual_seed(0)
60 | return torch.randn(1024, 1) # standard normal
61 |
--------------------------------------------------------------------------------
/.github/workflows/static.yml:
--------------------------------------------------------------------------------
1 | # This workflow builds and deploys the documentation to GitHub Pages
2 | name: Deploy Docs
3 |
4 | on:
5 | # Trigger after your Lint Pytest workflow completes on main branch
6 | workflow_run:
7 | workflows: ["Lint Pytest"]
8 | types: [completed]
9 | branches: ["main"]
10 |
11 | # ALso allow manual dispatch
12 | workflow_dispatch:
13 |
14 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
15 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete.
16 | concurrency:
17 | group: "pages"
18 | cancel-in-progress: false
19 |
20 | jobs:
21 | deploy:
22 | if: ${{github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success' }}
23 | runs-on: ubuntu-latest
24 | permissions:
25 | contents: write # for pushing to gh-pages branch
26 |
27 | steps:
28 | - name: Checkout code
29 | uses: actions/checkout@v4
30 |
31 | - name: Set up Python 3.11
32 | uses: actions/setup-python@v5
33 | with:
34 | python-version: "3.11"
35 |
36 | - name: Cache pip
37 | uses: actions/cache@v4
38 | with:
39 | path: ~/.cache/pip
40 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }}
41 |
42 | - name: Install uv and sync deps
43 | run: |
44 | python3 -m pip install --upgrade pip
45 | python3 -m pip install uv
46 | python3 -m uv sync --extra cpu
47 |
48 | - name: Build Sphinx docs
49 | # regenerate module RST files (force) and build HTML docs
50 | run: |
51 | uv run sphinx-apidoc --force -o docs torchvinecopulib/ --separate
52 | uv run sphinx-build -b html docs/ docs/_build/html
53 |
54 | # Deploy to gh-pages branch
55 |
56 | - name: Deploy to gh-pages branch
57 | uses: peaceiris/actions-gh-pages@v4
58 | with:
59 | force_orphan: true
60 | github_token: ${{ secrets.GITHUB_TOKEN }}
61 | publish_branch: gh-pages
62 | publish_dir: docs/_build/html
63 |
--------------------------------------------------------------------------------
/examples/pred_intvl/vine_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch._dynamo as dynamo
3 |
4 | import torchvinecopulib as tvc
5 |
6 |
7 | def train_vine(Z_train, Y_train, seed=42, device="cpu"):
8 | # * build vine cop on Z and Y together
9 | ZY_train = torch.cat([Z_train, Y_train.view(-1, 1)], dim=1)
10 | # * all except last Y column, for MST in first stage
11 | first_tree_vertex = tuple(range(Z_train.shape[1]))
12 | model_vine: tvc.VineCop = tvc.VineCop(num_dim=ZY_train.shape[1], is_cop_scale=False).to(
13 | device=device
14 | )
15 | model_vine.fit(
16 | obs=ZY_train,
17 | first_tree_vertex=first_tree_vertex,
18 | mtd_bidep="ferreira_tail_dep_coeff",
19 | mtd_kde="tll",
20 | mtd_tll="quadratic",
21 | seed=seed,
22 | )
23 | return model_vine
24 |
25 |
26 | @dynamo.disable
27 | @torch.no_grad()
28 | def vine_pred_intvl(model: tvc.VineCop, Z_test, alpha=0.05, seed=42, device="cpu"):
29 | # * assuming Zy is fitted by model; y is the last column
30 | num_sample = Z_test.shape[0]
31 | idx_quantile = Z_test.shape[1]
32 | sample_order = (idx_quantile,)
33 | # * fill marginal obs (will be handled by marginal if not is_cop_scale)
34 | dct_v_s_obs = {(_,): Z_test[:, [_]] for _ in range(Z_test.shape[1])}
35 | # * fill quantile deep in the vine (assuming cop scale)
36 | dct_v_s_obs[model.sample_order] = torch.full((num_sample, 1), alpha / 2, device=device)
37 | lower = model.sample(
38 | num_sample=num_sample,
39 | sample_order=sample_order,
40 | dct_v_s_obs=dct_v_s_obs,
41 | seed=seed,
42 | )[:, idx_quantile]
43 | dct_v_s_obs[model.sample_order] = torch.full((num_sample, 1), 0.5, device=device)
44 | median = model.sample(
45 | num_sample=num_sample,
46 | sample_order=sample_order,
47 | dct_v_s_obs=dct_v_s_obs,
48 | seed=seed,
49 | )[:, idx_quantile]
50 | dct_v_s_obs[model.sample_order] = torch.full((num_sample, 1), 1 - alpha / 2, device=device)
51 | upper = model.sample(
52 | num_sample=num_sample,
53 | sample_order=sample_order,
54 | dct_v_s_obs=dct_v_s_obs,
55 | seed=seed,
56 | )[:, idx_quantile]
57 | return median.cpu().numpy(), lower.cpu().numpy(), upper.cpu().numpy()
58 |
--------------------------------------------------------------------------------
/examples/vcae/vcae/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy import linalg
4 |
5 |
6 | def mmd(real: torch.Tensor, fake: torch.Tensor, sigmas=[1e-3, 1e-2, 1e-1, 1, 10, 100]):
7 | """
8 | Differentiable MMD loss using Gaussian kernels with fixed sigmas and
9 | distance normalization via Mxx.mean().
10 |
11 | Parameters
12 | ----------
13 | real : (n, d) tensor
14 | Batch of real samples (features or images).
15 | fake : (m, d) tensor
16 | Batch of generated samples.
17 | sigmas : list of float
18 | Bandwidths for the RBF kernel. Defaults to wide, fixed list.
19 |
20 | Returns
21 | -------
22 | mmd : scalar tensor
23 | Differentiable scalar loss value.
24 | """
25 | real = real.view(real.size(0), -1)
26 | fake = fake.view(fake.size(0), -1)
27 |
28 | def pairwise_squared_distances(x, y):
29 | x_norm = (x**2).sum(dim=1, keepdim=True)
30 | y_norm = (y**2).sum(dim=1, keepdim=True)
31 | return x_norm + y_norm.T - 2.0 * x @ y.T
32 |
33 | Mxx = pairwise_squared_distances(real, real)
34 | Mxy = pairwise_squared_distances(real, fake)
35 | Myy = pairwise_squared_distances(fake, fake)
36 |
37 | # Normalization factor based on real-real distances
38 | scale = Mxx.mean().detach()
39 |
40 | mmd_total = 0.0
41 | for sigma in sigmas:
42 | denom = scale * 2.0 * sigma**2
43 | Kxx = torch.exp(-Mxx / denom)
44 | Kxy = torch.exp(-Mxy / denom)
45 | Kyy = torch.exp(-Myy / denom)
46 |
47 | mmd_total += Kxx.mean() + Kyy.mean() - 2.0 * Kxy.mean()
48 |
49 | return mmd_total / len(sigmas)
50 |
51 |
52 | def fid(X, Y):
53 | m = X.mean(0)
54 | m_w = Y.mean(0)
55 | X_np = X.numpy()
56 | Y_np = Y.numpy()
57 |
58 | C = np.cov(X_np.transpose())
59 | C_w = np.cov(Y_np.transpose())
60 | C_C_w_sqrt = linalg.sqrtm(C.dot(C_w), True).real
61 |
62 | score = m.dot(m) + m_w.dot(m_w) - 2 * m_w.dot(m) + np.trace(C + C_w - 2 * C_C_w_sqrt)
63 | return np.sqrt(score)
64 |
65 |
66 | class Score:
67 | mmd = 0
68 | fid = 0
69 |
70 |
71 | def compute_score(real, fake, sigmas=[1e-3, 1e-2, 1e-1, 1, 10, 100]):
72 | real = real.to("cpu")
73 | fake = fake.to("cpu")
74 |
75 | s = Score()
76 | s.mmd = np.sqrt(mmd(real, fake, sigmas).numpy())
77 | s.fid = fid(fake, real).numpy()
78 |
79 | return s
80 |
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Lint Pytest
5 |
6 | on:
7 | push:
8 | branches: ["main"]
9 | pull_request:
10 | branches: ["main"]
11 |
12 | workflow_dispatch:
13 |
14 | jobs:
15 | lint-pytest:
16 | runs-on: ${{ matrix.os }}
17 | strategy:
18 | fail-fast: false
19 | matrix:
20 | os: [ubuntu-latest, windows-latest, macos-latest]
21 | python-version: ["3.11", "3.12", "3.13"]
22 | steps:
23 | - name: Checkout code
24 | uses: actions/checkout@v4
25 |
26 | - name: Set up Python ${{ matrix.python-version }}
27 | uses: actions/setup-python@v5
28 | with:
29 | python-version: ${{ matrix.python-version }}
30 |
31 | - name: Cache pip
32 | uses: actions/cache@v4
33 | with:
34 | path: ~/.cache/pip
35 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }}
36 |
37 | - name: Install uv and sync deps
38 | run: |
39 | python3 -m pip install --upgrade pip
40 | python3 -m pip install uv
41 | python3 -m uv sync --extra cpu
42 | # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
43 |
44 | - name: Lint with flake8
45 | run: |
46 | # stop the build if there are Python syntax errors or undefined names
47 | uv run flake8 ./torchvinecopulib --exclude .venv --count --select=E9,F63,F7,F82 --show-source --statistics
48 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
49 | uv run flake8 ./torchvinecopulib --count --exit-zero --max-complexity=10 --max-line-length=127 --ignore=E203,W503,E501,C901 --statistics
50 |
51 | - name: Test with pytest
52 | run: |
53 | uv run python -c "import sys; print(sys.version)"
54 | uv run python -c "import torch; print(torch.cuda.is_available())"
55 | uv run coverage run --source=torchvinecopulib -m pytest tests
56 | uv run coverage report
57 | uv run coverage xml -o coverage.xml
58 |
59 | - name: Upload coverage to Codacy
60 | if: ${{ matrix.os == 'ubuntu-latest' && matrix.python-version == '3.13' }}
61 | run: |
62 | export CODACY_API_BASE_URL=${{ secrets.CODACY_API_BASE_URL }}
63 | export CODACY_PROJECT_TOKEN=${{ secrets.CODACY_PROJECT_TOKEN }}
64 | bash <(curl -Ls https://coverage.codacy.com/get.sh) report -r ./coverage.xml
65 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Put a project logo here if you have one
2 | .. .. image:: _static/logo.png
3 | .. :alt: torchvinecopulib Logo
4 |
5 | ============================================
6 | Welcome to torchvinecopulib's documentation!
7 | ============================================
8 |
9 | .. Add badges here
10 | .. raw:: html
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | ``torchvinecopulib`` is a ``Python`` library for fitting and sampling vine copulas using ``PyTorch``.
27 | It is designed for researchers and practitioners in statistics, machine learning, and finance who need flexible, GPU-accelerated, and vectorized copula modeling and sampling.
28 |
29 | **GitHub Repository:** https://github.com/TY-Cheng/torchvinecopulib
30 |
31 | .. toctree::
32 | :maxdepth: 3
33 | :caption: Core Modules
34 | :titlesonly:
35 |
36 | modules.rst
37 |
38 |
39 | Indices and tables
40 | ==================
41 |
42 | * :ref:`genindex`
43 | * :ref:`modindex`
44 | * :ref:`search`
--------------------------------------------------------------------------------
/examples/vcae/vcae/experiment.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Union
3 |
4 | import numpy as np
5 | import pytorch_lightning as pl
6 | import torch
7 |
8 | from .config import DEVICE, Config
9 | from .metrics import compute_score
10 | from .model import LitAutoencoder, LitMNISTAutoencoder, LitSVHNAutoencoder
11 |
12 |
13 | def set_seed(seed: int):
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 | torch.backends.cudnn.benchmark = True
19 | torch.backends.cudnn.enabled = True
20 | torch.backends.cudnn.deterministic = True
21 |
22 |
23 | def run_experiment(
24 | seed: int, config: Config, vine_lambda: float = 1.0, dataset: str = "MNIST"
25 | ) -> dict[str, Union[float, int, str]]:
26 | # Set the seed for reproducibility
27 | set_seed(seed)
28 | config.seed = seed
29 |
30 | # Instantiate the model
31 | model_initial: LitAutoencoder
32 | if dataset == "MNIST":
33 | model_initial = LitMNISTAutoencoder(config)
34 | elif dataset == "SVHN":
35 | model_initial = LitSVHNAutoencoder(config)
36 | else:
37 | raise ValueError(f"Unsupported dataset: {dataset}")
38 |
39 | # Set up trainer
40 | trainer_initial = pl.Trainer(
41 | accelerator=config.accelerator,
42 | devices=config.devices,
43 | max_epochs=config.max_epochs,
44 | logger=False, # disables all loggers
45 | enable_progress_bar=False, # disables tqdm
46 | enable_model_summary=False, # disables model summary printout
47 | )
48 |
49 | # Train the base autoencoder
50 | trainer_initial.fit(model_initial)
51 |
52 | # Stay on DEVICE
53 | model_initial.to(DEVICE)
54 |
55 | # Learn vine
56 | model_initial.learn_vine(n_samples=5000)
57 |
58 | # Extract test data
59 | rep_initial, _, data_initial, decoded_initial, samples_initial = model_initial.get_data(
60 | stage="test"
61 | )
62 |
63 | # Reset the seed for refitting to avoid data leakage
64 | set_seed(seed)
65 |
66 | # Create a new model with the same configuration but reset vine lambda
67 | config.vine_lambda = vine_lambda
68 | model_refit = model_initial.copy_with_config(config)
69 |
70 | # Set up trainer for refitting
71 | trainer_refit = pl.Trainer(
72 | accelerator=config.accelerator,
73 | devices=config.devices,
74 | max_epochs=config.max_epochs,
75 | logger=False, # disables all loggers
76 | enable_progress_bar=False, # disables tqdm
77 | enable_model_summary=False, # disables model summary printout
78 | )
79 |
80 | # Refit the model
81 | trainer_refit.fit(model_refit)
82 |
83 | # Stay on DEVICE
84 | model_refit.to(DEVICE)
85 |
86 | # Extract test data
87 | rep_refit, _, data_refit, decoded_refit, samples_refit = model_refit.get_data(stage="test")
88 |
89 | assert model_initial.vine is not None
90 | assert model_refit.vine is not None
91 | loglik_initial = model_initial.vine.log_pdf(rep_initial).mean().item()
92 | loglik_refit = model_refit.vine.log_pdf(rep_refit).mean().item()
93 |
94 | mse_initial = torch.nn.functional.mse_loss(decoded_initial, data_initial).item()
95 | mse_refit = torch.nn.functional.mse_loss(decoded_refit, data_refit).item()
96 |
97 | sigmas = [1e-3, 1e-2, 1e-1, 1, 10, 100]
98 | score_initial = compute_score(data_initial, samples_initial, sigmas=sigmas)
99 | score_refit = compute_score(data_refit, samples_refit, sigmas=sigmas)
100 |
101 | return {
102 | "seed": seed,
103 | "dataset": dataset,
104 | "mse_initial": mse_initial,
105 | "mse_refit": mse_refit,
106 | "loglik_initial": loglik_initial,
107 | "loglik_refit": loglik_refit,
108 | "mmd_initial": score_initial.mmd,
109 | "mmd_refit": score_refit.mmd,
110 | "fid_initial": score_initial.fid,
111 | "fid_refit": score_refit.fid,
112 | }
113 |
--------------------------------------------------------------------------------
/.github/workflows/publish-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow builds a Python package and publishes it to PyPI and TestPyPI.
2 | name: Publish to PyPI and TestPyPI
3 |
4 | on:
5 | push:
6 | branches: ["main"]
7 | tags:
8 | - "v*.*.*" # Match semantic version tags
9 | workflow_dispatch:
10 | # Allow manual dispatch for testing or immediate releases
11 |
12 | jobs:
13 | build:
14 | name: Build distributions
15 | runs-on: ubuntu-latest
16 | steps:
17 | - name: Checkout code
18 | uses: actions/checkout@v4
19 |
20 | - name: Set up Python 3.11
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: 3.11
24 |
25 | - name: Cache pip
26 | uses: actions/cache@v4
27 | with:
28 | path: ~/.cache/pip
29 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }}
30 |
31 | - name: Install uv and sync deps, build sdist and wheel
32 | run: |
33 | python3 -m pip install --upgrade pip
34 | python3 -m pip install uv
35 | python3 -m uv sync --extra cpu
36 | python3 -m uv build
37 |
38 | - name: Upload dist/*
39 | uses: actions/upload-artifact@v4
40 | with:
41 | name: python-distribution-packages
42 | path: dist/*
43 |
44 | publish-testpypi:
45 | name: Publish to TestPyPI
46 | needs: build
47 | runs-on: ubuntu-latest
48 | environment:
49 | name: testpypi
50 | url: https://test.pypi.org/p/torchvinecopulib/
51 | permissions:
52 | id-token: write
53 | # contents: write
54 | steps:
55 | - name: Download dist/*
56 | uses: actions/download-artifact@v4
57 | with:
58 | name: python-distribution-packages
59 | path: dist/
60 |
61 | - name: Publish to TestPyPI
62 | uses: pypa/gh-action-pypi-publish@release/v1
63 | with:
64 | verbose: true
65 | skip-existing: true
66 | repository-url: https://test.pypi.org/legacy/
67 | env:
68 | TWINE_USERNAME: __token__
69 | TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}
70 |
71 | publish-pypi:
72 | name: Publish to PyPI
73 | runs-on: ubuntu-latest
74 | needs: publish-testpypi
75 | if: startsWith(github.ref, 'refs/tags/v')
76 | environment:
77 | name: pypi
78 | url: https://pypi.org/project/torchvinecopulib/
79 | permissions:
80 | id-token: write
81 | # contents: write
82 | steps:
83 | - name: Download dist/*
84 | uses: actions/download-artifact@v4
85 | with:
86 | name: python-distribution-packages
87 | path: dist/
88 |
89 | - name: Publish to PyPI via Twine
90 | uses: pypa/gh-action-pypi-publish@release/v1
91 | env:
92 | TWINE_USERNAME: __token__
93 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
94 | with:
95 | verbose: true
96 | skip-existing: true
97 |
98 | github-release:
99 | name: Create GitHub Release
100 | runs-on: ubuntu-latest
101 | needs: publish-testpypi
102 | if: startsWith(github.ref, 'refs/tags/v')
103 | permissions:
104 | contents: write
105 | id-token: write
106 |
107 | steps:
108 | - name: Download dist/*
109 | uses: actions/download-artifact@v4
110 | with:
111 | name: python-distribution-packages
112 | path: dist/
113 |
114 | - name: Sign the distribution packages with Sigstore
115 | uses: sigstore/gh-action-sigstore-python@v3.0.0
116 | with:
117 | inputs: |
118 | ./dist/*.tar.gz
119 | ./dist/*.whl
120 |
121 | - name: Create GitHub Release and Upload Assets
122 | env:
123 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
124 | run: |
125 | gh release create "${{ github.ref_name }}" \
126 | --repo "${{ github.repository }}" \
127 | --title "Release ${{ github.ref_name }}" \
128 | --notes "Automated release for ${{ github.ref_name }}. Signed artifacts are attached." \
129 | dist/*
130 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # List/Specify python version
2 | # uv python list
3 | # uv python pin python3.13
4 | # Install default + dev dependencies with torch for CPU/cu126:
5 | # uv sync --extra cpu -U
6 | # uv sync --extra cu126 -U
7 |
8 | [build-system]
9 | requires = ["hatchling"]
10 | build-backend = "hatchling.build"
11 |
12 | [project.urls]
13 | # homepage = ""
14 | repository = "https://github.com/TY-Cheng/torchvinecopulib"
15 | documentation = "https://ty-cheng.github.io/torchvinecopulib/"
16 |
17 | [project]
18 | name = "torchvinecopulib"
19 | version = "1.1.2"
20 | requires-python = ">=3.11"
21 | authors = [{ name = "Tuoyuan Cheng", email = "tuoyuan.cheng@nus.edu.sg" }]
22 | description = "Fitting and sampling vine copulas using PyTorch."
23 | readme = "README.md"
24 | keywords = ["vine copula", "copula", "torch", "conditional simulation"]
25 | classifiers = [
26 | # "Private :: Do Not Upload",
27 | "License :: OSI Approved :: MIT License",
28 | "Environment :: GPU :: NVIDIA CUDA :: 12",
29 | "Operating System :: Microsoft :: Windows :: Windows 10",
30 | "Operating System :: POSIX :: Linux",
31 | "Operating System :: MacOS",
32 | "Programming Language :: Python :: 3.11",
33 | "Programming Language :: Python :: 3.12",
34 | "Programming Language :: Python :: 3.13",
35 | "Topic :: Scientific/Engineering",
36 | ]
37 | dependencies = ["numpy>=2", "scipy", "fastkde==2.1.3", "pyvinecopulib"]
38 |
39 | [dependency-groups]
40 | dev = [
41 | "blitz-bayesian-pytorch",
42 | "ccxt",
43 | "coverage",
44 | "datasets",
45 | "docformatter",
46 | "flake8",
47 | "furo",
48 | "ipykernel",
49 | "kagglehub",
50 | "matplotlib",
51 | "missingno",
52 | "mypy",
53 | "pandas",
54 | "pytest-cov",
55 | "pytest",
56 | "python-dotenv",
57 | "ruff",
58 | "scikit-learn",
59 | "sphinx_pyproject",
60 | "sphinx",
61 | "tokenize-rt",
62 | "torchvision",
63 | "ucimlrepo",
64 | "yfinance",
65 | ]
66 |
67 | [project.optional-dependencies]
68 | cpu = ["torch>=2"]
69 | cu126 = ["torch>=2"]
70 | cu128 = ["torch>=2"]
71 | examples = ["pytorch-lightning", "tqdm"]
72 |
73 | [tool.uv]
74 | managed = true
75 | default-groups = ["dev"]
76 | conflicts = [[{ extra = "cpu" }, { extra = "cu126" }, { extra = "cu128" }]]
77 |
78 | [tool.uv.sources]
79 | torch = [
80 | { index = "torch-cpu", extra = "cpu" },
81 | { index = "torch-cu126", extra = "cu126" },
82 | { index = "torch-cu128", extra = "cu128" },
83 | ]
84 |
85 | [[tool.uv.index]]
86 | name = "torch-cpu"
87 | url = "https://download.pytorch.org/whl/cpu"
88 | explicit = true
89 |
90 | [[tool.uv.index]]
91 | name = "torch-cu126"
92 | url = "https://download.pytorch.org/whl/cu126"
93 | explicit = true
94 |
95 | [[tool.uv.index]]
96 | name = "torch-cu128"
97 | url = "https://download.pytorch.org/whl/cu128"
98 | explicit = true
99 |
100 | [[tool.uv.index]]
101 | name = "testpypi"
102 | url = "https://test.pypi.org/simple/"
103 | publish-url = "https://test.pypi.org/legacy/"
104 | explicit = true
105 |
106 | [tool.ruff]
107 | line-length = 99
108 |
109 | [tool.ruff.format]
110 | quote-style = "double"
111 | indent-style = "space"
112 | docstring-code-format = true
113 | docstring-code-line-length = 99
114 |
115 | [tool.ruff.lint.pycodestyle]
116 | max-doc-length = 99
117 | max-line-length = 99
118 |
119 | [tool.docformatter]
120 | wrap-summaries = 99
121 | wrap-descriptions = 99
122 | blank-lines-around-summary = true
123 | in-place = true
124 | make-summary-multi-line = true
125 | pre-summary-space = true
126 | recursive = true
127 | skip-ignore = false
128 |
129 | [tool.flake8]
130 | max-line-length = 127
131 | extend-ignore = ["E203", "W503"]
132 | exclude = [".venv", "build", "dist", "docs"]
133 | per-file-ignores = [
134 | "torchvinecopulib/vinecop/__init__.py: E501, C901",
135 | "torchvinecopulib/bicop/__init__.py: E501, C901",
136 | "torchvinecopulib/util/__init__.py: E501, C901",
137 | ]
138 |
139 | [tool.sphinx-pyproject]
140 | source-dir = "docs"
141 | build-dir = "docs/_build"
142 | all_files = true
143 | extensions = [
144 | "sphinx.ext.autodoc",
145 | "sphinx.ext.autosectionlabel",
146 | "sphinx.ext.autosummary",
147 | "sphinx.ext.coverage",
148 | "sphinx.ext.doctest",
149 | "sphinx.ext.extlinks",
150 | "sphinx.ext.githubpages",
151 | "sphinx.ext.graphviz",
152 | "sphinx.ext.ifconfig",
153 | "sphinx.ext.imgconverter",
154 | "sphinx.ext.inheritance_diagram",
155 | "sphinx.ext.intersphinx",
156 | "sphinx.ext.napoleon",
157 | "sphinx.ext.todo",
158 | "sphinx.ext.viewcode",
159 | ]
160 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
161 | html_theme = "furo"
162 |
--------------------------------------------------------------------------------
/tests/test_util.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from scipy.stats import kendalltau as scipy_tau
4 |
5 | from torchvinecopulib.util import (
6 | ENUM_FUNC_BIDEP,
7 | chatterjee_xi,
8 | ferreira_tail_dep_coeff,
9 | kdeCDFPPF1D,
10 | kendall_tau,
11 | mutual_info,
12 | solve_ITP,
13 | )
14 |
15 | from . import EPS, U_tensor, bicop_pair, sample_1d
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "x,y,expected",
20 | [
21 | ([1, 2, 3], [1, 2, 3], 1.0),
22 | ([1, 2, 3], [3, 2, 1], -1.0),
23 | ],
24 | )
25 | def test_kendall_tau_perfect(x, y, expected):
26 | x = torch.tensor(x).view(-1, 1).double()
27 | y = torch.tensor(y).view(-1, 1).double()
28 | tau, p = kendall_tau(x, y)
29 | assert pytest.approx(expected, abs=1e-6) == tau.item()
30 |
31 |
32 | def test_kendall_tau_matches_scipy_random():
33 | torch.manual_seed(0)
34 | x = torch.rand(50, 1)
35 | y = torch.rand(50, 1)
36 | tau_torch, p = kendall_tau(x, y)
37 | tau_scipy, p_scipy = scipy_tau(x.flatten().numpy(), y.flatten().numpy())
38 | assert pytest.approx(tau_scipy, rel=1e-6) == tau_torch.item()
39 |
40 |
41 | def test_mutual_info_independent():
42 | torch.manual_seed(0)
43 | mi = mutual_info(*torch.rand(2, 10000))
44 | assert abs(mi.item()) < 2e-2 # near zero
45 |
46 |
47 | def test_mutual_info_dependent():
48 | torch.manual_seed(0)
49 | u = torch.rand(500, 1)
50 | v = u.clone()
51 | mi = mutual_info(u, v)
52 | assert mi.item() > 0.5 # significantly positive
53 |
54 |
55 | def test_tail_dep_perfect():
56 | u = torch.linspace(0, 1, 500).view(-1, 1)
57 | # y=u → perfect
58 | lam = ferreira_tail_dep_coeff(u, u)
59 | assert pytest.approx(1.0, rel=1e-3) == lam.item()
60 |
61 |
62 | def test_tail_dep_independent():
63 | torch.manual_seed(0)
64 | u = torch.rand(1000, 1)
65 | v = torch.rand(1000, 1)
66 | lam = ferreira_tail_dep_coeff(u, v)
67 | assert lam.item() < 0.2 # near zero
68 |
69 |
70 | def test_xi_perfect():
71 | u = torch.arange(1, 101).view(-1, 1).double()
72 | xi = chatterjee_xi(u, u)
73 | assert pytest.approx(1.0, abs=3e-2) == xi.item()
74 |
75 |
76 | def test_xi_independent():
77 | torch.manual_seed(1)
78 | u = torch.rand(1000, 1)
79 | v = torch.rand(1000, 1)
80 | xi = chatterjee_xi(u, v)
81 | assert abs(xi.item()) < 0.1
82 |
83 |
84 | def test_enum_dispatches_correctly():
85 | u = torch.rand(50, 1)
86 | v = u.clone()
87 | # perfect correlation → tau=1, xi=1, tail=1, mi>0
88 | out_tau = ENUM_FUNC_BIDEP.kendall_tau(u, v)
89 | out_xi = ENUM_FUNC_BIDEP.chatterjee_xi(u, v)
90 | out_tail = ENUM_FUNC_BIDEP.ferreira_tail_dep_coeff(u, v)
91 | out_mi = ENUM_FUNC_BIDEP.mutual_info(u, v)
92 | assert pytest.approx(1.0, abs=3e-2) == out_tau[0].item()
93 | assert pytest.approx(1.0, abs=3e-2) == out_xi.item()
94 | assert pytest.approx(1.0, rel=3e-2) == out_tail.item()
95 | assert out_mi.item() > 0.5
96 |
97 |
98 | def test_kde_cdf_ppf_inverse(sample_1d):
99 | kde = kdeCDFPPF1D(sample_1d, num_step_grid=257)
100 | range = kde.x_max - kde.x_min
101 | xs = torch.linspace(
102 | kde.x_min + range / 10, kde.x_max - range / 10, 50, dtype=torch.float64
103 | ).view(-1, 1)
104 | qs = kde.cdf(xs)
105 | xs_rec = kde.ppf(qs)
106 | assert torch.allclose(xs, xs_rec, atol=1e-3)
107 |
108 |
109 | def test_kde_bounds_and_pdf(sample_1d):
110 | kde = kdeCDFPPF1D(sample_1d, num_step_grid=257)
111 | # cdf out-of-bounds
112 | oob = torch.tensor([[kde.x_min - 1.0], [kde.x_max + 1.0]], dtype=torch.float64)
113 | assert torch.all(kde.cdf(oob) == torch.tensor([[0.0], [1.0]], dtype=torch.float64))
114 | # ppf out-of-bounds
115 | assert torch.allclose(
116 | torch.tensor([[kde.x_min], [kde.x_max]], dtype=torch.float64),
117 | kde.ppf(torch.tensor([[-1.0], [2.0]])),
118 | )
119 | # pdf ≥ 0
120 | pts = torch.linspace(kde.x_min, kde.x_max, 100).view(-1, 1)
121 | assert (kde.pdf(pts) >= 0).all()
122 | # log_pdf finite
123 | assert torch.isfinite(kde.log_pdf(pts)).all()
124 |
125 |
126 | def test_kde_negloglik_forward(sample_1d):
127 | kde = kdeCDFPPF1D(sample_1d, num_step_grid=None)
128 | val1 = kde.negloglik
129 | val2 = kde.forward(sample_1d)
130 | assert pytest.approx(val1.item(), rel=1e-6) == val2.item()
131 |
132 |
133 | def test_kde_str(sample_1d):
134 | kde = kdeCDFPPF1D(sample_1d, num_step_grid=257)
135 | str_repr = str(kde)
136 | assert "kdeCDFPPF1D" in str_repr
137 | assert "num_step_grid" in str_repr
138 | assert "257" in str_repr
139 | assert "x_min" in str_repr
140 | assert "x_max" in str_repr
141 |
142 |
143 | def test_solve_itp_scalar():
144 | # f(x)=x-0.3 has root at 0.3
145 | f = lambda x: x - 0.3
146 | root = solve_ITP(f, torch.tensor(0.0), torch.tensor(1.0))
147 | assert pytest.approx(0.3, abs=1e-6) == root.item()
148 |
149 |
150 | def test_solve_itp_vectorized():
151 | # two independent eq’s: x-a=0, x-b=0
152 | a = torch.tensor([0.2, 0.7])
153 | b = torch.tensor([1.0, 1.0])
154 | f = lambda x: x - a
155 | roots = solve_ITP(f, torch.zeros_like(a), b)
156 | assert torch.allclose(roots, a, atol=1e-6)
157 |
--------------------------------------------------------------------------------
/examples/vcae/vcae_analyze.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "8410dbd6",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# add autoreload and other ipython magic\n",
11 | "import glob\n",
12 | "import os\n",
13 | "\n",
14 | "import numpy as np\n",
15 | "import pandas as pd\n",
16 | "from IPython import get_ipython\n",
17 | "\n",
18 | "ipython = get_ipython()\n",
19 | "ipython.run_line_magic(\"load_ext\", \"autoreload\")\n",
20 | "ipython.run_line_magic(\"autoreload\", \"2\")\n",
21 | "ipython.run_line_magic(\"matplotlib\", \"inline\")\n",
22 | "ipython.run_line_magic(\"config\", 'InlineBackend.figure_format = \"retina\"')"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "id": "8ac6e908",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# Replace with your folder path\n",
33 | "folder = os.getcwd()\n",
34 | "\n",
35 | "# List all matching CSV files\n",
36 | "csv_files = glob.glob(os.path.join(folder, \"results_*.csv\"))\n",
37 | "\n",
38 | "# Read and concatenate all into a single DataFrame\n",
39 | "df_all = pd.concat((pd.read_csv(f) for f in csv_files), ignore_index=True)\n",
40 | "\n",
41 | "# Make the data longer for analysis\n",
42 | "df_long = (\n",
43 | " df_all.melt(id_vars=\"seed\", var_name=\"variable\", value_name=\"value\")\n",
44 | " .assign(\n",
45 | " metric=lambda x: x[\"variable\"].str.extract(r\"^([a-zA-Z]+)\"),\n",
46 | " method=lambda x: x[\"variable\"].str.extract(r\"_(.*)$\").fillna(\"\"),\n",
47 | " )\n",
48 | " .drop(columns=[\"variable\"])\n",
49 | " .pivot(index=[\"seed\", \"method\"], columns=\"metric\", values=\"value\")\n",
50 | " .reset_index()\n",
51 | " .drop(columns=[\"seed\"])\n",
52 | ")"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "id": "f854dc62",
59 | "metadata": {},
60 | "outputs": [
61 | {
62 | "data": {
63 | "text/html": [
64 | "\n",
65 | "\n",
78 | "
\n",
79 | " \n",
80 | " \n",
81 | " | \n",
82 | " mse | \n",
83 | " loglik | \n",
84 | " mmd | \n",
85 | " fid | \n",
86 | "
\n",
87 | " \n",
88 | " \n",
89 | " \n",
90 | " | initial | \n",
91 | " 0.054 ± 0.001 | \n",
92 | " -17.179 ± 1.131 | \n",
93 | " 0.514 ± 0.027 | \n",
94 | " 6.259 ± 0.099 | \n",
95 | "
\n",
96 | " \n",
97 | " | refit | \n",
98 | " 0.052 ± 0.001 | \n",
99 | " -8.098 ± 1.199 | \n",
100 | " 0.492 ± 0.03 | \n",
101 | " 6.62 ± 0.197 | \n",
102 | "
\n",
103 | " \n",
104 | "
\n",
105 | "
"
106 | ],
107 | "text/plain": [
108 | " mse loglik mmd fid\n",
109 | "initial 0.054 ± 0.001 -17.179 ± 1.131 0.514 ± 0.027 6.259 ± 0.099\n",
110 | "refit 0.052 ± 0.001 -8.098 ± 1.199 0.492 ± 0.03 6.62 ± 0.197"
111 | ]
112 | },
113 | "execution_count": 61,
114 | "metadata": {},
115 | "output_type": "execute_result"
116 | }
117 | ],
118 | "source": [
119 | "def mean_ci(x, digits: int = 3) -> str:\n",
120 | " x = np.asarray(x, dtype=np.float64)\n",
121 | " n = x.size\n",
122 | " if n == 0:\n",
123 | " return \"nan ± nan\"\n",
124 | " mean = np.mean(x)\n",
125 | " se = np.std(x, ddof=1) / np.sqrt(n)\n",
126 | " ci_half_width = 1.96 * se\n",
127 | " return f\"{round(mean, digits)} ± {round(ci_half_width, digits)}\"\n",
128 | "\n",
129 | "\n",
130 | "summary = df_long.groupby(\"method\").agg(mean_ci)\n",
131 | "summary = summary[[\"mse\", \"loglik\", \"mmd\", \"fid\"]]\n",
132 | "summary.columns.name = None\n",
133 | "summary.index.name = None\n",
134 | "summary"
135 | ]
136 | }
137 | ],
138 | "metadata": {
139 | "kernelspec": {
140 | "display_name": ".venv",
141 | "language": "python",
142 | "name": "python3"
143 | },
144 | "language_info": {
145 | "codemirror_mode": {
146 | "name": "ipython",
147 | "version": 3
148 | },
149 | "file_extension": ".py",
150 | "mimetype": "text/x-python",
151 | "name": "python",
152 | "nbconvert_exporter": "python",
153 | "pygments_lexer": "ipython3",
154 | "version": "3.13.5"
155 | }
156 | },
157 | "nbformat": 4,
158 | "nbformat_minor": 5
159 | }
160 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Manual
2 | /scratch
3 | /docs
4 | /Styles
5 | /data
6 | /out
7 | *.pkl
8 | *.svg
9 | checkpoints/
10 | MNIST/
11 | logs/
12 | examples/MNIST
13 | examples/logs
14 | examples/checkpoints
15 | examples/**/*.gz
16 | # examples/**/*.ipynb
17 |
18 | # Python-generated files
19 | __pycache__/
20 | *.py[oc]
21 | build/
22 | dist/
23 | wheels/
24 | *.egg-info
25 |
26 | # Virtual environments
27 | .venv
28 |
29 | # Byte-compiled / optimized / DLL files
30 | __pycache__/
31 | *.py[cod]
32 | *$py.class
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | share/python-wheels/
52 | *.egg-info/
53 | .installed.cfg
54 | *.egg
55 | MANIFEST
56 |
57 | # PyInstaller
58 | # Usually these files are written by a python script from a template
59 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
60 | *.manifest
61 | *.spec
62 |
63 | # Installer logs
64 | pip-log.txt
65 | pip-delete-this-directory.txt
66 |
67 | # Unit test / coverage reports
68 | htmlcov/
69 | .tox/
70 | .nox/
71 | .coverage
72 | .coverage.*
73 | .cache
74 | nosetests.xml
75 | coverage.xml
76 | *.cover
77 | *.py,cover
78 | .hypothesis/
79 | .pytest_cache/
80 | cover/
81 |
82 | # Translations
83 | *.mo
84 | *.pot
85 |
86 | # Django stuff:
87 | *.log
88 | local_settings.py
89 | db.sqlite3
90 | db.sqlite3-journal
91 |
92 | # Flask stuff:
93 | instance/
94 | .webassets-cache
95 |
96 | # Scrapy stuff:
97 | .scrapy
98 |
99 | # Sphinx documentation
100 | docs/_build/
101 |
102 | # PyBuilder
103 | .pybuilder/
104 | target/
105 |
106 | # Jupyter Notebook
107 | .ipynb_checkpoints
108 |
109 | # IPython
110 | profile_default/
111 | ipython_config.py
112 |
113 | # pyenv
114 | # For a library or package, you might want to ignore these files since the code is
115 | # intended to run in multiple environments; otherwise, check them in:
116 | .python-version
117 |
118 | # pipenv
119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
122 | # install all needed dependencies.
123 | #Pipfile.lock
124 |
125 | # UV
126 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
127 | # This is especially recommended for binary packages to ensure reproducibility, and is more
128 | # commonly ignored for libraries.
129 | uv.lock
130 |
131 | # poetry
132 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
133 | # This is especially recommended for binary packages to ensure reproducibility, and is more
134 | # commonly ignored for libraries.
135 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
136 | poetry.lock
137 |
138 | # pdm
139 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
140 | #pdm.lock
141 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
142 | # in version control.
143 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
144 | .pdm.toml
145 | .pdm-python
146 | .pdm-build/
147 |
148 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
149 | __pypackages__/
150 |
151 | # Celery stuff
152 | celerybeat-schedule
153 | celerybeat.pid
154 |
155 | # SageMath parsed files
156 | *.sage.py
157 |
158 | # Environments
159 | .env
160 | .venv
161 | env/
162 | venv/
163 | ENV/
164 | env.bak/
165 | venv.bak/
166 |
167 | # Spyder project settings
168 | .spyderproject
169 | .spyproject
170 |
171 | # Rope project settings
172 | .ropeproject
173 |
174 | # mkdocs documentation
175 | /site
176 |
177 | # mypy
178 | .mypy_cache/
179 | .dmypy.json
180 | dmypy.json
181 |
182 | # Pyre type checker
183 | .pyre/
184 |
185 | # pytype static type analyzer
186 | .pytype/
187 |
188 | # Cython debug symbols
189 | cython_debug/
190 |
191 | # PyCharm
192 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
193 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
194 | # and can be added to the global gitignore or merged into this file. For a more nuclear
195 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
196 | #.idea/
197 |
198 | # Ruff stuff:
199 | .ruff_cache/
200 |
201 | # PyPI configuration file
202 | .pypirc
203 |
204 | # General
205 | .DS_Store
206 | .AppleDouble
207 | .LSOverride
208 |
209 | # Icon must end with two \r
210 | Icon
211 |
212 | # Thumbnails
213 | ._*
214 |
215 | # Files that might appear in the root of a volume
216 | .DocumentRevisions-V100
217 | .fseventsd
218 | .Spotlight-V100
219 | .TemporaryItems
220 | .Trashes
221 | .VolumeIcon.icns
222 | .com.apple.timemachine.donotpresent
223 |
224 | # Directories potentially created on remote AFP share
225 | .AppleDB
226 | .AppleDesktop
227 | Network Trash Folder
228 | Temporary Items
229 | .apdisk
230 |
231 | .vscode/*
232 | # !.vscode/settings.json
233 | !.vscode/tasks.json
234 | !.vscode/launch.json
235 | !.vscode/extensions.json
236 | !.vscode/*.code-snippets
237 |
238 | # Local History for Visual Studio Code
239 | .history/
240 |
241 | # Built Visual Studio Code Extensions
242 | *.vsix
243 |
244 | *.csv
245 | *.mat
--------------------------------------------------------------------------------
/examples/pred_intvl/experiments.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pickle
3 | from itertools import product
4 |
5 | import numpy as np
6 | import torch
7 | from data_utils import (
8 | DIR_WORK,
9 | device,
10 | extract_XY,
11 | get_logger,
12 | load_california_housing,
13 | load_news_popularity,
14 | )
15 | from models import (
16 | bnn_pred_intvl,
17 | ensemble_pred_intvl,
18 | extract_features,
19 | mc_dropout_pred_intvl,
20 | train_base_model,
21 | train_bnn,
22 | train_ensemble,
23 | )
24 | from vine_wrapper import train_vine, vine_pred_intvl
25 |
26 | DATASETS = {
27 | "california": load_california_housing,
28 | "news": load_news_popularity,
29 | }
30 | ALPHA = 0.05
31 | DIR_OUT = DIR_WORK / "examples" / "pred_intvl" / "out"
32 | DIR_OUT.mkdir(parents=True, exist_ok=True)
33 | LATENT_DIM = 10
34 | lst_lr = [1e-2, 1e-3, 1e-4]
35 | lst_weight_decay = [0.0, 0.91, 0.99]
36 | lst_p_drop = [0.1, 0.3, 0.5]
37 | lr = lst_lr[1]
38 | weight_decay = lst_weight_decay[1]
39 | p_drop = lst_p_drop[1]
40 | # * ===== is_test =====
41 | is_test = False
42 | # * ===== is_test =====
43 |
44 | if is_test:
45 | DIR_OUT = DIR_OUT / "tmp"
46 | SEEDS = list(range(3))
47 | N_ENSEMBLE = 4
48 | NUM_EPOCH = 5
49 | else:
50 | SEEDS = list(range(1, 100))
51 | N_ENSEMBLE = 20
52 | NUM_EPOCH = 150
53 |
54 |
55 | if __name__ == "__main__":
56 | for seed, ds_name in product(
57 | SEEDS,
58 | DATASETS.keys(),
59 | ):
60 | file = DIR_OUT / f"PredIntvl_{ds_name}_{seed}.pkl"
61 | file_log = DIR_OUT / f"PredIntvl_{ds_name}_{seed}.log"
62 | logger = get_logger(file_log, name=f"PredIntvl_{ds_name}_{seed}")
63 | if file.exists():
64 | logger.warning(f"File {file} already exists, skipping...")
65 | continue
66 | logger.info(f"Running {ds_name} with seed {seed}...")
67 | # * set seed
68 | torch.manual_seed(seed)
69 | torch.cuda.manual_seed_all(seed)
70 | torch.backends.cudnn.deterministic = True
71 | torch.backends.cudnn.benchmark = False
72 | np.random.seed(seed)
73 | # * load data
74 | try:
75 | train_loader, val_loader, test_loader, xsc, ysc = DATASETS[ds_name](seed_val=seed)
76 | X_test, Y_test = extract_XY(test_loader, device)
77 | except Exception as e:
78 | logger.error(f"Error loading dataset {ds_name}: {e}")
79 | continue
80 | input_dim = X_test.shape[1]
81 | dct_result = {
82 | "seed": seed,
83 | "dataset": ds_name,
84 | "alpha": ALPHA,
85 | "Y_test": Y_test.flatten().cpu().numpy(),
86 | }
87 | # * train / fit: get PI on test
88 | # ! ensemble
89 | try:
90 | logger.info("Training ensemble...")
91 | torch.manual_seed(seed)
92 | ensemble = train_ensemble(
93 | M=N_ENSEMBLE,
94 | train_loader=train_loader,
95 | val_loader=val_loader,
96 | input_dim=input_dim,
97 | latent_dim=LATENT_DIM,
98 | lr=lr,
99 | weight_decay=weight_decay,
100 | device=device,
101 | num_epochs=NUM_EPOCH,
102 | )
103 | dct_result["ensemble"] = ensemble_pred_intvl(ensemble, X_test, device, ALPHA)
104 | except Exception as e:
105 | logger.error(f"Error training ensemble: {e}")
106 | dct_result["ensemble"] = (None, None, None)
107 | # ! base with mc dropout
108 | try:
109 | logger.info("Training base model with MC dropout...")
110 | torch.manual_seed(seed)
111 | model_mcdropout = train_base_model(
112 | train_loader=train_loader,
113 | val_loader=val_loader,
114 | input_dim=input_dim,
115 | lr=lr,
116 | weight_decay=weight_decay,
117 | p_drop=p_drop,
118 | latent_dim=LATENT_DIM,
119 | num_epoch=NUM_EPOCH,
120 | device=device,
121 | )
122 | dct_result["mcdropout"] = mc_dropout_pred_intvl(
123 | model_mcdropout, X_test, T=200, device=device, alpha=ALPHA
124 | )
125 | except Exception as e:
126 | logger.error(f"Error training base model with MC dropout: {e}")
127 | dct_result["mcdropout"] = (None, None, None)
128 | # ! bnn
129 | try:
130 | logger.info("Training BNN...")
131 | torch.manual_seed(seed)
132 | model_bnn = train_bnn(
133 | train_loader=train_loader,
134 | val_loader=val_loader,
135 | input_dim=input_dim,
136 | latent_dim=LATENT_DIM,
137 | lr=lr,
138 | weight_decay=weight_decay,
139 | device=device,
140 | num_epochs=NUM_EPOCH,
141 | )
142 | dct_result["bnn"] = bnn_pred_intvl(
143 | model_bnn, X_test, T=200, alpha=ALPHA, device=device
144 | )
145 | except Exception as e:
146 | logger.error(f"Error training BNN: {e}")
147 | dct_result["bnn"] = (None, None, None)
148 | # ! vine (also extracting Y_train, Y_test)
149 | try:
150 | logger.info("Training vine...")
151 | torch.manual_seed(seed)
152 | model_base = train_base_model(
153 | train_loader=train_loader,
154 | val_loader=val_loader,
155 | input_dim=input_dim,
156 | lr=lr,
157 | weight_decay=weight_decay,
158 | p_drop=0.0,
159 | latent_dim=LATENT_DIM,
160 | device=device,
161 | )
162 | Z_train, Y_train = extract_features(model_base, train_loader, device)
163 | model_vine = train_vine(Z_train, Y_train, device=device)
164 | Z_test, Y_test = extract_features(model_base, test_loader, device)
165 | dct_result["vine"] = vine_pred_intvl(
166 | model_vine, Z_test, alpha=ALPHA, seed=seed, device=device
167 | )
168 | except Exception as e:
169 | logger.error(f"Error training vine: {e}")
170 | dct_result["vine"] = (None, None, None)
171 | with open(file, "wb") as f:
172 | pickle.dump(dct_result, f)
173 |
174 | logger.info(f"Results saved to {file}")
175 | logging.shutdown()
176 |
--------------------------------------------------------------------------------
/examples/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import pickle
4 | import platform
5 | import sys
6 | import time
7 | from collections import defaultdict
8 | from itertools import product
9 | from pathlib import Path
10 |
11 | import pyvinecopulib as pvc
12 | import torch
13 | from dotenv import load_dotenv
14 | from torch.special import ndtr
15 |
16 | import torchvinecopulib as tvc
17 |
18 | # ! ===================
19 | # ! ===================
20 | is_test = False
21 | # ! ===================
22 | # ! ===================
23 | num_threads = 10
24 | is_cuda_avail = torch.cuda.is_available()
25 | load_dotenv()
26 | DIR_WORK = Path(os.getenv("DIR_WORK"))
27 | DIR_OUT = DIR_WORK / "examples" / "benchmark" / "out"
28 | DIR_OUT.mkdir(parents=True, exist_ok=True)
29 | torch.set_default_dtype(torch.float64)
30 | SEED = 42
31 | if is_test:
32 | lst_num_dim = [10, 30][::-1]
33 | lst_num_obs = [1000, 50000][::-1]
34 | lst_seed = list(range(3))
35 | else:
36 | lst_num_dim = [10, 20, 30, 40, 50][::-1]
37 | lst_num_obs = [1000, 5000, 10000, 20000, 30000, 40000, 50000][::-1]
38 | lst_seed = list(range(100))
39 | dct_time_fit = {
40 | "pvc": defaultdict(list),
41 | "tvc": defaultdict(list),
42 | "tvc_cuda": defaultdict(list),
43 | }
44 | dct_time_sample = {
45 | "pvc": defaultdict(list),
46 | "tvc": defaultdict(list),
47 | "tvc_cuda": defaultdict(list),
48 | }
49 | dct_time_pdf = {
50 | "pvc": defaultdict(list),
51 | "tvc": defaultdict(list),
52 | "tvc_cuda": defaultdict(list),
53 | }
54 |
55 |
56 | def cuda_warmup(num_obs, num_dim):
57 | if torch.cuda.is_available():
58 | device = torch.device("cuda")
59 | for _ in range(5):
60 | _ = torch.randn(num_obs, num_dim, device=device)
61 | torch.cuda.synchronize()
62 | else:
63 | print("CUDA is not available. Skipping warm-up.")
64 |
65 |
66 | print(f"Python executable: {sys.executable}")
67 | print(f"Python version: {sys.version.splitlines()[0]}")
68 | print(f"Platform: {platform.platform()}")
69 | print(
70 | f"PyTorch: {torch.__version__} (CUDA available: {torch.cuda.is_available()}, CUDA toolkit: {torch.version.cuda})"
71 | )
72 | print(f"pyvinecopulib: {pvc.__version__}")
73 | print(f"torchvinecopulib: {tvc.__version__}")
74 | print(f"number of seeds: {len(lst_seed)}")
75 |
76 | # %%
77 | for num_dim, num_obs in product(lst_num_dim, lst_num_obs):
78 | print(f"\n\n{time.strftime('%Y-%m-%d %H:%M:%S')}\nnum_dim: {num_dim}, num_obs: {num_obs}\n\n")
79 | # ! preprocess into copula scale (uniform marginals)
80 | torch.manual_seed(SEED)
81 | # * tensor on cpu
82 | R = torch.rand(num_dim, num_dim, dtype=torch.float64)
83 | R /= R.norm(dim=1, keepdim=True)
84 | R @= R.T
85 | U = ndtr(
86 | (torch.randn(num_obs, num_dim, dtype=torch.float64) @ torch.linalg.cholesky(R, upper=True))
87 | )
88 | # * tensor on cuda
89 | if is_cuda_avail:
90 | U_cuda = U.cuda()
91 | # * np on cpu
92 | U_numpy = U.numpy().astype("float64")
93 | # ! pvc
94 | pvc_ctrl = pvc.FitControlsVinecop(
95 | family_set=(pvc.BicopFamily.indep, pvc.BicopFamily.tll),
96 | nonparametric_method="quadratic",
97 | tree_criterion="tau",
98 | num_threads=num_threads,
99 | )
100 | # ! tvc
101 | tvc_mdl = tvc.VineCop(num_dim=num_dim, num_step_grid=64, is_cop_scale=True)
102 | # ! tvc_cuda
103 | if is_cuda_avail:
104 | tvc_mdl_cuda = tvc.VineCop(num_dim=num_dim, num_step_grid=64, is_cop_scale=True).cuda()
105 |
106 | # * fit
107 | print(f"\n{time.strftime('%Y-%m-%d %H:%M:%S')} Fitting...\n")
108 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} PVC...")
109 | for seed in lst_seed:
110 | t0 = time.perf_counter()
111 | pvc_mdl = pvc.Vinecop.from_data(data=U_numpy, controls=pvc_ctrl)
112 | t1 = time.perf_counter()
113 | if seed > 0:
114 | dct_time_fit["pvc"][num_obs, num_dim].append(t1 - t0)
115 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC...")
116 | for seed in lst_seed:
117 | t0 = time.perf_counter()
118 | tvc_mdl.fit(obs=U, mtd_bidep="kendall_tau", num_iter_max=11)
119 | t1 = time.perf_counter()
120 | if seed > 0:
121 | dct_time_fit["tvc"][num_obs, num_dim].append(t1 - t0)
122 | if is_cuda_avail:
123 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC CUDA...")
124 | cuda_warmup(num_obs, num_dim)
125 | for seed in lst_seed:
126 | torch.cuda.synchronize()
127 | t0 = time.perf_counter()
128 | tvc_mdl_cuda.fit(obs=U_cuda, mtd_bidep="kendall_tau", num_iter_max=11)
129 | t1 = time.perf_counter()
130 | torch.cuda.synchronize()
131 | if seed > 0:
132 | dct_time_fit["tvc_cuda"][num_obs, num_dim].append(t1 - t0)
133 |
134 | # * sample
135 | print(f"\n{time.strftime('%Y-%m-%d %H:%M:%S')} Sampling...\n")
136 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} PVC...")
137 | for seed in lst_seed:
138 | t0 = time.perf_counter()
139 | pvc_mdl.simulate(n=num_obs, num_threads=num_threads)
140 | t1 = time.perf_counter()
141 | if seed > 0:
142 | dct_time_sample["pvc"][num_obs, num_dim].append(t1 - t0)
143 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC...")
144 | for seed in lst_seed:
145 | t0 = time.perf_counter()
146 | tvc_mdl.sample(num_sample=num_obs)
147 | t1 = time.perf_counter()
148 | if seed > 0:
149 | dct_time_sample["tvc"][num_obs, num_dim].append(t1 - t0)
150 | if is_cuda_avail:
151 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC CUDA...")
152 | cuda_warmup(num_obs, num_dim)
153 | for seed in lst_seed:
154 | torch.cuda.synchronize()
155 | t0 = time.perf_counter()
156 | tvc_mdl_cuda.sample(num_sample=num_obs)
157 | t1 = time.perf_counter()
158 | torch.cuda.synchronize()
159 | if seed > 0:
160 | dct_time_sample["tvc_cuda"][num_obs, num_dim].append(t1 - t0)
161 |
162 | # * pdf
163 | print(f"\n{time.strftime('%Y-%m-%d %H:%M:%S')} PDF...\n")
164 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} PVC...")
165 | for seed in lst_seed:
166 | t0 = time.perf_counter()
167 | pvc_mdl.pdf(U_numpy, num_threads=num_threads)
168 | t1 = time.perf_counter()
169 | if seed > 0:
170 | dct_time_pdf["pvc"][num_obs, num_dim].append(t1 - t0)
171 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC...")
172 | for seed in lst_seed:
173 | t0 = time.perf_counter()
174 | tvc_mdl.log_pdf(U)
175 | t1 = time.perf_counter()
176 | if seed > 0:
177 | dct_time_pdf["tvc"][num_obs, num_dim].append(t1 - t0)
178 | if is_cuda_avail:
179 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC CUDA...")
180 | cuda_warmup(num_obs, num_dim)
181 | for seed in lst_seed:
182 | torch.cuda.synchronize()
183 | t0 = time.perf_counter()
184 | tvc_mdl_cuda.log_pdf(U_cuda)
185 | t1 = time.perf_counter()
186 | torch.cuda.synchronize()
187 | if seed > 0:
188 | dct_time_pdf["tvc_cuda"][num_obs, num_dim].append(t1 - t0)
189 |
190 | # %%
191 | # ! save
192 | with open(DIR_OUT / "time_fit.pkl", "wb") as f:
193 | pickle.dump(dct_time_fit, f)
194 | with open(DIR_OUT / "time_sample.pkl", "wb") as f:
195 | pickle.dump(dct_time_sample, f)
196 | with open(DIR_OUT / "time_pdf.pkl", "wb") as f:
197 | pickle.dump(dct_time_pdf, f)
198 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # torchvinecopulib
2 |
3 | [](https://app.codacy.com/gh/TY-Cheng/torchvinecopulib/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
4 | [](https://app.codacy.com?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_coverage)
5 | [](https://github.com/TY-Cheng/torchvinecopulib/actions/workflows/python-package.yml)
6 | [](https://ty-cheng.github.io/torchvinecopulib/)
7 |
8 | 
9 | [](https://github.com/TY-Cheng/torchvinecopulib/actions/workflows/python-package.yml)
10 |
11 | [](https://github.com/TY-Cheng/torchvinecopulib/blob/main/LICENSE)
12 | [](https://pypi.org/project/torchvinecopulib/)
13 |
14 | A Python library for fitting and sampling vine copulas, using [PyTorch](https://pytorch.org/get-started/locally/).
15 |
16 | - C/D/R-Vine full-sampling/ quantile-regression/ conditional-sampling, all in one package
17 | - Flexible sampling order for experienced users
18 | - Vectorized tensor computation with GPU (`device='cuda'`) support
19 | - Shorter runtimes for higher dimension simulations
20 | - Pure `Python` library, inspired by [pyvinecopulib](https://github.com/vinecopulib/pyvinecopulib/) on Windows, Linux, MacOS
21 | - IO and visualization support
22 |
23 | ## Citation
24 | If you use `torchvinecopulib` in your work, please cite:
25 |
26 | > Cheng, Tuoyuan, Thibault Vatter, Thomas Nagler, and Kan Chen. "Vine Copulas as Differentiable Computational Graphs." arXiv preprint arXiv:2506.13318 (2025).
27 |
28 | ```latex
29 | @article{cheng2025vine,
30 | title={Vine Copulas as Differentiable Computational Graphs},
31 | author={Cheng, Tuoyuan and Vatter, Thibault and Nagler, Thomas and Chen, Kan},
32 | journal={arXiv preprint arXiv:2506.13318},
33 | year={2025},
34 | url={https://arxiv.org/abs/2506.13318},
35 | }
36 | ```
37 |
38 | ## Examples
39 |
40 | Visit the [`./examples/`](https://github.com/TY-Cheng/torchvinecopulib/tree/main/examples) folder for `.ipynb` Jupyter notebooks.
41 |
42 | ## Installation
43 |
44 | - By `pip` from [`PyPI`](https://pypi.org/project/torchvinecopulib/) (see the dependencies and uv sections below for CUDA support):
45 |
46 | ```bash
47 | pip install torchvinecopulib torch
48 | ```
49 |
50 | - Or `pip` from `./dist/*.whl` or `./dist/*.tar.gz` in this repo.
51 | Need to use proper file name.
52 |
53 | ```bash
54 | # inside project root folder
55 | pip install ./dist/torchvinecopulib-1.1.0-py3-none-any.whl
56 | # or
57 | pip install ./dist/torchvinecopulib-1.1.0.tar.gz
58 | ```
59 |
60 | ### (Recommended) [uv](https://docs.astral.sh/uv/getting-started/) for Dependency Management and Packaging
61 |
62 | After `git clone https://github.com/TY-Cheng/torchvinecopulib.git`, `cd` into the project root where [`pyproject.toml`](https://github.com/TY-Cheng/torchvinecopulib/blob/main/pyproject.toml) exists,
63 |
64 | ```bash
65 | # From inside the project root folder
66 | # Create and activate local virtual environment
67 | uv venv .venv
68 | source .venv/bin/activate
69 |
70 | # Sync dependencies with CPU support (default)
71 | uv sync --extra cpu
72 |
73 | # Or for CUDA 12.6 or 12.8 support (depends on your CUDA version)
74 | uv sync --extra cu126
75 |
76 | # Additionally, to install additional dependencies for the examples
77 | uv sync --extra examples
78 | ```
79 |
80 | ## Dependencies
81 |
82 | ```toml
83 | # inside the `./pyproject.toml` file;
84 | fastkde = "*"
85 | numpy = "*"
86 | pyvinecopulib = "*"
87 | python = ">=3.11"
88 | scipy = "*"
89 | # optional to facilitate customization
90 | torch = [
91 | { index = "torch-cpu", extra = "cpu" },
92 | { index = "torch-cu126", extra = "cu126" },
93 | { index = "torch-cu128", extra = "cu128" },
94 | ]
95 | ```
96 |
97 | For [PyTorch](https://pytorch.org/get-started/locally/) with `cuda`:
98 |
99 | ```bash
100 | pip install torch --index-url https://download.pytorch.org/whl/cu126 --force-reinstall
101 | # check cuda availability
102 | python -c "import torch; print(torch.cuda.is_available())"
103 | ```
104 |
105 | > [!TIP]
106 | > macOS users should set `device='cpu'` at this stage, for using `device='mps'` won't support `dtype=torch.float64`.
107 |
108 | ## Documentation
109 |
110 | - Visit [GitHub Pages website](https://ty-cheng.github.io/torchvinecopulib/)
111 |
112 | - Or build by yourself (need [`Sphinx`](https://github.com/sphinx-doc/sphinx), theme [`furo`](https://github.com/pradyunsg/furo) and [the GNU `make`](https://www.gnu.org/software/make/))
113 |
114 | ```bash
115 | # inside project root folder
116 | sphinx-apidoc -o ./docs ./torchvinecopulib && cd ./docs && make html && cd ..
117 | # if using uv
118 | uv run sphinx-apidoc -o docs torchvinecopulib/ --separate
119 | uv run sphinx-build docs docs/_build/html
120 | ```
121 |
122 | ## Tests
123 |
124 | ```python
125 | # inside project root folder
126 | python -m pytest ./tests
127 | # coverage report
128 | coverage run -m pytest ./tests && coverage html
129 | # if using uv
130 | uv run coverage run --source=torchvinecopulib -m pytest ./tests
131 | uv run coverage report -m
132 | ```
133 |
134 | ## TODO
135 |
136 | - `VineCop.rosenblatt`
137 | - replace `dict` with `torch.Tensor` using some `mod`
138 | - vectorized union-find
139 | - flatten `_visit` logic
140 | - `examples/someapplications.ipynb`
141 | - flatten dynamic nested dicts into tensors
142 | - [`fastkde.pdf`](https://github.com/LBL-EESA/fastkde/blob/main/src/fastkde/fastKDE.py) onto `torch.Tensor`
143 |
144 | ## Contributing
145 |
146 | We welcome contributions, whether it's a bug report, feature suggestion, code contribution, or documentation improvement.
147 |
148 | - If you encounter any issues with the project or have ideas for new features, please [open an issue](https://github.com/TY-Cheng/torchvinecopulib/issues/new) on GitHub or [privately email us](mailto:cty120120@gmail.com). Make sure to include detailed information about the problem or feature request, including steps to reproduce for bugs.
149 |
150 | ### Code Contributions
151 |
152 | 1. Fork the repository and create a new branch from the `main` branch.
153 | 2. Make your changes and ensure they adhere to the project's coding style and conventions.
154 | 3. Write tests for any new functionality and ensure existing tests pass.
155 | 4. Commit your changes with clear and descriptive commit messages.
156 | 5. Push your changes to your fork and submit a pull request to the `main` branch of the original repository.
157 |
158 | ### Pull Request Guidelines
159 |
160 | - Keep pull requests focused on addressing a single issue or feature.
161 | - Include a clear and descriptive title and description for your pull request.
162 | - Make sure all tests pass before submitting the pull request.
163 | - If your pull request addresses an open issue, reference the issue number in the description using the syntax `#issue_number`.
164 | - [in-place ops can be slower](https://discuss.pytorch.org/t/are-inplace-operations-faster/61209/4)
165 | - [torch.jit.script can be slower](https://discuss.pytorch.org/t/why-is-torch-jit-script-slower/120131/6)
166 |
167 | ## License
168 |
169 | This project is released under the MIT License (© 2024- Tuoyuan Cheng, Kan Chen).
170 | See [LICENSE](./LICENSE) for the full text, including our own grant of rights and disclaimer.
171 |
172 | ### Third-Party Dependencies
173 |
174 | See the “Third-Party Dependencies” section in [LICENSE](./LICENSE) for details on the `PyTorch`, `FastKDE`, and `pyvinecopulib` licenses that govern those components.
175 |
--------------------------------------------------------------------------------
/examples/vcae/vcae_debug.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "f20e9099",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stdout",
11 | "output_type": "stream",
12 | "text": [
13 | "The autoreload extension is already loaded. To reload it, use:\n",
14 | " %reload_ext autoreload\n"
15 | ]
16 | },
17 | {
18 | "name": "stderr",
19 | "output_type": "stream",
20 | "text": [
21 | "Running experiments: 0%| | 0/1 [00:00, ?it/s]"
22 | ]
23 | },
24 | {
25 | "name": "stderr",
26 | "output_type": "stream",
27 | "text": [
28 | "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
29 | "GPU available: True (cuda), used: True\n",
30 | "TPU available: False, using: 0 TPU cores\n",
31 | "HPU available: False, using: 0 HPUs\n",
32 | "/home/tvatter/Dropbox/github/torchvinecopulib/.venv/lib/python3.13/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /home/tvatter/Dropbox/github/torchvinecopulib/examples/vcae/checkpoints exists and is not empty.\n",
33 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
34 | "/home/tvatter/Dropbox/github/torchvinecopulib/.venv/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n",
35 | "/home/tvatter/Dropbox/github/torchvinecopulib/.venv/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n",
36 | "`Trainer.fit` stopped: `max_epochs=10` reached.\n",
37 | "/home/tvatter/Dropbox/github/torchvinecopulib/examples/vcae/vcae/model.py:240: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
38 | " )\n",
39 | "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
40 | "GPU available: True (cuda), used: True\n",
41 | "TPU available: False, using: 0 TPU cores\n",
42 | "HPU available: False, using: 0 HPUs\n",
43 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
44 | ]
45 | }
46 | ],
47 | "source": [
48 | "# add autoreload and other ipython magic\n",
49 | "import os\n",
50 | "\n",
51 | "import pandas as pd\n",
52 | "from IPython import get_ipython\n",
53 | "from tqdm import tqdm\n",
54 | "from vcae.config import config_mnist, config_svhn\n",
55 | "from vcae.experiment import run_experiment\n",
56 | "\n",
57 | "ipython = get_ipython()\n",
58 | "ipython.run_line_magic(\"load_ext\", \"autoreload\")\n",
59 | "ipython.run_line_magic(\"autoreload\", \"2\")\n",
60 | "ipython.run_line_magic(\"matplotlib\", \"inline\")\n",
61 | "ipython.run_line_magic(\"config\", 'InlineBackend.figure_format = \"retina\"')\n",
62 | "\n",
63 | "results = []\n",
64 | "dataset = \"MNIST\" # or \"SVHN\"\n",
65 | "# dataset = \"SVHN\" # or \"MNIST\"\n",
66 | "vine_lambda = 1.0 # Default value for vine lambda\n",
67 | "save_results = True # Save results for each seed in a separate CSV file\n",
68 | "n_seeds = 1 # Number of seeds to run\n",
69 | "\n",
70 | "if dataset == \"MNIST\":\n",
71 | " config = config_mnist\n",
72 | "elif dataset == \"SVHN\":\n",
73 | " config = config_svhn\n",
74 | "else:\n",
75 | " raise ValueError(f\"Unknown dataset: {dataset}\")\n",
76 | "\n",
77 | "output_path = f\"results_{dataset}_{vine_lambda}.csv\"\n",
78 | "for seed in tqdm(range(n_seeds), desc=\"Running experiments\"):\n",
79 | " result = run_experiment(seed, config, dataset=dataset, vine_lambda=vine_lambda)\n",
80 | " results.append(result)\n",
81 | " # Save result to a CSV file\n",
82 | " if save_results:\n",
83 | " df_result = pd.DataFrame([result])\n",
84 | "\n",
85 | " # Write headers only once\n",
86 | " if not os.path.exists(output_path):\n",
87 | " df_result.to_csv(output_path, index=False, mode=\"w\")\n",
88 | " else:\n",
89 | " df_result.to_csv(output_path, index=False, mode=\"a\", header=False)\n",
90 | "\n",
91 | "\n",
92 | "df_results = pd.DataFrame(results)\n",
93 | "# df_results.to_csv(\"experiment_results.csv\", index=False)\n"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "id": "889cf707",
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "# # Instantiate the LitMNISTAutoencoder\n",
104 | "# model = LitMNISTAutoencoder()\n",
105 | "\n",
106 | "# # Instantiate a trainer with the specified configuration\n",
107 | "# trainer = pl.Trainer(\n",
108 | "# accelerator=config.accelerator,\n",
109 | "# devices=config.devices,\n",
110 | "# max_epochs=config.max_epochs,\n",
111 | "# logger=CSVLogger(save_dir=config.save_dir),\n",
112 | "# )\n",
113 | "\n",
114 | "# # Train the model using the trainer\n",
115 | "# trainer.fit(model)\n",
116 | "\n",
117 | "# # Train the vine\n",
118 | "# model.learn_vine(n_samples=5000)\n",
119 | "# # # Read in the training metrics from the CSV file generated by the logger\n",
120 | "# # metrics = pd.read_csv(f\"{trainer.logger.log_dir}/metrics.csv\")\n",
121 | "\n",
122 | "# # # Remove the \"step\" column, which is not needed for our analysis\n",
123 | "# # del metrics[\"step\"]\n",
124 | "\n",
125 | "# # # Set the epoch column as the index, for easier plotting\n",
126 | "# # metrics.set_index(\"epoch\", inplace=True)\n",
127 | "\n",
128 | "# # # Create a line plot of the training metrics using Seaborn\n",
129 | "# # sns.relplot(data=metrics, kind=\"line\")\n",
130 | "\n",
131 | "# # Train the vine\n",
132 | "# model.learn_vine(n_samples=5000)\n",
133 | "\n",
134 | "# # Copy the model for refitting\n",
135 | "# model_refit = copy.deepcopy(model)\n",
136 | "\n",
137 | "# # Instantiate a new trainer\n",
138 | "# trainer_refit = pl.Trainer(\n",
139 | "# accelerator=config.accelerator,\n",
140 | "# devices=config.devices,\n",
141 | "# max_epochs=config.max_epochs,\n",
142 | "# logger=CSVLogger(save_dir=config.save_dir),\n",
143 | "# )\n",
144 | "\n",
145 | "# # Refit the model\n",
146 | "# trainer_refit.fit(model_refit)\n",
147 | "\n",
148 | "# # Test the model\n",
149 | "# representation, labels, data, decoded, samples = model.get_data(stage=\"test\")\n",
150 | "# representation_refit, labels_refit, data_refit, decoded_refit, samples_refit = model_refit.get_data(\n",
151 | "# stage=\"test\"\n",
152 | "# )\n",
153 | "\n",
154 | "# sigmas = [1e-3, 1e-2, 1e-1, 1, 10, 100]\n",
155 | "# score_model = compute_score(data, samples, DEVICE, sigmas=sigmas)\n",
156 | "# score_refit_model = compute_score(refit_data, refit_samples, DEVICE, sigmas=sigmas)\n",
157 | "# loglik_model = model.vine.log_pdf(representation).mean()\n",
158 | "# loglik_refit_model = model_refit.vine.log_pdf(representation_refit).mean()\n",
159 | "# print(\"Log-likelihood (original vs refit):\")\n",
160 | "# print(\n",
161 | "# f\"Log-likelihood: {loglik_model} vs {loglik_refit_model} => original is {loglik_model / loglik_refit_model} x worse\"\n",
162 | "# )\n",
163 | "# print(\"Model scores (original vs refit):\")\n",
164 | "# print(\n",
165 | "# f\"MMD: {score_model.mmd} vs {score_refit_model.mmd} => original is {score_model.mmd / score_refit_model.mmd} x worse\"\n",
166 | "# )\n",
167 | "# print(\n",
168 | "# f\"FID: {score_model.fid} vs {score_refit_model.fid} => original is {score_model.fid / score_refit_model.fid} x worse\"\n",
169 | "# )\n"
170 | ]
171 | }
172 | ],
173 | "metadata": {
174 | "kernelspec": {
175 | "display_name": ".venv",
176 | "language": "python",
177 | "name": "python3"
178 | },
179 | "language_info": {
180 | "codemirror_mode": {
181 | "name": "ipython",
182 | "version": 3
183 | },
184 | "file_extension": ".py",
185 | "mimetype": "text/x-python",
186 | "name": "python",
187 | "nbconvert_exporter": "python",
188 | "pygments_lexer": "ipython3",
189 | "version": "3.13.5"
190 | }
191 | },
192 | "nbformat": 4,
193 | "nbformat_minor": 5
194 | }
195 |
--------------------------------------------------------------------------------
/examples/pred_intvl/data_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import platform
4 | import sys
5 | from pathlib import Path
6 |
7 | import torch
8 | from dotenv import load_dotenv
9 | from sklearn.datasets import fetch_california_housing, fetch_openml
10 | from sklearn.model_selection import train_test_split
11 | from sklearn.preprocessing import StandardScaler
12 | from torch.utils.data import DataLoader, TensorDataset
13 |
14 | latent_dim = 10
15 | num_epochs = 10
16 | batch_size = 128
17 | test_size = 0.2
18 | val_size = 0.1
19 | random_seed = 42
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 | torch.manual_seed(random_seed)
22 | load_dotenv()
23 | DIR_WORK = Path(os.getenv("DIR_WORK"))
24 |
25 |
26 | def get_logger(
27 | log_file: Path | str,
28 | console_level: int = logging.WARNING,
29 | file_level: int = logging.INFO,
30 | fmt_console: str = "%(asctime)s - %(levelname)s - %(message)s",
31 | fmt_file: str = "%(asctime)s - %(name)s - %(levelname)s - [%(pathname)s:%(lineno)d] - %(message)s",
32 | name: str | None = None,
33 | ) -> logging.Logger:
34 | """Create (or retrieve) a module‐level logger that writes INFO+ to console and
35 | WARNING+ to file.
36 |
37 | Args:
38 | log_file: path to the file where warnings+ should be logged.
39 | console_level: logging level for console handler.
40 | file_level: logging level for file handler.
41 | fmt_console: format string for console output.
42 | fmt_file: format string for file output.
43 | name: name for the logger; defaults to str(log_file).
44 | device: the device string to log (e.g. "cuda"/"cpu"); if None, auto‐detects.
45 |
46 | Returns:
47 | A configured logging.Logger instance.
48 | """
49 | log_file = Path(log_file)
50 | log_file.parent.mkdir(parents=True, exist_ok=True)
51 |
52 | logger_name = name or str(log_file)
53 | logger = logging.getLogger(logger_name)
54 | logger.setLevel(logging.INFO)
55 |
56 | # * prevent duplicate handlers if called multiple times
57 | if not logger.handlers:
58 | # * file handler
59 | fh = logging.FileHandler(log_file)
60 | fh.setLevel(file_level)
61 | fh.setFormatter(logging.Formatter(fmt_file))
62 | logger.addHandler(fh)
63 | # * console handler
64 | ch = logging.StreamHandler(sys.stdout)
65 | ch.setLevel(console_level)
66 | ch.setFormatter(logging.Formatter(fmt_console))
67 | logger.addHandler(ch)
68 | # * initial banner
69 | logger.info("--- Logger initialized ---")
70 | logger.info(f"Log file: {log_file}")
71 | logger.info(f"Python: {sys.version.replace(chr(10), ' ')}")
72 | logger.info(f"Platform: {platform.platform()}")
73 | logger.info(f"PyTorch: {torch.__version__}")
74 | logger.info(f"CUDA available: {torch.cuda.is_available()}")
75 |
76 | return logger
77 |
78 |
79 | def load_california_housing(
80 | batch_size: int = 128,
81 | test_size: float = test_size,
82 | val_size: float = val_size,
83 | seed_val: int = 42,
84 | num_workers: int = 4,
85 | pin_memory: bool = True,
86 | ):
87 | """Fetch California Housing, split into train/val/test, scale (features & target) on
88 | train only, and wrap in PyTorch DataLoaders.
89 |
90 | Returns: train_loader, val_loader, test_loader, x_scaler, y_scaler
91 | """
92 | torch.manual_seed(seed_val)
93 |
94 | # * download & split
95 | data = fetch_california_housing()
96 | X, y = data.data, data.target[:, None]
97 | # ! test set untouched
98 | X_trainval, X_test, y_trainval, y_test = train_test_split(
99 | X, y, test_size=test_size, random_state=seed_val
100 | )
101 | # ! carve off validation from trainval
102 | X_train, X_val, y_train, y_val = train_test_split(
103 | X_trainval, y_trainval, test_size=val_size, random_state=seed_val
104 | )
105 | x_scaler = StandardScaler().fit(X_train)
106 | y_scaler = StandardScaler().fit(y_train)
107 | # * transform all splits
108 | X_train = x_scaler.transform(X_train)
109 | X_val = x_scaler.transform(X_val)
110 | X_test = x_scaler.transform(X_test)
111 | y_train = y_scaler.transform(y_train)
112 | y_val = y_scaler.transform(y_val)
113 | y_test = y_scaler.transform(y_test)
114 | # * to torch tensors
115 | Xtr = torch.from_numpy(X_train).float()
116 | Ytr = torch.from_numpy(y_train).float()
117 | Xva = torch.from_numpy(X_val).float()
118 | Yva = torch.from_numpy(y_val).float()
119 | Xte = torch.from_numpy(X_test).float()
120 | Yte = torch.from_numpy(y_test).float()
121 | # * wrap into DataLoaders
122 | train_loader = DataLoader(
123 | TensorDataset(Xtr, Ytr),
124 | batch_size=batch_size,
125 | shuffle=True,
126 | num_workers=num_workers,
127 | pin_memory=pin_memory,
128 | )
129 | val_loader = DataLoader(
130 | TensorDataset(Xva, Yva),
131 | batch_size=batch_size,
132 | shuffle=False,
133 | num_workers=num_workers,
134 | pin_memory=pin_memory,
135 | )
136 | test_loader = DataLoader(
137 | TensorDataset(Xte, Yte),
138 | batch_size=batch_size,
139 | shuffle=False,
140 | num_workers=num_workers,
141 | pin_memory=pin_memory,
142 | )
143 |
144 | return train_loader, val_loader, test_loader, x_scaler, y_scaler
145 |
146 |
147 | def load_news_popularity(
148 | batch_size: int = 128,
149 | test_size: float = test_size,
150 | val_size: float = val_size,
151 | seed_val: int = 42,
152 | num_workers: int = 4,
153 | pin_memory: bool = True,
154 | ):
155 | """Online News Popularity regression (UCI).
156 |
157 | Splits into train/val/test, scales on train only, wraps in DataLoaders.
158 | """
159 | torch.manual_seed(seed_val)
160 | # * fetch & split
161 | X, y = fetch_openml("OnlineNewsPopularity", version=1, return_X_y=True, as_frame=False)
162 | X = X[:, 1:] # ! remove the first column (website)
163 | y = y.astype(float).reshape(-1, 1)
164 | # ! test set untouched
165 | X_trainval, X_test, y_trainval, y_test = train_test_split(
166 | X, y, test_size=test_size, random_state=seed_val
167 | )
168 | # ! carve off validation
169 | X_train, X_val, y_train, y_val = train_test_split(
170 | X_trainval, y_trainval, test_size=val_size, random_state=seed_val
171 | )
172 | # * standardize
173 | x_scaler = StandardScaler().fit(X_train)
174 | y_scaler = StandardScaler().fit(y_train)
175 |
176 | X_train = x_scaler.transform(X_train)
177 | X_val = x_scaler.transform(X_val)
178 | X_test = x_scaler.transform(X_test)
179 |
180 | y_train = y_scaler.transform(y_train)
181 | y_val = y_scaler.transform(y_val)
182 | y_test = y_scaler.transform(y_test)
183 |
184 | # * to torch tensors
185 | Xtr = torch.from_numpy(X_train).float()
186 | Ytr = torch.from_numpy(y_train).float()
187 | Xva = torch.from_numpy(X_val).float()
188 | Yva = torch.from_numpy(y_val).float()
189 | Xte = torch.from_numpy(X_test).float()
190 | Yte = torch.from_numpy(y_test).float()
191 |
192 | # * DataLoaders
193 | train_loader = DataLoader(
194 | TensorDataset(Xtr, Ytr),
195 | batch_size=batch_size,
196 | shuffle=True,
197 | num_workers=num_workers,
198 | pin_memory=pin_memory,
199 | )
200 | val_loader = DataLoader(
201 | TensorDataset(Xva, Yva),
202 | batch_size=batch_size,
203 | shuffle=False,
204 | num_workers=num_workers,
205 | pin_memory=pin_memory,
206 | )
207 | test_loader = DataLoader(
208 | TensorDataset(Xte, Yte),
209 | batch_size=batch_size,
210 | shuffle=False,
211 | num_workers=num_workers,
212 | pin_memory=pin_memory,
213 | )
214 |
215 | return train_loader, val_loader, test_loader, x_scaler, y_scaler
216 |
217 |
218 | @torch.no_grad()
219 | def extract_XY(loader, device):
220 | Xs, Ys = [], []
221 | for x, y in loader:
222 | Xs.append(x.to(device))
223 | Ys.append(y.to(device))
224 | return torch.cat(Xs, 0).cpu(), torch.cat(Ys, 0).cpu()
225 |
226 |
227 | if __name__ == "__main__":
228 | # quick test / demo
229 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
230 | train_loader, test_loader, x_scaler, y_scaler = load_california_housing()
231 | X_train, Y_train = extract_XY(train_loader, device)
232 | X_test, Y_test = extract_XY(test_loader, device)
233 | print(f"Train X: {X_train.shape}, Y: {Y_train.shape}")
234 | print(f"Test X: {X_test.shape}, Y: {Y_test.shape}")
235 | # quick sanity check
236 | train_loader, test_loader, xscaler, yscaler = load_news_popularity(
237 | batch_size=128, test_size=0.2, seed=42
238 | )
239 | Xtr, Ytr = extract_XY(train_loader, device)
240 | Xte, Yte = extract_XY(test_loader, device)
241 | print(f"News‐pop train: X={Xtr.shape}, Y={Ytr.shape}")
242 | print(f"News‐pop test: X={Xte.shape}, Y={Yte.shape}")
243 |
--------------------------------------------------------------------------------
/tests/test_vinecop.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import networkx as nx
3 | import pytest
4 | import torch
5 |
6 | from torchvinecopulib.vinecop import VineCop
7 |
8 |
9 | def test_init_attributes_and_defaults():
10 | num_dim = 5
11 | vc = VineCop(num_dim=num_dim, is_cop_scale=True, num_step_grid=32)
12 | # basic attrs
13 | assert vc.num_dim == num_dim
14 | assert vc.is_cop_scale is True
15 | assert vc.num_step_grid == 32
16 | # empty structures
17 | assert isinstance(vc.marginals, torch.nn.ModuleList)
18 | assert len(vc.marginals) == num_dim
19 | assert isinstance(vc.bicops, torch.nn.ModuleDict)
20 | # number of pair-copulas = n(n-1)/2
21 | assert len(vc.bicops) == num_dim * (num_dim - 1) // 2
22 | assert vc.tree_bidep == [{} for _ in range(num_dim - 1)]
23 | assert vc.sample_order == tuple(range(num_dim))
24 | # no data yet
25 | assert vc.num_obs.item() == 0
26 |
27 |
28 | def test_device_and_dtype_after_cuda():
29 | vc = VineCop(3).to("cpu")
30 | assert vc.device.type == "cpu"
31 | assert vc.dtype is torch.float64
32 | if torch.cuda.is_available():
33 | vc2 = VineCop(3).cuda()
34 | assert vc2.device.type == "cuda"
35 |
36 |
37 | @pytest.mark.parametrize("mtd_vine", ["dvine", "cvine", "rvine"])
38 | @pytest.mark.parametrize(
39 | "mtd_bidep",
40 | ["kendall_tau", "mutual_info", "chatterjee_xi", "ferreira_tail_dep_coeff"],
41 | )
42 | def test_fit_sample_logpdf_cdf_forward(mtd_vine, mtd_bidep):
43 | num_dim = 5
44 | torch.manual_seed(0)
45 | U = torch.special.ndtri(torch.rand(200, num_dim, dtype=torch.float64))
46 | vc = VineCop(num_dim=num_dim, is_cop_scale=False, num_step_grid=32)
47 | # fit in copula-scale
48 | vc.fit(
49 | obs=U,
50 | mtd_kde="tll", # use TLL for bidep estimation
51 | is_dissmann=True,
52 | mtd_vine=mtd_vine,
53 | mtd_bidep=mtd_bidep,
54 | thresh_trunc=None, # no truncation
55 | )
56 | # num_obs must be updated
57 | assert vc.num_obs.item() == 200
58 |
59 | # sampling
60 | samp = vc.sample(num_sample=50, seed=1, is_sobol=False)
61 | assert samp.shape == (50, num_dim)
62 |
63 | # log_pdf
64 | lp = vc.log_pdf(U)
65 | assert lp.shape == (200, 1)
66 | # forward = neg average log-lik
67 | fwd = vc.forward(U)
68 | assert fwd.dim() == 0
69 |
70 | # cdf approximation in [0,1]
71 | cdf_vals = vc.cdf(U[:10], num_sample=1000, seed=2)
72 | assert cdf_vals.shape == (10, 1)
73 | assert (cdf_vals >= 0).all() and (cdf_vals <= 1).all()
74 |
75 |
76 | @pytest.mark.parametrize("mtd_vine", ["dvine", "cvine", "rvine"])
77 | @pytest.mark.parametrize(
78 | "mtd_bidep",
79 | ["kendall_tau", "mutual_info", "chatterjee_xi", "ferreira_tail_dep_coeff"],
80 | )
81 | def test_matrix_diagonal_matches_sample_order(mtd_vine, mtd_bidep):
82 | num_dim = 5
83 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16)
84 | torch.manual_seed(1)
85 | U = torch.rand(100, num_dim, dtype=torch.float64)
86 | vc.fit(
87 | U,
88 | mtd_kde="tll",
89 | is_dissmann=True,
90 | mtd_vine=mtd_vine,
91 | mtd_bidep=mtd_bidep,
92 | )
93 | M = vc.matrix
94 | # must be square
95 | assert M.shape == (num_dim, num_dim)
96 | # * d unique elements in each row
97 | for i in range(num_dim):
98 | elems = set(M[i, :].tolist())
99 | elems.discard(-1) # discard -1
100 | assert len(elems) == num_dim - i, f"Row {i} has incorrect unique elements"
101 | # diag = sample_order
102 | for i in range(num_dim):
103 | assert M[i, i].item() == vc.sample_order[i]
104 |
105 |
106 | def test_fit_with_explicit_matrix_uses_exact_edges():
107 | num_dim = 5
108 | torch.manual_seed(1)
109 | U = torch.rand(200, num_dim, dtype=torch.float64)
110 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16)
111 | vc.fit(U, mtd_kde="tll", is_dissmann=True)
112 | M = vc.matrix
113 | # M is a square matrix with num_dim rows and columns
114 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16)
115 | # fit with our explicit matrix
116 | vc.fit(U, is_dissmann=False, matrix=M)
117 | # now verify that for each level lv, and each idx in [0..num_dim-lv-2],
118 | # the edge (v_l, v_r, *cond) comes out exactly as we encoded it in M
119 | for lv in range(num_dim - 1):
120 | tree = vc.tree_bidep[lv]
121 | # must have exactly num_dim-lv-1 edges
122 | assert len(tree) == num_dim - lv - 1
123 |
124 | for idx in range(num_dim - lv - 1):
125 | # the two “free” spots in row idx of M
126 | a = int(M[idx, idx])
127 | b = int(M[idx, num_dim - lv - 1])
128 | v_l, v_r = sorted((a, b))
129 | # any remaining entries in that row form the conditioning set
130 | cond = sorted([int(_) for _ in M[idx, num_dim - lv :].tolist()])
131 | expected_edge = (v_l, v_r, *cond)
132 |
133 | # check that this exact tuple is a key in tree_bidep[lv]
134 | assert expected_edge in tree, (
135 | f"Expected edge {expected_edge} at level {lv} but got {list(tree)}"
136 | )
137 |
138 |
139 | def test_reset_clears_all_levels_and_bicops():
140 | num_dim = 5
141 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16)
142 | torch.manual_seed(0)
143 | U = torch.rand(50, num_dim, dtype=torch.float64)
144 | # fit with Dissmann algorithm
145 | vc.fit(
146 | U,
147 | is_dissmann=True,
148 | mtd_kde="tll",
149 | mtd_vine="rvine",
150 | mtd_bidep="kendall_tau",
151 | thresh_trunc=None,
152 | )
153 | assert vc.num_obs.item() > 0
154 | # now reset
155 | vc.reset()
156 | assert vc.num_obs.item() == 0
157 | assert vc.tree_bidep == [{} for _ in range(num_dim - 1)]
158 | # all BiCop should be independent again
159 | assert all(bc.is_indep for bc in vc.bicops.values())
160 |
161 |
162 | def test_reset_and_str():
163 | num_dim = 5
164 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16)
165 | torch.manual_seed(0)
166 | U = torch.rand(50, num_dim, dtype=torch.float64)
167 | # fit with Dissmann algorithm
168 | vc.fit(
169 | U,
170 | is_dissmann=True,
171 | mtd_kde="tll",
172 | mtd_vine="rvine",
173 | mtd_bidep="kendall_tau",
174 | thresh_trunc=None,
175 | )
176 | # __str__ contains key fields
177 | s = str(vc)
178 | for key in [
179 | "num_dim",
180 | "num_obs",
181 | "is_cop_scale",
182 | "num_step_grid",
183 | "mtd_bidep",
184 | "negloglik",
185 | "dtype",
186 | "device",
187 | "sample_order",
188 | ]:
189 | assert key in s, f"'{key}' not found in VineCop string representation"
190 |
191 |
192 | @pytest.mark.parametrize("mtd_vine", ["dvine", "cvine", "rvine"])
193 | @pytest.mark.parametrize(
194 | "mtd_bidep",
195 | ["kendall_tau", "mutual_info", "chatterjee_xi", "ferreira_tail_dep_coeff"],
196 | )
197 | def test_ref_count_hfunc_on_fitted_vine(mtd_vine, mtd_bidep):
198 | num_dim = 5
199 | torch.manual_seed(0)
200 | U = torch.rand(100, num_dim, dtype=torch.float64)
201 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16)
202 | vc.fit(
203 | U,
204 | mtd_kde="tll",
205 | is_dissmann=True,
206 | mtd_vine=mtd_vine,
207 | mtd_bidep=mtd_bidep,
208 | thresh_trunc=0.1,
209 | )
210 | # test static ref_count_hfunc
211 | ref_cnt, sources, num_hfunc = VineCop.ref_count_hfunc(
212 | num_dim=vc.num_dim,
213 | struct_obs=vc.struct_obs,
214 | sample_order=vc.sample_order,
215 | )
216 | assert isinstance(ref_cnt, dict)
217 | assert isinstance(sources, list)
218 | assert isinstance(num_hfunc, int)
219 | if mtd_vine == "cvine":
220 | # for cvine, we expect 0 hfuncs
221 | assert num_hfunc == 0
222 | else:
223 | assert num_hfunc >= 0
224 |
225 |
226 | def test_draw_lv_and_draw_dag(tmp_path):
227 | num_dim = 5
228 | torch.manual_seed(0)
229 | U = torch.rand(100, num_dim, dtype=torch.float64)
230 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=32)
231 | vc.fit(
232 | U,
233 | mtd_kde="tll",
234 | is_dissmann=True,
235 | mtd_vine="rvine",
236 | mtd_bidep="kendall_tau",
237 | thresh_trunc=None,
238 | )
239 |
240 | # draw level-0 with pseudo-obs nodes
241 | fig, ax, G = vc.draw_lv(lv=0, is_bcp=False)
242 | assert isinstance(G, nx.Graph)
243 | assert G.number_of_nodes() > 0
244 | plt.close(fig)
245 |
246 | # save level-1 bicop view
247 | fpath_lv = tmp_path / "level1.png"
248 | fig2, ax2, G2, outpath_lv = vc.draw_lv(lv=1, is_bcp=True, f_path=fpath_lv)
249 | assert outpath_lv == fpath_lv
250 | assert fpath_lv.exists()
251 | plt.close(fig2)
252 | # save level-1 pseudo-obs view
253 | fpath_lv = tmp_path / "level1.png"
254 | fig2, ax2, G2, outpath_lv = vc.draw_lv(lv=1, is_bcp=False, f_path=fpath_lv)
255 | assert outpath_lv == fpath_lv
256 | assert fpath_lv.exists()
257 | plt.close(fig2)
258 |
259 | # draw the DAG
260 | fig3, ax3, G3 = vc.draw_dag()
261 | assert isinstance(G3, nx.DiGraph)
262 | assert len(G3.nodes) > 0
263 | plt.close(fig3)
264 |
265 | # save DAG
266 | fpath_dag = tmp_path / "dag.png"
267 | fig4, ax4, G4, outpath_dag = vc.draw_dag(f_path=fpath_dag)
268 | assert outpath_dag == fpath_dag
269 | assert fpath_dag.exists()
270 | plt.close(fig4)
271 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | ### MIT License
2 |
3 | Copyright (c) 2024- Tuoyuan Cheng, Kan Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 | ---
24 |
25 | ## Third-Party Dependencies
26 |
27 | This project depends on the following libraries, each of which is governed by its own license. You must read and comply with those licenses when using this software.
28 |
29 | ### 1. PyTorch
30 |
31 | - **URL**: https://github.com/pytorch/pytorch/blob/main/LICENSE
32 |
33 | Full PyTorch license text:
34 |
35 | From PyTorch:
36 |
37 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
38 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
39 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
40 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
41 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
42 | Copyright (c) 2011-2013 NYU (Clement Farabet)
43 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
44 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
45 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
46 |
47 | From Caffe2:
48 |
49 | Copyright (c) 2016-present, Facebook Inc. All rights reserved.
50 |
51 | All contributions by Facebook:
52 | Copyright (c) 2016 Facebook Inc.
53 |
54 | All contributions by Google:
55 | Copyright (c) 2015 Google Inc.
56 | All rights reserved.
57 |
58 | All contributions by Yangqing Jia:
59 | Copyright (c) 2015 Yangqing Jia
60 | All rights reserved.
61 |
62 | All contributions by Kakao Brain:
63 | Copyright 2019-2020 Kakao Brain
64 |
65 | All contributions by Cruise LLC:
66 | Copyright (c) 2022 Cruise LLC.
67 | All rights reserved.
68 |
69 | All contributions by Tri Dao:
70 | Copyright (c) 2024 Tri Dao.
71 | All rights reserved.
72 |
73 | All contributions by Arm:
74 | Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
75 |
76 | All contributions from Caffe:
77 | Copyright(c) 2013, 2014, 2015, the respective contributors
78 | All rights reserved.
79 |
80 | All other contributions:
81 | Copyright(c) 2015, 2016 the respective contributors
82 | All rights reserved.
83 |
84 | Caffe2 uses a copyright model similar to Caffe: each contributor holds
85 | copyright over their contributions to Caffe2. The project versioning records
86 | all such contribution and copyright details. If a contributor wants to further
87 | mark their specific copyright on a particular contribution, they should
88 | indicate their copyright solely in the commit message of the change when it is
89 | committed.
90 |
91 | All rights reserved.
92 |
93 | Redistribution and use in source and binary forms, with or without
94 | modification, are permitted provided that the following conditions are met:
95 |
96 | 1. Redistributions of source code must retain the above copyright
97 | notice, this list of conditions and the following disclaimer.
98 |
99 | 2. Redistributions in binary form must reproduce the above copyright
100 | notice, this list of conditions and the following disclaimer in the
101 | documentation and/or other materials provided with the distribution.
102 |
103 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
104 | and IDIAP Research Institute nor the names of its contributors may be
105 | used to endorse or promote products derived from this software without
106 | specific prior written permission.
107 |
108 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
109 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
110 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
111 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
112 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
113 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
114 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
115 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
116 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
117 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
118 | POSSIBILITY OF SUCH DAMAGE.
119 |
120 | ### 2. FastKDE
121 |
122 | - **URL**: https://github.com/LBL-EESA/fastkde/blob/main/LICENSE.txt
123 |
124 | - **License**: Lawrence Berkeley National Laboratory (“LBNL”) Non-Commercial Use Only License
125 |
126 | Full FastKDE license text:
127 |
128 | LAWRENCE BERKELEY NATIONAL LABORATORY
129 | RESEARCH & DEVELOPMENT, NON-COMMERCIAL USE ONLY, LICENSE
130 |
131 | Copyright (c) 2015, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved.
132 |
133 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
134 |
135 | (1) Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
136 |
137 | (2) Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
138 |
139 | (3) Neither the name of the University of California, Lawrence Berkeley National Laboratory, U.S. Dept. of Energy nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
140 |
141 | (4) Use of the software, in source or binary form is FOR RESEARCH & DEVELOPMENT, NON-COMMERCIAL USE, PURPOSES ONLY. All commercial use rights for the software are hereby reserved. A separate commercial use license is available from Lawrence Berkeley National Laboratory.
142 | (5) In the event you create any bug fixes, patches, upgrades, updates, modifications, derivative works or enhancements to the source code or binary code of the software ("Enhancements") you hereby grant The Regents of the University of California and the U.S. Government a paid-up, non-exclusive, irrevocable, worldwide license in the Enhancements to reproduce, prepare derivative works, distribute copies to the public, perform publicly and display publicly, and to permit others to do so.
143 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
144 | **_ Copyright Notice _**
145 | FastKDE v1.0, Copyright (c) 2015, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved.
146 | If you have questions about your rights to use or distribute this software, please contact Berkeley Lab's Innovation & Partnerships Office at IPO@lbl.gov.
147 | NOTICE. This software was developed under funding from the U.S. Department of Energy. As such, the U.S. Government has been granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, prepare derivative works, and perform publicly and display publicly. Beginning five (5) years after the date permission to assert copyright is obtained from the U.S. Department of Energy, and subject to any subsequent five (5) year renewals, the U.S. Government is granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, prepare derivative works, distribute copies to the public, perform publicly and display publicly, and to permit others to do so.
148 |
149 | ### 3. pyvinecopulib
150 |
151 | - **URL**: https://github.com/vinecopulib/pyvinecopulib/blob/main/LICENSE
152 |
153 | Full pyvinecopulib license text:
154 |
155 | The MIT License (MIT)
156 |
157 | Copyright © 2019- Thomas Nagler and Thibault Vatter
158 |
159 | Permission is hereby granted, free of charge, to any person obtaining a copy of
160 | this software and associated documentation files (the “Software”), to deal in
161 | the Software without restriction, including without limitation the rights to
162 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
163 | the Software, and to permit persons to whom the Software is furnished to do so,
164 | subject to the following conditions:
165 |
166 | The above copyright notice and this permission notice shall be included in all
167 | copies or substantial portions of the Software.
168 |
169 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
170 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
171 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
172 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
173 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
174 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
175 |
176 | ---
177 |
--------------------------------------------------------------------------------
/tests/test_bicop.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import matplotlib.pyplot as plt
3 | import pytest
4 | import torch
5 |
6 | import torchvinecopulib as tvc
7 |
8 | from . import EPS, U_tensor, bicop_pair
9 |
10 |
11 | def test_device_and_dtype():
12 | cop = tvc.BiCop(num_step_grid=16)
13 | # by default on CPU float64
14 | assert cop.device.type == "cpu"
15 | assert cop.dtype is torch.float64
16 |
17 | if torch.cuda.is_available():
18 | cop = tvc.BiCop(num_step_grid=16).cuda()
19 | assert cop.device.type == "cuda"
20 |
21 |
22 | def test_monotonicity_and_range(bicop_pair):
23 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair
24 |
25 | # pick one of the two implementations or loop both
26 | for bicop in (bc_fast, bc_tll):
27 | # * simple diagonal check
28 | grid = torch.linspace(EPS, 1.0 - EPS, 100, device=U.device, dtype=torch.float64).unsqueeze(
29 | 1
30 | )
31 | pts = torch.hstack([grid, grid])
32 | out = bicop.cdf(pts)
33 | assert out.min() >= -EPS and out.max() <= 1 + EPS
34 | assert out.diff(0).min() >= -EPS
35 |
36 | # * row slices + left/right variants
37 | for u in grid:
38 | pts = torch.hstack([grid, u.repeat(grid.size(0), 1)])
39 | for pts in (pts, pts.flip(1)):
40 | for fn in (bicop.cdf, bicop.hfunc_r, bicop.hinv_r):
41 | v = fn(pts)
42 | assert v.min() >= -EPS and v.max() <= 1 + EPS
43 | assert v.diff(0).min() >= -EPS
44 |
45 |
46 | def test_inversion(bicop_pair):
47 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair
48 |
49 | for bicop in (bc_fast, bc_tll):
50 | grid = torch.linspace(0.1, 0.9, 50, device=U.device, dtype=torch.float64).unsqueeze(1)
51 | for u in grid:
52 | pts = torch.hstack([grid, u.repeat(grid.size(0), 1)])
53 |
54 | # * right‐side inverse
55 | rec0 = bicop.hinv_r(torch.hstack([bicop.hfunc_r(pts), pts[:, [1]]]))
56 | assert torch.allclose(rec0, pts[:, [0]], atol=1e-3)
57 |
58 | rec1 = bicop.hfunc_r(torch.hstack([bicop.hinv_r(pts), pts[:, [1]]]))
59 | assert torch.allclose(rec1, pts[:, [0]], atol=1e-3)
60 |
61 | # * left‐side inverse
62 | pts_rev = pts.flip(1)
63 | rec2 = bicop.hinv_l(torch.hstack([pts_rev[:, [0]], bicop.hfunc_l(pts_rev)]))
64 | assert torch.allclose(rec2, pts_rev[:, [1]], atol=1e-3)
65 |
66 | rec3 = bicop.hfunc_l(torch.hstack([pts_rev[:, [0]], bicop.hinv_l(pts_rev)]))
67 | assert torch.allclose(rec3, pts_rev[:, [1]], atol=1e-3)
68 |
69 |
70 | def test_pdf_integrates_to_one(bicop_pair):
71 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair
72 | for cop in (bc_fast, bc_tll):
73 | # our grid is uniform on [0,1]² with spacing Δ = 1/(N−1)
74 | Δ = 1.0 / (cop.num_step_grid - 1)
75 | # approximate ∫ pdf(u,v) du dv ≈ Σ_pdf_grid * Δ²
76 | approx_mass = (cop._pdf_grid.sum() * Δ**2).item()
77 | assert pytest.approx(expected=1.0, rel=1e-2) == approx_mass
78 | # non-negativity
79 | assert (cop._pdf_grid >= -EPS).all()
80 |
81 |
82 | def test_log_pdf_matches_log_of_pdf(bicop_pair):
83 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair
84 | for cop in (bc_fast, bc_tll):
85 | pts = torch.rand(500, 2, dtype=torch.float64, device=cop.device)
86 | pdf = cop.pdf(pts)
87 | logp = cop.log_pdf(pts)
88 | # where pdf>0, log_pdf == log(pdf)
89 | mask = pdf.squeeze(1) > 0
90 | assert torch.allclose(logp[mask], pdf[mask].log(), atol=1e-6)
91 |
92 |
93 | def test_log_pdf_handles_zero():
94 | cop = tvc.BiCop(num_step_grid=4)
95 | cop.is_indep = False
96 | # ! monkey‐patch pdf to always return zero
97 | cop.pdf = lambda obs: torch.zeros(obs.shape[0], 1, dtype=torch.float64)
98 | pts = torch.rand(100, 2, dtype=torch.float64)
99 | logp = cop.log_pdf(pts)
100 | # every entry should equal the neg‐infinity replacement
101 | assert torch.all(logp == -13.815510557964274)
102 |
103 |
104 | def test_sample_marginals(bicop_pair):
105 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair
106 | for cop in (bc_fast, bc_tll):
107 | for is_sobol in (False, True):
108 | samp = cop.sample(2000, seed=0, is_sobol=is_sobol)
109 | # samples lie in [0,1]
110 | assert samp.min() >= 0.0 and samp.max() <= 1.0
111 | # * marginal histograms should be roughly uniform
112 | counts_u = torch.histc(samp[:, 0], bins=10, min=0, max=1)
113 | counts_v = torch.histc(samp[:, 1], bins=10, min=0, max=1)
114 | # each bin ~200 ± 5 σ (σ≈√(N·p·(1−p))≈√(2000·0.1·0.9)≈13.4)
115 | assert counts_u.std() < 20
116 | assert counts_v.std() < 20
117 |
118 |
119 | def test_internal_buffers_and_flags(bicop_pair):
120 | _, _, _, U, bc_fast, bc_tll = bicop_pair
121 | for cop, mtd_kde in [(bc_fast, "fastKDE"), (bc_tll, "tll")]:
122 | print(cop)
123 | assert not cop.is_indep
124 | assert cop.mtd_kde == mtd_kde
125 | assert cop.num_obs == U.shape[0]
126 | # all the pre‐computed grids are the right shape
127 | m = cop.num_step_grid
128 | for name in ("_pdf_grid", "_cdf_grid", "_hfunc_l_grid", "_hfunc_r_grid"):
129 | grid = getattr(cop, name)
130 | assert grid.shape == (m, m)
131 |
132 |
133 | def test_tau_estimation(bicop_pair):
134 | _, _, _, U, bc_fast, bc_mtd_kde = bicop_pair
135 | # re‐fit with tau estimation
136 | bc = tvc.BiCop(num_step_grid=64)
137 | bc.fit(U, mtd_kde="tll", is_tau_est=True)
138 | # kendalltau must be nonzero for dependent data
139 | assert bc.tau[0].abs().item() > 0
140 | assert bc.tau[1].abs().item() >= 0
141 |
142 |
143 | def test_sample_shape_and_dtype_on_tll(bicop_pair):
144 | _, _, _, U, bc_fast, bc_tll = bicop_pair
145 | for cop in (bc_fast, bc_tll):
146 | s = cop.sample(123, seed=7, is_sobol=True)
147 | assert s.shape == (123, 2)
148 | assert s.dtype is cop.dtype
149 | assert s.device == cop.device
150 |
151 |
152 | def test_imshow_and_plot_api(bicop_pair):
153 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair
154 | cop = bc_fast
155 | # imshow
156 | fig, ax = cop.imshow(is_log_pdf=True)
157 | assert isinstance(fig, matplotlib.figure.Figure)
158 | assert isinstance(ax, matplotlib.axes.Axes)
159 | plt.close(fig)
160 |
161 | # contour
162 | fig2, ax2 = cop.plot(plot_type="contour", margin_type="unif")
163 | assert isinstance(fig2, matplotlib.figure.Figure)
164 | assert isinstance(ax2, matplotlib.axes.Axes)
165 | plt.close(fig2)
166 | fig2, ax2 = cop.plot(plot_type="contour", margin_type="norm")
167 | assert isinstance(fig2, matplotlib.figure.Figure)
168 | assert isinstance(ax2, matplotlib.axes.Axes)
169 | plt.close(fig2)
170 |
171 | # surface
172 | fig3, ax3 = cop.plot(plot_type="surface", margin_type="unif")
173 | assert isinstance(fig3, matplotlib.figure.Figure)
174 | assert isinstance(ax3, matplotlib.axes.Axes)
175 | plt.close(fig3)
176 | fig3, ax3 = cop.plot(plot_type="surface", margin_type="norm")
177 | assert isinstance(fig3, matplotlib.figure.Figure)
178 | assert isinstance(ax3, matplotlib.axes.Axes)
179 | plt.close(fig3)
180 |
181 | # invalid args
182 | with pytest.raises(ValueError):
183 | cop.plot(plot_type="foo")
184 | with pytest.raises(ValueError):
185 | cop.plot(margin_type="bar")
186 |
187 |
188 | def test_plot_accepts_unused_kwargs(bicop_pair):
189 | _, _, _, U, bc_fast, _ = bicop_pair
190 | # just ensure it doesn’t crash
191 | bc_fast.plot(plot_type="contour", margin_type="norm", xylim=(0, 1), grid_size=50)
192 | bc_fast.plot(plot_type="surface", margin_type="unif", xylim=(0, 1), grid_size=20)
193 |
194 |
195 | def test_reset_and_str(bicop_pair):
196 | # ! notice scope="module" so we put this test at the end
197 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair
198 | for cop in (bc_fast, bc_tll):
199 | cop.reset()
200 | # should go back to independent
201 | assert cop.is_indep
202 | assert cop.num_obs == 0
203 | # __str__ contains key fields
204 | s = str(cop)
205 | assert "is_indep" in s and "num_obs" in s and "mtd_kde" in s
206 |
207 |
208 | @pytest.mark.parametrize("method", ["constant", "linear", "quadratic"])
209 | def test_tll_methods_do_not_crash(U_tensor, method):
210 | cop = tvc.BiCop(num_step_grid=32)
211 | # should _not_ raise for any of the valid nonparametric_method names
212 | cop.fit(U_tensor, mtd_kde="tll", mtd_tll=method)
213 |
214 |
215 | def test_fit_invalid_method_raises(U_tensor):
216 | cop = tvc.BiCop(num_step_grid=32)
217 | with pytest.raises(RuntimeError):
218 | # pick something bogus
219 | cop.fit(U_tensor, mtd_kde="tll", mtd_tll="no_such_method")
220 |
221 |
222 | def test_interp_on_trivial_grid():
223 | # make a BiCop with a 2×2 grid
224 | bc = tvc.BiCop(num_step_grid=2)
225 | # override the geometry so that step_grid == 1.0 and target == 1
226 | bc.step_grid = 1.0
227 | bc._target = 1.0
228 | bc._EPS = 0.0 # so we don't get any clamping at the edges
229 |
230 | # grid:
231 | # g00 = 0, g01 = 1
232 | # g10 = 2, g11 = 3
233 | grid = torch.tensor([[0.0, 1.0], [2.0, 3.0]], dtype=torch.float64)
234 |
235 | # corners should map exactly:
236 | pts = torch.tensor(
237 | [
238 | [0.0, 0.0], # g00
239 | [0.0, 1.0], # g01
240 | [1.0, 0.0], # g10
241 | [1.0, 1.0], # g11
242 | ],
243 | dtype=torch.float64,
244 | )
245 | out = bc._interp(grid, pts)
246 | assert torch.allclose(out, torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float64))
247 |
248 | # center point should average correctly:
249 | center = torch.tensor([[0.5, 0.5]], dtype=torch.float64)
250 | val = bc._interp(grid, center)
251 | # manual bilinear: 0 + (2−0)*.5 + (1−0)*.5 + (3−1−2+0)*.5*.5 = 1.5
252 | assert torch.allclose(val, torch.tensor([[1.5]], dtype=torch.float64))
253 |
254 | # if you ask for out-of-bounds it should clamp to [0,1] and then pick the corner:
255 | # e.g. (−1, −1)→(0,0), (2,3)→(1,1)
256 | oob = torch.tensor([[-1.0, -1.0], [2.0, 3.0]], dtype=torch.float64)
257 | val_oob = bc._interp(grid, oob)
258 | assert torch.allclose(val_oob, torch.tensor([0.0, 3.0], dtype=torch.float64))
259 |
260 |
261 | def test_imshow_with_existing_axes():
262 | cop = tvc.BiCop(num_step_grid=32)
263 | us = torch.rand(100, 2)
264 | cop.fit(us, mtd_kde="fastKDE")
265 | fig, outer_ax = plt.subplots()
266 | fig2, ax2 = cop.imshow(is_log_pdf=False, ax=outer_ax, cmap="viridis")
267 | # should have returned the same axes object
268 | assert ax2 is outer_ax
269 | plt.close(fig)
270 | plt.close(fig2)
271 |
272 |
273 | def test_independent_copula_properties():
274 | for cop in (tvc.BiCop(num_step_grid=64), tvc.BiCop(num_step_grid=16)):
275 | # before fit, should be independent
276 | # CDF(u,v) = u·v, PDF(u,v)=1, hfunc_r(u,v)=u, hfunc_l(u,v)=v
277 | us = torch.rand(1000, 2, dtype=torch.float64)
278 | cdf = cop.cdf(us)
279 | assert torch.allclose(cdf, (us[:, 0] * us[:, 1]).unsqueeze(1))
280 | pdf = cop.pdf(us)
281 | assert torch.allclose(pdf, torch.ones_like(pdf))
282 | logpdf = cop.log_pdf(us)
283 | assert torch.allclose(logpdf, torch.zeros_like(logpdf))
284 | hr = cop.hfunc_r(us)
285 | hl = cop.hfunc_l(us)
286 | assert torch.allclose(hr, us[:, [0]])
287 | assert torch.allclose(hl, us[:, [1]])
288 | hir = cop.hinv_r(us)
289 | hil = cop.hinv_l(us)
290 | assert torch.allclose(hir, us[:, [0]])
291 | assert torch.allclose(hil, us[:, [1]])
292 |
--------------------------------------------------------------------------------
/examples/pred_intvl/models.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | import torch.nn as nn
5 | from blitz.modules import BayesianLinear
6 | from blitz.utils import variational_estimator
7 | from torch.utils.data import DataLoader
8 |
9 |
10 | class EncoderRegressor(nn.Module):
11 | def __init__(
12 | self,
13 | input_dim: int = 8,
14 | latent_dim: int = 10,
15 | hidden_dim1: int = 128,
16 | hidden_dim2: int = 64,
17 | hidden_dim3: int = 32,
18 | hidden_dim4: int = 16,
19 | hidden_dim5: int = 8,
20 | head_dim: int = 16,
21 | p_drop: float = 0.3,
22 | ):
23 | super().__init__()
24 | # * MLP encoder mapping ℝ⁸ → ℝ^{latent_dim}
25 | self.encoder = nn.Sequential(
26 | nn.Linear(input_dim, hidden_dim1),
27 | nn.BatchNorm1d(hidden_dim1),
28 | nn.LeakyReLU(0.1, inplace=True),
29 | nn.Dropout(p_drop),
30 | #
31 | nn.Linear(hidden_dim1, hidden_dim2),
32 | nn.BatchNorm1d(hidden_dim2),
33 | nn.LeakyReLU(0.1, inplace=True),
34 | nn.Dropout(p_drop),
35 | #
36 | nn.Linear(hidden_dim2, latent_dim),
37 | nn.BatchNorm1d(latent_dim),
38 | nn.LeakyReLU(0.1, inplace=True),
39 | nn.Dropout(p_drop),
40 | #
41 | nn.Linear(latent_dim, hidden_dim3),
42 | nn.BatchNorm1d(hidden_dim3),
43 | nn.LeakyReLU(0.1, inplace=True),
44 | nn.Dropout(p_drop),
45 | #
46 | nn.Linear(hidden_dim3, hidden_dim4),
47 | nn.BatchNorm1d(hidden_dim4),
48 | nn.LeakyReLU(0.1, inplace=True),
49 | nn.Dropout(p_drop),
50 | #
51 | nn.Linear(hidden_dim4, hidden_dim5),
52 | nn.BatchNorm1d(hidden_dim5),
53 | nn.LeakyReLU(0.1, inplace=True),
54 | nn.Dropout(p_drop),
55 | #
56 | nn.Linear(hidden_dim5, latent_dim),
57 | nn.BatchNorm1d(latent_dim),
58 | nn.Tanh(), # keeps latent coords in [–1,1]
59 | )
60 | # --- head: ℝ^{latent_dim} → ℝ (house value) ---
61 | self.head = nn.Sequential(
62 | nn.Linear(latent_dim, head_dim),
63 | nn.LeakyReLU(0.1, inplace=True),
64 | nn.Dropout(p_drop / 2),
65 | nn.Linear(head_dim, 1),
66 | )
67 |
68 | def forward(self, x: torch.Tensor):
69 | # x: [batch, 8]
70 | z = self.encoder(x) # [batch, latent_dim]
71 | y_hat = self.head(z) # [batch, num_outputs]
72 | return y_hat, z
73 |
74 |
75 | def train_epoch(model, loader, optim, criterion, device):
76 | model.train()
77 | total_loss = 0.0
78 | for x, y in loader:
79 | x, y = x.to(device), y.to(device)
80 | optim.zero_grad() # * claer out grad from last step
81 | y_hat, _ = model(x)
82 | # * if using MSELoss, we want y to be float and same shape
83 | loss = criterion(
84 | y_hat,
85 | y.float() if isinstance(criterion, nn.MSELoss) else y,
86 | )
87 | loss.backward() # * compute new grad
88 | optim.step() # * update weights with fresh grad
89 | total_loss += loss.item() * x.size(0)
90 | return total_loss / len(loader.dataset)
91 |
92 |
93 | @torch.no_grad()
94 | def eval_epoch(model, loader, criterion, device):
95 | model.eval()
96 | total_loss = 0.0
97 | for x, y in loader:
98 | x, y = x.to(device), y.to(device)
99 | y_hat, _ = model(x)
100 | # * if using MSELoss, we want y to be float and same shape
101 | total_loss += criterion(
102 | y_hat,
103 | y.float() if isinstance(criterion, nn.MSELoss) else y,
104 | ).item() * x.size(0)
105 | return total_loss / len(loader.dataset)
106 |
107 |
108 | def train_base_model(
109 | train_loader,
110 | val_loader,
111 | input_dim,
112 | lr=1e-3,
113 | weight_decay=0.91,
114 | p_drop=0.5,
115 | latent_dim=10,
116 | num_epoch=100,
117 | patience=5,
118 | device=None,
119 | ):
120 | model = EncoderRegressor(input_dim=input_dim, latent_dim=latent_dim, p_drop=p_drop).to(device)
121 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
122 | criterion = nn.MSELoss()
123 | best_state = copy.deepcopy(model.state_dict())
124 | best_loss = float("inf")
125 | wait = 0
126 | for epoch in range(1, num_epoch + 1):
127 | train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
128 | val_loss = eval_epoch(model, val_loader, criterion, device)
129 | if val_loss < best_loss:
130 | best_loss = val_loss
131 | best_state = copy.deepcopy(model.state_dict())
132 | wait = 0
133 | else:
134 | wait += 1
135 | if wait > patience:
136 | # print(f"Early stopping at epoch {epoch}")
137 | break
138 | # print(f"Epoch {epoch:2d} — train loss: {loss:.4f}")
139 | model.load_state_dict(best_state)
140 | return model.to(device)
141 |
142 |
143 | @torch.no_grad()
144 | def extract_features(model, loader, device):
145 | model.eval()
146 | all_z, all_y = [], []
147 | for x, y in loader:
148 | x = x.to(device)
149 | _, z = model(x)
150 | all_z.append(z.cpu())
151 | all_y.append(y.cpu().float()) # ensure shape [B,1]
152 | Z = torch.cat(all_z, dim=0) # [N, latent_dim]
153 | Y = torch.cat(all_y, dim=0) # [N, 1]
154 | return Z, Y
155 |
156 |
157 | def train_ensemble(
158 | M,
159 | train_loader: DataLoader,
160 | val_loader: DataLoader,
161 | input_dim,
162 | latent_dim,
163 | device,
164 | num_epochs,
165 | lr=1e-3,
166 | weight_decay=0.91,
167 | patience=5,
168 | ):
169 | ensemble = nn.ModuleList()
170 | for m in range(M):
171 | torch.manual_seed(m)
172 | model_m = train_base_model(
173 | train_loader=train_loader,
174 | val_loader=val_loader,
175 | input_dim=input_dim,
176 | latent_dim=latent_dim,
177 | lr=lr,
178 | weight_decay=weight_decay,
179 | p_drop=0.0,
180 | num_epoch=num_epochs,
181 | patience=patience,
182 | device=device,
183 | )
184 | ensemble.append(model_m.cpu())
185 | return ensemble.to(device)
186 |
187 |
188 | @torch.no_grad()
189 | def ensemble_pred_intvl(ensemble, x, device, alpha=0.05):
190 | preds = []
191 | for model in ensemble:
192 | model = model.to(device)
193 | y_hat, _ = model(x.to(device)) # [batch,1]
194 | preds.append(y_hat.cpu())
195 | P = torch.stack(preds, dim=0) # [M, batch, 1]
196 | mean = P.mean(dim=0) # [batch,1]
197 | lo = P.quantile(alpha / 2, dim=0) # [batch,1]
198 | hi = P.quantile(1 - alpha / 2, dim=0) # [batch,1]
199 | return (
200 | mean.cpu().flatten().numpy(),
201 | lo.cpu().flatten().numpy(),
202 | hi.cpu().flatten().numpy(),
203 | )
204 |
205 |
206 | @torch.no_grad()
207 | def mc_dropout_pred_intvl(model, x, T=100, device=None, alpha=0.05):
208 | """
209 | x: Tensor [batch,1,28,28]
210 | returns mean, lower, upper: each [batch,1]
211 | """
212 | model.eval()
213 | x = x.to(device)
214 |
215 | # * re-enable dropout layers
216 | for m in model.modules():
217 | if isinstance(m, nn.Dropout):
218 | m.train()
219 |
220 | # * collect T predictions
221 | preds = []
222 | for _ in range(T):
223 | y_hat, _ = model(x) # [batch,1]
224 | preds.append(y_hat.cpu())
225 | P = torch.stack(preds, dim=0) # [T, batch, 1]
226 |
227 | mu = P.mean(dim=0) # [batch,1]
228 | lower = P.quantile(alpha / 2, dim=0) # [batch,1]
229 | upper = P.quantile(1 - alpha / 2, dim=0) # [batch,1]
230 | return (
231 | mu.cpu().flatten().numpy(),
232 | lower.cpu().flatten().numpy(),
233 | upper.cpu().flatten().numpy(),
234 | )
235 |
236 |
237 | @variational_estimator
238 | class BayesianEncoderRegressor(nn.Module):
239 | def __init__(
240 | self,
241 | input_dim: int,
242 | latent_dim=10,
243 | hidden_dim1=128,
244 | hidden_dim2=64,
245 | hidden_dim3=32,
246 | hidden_dim4=16,
247 | hidden_dim5=8,
248 | head_dim=16,
249 | num_outputs=1,
250 | ):
251 | super().__init__()
252 | # -- Encoder: vector 2 latent z
253 | self._encoder = nn.Sequential(
254 | BayesianLinear(input_dim, hidden_dim1),
255 | nn.LeakyReLU(inplace=True),
256 | BayesianLinear(hidden_dim1, hidden_dim2),
257 | nn.LeakyReLU(inplace=True),
258 | BayesianLinear(hidden_dim2, hidden_dim3),
259 | nn.LeakyReLU(inplace=True),
260 | BayesianLinear(hidden_dim3, hidden_dim4),
261 | nn.LeakyReLU(inplace=True),
262 | BayesianLinear(hidden_dim4, hidden_dim5),
263 | nn.LeakyReLU(inplace=True),
264 | BayesianLinear(hidden_dim5, latent_dim),
265 | nn.Tanh(), # keeps latent coords in [–1,1]
266 | )
267 | # -- Head: z 2 y
268 | self._head = nn.Sequential(
269 | BayesianLinear(latent_dim, head_dim),
270 | nn.LeakyReLU(0.1, inplace=True),
271 | BayesianLinear(head_dim, num_outputs),
272 | )
273 |
274 | def forward(self, x):
275 | # x: [batch, input_dim]
276 | z = self._encoder(x)
277 | y = self._head(z)
278 | return y
279 |
280 | def encode(self, x):
281 | return self._encoder(x)
282 |
283 |
284 | def train_bnn(
285 | train_loader: DataLoader,
286 | val_loader: DataLoader,
287 | input_dim,
288 | latent_dim,
289 | num_epochs,
290 | lr=1e-3,
291 | weight_decay=0.91,
292 | device=None,
293 | patience=5,
294 | ):
295 | """Trains one BayesianEncoderRegressor with ELBO-loss and early stopping on
296 | validation MSE.
297 |
298 | Returns the model with best val-MSE.
299 | """
300 | model = BayesianEncoderRegressor(
301 | input_dim=input_dim,
302 | latent_dim=latent_dim,
303 | ).to(device)
304 | optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
305 | criterion = nn.MSELoss()
306 | best_state = copy.deepcopy(model.state_dict())
307 | best_val_loss = float("inf")
308 | wait = 0
309 | for epoch in range(num_epochs):
310 | model.train()
311 | for x, y in train_loader:
312 | x, y = x.to(device), y.to(device)
313 | optim.zero_grad()
314 | # * sample_elbo will call forward(x) → y
315 | loss = model.sample_elbo(
316 | inputs=x,
317 | labels=y,
318 | criterion=criterion,
319 | sample_nbr=3,
320 | complexity_cost_weight=1 / len(train_loader.dataset),
321 | )
322 | loss.backward()
323 | optim.step()
324 | model.eval()
325 | val_losses = []
326 | with torch.no_grad():
327 | for x, y in val_loader:
328 | x, y = x.to(device), y.to(device)
329 | preds = []
330 | for _ in range(10):
331 | preds.append(model(x))
332 | y_hat = torch.stack(preds, 0).mean(dim=0).to(device)
333 | val_losses.append(criterion(y_hat, y.float()).item() * x.size(0))
334 | val_loss = sum(val_losses) / len(val_loader.dataset)
335 | if val_loss < best_val_loss - 1e-6:
336 | best_val_loss = val_loss
337 | best_state = copy.deepcopy(model.state_dict())
338 | wait = 0
339 | else:
340 | wait += 1
341 | if wait >= patience:
342 | # print(f"Early stopping at epoch {epoch}")
343 | break
344 | model.load_state_dict(best_state)
345 | return model.to(device)
346 |
347 |
348 | @torch.no_grad()
349 | def bnn_pred_intvl(model, x, T=200, alpha=0.05, device=None):
350 | model.train() # keep Bayesian layers sampling
351 | preds = []
352 | for _ in range(T):
353 | out = model(x.to(device))
354 | # out might be y, or (y,z), or (y,z,kl) — we only need the first thing
355 | y_hat = out if isinstance(out, torch.Tensor) else out[0]
356 | preds.append(y_hat.cpu())
357 | P = torch.stack(preds, 0) # [T, batch, 1]
358 | mu = P.mean(dim=0) # [batch,1]
359 | lower = P.quantile(alpha / 2, dim=0)
360 | upper = P.quantile(1 - alpha / 2, dim=0)
361 | return (
362 | mu.cpu().flatten().numpy(),
363 | lower.cpu().flatten().numpy(),
364 | upper.cpu().flatten().numpy(),
365 | )
366 |
--------------------------------------------------------------------------------
/examples/vcae/vcae/model.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import copy
3 | from dataclasses import asdict
4 | from typing import Optional
5 |
6 | import pytorch_lightning as pl
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 | from torch.utils.data import DataLoader, Dataset, random_split
11 | from torchvision import transforms
12 | from torchvision.datasets import MNIST, SVHN
13 |
14 | import torchvinecopulib as tvc
15 |
16 | from .config import DEVICE, Config
17 |
18 |
19 | class LitAutoencoder(pl.LightningModule, abc.ABC):
20 | def __init__(self, config: Config) -> None:
21 | """Initialize the autoencoder with the given configuration."""
22 | super().__init__()
23 | self.save_hyperparameters(asdict(config))
24 | self.flat_dim = int(torch.prod(torch.tensor(self.hparams["dims"])))
25 |
26 | # Placeholders for data attributes
27 | self.data_test: Optional[Dataset] = None
28 | self.data_train: Optional[Dataset] = None
29 | self.data_val: Optional[Dataset] = None
30 |
31 | # Placeholder for the vine copula
32 | self.vine: Optional[tvc.VineCop] = None
33 |
34 | # Call subclass-defined builders
35 | self.encoder = self.build_encoder()
36 | self.decoder = self.build_decoder()
37 | self.transform = self.build_transform()
38 |
39 | @property
40 | @abc.abstractmethod
41 | def dataset_cls(self) -> type:
42 | """Subclasses must return the dataset class (e.g., MNIST, SVHN)."""
43 | raise NotImplementedError
44 |
45 | @property
46 | @abc.abstractmethod
47 | def dataset_kwargs(self) -> dict:
48 | """Subclasses must return a dictionary of keyword arguments for the dataset."""
49 | raise NotImplementedError
50 |
51 | @abc.abstractmethod
52 | def build_encoder(self) -> nn.Module:
53 | """Subclasses must return an nn.Module mapping x -> z"""
54 | raise NotImplementedError
55 |
56 | @abc.abstractmethod
57 | def build_decoder(self) -> nn.Module:
58 | """Subclasses must return an nn.Module mapping z -> x̂"""
59 | raise NotImplementedError
60 |
61 | @abc.abstractmethod
62 | def build_transform(self) -> transforms.Compose:
63 | """Subclasses must return a torchvision transforms.Compose for data preprocessing."""
64 | raise NotImplementedError
65 |
66 | def copy_with_config(self, new_config: Config) -> "LitAutoencoder":
67 | """Create a copy of the model with a new configuration (for refit)."""
68 | new_model = self.__class__(new_config)
69 | new_model.encoder.load_state_dict(self.encoder.state_dict())
70 | new_model.decoder.load_state_dict(self.decoder.state_dict())
71 | if self.vine is not None:
72 | new_model.vine = copy.deepcopy(self.vine)
73 | return new_model
74 |
75 | def set_vine(self, vine: tvc.VineCop) -> None:
76 | if not isinstance(vine, tvc.VineCop):
77 | raise ValueError("Vine must be of type tvc.VineCop for tvc.")
78 | latent_size: int = self.hparams["latent_size"]
79 | if not vine.num_dim == latent_size:
80 | raise ValueError(
81 | f"Vine dimension {vine.num_dim} does not match latent size {latent_size}."
82 | )
83 | self.vine = vine
84 | self.add_module("vine", vine)
85 |
86 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
87 | latent_size: int = self.hparams["latent_size"]
88 | dims: tuple[int, ...] = self.hparams["dims"]
89 | z = self.encoder(x)
90 | x_hat = self.decoder(z)
91 | return x_hat.view(-1, *dims), z.view(-1, latent_size)
92 |
93 | def compute_loss(self, x: torch.Tensor) -> torch.Tensor:
94 | x_hat, z = self(x)
95 | loss = F.mse_loss(x_hat, x)
96 | if self.hparams["vine_lambda"] > 0:
97 | if self.vine is None:
98 | raise ValueError("Vine must be set before computing the loss.")
99 | vine_loss = -self.vine.log_pdf(z).mean()
100 | vine_lambda: float = self.hparams["vine_lambda"]
101 | loss += vine_lambda * vine_loss
102 | # use_mmd: bool = self.hparams["use_mmd"]
103 | # if use_mmd:
104 | # mmd_sigmas: list = self.hparams["mmd_sigmas"]
105 | # mmd_lambda: float = self.hparams["mmd_lambda"]
106 | # z_vine = self.vine.sample(x.shape[0])
107 | # z_vine = torch.tensor(z_vine, dtype=z.dtype, device=x.device)
108 | # x_vine = self.decoder(z_vine)
109 | # mmd_loss = mmd(x, x_vine, sigmas=mmd_sigmas)
110 | # loss += mmd_lambda * mmd_loss
111 | return loss
112 |
113 | def training_step(
114 | self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
115 | ) -> torch.Tensor:
116 | """Training step to compute loss on training data."""
117 | x, _ = batch
118 | x.to(DEVICE)
119 | loss = self.compute_loss(x)
120 | self.log("train_loss", loss, prog_bar=True)
121 | return loss
122 |
123 | def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
124 | """Validation step to compute loss on validation data."""
125 | x, _ = batch
126 | x.to(DEVICE)
127 | loss = self.compute_loss(x)
128 | self.log("val_loss", loss, prog_bar=True)
129 |
130 | def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
131 | """Test step to compute loss on test data."""
132 | x, _ = batch
133 | x.to(DEVICE)
134 | loss = self.compute_loss(x)
135 | self.log("test_loss", loss, prog_bar=True)
136 |
137 | def configure_optimizers(self) -> torch.optim.Optimizer:
138 | """Returns the optimizer for training."""
139 | learning_rate: float = self.hparams["learning_rate"]
140 | return torch.optim.Adam(self.parameters(), lr=learning_rate)
141 |
142 | def prepare_data(self) -> None:
143 | """Download the dataset if not already present."""
144 | data_dir = self.hparams["data_dir"]
145 | self.dataset_cls(data_dir, download=True, **self.dataset_kwargs["train"])
146 | self.dataset_cls(data_dir, download=True, **self.dataset_kwargs["test"])
147 |
148 | def setup(self, stage=None) -> None:
149 | """Setup the datasets for training, validation, and testing."""
150 | data_dir = self.hparams["data_dir"]
151 |
152 | if stage in ("fit", None):
153 | data_full = self.dataset_cls(
154 | data_dir, transform=self.transform, **self.dataset_kwargs["train"]
155 | )
156 | n_total = len(data_full)
157 | n_val = int(self.hparams["val_train_split"] * n_total)
158 | n_train = n_total - n_val
159 |
160 | generator = torch.Generator().manual_seed(self.hparams["seed"])
161 | self.data_train, self.data_val = random_split(
162 | data_full, [n_train, n_val], generator=generator
163 | )
164 |
165 | if stage in ("test", None):
166 | self.data_test = self.dataset_cls(
167 | data_dir, transform=self.transform, **self.dataset_kwargs["test"]
168 | )
169 |
170 | def train_dataloader(self) -> DataLoader:
171 | """Returns the training dataloader."""
172 | if self.data_train is None:
173 | self.setup(stage="fit")
174 | assert self.data_train is not None
175 | batch_size: int = self.hparams["batch_size"]
176 | num_workers: int = self.hparams["num_workers"]
177 | return DataLoader(
178 | self.data_train,
179 | batch_size=batch_size,
180 | pin_memory=True,
181 | persistent_workers=True,
182 | num_workers=num_workers,
183 | )
184 |
185 | def val_dataloader(self) -> DataLoader:
186 | """Returns the validation dataloader."""
187 | if self.data_val is None:
188 | self.setup(stage="fit")
189 | assert self.data_val is not None
190 | batch_size: int = self.hparams["batch_size"]
191 | num_workers: int = self.hparams["num_workers"]
192 | return DataLoader(
193 | self.data_val,
194 | batch_size=batch_size,
195 | pin_memory=True,
196 | persistent_workers=True,
197 | num_workers=num_workers,
198 | )
199 |
200 | def test_dataloader(self) -> DataLoader:
201 | """Returns the test dataloader."""
202 | if self.data_test is None:
203 | self.setup(stage="test")
204 | assert self.data_test is not None
205 | batch_size: int = self.hparams["batch_size"]
206 | num_workers: int = self.hparams["num_workers"]
207 | return DataLoader(
208 | self.data_test,
209 | batch_size=batch_size,
210 | pin_memory=True,
211 | persistent_workers=True,
212 | num_workers=num_workers,
213 | )
214 |
215 | def get_data(
216 | self, stage: str = "fit"
217 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
218 | """Extracts representations, labels, data, decoded outputs, and samples (e.g., to compute metrics)."""
219 | if stage == "fit" or stage is None:
220 | data_loader = self.train_dataloader()
221 | elif stage == "test":
222 | data_loader = self.test_dataloader()
223 | representations = []
224 | decoded = []
225 | labels = []
226 | data = []
227 | samples = []
228 | encoder_device = next(self.encoder.parameters()).device
229 | decoder_device = next(self.decoder.parameters()).device
230 | for batch in data_loader:
231 | x, y = batch
232 | x = x.to(encoder_device)
233 | with torch.no_grad():
234 | z = self.encoder(x).to(decoder_device)
235 | x_hat = self.decoder(z)
236 | if self.vine is not None:
237 | sample = self.vine.sample(x.shape[0])
238 | sample = self.decoder(
239 | torch.tensor(sample, dtype=z.dtype, device=decoder_device)
240 | )
241 | decoded.append(x_hat)
242 | representations.append(z)
243 | labels.append(y)
244 | data.append(x)
245 | if self.vine is not None:
246 | samples.append(sample)
247 |
248 | # Concatenate into a single tensor
249 | representations_tensor = torch.cat(representations, dim=0)
250 | labels_tensor = torch.cat(labels, dim=0)
251 | data_tensor = torch.cat(data, dim=0).flatten(start_dim=1).flatten(start_dim=1)
252 | decoded_tensor = torch.cat(decoded, dim=0).flatten(start_dim=1)
253 | samples_tensor = (
254 | torch.cat(samples, dim=0).flatten(start_dim=1) if self.vine is not None else None
255 | )
256 |
257 | return representations_tensor, labels_tensor, data_tensor, decoded_tensor, samples_tensor
258 |
259 | def learn_vine(self, n_samples: int = 5000) -> None:
260 | """Learn the vine copula from a subset of representations."""
261 | if self.data_train is None:
262 | self.setup(stage="fit")
263 | representations, _, _, _, _ = self.get_data(stage="fit")
264 |
265 | representations_subset = representations[
266 | torch.randperm(representations.shape[0])[:n_samples]
267 | ]
268 | vine_tvc = tvc.VineCop(
269 | num_dim=representations_subset.shape[1],
270 | is_cop_scale=False,
271 | num_step_grid=30,
272 | ).to(DEVICE)
273 | vine_tvc.fit(
274 | obs=representations_subset,
275 | mtd_kde="tll",
276 | )
277 | self.set_vine(vine_tvc)
278 |
279 |
280 | class LitMNISTAutoencoder(LitAutoencoder):
281 | def __init__(self, config: Config) -> None:
282 | super().__init__(config)
283 |
284 | def build_transform(self) -> transforms.Compose:
285 | """Returns a torchvision transforms.Compose for MNIST preprocessing."""
286 | return transforms.Compose([transforms.ToTensor()])
287 |
288 | @property
289 | def dataset_cls(self) -> type:
290 | """Returns the dataset class for MNIST."""
291 | return MNIST
292 |
293 | @property
294 | def dataset_kwargs(self) -> dict:
295 | """Returns a dictionary of keyword arguments for the MNIST dataset."""
296 | return {
297 | "train": {"train": True},
298 | "test": {"train": False},
299 | }
300 |
301 | def build_encoder(self) -> nn.Module:
302 | """Returns a fully connected encoder for MNIST."""
303 | # Encoder: flatten → hidden → latent
304 | latent_size: int = self.hparams["latent_size"]
305 | hidden_size: int = self.hparams["hidden_size"]
306 | return nn.Sequential(
307 | nn.Flatten(),
308 | nn.Linear(self.flat_dim, hidden_size),
309 | nn.ReLU(),
310 | nn.Linear(hidden_size, hidden_size // 2),
311 | nn.ReLU(),
312 | nn.Linear(hidden_size // 2, latent_size),
313 | ).to(DEVICE)
314 |
315 | def build_decoder(self) -> nn.Module:
316 | """Returns a fully connected decoder for MNIST."""
317 | # Decoder: latent → hidden → image
318 | latent_size: int = self.hparams["latent_size"]
319 | hidden_size: int = self.hparams["hidden_size"]
320 | return nn.Sequential(
321 | nn.Linear(latent_size, hidden_size // 2),
322 | nn.ReLU(),
323 | nn.Linear(hidden_size // 2, hidden_size),
324 | nn.ReLU(),
325 | nn.Linear(hidden_size, self.flat_dim),
326 | nn.Sigmoid(), # Ensure output in [0,1] range
327 | ).to(DEVICE)
328 |
329 |
330 | class LitSVHNAutoencoder(LitAutoencoder):
331 | def __init__(self, config: Config) -> None:
332 | super().__init__(config)
333 |
334 | def build_transform(self) -> transforms.Compose:
335 | """Returns a torchvision transforms.Compose for SVHN preprocessing."""
336 | return transforms.Compose([transforms.ToTensor()])
337 |
338 | @property
339 | def dataset_cls(self) -> type:
340 | """Returns the dataset class for SVHN."""
341 | return SVHN
342 |
343 | @property
344 | def dataset_kwargs(self) -> dict:
345 | """Returns a dictionary of keyword arguments for the SVHN dataset."""
346 | return {
347 | "train": {"split": "train"},
348 | "test": {"split": "test"},
349 | }
350 |
351 | def build_encoder(self) -> nn.Module:
352 | """Returns a convolutional encoder for SVHN."""
353 | latent_size = self.hparams["latent_size"]
354 | hidden_size = self.hparams["hidden_size"]
355 | return nn.Sequential(
356 | nn.Conv2d(
357 | 3, hidden_size // 4, kernel_size=4, stride=2, padding=1
358 | ), # → [B, 32, 16, 16]
359 | nn.ReLU(),
360 | nn.Conv2d(
361 | hidden_size // 4, hidden_size // 2, kernel_size=4, stride=2, padding=1
362 | ), # → [B, 64, 8, 8]
363 | nn.ReLU(),
364 | nn.Conv2d(
365 | hidden_size // 2, hidden_size, kernel_size=4, stride=2, padding=1
366 | ), # → [B, 128, 4, 4]
367 | nn.ReLU(),
368 | nn.Flatten(),
369 | nn.Linear(hidden_size * 4 * 4, latent_size),
370 | ).to(DEVICE)
371 |
372 | def build_decoder(self) -> nn.Module:
373 | """Returns a convolutional decoder for SVHN."""
374 | latent_size = self.hparams["latent_size"]
375 | hidden_size = self.hparams["hidden_size"]
376 | return nn.Sequential(
377 | nn.Linear(latent_size, hidden_size * 4 * 4),
378 | nn.ReLU(),
379 | nn.Unflatten(1, (hidden_size, 4, 4)),
380 | nn.ConvTranspose2d(
381 | hidden_size, hidden_size // 2, kernel_size=4, stride=2, padding=1
382 | ), # → [B, 64, 8, 8]
383 | nn.ReLU(),
384 | nn.ConvTranspose2d(
385 | hidden_size // 2, hidden_size // 4, kernel_size=4, stride=2, padding=1
386 | ), # → [B, 32, 16, 16]
387 | nn.ReLU(),
388 | nn.ConvTranspose2d(
389 | hidden_size // 4, 3, kernel_size=4, stride=2, padding=1
390 | ), # → [B, 3, 32, 32]
391 | nn.Sigmoid(), # For pixel values in [0,1]
392 | ).to(DEVICE)
--------------------------------------------------------------------------------
/torchvinecopulib/util/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | torchvinecopulib.util
3 | ----------------------
4 | Utility routines for copula‐based dependence measures, 1D KDE CDF/PPF, and root-finding via the Interpolate Truncate and Project (ITP) method.
5 |
6 | Decorators
7 | -----------
8 | * torch.compile() for solve_ITP()
9 | * torch.no_grad() for solve_ITP(), kendall_tau(), mutual_info(), ferreira_tail_dep_coeff(), chatterjee_xi()
10 |
11 | References
12 | -----------
13 | - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D., & O’Brien, J. P. (2016). A fast and objective multidimensional kernel density estimation method: fastKDE. Computational Statistics & Data Analysis, 101, 148-160.
14 | - O’Brien, T. A., Collins, W. D., Rauscher, S. A., & Ringler, T. D. (2014). Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Computational Statistics & Data Analysis, 79, 222-234.
15 | - Purkayastha, S., & Song, P. X. K. (2024). fastMI: A fast and consistent copula-based nonparametric estimator of mutual information. Journal of Multivariate Analysis, 201, 105270.
16 | - Ferreira, M. S. (2013). Nonparametric estimation of the tail-dependence coefficient.
17 | - Chatterjee, S. (2021). A new coefficient of correlation. Journal of the American Statistical Association, 116(536), 2009-2022.
18 | - Lin, Z., & Han, F. (2023). On boosting the power of Chatterjee’s rank correlation. Biometrika, 110(2), 283-299.
19 | - Oliveira, I. F., & Takahashi, R. H. (2020). An enhancement of the bisection method average performance preserving minmax optimality. ACM Transactions on Mathematical Software (TOMS), 47(1), 1-24.
20 | """
21 |
22 | import enum
23 | from pprint import pformat
24 |
25 | import fastkde
26 | import torch
27 | from scipy.stats import kendalltau
28 |
29 | _EPS = 1e-10
30 |
31 |
32 | @torch.no_grad()
33 | def kendall_tau(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
34 | """Compute Kendall's tau correlation coefficient and p-value. Moves inputs to CPU and delegates
35 | to SciPy’s ``kendalltau``.
36 |
37 | Args:
38 | x (torch.Tensor): shape (n, 1)
39 | y (torch.Tensor): shape (n, 1)
40 | Returns:
41 | torch.Tensor: Kendall's tau correlation coefficient and p-value
42 | """
43 | return torch.as_tensor(
44 | kendalltau(x.view(-1).cpu(), y.view(-1).cpu()),
45 | dtype=x.dtype,
46 | device=x.device,
47 | )
48 |
49 |
50 | @torch.no_grad()
51 | def mutual_info(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
52 | """Estimate mutual information using ``fastKDE``. Moves inputs to CPU and delegates to
53 | ``fastKDE.pdf``.
54 |
55 | - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D., & O’Brien, J. P. (2016). A fast and objective multidimensional kernel density estimation method: fastKDE. Computational Statistics & Data Analysis, 101, 148-160.
56 | - O’Brien, T. A., Collins, W. D., Rauscher, S. A., & Ringler, T. D. (2014). Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Computational Statistics & Data Analysis, 79, 222-234.
57 | - Purkayastha, S., & Song, P. X. K. (2024). fastMI: A fast and consistent copula-based nonparametric estimator of mutual information. Journal of Multivariate Analysis, 201, 105270.
58 |
59 | Args:
60 | x (torch.Tensor): shape (n, 1)
61 | y (torch.Tensor): shape (n, 1)
62 | Returns:
63 | torch.Tensor: Estimated mutual information
64 | """
65 | x = x.clamp(_EPS, 1.0 - _EPS).view(-1).cpu()
66 | y = y.clamp(_EPS, 1.0 - _EPS).view(-1).cpu()
67 | joint = torch.as_tensor(fastkde.pdf(x, y).values, dtype=x.dtype, device=x.device)
68 | margin_x = torch.as_tensor(fastkde.pdf(x).values, dtype=x.dtype, device=x.device)
69 | margin_y = torch.as_tensor(fastkde.pdf(y).values, dtype=x.dtype, device=x.device)
70 | return (
71 | joint[joint > 0.0].log().mean()
72 | - margin_x[margin_x > 0.0].log().mean()
73 | - margin_y[margin_y > 0.0].log().mean()
74 | )
75 |
76 |
77 | @torch.no_grad()
78 | def ferreira_tail_dep_coeff(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
79 | """Estimate tail dependence coefficient (λ), modifed from Ferreira's method, symmetric for
80 | (x,y), (y,1-x), (1-x,1-y), (1-y,x), (y,x), (1-x,y), (1-y,1-x), (x,1-y).
81 |
82 | - Ferreira, M. S. (2013). Nonparametric estimation of the tail-dependence coefficient.
83 |
84 | Args:
85 | x (torch.Tensor): shape (n, 1)
86 | y (torch.Tensor): shape (n, 1)
87 | Returns:
88 | torch.Tensor: Estimated tail dependence coefficient
89 | """
90 | return (
91 | 3.0
92 | - (
93 | 1.0
94 | - torch.stack([torch.max(x, y), torch.max(1.0 - x, y)], dim=1)
95 | .mean(dim=0)
96 | .clamp(0.5, 0.6666666666666666)
97 | .min()
98 | ).reciprocal()
99 | )
100 |
101 |
102 | @torch.no_grad()
103 | def chatterjee_xi(x: torch.Tensor, y: torch.Tensor, M: int = 1) -> torch.Tensor:
104 | """Estimate Chatterjee's rank correlation coefficient (ξ)
105 |
106 | - Chatterjee, S. (2021). A new coefficient of correlation. Journal of the American Statistical Association, 116(536), 2009-2022.
107 | - Lin, Z., & Han, F. (2023). On boosting the power of Chatterjee’s rank correlation. Biometrika, 110(2), 283-299.
108 |
109 | Args:
110 | x (torch.Tensor): shape (n, 1)
111 | y (torch.Tensor): shape (n, 1)
112 | M (int, optional): num of nearest-neighbor. Defaults to 1.
113 | Returns:
114 | torch.Tensor: Estimated Chatterjee's rank correlation coefficient
115 | """
116 | # * ranks of x in the order of (ranks of y)
117 | # * ranks of y in the order of (ranks of x)
118 | xrank, yrank = (
119 | x.argsort(dim=0).argsort(dim=0) + 1,
120 | y.argsort(dim=0).argsort(dim=0) + 1,
121 | )
122 | xrank, yrank = xrank[yrank.argsort(dim=0)], yrank[xrank.argsort(dim=0)]
123 |
124 | # * the inner sum inside the numerator term, ∑min(Ri, Rjm(i))
125 | # * max for symmetry as following Remark 1 in Chatterjee (2021)
126 | def xy_sum(m: int) -> tuple:
127 | return (
128 | (torch.min(xrank[:-m], xrank[m:])).sum() + xrank[-m:].sum(),
129 | (torch.min(yrank[:-m], yrank[m:])).sum() + yrank[-m:].sum(),
130 | )
131 |
132 | # * whole eq. 3 in Lin and Han (2023)
133 | n = x.shape[0]
134 | return -2.0 + 24.0 * (
135 | torch.as_tensor([xy_sum(m) for m in range(1, M + 1)], device=x.device, dtype=x.dtype)
136 | .sum(dim=0)
137 | .max()
138 | ) / (M * (1.0 + n) * (1.0 + M + 4.0 * n))
139 |
140 |
141 | class ENUM_FUNC_BIDEP(enum.Enum):
142 | """
143 | Enum wrapper for bivariate dependence functions.
144 | """
145 |
146 | chatterjee_xi = enum.member(chatterjee_xi)
147 | ferreira_tail_dep_coeff = enum.member(ferreira_tail_dep_coeff)
148 | kendall_tau = enum.member(kendall_tau)
149 | mutual_info = enum.member(mutual_info)
150 |
151 | def __call__(self, x: torch.Tensor, y: torch.Tensor, **kw):
152 | return self.value(x, y, **kw)
153 |
154 |
155 | class kdeCDFPPF1D(torch.nn.Module):
156 | _EPS = _EPS
157 |
158 | def __init__(
159 | self,
160 | x: torch.Tensor,
161 | num_step_grid: int = None,
162 | x_min: float = None,
163 | x_max: float = None,
164 | pad: float = 0.1,
165 | ):
166 | """1D KDE CDF/PPF using ``fastKDE`` + Simpson's rule. Given a sample ``x``, fits a kernel
167 | density estimate via ``fastKDE`` on a grid of size ``num_step_grid`` (power of two plus
168 | one). Precomputes PDF, CDF, and their finite‐difference slopes for fast interpolation.
169 |
170 | - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D., & O’Brien, J. P. (2016). A fast and objective multidimensional kernel density estimation method: fastKDE. Computational Statistics & Data Analysis, 101, 148-160.
171 | - O’Brien, T. A., Collins, W. D., Rauscher, S. A., & Ringler, T. D. (2014). Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Computational Statistics & Data Analysis, 79, 222-234.
172 |
173 | Args:
174 | x (torch.Tensor): input sample to fit the KDE.
175 | num_step_grid (int, optional): number of grid points for the KDE, should be power of 2 plus 1. Defaults to None.
176 | x_min (float, optional): minimum value of the grid. Defaults to x.min() - pad.
177 | x_max (float, optional): maximum value of the grid. Defaults to x.max() + pad.
178 | pad (float, optional): padding to extend beyond the min/max when ``x_min``/``x_max`` is None. Defaults to 1.0.
179 | """
180 | super().__init__()
181 | self.num_obs = x.shape[0]
182 | self.x_min = x_min if x_min is not None else x.min().item() - pad
183 | self.x_max = x_max if x_max is not None else x.max().item() + pad
184 | # * power of 2 plus 1
185 | if num_step_grid is None:
186 | num_step_grid = int(2 ** torch.log2(torch.tensor(x.numel())).ceil().item()) + 1
187 | self.num_step_grid = num_step_grid
188 | # * fastkde
189 | res = fastkde.pdf(x.view(-1).cpu().numpy(), num_points=num_step_grid)
190 | xs = torch.from_numpy(res.var0.values).to(dtype=torch.float64)
191 | pdfs = torch.from_numpy(res.values).to(dtype=torch.float64).clamp_min(self._EPS)
192 | N = pdfs.shape[0]
193 | ws = torch.ones(N, dtype=torch.float64)
194 | ws[1:-1:2] = 4
195 | ws[2:-1:2] = 2
196 | h = xs[1] - xs[0]
197 | cdf = torch.cumsum(pdfs * ws, dim=0) * (h / 3)
198 | cdf = cdf / cdf[-1]
199 | slope_fwd = (cdf[1:] - cdf[:-1]) / h
200 | slope_inv = h / (cdf[1:] - cdf[:-1])
201 | slope_pdf = (pdfs[1:] - pdfs[:-1]) / h
202 | self.register_buffer("grid_x", xs)
203 | self.register_buffer("grid_pdf", pdfs)
204 | self.register_buffer("grid_cdf", cdf)
205 | self.register_buffer("slope_fwd", slope_fwd)
206 | self.register_buffer("slope_inv", slope_inv)
207 | self.register_buffer("slope_pdf", slope_pdf)
208 | self.h = h
209 | # ! device agnostic
210 | self.register_buffer("_dd", torch.tensor([], dtype=torch.float64))
211 | self.negloglik = -self.log_pdf(x).mean()
212 |
213 | @property
214 | def device(self):
215 | return self._dd.device
216 |
217 | @property
218 | def dtype(self):
219 | return self._dd.dtype
220 |
221 | def cdf(self, x: torch.Tensor) -> torch.Tensor:
222 | """Compute the CDF of the fitted KDE at ``x``.
223 |
224 | Args:
225 | x (torch.Tensor): Points at which to evaluate the CDF.
226 | Returns:
227 | torch.Tensor: CDF values at ``x``, clamped to [0, 1].
228 | """
229 | # ! device agnostic
230 | x = x.to(device=self.device, dtype=self.dtype)
231 | x_clamped = x.clamp(self.x_min, self.x_max)
232 | idx = torch.searchsorted(self.grid_x, x_clamped, right=False)
233 | idx = idx.clamp(1, self.grid_cdf.numel() - 1)
234 | y = (self.grid_cdf[idx - 1]) + (self.slope_fwd[idx - 1]) * (
235 | x_clamped - self.grid_x[idx - 1]
236 | )
237 | y = torch.where(x < self.x_min, torch.zeros_like(y), y)
238 | y = torch.where(x > self.x_max, torch.ones_like(y), y)
239 | return y.clamp(0.0, 1.0)
240 |
241 | def ppf(self, q: torch.Tensor) -> torch.Tensor:
242 | """Compute the PPF (quantile function) of the fitted KDE at ``q``.
243 |
244 | Args:
245 | q (torch.Tensor): Quantiles at which to evaluate the PPF.
246 | Returns:
247 | torch.Tensor: PPF values at ``q``, clamped to [x_min, x_max].
248 | """
249 | # ! device agnostic
250 | q = q.to(device=self.device, dtype=self.dtype)
251 | q_clamped = q.clamp(0.0, 1.0)
252 | idx = torch.searchsorted(self.grid_cdf, q_clamped, right=False)
253 | idx = idx.clamp(1, self.grid_cdf.numel() - 1)
254 | x = (self.grid_x[idx - 1]) + (self.slope_inv[idx - 1]) * (
255 | q_clamped - self.grid_cdf[idx - 1]
256 | )
257 | x = torch.where(q < 0.0, torch.full_like(x, self.x_min), x)
258 | x = torch.where(q > 1.0, torch.full_like(x, self.x_max), x)
259 | return x.clamp(self.x_min, self.x_max)
260 |
261 | def pdf(self, x: torch.Tensor) -> torch.Tensor:
262 | """Compute the PDF of the fitted KDE at ``x``.
263 |
264 | Args:
265 | x (torch.Tensor): Points at which to evaluate the PDF.
266 | Returns:
267 | torch.Tensor: PDF values at ``x``, clamped to [0, ∞).
268 | """
269 | # ! device agnostic
270 | x = x.to(device=self.device, dtype=self.dtype)
271 | x_clamped = x.clamp(self.x_min, self.x_max)
272 | idx = torch.searchsorted(self.grid_x, x_clamped, right=False)
273 | idx = idx.clamp(1, self.grid_pdf.numel() - 1)
274 | f = self.grid_pdf[idx - 1] + (self.slope_pdf[idx - 1]) * (x_clamped - self.grid_x[idx - 1])
275 | f = torch.where((x < self.x_min) | (x > self.x_max), torch.zeros_like(f), f)
276 | return f.clamp_min(0.0)
277 |
278 | def log_pdf(self, x: torch.Tensor) -> torch.Tensor:
279 | """Compute the log PDF of the fitted KDE at ``x``.
280 |
281 | Args:
282 | x (torch.Tensor): Points at which to evaluate the log PDF.
283 | Returns:
284 | torch.Tensor: Log PDF values at ``x``, guaranteed to be finite.
285 | """
286 | return self.pdf(x).log().nan_to_num(posinf=0.0, neginf=-13.815510557964274)
287 |
288 | def forward(self, x: torch.Tensor) -> torch.Tensor:
289 | """Average negative log-likelihood of the fitted KDE at ``x``.
290 |
291 | Args:
292 | x (torch.Tensor): Points at which to evaluate the negative log-likelihood.
293 | Returns:
294 | torch.Tensor: Negative log-likelihood values at ``x``, averaged over the batch.
295 | """
296 | return -self.log_pdf(x).mean()
297 |
298 | def __str__(self):
299 | """String representation of the ``kdeCDFPPF1D`` object.
300 |
301 | Returns:
302 | str: String representation of the ``kdeCDFPPF1D`` object.
303 | """
304 | header = self.__class__.__name__
305 | params = {
306 | "num_obs": int(self.num_obs),
307 | "negloglik": float(self.negloglik.round(decimals=4)),
308 | "x_min": float(round(self.x_min, 4)),
309 | "x_max": float(round(self.x_max, 4)),
310 | "num_step_grid": int(self.num_step_grid),
311 | "dtype": self.dtype,
312 | "device": self.device,
313 | }
314 | params_str = pformat(params, sort_dicts=False, underscore_numbers=True)
315 | return f"{header}\n{params_str[1:-1]}\n\n"
316 |
317 |
318 | # @torch.compile
319 | @torch.no_grad()
320 | def solve_ITP(
321 | fun: callable,
322 | x_a: torch.Tensor,
323 | x_b: torch.Tensor,
324 | epsilon: float = _EPS,
325 | num_iter_max: int = 31,
326 | k_1: float = 0.2,
327 | ) -> torch.Tensor:
328 | """Root-finding for ``fun`` via the Interpolate Truncate and Project (ITP) method within
329 | [``x_a``, ``x_b``], with guaranteed average performance strictly better than the bisection
330 | method under any continuous distribution.
331 |
332 | - Oliveira, I. F., & Takahashi, R. H. (2020). An enhancement of the bisection method average performance preserving minmax optimality. ACM Transactions on Mathematical Software (TOMS), 47(1), 1-24.
333 | https://en.wikipedia.org/wiki/ITP_method
334 | https://docs.rs/kurbo/latest/kurbo/common/fn.solve_itp.html
335 |
336 | Args:
337 | fun (callable): function to find the root of.
338 | x_a (torch.Tensor): lower bound of the interval to search.
339 | x_b (torch.Tensor): upper bound of the interval to search.
340 | epsilon (float, optional): convergence tolerance. Defaults to _EPS.
341 | num_iter_max (int, optional): maximum number of iterations. Defaults to 31.
342 | k_1 (float, optional): scaling factor for the truncation step. Defaults to 0.2.
343 | Returns:
344 | torch.Tensor: approximated root of the function `fun` in the interval [x_a, x_b].
345 | """
346 | y_a, y_b = fun(x_a), fun(x_b)
347 | # * corner cases
348 | x_a = torch.where(condition=y_b.abs() < epsilon, input=x_b - epsilon * num_iter_max, other=x_a)
349 | x_b = torch.where(condition=y_a.abs() < epsilon, input=x_a + epsilon * num_iter_max, other=x_b)
350 | y_a, y_b, x_wid = fun(x_a), fun(x_b), x_b - x_a
351 | eps_2 = torch.as_tensor(epsilon * 2.0, device=x_a.device, dtype=x_a.dtype)
352 | eps_scale = epsilon * 2.0 ** (
353 | (x_wid / epsilon).max().clamp_min(1.0).log2().ceil().clamp_min(1.0).int()
354 | )
355 | x_half = torch.empty_like(x_wid)
356 | rho = torch.empty_like(x_wid)
357 | sigma = torch.empty_like(x_wid)
358 | delta = torch.empty_like(x_wid)
359 | for _ in range(num_iter_max):
360 | if (x_wid < eps_2).all():
361 | break
362 | # * update parameters
363 | x_half.copy_(0.5 * (x_a + x_b))
364 | rho.copy_(eps_scale - 0.5 * x_wid)
365 | # * interpolation
366 | x_f = (y_b * x_a - y_a * x_b) / (y_b - y_a)
367 | sigma.copy_(x_half - x_f)
368 | # ! here k2 = 2 hardwired for efficiency.
369 | delta.copy_(k_1 * x_wid.square())
370 | # * truncation
371 | x_t = torch.where(
372 | condition=delta <= sigma.abs(),
373 | input=x_f + torch.copysign(delta, sigma),
374 | other=x_half,
375 | )
376 | # * projection
377 | x_itp = torch.where(
378 | condition=rho >= (x_t - x_half).abs(),
379 | input=x_t,
380 | other=x_half - torch.copysign(rho, sigma),
381 | )
382 | # * update interval
383 | y_itp = fun(x_itp)
384 | idx = y_itp > 0.0
385 | x_b[idx], y_b[idx] = x_itp[idx], y_itp[idx]
386 | idx = ~idx
387 | x_a[idx], y_a[idx] = x_itp[idx], y_itp[idx]
388 | x_wid = x_b - x_a
389 | eps_scale *= 0.5
390 | return 0.5 * (x_a + x_b)
391 |
--------------------------------------------------------------------------------
/torchvinecopulib/bicop/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | torchvinecopulib.bicop
3 | -----------------------
4 | Provides ``BiCop`` (``torch.nn.Module``) for estimating, evaluating, and sampling
5 | from bivariate copulas via ``tll`` or ``fastKDE`` approaches.
6 |
7 | Decorators
8 | -----------
9 | * torch.compile() for bilinear interpolation
10 | * torch.no_grad() for fit(), hinv_l(), hinv_r(), sample(), and imshow()
11 |
12 | Key Features
13 | -------------
14 | - Fits a copula density on a uniform [0,1]² grid and caches PDF/CDF/h‐functions
15 | - Device‐agnostic: all buffers live on the same device/dtype you fit on
16 | - Fast bilinear interpolation compiled with ``torch.compile``
17 | - Convenient ``.cdf()``, ``.pdf()``, ``.hfunc_*()``, ``.hinv_*()``, and ``.sample()`` APIs
18 | - Plotting helpers: ``.imshow()`` and ``.plot(contour|surface)``
19 |
20 | Usage
21 | ------
22 | >>> from torchvinecopulib.bicop import BiCop
23 | >>> cop = BiCop(num_step_grid=256)
24 | >>> cop.fit(obs) # obs: Tensor of shape (n,2) in [0,1]²
25 | >>> u = torch.rand(10, 2)
26 | >>> cdf_vals = cop.cdf(u)
27 | >>> samples = cop.sample(1000, is_sobol=True)
28 |
29 | References
30 | -----------
31 | - Nagler, T., Schellhase, C., & Czado, C. (2017). Nonparametric estimation of simplified vine copula models: comparison of methods. Dependence Modeling, 5(1), 99-120.
32 | - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D. & O’Brien, J. P. A fast and objective multidimensional kernel density estimation method: fastKDE. Comput. Stat. Data Anal. 101, 148–160 (2016). http://dx.doi.org/10.1016/j.csda.2016.02.014
33 | - O’Brien, T. A., Collins, W. D., Rauscher, S. A. & Ringler, T. D. Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Comput. Stat. Data Anal. 79, 222–234 (2014). http://dx.doi.org/10.1016/j.csda.2014.06.002
34 | """
35 |
36 | from pprint import pformat
37 | from typing import Optional, cast
38 |
39 | import matplotlib.pyplot as plt
40 | import numpy as np
41 | import pyvinecopulib as pv
42 | import torch
43 | from fastkde import pdf as fkpdf
44 | from matplotlib.colors import LinearSegmentedColormap
45 | from mpl_toolkits.mplot3d.axes3d import Axes3D
46 | from scipy.stats import kendalltau, norm
47 |
48 | from ..util import _EPS, solve_ITP
49 |
50 | __all__ = [
51 | "BiCop",
52 | ]
53 |
54 |
55 | class BiCop(torch.nn.Module):
56 | # ! hinv
57 | _EPS: float = _EPS
58 |
59 | def __init__(
60 | self,
61 | num_step_grid: int = 128,
62 | ):
63 | """Initializes the bivariate copula (BiCop) class. By default an independent bicop.
64 |
65 | Args:
66 | num_step_grid (int, optional): number of steps per dimension for the precomputed grids (must be a power of 2). Defaults to 128.
67 | """
68 | super().__init__()
69 | # * by default an independent bicop, otherwise cache grids from KDE
70 | self.is_indep = True
71 | self.mtd_kde = "tll" # default method for estimating the copula density
72 | self.num_step_grid = num_step_grid
73 | self.register_buffer("tau", torch.zeros(2, dtype=torch.float64))
74 | self.register_buffer("num_obs", torch.empty((), dtype=torch.int))
75 | self.register_buffer("negloglik", torch.zeros((), dtype=torch.float64))
76 | self.register_buffer(
77 | "_pdf_grid",
78 | torch.ones(num_step_grid, num_step_grid, dtype=torch.float64),
79 | )
80 | self.register_buffer(
81 | "_cdf_grid",
82 | torch.empty(num_step_grid, num_step_grid, dtype=torch.float64),
83 | )
84 | self.register_buffer(
85 | "_hfunc_l_grid",
86 | torch.empty(num_step_grid, num_step_grid, dtype=torch.float64),
87 | )
88 | self.register_buffer(
89 | "_hfunc_r_grid",
90 | torch.empty(num_step_grid, num_step_grid, dtype=torch.float64),
91 | )
92 | # ! device agnostic
93 | self.register_buffer("_dd", torch.tensor([], dtype=torch.float64))
94 |
95 | @property
96 | def device(self) -> torch.device:
97 | """Get the device of the bicop model (all internal buffers).
98 |
99 | Returns:
100 | torch.device: The device on which the registered buffers reside.
101 | """
102 | return self._dd.device
103 |
104 | @property
105 | def dtype(self) -> torch.dtype:
106 | """Get the data type of the bicop model (all internal buffers). Should be torch.float64.
107 |
108 | Returns:
109 | torch.dtype: The data type of the registered buffers.
110 | """
111 | return self._dd.dtype
112 |
113 | @torch.no_grad()
114 | def reset(self) -> None:
115 | """Reinitialize state and zero all statistics and precomputed grids.
116 |
117 | Sets the bicop back to independent bicop and clears accumulated
118 | metrics (``tau``, ``num_obs``, ``negloglik``) as well as all grid buffers
119 | (``_pdf_grid``, ``_cdf_grid``, ``_hfunc_l_grid``, ``_hfunc_r_grid``).
120 | """
121 | self.is_indep = True
122 | self.tau.zero_()
123 | self.num_obs.zero_()
124 | self.negloglik.zero_()
125 | self._pdf_grid.zero_()
126 | self._cdf_grid.zero_()
127 | self._hfunc_l_grid.zero_()
128 | self._hfunc_r_grid.zero_()
129 |
130 | @torch.no_grad()
131 | def fit(
132 | self,
133 | obs: torch.Tensor,
134 | mtd_kde: str = "tll",
135 | mtd_tll: str = "constant",
136 | num_iter_max: int = 17,
137 | is_tau_est: bool = False,
138 | ) -> None:
139 | """Estimate and cache PDF/CDF/h-function grids from bivariate copula observations.
140 |
141 | This method computes KDE-based bicopula densities on a uniform [0,1]² grid and populates internal buffers
142 | (``_pdf_grid``, ``_cdf_grid``, ``_hfunc_l_grid``, ``_hfunc_r_grid``, ``negloglik``).
143 |
144 | - Nagler, T., Schellhase, C., & Czado, C. (2017). Nonparametric estimation of simplified vine copula models: comparison of methods. Dependence Modeling, 5(1), 99-120.
145 | - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D. & O’Brien, J. P. A fast and objective multidimensional kernel density estimation method: fastKDE. Comput. Stat. Data Anal. 101, 148–160 (2016). http://dx.doi.org/10.1016/j.csda.2016.02.014
146 | - O’Brien, T. A., Collins, W. D., Rauscher, S. A. & Ringler, T. D. Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Comput. Stat. Data Anal. 79, 222–234 (2014). http://dx.doi.org/10.1016/j.csda.2014.06.002
147 |
148 |
149 | Args:
150 | obs (torch.Tensor): shape (n, 2) bicop obs in [0, 1]².
151 | mtd_kde (str, optional): Method for estimating the copula density. One of ("tll", "fastKDE"). Defaults to "tll".
152 | mtd_tll (str, optional): fit method for the transformation local-likelihood (TLL) nonparametric family, used only when ``mtd_kde="tll"``, one of ("constant", "linear", or "quadratic"). Defaults to "constant".
153 | num_iter_max (int, optional): num of Sinkhorn/IPF iters for grid normalization, used only when ``mtd_kde="fastKDE"``. Defaults to 17.
154 | is_tau_est (bool, optional): If True, compute and store Kendall’s τ. Defaults to ``False``.
155 | """
156 | # ! device agnostic
157 | device, dtype = self.device, self.dtype
158 | self.is_indep = False
159 | self.mtd_kde = mtd_kde
160 | self.num_obs.copy_(obs.shape[0])
161 | # * assuming already in [0, 1]
162 | obs = obs.clamp(min=0.0, max=1.0)
163 | if is_tau_est:
164 | self.tau.copy_(
165 | torch.as_tensor(
166 | kendalltau(obs[:, 0].cpu(), obs[:, 1].cpu()),
167 | device=device,
168 | dtype=dtype,
169 | )
170 | )
171 | self._target = self.num_step_grid - 1.0 # * marginal target
172 | self.step_grid = 1.0 / self._target
173 | # ! pdf
174 | if mtd_kde == "tll":
175 | controls = pv.FitControlsBicop(
176 | family_set=[pv.tll],
177 | num_threads=torch.get_num_threads(),
178 | nonparametric_method=mtd_tll,
179 | )
180 | cop = pv.Bicop.from_data(data=obs.cpu().numpy(), controls=controls)
181 | axis = torch.linspace(
182 | _EPS,
183 | 1.0 - _EPS,
184 | steps=self.num_step_grid,
185 | device="cpu",
186 | dtype=torch.float64,
187 | )
188 | pdf_grid = (
189 | torch.from_numpy(cop.pdf(torch.cartesian_prod(axis, axis).view(-1, 2).numpy()))
190 | .view(self.num_step_grid, self.num_step_grid)
191 | .to(device=device, dtype=dtype)
192 | )
193 | elif mtd_kde == "fastKDE":
194 | pdf_grid = torch.from_numpy(
195 | fkpdf(
196 | obs[:, 0].cpu(),
197 | obs[:, 1].cpu(),
198 | num_points=self.num_step_grid * 2 + 1,
199 | ).values
200 | ).to(device=device, dtype=dtype)
201 | # * padding/trimming after fastkde.pdf
202 | H, W = pdf_grid.shape
203 | if H < self.num_step_grid:
204 | pdf_grid = torch.cat(
205 | [
206 | pdf_grid,
207 | torch.zeros(self.num_step_grid - H, W, dtype=dtype, device=device),
208 | ],
209 | dim=0,
210 | )
211 | H, W = pdf_grid.shape
212 | if W < self.num_step_grid:
213 | pdf_grid = torch.cat(
214 | [
215 | pdf_grid,
216 | torch.zeros(H, self.num_step_grid - W, dtype=dtype, device=device),
217 | ],
218 | dim=1,
219 | )
220 | pdf_grid = pdf_grid[: self.num_step_grid, : self.num_step_grid].clamp_min(0.0)
221 | pdf_grid = pdf_grid.view(self.num_step_grid, self.num_step_grid).T
222 | # * normalization: Sinkhorn / iterative proportional fitting (IPF)
223 | for _ in range(num_iter_max):
224 | pdf_grid *= self._target / pdf_grid.sum(dim=0, keepdim=True)
225 | pdf_grid *= self._target / pdf_grid.sum(dim=1, keepdim=True)
226 | pdf_grid /= pdf_grid.sum() * self.step_grid**2
227 | else:
228 | raise NotImplementedError
229 | self._pdf_grid = pdf_grid
230 | # * negloglik
231 | self.negloglik = -self.log_pdf(obs=obs).nan_to_num(posinf=0.0, neginf=0.0).sum()
232 | # ! cdf
233 | self._cdf_grid = ((self._pdf_grid * self.step_grid**2).cumsum(dim=0).cumsum(dim=1)).clamp_(
234 | 0.0, 1.0
235 | )
236 | # ! h functions
237 | self._hfunc_l_grid = (self._pdf_grid * self.step_grid).cumsum(dim=1).clamp_(0.0, 1.0)
238 | self._hfunc_r_grid = (self._pdf_grid * self.step_grid).cumsum(dim=0).clamp_(0.0, 1.0)
239 |
240 | # @torch.compile
241 | def _interp(self, grid: torch.Tensor, obs: torch.Tensor) -> torch.Tensor:
242 | """Bilinearly interpolate values on a 2D grid at given sample points.
243 |
244 | Args:
245 | grid (torch.Tensor): Precomputed grid of values (e.g., PDF/CDF/h‐function), shape (m,m).
246 | obs (torch.Tensor): Points in [0,1]² where to interpolate (rows are (u₁,u₂)), shape (n,2).
247 |
248 | Returns:
249 | torch.Tensor: Interpolated grid values at each observation, clamped ≥0, shape (n,1).
250 | """
251 | idx = obs.clamp(self._EPS, 1 - self._EPS) / self.step_grid
252 | i0 = idx.floor().long()
253 | di = idx - i0
254 | i1 = torch.minimum(
255 | i0 + 1,
256 | torch.full_like(input=i0, fill_value=self._target, device=idx.device),
257 | )
258 | g00 = grid[i0[:, 0], i0[:, 1]]
259 | g10 = grid[i1[:, 0], i0[:, 1]]
260 | g01 = grid[i0[:, 0], i1[:, 1]]
261 | g11 = grid[i1[:, 0], i1[:, 1]]
262 | return (
263 | g00
264 | + (g10 - g00) * di[:, 0]
265 | + (g01 - g00) * di[:, 1]
266 | + (g11 - g01 - g10 + g00) * di[:, 0] * di[:, 1]
267 | ).clamp_min(0.0)
268 |
269 | def cdf(self, obs: torch.Tensor) -> torch.Tensor:
270 | """Evaluate the copula CDF at given points. For independent copula, returns u₁·u₂.
271 |
272 | Args:
273 | obs (torch.Tensor): Points in [0,1]² where to evaluate the CDF (rows are (u₁,u₂)), shape (n,2).
274 |
275 | Returns:
276 | torch.Tensor: CDF values at each observation, shape (n,1).
277 | """
278 | # ! device agnostic
279 | obs = obs.to(device=self.device, dtype=self.dtype)
280 | if self.is_indep:
281 | return obs.prod(dim=1, keepdim=True)
282 | return self._interp(grid=self._cdf_grid, obs=obs).unsqueeze(dim=1)
283 |
284 | def hfunc_l(self, obs: torch.Tensor) -> torch.Tensor:
285 | """Evaluate the left h-function at given points. Computes H(u₂ | u₁):= ∂/∂u₁ C(u₁,u₂) for
286 | the fitted copula. For independent copula, returns u₂.
287 |
288 | Args:
289 | obs (torch.Tensor): Points in [0,1]² where to evaluate the left h-function (rows are (u₁,u₂)), shape (n,2).
290 |
291 | Returns:
292 | torch.Tensor: Left h-function values at each observation, shape (n,1).
293 | """
294 | # ! device agnostic
295 | obs = obs.to(device=self.device, dtype=self.dtype)
296 | if self.is_indep:
297 | return obs[:, [1]]
298 | return self._interp(grid=self._hfunc_l_grid, obs=obs).unsqueeze(dim=1)
299 |
300 | def hfunc_r(self, obs: torch.Tensor) -> torch.Tensor:
301 | """Evaluate the right h-function at given points. Computes H(u₁ | u₂):= ∂/∂u₂ C(u₁,u₂) for
302 | the fitted copula. For independent copula, returns u₁.
303 |
304 | Args:
305 | obs (torch.Tensor): Points in [0,1]² where to evaluate the right h-function (rows are (u₁,u₂)), shape (n,2).
306 |
307 | Returns:
308 | torch.Tensor: Right h-function values at each observation, shape (n,1).
309 | """
310 | # ! device agnostic
311 | obs = obs.to(device=self.device, dtype=self.dtype)
312 | if self.is_indep:
313 | return obs[:, [0]]
314 | return self._interp(grid=self._hfunc_r_grid, obs=obs).unsqueeze(dim=1)
315 |
316 | @torch.no_grad()
317 | def hinv_l(self, obs: torch.Tensor) -> torch.Tensor:
318 | """Invert the left h‐function via root‐finding: find u₂ given (u₁, p). Solves H(u₂ | u₁) = p
319 | by ITP between 0 and 1.
320 |
321 | Args:
322 | obs (torch.Tensor): Points in [0,1]² where to evaluate the left h-function (rows are (u₁,u₂)), shape (n,2).
323 |
324 | Returns:
325 | torch.Tensor: Solutions u₂ ∈ [0,1], shape (n,1).
326 | """
327 | # ! device agnostic
328 | obs = obs.to(device=self.device, dtype=self.dtype)
329 | if self.is_indep:
330 | return obs[:, [1]]
331 | # * via root-finding
332 | u_l = obs[:, [0]]
333 | p = obs[:, [1]]
334 | return solve_ITP(
335 | fun=lambda u_r: self.hfunc_l(obs=torch.hstack([u_l, u_r])) - p,
336 | x_a=torch.zeros_like(p),
337 | x_b=torch.ones_like(p),
338 | ).clamp(min=0.0, max=1.0)
339 |
340 | @torch.no_grad()
341 | def hinv_r(self, obs: torch.Tensor) -> torch.Tensor:
342 | """Invert the right h‐function via root‐finding: find u₁ given (u₂, p). Solves H(u₁ | u₂) =
343 | p by ITP between 0 and 1.
344 |
345 | Args:
346 | obs (torch.Tensor): Points in [0,1]² where to evaluate the right h-function (rows are (u₁,u₂)), shape (n,2).
347 | Returns:
348 | torch.Tensor: Solutions u₁ ∈ [0,1], shape (n,1).
349 | """
350 | # ! device agnostic
351 | obs = obs.to(device=self.device, dtype=self.dtype)
352 | if self.is_indep:
353 | return obs[:, [0]]
354 | # * via root-finding
355 | u_r = obs[:, [1]]
356 | p = obs[:, [0]]
357 | return solve_ITP(
358 | fun=lambda u_l: self.hfunc_r(obs=torch.hstack([u_l, u_r])) - p,
359 | x_a=torch.zeros_like(p),
360 | x_b=torch.ones_like(p),
361 | ).clamp(min=0.0, max=1.0)
362 |
363 | def pdf(self, obs: torch.Tensor) -> torch.Tensor:
364 | """Evaluate the copula PDF at given points. For independent copula, returns 1.
365 |
366 | Args:
367 | obs (torch.Tensor): Points in [0,1]² where to evaluate the PDF (rows are (u₁,u₂)), shape (n,2).
368 | Returns:
369 | torch.Tensor: PDF values at each observation, shape (n,1).
370 | """
371 | # ! device agnostic
372 | obs = obs.to(device=self.device, dtype=self.dtype)
373 | if self.is_indep:
374 | return torch.ones_like(obs[:, [0]])
375 | return self._interp(grid=self._pdf_grid, obs=obs).unsqueeze(dim=1)
376 |
377 | def log_pdf(self, obs: torch.Tensor) -> torch.Tensor:
378 | """Evaluate the copula log-PDF at given points, with safe handling of inf/nan. For
379 | independent copula, returns 0.
380 |
381 | Args:
382 | obs (torch.Tensor): Points in [0,1]² where to evaluate the log-PDF (rows are (u₁,u₂)), shape (n,2).
383 | Returns:
384 | torch.Tensor: log-PDF values at each observation, shape (n,1).
385 | """
386 | # ! device agnostic
387 | obs = obs.to(device=self.device, dtype=self.dtype)
388 | if self.is_indep:
389 | return torch.zeros_like(obs[:, [0]])
390 | return self.pdf(obs=obs).log().nan_to_num(posinf=0.0, neginf=-13.815510557964274)
391 |
392 | @torch.no_grad()
393 | def sample(
394 | self, num_sample: int = 100, seed: int = 42, is_sobol: bool = False
395 | ) -> torch.Tensor:
396 | """Sample from the copula by inverse Rosenblatt transform. Uses Sobol sequence if
397 | ``is_sobol=True``, otherwise uniform RNG. For independent copula, returns uniform samples in
398 | [0,1]².
399 |
400 | Args:
401 | num_sample (int, optional): number of samples to generate. Defaults to 100.
402 | seed (int, optional): random seed for reproducibility. Defaults to 42.
403 | is_sobol (bool, optional): If True, use Sobol sampling. Defaults to False.
404 | Returns:
405 | torch.Tensor: Generated samples, shape (num_sample, 2).
406 | """
407 | # ! device agnostic
408 | device, dtype = self.device, self.dtype
409 | if is_sobol:
410 | obs = (
411 | torch.quasirandom.SobolEngine(dimension=2, scramble=True, seed=seed)
412 | .draw(n=num_sample, dtype=dtype)
413 | .to(device=device)
414 | )
415 | else:
416 | torch.manual_seed(seed=seed)
417 | obs = torch.rand(size=(num_sample, 2), dtype=dtype, device=device)
418 | if not self.is_indep:
419 | obs[:, [1]] = self.hinv_l(obs=obs)
420 | return obs
421 |
422 | def __str__(self) -> str:
423 | """String representation of the BiCop class.
424 |
425 | Returns:
426 | str: String representation of the BiCop class.
427 | """
428 | return f"""{self.__class__.__name__}\n{
429 | pformat(
430 | object={
431 | "is_indep": self.is_indep,
432 | "num_obs": self.num_obs,
433 | "negloglik": self.negloglik.round(decimals=4),
434 | "num_step_grid": self.num_step_grid,
435 | "tau": self.tau.round(decimals=4),
436 | "mtd_kde": self.mtd_kde,
437 | "dtype": self._dd.dtype,
438 | "device": self._dd.device,
439 | },
440 | compact=True,
441 | sort_dicts=False,
442 | underscore_numbers=True,
443 | )
444 | }"""
445 |
446 | @torch.no_grad()
447 | def imshow(
448 | self,
449 | is_log_pdf: bool = False,
450 | ax: plt.Axes | None = None,
451 | cmap: str = "inferno",
452 | xlabel: str = r"$u_{left}$",
453 | ylabel: str = r"$u_{right}$",
454 | title: str = "Estimated bivariate copula density",
455 | colorbartitle: str = "Density",
456 | **imshow_kwargs: dict,
457 | ) -> tuple[plt.Figure, plt.Axes]:
458 | """Display the (log-)PDF grid as a heatmap.
459 |
460 | Args:
461 | is_log_pdf (bool, optional): If True, plot log-PDF. Defaults to False.
462 | ax (plt.Axes, optional): Matplotlib Axes object to plot on. If None, a new figure and axes are created. Defaults to None.
463 | cmap (str, optional): Colormap for the plot. Defaults to "inferno".
464 | xlabel (str, optional): X-axis label. Defaults to r"$u_{left}$".
465 | ylabel (str, optional): Y-axis label. Defaults to r"$u_{right}$".
466 | title (str, optional): Plot title. Defaults to "Estimated bivariate copula density".
467 | colorbartitle (str, optional): Colorbar title. Defaults to "Density".
468 | **imshow_kwargs: Additional keyword arguments for imshow.
469 | Returns:
470 | tuple[plt.Figure, plt.Axes]: The figure and axes objects.
471 | """
472 | if ax is None:
473 | fig, ax = plt.subplots()
474 | else:
475 | fig = ax.figure
476 | im = ax.imshow(
477 | X=self._pdf_grid.log().nan_to_num(posinf=0.0, neginf=-13.815510557964274).cpu()
478 | if is_log_pdf
479 | else self._pdf_grid.cpu(),
480 | extent=(0, 1, 0, 1),
481 | origin="lower",
482 | cmap=cmap,
483 | **imshow_kwargs,
484 | )
485 | ax.set_xlabel(xlabel=xlabel)
486 | ax.set_ylabel(ylabel=ylabel)
487 | ax.set_title(label=title)
488 | ax.set_xlim(0, 1)
489 | ax.set_ylim(0, 1)
490 | plt.colorbar(im, ax=ax, label=colorbartitle)
491 | return fig, ax
492 |
493 | @torch.no_grad()
494 | def plot(
495 | self,
496 | plot_type: str = "surface",
497 | margin_type: str = "unif",
498 | xylim: Optional[tuple[float, float]] = None,
499 | grid_size: Optional[int] = None,
500 | ) -> tuple[plt.Figure, plt.Axes]:
501 | """Plot the bivariate copula density.
502 |
503 | Args:
504 | plot_type (str, optional): Type of plot, either "contour" or "surface". Defaults to "surface".
505 | margin_type (str, optional): Type of margin, either "unif" or "norm". Defaults to "unif".
506 | xylim (tuple[float, float], optional): Limits for x and y axes. Defaults to None.
507 | grid_size (int, optional): Size of the grid for the plot. Defaults to None.
508 | Returns:
509 | tuple[plt.Figure, plt.Axes]: The figure and axes objects.
510 | """
511 | # * validate inputs
512 | if plot_type not in ["contour", "surface"]:
513 | raise ValueError("Unknown type")
514 | elif plot_type == "contour" and grid_size is None:
515 | grid_size = 100
516 | elif plot_type == "surface" and grid_size is None:
517 | grid_size = 40
518 | # * margin type and grid points
519 | if margin_type not in ["unif", "norm"]:
520 | raise ValueError("Unknown margin type")
521 | elif margin_type == "unif":
522 | if xylim is None:
523 | xylim = (1e-2, 1 - 1e-2)
524 | if plot_type == "contour":
525 | points = np.linspace(1e-5, 1 - 1e-5, grid_size)
526 | else:
527 | points = np.linspace(1, grid_size, grid_size) / (grid_size + 1)
528 | g = np.meshgrid(points, points)
529 | points = g[0][0]
530 | adj = 1
531 | levels = [0.2, 0.6, 1, 1.5, 2, 3, 5, 10, 20]
532 | xlabel, ylabel = "u1", "u2"
533 | elif margin_type == "norm":
534 | if xylim is None:
535 | xylim = (-3, 3)
536 | points = norm.cdf(np.linspace(xylim[0], xylim[1], grid_size))
537 | g = np.meshgrid(points, points)
538 | points = norm.ppf(g[0][0])
539 | adj = np.outer(norm.pdf(points), norm.pdf(points))
540 | levels = [0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5]
541 | xlabel, ylabel = "z1", "z2"
542 |
543 | # * evaluate on grid
544 | g_tensor = torch.from_numpy(np.stack(g, axis=-1).reshape(-1, 2)).to(
545 | device=self.device, dtype=self.dtype
546 | )
547 | vals = self.pdf(g_tensor).cpu().numpy()
548 | cop = np.reshape(vals, (grid_size, grid_size))
549 |
550 | # * adjust for margins
551 | dens = cop * adj
552 | if len(np.unique(dens)) == 1:
553 | dens[0] = 1.000001 * dens[0]
554 | if margin_type == "unif":
555 | zlim = (0, max(3, 1.1 * max(dens.ravel())))
556 | elif margin_type == "norm":
557 | zlim = (0, max(0.4, 1.1 * max(dens.ravel())))
558 |
559 | # * create a colormap
560 | jet_colors = LinearSegmentedColormap.from_list(
561 | name="jet_colors",
562 | colors=[
563 | "#00007F",
564 | "blue",
565 | "#007FFF",
566 | "cyan",
567 | "#7FFF7F",
568 | "yellow",
569 | "#FF7F00",
570 | "red",
571 | "#7F0000",
572 | ],
573 | N=100,
574 | )
575 |
576 | # * plot
577 | if plot_type == "contour":
578 | fig, ax = plt.subplots()
579 | contour = ax.contour(points, points, dens, levels=levels, cmap="gray")
580 | ax.clabel(contour, inline=True, fontsize=8, fmt="%1.2f")
581 | ax.set_aspect("equal")
582 | ax.grid(True)
583 | elif plot_type == "surface":
584 | fig = plt.figure()
585 | ax = cast(Axes3D, fig.add_subplot(111, projection="3d"))
586 | ax.view_init(elev=30, azim=-110)
587 | X, Y = np.meshgrid(points, points)
588 | ax.plot_surface(X, Y, dens, cmap=jet_colors, edgecolor="none", shade=False)
589 | ax.set_zlim(zlim)
590 | ax.set_box_aspect([1, 1, 1])
591 | ax.xaxis.pane.fill = False
592 | ax.yaxis.pane.fill = False
593 | ax.zaxis.pane.fill = False
594 | ax.grid(False)
595 | ax.set_xlabel(xlabel)
596 | ax.set_ylabel(ylabel)
597 | ax.set_xlim(xylim)
598 | ax.set_ylim(xylim)
599 | fig.tight_layout()
600 | plt.draw_if_interactive()
601 | return fig, ax
602 |
--------------------------------------------------------------------------------