├── examples ├── vcae │ ├── vcae │ │ ├── __init__.py │ │ ├── config.py │ │ ├── metrics.py │ │ ├── experiment.py │ │ └── model.py │ ├── run_seeds.sh │ ├── run_seeds.py │ ├── vcae_analyze.ipynb │ └── vcae_debug.ipynb ├── benchmark │ ├── readme.md │ └── __init__.py └── pred_intvl │ ├── readme.md │ ├── vine_wrapper.py │ ├── experiments.py │ ├── data_utils.py │ └── models.py ├── docs ├── Makefile ├── conf.py └── index.rst ├── torchvinecopulib ├── __init__.py ├── util │ └── __init__.py └── bicop │ └── __init__.py ├── tests ├── test_package_metadata.py ├── __init__.py ├── test_util.py ├── test_vinecop.py └── test_bicop.py ├── .github └── workflows │ ├── static.yml │ ├── python-package.yml │ └── publish-package.yml ├── pyproject.toml ├── .gitignore ├── README.md └── LICENSE /examples/vcae/vcae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/benchmark/readme.md: -------------------------------------------------------------------------------- 1 | ## Specify project root 2 | 3 | Put a `.env` file in project root folder to specify the project root folder location as an env var: 4 | 5 | ```bash 6 | DIR_WORK='~/path/to/project/root/folder' 7 | ``` 8 | 9 | ## Environment setup 10 | 11 | In project root folder run in bash: 12 | 13 | ```bash 14 | # cpu only 15 | uv sync --extra cpu -U 16 | # or with cuda 17 | uv sync --extra cu126 -U 18 | ``` 19 | 20 | ## Run 21 | 22 | `cd` into project root folder, trigger by: 23 | 24 | ```bash 25 | uv run ./examples/benchmark/__init__.py 26 | ``` 27 | -------------------------------------------------------------------------------- /examples/pred_intvl/readme.md: -------------------------------------------------------------------------------- 1 | ## Specify project root 2 | 3 | Put a `.env` file in project root folder to specify the project root folder location as an env var: 4 | 5 | ```bash 6 | DIR_WORK='~/path/to/project/root/folder' 7 | ``` 8 | 9 | ## Environment setup 10 | 11 | In project root folder run in bash: 12 | 13 | ```bash 14 | # cpu only 15 | uv sync --extra cpu -U 16 | # or with cuda 17 | uv sync --extra cu126 -U 18 | ``` 19 | 20 | ## Run 21 | 22 | `cd` into project root folder, trigger by: 23 | 24 | ```bash 25 | uv run ./examples/pred_intvl/experiments.py 26 | ``` 27 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | clean: 18 | @$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O); 19 | 20 | # Catch-all target: route all unknown targets to Sphinx using the new 21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 22 | %: Makefile 23 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 24 | -------------------------------------------------------------------------------- /torchvinecopulib/__init__.py: -------------------------------------------------------------------------------- 1 | from . import util 2 | from .bicop import BiCop 3 | from .vinecop import VineCop 4 | 5 | __all__ = [ 6 | "BiCop", 7 | "VineCop", 8 | "util", 9 | ] 10 | # dynamically grab the version you just built & installed 11 | try: 12 | from importlib.metadata import version, PackageNotFoundError 13 | except ImportError: 14 | # Python <3.8 fallback 15 | from pkg_resources import ( 16 | get_distribution as version, 17 | DistributionNotFound as PackageNotFoundError, 18 | ) 19 | 20 | try: 21 | __version__ = version(__name__) 22 | except PackageNotFoundError: 23 | # this can happen if you run from a source checkout 24 | __version__ = "0+unknown" 25 | 26 | __title__ = "torchvinecopulib" # the canonical project name 27 | __author__ = "Tuoyuan Cheng" 28 | __url__ = "https://github.com/TY-Cheng/torchvinecopulib" # the project homepage 29 | __description__ = "Fitting and sampling vine copulas using PyTorch." 30 | -------------------------------------------------------------------------------- /tests/test_package_metadata.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import torchvinecopulib as tvc 3 | 4 | 5 | def test___all___imports_everything(): 6 | # make sure every name in __all__ actually exists on the package 7 | for name in tvc.__all__: 8 | assert hasattr(tvc, name), f"{name!r} is missing from torchvinecopulib" 9 | 10 | 11 | def test_version_matches_distribution(): 12 | # this will throw if the package isn’t actually installed under that name 13 | dist_version = importlib.metadata.version("torchvinecopulib") 14 | assert tvc.__version__ == dist_version 15 | 16 | 17 | def test_metadata_fields(): 18 | # simple sanity‐checks of your metadata dunders 19 | assert isinstance(tvc.__version__, str) and len(tvc.__version__) > 0 20 | assert tvc.__title__ == "torchvinecopulib" 21 | assert isinstance(tvc.__author__, str) and len(tvc.__author__) > 10 22 | assert tvc.__url__.startswith("https://github.com/TY-Cheng/torchvinecopulib") 23 | assert isinstance(tvc.__description__, str) and len(tvc.__description__) > 10 24 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | from sphinx_pyproject import SphinxConfig 5 | 6 | sys.path.append(".") 7 | sys.path.insert(0, str(Path(__file__).parents[1])) 8 | # * load the pyproject.toml file using SphinxConfig 9 | # * using Path for better cross-platform compatibility 10 | try: 11 | config = SphinxConfig() 12 | except FileNotFoundError as err: 13 | raise FileNotFoundError("pyproject.toml not found") from err 14 | 15 | # * project metadata 16 | project = config.name 17 | author = config.author 18 | maintainer = config.get("maintainer", author) 19 | copyright = config.get("copyright", f"2024-, {author}") 20 | version = release = config.version 21 | documentation_summary = config.description 22 | extensions = config.get("extensions", []) 23 | html_theme = config.get("html_theme", "furo") 24 | html_title = f"{project} v{version}" 25 | html_theme_options = { 26 | "sidebar_hide_name": False, 27 | # "light_logo": "../torchvinecopulib.png", 28 | # "dark_logo": "../torchvinecopulib.png", 29 | # "sticky_navigation": True, 30 | # "navigation_with_keys": True, 31 | # "footer_text": f"© {copyright}", 32 | # "navigation_depth": 4, 33 | # "titles_only": False, 34 | } 35 | autosummary_generate = True 36 | -------------------------------------------------------------------------------- /examples/vcae/vcae/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | 6 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 7 | torch.set_float32_matmul_precision("medium") 8 | 9 | 10 | @dataclass 11 | class Config: 12 | # Reproducibility 13 | seed: int = 42 14 | 15 | # Training-related 16 | data_dir: str = os.environ.get("PATH_DATASETS", ".") 17 | save_dir: str = "logs/" 18 | batch_size: int = 512 if torch.cuda.is_available() else 64 19 | max_epochs: int = 10 20 | accelerator: str = DEVICE 21 | devices: int = 1 22 | num_workers: int = 1 # or min(15, os.cpu_count()) 23 | 24 | # Data-related 25 | dims: tuple[int, ...] = (1, 28, 28) 26 | val_train_split: float = 0.1 27 | 28 | # Model-related 29 | hidden_size: int = 64 30 | latent_size: int = 10 31 | learning_rate: float = 2e-4 32 | vine_lambda: float = 0.0 33 | # use_mmd: bool = False 34 | # mmd_sigmas: list[float] = [1e-1, 1, 10] 35 | # mmd_lambda: float = 10.0 36 | 37 | config_mnist = Config( 38 | max_epochs=10, 39 | dims=(1, 28, 28), 40 | hidden_size=64, 41 | latent_size=10, 42 | ) 43 | 44 | config_svhn = Config( 45 | max_epochs=50, 46 | dims=(3, 32, 32), 47 | hidden_size=128, 48 | latent_size=32, 49 | ) 50 | -------------------------------------------------------------------------------- /examples/vcae/run_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PYTHON_BIN=$(which python) 4 | USE_NOHUP=false 5 | 6 | # Defaults 7 | START=0 8 | END=30 9 | STEP=10 10 | 11 | # Parse arguments 12 | POSITIONAL=() 13 | while [[ $# -gt 0 ]]; do 14 | case "$1" in 15 | --nohup) 16 | USE_NOHUP=true 17 | shift 18 | ;; 19 | *) 20 | POSITIONAL+=("$1") 21 | shift 22 | ;; 23 | esac 24 | done 25 | 26 | # Restore positional args 27 | set -- "${POSITIONAL[@]}" 28 | 29 | # Assign range values if provided 30 | if [[ $# -ge 1 ]]; then START=$1; fi 31 | if [[ $# -ge 2 ]]; then END=$2; fi 32 | if [[ $# -ge 3 ]]; then STEP=$3; fi 33 | 34 | # Validate input 35 | if (( STEP <= 0 )); then 36 | echo "Error: STEP must be a positive integer." >&2 37 | exit 1 38 | fi 39 | 40 | if (( (END - START) % STEP != 0 )); then 41 | echo "Error: (END - START) must be divisible by STEP." >&2 42 | exit 1 43 | fi 44 | 45 | echo "Using Python binary: $PYTHON_BIN" 46 | echo "Using nohup: $USE_NOHUP" 47 | echo "Range: $START to $END with step $STEP" 48 | 49 | # Launch loop 50 | for ((i = START; i < END; i += STEP)); do 51 | j=$((i + STEP)) 52 | name="seeds_${i}_${j}" 53 | echo "Launching $name" 54 | if $USE_NOHUP; then 55 | nohup "$PYTHON_BIN" run_seeds.py $i $j > logs/$name.log 2>&1 & 56 | else 57 | "$PYTHON_BIN" run_seeds.py $i $j > logs/$name.log 2>&1 & 58 | fi 59 | done 60 | 61 | wait 62 | -------------------------------------------------------------------------------- /examples/vcae/run_seeds.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | import os 4 | import sys 5 | from typing import Union 6 | 7 | import pandas as pd 8 | from tqdm import tqdm 9 | from vcae.config import config_mnist, config_svhn 10 | from vcae.experiment import run_experiment 11 | 12 | dataset = "MNIST" # or "SVHN" 13 | start = int(sys.argv[1]) 14 | end = int(sys.argv[2]) 15 | 16 | if dataset == "MNIST": 17 | config = config_mnist 18 | elif dataset == "SVHN": 19 | config = config_svhn 20 | else: 21 | raise ValueError(f"Unsupported dataset: {dataset}") 22 | 23 | 24 | # Redirect tqdm and errors to log file 25 | log_path = f"progress_{dataset}_{start}_{end}.log" 26 | log_file = open(log_path, "w") 27 | 28 | logging.basicConfig( 29 | level=logging.INFO, 30 | format="%(asctime)s [%(levelname)s] %(message)s", 31 | handlers=[logging.StreamHandler(log_file)], 32 | ) 33 | 34 | 35 | @contextlib.contextmanager 36 | def suppress_output(): 37 | with contextlib.redirect_stdout(log_file), contextlib.redirect_stderr(log_file): 38 | logging_level = logging.getLogger().level 39 | logging.getLogger().setLevel(logging.ERROR) 40 | try: 41 | yield 42 | finally: 43 | logging.getLogger().setLevel(logging_level) 44 | 45 | 46 | results: list[dict[str, Union[float, int, str]]] = [] 47 | output_path = f"results_{dataset}_{start}_{end}.csv" 48 | for seed in tqdm(range(start, end), desc=f"Seeds {start}-{end}", file=log_file): 49 | try: 50 | with suppress_output(): 51 | result = run_experiment(seed, config, dataset=dataset) 52 | df = pd.DataFrame([result]) 53 | 54 | # Write headers only once 55 | if not os.path.exists(output_path): 56 | df.to_csv(output_path, index=False, mode="w") 57 | else: 58 | df.to_csv(output_path, index=False, mode="a", header=False) 59 | 60 | except Exception: 61 | logging.exception(f"Exception while running seed {seed}") 62 | 63 | logging.info(f"All done. Results saved to {output_path}") 64 | log_file.close() 65 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # conftest.py 2 | import numpy as np 3 | import pytest 4 | import pyvinecopulib as pvc 5 | import torch 6 | 7 | import torchvinecopulib as tvc 8 | 9 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | N_SIM = 2000 11 | SEEDS = list(range(5)) 12 | EPS = tvc.util._EPS 13 | # List the (family, true-parameter) pairs you want to test 14 | FAMILIES = [ 15 | (pvc.gaussian, np.array([[0.7]]), 0), # rotation 0 16 | (pvc.clayton, np.array([[0.9]]), 0), 17 | (pvc.clayton, np.array([[0.9]]), 90), 18 | (pvc.clayton, np.array([[0.9]]), 180), 19 | (pvc.clayton, np.array([[0.9]]), 270), 20 | (pvc.frank, np.array([[3.0]]), 0), 21 | (pvc.frank, np.array([[-3.0]]), 0), 22 | # …add more if you like… 23 | ] 24 | 25 | 26 | @pytest.fixture(scope="module", params=FAMILIES, ids=lambda f: f[0].name) 27 | def bicop_pair(request): 28 | """ 29 | Returns a tuple: 30 | ( family, true_params, U_tensor, bicop_fastkde, bicop_tll ) 31 | 32 | notice the scope="module" so that the fixture is created only once and reused in all tests that use it. 33 | """ 34 | family, true_params, rotation = request.param 35 | 36 | # 1) build the 'true' copula and simulate U 37 | true_bc = pvc.Bicop(family=family, parameters=true_params, rotation=rotation) 38 | U = true_bc.simulate(n=N_SIM, seeds=SEEDS) # shape (N_SIM, 2) 39 | U_tensor = torch.tensor(U, device=DEVICE, dtype=torch.float64) 40 | 41 | # 2) fit two torchvinecopulib instances (fast KDE and TLL) 42 | bc_fast = tvc.BiCop(num_step_grid=512).to(DEVICE) 43 | bc_fast.fit(U_tensor, mtd_kde="fastKDE") 44 | 45 | bc_tll = tvc.BiCop(num_step_grid=512).to(DEVICE) 46 | bc_tll.fit(U_tensor, mtd_kde="tll") 47 | 48 | return family, true_params, rotation, U_tensor, bc_fast, bc_tll 49 | 50 | 51 | @pytest.fixture(scope="module") 52 | def U_tensor(): 53 | # a moderately‐sized random [0,1]² sample 54 | return torch.rand(500, 2, dtype=torch.float64) 55 | 56 | 57 | @pytest.fixture(scope="module") 58 | def sample_1d(): 59 | torch.manual_seed(0) 60 | return torch.randn(1024, 1) # standard normal 61 | -------------------------------------------------------------------------------- /.github/workflows/static.yml: -------------------------------------------------------------------------------- 1 | # This workflow builds and deploys the documentation to GitHub Pages 2 | name: Deploy Docs 3 | 4 | on: 5 | # Trigger after your Lint Pytest workflow completes on main branch 6 | workflow_run: 7 | workflows: ["Lint Pytest"] 8 | types: [completed] 9 | branches: ["main"] 10 | 11 | # ALso allow manual dispatch 12 | workflow_dispatch: 13 | 14 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 15 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 16 | concurrency: 17 | group: "pages" 18 | cancel-in-progress: false 19 | 20 | jobs: 21 | deploy: 22 | if: ${{github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success' }} 23 | runs-on: ubuntu-latest 24 | permissions: 25 | contents: write # for pushing to gh-pages branch 26 | 27 | steps: 28 | - name: Checkout code 29 | uses: actions/checkout@v4 30 | 31 | - name: Set up Python 3.11 32 | uses: actions/setup-python@v5 33 | with: 34 | python-version: "3.11" 35 | 36 | - name: Cache pip 37 | uses: actions/cache@v4 38 | with: 39 | path: ~/.cache/pip 40 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }} 41 | 42 | - name: Install uv and sync deps 43 | run: | 44 | python3 -m pip install --upgrade pip 45 | python3 -m pip install uv 46 | python3 -m uv sync --extra cpu 47 | 48 | - name: Build Sphinx docs 49 | # regenerate module RST files (force) and build HTML docs 50 | run: | 51 | uv run sphinx-apidoc --force -o docs torchvinecopulib/ --separate 52 | uv run sphinx-build -b html docs/ docs/_build/html 53 | 54 | # Deploy to gh-pages branch 55 | 56 | - name: Deploy to gh-pages branch 57 | uses: peaceiris/actions-gh-pages@v4 58 | with: 59 | force_orphan: true 60 | github_token: ${{ secrets.GITHUB_TOKEN }} 61 | publish_branch: gh-pages 62 | publish_dir: docs/_build/html 63 | -------------------------------------------------------------------------------- /examples/pred_intvl/vine_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch._dynamo as dynamo 3 | 4 | import torchvinecopulib as tvc 5 | 6 | 7 | def train_vine(Z_train, Y_train, seed=42, device="cpu"): 8 | # * build vine cop on Z and Y together 9 | ZY_train = torch.cat([Z_train, Y_train.view(-1, 1)], dim=1) 10 | # * all except last Y column, for MST in first stage 11 | first_tree_vertex = tuple(range(Z_train.shape[1])) 12 | model_vine: tvc.VineCop = tvc.VineCop(num_dim=ZY_train.shape[1], is_cop_scale=False).to( 13 | device=device 14 | ) 15 | model_vine.fit( 16 | obs=ZY_train, 17 | first_tree_vertex=first_tree_vertex, 18 | mtd_bidep="ferreira_tail_dep_coeff", 19 | mtd_kde="tll", 20 | mtd_tll="quadratic", 21 | seed=seed, 22 | ) 23 | return model_vine 24 | 25 | 26 | @dynamo.disable 27 | @torch.no_grad() 28 | def vine_pred_intvl(model: tvc.VineCop, Z_test, alpha=0.05, seed=42, device="cpu"): 29 | # * assuming Zy is fitted by model; y is the last column 30 | num_sample = Z_test.shape[0] 31 | idx_quantile = Z_test.shape[1] 32 | sample_order = (idx_quantile,) 33 | # * fill marginal obs (will be handled by marginal if not is_cop_scale) 34 | dct_v_s_obs = {(_,): Z_test[:, [_]] for _ in range(Z_test.shape[1])} 35 | # * fill quantile deep in the vine (assuming cop scale) 36 | dct_v_s_obs[model.sample_order] = torch.full((num_sample, 1), alpha / 2, device=device) 37 | lower = model.sample( 38 | num_sample=num_sample, 39 | sample_order=sample_order, 40 | dct_v_s_obs=dct_v_s_obs, 41 | seed=seed, 42 | )[:, idx_quantile] 43 | dct_v_s_obs[model.sample_order] = torch.full((num_sample, 1), 0.5, device=device) 44 | median = model.sample( 45 | num_sample=num_sample, 46 | sample_order=sample_order, 47 | dct_v_s_obs=dct_v_s_obs, 48 | seed=seed, 49 | )[:, idx_quantile] 50 | dct_v_s_obs[model.sample_order] = torch.full((num_sample, 1), 1 - alpha / 2, device=device) 51 | upper = model.sample( 52 | num_sample=num_sample, 53 | sample_order=sample_order, 54 | dct_v_s_obs=dct_v_s_obs, 55 | seed=seed, 56 | )[:, idx_quantile] 57 | return median.cpu().numpy(), lower.cpu().numpy(), upper.cpu().numpy() 58 | -------------------------------------------------------------------------------- /examples/vcae/vcae/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy import linalg 4 | 5 | 6 | def mmd(real: torch.Tensor, fake: torch.Tensor, sigmas=[1e-3, 1e-2, 1e-1, 1, 10, 100]): 7 | """ 8 | Differentiable MMD loss using Gaussian kernels with fixed sigmas and 9 | distance normalization via Mxx.mean(). 10 | 11 | Parameters 12 | ---------- 13 | real : (n, d) tensor 14 | Batch of real samples (features or images). 15 | fake : (m, d) tensor 16 | Batch of generated samples. 17 | sigmas : list of float 18 | Bandwidths for the RBF kernel. Defaults to wide, fixed list. 19 | 20 | Returns 21 | ------- 22 | mmd : scalar tensor 23 | Differentiable scalar loss value. 24 | """ 25 | real = real.view(real.size(0), -1) 26 | fake = fake.view(fake.size(0), -1) 27 | 28 | def pairwise_squared_distances(x, y): 29 | x_norm = (x**2).sum(dim=1, keepdim=True) 30 | y_norm = (y**2).sum(dim=1, keepdim=True) 31 | return x_norm + y_norm.T - 2.0 * x @ y.T 32 | 33 | Mxx = pairwise_squared_distances(real, real) 34 | Mxy = pairwise_squared_distances(real, fake) 35 | Myy = pairwise_squared_distances(fake, fake) 36 | 37 | # Normalization factor based on real-real distances 38 | scale = Mxx.mean().detach() 39 | 40 | mmd_total = 0.0 41 | for sigma in sigmas: 42 | denom = scale * 2.0 * sigma**2 43 | Kxx = torch.exp(-Mxx / denom) 44 | Kxy = torch.exp(-Mxy / denom) 45 | Kyy = torch.exp(-Myy / denom) 46 | 47 | mmd_total += Kxx.mean() + Kyy.mean() - 2.0 * Kxy.mean() 48 | 49 | return mmd_total / len(sigmas) 50 | 51 | 52 | def fid(X, Y): 53 | m = X.mean(0) 54 | m_w = Y.mean(0) 55 | X_np = X.numpy() 56 | Y_np = Y.numpy() 57 | 58 | C = np.cov(X_np.transpose()) 59 | C_w = np.cov(Y_np.transpose()) 60 | C_C_w_sqrt = linalg.sqrtm(C.dot(C_w), True).real 61 | 62 | score = m.dot(m) + m_w.dot(m_w) - 2 * m_w.dot(m) + np.trace(C + C_w - 2 * C_C_w_sqrt) 63 | return np.sqrt(score) 64 | 65 | 66 | class Score: 67 | mmd = 0 68 | fid = 0 69 | 70 | 71 | def compute_score(real, fake, sigmas=[1e-3, 1e-2, 1e-1, 1, 10, 100]): 72 | real = real.to("cpu") 73 | fake = fake.to("cpu") 74 | 75 | s = Score() 76 | s.mmd = np.sqrt(mmd(real, fake, sigmas).numpy()) 77 | s.fid = fid(fake, real).numpy() 78 | 79 | return s 80 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Lint Pytest 5 | 6 | on: 7 | push: 8 | branches: ["main"] 9 | pull_request: 10 | branches: ["main"] 11 | 12 | workflow_dispatch: 13 | 14 | jobs: 15 | lint-pytest: 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | os: [ubuntu-latest, windows-latest, macos-latest] 21 | python-version: ["3.11", "3.12", "3.13"] 22 | steps: 23 | - name: Checkout code 24 | uses: actions/checkout@v4 25 | 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | 31 | - name: Cache pip 32 | uses: actions/cache@v4 33 | with: 34 | path: ~/.cache/pip 35 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }} 36 | 37 | - name: Install uv and sync deps 38 | run: | 39 | python3 -m pip install --upgrade pip 40 | python3 -m pip install uv 41 | python3 -m uv sync --extra cpu 42 | # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 43 | 44 | - name: Lint with flake8 45 | run: | 46 | # stop the build if there are Python syntax errors or undefined names 47 | uv run flake8 ./torchvinecopulib --exclude .venv --count --select=E9,F63,F7,F82 --show-source --statistics 48 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 49 | uv run flake8 ./torchvinecopulib --count --exit-zero --max-complexity=10 --max-line-length=127 --ignore=E203,W503,E501,C901 --statistics 50 | 51 | - name: Test with pytest 52 | run: | 53 | uv run python -c "import sys; print(sys.version)" 54 | uv run python -c "import torch; print(torch.cuda.is_available())" 55 | uv run coverage run --source=torchvinecopulib -m pytest tests 56 | uv run coverage report 57 | uv run coverage xml -o coverage.xml 58 | 59 | - name: Upload coverage to Codacy 60 | if: ${{ matrix.os == 'ubuntu-latest' && matrix.python-version == '3.13' }} 61 | run: | 62 | export CODACY_API_BASE_URL=${{ secrets.CODACY_API_BASE_URL }} 63 | export CODACY_PROJECT_TOKEN=${{ secrets.CODACY_PROJECT_TOKEN }} 64 | bash <(curl -Ls https://coverage.codacy.com/get.sh) report -r ./coverage.xml 65 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Put a project logo here if you have one 2 | .. .. image:: _static/logo.png 3 | .. :alt: torchvinecopulib Logo 4 | 5 | ============================================ 6 | Welcome to torchvinecopulib's documentation! 7 | ============================================ 8 | 9 | .. Add badges here 10 | .. raw:: html 11 | 12 |

13 | Codacy Grade 14 | Codacy Coverage 15 | Lint Pytest 16 | Deploy Docs 17 |
18 | PyPI - Python Version 19 | OS Compatibility 20 |
21 | GitHub License 22 | PyPI - Version 23 | DOI 24 |

25 | 26 | ``torchvinecopulib`` is a ``Python`` library for fitting and sampling vine copulas using ``PyTorch``. 27 | It is designed for researchers and practitioners in statistics, machine learning, and finance who need flexible, GPU-accelerated, and vectorized copula modeling and sampling. 28 | 29 | **GitHub Repository:** https://github.com/TY-Cheng/torchvinecopulib 30 | 31 | .. toctree:: 32 | :maxdepth: 3 33 | :caption: Core Modules 34 | :titlesonly: 35 | 36 | modules.rst 37 | 38 | 39 | Indices and tables 40 | ================== 41 | 42 | * :ref:`genindex` 43 | * :ref:`modindex` 44 | * :ref:`search` -------------------------------------------------------------------------------- /examples/vcae/vcae/experiment.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Union 3 | 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | 8 | from .config import DEVICE, Config 9 | from .metrics import compute_score 10 | from .model import LitAutoencoder, LitMNISTAutoencoder, LitSVHNAutoencoder 11 | 12 | 13 | def set_seed(seed: int): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | torch.backends.cudnn.benchmark = True 19 | torch.backends.cudnn.enabled = True 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | def run_experiment( 24 | seed: int, config: Config, vine_lambda: float = 1.0, dataset: str = "MNIST" 25 | ) -> dict[str, Union[float, int, str]]: 26 | # Set the seed for reproducibility 27 | set_seed(seed) 28 | config.seed = seed 29 | 30 | # Instantiate the model 31 | model_initial: LitAutoencoder 32 | if dataset == "MNIST": 33 | model_initial = LitMNISTAutoencoder(config) 34 | elif dataset == "SVHN": 35 | model_initial = LitSVHNAutoencoder(config) 36 | else: 37 | raise ValueError(f"Unsupported dataset: {dataset}") 38 | 39 | # Set up trainer 40 | trainer_initial = pl.Trainer( 41 | accelerator=config.accelerator, 42 | devices=config.devices, 43 | max_epochs=config.max_epochs, 44 | logger=False, # disables all loggers 45 | enable_progress_bar=False, # disables tqdm 46 | enable_model_summary=False, # disables model summary printout 47 | ) 48 | 49 | # Train the base autoencoder 50 | trainer_initial.fit(model_initial) 51 | 52 | # Stay on DEVICE 53 | model_initial.to(DEVICE) 54 | 55 | # Learn vine 56 | model_initial.learn_vine(n_samples=5000) 57 | 58 | # Extract test data 59 | rep_initial, _, data_initial, decoded_initial, samples_initial = model_initial.get_data( 60 | stage="test" 61 | ) 62 | 63 | # Reset the seed for refitting to avoid data leakage 64 | set_seed(seed) 65 | 66 | # Create a new model with the same configuration but reset vine lambda 67 | config.vine_lambda = vine_lambda 68 | model_refit = model_initial.copy_with_config(config) 69 | 70 | # Set up trainer for refitting 71 | trainer_refit = pl.Trainer( 72 | accelerator=config.accelerator, 73 | devices=config.devices, 74 | max_epochs=config.max_epochs, 75 | logger=False, # disables all loggers 76 | enable_progress_bar=False, # disables tqdm 77 | enable_model_summary=False, # disables model summary printout 78 | ) 79 | 80 | # Refit the model 81 | trainer_refit.fit(model_refit) 82 | 83 | # Stay on DEVICE 84 | model_refit.to(DEVICE) 85 | 86 | # Extract test data 87 | rep_refit, _, data_refit, decoded_refit, samples_refit = model_refit.get_data(stage="test") 88 | 89 | assert model_initial.vine is not None 90 | assert model_refit.vine is not None 91 | loglik_initial = model_initial.vine.log_pdf(rep_initial).mean().item() 92 | loglik_refit = model_refit.vine.log_pdf(rep_refit).mean().item() 93 | 94 | mse_initial = torch.nn.functional.mse_loss(decoded_initial, data_initial).item() 95 | mse_refit = torch.nn.functional.mse_loss(decoded_refit, data_refit).item() 96 | 97 | sigmas = [1e-3, 1e-2, 1e-1, 1, 10, 100] 98 | score_initial = compute_score(data_initial, samples_initial, sigmas=sigmas) 99 | score_refit = compute_score(data_refit, samples_refit, sigmas=sigmas) 100 | 101 | return { 102 | "seed": seed, 103 | "dataset": dataset, 104 | "mse_initial": mse_initial, 105 | "mse_refit": mse_refit, 106 | "loglik_initial": loglik_initial, 107 | "loglik_refit": loglik_refit, 108 | "mmd_initial": score_initial.mmd, 109 | "mmd_refit": score_refit.mmd, 110 | "fid_initial": score_initial.fid, 111 | "fid_refit": score_refit.fid, 112 | } 113 | -------------------------------------------------------------------------------- /.github/workflows/publish-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow builds a Python package and publishes it to PyPI and TestPyPI. 2 | name: Publish to PyPI and TestPyPI 3 | 4 | on: 5 | push: 6 | branches: ["main"] 7 | tags: 8 | - "v*.*.*" # Match semantic version tags 9 | workflow_dispatch: 10 | # Allow manual dispatch for testing or immediate releases 11 | 12 | jobs: 13 | build: 14 | name: Build distributions 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up Python 3.11 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: 3.11 24 | 25 | - name: Cache pip 26 | uses: actions/cache@v4 27 | with: 28 | path: ~/.cache/pip 29 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }} 30 | 31 | - name: Install uv and sync deps, build sdist and wheel 32 | run: | 33 | python3 -m pip install --upgrade pip 34 | python3 -m pip install uv 35 | python3 -m uv sync --extra cpu 36 | python3 -m uv build 37 | 38 | - name: Upload dist/* 39 | uses: actions/upload-artifact@v4 40 | with: 41 | name: python-distribution-packages 42 | path: dist/* 43 | 44 | publish-testpypi: 45 | name: Publish to TestPyPI 46 | needs: build 47 | runs-on: ubuntu-latest 48 | environment: 49 | name: testpypi 50 | url: https://test.pypi.org/p/torchvinecopulib/ 51 | permissions: 52 | id-token: write 53 | # contents: write 54 | steps: 55 | - name: Download dist/* 56 | uses: actions/download-artifact@v4 57 | with: 58 | name: python-distribution-packages 59 | path: dist/ 60 | 61 | - name: Publish to TestPyPI 62 | uses: pypa/gh-action-pypi-publish@release/v1 63 | with: 64 | verbose: true 65 | skip-existing: true 66 | repository-url: https://test.pypi.org/legacy/ 67 | env: 68 | TWINE_USERNAME: __token__ 69 | TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} 70 | 71 | publish-pypi: 72 | name: Publish to PyPI 73 | runs-on: ubuntu-latest 74 | needs: publish-testpypi 75 | if: startsWith(github.ref, 'refs/tags/v') 76 | environment: 77 | name: pypi 78 | url: https://pypi.org/project/torchvinecopulib/ 79 | permissions: 80 | id-token: write 81 | # contents: write 82 | steps: 83 | - name: Download dist/* 84 | uses: actions/download-artifact@v4 85 | with: 86 | name: python-distribution-packages 87 | path: dist/ 88 | 89 | - name: Publish to PyPI via Twine 90 | uses: pypa/gh-action-pypi-publish@release/v1 91 | env: 92 | TWINE_USERNAME: __token__ 93 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 94 | with: 95 | verbose: true 96 | skip-existing: true 97 | 98 | github-release: 99 | name: Create GitHub Release 100 | runs-on: ubuntu-latest 101 | needs: publish-testpypi 102 | if: startsWith(github.ref, 'refs/tags/v') 103 | permissions: 104 | contents: write 105 | id-token: write 106 | 107 | steps: 108 | - name: Download dist/* 109 | uses: actions/download-artifact@v4 110 | with: 111 | name: python-distribution-packages 112 | path: dist/ 113 | 114 | - name: Sign the distribution packages with Sigstore 115 | uses: sigstore/gh-action-sigstore-python@v3.0.0 116 | with: 117 | inputs: | 118 | ./dist/*.tar.gz 119 | ./dist/*.whl 120 | 121 | - name: Create GitHub Release and Upload Assets 122 | env: 123 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 124 | run: | 125 | gh release create "${{ github.ref_name }}" \ 126 | --repo "${{ github.repository }}" \ 127 | --title "Release ${{ github.ref_name }}" \ 128 | --notes "Automated release for ${{ github.ref_name }}. Signed artifacts are attached." \ 129 | dist/* 130 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # List/Specify python version 2 | # uv python list 3 | # uv python pin python3.13 4 | # Install default + dev dependencies with torch for CPU/cu126: 5 | # uv sync --extra cpu -U 6 | # uv sync --extra cu126 -U 7 | 8 | [build-system] 9 | requires = ["hatchling"] 10 | build-backend = "hatchling.build" 11 | 12 | [project.urls] 13 | # homepage = "" 14 | repository = "https://github.com/TY-Cheng/torchvinecopulib" 15 | documentation = "https://ty-cheng.github.io/torchvinecopulib/" 16 | 17 | [project] 18 | name = "torchvinecopulib" 19 | version = "1.1.2" 20 | requires-python = ">=3.11" 21 | authors = [{ name = "Tuoyuan Cheng", email = "tuoyuan.cheng@nus.edu.sg" }] 22 | description = "Fitting and sampling vine copulas using PyTorch." 23 | readme = "README.md" 24 | keywords = ["vine copula", "copula", "torch", "conditional simulation"] 25 | classifiers = [ 26 | # "Private :: Do Not Upload", 27 | "License :: OSI Approved :: MIT License", 28 | "Environment :: GPU :: NVIDIA CUDA :: 12", 29 | "Operating System :: Microsoft :: Windows :: Windows 10", 30 | "Operating System :: POSIX :: Linux", 31 | "Operating System :: MacOS", 32 | "Programming Language :: Python :: 3.11", 33 | "Programming Language :: Python :: 3.12", 34 | "Programming Language :: Python :: 3.13", 35 | "Topic :: Scientific/Engineering", 36 | ] 37 | dependencies = ["numpy>=2", "scipy", "fastkde==2.1.3", "pyvinecopulib"] 38 | 39 | [dependency-groups] 40 | dev = [ 41 | "blitz-bayesian-pytorch", 42 | "ccxt", 43 | "coverage", 44 | "datasets", 45 | "docformatter", 46 | "flake8", 47 | "furo", 48 | "ipykernel", 49 | "kagglehub", 50 | "matplotlib", 51 | "missingno", 52 | "mypy", 53 | "pandas", 54 | "pytest-cov", 55 | "pytest", 56 | "python-dotenv", 57 | "ruff", 58 | "scikit-learn", 59 | "sphinx_pyproject", 60 | "sphinx", 61 | "tokenize-rt", 62 | "torchvision", 63 | "ucimlrepo", 64 | "yfinance", 65 | ] 66 | 67 | [project.optional-dependencies] 68 | cpu = ["torch>=2"] 69 | cu126 = ["torch>=2"] 70 | cu128 = ["torch>=2"] 71 | examples = ["pytorch-lightning", "tqdm"] 72 | 73 | [tool.uv] 74 | managed = true 75 | default-groups = ["dev"] 76 | conflicts = [[{ extra = "cpu" }, { extra = "cu126" }, { extra = "cu128" }]] 77 | 78 | [tool.uv.sources] 79 | torch = [ 80 | { index = "torch-cpu", extra = "cpu" }, 81 | { index = "torch-cu126", extra = "cu126" }, 82 | { index = "torch-cu128", extra = "cu128" }, 83 | ] 84 | 85 | [[tool.uv.index]] 86 | name = "torch-cpu" 87 | url = "https://download.pytorch.org/whl/cpu" 88 | explicit = true 89 | 90 | [[tool.uv.index]] 91 | name = "torch-cu126" 92 | url = "https://download.pytorch.org/whl/cu126" 93 | explicit = true 94 | 95 | [[tool.uv.index]] 96 | name = "torch-cu128" 97 | url = "https://download.pytorch.org/whl/cu128" 98 | explicit = true 99 | 100 | [[tool.uv.index]] 101 | name = "testpypi" 102 | url = "https://test.pypi.org/simple/" 103 | publish-url = "https://test.pypi.org/legacy/" 104 | explicit = true 105 | 106 | [tool.ruff] 107 | line-length = 99 108 | 109 | [tool.ruff.format] 110 | quote-style = "double" 111 | indent-style = "space" 112 | docstring-code-format = true 113 | docstring-code-line-length = 99 114 | 115 | [tool.ruff.lint.pycodestyle] 116 | max-doc-length = 99 117 | max-line-length = 99 118 | 119 | [tool.docformatter] 120 | wrap-summaries = 99 121 | wrap-descriptions = 99 122 | blank-lines-around-summary = true 123 | in-place = true 124 | make-summary-multi-line = true 125 | pre-summary-space = true 126 | recursive = true 127 | skip-ignore = false 128 | 129 | [tool.flake8] 130 | max-line-length = 127 131 | extend-ignore = ["E203", "W503"] 132 | exclude = [".venv", "build", "dist", "docs"] 133 | per-file-ignores = [ 134 | "torchvinecopulib/vinecop/__init__.py: E501, C901", 135 | "torchvinecopulib/bicop/__init__.py: E501, C901", 136 | "torchvinecopulib/util/__init__.py: E501, C901", 137 | ] 138 | 139 | [tool.sphinx-pyproject] 140 | source-dir = "docs" 141 | build-dir = "docs/_build" 142 | all_files = true 143 | extensions = [ 144 | "sphinx.ext.autodoc", 145 | "sphinx.ext.autosectionlabel", 146 | "sphinx.ext.autosummary", 147 | "sphinx.ext.coverage", 148 | "sphinx.ext.doctest", 149 | "sphinx.ext.extlinks", 150 | "sphinx.ext.githubpages", 151 | "sphinx.ext.graphviz", 152 | "sphinx.ext.ifconfig", 153 | "sphinx.ext.imgconverter", 154 | "sphinx.ext.inheritance_diagram", 155 | "sphinx.ext.intersphinx", 156 | "sphinx.ext.napoleon", 157 | "sphinx.ext.todo", 158 | "sphinx.ext.viewcode", 159 | ] 160 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 161 | html_theme = "furo" 162 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from scipy.stats import kendalltau as scipy_tau 4 | 5 | from torchvinecopulib.util import ( 6 | ENUM_FUNC_BIDEP, 7 | chatterjee_xi, 8 | ferreira_tail_dep_coeff, 9 | kdeCDFPPF1D, 10 | kendall_tau, 11 | mutual_info, 12 | solve_ITP, 13 | ) 14 | 15 | from . import EPS, U_tensor, bicop_pair, sample_1d 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "x,y,expected", 20 | [ 21 | ([1, 2, 3], [1, 2, 3], 1.0), 22 | ([1, 2, 3], [3, 2, 1], -1.0), 23 | ], 24 | ) 25 | def test_kendall_tau_perfect(x, y, expected): 26 | x = torch.tensor(x).view(-1, 1).double() 27 | y = torch.tensor(y).view(-1, 1).double() 28 | tau, p = kendall_tau(x, y) 29 | assert pytest.approx(expected, abs=1e-6) == tau.item() 30 | 31 | 32 | def test_kendall_tau_matches_scipy_random(): 33 | torch.manual_seed(0) 34 | x = torch.rand(50, 1) 35 | y = torch.rand(50, 1) 36 | tau_torch, p = kendall_tau(x, y) 37 | tau_scipy, p_scipy = scipy_tau(x.flatten().numpy(), y.flatten().numpy()) 38 | assert pytest.approx(tau_scipy, rel=1e-6) == tau_torch.item() 39 | 40 | 41 | def test_mutual_info_independent(): 42 | torch.manual_seed(0) 43 | mi = mutual_info(*torch.rand(2, 10000)) 44 | assert abs(mi.item()) < 2e-2 # near zero 45 | 46 | 47 | def test_mutual_info_dependent(): 48 | torch.manual_seed(0) 49 | u = torch.rand(500, 1) 50 | v = u.clone() 51 | mi = mutual_info(u, v) 52 | assert mi.item() > 0.5 # significantly positive 53 | 54 | 55 | def test_tail_dep_perfect(): 56 | u = torch.linspace(0, 1, 500).view(-1, 1) 57 | # y=u → perfect 58 | lam = ferreira_tail_dep_coeff(u, u) 59 | assert pytest.approx(1.0, rel=1e-3) == lam.item() 60 | 61 | 62 | def test_tail_dep_independent(): 63 | torch.manual_seed(0) 64 | u = torch.rand(1000, 1) 65 | v = torch.rand(1000, 1) 66 | lam = ferreira_tail_dep_coeff(u, v) 67 | assert lam.item() < 0.2 # near zero 68 | 69 | 70 | def test_xi_perfect(): 71 | u = torch.arange(1, 101).view(-1, 1).double() 72 | xi = chatterjee_xi(u, u) 73 | assert pytest.approx(1.0, abs=3e-2) == xi.item() 74 | 75 | 76 | def test_xi_independent(): 77 | torch.manual_seed(1) 78 | u = torch.rand(1000, 1) 79 | v = torch.rand(1000, 1) 80 | xi = chatterjee_xi(u, v) 81 | assert abs(xi.item()) < 0.1 82 | 83 | 84 | def test_enum_dispatches_correctly(): 85 | u = torch.rand(50, 1) 86 | v = u.clone() 87 | # perfect correlation → tau=1, xi=1, tail=1, mi>0 88 | out_tau = ENUM_FUNC_BIDEP.kendall_tau(u, v) 89 | out_xi = ENUM_FUNC_BIDEP.chatterjee_xi(u, v) 90 | out_tail = ENUM_FUNC_BIDEP.ferreira_tail_dep_coeff(u, v) 91 | out_mi = ENUM_FUNC_BIDEP.mutual_info(u, v) 92 | assert pytest.approx(1.0, abs=3e-2) == out_tau[0].item() 93 | assert pytest.approx(1.0, abs=3e-2) == out_xi.item() 94 | assert pytest.approx(1.0, rel=3e-2) == out_tail.item() 95 | assert out_mi.item() > 0.5 96 | 97 | 98 | def test_kde_cdf_ppf_inverse(sample_1d): 99 | kde = kdeCDFPPF1D(sample_1d, num_step_grid=257) 100 | range = kde.x_max - kde.x_min 101 | xs = torch.linspace( 102 | kde.x_min + range / 10, kde.x_max - range / 10, 50, dtype=torch.float64 103 | ).view(-1, 1) 104 | qs = kde.cdf(xs) 105 | xs_rec = kde.ppf(qs) 106 | assert torch.allclose(xs, xs_rec, atol=1e-3) 107 | 108 | 109 | def test_kde_bounds_and_pdf(sample_1d): 110 | kde = kdeCDFPPF1D(sample_1d, num_step_grid=257) 111 | # cdf out-of-bounds 112 | oob = torch.tensor([[kde.x_min - 1.0], [kde.x_max + 1.0]], dtype=torch.float64) 113 | assert torch.all(kde.cdf(oob) == torch.tensor([[0.0], [1.0]], dtype=torch.float64)) 114 | # ppf out-of-bounds 115 | assert torch.allclose( 116 | torch.tensor([[kde.x_min], [kde.x_max]], dtype=torch.float64), 117 | kde.ppf(torch.tensor([[-1.0], [2.0]])), 118 | ) 119 | # pdf ≥ 0 120 | pts = torch.linspace(kde.x_min, kde.x_max, 100).view(-1, 1) 121 | assert (kde.pdf(pts) >= 0).all() 122 | # log_pdf finite 123 | assert torch.isfinite(kde.log_pdf(pts)).all() 124 | 125 | 126 | def test_kde_negloglik_forward(sample_1d): 127 | kde = kdeCDFPPF1D(sample_1d, num_step_grid=None) 128 | val1 = kde.negloglik 129 | val2 = kde.forward(sample_1d) 130 | assert pytest.approx(val1.item(), rel=1e-6) == val2.item() 131 | 132 | 133 | def test_kde_str(sample_1d): 134 | kde = kdeCDFPPF1D(sample_1d, num_step_grid=257) 135 | str_repr = str(kde) 136 | assert "kdeCDFPPF1D" in str_repr 137 | assert "num_step_grid" in str_repr 138 | assert "257" in str_repr 139 | assert "x_min" in str_repr 140 | assert "x_max" in str_repr 141 | 142 | 143 | def test_solve_itp_scalar(): 144 | # f(x)=x-0.3 has root at 0.3 145 | f = lambda x: x - 0.3 146 | root = solve_ITP(f, torch.tensor(0.0), torch.tensor(1.0)) 147 | assert pytest.approx(0.3, abs=1e-6) == root.item() 148 | 149 | 150 | def test_solve_itp_vectorized(): 151 | # two independent eq’s: x-a=0, x-b=0 152 | a = torch.tensor([0.2, 0.7]) 153 | b = torch.tensor([1.0, 1.0]) 154 | f = lambda x: x - a 155 | roots = solve_ITP(f, torch.zeros_like(a), b) 156 | assert torch.allclose(roots, a, atol=1e-6) 157 | -------------------------------------------------------------------------------- /examples/vcae/vcae_analyze.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "8410dbd6", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# add autoreload and other ipython magic\n", 11 | "import glob\n", 12 | "import os\n", 13 | "\n", 14 | "import numpy as np\n", 15 | "import pandas as pd\n", 16 | "from IPython import get_ipython\n", 17 | "\n", 18 | "ipython = get_ipython()\n", 19 | "ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", 20 | "ipython.run_line_magic(\"autoreload\", \"2\")\n", 21 | "ipython.run_line_magic(\"matplotlib\", \"inline\")\n", 22 | "ipython.run_line_magic(\"config\", 'InlineBackend.figure_format = \"retina\"')" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "8ac6e908", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# Replace with your folder path\n", 33 | "folder = os.getcwd()\n", 34 | "\n", 35 | "# List all matching CSV files\n", 36 | "csv_files = glob.glob(os.path.join(folder, \"results_*.csv\"))\n", 37 | "\n", 38 | "# Read and concatenate all into a single DataFrame\n", 39 | "df_all = pd.concat((pd.read_csv(f) for f in csv_files), ignore_index=True)\n", 40 | "\n", 41 | "# Make the data longer for analysis\n", 42 | "df_long = (\n", 43 | " df_all.melt(id_vars=\"seed\", var_name=\"variable\", value_name=\"value\")\n", 44 | " .assign(\n", 45 | " metric=lambda x: x[\"variable\"].str.extract(r\"^([a-zA-Z]+)\"),\n", 46 | " method=lambda x: x[\"variable\"].str.extract(r\"_(.*)$\").fillna(\"\"),\n", 47 | " )\n", 48 | " .drop(columns=[\"variable\"])\n", 49 | " .pivot(index=[\"seed\", \"method\"], columns=\"metric\", values=\"value\")\n", 50 | " .reset_index()\n", 51 | " .drop(columns=[\"seed\"])\n", 52 | ")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "f854dc62", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/html": [ 64 | "
\n", 65 | "\n", 78 | "\n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | "
mseloglikmmdfid
initial0.054 ± 0.001-17.179 ± 1.1310.514 ± 0.0276.259 ± 0.099
refit0.052 ± 0.001-8.098 ± 1.1990.492 ± 0.036.62 ± 0.197
\n", 105 | "
" 106 | ], 107 | "text/plain": [ 108 | " mse loglik mmd fid\n", 109 | "initial 0.054 ± 0.001 -17.179 ± 1.131 0.514 ± 0.027 6.259 ± 0.099\n", 110 | "refit 0.052 ± 0.001 -8.098 ± 1.199 0.492 ± 0.03 6.62 ± 0.197" 111 | ] 112 | }, 113 | "execution_count": 61, 114 | "metadata": {}, 115 | "output_type": "execute_result" 116 | } 117 | ], 118 | "source": [ 119 | "def mean_ci(x, digits: int = 3) -> str:\n", 120 | " x = np.asarray(x, dtype=np.float64)\n", 121 | " n = x.size\n", 122 | " if n == 0:\n", 123 | " return \"nan ± nan\"\n", 124 | " mean = np.mean(x)\n", 125 | " se = np.std(x, ddof=1) / np.sqrt(n)\n", 126 | " ci_half_width = 1.96 * se\n", 127 | " return f\"{round(mean, digits)} ± {round(ci_half_width, digits)}\"\n", 128 | "\n", 129 | "\n", 130 | "summary = df_long.groupby(\"method\").agg(mean_ci)\n", 131 | "summary = summary[[\"mse\", \"loglik\", \"mmd\", \"fid\"]]\n", 132 | "summary.columns.name = None\n", 133 | "summary.index.name = None\n", 134 | "summary" 135 | ] 136 | } 137 | ], 138 | "metadata": { 139 | "kernelspec": { 140 | "display_name": ".venv", 141 | "language": "python", 142 | "name": "python3" 143 | }, 144 | "language_info": { 145 | "codemirror_mode": { 146 | "name": "ipython", 147 | "version": 3 148 | }, 149 | "file_extension": ".py", 150 | "mimetype": "text/x-python", 151 | "name": "python", 152 | "nbconvert_exporter": "python", 153 | "pygments_lexer": "ipython3", 154 | "version": "3.13.5" 155 | } 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 5 159 | } 160 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Manual 2 | /scratch 3 | /docs 4 | /Styles 5 | /data 6 | /out 7 | *.pkl 8 | *.svg 9 | checkpoints/ 10 | MNIST/ 11 | logs/ 12 | examples/MNIST 13 | examples/logs 14 | examples/checkpoints 15 | examples/**/*.gz 16 | # examples/**/*.ipynb 17 | 18 | # Python-generated files 19 | __pycache__/ 20 | *.py[oc] 21 | build/ 22 | dist/ 23 | wheels/ 24 | *.egg-info 25 | 26 | # Virtual environments 27 | .venv 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | share/python-wheels/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | MANIFEST 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .nox/ 71 | .coverage 72 | .coverage.* 73 | .cache 74 | nosetests.xml 75 | coverage.xml 76 | *.cover 77 | *.py,cover 78 | .hypothesis/ 79 | .pytest_cache/ 80 | cover/ 81 | 82 | # Translations 83 | *.mo 84 | *.pot 85 | 86 | # Django stuff: 87 | *.log 88 | local_settings.py 89 | db.sqlite3 90 | db.sqlite3-journal 91 | 92 | # Flask stuff: 93 | instance/ 94 | .webassets-cache 95 | 96 | # Scrapy stuff: 97 | .scrapy 98 | 99 | # Sphinx documentation 100 | docs/_build/ 101 | 102 | # PyBuilder 103 | .pybuilder/ 104 | target/ 105 | 106 | # Jupyter Notebook 107 | .ipynb_checkpoints 108 | 109 | # IPython 110 | profile_default/ 111 | ipython_config.py 112 | 113 | # pyenv 114 | # For a library or package, you might want to ignore these files since the code is 115 | # intended to run in multiple environments; otherwise, check them in: 116 | .python-version 117 | 118 | # pipenv 119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 122 | # install all needed dependencies. 123 | #Pipfile.lock 124 | 125 | # UV 126 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 127 | # This is especially recommended for binary packages to ensure reproducibility, and is more 128 | # commonly ignored for libraries. 129 | uv.lock 130 | 131 | # poetry 132 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 133 | # This is especially recommended for binary packages to ensure reproducibility, and is more 134 | # commonly ignored for libraries. 135 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 136 | poetry.lock 137 | 138 | # pdm 139 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 140 | #pdm.lock 141 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 142 | # in version control. 143 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 144 | .pdm.toml 145 | .pdm-python 146 | .pdm-build/ 147 | 148 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 149 | __pypackages__/ 150 | 151 | # Celery stuff 152 | celerybeat-schedule 153 | celerybeat.pid 154 | 155 | # SageMath parsed files 156 | *.sage.py 157 | 158 | # Environments 159 | .env 160 | .venv 161 | env/ 162 | venv/ 163 | ENV/ 164 | env.bak/ 165 | venv.bak/ 166 | 167 | # Spyder project settings 168 | .spyderproject 169 | .spyproject 170 | 171 | # Rope project settings 172 | .ropeproject 173 | 174 | # mkdocs documentation 175 | /site 176 | 177 | # mypy 178 | .mypy_cache/ 179 | .dmypy.json 180 | dmypy.json 181 | 182 | # Pyre type checker 183 | .pyre/ 184 | 185 | # pytype static type analyzer 186 | .pytype/ 187 | 188 | # Cython debug symbols 189 | cython_debug/ 190 | 191 | # PyCharm 192 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 193 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 194 | # and can be added to the global gitignore or merged into this file. For a more nuclear 195 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 196 | #.idea/ 197 | 198 | # Ruff stuff: 199 | .ruff_cache/ 200 | 201 | # PyPI configuration file 202 | .pypirc 203 | 204 | # General 205 | .DS_Store 206 | .AppleDouble 207 | .LSOverride 208 | 209 | # Icon must end with two \r 210 | Icon 211 | 212 | # Thumbnails 213 | ._* 214 | 215 | # Files that might appear in the root of a volume 216 | .DocumentRevisions-V100 217 | .fseventsd 218 | .Spotlight-V100 219 | .TemporaryItems 220 | .Trashes 221 | .VolumeIcon.icns 222 | .com.apple.timemachine.donotpresent 223 | 224 | # Directories potentially created on remote AFP share 225 | .AppleDB 226 | .AppleDesktop 227 | Network Trash Folder 228 | Temporary Items 229 | .apdisk 230 | 231 | .vscode/* 232 | # !.vscode/settings.json 233 | !.vscode/tasks.json 234 | !.vscode/launch.json 235 | !.vscode/extensions.json 236 | !.vscode/*.code-snippets 237 | 238 | # Local History for Visual Studio Code 239 | .history/ 240 | 241 | # Built Visual Studio Code Extensions 242 | *.vsix 243 | 244 | *.csv 245 | *.mat -------------------------------------------------------------------------------- /examples/pred_intvl/experiments.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | from itertools import product 4 | 5 | import numpy as np 6 | import torch 7 | from data_utils import ( 8 | DIR_WORK, 9 | device, 10 | extract_XY, 11 | get_logger, 12 | load_california_housing, 13 | load_news_popularity, 14 | ) 15 | from models import ( 16 | bnn_pred_intvl, 17 | ensemble_pred_intvl, 18 | extract_features, 19 | mc_dropout_pred_intvl, 20 | train_base_model, 21 | train_bnn, 22 | train_ensemble, 23 | ) 24 | from vine_wrapper import train_vine, vine_pred_intvl 25 | 26 | DATASETS = { 27 | "california": load_california_housing, 28 | "news": load_news_popularity, 29 | } 30 | ALPHA = 0.05 31 | DIR_OUT = DIR_WORK / "examples" / "pred_intvl" / "out" 32 | DIR_OUT.mkdir(parents=True, exist_ok=True) 33 | LATENT_DIM = 10 34 | lst_lr = [1e-2, 1e-3, 1e-4] 35 | lst_weight_decay = [0.0, 0.91, 0.99] 36 | lst_p_drop = [0.1, 0.3, 0.5] 37 | lr = lst_lr[1] 38 | weight_decay = lst_weight_decay[1] 39 | p_drop = lst_p_drop[1] 40 | # * ===== is_test ===== 41 | is_test = False 42 | # * ===== is_test ===== 43 | 44 | if is_test: 45 | DIR_OUT = DIR_OUT / "tmp" 46 | SEEDS = list(range(3)) 47 | N_ENSEMBLE = 4 48 | NUM_EPOCH = 5 49 | else: 50 | SEEDS = list(range(1, 100)) 51 | N_ENSEMBLE = 20 52 | NUM_EPOCH = 150 53 | 54 | 55 | if __name__ == "__main__": 56 | for seed, ds_name in product( 57 | SEEDS, 58 | DATASETS.keys(), 59 | ): 60 | file = DIR_OUT / f"PredIntvl_{ds_name}_{seed}.pkl" 61 | file_log = DIR_OUT / f"PredIntvl_{ds_name}_{seed}.log" 62 | logger = get_logger(file_log, name=f"PredIntvl_{ds_name}_{seed}") 63 | if file.exists(): 64 | logger.warning(f"File {file} already exists, skipping...") 65 | continue 66 | logger.info(f"Running {ds_name} with seed {seed}...") 67 | # * set seed 68 | torch.manual_seed(seed) 69 | torch.cuda.manual_seed_all(seed) 70 | torch.backends.cudnn.deterministic = True 71 | torch.backends.cudnn.benchmark = False 72 | np.random.seed(seed) 73 | # * load data 74 | try: 75 | train_loader, val_loader, test_loader, xsc, ysc = DATASETS[ds_name](seed_val=seed) 76 | X_test, Y_test = extract_XY(test_loader, device) 77 | except Exception as e: 78 | logger.error(f"Error loading dataset {ds_name}: {e}") 79 | continue 80 | input_dim = X_test.shape[1] 81 | dct_result = { 82 | "seed": seed, 83 | "dataset": ds_name, 84 | "alpha": ALPHA, 85 | "Y_test": Y_test.flatten().cpu().numpy(), 86 | } 87 | # * train / fit: get PI on test 88 | # ! ensemble 89 | try: 90 | logger.info("Training ensemble...") 91 | torch.manual_seed(seed) 92 | ensemble = train_ensemble( 93 | M=N_ENSEMBLE, 94 | train_loader=train_loader, 95 | val_loader=val_loader, 96 | input_dim=input_dim, 97 | latent_dim=LATENT_DIM, 98 | lr=lr, 99 | weight_decay=weight_decay, 100 | device=device, 101 | num_epochs=NUM_EPOCH, 102 | ) 103 | dct_result["ensemble"] = ensemble_pred_intvl(ensemble, X_test, device, ALPHA) 104 | except Exception as e: 105 | logger.error(f"Error training ensemble: {e}") 106 | dct_result["ensemble"] = (None, None, None) 107 | # ! base with mc dropout 108 | try: 109 | logger.info("Training base model with MC dropout...") 110 | torch.manual_seed(seed) 111 | model_mcdropout = train_base_model( 112 | train_loader=train_loader, 113 | val_loader=val_loader, 114 | input_dim=input_dim, 115 | lr=lr, 116 | weight_decay=weight_decay, 117 | p_drop=p_drop, 118 | latent_dim=LATENT_DIM, 119 | num_epoch=NUM_EPOCH, 120 | device=device, 121 | ) 122 | dct_result["mcdropout"] = mc_dropout_pred_intvl( 123 | model_mcdropout, X_test, T=200, device=device, alpha=ALPHA 124 | ) 125 | except Exception as e: 126 | logger.error(f"Error training base model with MC dropout: {e}") 127 | dct_result["mcdropout"] = (None, None, None) 128 | # ! bnn 129 | try: 130 | logger.info("Training BNN...") 131 | torch.manual_seed(seed) 132 | model_bnn = train_bnn( 133 | train_loader=train_loader, 134 | val_loader=val_loader, 135 | input_dim=input_dim, 136 | latent_dim=LATENT_DIM, 137 | lr=lr, 138 | weight_decay=weight_decay, 139 | device=device, 140 | num_epochs=NUM_EPOCH, 141 | ) 142 | dct_result["bnn"] = bnn_pred_intvl( 143 | model_bnn, X_test, T=200, alpha=ALPHA, device=device 144 | ) 145 | except Exception as e: 146 | logger.error(f"Error training BNN: {e}") 147 | dct_result["bnn"] = (None, None, None) 148 | # ! vine (also extracting Y_train, Y_test) 149 | try: 150 | logger.info("Training vine...") 151 | torch.manual_seed(seed) 152 | model_base = train_base_model( 153 | train_loader=train_loader, 154 | val_loader=val_loader, 155 | input_dim=input_dim, 156 | lr=lr, 157 | weight_decay=weight_decay, 158 | p_drop=0.0, 159 | latent_dim=LATENT_DIM, 160 | device=device, 161 | ) 162 | Z_train, Y_train = extract_features(model_base, train_loader, device) 163 | model_vine = train_vine(Z_train, Y_train, device=device) 164 | Z_test, Y_test = extract_features(model_base, test_loader, device) 165 | dct_result["vine"] = vine_pred_intvl( 166 | model_vine, Z_test, alpha=ALPHA, seed=seed, device=device 167 | ) 168 | except Exception as e: 169 | logger.error(f"Error training vine: {e}") 170 | dct_result["vine"] = (None, None, None) 171 | with open(file, "wb") as f: 172 | pickle.dump(dct_result, f) 173 | 174 | logger.info(f"Results saved to {file}") 175 | logging.shutdown() 176 | -------------------------------------------------------------------------------- /examples/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import pickle 4 | import platform 5 | import sys 6 | import time 7 | from collections import defaultdict 8 | from itertools import product 9 | from pathlib import Path 10 | 11 | import pyvinecopulib as pvc 12 | import torch 13 | from dotenv import load_dotenv 14 | from torch.special import ndtr 15 | 16 | import torchvinecopulib as tvc 17 | 18 | # ! =================== 19 | # ! =================== 20 | is_test = False 21 | # ! =================== 22 | # ! =================== 23 | num_threads = 10 24 | is_cuda_avail = torch.cuda.is_available() 25 | load_dotenv() 26 | DIR_WORK = Path(os.getenv("DIR_WORK")) 27 | DIR_OUT = DIR_WORK / "examples" / "benchmark" / "out" 28 | DIR_OUT.mkdir(parents=True, exist_ok=True) 29 | torch.set_default_dtype(torch.float64) 30 | SEED = 42 31 | if is_test: 32 | lst_num_dim = [10, 30][::-1] 33 | lst_num_obs = [1000, 50000][::-1] 34 | lst_seed = list(range(3)) 35 | else: 36 | lst_num_dim = [10, 20, 30, 40, 50][::-1] 37 | lst_num_obs = [1000, 5000, 10000, 20000, 30000, 40000, 50000][::-1] 38 | lst_seed = list(range(100)) 39 | dct_time_fit = { 40 | "pvc": defaultdict(list), 41 | "tvc": defaultdict(list), 42 | "tvc_cuda": defaultdict(list), 43 | } 44 | dct_time_sample = { 45 | "pvc": defaultdict(list), 46 | "tvc": defaultdict(list), 47 | "tvc_cuda": defaultdict(list), 48 | } 49 | dct_time_pdf = { 50 | "pvc": defaultdict(list), 51 | "tvc": defaultdict(list), 52 | "tvc_cuda": defaultdict(list), 53 | } 54 | 55 | 56 | def cuda_warmup(num_obs, num_dim): 57 | if torch.cuda.is_available(): 58 | device = torch.device("cuda") 59 | for _ in range(5): 60 | _ = torch.randn(num_obs, num_dim, device=device) 61 | torch.cuda.synchronize() 62 | else: 63 | print("CUDA is not available. Skipping warm-up.") 64 | 65 | 66 | print(f"Python executable: {sys.executable}") 67 | print(f"Python version: {sys.version.splitlines()[0]}") 68 | print(f"Platform: {platform.platform()}") 69 | print( 70 | f"PyTorch: {torch.__version__} (CUDA available: {torch.cuda.is_available()}, CUDA toolkit: {torch.version.cuda})" 71 | ) 72 | print(f"pyvinecopulib: {pvc.__version__}") 73 | print(f"torchvinecopulib: {tvc.__version__}") 74 | print(f"number of seeds: {len(lst_seed)}") 75 | 76 | # %% 77 | for num_dim, num_obs in product(lst_num_dim, lst_num_obs): 78 | print(f"\n\n{time.strftime('%Y-%m-%d %H:%M:%S')}\nnum_dim: {num_dim}, num_obs: {num_obs}\n\n") 79 | # ! preprocess into copula scale (uniform marginals) 80 | torch.manual_seed(SEED) 81 | # * tensor on cpu 82 | R = torch.rand(num_dim, num_dim, dtype=torch.float64) 83 | R /= R.norm(dim=1, keepdim=True) 84 | R @= R.T 85 | U = ndtr( 86 | (torch.randn(num_obs, num_dim, dtype=torch.float64) @ torch.linalg.cholesky(R, upper=True)) 87 | ) 88 | # * tensor on cuda 89 | if is_cuda_avail: 90 | U_cuda = U.cuda() 91 | # * np on cpu 92 | U_numpy = U.numpy().astype("float64") 93 | # ! pvc 94 | pvc_ctrl = pvc.FitControlsVinecop( 95 | family_set=(pvc.BicopFamily.indep, pvc.BicopFamily.tll), 96 | nonparametric_method="quadratic", 97 | tree_criterion="tau", 98 | num_threads=num_threads, 99 | ) 100 | # ! tvc 101 | tvc_mdl = tvc.VineCop(num_dim=num_dim, num_step_grid=64, is_cop_scale=True) 102 | # ! tvc_cuda 103 | if is_cuda_avail: 104 | tvc_mdl_cuda = tvc.VineCop(num_dim=num_dim, num_step_grid=64, is_cop_scale=True).cuda() 105 | 106 | # * fit 107 | print(f"\n{time.strftime('%Y-%m-%d %H:%M:%S')} Fitting...\n") 108 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} PVC...") 109 | for seed in lst_seed: 110 | t0 = time.perf_counter() 111 | pvc_mdl = pvc.Vinecop.from_data(data=U_numpy, controls=pvc_ctrl) 112 | t1 = time.perf_counter() 113 | if seed > 0: 114 | dct_time_fit["pvc"][num_obs, num_dim].append(t1 - t0) 115 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC...") 116 | for seed in lst_seed: 117 | t0 = time.perf_counter() 118 | tvc_mdl.fit(obs=U, mtd_bidep="kendall_tau", num_iter_max=11) 119 | t1 = time.perf_counter() 120 | if seed > 0: 121 | dct_time_fit["tvc"][num_obs, num_dim].append(t1 - t0) 122 | if is_cuda_avail: 123 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC CUDA...") 124 | cuda_warmup(num_obs, num_dim) 125 | for seed in lst_seed: 126 | torch.cuda.synchronize() 127 | t0 = time.perf_counter() 128 | tvc_mdl_cuda.fit(obs=U_cuda, mtd_bidep="kendall_tau", num_iter_max=11) 129 | t1 = time.perf_counter() 130 | torch.cuda.synchronize() 131 | if seed > 0: 132 | dct_time_fit["tvc_cuda"][num_obs, num_dim].append(t1 - t0) 133 | 134 | # * sample 135 | print(f"\n{time.strftime('%Y-%m-%d %H:%M:%S')} Sampling...\n") 136 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} PVC...") 137 | for seed in lst_seed: 138 | t0 = time.perf_counter() 139 | pvc_mdl.simulate(n=num_obs, num_threads=num_threads) 140 | t1 = time.perf_counter() 141 | if seed > 0: 142 | dct_time_sample["pvc"][num_obs, num_dim].append(t1 - t0) 143 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC...") 144 | for seed in lst_seed: 145 | t0 = time.perf_counter() 146 | tvc_mdl.sample(num_sample=num_obs) 147 | t1 = time.perf_counter() 148 | if seed > 0: 149 | dct_time_sample["tvc"][num_obs, num_dim].append(t1 - t0) 150 | if is_cuda_avail: 151 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC CUDA...") 152 | cuda_warmup(num_obs, num_dim) 153 | for seed in lst_seed: 154 | torch.cuda.synchronize() 155 | t0 = time.perf_counter() 156 | tvc_mdl_cuda.sample(num_sample=num_obs) 157 | t1 = time.perf_counter() 158 | torch.cuda.synchronize() 159 | if seed > 0: 160 | dct_time_sample["tvc_cuda"][num_obs, num_dim].append(t1 - t0) 161 | 162 | # * pdf 163 | print(f"\n{time.strftime('%Y-%m-%d %H:%M:%S')} PDF...\n") 164 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} PVC...") 165 | for seed in lst_seed: 166 | t0 = time.perf_counter() 167 | pvc_mdl.pdf(U_numpy, num_threads=num_threads) 168 | t1 = time.perf_counter() 169 | if seed > 0: 170 | dct_time_pdf["pvc"][num_obs, num_dim].append(t1 - t0) 171 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC...") 172 | for seed in lst_seed: 173 | t0 = time.perf_counter() 174 | tvc_mdl.log_pdf(U) 175 | t1 = time.perf_counter() 176 | if seed > 0: 177 | dct_time_pdf["tvc"][num_obs, num_dim].append(t1 - t0) 178 | if is_cuda_avail: 179 | print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} TVC CUDA...") 180 | cuda_warmup(num_obs, num_dim) 181 | for seed in lst_seed: 182 | torch.cuda.synchronize() 183 | t0 = time.perf_counter() 184 | tvc_mdl_cuda.log_pdf(U_cuda) 185 | t1 = time.perf_counter() 186 | torch.cuda.synchronize() 187 | if seed > 0: 188 | dct_time_pdf["tvc_cuda"][num_obs, num_dim].append(t1 - t0) 189 | 190 | # %% 191 | # ! save 192 | with open(DIR_OUT / "time_fit.pkl", "wb") as f: 193 | pickle.dump(dct_time_fit, f) 194 | with open(DIR_OUT / "time_sample.pkl", "wb") as f: 195 | pickle.dump(dct_time_sample, f) 196 | with open(DIR_OUT / "time_pdf.pkl", "wb") as f: 197 | pickle.dump(dct_time_pdf, f) 198 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchvinecopulib 2 | 3 | [![Codacy Badge](https://app.codacy.com/project/badge/Grade/e8a7a7448b2043d9bbefafc5a3ec14f7)](https://app.codacy.com/gh/TY-Cheng/torchvinecopulib/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade) 4 | [![Codacy Badge](https://app.codacy.com/project/badge/Coverage/e8a7a7448b2043d9bbefafc5a3ec14f7)](https://app.codacy.com?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_coverage) 5 | [![Lint Pytest](https://github.com/TY-Cheng/torchvinecopulib/actions/workflows/python-package.yml/badge.svg?branch=main)](https://github.com/TY-Cheng/torchvinecopulib/actions/workflows/python-package.yml) 6 | [![Deploy Docs](https://github.com/TY-Cheng/torchvinecopulib/actions/workflows/static.yml/badge.svg?branch=main)](https://ty-cheng.github.io/torchvinecopulib/) 7 | 8 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchvinecopulib) 9 | [![OS](https://img.shields.io/badge/OS-Windows%7CmacOS%7CUbuntu-blue)](https://github.com/TY-Cheng/torchvinecopulib/actions/workflows/python-package.yml) 10 | 11 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/TY-Cheng/torchvinecopulib/blob/main/LICENSE) 12 | [![PyPI - Version](https://img.shields.io/pypi/v/torchvinecopulib)](https://pypi.org/project/torchvinecopulib/) 13 | 14 | A Python library for fitting and sampling vine copulas, using [PyTorch](https://pytorch.org/get-started/locally/). 15 | 16 | - C/D/R-Vine full-sampling/ quantile-regression/ conditional-sampling, all in one package 17 | - Flexible sampling order for experienced users 18 | - Vectorized tensor computation with GPU (`device='cuda'`) support 19 | - Shorter runtimes for higher dimension simulations 20 | - Pure `Python` library, inspired by [pyvinecopulib](https://github.com/vinecopulib/pyvinecopulib/) on Windows, Linux, MacOS 21 | - IO and visualization support 22 | 23 | ## Citation 24 | If you use `torchvinecopulib` in your work, please cite: 25 | 26 | > Cheng, Tuoyuan, Thibault Vatter, Thomas Nagler, and Kan Chen. "Vine Copulas as Differentiable Computational Graphs." arXiv preprint arXiv:2506.13318 (2025). 27 | 28 | ```latex 29 | @article{cheng2025vine, 30 | title={Vine Copulas as Differentiable Computational Graphs}, 31 | author={Cheng, Tuoyuan and Vatter, Thibault and Nagler, Thomas and Chen, Kan}, 32 | journal={arXiv preprint arXiv:2506.13318}, 33 | year={2025}, 34 | url={https://arxiv.org/abs/2506.13318}, 35 | } 36 | ``` 37 | 38 | ## Examples 39 | 40 | Visit the [`./examples/`](https://github.com/TY-Cheng/torchvinecopulib/tree/main/examples) folder for `.ipynb` Jupyter notebooks. 41 | 42 | ## Installation 43 | 44 | - By `pip` from [`PyPI`](https://pypi.org/project/torchvinecopulib/) (see the dependencies and uv sections below for CUDA support): 45 | 46 | ```bash 47 | pip install torchvinecopulib torch 48 | ``` 49 | 50 | - Or `pip` from `./dist/*.whl` or `./dist/*.tar.gz` in this repo. 51 | Need to use proper file name. 52 | 53 | ```bash 54 | # inside project root folder 55 | pip install ./dist/torchvinecopulib-1.1.0-py3-none-any.whl 56 | # or 57 | pip install ./dist/torchvinecopulib-1.1.0.tar.gz 58 | ``` 59 | 60 | ### (Recommended) [uv](https://docs.astral.sh/uv/getting-started/) for Dependency Management and Packaging 61 | 62 | After `git clone https://github.com/TY-Cheng/torchvinecopulib.git`, `cd` into the project root where [`pyproject.toml`](https://github.com/TY-Cheng/torchvinecopulib/blob/main/pyproject.toml) exists, 63 | 64 | ```bash 65 | # From inside the project root folder 66 | # Create and activate local virtual environment 67 | uv venv .venv 68 | source .venv/bin/activate 69 | 70 | # Sync dependencies with CPU support (default) 71 | uv sync --extra cpu 72 | 73 | # Or for CUDA 12.6 or 12.8 support (depends on your CUDA version) 74 | uv sync --extra cu126 75 | 76 | # Additionally, to install additional dependencies for the examples 77 | uv sync --extra examples 78 | ``` 79 | 80 | ## Dependencies 81 | 82 | ```toml 83 | # inside the `./pyproject.toml` file; 84 | fastkde = "*" 85 | numpy = "*" 86 | pyvinecopulib = "*" 87 | python = ">=3.11" 88 | scipy = "*" 89 | # optional to facilitate customization 90 | torch = [ 91 | { index = "torch-cpu", extra = "cpu" }, 92 | { index = "torch-cu126", extra = "cu126" }, 93 | { index = "torch-cu128", extra = "cu128" }, 94 | ] 95 | ``` 96 | 97 | For [PyTorch](https://pytorch.org/get-started/locally/) with `cuda`: 98 | 99 | ```bash 100 | pip install torch --index-url https://download.pytorch.org/whl/cu126 --force-reinstall 101 | # check cuda availability 102 | python -c "import torch; print(torch.cuda.is_available())" 103 | ``` 104 | 105 | > [!TIP] 106 | > macOS users should set `device='cpu'` at this stage, for using `device='mps'` won't support `dtype=torch.float64`. 107 | 108 | ## Documentation 109 | 110 | - Visit [GitHub Pages website](https://ty-cheng.github.io/torchvinecopulib/) 111 | 112 | - Or build by yourself (need [`Sphinx`](https://github.com/sphinx-doc/sphinx), theme [`furo`](https://github.com/pradyunsg/furo) and [the GNU `make`](https://www.gnu.org/software/make/)) 113 | 114 | ```bash 115 | # inside project root folder 116 | sphinx-apidoc -o ./docs ./torchvinecopulib && cd ./docs && make html && cd .. 117 | # if using uv 118 | uv run sphinx-apidoc -o docs torchvinecopulib/ --separate 119 | uv run sphinx-build docs docs/_build/html 120 | ``` 121 | 122 | ## Tests 123 | 124 | ```python 125 | # inside project root folder 126 | python -m pytest ./tests 127 | # coverage report 128 | coverage run -m pytest ./tests && coverage html 129 | # if using uv 130 | uv run coverage run --source=torchvinecopulib -m pytest ./tests 131 | uv run coverage report -m 132 | ``` 133 | 134 | ## TODO 135 | 136 | - `VineCop.rosenblatt` 137 | - replace `dict` with `torch.Tensor` using some `mod` 138 | - vectorized union-find 139 | - flatten `_visit` logic 140 | - `examples/someapplications.ipynb` 141 | - flatten dynamic nested dicts into tensors 142 | - [`fastkde.pdf`](https://github.com/LBL-EESA/fastkde/blob/main/src/fastkde/fastKDE.py) onto `torch.Tensor` 143 | 144 | ## Contributing 145 | 146 | We welcome contributions, whether it's a bug report, feature suggestion, code contribution, or documentation improvement. 147 | 148 | - If you encounter any issues with the project or have ideas for new features, please [open an issue](https://github.com/TY-Cheng/torchvinecopulib/issues/new) on GitHub or [privately email us](mailto:cty120120@gmail.com). Make sure to include detailed information about the problem or feature request, including steps to reproduce for bugs. 149 | 150 | ### Code Contributions 151 | 152 | 1. Fork the repository and create a new branch from the `main` branch. 153 | 2. Make your changes and ensure they adhere to the project's coding style and conventions. 154 | 3. Write tests for any new functionality and ensure existing tests pass. 155 | 4. Commit your changes with clear and descriptive commit messages. 156 | 5. Push your changes to your fork and submit a pull request to the `main` branch of the original repository. 157 | 158 | ### Pull Request Guidelines 159 | 160 | - Keep pull requests focused on addressing a single issue or feature. 161 | - Include a clear and descriptive title and description for your pull request. 162 | - Make sure all tests pass before submitting the pull request. 163 | - If your pull request addresses an open issue, reference the issue number in the description using the syntax `#issue_number`. 164 | - [in-place ops can be slower](https://discuss.pytorch.org/t/are-inplace-operations-faster/61209/4) 165 | - [torch.jit.script can be slower](https://discuss.pytorch.org/t/why-is-torch-jit-script-slower/120131/6) 166 | 167 | ## License 168 | 169 | This project is released under the MIT License (© 2024- Tuoyuan Cheng, Kan Chen). 170 | See [LICENSE](./LICENSE) for the full text, including our own grant of rights and disclaimer. 171 | 172 | ### Third-Party Dependencies 173 | 174 | See the “Third-Party Dependencies” section in [LICENSE](./LICENSE) for details on the `PyTorch`, `FastKDE`, and `pyvinecopulib` licenses that govern those components. 175 | -------------------------------------------------------------------------------- /examples/vcae/vcae_debug.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "f20e9099", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "The autoreload extension is already loaded. To reload it, use:\n", 14 | " %reload_ext autoreload\n" 15 | ] 16 | }, 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "Running experiments: 0%| | 0/1 [00:00 original is {loglik_model / loglik_refit_model} x worse\"\n", 162 | "# )\n", 163 | "# print(\"Model scores (original vs refit):\")\n", 164 | "# print(\n", 165 | "# f\"MMD: {score_model.mmd} vs {score_refit_model.mmd} => original is {score_model.mmd / score_refit_model.mmd} x worse\"\n", 166 | "# )\n", 167 | "# print(\n", 168 | "# f\"FID: {score_model.fid} vs {score_refit_model.fid} => original is {score_model.fid / score_refit_model.fid} x worse\"\n", 169 | "# )\n" 170 | ] 171 | } 172 | ], 173 | "metadata": { 174 | "kernelspec": { 175 | "display_name": ".venv", 176 | "language": "python", 177 | "name": "python3" 178 | }, 179 | "language_info": { 180 | "codemirror_mode": { 181 | "name": "ipython", 182 | "version": 3 183 | }, 184 | "file_extension": ".py", 185 | "mimetype": "text/x-python", 186 | "name": "python", 187 | "nbconvert_exporter": "python", 188 | "pygments_lexer": "ipython3", 189 | "version": "3.13.5" 190 | } 191 | }, 192 | "nbformat": 4, 193 | "nbformat_minor": 5 194 | } 195 | -------------------------------------------------------------------------------- /examples/pred_intvl/data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import platform 4 | import sys 5 | from pathlib import Path 6 | 7 | import torch 8 | from dotenv import load_dotenv 9 | from sklearn.datasets import fetch_california_housing, fetch_openml 10 | from sklearn.model_selection import train_test_split 11 | from sklearn.preprocessing import StandardScaler 12 | from torch.utils.data import DataLoader, TensorDataset 13 | 14 | latent_dim = 10 15 | num_epochs = 10 16 | batch_size = 128 17 | test_size = 0.2 18 | val_size = 0.1 19 | random_seed = 42 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | torch.manual_seed(random_seed) 22 | load_dotenv() 23 | DIR_WORK = Path(os.getenv("DIR_WORK")) 24 | 25 | 26 | def get_logger( 27 | log_file: Path | str, 28 | console_level: int = logging.WARNING, 29 | file_level: int = logging.INFO, 30 | fmt_console: str = "%(asctime)s - %(levelname)s - %(message)s", 31 | fmt_file: str = "%(asctime)s - %(name)s - %(levelname)s - [%(pathname)s:%(lineno)d] - %(message)s", 32 | name: str | None = None, 33 | ) -> logging.Logger: 34 | """Create (or retrieve) a module‐level logger that writes INFO+ to console and 35 | WARNING+ to file. 36 | 37 | Args: 38 | log_file: path to the file where warnings+ should be logged. 39 | console_level: logging level for console handler. 40 | file_level: logging level for file handler. 41 | fmt_console: format string for console output. 42 | fmt_file: format string for file output. 43 | name: name for the logger; defaults to str(log_file). 44 | device: the device string to log (e.g. "cuda"/"cpu"); if None, auto‐detects. 45 | 46 | Returns: 47 | A configured logging.Logger instance. 48 | """ 49 | log_file = Path(log_file) 50 | log_file.parent.mkdir(parents=True, exist_ok=True) 51 | 52 | logger_name = name or str(log_file) 53 | logger = logging.getLogger(logger_name) 54 | logger.setLevel(logging.INFO) 55 | 56 | # * prevent duplicate handlers if called multiple times 57 | if not logger.handlers: 58 | # * file handler 59 | fh = logging.FileHandler(log_file) 60 | fh.setLevel(file_level) 61 | fh.setFormatter(logging.Formatter(fmt_file)) 62 | logger.addHandler(fh) 63 | # * console handler 64 | ch = logging.StreamHandler(sys.stdout) 65 | ch.setLevel(console_level) 66 | ch.setFormatter(logging.Formatter(fmt_console)) 67 | logger.addHandler(ch) 68 | # * initial banner 69 | logger.info("--- Logger initialized ---") 70 | logger.info(f"Log file: {log_file}") 71 | logger.info(f"Python: {sys.version.replace(chr(10), ' ')}") 72 | logger.info(f"Platform: {platform.platform()}") 73 | logger.info(f"PyTorch: {torch.__version__}") 74 | logger.info(f"CUDA available: {torch.cuda.is_available()}") 75 | 76 | return logger 77 | 78 | 79 | def load_california_housing( 80 | batch_size: int = 128, 81 | test_size: float = test_size, 82 | val_size: float = val_size, 83 | seed_val: int = 42, 84 | num_workers: int = 4, 85 | pin_memory: bool = True, 86 | ): 87 | """Fetch California Housing, split into train/val/test, scale (features & target) on 88 | train only, and wrap in PyTorch DataLoaders. 89 | 90 | Returns: train_loader, val_loader, test_loader, x_scaler, y_scaler 91 | """ 92 | torch.manual_seed(seed_val) 93 | 94 | # * download & split 95 | data = fetch_california_housing() 96 | X, y = data.data, data.target[:, None] 97 | # ! test set untouched 98 | X_trainval, X_test, y_trainval, y_test = train_test_split( 99 | X, y, test_size=test_size, random_state=seed_val 100 | ) 101 | # ! carve off validation from trainval 102 | X_train, X_val, y_train, y_val = train_test_split( 103 | X_trainval, y_trainval, test_size=val_size, random_state=seed_val 104 | ) 105 | x_scaler = StandardScaler().fit(X_train) 106 | y_scaler = StandardScaler().fit(y_train) 107 | # * transform all splits 108 | X_train = x_scaler.transform(X_train) 109 | X_val = x_scaler.transform(X_val) 110 | X_test = x_scaler.transform(X_test) 111 | y_train = y_scaler.transform(y_train) 112 | y_val = y_scaler.transform(y_val) 113 | y_test = y_scaler.transform(y_test) 114 | # * to torch tensors 115 | Xtr = torch.from_numpy(X_train).float() 116 | Ytr = torch.from_numpy(y_train).float() 117 | Xva = torch.from_numpy(X_val).float() 118 | Yva = torch.from_numpy(y_val).float() 119 | Xte = torch.from_numpy(X_test).float() 120 | Yte = torch.from_numpy(y_test).float() 121 | # * wrap into DataLoaders 122 | train_loader = DataLoader( 123 | TensorDataset(Xtr, Ytr), 124 | batch_size=batch_size, 125 | shuffle=True, 126 | num_workers=num_workers, 127 | pin_memory=pin_memory, 128 | ) 129 | val_loader = DataLoader( 130 | TensorDataset(Xva, Yva), 131 | batch_size=batch_size, 132 | shuffle=False, 133 | num_workers=num_workers, 134 | pin_memory=pin_memory, 135 | ) 136 | test_loader = DataLoader( 137 | TensorDataset(Xte, Yte), 138 | batch_size=batch_size, 139 | shuffle=False, 140 | num_workers=num_workers, 141 | pin_memory=pin_memory, 142 | ) 143 | 144 | return train_loader, val_loader, test_loader, x_scaler, y_scaler 145 | 146 | 147 | def load_news_popularity( 148 | batch_size: int = 128, 149 | test_size: float = test_size, 150 | val_size: float = val_size, 151 | seed_val: int = 42, 152 | num_workers: int = 4, 153 | pin_memory: bool = True, 154 | ): 155 | """Online News Popularity regression (UCI). 156 | 157 | Splits into train/val/test, scales on train only, wraps in DataLoaders. 158 | """ 159 | torch.manual_seed(seed_val) 160 | # * fetch & split 161 | X, y = fetch_openml("OnlineNewsPopularity", version=1, return_X_y=True, as_frame=False) 162 | X = X[:, 1:] # ! remove the first column (website) 163 | y = y.astype(float).reshape(-1, 1) 164 | # ! test set untouched 165 | X_trainval, X_test, y_trainval, y_test = train_test_split( 166 | X, y, test_size=test_size, random_state=seed_val 167 | ) 168 | # ! carve off validation 169 | X_train, X_val, y_train, y_val = train_test_split( 170 | X_trainval, y_trainval, test_size=val_size, random_state=seed_val 171 | ) 172 | # * standardize 173 | x_scaler = StandardScaler().fit(X_train) 174 | y_scaler = StandardScaler().fit(y_train) 175 | 176 | X_train = x_scaler.transform(X_train) 177 | X_val = x_scaler.transform(X_val) 178 | X_test = x_scaler.transform(X_test) 179 | 180 | y_train = y_scaler.transform(y_train) 181 | y_val = y_scaler.transform(y_val) 182 | y_test = y_scaler.transform(y_test) 183 | 184 | # * to torch tensors 185 | Xtr = torch.from_numpy(X_train).float() 186 | Ytr = torch.from_numpy(y_train).float() 187 | Xva = torch.from_numpy(X_val).float() 188 | Yva = torch.from_numpy(y_val).float() 189 | Xte = torch.from_numpy(X_test).float() 190 | Yte = torch.from_numpy(y_test).float() 191 | 192 | # * DataLoaders 193 | train_loader = DataLoader( 194 | TensorDataset(Xtr, Ytr), 195 | batch_size=batch_size, 196 | shuffle=True, 197 | num_workers=num_workers, 198 | pin_memory=pin_memory, 199 | ) 200 | val_loader = DataLoader( 201 | TensorDataset(Xva, Yva), 202 | batch_size=batch_size, 203 | shuffle=False, 204 | num_workers=num_workers, 205 | pin_memory=pin_memory, 206 | ) 207 | test_loader = DataLoader( 208 | TensorDataset(Xte, Yte), 209 | batch_size=batch_size, 210 | shuffle=False, 211 | num_workers=num_workers, 212 | pin_memory=pin_memory, 213 | ) 214 | 215 | return train_loader, val_loader, test_loader, x_scaler, y_scaler 216 | 217 | 218 | @torch.no_grad() 219 | def extract_XY(loader, device): 220 | Xs, Ys = [], [] 221 | for x, y in loader: 222 | Xs.append(x.to(device)) 223 | Ys.append(y.to(device)) 224 | return torch.cat(Xs, 0).cpu(), torch.cat(Ys, 0).cpu() 225 | 226 | 227 | if __name__ == "__main__": 228 | # quick test / demo 229 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 230 | train_loader, test_loader, x_scaler, y_scaler = load_california_housing() 231 | X_train, Y_train = extract_XY(train_loader, device) 232 | X_test, Y_test = extract_XY(test_loader, device) 233 | print(f"Train X: {X_train.shape}, Y: {Y_train.shape}") 234 | print(f"Test X: {X_test.shape}, Y: {Y_test.shape}") 235 | # quick sanity check 236 | train_loader, test_loader, xscaler, yscaler = load_news_popularity( 237 | batch_size=128, test_size=0.2, seed=42 238 | ) 239 | Xtr, Ytr = extract_XY(train_loader, device) 240 | Xte, Yte = extract_XY(test_loader, device) 241 | print(f"News‐pop train: X={Xtr.shape}, Y={Ytr.shape}") 242 | print(f"News‐pop test: X={Xte.shape}, Y={Yte.shape}") 243 | -------------------------------------------------------------------------------- /tests/test_vinecop.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | import pytest 4 | import torch 5 | 6 | from torchvinecopulib.vinecop import VineCop 7 | 8 | 9 | def test_init_attributes_and_defaults(): 10 | num_dim = 5 11 | vc = VineCop(num_dim=num_dim, is_cop_scale=True, num_step_grid=32) 12 | # basic attrs 13 | assert vc.num_dim == num_dim 14 | assert vc.is_cop_scale is True 15 | assert vc.num_step_grid == 32 16 | # empty structures 17 | assert isinstance(vc.marginals, torch.nn.ModuleList) 18 | assert len(vc.marginals) == num_dim 19 | assert isinstance(vc.bicops, torch.nn.ModuleDict) 20 | # number of pair-copulas = n(n-1)/2 21 | assert len(vc.bicops) == num_dim * (num_dim - 1) // 2 22 | assert vc.tree_bidep == [{} for _ in range(num_dim - 1)] 23 | assert vc.sample_order == tuple(range(num_dim)) 24 | # no data yet 25 | assert vc.num_obs.item() == 0 26 | 27 | 28 | def test_device_and_dtype_after_cuda(): 29 | vc = VineCop(3).to("cpu") 30 | assert vc.device.type == "cpu" 31 | assert vc.dtype is torch.float64 32 | if torch.cuda.is_available(): 33 | vc2 = VineCop(3).cuda() 34 | assert vc2.device.type == "cuda" 35 | 36 | 37 | @pytest.mark.parametrize("mtd_vine", ["dvine", "cvine", "rvine"]) 38 | @pytest.mark.parametrize( 39 | "mtd_bidep", 40 | ["kendall_tau", "mutual_info", "chatterjee_xi", "ferreira_tail_dep_coeff"], 41 | ) 42 | def test_fit_sample_logpdf_cdf_forward(mtd_vine, mtd_bidep): 43 | num_dim = 5 44 | torch.manual_seed(0) 45 | U = torch.special.ndtri(torch.rand(200, num_dim, dtype=torch.float64)) 46 | vc = VineCop(num_dim=num_dim, is_cop_scale=False, num_step_grid=32) 47 | # fit in copula-scale 48 | vc.fit( 49 | obs=U, 50 | mtd_kde="tll", # use TLL for bidep estimation 51 | is_dissmann=True, 52 | mtd_vine=mtd_vine, 53 | mtd_bidep=mtd_bidep, 54 | thresh_trunc=None, # no truncation 55 | ) 56 | # num_obs must be updated 57 | assert vc.num_obs.item() == 200 58 | 59 | # sampling 60 | samp = vc.sample(num_sample=50, seed=1, is_sobol=False) 61 | assert samp.shape == (50, num_dim) 62 | 63 | # log_pdf 64 | lp = vc.log_pdf(U) 65 | assert lp.shape == (200, 1) 66 | # forward = neg average log-lik 67 | fwd = vc.forward(U) 68 | assert fwd.dim() == 0 69 | 70 | # cdf approximation in [0,1] 71 | cdf_vals = vc.cdf(U[:10], num_sample=1000, seed=2) 72 | assert cdf_vals.shape == (10, 1) 73 | assert (cdf_vals >= 0).all() and (cdf_vals <= 1).all() 74 | 75 | 76 | @pytest.mark.parametrize("mtd_vine", ["dvine", "cvine", "rvine"]) 77 | @pytest.mark.parametrize( 78 | "mtd_bidep", 79 | ["kendall_tau", "mutual_info", "chatterjee_xi", "ferreira_tail_dep_coeff"], 80 | ) 81 | def test_matrix_diagonal_matches_sample_order(mtd_vine, mtd_bidep): 82 | num_dim = 5 83 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16) 84 | torch.manual_seed(1) 85 | U = torch.rand(100, num_dim, dtype=torch.float64) 86 | vc.fit( 87 | U, 88 | mtd_kde="tll", 89 | is_dissmann=True, 90 | mtd_vine=mtd_vine, 91 | mtd_bidep=mtd_bidep, 92 | ) 93 | M = vc.matrix 94 | # must be square 95 | assert M.shape == (num_dim, num_dim) 96 | # * d unique elements in each row 97 | for i in range(num_dim): 98 | elems = set(M[i, :].tolist()) 99 | elems.discard(-1) # discard -1 100 | assert len(elems) == num_dim - i, f"Row {i} has incorrect unique elements" 101 | # diag = sample_order 102 | for i in range(num_dim): 103 | assert M[i, i].item() == vc.sample_order[i] 104 | 105 | 106 | def test_fit_with_explicit_matrix_uses_exact_edges(): 107 | num_dim = 5 108 | torch.manual_seed(1) 109 | U = torch.rand(200, num_dim, dtype=torch.float64) 110 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16) 111 | vc.fit(U, mtd_kde="tll", is_dissmann=True) 112 | M = vc.matrix 113 | # M is a square matrix with num_dim rows and columns 114 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16) 115 | # fit with our explicit matrix 116 | vc.fit(U, is_dissmann=False, matrix=M) 117 | # now verify that for each level lv, and each idx in [0..num_dim-lv-2], 118 | # the edge (v_l, v_r, *cond) comes out exactly as we encoded it in M 119 | for lv in range(num_dim - 1): 120 | tree = vc.tree_bidep[lv] 121 | # must have exactly num_dim-lv-1 edges 122 | assert len(tree) == num_dim - lv - 1 123 | 124 | for idx in range(num_dim - lv - 1): 125 | # the two “free” spots in row idx of M 126 | a = int(M[idx, idx]) 127 | b = int(M[idx, num_dim - lv - 1]) 128 | v_l, v_r = sorted((a, b)) 129 | # any remaining entries in that row form the conditioning set 130 | cond = sorted([int(_) for _ in M[idx, num_dim - lv :].tolist()]) 131 | expected_edge = (v_l, v_r, *cond) 132 | 133 | # check that this exact tuple is a key in tree_bidep[lv] 134 | assert expected_edge in tree, ( 135 | f"Expected edge {expected_edge} at level {lv} but got {list(tree)}" 136 | ) 137 | 138 | 139 | def test_reset_clears_all_levels_and_bicops(): 140 | num_dim = 5 141 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16) 142 | torch.manual_seed(0) 143 | U = torch.rand(50, num_dim, dtype=torch.float64) 144 | # fit with Dissmann algorithm 145 | vc.fit( 146 | U, 147 | is_dissmann=True, 148 | mtd_kde="tll", 149 | mtd_vine="rvine", 150 | mtd_bidep="kendall_tau", 151 | thresh_trunc=None, 152 | ) 153 | assert vc.num_obs.item() > 0 154 | # now reset 155 | vc.reset() 156 | assert vc.num_obs.item() == 0 157 | assert vc.tree_bidep == [{} for _ in range(num_dim - 1)] 158 | # all BiCop should be independent again 159 | assert all(bc.is_indep for bc in vc.bicops.values()) 160 | 161 | 162 | def test_reset_and_str(): 163 | num_dim = 5 164 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16) 165 | torch.manual_seed(0) 166 | U = torch.rand(50, num_dim, dtype=torch.float64) 167 | # fit with Dissmann algorithm 168 | vc.fit( 169 | U, 170 | is_dissmann=True, 171 | mtd_kde="tll", 172 | mtd_vine="rvine", 173 | mtd_bidep="kendall_tau", 174 | thresh_trunc=None, 175 | ) 176 | # __str__ contains key fields 177 | s = str(vc) 178 | for key in [ 179 | "num_dim", 180 | "num_obs", 181 | "is_cop_scale", 182 | "num_step_grid", 183 | "mtd_bidep", 184 | "negloglik", 185 | "dtype", 186 | "device", 187 | "sample_order", 188 | ]: 189 | assert key in s, f"'{key}' not found in VineCop string representation" 190 | 191 | 192 | @pytest.mark.parametrize("mtd_vine", ["dvine", "cvine", "rvine"]) 193 | @pytest.mark.parametrize( 194 | "mtd_bidep", 195 | ["kendall_tau", "mutual_info", "chatterjee_xi", "ferreira_tail_dep_coeff"], 196 | ) 197 | def test_ref_count_hfunc_on_fitted_vine(mtd_vine, mtd_bidep): 198 | num_dim = 5 199 | torch.manual_seed(0) 200 | U = torch.rand(100, num_dim, dtype=torch.float64) 201 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=16) 202 | vc.fit( 203 | U, 204 | mtd_kde="tll", 205 | is_dissmann=True, 206 | mtd_vine=mtd_vine, 207 | mtd_bidep=mtd_bidep, 208 | thresh_trunc=0.1, 209 | ) 210 | # test static ref_count_hfunc 211 | ref_cnt, sources, num_hfunc = VineCop.ref_count_hfunc( 212 | num_dim=vc.num_dim, 213 | struct_obs=vc.struct_obs, 214 | sample_order=vc.sample_order, 215 | ) 216 | assert isinstance(ref_cnt, dict) 217 | assert isinstance(sources, list) 218 | assert isinstance(num_hfunc, int) 219 | if mtd_vine == "cvine": 220 | # for cvine, we expect 0 hfuncs 221 | assert num_hfunc == 0 222 | else: 223 | assert num_hfunc >= 0 224 | 225 | 226 | def test_draw_lv_and_draw_dag(tmp_path): 227 | num_dim = 5 228 | torch.manual_seed(0) 229 | U = torch.rand(100, num_dim, dtype=torch.float64) 230 | vc = VineCop(num_dim, is_cop_scale=True, num_step_grid=32) 231 | vc.fit( 232 | U, 233 | mtd_kde="tll", 234 | is_dissmann=True, 235 | mtd_vine="rvine", 236 | mtd_bidep="kendall_tau", 237 | thresh_trunc=None, 238 | ) 239 | 240 | # draw level-0 with pseudo-obs nodes 241 | fig, ax, G = vc.draw_lv(lv=0, is_bcp=False) 242 | assert isinstance(G, nx.Graph) 243 | assert G.number_of_nodes() > 0 244 | plt.close(fig) 245 | 246 | # save level-1 bicop view 247 | fpath_lv = tmp_path / "level1.png" 248 | fig2, ax2, G2, outpath_lv = vc.draw_lv(lv=1, is_bcp=True, f_path=fpath_lv) 249 | assert outpath_lv == fpath_lv 250 | assert fpath_lv.exists() 251 | plt.close(fig2) 252 | # save level-1 pseudo-obs view 253 | fpath_lv = tmp_path / "level1.png" 254 | fig2, ax2, G2, outpath_lv = vc.draw_lv(lv=1, is_bcp=False, f_path=fpath_lv) 255 | assert outpath_lv == fpath_lv 256 | assert fpath_lv.exists() 257 | plt.close(fig2) 258 | 259 | # draw the DAG 260 | fig3, ax3, G3 = vc.draw_dag() 261 | assert isinstance(G3, nx.DiGraph) 262 | assert len(G3.nodes) > 0 263 | plt.close(fig3) 264 | 265 | # save DAG 266 | fpath_dag = tmp_path / "dag.png" 267 | fig4, ax4, G4, outpath_dag = vc.draw_dag(f_path=fpath_dag) 268 | assert outpath_dag == fpath_dag 269 | assert fpath_dag.exists() 270 | plt.close(fig4) 271 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | ### MIT License 2 | 3 | Copyright (c) 2024- Tuoyuan Cheng, Kan Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | --- 24 | 25 | ## Third-Party Dependencies 26 | 27 | This project depends on the following libraries, each of which is governed by its own license. You must read and comply with those licenses when using this software. 28 | 29 | ### 1. PyTorch 30 | 31 | - **URL**: https://github.com/pytorch/pytorch/blob/main/LICENSE 32 | 33 | Full PyTorch license text: 34 | 35 | From PyTorch: 36 | 37 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 38 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 39 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 40 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 41 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 42 | Copyright (c) 2011-2013 NYU (Clement Farabet) 43 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 44 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 45 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 46 | 47 | From Caffe2: 48 | 49 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 50 | 51 | All contributions by Facebook: 52 | Copyright (c) 2016 Facebook Inc. 53 | 54 | All contributions by Google: 55 | Copyright (c) 2015 Google Inc. 56 | All rights reserved. 57 | 58 | All contributions by Yangqing Jia: 59 | Copyright (c) 2015 Yangqing Jia 60 | All rights reserved. 61 | 62 | All contributions by Kakao Brain: 63 | Copyright 2019-2020 Kakao Brain 64 | 65 | All contributions by Cruise LLC: 66 | Copyright (c) 2022 Cruise LLC. 67 | All rights reserved. 68 | 69 | All contributions by Tri Dao: 70 | Copyright (c) 2024 Tri Dao. 71 | All rights reserved. 72 | 73 | All contributions by Arm: 74 | Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates 75 | 76 | All contributions from Caffe: 77 | Copyright(c) 2013, 2014, 2015, the respective contributors 78 | All rights reserved. 79 | 80 | All other contributions: 81 | Copyright(c) 2015, 2016 the respective contributors 82 | All rights reserved. 83 | 84 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 85 | copyright over their contributions to Caffe2. The project versioning records 86 | all such contribution and copyright details. If a contributor wants to further 87 | mark their specific copyright on a particular contribution, they should 88 | indicate their copyright solely in the commit message of the change when it is 89 | committed. 90 | 91 | All rights reserved. 92 | 93 | Redistribution and use in source and binary forms, with or without 94 | modification, are permitted provided that the following conditions are met: 95 | 96 | 1. Redistributions of source code must retain the above copyright 97 | notice, this list of conditions and the following disclaimer. 98 | 99 | 2. Redistributions in binary form must reproduce the above copyright 100 | notice, this list of conditions and the following disclaimer in the 101 | documentation and/or other materials provided with the distribution. 102 | 103 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 104 | and IDIAP Research Institute nor the names of its contributors may be 105 | used to endorse or promote products derived from this software without 106 | specific prior written permission. 107 | 108 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 109 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 110 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 111 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 112 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 113 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 114 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 115 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 116 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 117 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 118 | POSSIBILITY OF SUCH DAMAGE. 119 | 120 | ### 2. FastKDE 121 | 122 | - **URL**: https://github.com/LBL-EESA/fastkde/blob/main/LICENSE.txt 123 | 124 | - **License**: Lawrence Berkeley National Laboratory (“LBNL”) Non-Commercial Use Only License 125 | 126 | Full FastKDE license text: 127 | 128 | LAWRENCE BERKELEY NATIONAL LABORATORY 129 | RESEARCH & DEVELOPMENT, NON-COMMERCIAL USE ONLY, LICENSE 130 | 131 | Copyright (c) 2015, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. 132 | 133 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 134 | 135 | (1) Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 136 | 137 | (2) Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 138 | 139 | (3) Neither the name of the University of California, Lawrence Berkeley National Laboratory, U.S. Dept. of Energy nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 140 | 141 | (4) Use of the software, in source or binary form is FOR RESEARCH & DEVELOPMENT, NON-COMMERCIAL USE, PURPOSES ONLY. All commercial use rights for the software are hereby reserved. A separate commercial use license is available from Lawrence Berkeley National Laboratory. 142 | (5) In the event you create any bug fixes, patches, upgrades, updates, modifications, derivative works or enhancements to the source code or binary code of the software ("Enhancements") you hereby grant The Regents of the University of California and the U.S. Government a paid-up, non-exclusive, irrevocable, worldwide license in the Enhancements to reproduce, prepare derivative works, distribute copies to the public, perform publicly and display publicly, and to permit others to do so. 143 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 144 | **_ Copyright Notice _** 145 | FastKDE v1.0, Copyright (c) 2015, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. 146 | If you have questions about your rights to use or distribute this software, please contact Berkeley Lab's Innovation & Partnerships Office at IPO@lbl.gov. 147 | NOTICE. This software was developed under funding from the U.S. Department of Energy. As such, the U.S. Government has been granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, prepare derivative works, and perform publicly and display publicly. Beginning five (5) years after the date permission to assert copyright is obtained from the U.S. Department of Energy, and subject to any subsequent five (5) year renewals, the U.S. Government is granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, prepare derivative works, distribute copies to the public, perform publicly and display publicly, and to permit others to do so. 148 | 149 | ### 3. pyvinecopulib 150 | 151 | - **URL**: https://github.com/vinecopulib/pyvinecopulib/blob/main/LICENSE 152 | 153 | Full pyvinecopulib license text: 154 | 155 | The MIT License (MIT) 156 | 157 | Copyright © 2019- Thomas Nagler and Thibault Vatter 158 | 159 | Permission is hereby granted, free of charge, to any person obtaining a copy of 160 | this software and associated documentation files (the “Software”), to deal in 161 | the Software without restriction, including without limitation the rights to 162 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 163 | the Software, and to permit persons to whom the Software is furnished to do so, 164 | subject to the following conditions: 165 | 166 | The above copyright notice and this permission notice shall be included in all 167 | copies or substantial portions of the Software. 168 | 169 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 170 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 171 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 172 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 173 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 174 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 175 | 176 | --- 177 | -------------------------------------------------------------------------------- /tests/test_bicop.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import pytest 4 | import torch 5 | 6 | import torchvinecopulib as tvc 7 | 8 | from . import EPS, U_tensor, bicop_pair 9 | 10 | 11 | def test_device_and_dtype(): 12 | cop = tvc.BiCop(num_step_grid=16) 13 | # by default on CPU float64 14 | assert cop.device.type == "cpu" 15 | assert cop.dtype is torch.float64 16 | 17 | if torch.cuda.is_available(): 18 | cop = tvc.BiCop(num_step_grid=16).cuda() 19 | assert cop.device.type == "cuda" 20 | 21 | 22 | def test_monotonicity_and_range(bicop_pair): 23 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair 24 | 25 | # pick one of the two implementations or loop both 26 | for bicop in (bc_fast, bc_tll): 27 | # * simple diagonal check 28 | grid = torch.linspace(EPS, 1.0 - EPS, 100, device=U.device, dtype=torch.float64).unsqueeze( 29 | 1 30 | ) 31 | pts = torch.hstack([grid, grid]) 32 | out = bicop.cdf(pts) 33 | assert out.min() >= -EPS and out.max() <= 1 + EPS 34 | assert out.diff(0).min() >= -EPS 35 | 36 | # * row slices + left/right variants 37 | for u in grid: 38 | pts = torch.hstack([grid, u.repeat(grid.size(0), 1)]) 39 | for pts in (pts, pts.flip(1)): 40 | for fn in (bicop.cdf, bicop.hfunc_r, bicop.hinv_r): 41 | v = fn(pts) 42 | assert v.min() >= -EPS and v.max() <= 1 + EPS 43 | assert v.diff(0).min() >= -EPS 44 | 45 | 46 | def test_inversion(bicop_pair): 47 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair 48 | 49 | for bicop in (bc_fast, bc_tll): 50 | grid = torch.linspace(0.1, 0.9, 50, device=U.device, dtype=torch.float64).unsqueeze(1) 51 | for u in grid: 52 | pts = torch.hstack([grid, u.repeat(grid.size(0), 1)]) 53 | 54 | # * right‐side inverse 55 | rec0 = bicop.hinv_r(torch.hstack([bicop.hfunc_r(pts), pts[:, [1]]])) 56 | assert torch.allclose(rec0, pts[:, [0]], atol=1e-3) 57 | 58 | rec1 = bicop.hfunc_r(torch.hstack([bicop.hinv_r(pts), pts[:, [1]]])) 59 | assert torch.allclose(rec1, pts[:, [0]], atol=1e-3) 60 | 61 | # * left‐side inverse 62 | pts_rev = pts.flip(1) 63 | rec2 = bicop.hinv_l(torch.hstack([pts_rev[:, [0]], bicop.hfunc_l(pts_rev)])) 64 | assert torch.allclose(rec2, pts_rev[:, [1]], atol=1e-3) 65 | 66 | rec3 = bicop.hfunc_l(torch.hstack([pts_rev[:, [0]], bicop.hinv_l(pts_rev)])) 67 | assert torch.allclose(rec3, pts_rev[:, [1]], atol=1e-3) 68 | 69 | 70 | def test_pdf_integrates_to_one(bicop_pair): 71 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair 72 | for cop in (bc_fast, bc_tll): 73 | # our grid is uniform on [0,1]² with spacing Δ = 1/(N−1) 74 | Δ = 1.0 / (cop.num_step_grid - 1) 75 | # approximate ∫ pdf(u,v) du dv ≈ Σ_pdf_grid * Δ² 76 | approx_mass = (cop._pdf_grid.sum() * Δ**2).item() 77 | assert pytest.approx(expected=1.0, rel=1e-2) == approx_mass 78 | # non-negativity 79 | assert (cop._pdf_grid >= -EPS).all() 80 | 81 | 82 | def test_log_pdf_matches_log_of_pdf(bicop_pair): 83 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair 84 | for cop in (bc_fast, bc_tll): 85 | pts = torch.rand(500, 2, dtype=torch.float64, device=cop.device) 86 | pdf = cop.pdf(pts) 87 | logp = cop.log_pdf(pts) 88 | # where pdf>0, log_pdf == log(pdf) 89 | mask = pdf.squeeze(1) > 0 90 | assert torch.allclose(logp[mask], pdf[mask].log(), atol=1e-6) 91 | 92 | 93 | def test_log_pdf_handles_zero(): 94 | cop = tvc.BiCop(num_step_grid=4) 95 | cop.is_indep = False 96 | # ! monkey‐patch pdf to always return zero 97 | cop.pdf = lambda obs: torch.zeros(obs.shape[0], 1, dtype=torch.float64) 98 | pts = torch.rand(100, 2, dtype=torch.float64) 99 | logp = cop.log_pdf(pts) 100 | # every entry should equal the neg‐infinity replacement 101 | assert torch.all(logp == -13.815510557964274) 102 | 103 | 104 | def test_sample_marginals(bicop_pair): 105 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair 106 | for cop in (bc_fast, bc_tll): 107 | for is_sobol in (False, True): 108 | samp = cop.sample(2000, seed=0, is_sobol=is_sobol) 109 | # samples lie in [0,1] 110 | assert samp.min() >= 0.0 and samp.max() <= 1.0 111 | # * marginal histograms should be roughly uniform 112 | counts_u = torch.histc(samp[:, 0], bins=10, min=0, max=1) 113 | counts_v = torch.histc(samp[:, 1], bins=10, min=0, max=1) 114 | # each bin ~200 ± 5 σ (σ≈√(N·p·(1−p))≈√(2000·0.1·0.9)≈13.4) 115 | assert counts_u.std() < 20 116 | assert counts_v.std() < 20 117 | 118 | 119 | def test_internal_buffers_and_flags(bicop_pair): 120 | _, _, _, U, bc_fast, bc_tll = bicop_pair 121 | for cop, mtd_kde in [(bc_fast, "fastKDE"), (bc_tll, "tll")]: 122 | print(cop) 123 | assert not cop.is_indep 124 | assert cop.mtd_kde == mtd_kde 125 | assert cop.num_obs == U.shape[0] 126 | # all the pre‐computed grids are the right shape 127 | m = cop.num_step_grid 128 | for name in ("_pdf_grid", "_cdf_grid", "_hfunc_l_grid", "_hfunc_r_grid"): 129 | grid = getattr(cop, name) 130 | assert grid.shape == (m, m) 131 | 132 | 133 | def test_tau_estimation(bicop_pair): 134 | _, _, _, U, bc_fast, bc_mtd_kde = bicop_pair 135 | # re‐fit with tau estimation 136 | bc = tvc.BiCop(num_step_grid=64) 137 | bc.fit(U, mtd_kde="tll", is_tau_est=True) 138 | # kendalltau must be nonzero for dependent data 139 | assert bc.tau[0].abs().item() > 0 140 | assert bc.tau[1].abs().item() >= 0 141 | 142 | 143 | def test_sample_shape_and_dtype_on_tll(bicop_pair): 144 | _, _, _, U, bc_fast, bc_tll = bicop_pair 145 | for cop in (bc_fast, bc_tll): 146 | s = cop.sample(123, seed=7, is_sobol=True) 147 | assert s.shape == (123, 2) 148 | assert s.dtype is cop.dtype 149 | assert s.device == cop.device 150 | 151 | 152 | def test_imshow_and_plot_api(bicop_pair): 153 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair 154 | cop = bc_fast 155 | # imshow 156 | fig, ax = cop.imshow(is_log_pdf=True) 157 | assert isinstance(fig, matplotlib.figure.Figure) 158 | assert isinstance(ax, matplotlib.axes.Axes) 159 | plt.close(fig) 160 | 161 | # contour 162 | fig2, ax2 = cop.plot(plot_type="contour", margin_type="unif") 163 | assert isinstance(fig2, matplotlib.figure.Figure) 164 | assert isinstance(ax2, matplotlib.axes.Axes) 165 | plt.close(fig2) 166 | fig2, ax2 = cop.plot(plot_type="contour", margin_type="norm") 167 | assert isinstance(fig2, matplotlib.figure.Figure) 168 | assert isinstance(ax2, matplotlib.axes.Axes) 169 | plt.close(fig2) 170 | 171 | # surface 172 | fig3, ax3 = cop.plot(plot_type="surface", margin_type="unif") 173 | assert isinstance(fig3, matplotlib.figure.Figure) 174 | assert isinstance(ax3, matplotlib.axes.Axes) 175 | plt.close(fig3) 176 | fig3, ax3 = cop.plot(plot_type="surface", margin_type="norm") 177 | assert isinstance(fig3, matplotlib.figure.Figure) 178 | assert isinstance(ax3, matplotlib.axes.Axes) 179 | plt.close(fig3) 180 | 181 | # invalid args 182 | with pytest.raises(ValueError): 183 | cop.plot(plot_type="foo") 184 | with pytest.raises(ValueError): 185 | cop.plot(margin_type="bar") 186 | 187 | 188 | def test_plot_accepts_unused_kwargs(bicop_pair): 189 | _, _, _, U, bc_fast, _ = bicop_pair 190 | # just ensure it doesn’t crash 191 | bc_fast.plot(plot_type="contour", margin_type="norm", xylim=(0, 1), grid_size=50) 192 | bc_fast.plot(plot_type="surface", margin_type="unif", xylim=(0, 1), grid_size=20) 193 | 194 | 195 | def test_reset_and_str(bicop_pair): 196 | # ! notice scope="module" so we put this test at the end 197 | family, params, rotation, U, bc_fast, bc_tll = bicop_pair 198 | for cop in (bc_fast, bc_tll): 199 | cop.reset() 200 | # should go back to independent 201 | assert cop.is_indep 202 | assert cop.num_obs == 0 203 | # __str__ contains key fields 204 | s = str(cop) 205 | assert "is_indep" in s and "num_obs" in s and "mtd_kde" in s 206 | 207 | 208 | @pytest.mark.parametrize("method", ["constant", "linear", "quadratic"]) 209 | def test_tll_methods_do_not_crash(U_tensor, method): 210 | cop = tvc.BiCop(num_step_grid=32) 211 | # should _not_ raise for any of the valid nonparametric_method names 212 | cop.fit(U_tensor, mtd_kde="tll", mtd_tll=method) 213 | 214 | 215 | def test_fit_invalid_method_raises(U_tensor): 216 | cop = tvc.BiCop(num_step_grid=32) 217 | with pytest.raises(RuntimeError): 218 | # pick something bogus 219 | cop.fit(U_tensor, mtd_kde="tll", mtd_tll="no_such_method") 220 | 221 | 222 | def test_interp_on_trivial_grid(): 223 | # make a BiCop with a 2×2 grid 224 | bc = tvc.BiCop(num_step_grid=2) 225 | # override the geometry so that step_grid == 1.0 and target == 1 226 | bc.step_grid = 1.0 227 | bc._target = 1.0 228 | bc._EPS = 0.0 # so we don't get any clamping at the edges 229 | 230 | # grid: 231 | # g00 = 0, g01 = 1 232 | # g10 = 2, g11 = 3 233 | grid = torch.tensor([[0.0, 1.0], [2.0, 3.0]], dtype=torch.float64) 234 | 235 | # corners should map exactly: 236 | pts = torch.tensor( 237 | [ 238 | [0.0, 0.0], # g00 239 | [0.0, 1.0], # g01 240 | [1.0, 0.0], # g10 241 | [1.0, 1.0], # g11 242 | ], 243 | dtype=torch.float64, 244 | ) 245 | out = bc._interp(grid, pts) 246 | assert torch.allclose(out, torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float64)) 247 | 248 | # center point should average correctly: 249 | center = torch.tensor([[0.5, 0.5]], dtype=torch.float64) 250 | val = bc._interp(grid, center) 251 | # manual bilinear: 0 + (2−0)*.5 + (1−0)*.5 + (3−1−2+0)*.5*.5 = 1.5 252 | assert torch.allclose(val, torch.tensor([[1.5]], dtype=torch.float64)) 253 | 254 | # if you ask for out-of-bounds it should clamp to [0,1] and then pick the corner: 255 | # e.g. (−1, −1)→(0,0), (2,3)→(1,1) 256 | oob = torch.tensor([[-1.0, -1.0], [2.0, 3.0]], dtype=torch.float64) 257 | val_oob = bc._interp(grid, oob) 258 | assert torch.allclose(val_oob, torch.tensor([0.0, 3.0], dtype=torch.float64)) 259 | 260 | 261 | def test_imshow_with_existing_axes(): 262 | cop = tvc.BiCop(num_step_grid=32) 263 | us = torch.rand(100, 2) 264 | cop.fit(us, mtd_kde="fastKDE") 265 | fig, outer_ax = plt.subplots() 266 | fig2, ax2 = cop.imshow(is_log_pdf=False, ax=outer_ax, cmap="viridis") 267 | # should have returned the same axes object 268 | assert ax2 is outer_ax 269 | plt.close(fig) 270 | plt.close(fig2) 271 | 272 | 273 | def test_independent_copula_properties(): 274 | for cop in (tvc.BiCop(num_step_grid=64), tvc.BiCop(num_step_grid=16)): 275 | # before fit, should be independent 276 | # CDF(u,v) = u·v, PDF(u,v)=1, hfunc_r(u,v)=u, hfunc_l(u,v)=v 277 | us = torch.rand(1000, 2, dtype=torch.float64) 278 | cdf = cop.cdf(us) 279 | assert torch.allclose(cdf, (us[:, 0] * us[:, 1]).unsqueeze(1)) 280 | pdf = cop.pdf(us) 281 | assert torch.allclose(pdf, torch.ones_like(pdf)) 282 | logpdf = cop.log_pdf(us) 283 | assert torch.allclose(logpdf, torch.zeros_like(logpdf)) 284 | hr = cop.hfunc_r(us) 285 | hl = cop.hfunc_l(us) 286 | assert torch.allclose(hr, us[:, [0]]) 287 | assert torch.allclose(hl, us[:, [1]]) 288 | hir = cop.hinv_r(us) 289 | hil = cop.hinv_l(us) 290 | assert torch.allclose(hir, us[:, [0]]) 291 | assert torch.allclose(hil, us[:, [1]]) 292 | -------------------------------------------------------------------------------- /examples/pred_intvl/models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from blitz.modules import BayesianLinear 6 | from blitz.utils import variational_estimator 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | class EncoderRegressor(nn.Module): 11 | def __init__( 12 | self, 13 | input_dim: int = 8, 14 | latent_dim: int = 10, 15 | hidden_dim1: int = 128, 16 | hidden_dim2: int = 64, 17 | hidden_dim3: int = 32, 18 | hidden_dim4: int = 16, 19 | hidden_dim5: int = 8, 20 | head_dim: int = 16, 21 | p_drop: float = 0.3, 22 | ): 23 | super().__init__() 24 | # * MLP encoder mapping ℝ⁸ → ℝ^{latent_dim} 25 | self.encoder = nn.Sequential( 26 | nn.Linear(input_dim, hidden_dim1), 27 | nn.BatchNorm1d(hidden_dim1), 28 | nn.LeakyReLU(0.1, inplace=True), 29 | nn.Dropout(p_drop), 30 | # 31 | nn.Linear(hidden_dim1, hidden_dim2), 32 | nn.BatchNorm1d(hidden_dim2), 33 | nn.LeakyReLU(0.1, inplace=True), 34 | nn.Dropout(p_drop), 35 | # 36 | nn.Linear(hidden_dim2, latent_dim), 37 | nn.BatchNorm1d(latent_dim), 38 | nn.LeakyReLU(0.1, inplace=True), 39 | nn.Dropout(p_drop), 40 | # 41 | nn.Linear(latent_dim, hidden_dim3), 42 | nn.BatchNorm1d(hidden_dim3), 43 | nn.LeakyReLU(0.1, inplace=True), 44 | nn.Dropout(p_drop), 45 | # 46 | nn.Linear(hidden_dim3, hidden_dim4), 47 | nn.BatchNorm1d(hidden_dim4), 48 | nn.LeakyReLU(0.1, inplace=True), 49 | nn.Dropout(p_drop), 50 | # 51 | nn.Linear(hidden_dim4, hidden_dim5), 52 | nn.BatchNorm1d(hidden_dim5), 53 | nn.LeakyReLU(0.1, inplace=True), 54 | nn.Dropout(p_drop), 55 | # 56 | nn.Linear(hidden_dim5, latent_dim), 57 | nn.BatchNorm1d(latent_dim), 58 | nn.Tanh(), # keeps latent coords in [–1,1] 59 | ) 60 | # --- head: ℝ^{latent_dim} → ℝ (house value) --- 61 | self.head = nn.Sequential( 62 | nn.Linear(latent_dim, head_dim), 63 | nn.LeakyReLU(0.1, inplace=True), 64 | nn.Dropout(p_drop / 2), 65 | nn.Linear(head_dim, 1), 66 | ) 67 | 68 | def forward(self, x: torch.Tensor): 69 | # x: [batch, 8] 70 | z = self.encoder(x) # [batch, latent_dim] 71 | y_hat = self.head(z) # [batch, num_outputs] 72 | return y_hat, z 73 | 74 | 75 | def train_epoch(model, loader, optim, criterion, device): 76 | model.train() 77 | total_loss = 0.0 78 | for x, y in loader: 79 | x, y = x.to(device), y.to(device) 80 | optim.zero_grad() # * claer out grad from last step 81 | y_hat, _ = model(x) 82 | # * if using MSELoss, we want y to be float and same shape 83 | loss = criterion( 84 | y_hat, 85 | y.float() if isinstance(criterion, nn.MSELoss) else y, 86 | ) 87 | loss.backward() # * compute new grad 88 | optim.step() # * update weights with fresh grad 89 | total_loss += loss.item() * x.size(0) 90 | return total_loss / len(loader.dataset) 91 | 92 | 93 | @torch.no_grad() 94 | def eval_epoch(model, loader, criterion, device): 95 | model.eval() 96 | total_loss = 0.0 97 | for x, y in loader: 98 | x, y = x.to(device), y.to(device) 99 | y_hat, _ = model(x) 100 | # * if using MSELoss, we want y to be float and same shape 101 | total_loss += criterion( 102 | y_hat, 103 | y.float() if isinstance(criterion, nn.MSELoss) else y, 104 | ).item() * x.size(0) 105 | return total_loss / len(loader.dataset) 106 | 107 | 108 | def train_base_model( 109 | train_loader, 110 | val_loader, 111 | input_dim, 112 | lr=1e-3, 113 | weight_decay=0.91, 114 | p_drop=0.5, 115 | latent_dim=10, 116 | num_epoch=100, 117 | patience=5, 118 | device=None, 119 | ): 120 | model = EncoderRegressor(input_dim=input_dim, latent_dim=latent_dim, p_drop=p_drop).to(device) 121 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 122 | criterion = nn.MSELoss() 123 | best_state = copy.deepcopy(model.state_dict()) 124 | best_loss = float("inf") 125 | wait = 0 126 | for epoch in range(1, num_epoch + 1): 127 | train_loss = train_epoch(model, train_loader, optimizer, criterion, device) 128 | val_loss = eval_epoch(model, val_loader, criterion, device) 129 | if val_loss < best_loss: 130 | best_loss = val_loss 131 | best_state = copy.deepcopy(model.state_dict()) 132 | wait = 0 133 | else: 134 | wait += 1 135 | if wait > patience: 136 | # print(f"Early stopping at epoch {epoch}") 137 | break 138 | # print(f"Epoch {epoch:2d} — train loss: {loss:.4f}") 139 | model.load_state_dict(best_state) 140 | return model.to(device) 141 | 142 | 143 | @torch.no_grad() 144 | def extract_features(model, loader, device): 145 | model.eval() 146 | all_z, all_y = [], [] 147 | for x, y in loader: 148 | x = x.to(device) 149 | _, z = model(x) 150 | all_z.append(z.cpu()) 151 | all_y.append(y.cpu().float()) # ensure shape [B,1] 152 | Z = torch.cat(all_z, dim=0) # [N, latent_dim] 153 | Y = torch.cat(all_y, dim=0) # [N, 1] 154 | return Z, Y 155 | 156 | 157 | def train_ensemble( 158 | M, 159 | train_loader: DataLoader, 160 | val_loader: DataLoader, 161 | input_dim, 162 | latent_dim, 163 | device, 164 | num_epochs, 165 | lr=1e-3, 166 | weight_decay=0.91, 167 | patience=5, 168 | ): 169 | ensemble = nn.ModuleList() 170 | for m in range(M): 171 | torch.manual_seed(m) 172 | model_m = train_base_model( 173 | train_loader=train_loader, 174 | val_loader=val_loader, 175 | input_dim=input_dim, 176 | latent_dim=latent_dim, 177 | lr=lr, 178 | weight_decay=weight_decay, 179 | p_drop=0.0, 180 | num_epoch=num_epochs, 181 | patience=patience, 182 | device=device, 183 | ) 184 | ensemble.append(model_m.cpu()) 185 | return ensemble.to(device) 186 | 187 | 188 | @torch.no_grad() 189 | def ensemble_pred_intvl(ensemble, x, device, alpha=0.05): 190 | preds = [] 191 | for model in ensemble: 192 | model = model.to(device) 193 | y_hat, _ = model(x.to(device)) # [batch,1] 194 | preds.append(y_hat.cpu()) 195 | P = torch.stack(preds, dim=0) # [M, batch, 1] 196 | mean = P.mean(dim=0) # [batch,1] 197 | lo = P.quantile(alpha / 2, dim=0) # [batch,1] 198 | hi = P.quantile(1 - alpha / 2, dim=0) # [batch,1] 199 | return ( 200 | mean.cpu().flatten().numpy(), 201 | lo.cpu().flatten().numpy(), 202 | hi.cpu().flatten().numpy(), 203 | ) 204 | 205 | 206 | @torch.no_grad() 207 | def mc_dropout_pred_intvl(model, x, T=100, device=None, alpha=0.05): 208 | """ 209 | x: Tensor [batch,1,28,28] 210 | returns mean, lower, upper: each [batch,1] 211 | """ 212 | model.eval() 213 | x = x.to(device) 214 | 215 | # * re-enable dropout layers 216 | for m in model.modules(): 217 | if isinstance(m, nn.Dropout): 218 | m.train() 219 | 220 | # * collect T predictions 221 | preds = [] 222 | for _ in range(T): 223 | y_hat, _ = model(x) # [batch,1] 224 | preds.append(y_hat.cpu()) 225 | P = torch.stack(preds, dim=0) # [T, batch, 1] 226 | 227 | mu = P.mean(dim=0) # [batch,1] 228 | lower = P.quantile(alpha / 2, dim=0) # [batch,1] 229 | upper = P.quantile(1 - alpha / 2, dim=0) # [batch,1] 230 | return ( 231 | mu.cpu().flatten().numpy(), 232 | lower.cpu().flatten().numpy(), 233 | upper.cpu().flatten().numpy(), 234 | ) 235 | 236 | 237 | @variational_estimator 238 | class BayesianEncoderRegressor(nn.Module): 239 | def __init__( 240 | self, 241 | input_dim: int, 242 | latent_dim=10, 243 | hidden_dim1=128, 244 | hidden_dim2=64, 245 | hidden_dim3=32, 246 | hidden_dim4=16, 247 | hidden_dim5=8, 248 | head_dim=16, 249 | num_outputs=1, 250 | ): 251 | super().__init__() 252 | # -- Encoder: vector 2 latent z 253 | self._encoder = nn.Sequential( 254 | BayesianLinear(input_dim, hidden_dim1), 255 | nn.LeakyReLU(inplace=True), 256 | BayesianLinear(hidden_dim1, hidden_dim2), 257 | nn.LeakyReLU(inplace=True), 258 | BayesianLinear(hidden_dim2, hidden_dim3), 259 | nn.LeakyReLU(inplace=True), 260 | BayesianLinear(hidden_dim3, hidden_dim4), 261 | nn.LeakyReLU(inplace=True), 262 | BayesianLinear(hidden_dim4, hidden_dim5), 263 | nn.LeakyReLU(inplace=True), 264 | BayesianLinear(hidden_dim5, latent_dim), 265 | nn.Tanh(), # keeps latent coords in [–1,1] 266 | ) 267 | # -- Head: z 2 y 268 | self._head = nn.Sequential( 269 | BayesianLinear(latent_dim, head_dim), 270 | nn.LeakyReLU(0.1, inplace=True), 271 | BayesianLinear(head_dim, num_outputs), 272 | ) 273 | 274 | def forward(self, x): 275 | # x: [batch, input_dim] 276 | z = self._encoder(x) 277 | y = self._head(z) 278 | return y 279 | 280 | def encode(self, x): 281 | return self._encoder(x) 282 | 283 | 284 | def train_bnn( 285 | train_loader: DataLoader, 286 | val_loader: DataLoader, 287 | input_dim, 288 | latent_dim, 289 | num_epochs, 290 | lr=1e-3, 291 | weight_decay=0.91, 292 | device=None, 293 | patience=5, 294 | ): 295 | """Trains one BayesianEncoderRegressor with ELBO-loss and early stopping on 296 | validation MSE. 297 | 298 | Returns the model with best val-MSE. 299 | """ 300 | model = BayesianEncoderRegressor( 301 | input_dim=input_dim, 302 | latent_dim=latent_dim, 303 | ).to(device) 304 | optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 305 | criterion = nn.MSELoss() 306 | best_state = copy.deepcopy(model.state_dict()) 307 | best_val_loss = float("inf") 308 | wait = 0 309 | for epoch in range(num_epochs): 310 | model.train() 311 | for x, y in train_loader: 312 | x, y = x.to(device), y.to(device) 313 | optim.zero_grad() 314 | # * sample_elbo will call forward(x) → y 315 | loss = model.sample_elbo( 316 | inputs=x, 317 | labels=y, 318 | criterion=criterion, 319 | sample_nbr=3, 320 | complexity_cost_weight=1 / len(train_loader.dataset), 321 | ) 322 | loss.backward() 323 | optim.step() 324 | model.eval() 325 | val_losses = [] 326 | with torch.no_grad(): 327 | for x, y in val_loader: 328 | x, y = x.to(device), y.to(device) 329 | preds = [] 330 | for _ in range(10): 331 | preds.append(model(x)) 332 | y_hat = torch.stack(preds, 0).mean(dim=0).to(device) 333 | val_losses.append(criterion(y_hat, y.float()).item() * x.size(0)) 334 | val_loss = sum(val_losses) / len(val_loader.dataset) 335 | if val_loss < best_val_loss - 1e-6: 336 | best_val_loss = val_loss 337 | best_state = copy.deepcopy(model.state_dict()) 338 | wait = 0 339 | else: 340 | wait += 1 341 | if wait >= patience: 342 | # print(f"Early stopping at epoch {epoch}") 343 | break 344 | model.load_state_dict(best_state) 345 | return model.to(device) 346 | 347 | 348 | @torch.no_grad() 349 | def bnn_pred_intvl(model, x, T=200, alpha=0.05, device=None): 350 | model.train() # keep Bayesian layers sampling 351 | preds = [] 352 | for _ in range(T): 353 | out = model(x.to(device)) 354 | # out might be y, or (y,z), or (y,z,kl) — we only need the first thing 355 | y_hat = out if isinstance(out, torch.Tensor) else out[0] 356 | preds.append(y_hat.cpu()) 357 | P = torch.stack(preds, 0) # [T, batch, 1] 358 | mu = P.mean(dim=0) # [batch,1] 359 | lower = P.quantile(alpha / 2, dim=0) 360 | upper = P.quantile(1 - alpha / 2, dim=0) 361 | return ( 362 | mu.cpu().flatten().numpy(), 363 | lower.cpu().flatten().numpy(), 364 | upper.cpu().flatten().numpy(), 365 | ) 366 | -------------------------------------------------------------------------------- /examples/vcae/vcae/model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import copy 3 | from dataclasses import asdict 4 | from typing import Optional 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.utils.data import DataLoader, Dataset, random_split 11 | from torchvision import transforms 12 | from torchvision.datasets import MNIST, SVHN 13 | 14 | import torchvinecopulib as tvc 15 | 16 | from .config import DEVICE, Config 17 | 18 | 19 | class LitAutoencoder(pl.LightningModule, abc.ABC): 20 | def __init__(self, config: Config) -> None: 21 | """Initialize the autoencoder with the given configuration.""" 22 | super().__init__() 23 | self.save_hyperparameters(asdict(config)) 24 | self.flat_dim = int(torch.prod(torch.tensor(self.hparams["dims"]))) 25 | 26 | # Placeholders for data attributes 27 | self.data_test: Optional[Dataset] = None 28 | self.data_train: Optional[Dataset] = None 29 | self.data_val: Optional[Dataset] = None 30 | 31 | # Placeholder for the vine copula 32 | self.vine: Optional[tvc.VineCop] = None 33 | 34 | # Call subclass-defined builders 35 | self.encoder = self.build_encoder() 36 | self.decoder = self.build_decoder() 37 | self.transform = self.build_transform() 38 | 39 | @property 40 | @abc.abstractmethod 41 | def dataset_cls(self) -> type: 42 | """Subclasses must return the dataset class (e.g., MNIST, SVHN).""" 43 | raise NotImplementedError 44 | 45 | @property 46 | @abc.abstractmethod 47 | def dataset_kwargs(self) -> dict: 48 | """Subclasses must return a dictionary of keyword arguments for the dataset.""" 49 | raise NotImplementedError 50 | 51 | @abc.abstractmethod 52 | def build_encoder(self) -> nn.Module: 53 | """Subclasses must return an nn.Module mapping x -> z""" 54 | raise NotImplementedError 55 | 56 | @abc.abstractmethod 57 | def build_decoder(self) -> nn.Module: 58 | """Subclasses must return an nn.Module mapping z -> x̂""" 59 | raise NotImplementedError 60 | 61 | @abc.abstractmethod 62 | def build_transform(self) -> transforms.Compose: 63 | """Subclasses must return a torchvision transforms.Compose for data preprocessing.""" 64 | raise NotImplementedError 65 | 66 | def copy_with_config(self, new_config: Config) -> "LitAutoencoder": 67 | """Create a copy of the model with a new configuration (for refit).""" 68 | new_model = self.__class__(new_config) 69 | new_model.encoder.load_state_dict(self.encoder.state_dict()) 70 | new_model.decoder.load_state_dict(self.decoder.state_dict()) 71 | if self.vine is not None: 72 | new_model.vine = copy.deepcopy(self.vine) 73 | return new_model 74 | 75 | def set_vine(self, vine: tvc.VineCop) -> None: 76 | if not isinstance(vine, tvc.VineCop): 77 | raise ValueError("Vine must be of type tvc.VineCop for tvc.") 78 | latent_size: int = self.hparams["latent_size"] 79 | if not vine.num_dim == latent_size: 80 | raise ValueError( 81 | f"Vine dimension {vine.num_dim} does not match latent size {latent_size}." 82 | ) 83 | self.vine = vine 84 | self.add_module("vine", vine) 85 | 86 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]: 87 | latent_size: int = self.hparams["latent_size"] 88 | dims: tuple[int, ...] = self.hparams["dims"] 89 | z = self.encoder(x) 90 | x_hat = self.decoder(z) 91 | return x_hat.view(-1, *dims), z.view(-1, latent_size) 92 | 93 | def compute_loss(self, x: torch.Tensor) -> torch.Tensor: 94 | x_hat, z = self(x) 95 | loss = F.mse_loss(x_hat, x) 96 | if self.hparams["vine_lambda"] > 0: 97 | if self.vine is None: 98 | raise ValueError("Vine must be set before computing the loss.") 99 | vine_loss = -self.vine.log_pdf(z).mean() 100 | vine_lambda: float = self.hparams["vine_lambda"] 101 | loss += vine_lambda * vine_loss 102 | # use_mmd: bool = self.hparams["use_mmd"] 103 | # if use_mmd: 104 | # mmd_sigmas: list = self.hparams["mmd_sigmas"] 105 | # mmd_lambda: float = self.hparams["mmd_lambda"] 106 | # z_vine = self.vine.sample(x.shape[0]) 107 | # z_vine = torch.tensor(z_vine, dtype=z.dtype, device=x.device) 108 | # x_vine = self.decoder(z_vine) 109 | # mmd_loss = mmd(x, x_vine, sigmas=mmd_sigmas) 110 | # loss += mmd_lambda * mmd_loss 111 | return loss 112 | 113 | def training_step( 114 | self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 115 | ) -> torch.Tensor: 116 | """Training step to compute loss on training data.""" 117 | x, _ = batch 118 | x.to(DEVICE) 119 | loss = self.compute_loss(x) 120 | self.log("train_loss", loss, prog_bar=True) 121 | return loss 122 | 123 | def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: 124 | """Validation step to compute loss on validation data.""" 125 | x, _ = batch 126 | x.to(DEVICE) 127 | loss = self.compute_loss(x) 128 | self.log("val_loss", loss, prog_bar=True) 129 | 130 | def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: 131 | """Test step to compute loss on test data.""" 132 | x, _ = batch 133 | x.to(DEVICE) 134 | loss = self.compute_loss(x) 135 | self.log("test_loss", loss, prog_bar=True) 136 | 137 | def configure_optimizers(self) -> torch.optim.Optimizer: 138 | """Returns the optimizer for training.""" 139 | learning_rate: float = self.hparams["learning_rate"] 140 | return torch.optim.Adam(self.parameters(), lr=learning_rate) 141 | 142 | def prepare_data(self) -> None: 143 | """Download the dataset if not already present.""" 144 | data_dir = self.hparams["data_dir"] 145 | self.dataset_cls(data_dir, download=True, **self.dataset_kwargs["train"]) 146 | self.dataset_cls(data_dir, download=True, **self.dataset_kwargs["test"]) 147 | 148 | def setup(self, stage=None) -> None: 149 | """Setup the datasets for training, validation, and testing.""" 150 | data_dir = self.hparams["data_dir"] 151 | 152 | if stage in ("fit", None): 153 | data_full = self.dataset_cls( 154 | data_dir, transform=self.transform, **self.dataset_kwargs["train"] 155 | ) 156 | n_total = len(data_full) 157 | n_val = int(self.hparams["val_train_split"] * n_total) 158 | n_train = n_total - n_val 159 | 160 | generator = torch.Generator().manual_seed(self.hparams["seed"]) 161 | self.data_train, self.data_val = random_split( 162 | data_full, [n_train, n_val], generator=generator 163 | ) 164 | 165 | if stage in ("test", None): 166 | self.data_test = self.dataset_cls( 167 | data_dir, transform=self.transform, **self.dataset_kwargs["test"] 168 | ) 169 | 170 | def train_dataloader(self) -> DataLoader: 171 | """Returns the training dataloader.""" 172 | if self.data_train is None: 173 | self.setup(stage="fit") 174 | assert self.data_train is not None 175 | batch_size: int = self.hparams["batch_size"] 176 | num_workers: int = self.hparams["num_workers"] 177 | return DataLoader( 178 | self.data_train, 179 | batch_size=batch_size, 180 | pin_memory=True, 181 | persistent_workers=True, 182 | num_workers=num_workers, 183 | ) 184 | 185 | def val_dataloader(self) -> DataLoader: 186 | """Returns the validation dataloader.""" 187 | if self.data_val is None: 188 | self.setup(stage="fit") 189 | assert self.data_val is not None 190 | batch_size: int = self.hparams["batch_size"] 191 | num_workers: int = self.hparams["num_workers"] 192 | return DataLoader( 193 | self.data_val, 194 | batch_size=batch_size, 195 | pin_memory=True, 196 | persistent_workers=True, 197 | num_workers=num_workers, 198 | ) 199 | 200 | def test_dataloader(self) -> DataLoader: 201 | """Returns the test dataloader.""" 202 | if self.data_test is None: 203 | self.setup(stage="test") 204 | assert self.data_test is not None 205 | batch_size: int = self.hparams["batch_size"] 206 | num_workers: int = self.hparams["num_workers"] 207 | return DataLoader( 208 | self.data_test, 209 | batch_size=batch_size, 210 | pin_memory=True, 211 | persistent_workers=True, 212 | num_workers=num_workers, 213 | ) 214 | 215 | def get_data( 216 | self, stage: str = "fit" 217 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 218 | """Extracts representations, labels, data, decoded outputs, and samples (e.g., to compute metrics).""" 219 | if stage == "fit" or stage is None: 220 | data_loader = self.train_dataloader() 221 | elif stage == "test": 222 | data_loader = self.test_dataloader() 223 | representations = [] 224 | decoded = [] 225 | labels = [] 226 | data = [] 227 | samples = [] 228 | encoder_device = next(self.encoder.parameters()).device 229 | decoder_device = next(self.decoder.parameters()).device 230 | for batch in data_loader: 231 | x, y = batch 232 | x = x.to(encoder_device) 233 | with torch.no_grad(): 234 | z = self.encoder(x).to(decoder_device) 235 | x_hat = self.decoder(z) 236 | if self.vine is not None: 237 | sample = self.vine.sample(x.shape[0]) 238 | sample = self.decoder( 239 | torch.tensor(sample, dtype=z.dtype, device=decoder_device) 240 | ) 241 | decoded.append(x_hat) 242 | representations.append(z) 243 | labels.append(y) 244 | data.append(x) 245 | if self.vine is not None: 246 | samples.append(sample) 247 | 248 | # Concatenate into a single tensor 249 | representations_tensor = torch.cat(representations, dim=0) 250 | labels_tensor = torch.cat(labels, dim=0) 251 | data_tensor = torch.cat(data, dim=0).flatten(start_dim=1).flatten(start_dim=1) 252 | decoded_tensor = torch.cat(decoded, dim=0).flatten(start_dim=1) 253 | samples_tensor = ( 254 | torch.cat(samples, dim=0).flatten(start_dim=1) if self.vine is not None else None 255 | ) 256 | 257 | return representations_tensor, labels_tensor, data_tensor, decoded_tensor, samples_tensor 258 | 259 | def learn_vine(self, n_samples: int = 5000) -> None: 260 | """Learn the vine copula from a subset of representations.""" 261 | if self.data_train is None: 262 | self.setup(stage="fit") 263 | representations, _, _, _, _ = self.get_data(stage="fit") 264 | 265 | representations_subset = representations[ 266 | torch.randperm(representations.shape[0])[:n_samples] 267 | ] 268 | vine_tvc = tvc.VineCop( 269 | num_dim=representations_subset.shape[1], 270 | is_cop_scale=False, 271 | num_step_grid=30, 272 | ).to(DEVICE) 273 | vine_tvc.fit( 274 | obs=representations_subset, 275 | mtd_kde="tll", 276 | ) 277 | self.set_vine(vine_tvc) 278 | 279 | 280 | class LitMNISTAutoencoder(LitAutoencoder): 281 | def __init__(self, config: Config) -> None: 282 | super().__init__(config) 283 | 284 | def build_transform(self) -> transforms.Compose: 285 | """Returns a torchvision transforms.Compose for MNIST preprocessing.""" 286 | return transforms.Compose([transforms.ToTensor()]) 287 | 288 | @property 289 | def dataset_cls(self) -> type: 290 | """Returns the dataset class for MNIST.""" 291 | return MNIST 292 | 293 | @property 294 | def dataset_kwargs(self) -> dict: 295 | """Returns a dictionary of keyword arguments for the MNIST dataset.""" 296 | return { 297 | "train": {"train": True}, 298 | "test": {"train": False}, 299 | } 300 | 301 | def build_encoder(self) -> nn.Module: 302 | """Returns a fully connected encoder for MNIST.""" 303 | # Encoder: flatten → hidden → latent 304 | latent_size: int = self.hparams["latent_size"] 305 | hidden_size: int = self.hparams["hidden_size"] 306 | return nn.Sequential( 307 | nn.Flatten(), 308 | nn.Linear(self.flat_dim, hidden_size), 309 | nn.ReLU(), 310 | nn.Linear(hidden_size, hidden_size // 2), 311 | nn.ReLU(), 312 | nn.Linear(hidden_size // 2, latent_size), 313 | ).to(DEVICE) 314 | 315 | def build_decoder(self) -> nn.Module: 316 | """Returns a fully connected decoder for MNIST.""" 317 | # Decoder: latent → hidden → image 318 | latent_size: int = self.hparams["latent_size"] 319 | hidden_size: int = self.hparams["hidden_size"] 320 | return nn.Sequential( 321 | nn.Linear(latent_size, hidden_size // 2), 322 | nn.ReLU(), 323 | nn.Linear(hidden_size // 2, hidden_size), 324 | nn.ReLU(), 325 | nn.Linear(hidden_size, self.flat_dim), 326 | nn.Sigmoid(), # Ensure output in [0,1] range 327 | ).to(DEVICE) 328 | 329 | 330 | class LitSVHNAutoencoder(LitAutoencoder): 331 | def __init__(self, config: Config) -> None: 332 | super().__init__(config) 333 | 334 | def build_transform(self) -> transforms.Compose: 335 | """Returns a torchvision transforms.Compose for SVHN preprocessing.""" 336 | return transforms.Compose([transforms.ToTensor()]) 337 | 338 | @property 339 | def dataset_cls(self) -> type: 340 | """Returns the dataset class for SVHN.""" 341 | return SVHN 342 | 343 | @property 344 | def dataset_kwargs(self) -> dict: 345 | """Returns a dictionary of keyword arguments for the SVHN dataset.""" 346 | return { 347 | "train": {"split": "train"}, 348 | "test": {"split": "test"}, 349 | } 350 | 351 | def build_encoder(self) -> nn.Module: 352 | """Returns a convolutional encoder for SVHN.""" 353 | latent_size = self.hparams["latent_size"] 354 | hidden_size = self.hparams["hidden_size"] 355 | return nn.Sequential( 356 | nn.Conv2d( 357 | 3, hidden_size // 4, kernel_size=4, stride=2, padding=1 358 | ), # → [B, 32, 16, 16] 359 | nn.ReLU(), 360 | nn.Conv2d( 361 | hidden_size // 4, hidden_size // 2, kernel_size=4, stride=2, padding=1 362 | ), # → [B, 64, 8, 8] 363 | nn.ReLU(), 364 | nn.Conv2d( 365 | hidden_size // 2, hidden_size, kernel_size=4, stride=2, padding=1 366 | ), # → [B, 128, 4, 4] 367 | nn.ReLU(), 368 | nn.Flatten(), 369 | nn.Linear(hidden_size * 4 * 4, latent_size), 370 | ).to(DEVICE) 371 | 372 | def build_decoder(self) -> nn.Module: 373 | """Returns a convolutional decoder for SVHN.""" 374 | latent_size = self.hparams["latent_size"] 375 | hidden_size = self.hparams["hidden_size"] 376 | return nn.Sequential( 377 | nn.Linear(latent_size, hidden_size * 4 * 4), 378 | nn.ReLU(), 379 | nn.Unflatten(1, (hidden_size, 4, 4)), 380 | nn.ConvTranspose2d( 381 | hidden_size, hidden_size // 2, kernel_size=4, stride=2, padding=1 382 | ), # → [B, 64, 8, 8] 383 | nn.ReLU(), 384 | nn.ConvTranspose2d( 385 | hidden_size // 2, hidden_size // 4, kernel_size=4, stride=2, padding=1 386 | ), # → [B, 32, 16, 16] 387 | nn.ReLU(), 388 | nn.ConvTranspose2d( 389 | hidden_size // 4, 3, kernel_size=4, stride=2, padding=1 390 | ), # → [B, 3, 32, 32] 391 | nn.Sigmoid(), # For pixel values in [0,1] 392 | ).to(DEVICE) -------------------------------------------------------------------------------- /torchvinecopulib/util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | torchvinecopulib.util 3 | ---------------------- 4 | Utility routines for copula‐based dependence measures, 1D KDE CDF/PPF, and root-finding via the Interpolate Truncate and Project (ITP) method. 5 | 6 | Decorators 7 | ----------- 8 | * torch.compile() for solve_ITP() 9 | * torch.no_grad() for solve_ITP(), kendall_tau(), mutual_info(), ferreira_tail_dep_coeff(), chatterjee_xi() 10 | 11 | References 12 | ----------- 13 | - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D., & O’Brien, J. P. (2016). A fast and objective multidimensional kernel density estimation method: fastKDE. Computational Statistics & Data Analysis, 101, 148-160. 14 | - O’Brien, T. A., Collins, W. D., Rauscher, S. A., & Ringler, T. D. (2014). Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Computational Statistics & Data Analysis, 79, 222-234. 15 | - Purkayastha, S., & Song, P. X. K. (2024). fastMI: A fast and consistent copula-based nonparametric estimator of mutual information. Journal of Multivariate Analysis, 201, 105270. 16 | - Ferreira, M. S. (2013). Nonparametric estimation of the tail-dependence coefficient. 17 | - Chatterjee, S. (2021). A new coefficient of correlation. Journal of the American Statistical Association, 116(536), 2009-2022. 18 | - Lin, Z., & Han, F. (2023). On boosting the power of Chatterjee’s rank correlation. Biometrika, 110(2), 283-299. 19 | - Oliveira, I. F., & Takahashi, R. H. (2020). An enhancement of the bisection method average performance preserving minmax optimality. ACM Transactions on Mathematical Software (TOMS), 47(1), 1-24. 20 | """ 21 | 22 | import enum 23 | from pprint import pformat 24 | 25 | import fastkde 26 | import torch 27 | from scipy.stats import kendalltau 28 | 29 | _EPS = 1e-10 30 | 31 | 32 | @torch.no_grad() 33 | def kendall_tau(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 34 | """Compute Kendall's tau correlation coefficient and p-value. Moves inputs to CPU and delegates 35 | to SciPy’s ``kendalltau``. 36 | 37 | Args: 38 | x (torch.Tensor): shape (n, 1) 39 | y (torch.Tensor): shape (n, 1) 40 | Returns: 41 | torch.Tensor: Kendall's tau correlation coefficient and p-value 42 | """ 43 | return torch.as_tensor( 44 | kendalltau(x.view(-1).cpu(), y.view(-1).cpu()), 45 | dtype=x.dtype, 46 | device=x.device, 47 | ) 48 | 49 | 50 | @torch.no_grad() 51 | def mutual_info(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 52 | """Estimate mutual information using ``fastKDE``. Moves inputs to CPU and delegates to 53 | ``fastKDE.pdf``. 54 | 55 | - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D., & O’Brien, J. P. (2016). A fast and objective multidimensional kernel density estimation method: fastKDE. Computational Statistics & Data Analysis, 101, 148-160. 56 | - O’Brien, T. A., Collins, W. D., Rauscher, S. A., & Ringler, T. D. (2014). Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Computational Statistics & Data Analysis, 79, 222-234. 57 | - Purkayastha, S., & Song, P. X. K. (2024). fastMI: A fast and consistent copula-based nonparametric estimator of mutual information. Journal of Multivariate Analysis, 201, 105270. 58 | 59 | Args: 60 | x (torch.Tensor): shape (n, 1) 61 | y (torch.Tensor): shape (n, 1) 62 | Returns: 63 | torch.Tensor: Estimated mutual information 64 | """ 65 | x = x.clamp(_EPS, 1.0 - _EPS).view(-1).cpu() 66 | y = y.clamp(_EPS, 1.0 - _EPS).view(-1).cpu() 67 | joint = torch.as_tensor(fastkde.pdf(x, y).values, dtype=x.dtype, device=x.device) 68 | margin_x = torch.as_tensor(fastkde.pdf(x).values, dtype=x.dtype, device=x.device) 69 | margin_y = torch.as_tensor(fastkde.pdf(y).values, dtype=x.dtype, device=x.device) 70 | return ( 71 | joint[joint > 0.0].log().mean() 72 | - margin_x[margin_x > 0.0].log().mean() 73 | - margin_y[margin_y > 0.0].log().mean() 74 | ) 75 | 76 | 77 | @torch.no_grad() 78 | def ferreira_tail_dep_coeff(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 79 | """Estimate tail dependence coefficient (λ), modifed from Ferreira's method, symmetric for 80 | (x,y), (y,1-x), (1-x,1-y), (1-y,x), (y,x), (1-x,y), (1-y,1-x), (x,1-y). 81 | 82 | - Ferreira, M. S. (2013). Nonparametric estimation of the tail-dependence coefficient. 83 | 84 | Args: 85 | x (torch.Tensor): shape (n, 1) 86 | y (torch.Tensor): shape (n, 1) 87 | Returns: 88 | torch.Tensor: Estimated tail dependence coefficient 89 | """ 90 | return ( 91 | 3.0 92 | - ( 93 | 1.0 94 | - torch.stack([torch.max(x, y), torch.max(1.0 - x, y)], dim=1) 95 | .mean(dim=0) 96 | .clamp(0.5, 0.6666666666666666) 97 | .min() 98 | ).reciprocal() 99 | ) 100 | 101 | 102 | @torch.no_grad() 103 | def chatterjee_xi(x: torch.Tensor, y: torch.Tensor, M: int = 1) -> torch.Tensor: 104 | """Estimate Chatterjee's rank correlation coefficient (ξ) 105 | 106 | - Chatterjee, S. (2021). A new coefficient of correlation. Journal of the American Statistical Association, 116(536), 2009-2022. 107 | - Lin, Z., & Han, F. (2023). On boosting the power of Chatterjee’s rank correlation. Biometrika, 110(2), 283-299. 108 | 109 | Args: 110 | x (torch.Tensor): shape (n, 1) 111 | y (torch.Tensor): shape (n, 1) 112 | M (int, optional): num of nearest-neighbor. Defaults to 1. 113 | Returns: 114 | torch.Tensor: Estimated Chatterjee's rank correlation coefficient 115 | """ 116 | # * ranks of x in the order of (ranks of y) 117 | # * ranks of y in the order of (ranks of x) 118 | xrank, yrank = ( 119 | x.argsort(dim=0).argsort(dim=0) + 1, 120 | y.argsort(dim=0).argsort(dim=0) + 1, 121 | ) 122 | xrank, yrank = xrank[yrank.argsort(dim=0)], yrank[xrank.argsort(dim=0)] 123 | 124 | # * the inner sum inside the numerator term, ∑min(Ri, Rjm(i)) 125 | # * max for symmetry as following Remark 1 in Chatterjee (2021) 126 | def xy_sum(m: int) -> tuple: 127 | return ( 128 | (torch.min(xrank[:-m], xrank[m:])).sum() + xrank[-m:].sum(), 129 | (torch.min(yrank[:-m], yrank[m:])).sum() + yrank[-m:].sum(), 130 | ) 131 | 132 | # * whole eq. 3 in Lin and Han (2023) 133 | n = x.shape[0] 134 | return -2.0 + 24.0 * ( 135 | torch.as_tensor([xy_sum(m) for m in range(1, M + 1)], device=x.device, dtype=x.dtype) 136 | .sum(dim=0) 137 | .max() 138 | ) / (M * (1.0 + n) * (1.0 + M + 4.0 * n)) 139 | 140 | 141 | class ENUM_FUNC_BIDEP(enum.Enum): 142 | """ 143 | Enum wrapper for bivariate dependence functions. 144 | """ 145 | 146 | chatterjee_xi = enum.member(chatterjee_xi) 147 | ferreira_tail_dep_coeff = enum.member(ferreira_tail_dep_coeff) 148 | kendall_tau = enum.member(kendall_tau) 149 | mutual_info = enum.member(mutual_info) 150 | 151 | def __call__(self, x: torch.Tensor, y: torch.Tensor, **kw): 152 | return self.value(x, y, **kw) 153 | 154 | 155 | class kdeCDFPPF1D(torch.nn.Module): 156 | _EPS = _EPS 157 | 158 | def __init__( 159 | self, 160 | x: torch.Tensor, 161 | num_step_grid: int = None, 162 | x_min: float = None, 163 | x_max: float = None, 164 | pad: float = 0.1, 165 | ): 166 | """1D KDE CDF/PPF using ``fastKDE`` + Simpson's rule. Given a sample ``x``, fits a kernel 167 | density estimate via ``fastKDE`` on a grid of size ``num_step_grid`` (power of two plus 168 | one). Precomputes PDF, CDF, and their finite‐difference slopes for fast interpolation. 169 | 170 | - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D., & O’Brien, J. P. (2016). A fast and objective multidimensional kernel density estimation method: fastKDE. Computational Statistics & Data Analysis, 101, 148-160. 171 | - O’Brien, T. A., Collins, W. D., Rauscher, S. A., & Ringler, T. D. (2014). Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Computational Statistics & Data Analysis, 79, 222-234. 172 | 173 | Args: 174 | x (torch.Tensor): input sample to fit the KDE. 175 | num_step_grid (int, optional): number of grid points for the KDE, should be power of 2 plus 1. Defaults to None. 176 | x_min (float, optional): minimum value of the grid. Defaults to x.min() - pad. 177 | x_max (float, optional): maximum value of the grid. Defaults to x.max() + pad. 178 | pad (float, optional): padding to extend beyond the min/max when ``x_min``/``x_max`` is None. Defaults to 1.0. 179 | """ 180 | super().__init__() 181 | self.num_obs = x.shape[0] 182 | self.x_min = x_min if x_min is not None else x.min().item() - pad 183 | self.x_max = x_max if x_max is not None else x.max().item() + pad 184 | # * power of 2 plus 1 185 | if num_step_grid is None: 186 | num_step_grid = int(2 ** torch.log2(torch.tensor(x.numel())).ceil().item()) + 1 187 | self.num_step_grid = num_step_grid 188 | # * fastkde 189 | res = fastkde.pdf(x.view(-1).cpu().numpy(), num_points=num_step_grid) 190 | xs = torch.from_numpy(res.var0.values).to(dtype=torch.float64) 191 | pdfs = torch.from_numpy(res.values).to(dtype=torch.float64).clamp_min(self._EPS) 192 | N = pdfs.shape[0] 193 | ws = torch.ones(N, dtype=torch.float64) 194 | ws[1:-1:2] = 4 195 | ws[2:-1:2] = 2 196 | h = xs[1] - xs[0] 197 | cdf = torch.cumsum(pdfs * ws, dim=0) * (h / 3) 198 | cdf = cdf / cdf[-1] 199 | slope_fwd = (cdf[1:] - cdf[:-1]) / h 200 | slope_inv = h / (cdf[1:] - cdf[:-1]) 201 | slope_pdf = (pdfs[1:] - pdfs[:-1]) / h 202 | self.register_buffer("grid_x", xs) 203 | self.register_buffer("grid_pdf", pdfs) 204 | self.register_buffer("grid_cdf", cdf) 205 | self.register_buffer("slope_fwd", slope_fwd) 206 | self.register_buffer("slope_inv", slope_inv) 207 | self.register_buffer("slope_pdf", slope_pdf) 208 | self.h = h 209 | # ! device agnostic 210 | self.register_buffer("_dd", torch.tensor([], dtype=torch.float64)) 211 | self.negloglik = -self.log_pdf(x).mean() 212 | 213 | @property 214 | def device(self): 215 | return self._dd.device 216 | 217 | @property 218 | def dtype(self): 219 | return self._dd.dtype 220 | 221 | def cdf(self, x: torch.Tensor) -> torch.Tensor: 222 | """Compute the CDF of the fitted KDE at ``x``. 223 | 224 | Args: 225 | x (torch.Tensor): Points at which to evaluate the CDF. 226 | Returns: 227 | torch.Tensor: CDF values at ``x``, clamped to [0, 1]. 228 | """ 229 | # ! device agnostic 230 | x = x.to(device=self.device, dtype=self.dtype) 231 | x_clamped = x.clamp(self.x_min, self.x_max) 232 | idx = torch.searchsorted(self.grid_x, x_clamped, right=False) 233 | idx = idx.clamp(1, self.grid_cdf.numel() - 1) 234 | y = (self.grid_cdf[idx - 1]) + (self.slope_fwd[idx - 1]) * ( 235 | x_clamped - self.grid_x[idx - 1] 236 | ) 237 | y = torch.where(x < self.x_min, torch.zeros_like(y), y) 238 | y = torch.where(x > self.x_max, torch.ones_like(y), y) 239 | return y.clamp(0.0, 1.0) 240 | 241 | def ppf(self, q: torch.Tensor) -> torch.Tensor: 242 | """Compute the PPF (quantile function) of the fitted KDE at ``q``. 243 | 244 | Args: 245 | q (torch.Tensor): Quantiles at which to evaluate the PPF. 246 | Returns: 247 | torch.Tensor: PPF values at ``q``, clamped to [x_min, x_max]. 248 | """ 249 | # ! device agnostic 250 | q = q.to(device=self.device, dtype=self.dtype) 251 | q_clamped = q.clamp(0.0, 1.0) 252 | idx = torch.searchsorted(self.grid_cdf, q_clamped, right=False) 253 | idx = idx.clamp(1, self.grid_cdf.numel() - 1) 254 | x = (self.grid_x[idx - 1]) + (self.slope_inv[idx - 1]) * ( 255 | q_clamped - self.grid_cdf[idx - 1] 256 | ) 257 | x = torch.where(q < 0.0, torch.full_like(x, self.x_min), x) 258 | x = torch.where(q > 1.0, torch.full_like(x, self.x_max), x) 259 | return x.clamp(self.x_min, self.x_max) 260 | 261 | def pdf(self, x: torch.Tensor) -> torch.Tensor: 262 | """Compute the PDF of the fitted KDE at ``x``. 263 | 264 | Args: 265 | x (torch.Tensor): Points at which to evaluate the PDF. 266 | Returns: 267 | torch.Tensor: PDF values at ``x``, clamped to [0, ∞). 268 | """ 269 | # ! device agnostic 270 | x = x.to(device=self.device, dtype=self.dtype) 271 | x_clamped = x.clamp(self.x_min, self.x_max) 272 | idx = torch.searchsorted(self.grid_x, x_clamped, right=False) 273 | idx = idx.clamp(1, self.grid_pdf.numel() - 1) 274 | f = self.grid_pdf[idx - 1] + (self.slope_pdf[idx - 1]) * (x_clamped - self.grid_x[idx - 1]) 275 | f = torch.where((x < self.x_min) | (x > self.x_max), torch.zeros_like(f), f) 276 | return f.clamp_min(0.0) 277 | 278 | def log_pdf(self, x: torch.Tensor) -> torch.Tensor: 279 | """Compute the log PDF of the fitted KDE at ``x``. 280 | 281 | Args: 282 | x (torch.Tensor): Points at which to evaluate the log PDF. 283 | Returns: 284 | torch.Tensor: Log PDF values at ``x``, guaranteed to be finite. 285 | """ 286 | return self.pdf(x).log().nan_to_num(posinf=0.0, neginf=-13.815510557964274) 287 | 288 | def forward(self, x: torch.Tensor) -> torch.Tensor: 289 | """Average negative log-likelihood of the fitted KDE at ``x``. 290 | 291 | Args: 292 | x (torch.Tensor): Points at which to evaluate the negative log-likelihood. 293 | Returns: 294 | torch.Tensor: Negative log-likelihood values at ``x``, averaged over the batch. 295 | """ 296 | return -self.log_pdf(x).mean() 297 | 298 | def __str__(self): 299 | """String representation of the ``kdeCDFPPF1D`` object. 300 | 301 | Returns: 302 | str: String representation of the ``kdeCDFPPF1D`` object. 303 | """ 304 | header = self.__class__.__name__ 305 | params = { 306 | "num_obs": int(self.num_obs), 307 | "negloglik": float(self.negloglik.round(decimals=4)), 308 | "x_min": float(round(self.x_min, 4)), 309 | "x_max": float(round(self.x_max, 4)), 310 | "num_step_grid": int(self.num_step_grid), 311 | "dtype": self.dtype, 312 | "device": self.device, 313 | } 314 | params_str = pformat(params, sort_dicts=False, underscore_numbers=True) 315 | return f"{header}\n{params_str[1:-1]}\n\n" 316 | 317 | 318 | # @torch.compile 319 | @torch.no_grad() 320 | def solve_ITP( 321 | fun: callable, 322 | x_a: torch.Tensor, 323 | x_b: torch.Tensor, 324 | epsilon: float = _EPS, 325 | num_iter_max: int = 31, 326 | k_1: float = 0.2, 327 | ) -> torch.Tensor: 328 | """Root-finding for ``fun`` via the Interpolate Truncate and Project (ITP) method within 329 | [``x_a``, ``x_b``], with guaranteed average performance strictly better than the bisection 330 | method under any continuous distribution. 331 | 332 | - Oliveira, I. F., & Takahashi, R. H. (2020). An enhancement of the bisection method average performance preserving minmax optimality. ACM Transactions on Mathematical Software (TOMS), 47(1), 1-24. 333 | https://en.wikipedia.org/wiki/ITP_method 334 | https://docs.rs/kurbo/latest/kurbo/common/fn.solve_itp.html 335 | 336 | Args: 337 | fun (callable): function to find the root of. 338 | x_a (torch.Tensor): lower bound of the interval to search. 339 | x_b (torch.Tensor): upper bound of the interval to search. 340 | epsilon (float, optional): convergence tolerance. Defaults to _EPS. 341 | num_iter_max (int, optional): maximum number of iterations. Defaults to 31. 342 | k_1 (float, optional): scaling factor for the truncation step. Defaults to 0.2. 343 | Returns: 344 | torch.Tensor: approximated root of the function `fun` in the interval [x_a, x_b]. 345 | """ 346 | y_a, y_b = fun(x_a), fun(x_b) 347 | # * corner cases 348 | x_a = torch.where(condition=y_b.abs() < epsilon, input=x_b - epsilon * num_iter_max, other=x_a) 349 | x_b = torch.where(condition=y_a.abs() < epsilon, input=x_a + epsilon * num_iter_max, other=x_b) 350 | y_a, y_b, x_wid = fun(x_a), fun(x_b), x_b - x_a 351 | eps_2 = torch.as_tensor(epsilon * 2.0, device=x_a.device, dtype=x_a.dtype) 352 | eps_scale = epsilon * 2.0 ** ( 353 | (x_wid / epsilon).max().clamp_min(1.0).log2().ceil().clamp_min(1.0).int() 354 | ) 355 | x_half = torch.empty_like(x_wid) 356 | rho = torch.empty_like(x_wid) 357 | sigma = torch.empty_like(x_wid) 358 | delta = torch.empty_like(x_wid) 359 | for _ in range(num_iter_max): 360 | if (x_wid < eps_2).all(): 361 | break 362 | # * update parameters 363 | x_half.copy_(0.5 * (x_a + x_b)) 364 | rho.copy_(eps_scale - 0.5 * x_wid) 365 | # * interpolation 366 | x_f = (y_b * x_a - y_a * x_b) / (y_b - y_a) 367 | sigma.copy_(x_half - x_f) 368 | # ! here k2 = 2 hardwired for efficiency. 369 | delta.copy_(k_1 * x_wid.square()) 370 | # * truncation 371 | x_t = torch.where( 372 | condition=delta <= sigma.abs(), 373 | input=x_f + torch.copysign(delta, sigma), 374 | other=x_half, 375 | ) 376 | # * projection 377 | x_itp = torch.where( 378 | condition=rho >= (x_t - x_half).abs(), 379 | input=x_t, 380 | other=x_half - torch.copysign(rho, sigma), 381 | ) 382 | # * update interval 383 | y_itp = fun(x_itp) 384 | idx = y_itp > 0.0 385 | x_b[idx], y_b[idx] = x_itp[idx], y_itp[idx] 386 | idx = ~idx 387 | x_a[idx], y_a[idx] = x_itp[idx], y_itp[idx] 388 | x_wid = x_b - x_a 389 | eps_scale *= 0.5 390 | return 0.5 * (x_a + x_b) 391 | -------------------------------------------------------------------------------- /torchvinecopulib/bicop/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | torchvinecopulib.bicop 3 | ----------------------- 4 | Provides ``BiCop`` (``torch.nn.Module``) for estimating, evaluating, and sampling 5 | from bivariate copulas via ``tll`` or ``fastKDE`` approaches. 6 | 7 | Decorators 8 | ----------- 9 | * torch.compile() for bilinear interpolation 10 | * torch.no_grad() for fit(), hinv_l(), hinv_r(), sample(), and imshow() 11 | 12 | Key Features 13 | ------------- 14 | - Fits a copula density on a uniform [0,1]² grid and caches PDF/CDF/h‐functions 15 | - Device‐agnostic: all buffers live on the same device/dtype you fit on 16 | - Fast bilinear interpolation compiled with ``torch.compile`` 17 | - Convenient ``.cdf()``, ``.pdf()``, ``.hfunc_*()``, ``.hinv_*()``, and ``.sample()`` APIs 18 | - Plotting helpers: ``.imshow()`` and ``.plot(contour|surface)`` 19 | 20 | Usage 21 | ------ 22 | >>> from torchvinecopulib.bicop import BiCop 23 | >>> cop = BiCop(num_step_grid=256) 24 | >>> cop.fit(obs) # obs: Tensor of shape (n,2) in [0,1]² 25 | >>> u = torch.rand(10, 2) 26 | >>> cdf_vals = cop.cdf(u) 27 | >>> samples = cop.sample(1000, is_sobol=True) 28 | 29 | References 30 | ----------- 31 | - Nagler, T., Schellhase, C., & Czado, C. (2017). Nonparametric estimation of simplified vine copula models: comparison of methods. Dependence Modeling, 5(1), 99-120. 32 | - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D. & O’Brien, J. P. A fast and objective multidimensional kernel density estimation method: fastKDE. Comput. Stat. Data Anal. 101, 148–160 (2016). http://dx.doi.org/10.1016/j.csda.2016.02.014 33 | - O’Brien, T. A., Collins, W. D., Rauscher, S. A. & Ringler, T. D. Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Comput. Stat. Data Anal. 79, 222–234 (2014). http://dx.doi.org/10.1016/j.csda.2014.06.002 34 | """ 35 | 36 | from pprint import pformat 37 | from typing import Optional, cast 38 | 39 | import matplotlib.pyplot as plt 40 | import numpy as np 41 | import pyvinecopulib as pv 42 | import torch 43 | from fastkde import pdf as fkpdf 44 | from matplotlib.colors import LinearSegmentedColormap 45 | from mpl_toolkits.mplot3d.axes3d import Axes3D 46 | from scipy.stats import kendalltau, norm 47 | 48 | from ..util import _EPS, solve_ITP 49 | 50 | __all__ = [ 51 | "BiCop", 52 | ] 53 | 54 | 55 | class BiCop(torch.nn.Module): 56 | # ! hinv 57 | _EPS: float = _EPS 58 | 59 | def __init__( 60 | self, 61 | num_step_grid: int = 128, 62 | ): 63 | """Initializes the bivariate copula (BiCop) class. By default an independent bicop. 64 | 65 | Args: 66 | num_step_grid (int, optional): number of steps per dimension for the precomputed grids (must be a power of 2). Defaults to 128. 67 | """ 68 | super().__init__() 69 | # * by default an independent bicop, otherwise cache grids from KDE 70 | self.is_indep = True 71 | self.mtd_kde = "tll" # default method for estimating the copula density 72 | self.num_step_grid = num_step_grid 73 | self.register_buffer("tau", torch.zeros(2, dtype=torch.float64)) 74 | self.register_buffer("num_obs", torch.empty((), dtype=torch.int)) 75 | self.register_buffer("negloglik", torch.zeros((), dtype=torch.float64)) 76 | self.register_buffer( 77 | "_pdf_grid", 78 | torch.ones(num_step_grid, num_step_grid, dtype=torch.float64), 79 | ) 80 | self.register_buffer( 81 | "_cdf_grid", 82 | torch.empty(num_step_grid, num_step_grid, dtype=torch.float64), 83 | ) 84 | self.register_buffer( 85 | "_hfunc_l_grid", 86 | torch.empty(num_step_grid, num_step_grid, dtype=torch.float64), 87 | ) 88 | self.register_buffer( 89 | "_hfunc_r_grid", 90 | torch.empty(num_step_grid, num_step_grid, dtype=torch.float64), 91 | ) 92 | # ! device agnostic 93 | self.register_buffer("_dd", torch.tensor([], dtype=torch.float64)) 94 | 95 | @property 96 | def device(self) -> torch.device: 97 | """Get the device of the bicop model (all internal buffers). 98 | 99 | Returns: 100 | torch.device: The device on which the registered buffers reside. 101 | """ 102 | return self._dd.device 103 | 104 | @property 105 | def dtype(self) -> torch.dtype: 106 | """Get the data type of the bicop model (all internal buffers). Should be torch.float64. 107 | 108 | Returns: 109 | torch.dtype: The data type of the registered buffers. 110 | """ 111 | return self._dd.dtype 112 | 113 | @torch.no_grad() 114 | def reset(self) -> None: 115 | """Reinitialize state and zero all statistics and precomputed grids. 116 | 117 | Sets the bicop back to independent bicop and clears accumulated 118 | metrics (``tau``, ``num_obs``, ``negloglik``) as well as all grid buffers 119 | (``_pdf_grid``, ``_cdf_grid``, ``_hfunc_l_grid``, ``_hfunc_r_grid``). 120 | """ 121 | self.is_indep = True 122 | self.tau.zero_() 123 | self.num_obs.zero_() 124 | self.negloglik.zero_() 125 | self._pdf_grid.zero_() 126 | self._cdf_grid.zero_() 127 | self._hfunc_l_grid.zero_() 128 | self._hfunc_r_grid.zero_() 129 | 130 | @torch.no_grad() 131 | def fit( 132 | self, 133 | obs: torch.Tensor, 134 | mtd_kde: str = "tll", 135 | mtd_tll: str = "constant", 136 | num_iter_max: int = 17, 137 | is_tau_est: bool = False, 138 | ) -> None: 139 | """Estimate and cache PDF/CDF/h-function grids from bivariate copula observations. 140 | 141 | This method computes KDE-based bicopula densities on a uniform [0,1]² grid and populates internal buffers 142 | (``_pdf_grid``, ``_cdf_grid``, ``_hfunc_l_grid``, ``_hfunc_r_grid``, ``negloglik``). 143 | 144 | - Nagler, T., Schellhase, C., & Czado, C. (2017). Nonparametric estimation of simplified vine copula models: comparison of methods. Dependence Modeling, 5(1), 99-120. 145 | - O’Brien, T. A., Kashinath, K., Cavanaugh, N. R., Collins, W. D. & O’Brien, J. P. A fast and objective multidimensional kernel density estimation method: fastKDE. Comput. Stat. Data Anal. 101, 148–160 (2016). http://dx.doi.org/10.1016/j.csda.2016.02.014 146 | - O’Brien, T. A., Collins, W. D., Rauscher, S. A. & Ringler, T. D. Reducing the computational cost of the ECF using a nuFFT: A fast and objective probability density estimation method. Comput. Stat. Data Anal. 79, 222–234 (2014). http://dx.doi.org/10.1016/j.csda.2014.06.002 147 | 148 | 149 | Args: 150 | obs (torch.Tensor): shape (n, 2) bicop obs in [0, 1]². 151 | mtd_kde (str, optional): Method for estimating the copula density. One of ("tll", "fastKDE"). Defaults to "tll". 152 | mtd_tll (str, optional): fit method for the transformation local-likelihood (TLL) nonparametric family, used only when ``mtd_kde="tll"``, one of ("constant", "linear", or "quadratic"). Defaults to "constant". 153 | num_iter_max (int, optional): num of Sinkhorn/IPF iters for grid normalization, used only when ``mtd_kde="fastKDE"``. Defaults to 17. 154 | is_tau_est (bool, optional): If True, compute and store Kendall’s τ. Defaults to ``False``. 155 | """ 156 | # ! device agnostic 157 | device, dtype = self.device, self.dtype 158 | self.is_indep = False 159 | self.mtd_kde = mtd_kde 160 | self.num_obs.copy_(obs.shape[0]) 161 | # * assuming already in [0, 1] 162 | obs = obs.clamp(min=0.0, max=1.0) 163 | if is_tau_est: 164 | self.tau.copy_( 165 | torch.as_tensor( 166 | kendalltau(obs[:, 0].cpu(), obs[:, 1].cpu()), 167 | device=device, 168 | dtype=dtype, 169 | ) 170 | ) 171 | self._target = self.num_step_grid - 1.0 # * marginal target 172 | self.step_grid = 1.0 / self._target 173 | # ! pdf 174 | if mtd_kde == "tll": 175 | controls = pv.FitControlsBicop( 176 | family_set=[pv.tll], 177 | num_threads=torch.get_num_threads(), 178 | nonparametric_method=mtd_tll, 179 | ) 180 | cop = pv.Bicop.from_data(data=obs.cpu().numpy(), controls=controls) 181 | axis = torch.linspace( 182 | _EPS, 183 | 1.0 - _EPS, 184 | steps=self.num_step_grid, 185 | device="cpu", 186 | dtype=torch.float64, 187 | ) 188 | pdf_grid = ( 189 | torch.from_numpy(cop.pdf(torch.cartesian_prod(axis, axis).view(-1, 2).numpy())) 190 | .view(self.num_step_grid, self.num_step_grid) 191 | .to(device=device, dtype=dtype) 192 | ) 193 | elif mtd_kde == "fastKDE": 194 | pdf_grid = torch.from_numpy( 195 | fkpdf( 196 | obs[:, 0].cpu(), 197 | obs[:, 1].cpu(), 198 | num_points=self.num_step_grid * 2 + 1, 199 | ).values 200 | ).to(device=device, dtype=dtype) 201 | # * padding/trimming after fastkde.pdf 202 | H, W = pdf_grid.shape 203 | if H < self.num_step_grid: 204 | pdf_grid = torch.cat( 205 | [ 206 | pdf_grid, 207 | torch.zeros(self.num_step_grid - H, W, dtype=dtype, device=device), 208 | ], 209 | dim=0, 210 | ) 211 | H, W = pdf_grid.shape 212 | if W < self.num_step_grid: 213 | pdf_grid = torch.cat( 214 | [ 215 | pdf_grid, 216 | torch.zeros(H, self.num_step_grid - W, dtype=dtype, device=device), 217 | ], 218 | dim=1, 219 | ) 220 | pdf_grid = pdf_grid[: self.num_step_grid, : self.num_step_grid].clamp_min(0.0) 221 | pdf_grid = pdf_grid.view(self.num_step_grid, self.num_step_grid).T 222 | # * normalization: Sinkhorn / iterative proportional fitting (IPF) 223 | for _ in range(num_iter_max): 224 | pdf_grid *= self._target / pdf_grid.sum(dim=0, keepdim=True) 225 | pdf_grid *= self._target / pdf_grid.sum(dim=1, keepdim=True) 226 | pdf_grid /= pdf_grid.sum() * self.step_grid**2 227 | else: 228 | raise NotImplementedError 229 | self._pdf_grid = pdf_grid 230 | # * negloglik 231 | self.negloglik = -self.log_pdf(obs=obs).nan_to_num(posinf=0.0, neginf=0.0).sum() 232 | # ! cdf 233 | self._cdf_grid = ((self._pdf_grid * self.step_grid**2).cumsum(dim=0).cumsum(dim=1)).clamp_( 234 | 0.0, 1.0 235 | ) 236 | # ! h functions 237 | self._hfunc_l_grid = (self._pdf_grid * self.step_grid).cumsum(dim=1).clamp_(0.0, 1.0) 238 | self._hfunc_r_grid = (self._pdf_grid * self.step_grid).cumsum(dim=0).clamp_(0.0, 1.0) 239 | 240 | # @torch.compile 241 | def _interp(self, grid: torch.Tensor, obs: torch.Tensor) -> torch.Tensor: 242 | """Bilinearly interpolate values on a 2D grid at given sample points. 243 | 244 | Args: 245 | grid (torch.Tensor): Precomputed grid of values (e.g., PDF/CDF/h‐function), shape (m,m). 246 | obs (torch.Tensor): Points in [0,1]² where to interpolate (rows are (u₁,u₂)), shape (n,2). 247 | 248 | Returns: 249 | torch.Tensor: Interpolated grid values at each observation, clamped ≥0, shape (n,1). 250 | """ 251 | idx = obs.clamp(self._EPS, 1 - self._EPS) / self.step_grid 252 | i0 = idx.floor().long() 253 | di = idx - i0 254 | i1 = torch.minimum( 255 | i0 + 1, 256 | torch.full_like(input=i0, fill_value=self._target, device=idx.device), 257 | ) 258 | g00 = grid[i0[:, 0], i0[:, 1]] 259 | g10 = grid[i1[:, 0], i0[:, 1]] 260 | g01 = grid[i0[:, 0], i1[:, 1]] 261 | g11 = grid[i1[:, 0], i1[:, 1]] 262 | return ( 263 | g00 264 | + (g10 - g00) * di[:, 0] 265 | + (g01 - g00) * di[:, 1] 266 | + (g11 - g01 - g10 + g00) * di[:, 0] * di[:, 1] 267 | ).clamp_min(0.0) 268 | 269 | def cdf(self, obs: torch.Tensor) -> torch.Tensor: 270 | """Evaluate the copula CDF at given points. For independent copula, returns u₁·u₂. 271 | 272 | Args: 273 | obs (torch.Tensor): Points in [0,1]² where to evaluate the CDF (rows are (u₁,u₂)), shape (n,2). 274 | 275 | Returns: 276 | torch.Tensor: CDF values at each observation, shape (n,1). 277 | """ 278 | # ! device agnostic 279 | obs = obs.to(device=self.device, dtype=self.dtype) 280 | if self.is_indep: 281 | return obs.prod(dim=1, keepdim=True) 282 | return self._interp(grid=self._cdf_grid, obs=obs).unsqueeze(dim=1) 283 | 284 | def hfunc_l(self, obs: torch.Tensor) -> torch.Tensor: 285 | """Evaluate the left h-function at given points. Computes H(u₂ | u₁):= ∂/∂u₁ C(u₁,u₂) for 286 | the fitted copula. For independent copula, returns u₂. 287 | 288 | Args: 289 | obs (torch.Tensor): Points in [0,1]² where to evaluate the left h-function (rows are (u₁,u₂)), shape (n,2). 290 | 291 | Returns: 292 | torch.Tensor: Left h-function values at each observation, shape (n,1). 293 | """ 294 | # ! device agnostic 295 | obs = obs.to(device=self.device, dtype=self.dtype) 296 | if self.is_indep: 297 | return obs[:, [1]] 298 | return self._interp(grid=self._hfunc_l_grid, obs=obs).unsqueeze(dim=1) 299 | 300 | def hfunc_r(self, obs: torch.Tensor) -> torch.Tensor: 301 | """Evaluate the right h-function at given points. Computes H(u₁ | u₂):= ∂/∂u₂ C(u₁,u₂) for 302 | the fitted copula. For independent copula, returns u₁. 303 | 304 | Args: 305 | obs (torch.Tensor): Points in [0,1]² where to evaluate the right h-function (rows are (u₁,u₂)), shape (n,2). 306 | 307 | Returns: 308 | torch.Tensor: Right h-function values at each observation, shape (n,1). 309 | """ 310 | # ! device agnostic 311 | obs = obs.to(device=self.device, dtype=self.dtype) 312 | if self.is_indep: 313 | return obs[:, [0]] 314 | return self._interp(grid=self._hfunc_r_grid, obs=obs).unsqueeze(dim=1) 315 | 316 | @torch.no_grad() 317 | def hinv_l(self, obs: torch.Tensor) -> torch.Tensor: 318 | """Invert the left h‐function via root‐finding: find u₂ given (u₁, p). Solves H(u₂ | u₁) = p 319 | by ITP between 0 and 1. 320 | 321 | Args: 322 | obs (torch.Tensor): Points in [0,1]² where to evaluate the left h-function (rows are (u₁,u₂)), shape (n,2). 323 | 324 | Returns: 325 | torch.Tensor: Solutions u₂ ∈ [0,1], shape (n,1). 326 | """ 327 | # ! device agnostic 328 | obs = obs.to(device=self.device, dtype=self.dtype) 329 | if self.is_indep: 330 | return obs[:, [1]] 331 | # * via root-finding 332 | u_l = obs[:, [0]] 333 | p = obs[:, [1]] 334 | return solve_ITP( 335 | fun=lambda u_r: self.hfunc_l(obs=torch.hstack([u_l, u_r])) - p, 336 | x_a=torch.zeros_like(p), 337 | x_b=torch.ones_like(p), 338 | ).clamp(min=0.0, max=1.0) 339 | 340 | @torch.no_grad() 341 | def hinv_r(self, obs: torch.Tensor) -> torch.Tensor: 342 | """Invert the right h‐function via root‐finding: find u₁ given (u₂, p). Solves H(u₁ | u₂) = 343 | p by ITP between 0 and 1. 344 | 345 | Args: 346 | obs (torch.Tensor): Points in [0,1]² where to evaluate the right h-function (rows are (u₁,u₂)), shape (n,2). 347 | Returns: 348 | torch.Tensor: Solutions u₁ ∈ [0,1], shape (n,1). 349 | """ 350 | # ! device agnostic 351 | obs = obs.to(device=self.device, dtype=self.dtype) 352 | if self.is_indep: 353 | return obs[:, [0]] 354 | # * via root-finding 355 | u_r = obs[:, [1]] 356 | p = obs[:, [0]] 357 | return solve_ITP( 358 | fun=lambda u_l: self.hfunc_r(obs=torch.hstack([u_l, u_r])) - p, 359 | x_a=torch.zeros_like(p), 360 | x_b=torch.ones_like(p), 361 | ).clamp(min=0.0, max=1.0) 362 | 363 | def pdf(self, obs: torch.Tensor) -> torch.Tensor: 364 | """Evaluate the copula PDF at given points. For independent copula, returns 1. 365 | 366 | Args: 367 | obs (torch.Tensor): Points in [0,1]² where to evaluate the PDF (rows are (u₁,u₂)), shape (n,2). 368 | Returns: 369 | torch.Tensor: PDF values at each observation, shape (n,1). 370 | """ 371 | # ! device agnostic 372 | obs = obs.to(device=self.device, dtype=self.dtype) 373 | if self.is_indep: 374 | return torch.ones_like(obs[:, [0]]) 375 | return self._interp(grid=self._pdf_grid, obs=obs).unsqueeze(dim=1) 376 | 377 | def log_pdf(self, obs: torch.Tensor) -> torch.Tensor: 378 | """Evaluate the copula log-PDF at given points, with safe handling of inf/nan. For 379 | independent copula, returns 0. 380 | 381 | Args: 382 | obs (torch.Tensor): Points in [0,1]² where to evaluate the log-PDF (rows are (u₁,u₂)), shape (n,2). 383 | Returns: 384 | torch.Tensor: log-PDF values at each observation, shape (n,1). 385 | """ 386 | # ! device agnostic 387 | obs = obs.to(device=self.device, dtype=self.dtype) 388 | if self.is_indep: 389 | return torch.zeros_like(obs[:, [0]]) 390 | return self.pdf(obs=obs).log().nan_to_num(posinf=0.0, neginf=-13.815510557964274) 391 | 392 | @torch.no_grad() 393 | def sample( 394 | self, num_sample: int = 100, seed: int = 42, is_sobol: bool = False 395 | ) -> torch.Tensor: 396 | """Sample from the copula by inverse Rosenblatt transform. Uses Sobol sequence if 397 | ``is_sobol=True``, otherwise uniform RNG. For independent copula, returns uniform samples in 398 | [0,1]². 399 | 400 | Args: 401 | num_sample (int, optional): number of samples to generate. Defaults to 100. 402 | seed (int, optional): random seed for reproducibility. Defaults to 42. 403 | is_sobol (bool, optional): If True, use Sobol sampling. Defaults to False. 404 | Returns: 405 | torch.Tensor: Generated samples, shape (num_sample, 2). 406 | """ 407 | # ! device agnostic 408 | device, dtype = self.device, self.dtype 409 | if is_sobol: 410 | obs = ( 411 | torch.quasirandom.SobolEngine(dimension=2, scramble=True, seed=seed) 412 | .draw(n=num_sample, dtype=dtype) 413 | .to(device=device) 414 | ) 415 | else: 416 | torch.manual_seed(seed=seed) 417 | obs = torch.rand(size=(num_sample, 2), dtype=dtype, device=device) 418 | if not self.is_indep: 419 | obs[:, [1]] = self.hinv_l(obs=obs) 420 | return obs 421 | 422 | def __str__(self) -> str: 423 | """String representation of the BiCop class. 424 | 425 | Returns: 426 | str: String representation of the BiCop class. 427 | """ 428 | return f"""{self.__class__.__name__}\n{ 429 | pformat( 430 | object={ 431 | "is_indep": self.is_indep, 432 | "num_obs": self.num_obs, 433 | "negloglik": self.negloglik.round(decimals=4), 434 | "num_step_grid": self.num_step_grid, 435 | "tau": self.tau.round(decimals=4), 436 | "mtd_kde": self.mtd_kde, 437 | "dtype": self._dd.dtype, 438 | "device": self._dd.device, 439 | }, 440 | compact=True, 441 | sort_dicts=False, 442 | underscore_numbers=True, 443 | ) 444 | }""" 445 | 446 | @torch.no_grad() 447 | def imshow( 448 | self, 449 | is_log_pdf: bool = False, 450 | ax: plt.Axes | None = None, 451 | cmap: str = "inferno", 452 | xlabel: str = r"$u_{left}$", 453 | ylabel: str = r"$u_{right}$", 454 | title: str = "Estimated bivariate copula density", 455 | colorbartitle: str = "Density", 456 | **imshow_kwargs: dict, 457 | ) -> tuple[plt.Figure, plt.Axes]: 458 | """Display the (log-)PDF grid as a heatmap. 459 | 460 | Args: 461 | is_log_pdf (bool, optional): If True, plot log-PDF. Defaults to False. 462 | ax (plt.Axes, optional): Matplotlib Axes object to plot on. If None, a new figure and axes are created. Defaults to None. 463 | cmap (str, optional): Colormap for the plot. Defaults to "inferno". 464 | xlabel (str, optional): X-axis label. Defaults to r"$u_{left}$". 465 | ylabel (str, optional): Y-axis label. Defaults to r"$u_{right}$". 466 | title (str, optional): Plot title. Defaults to "Estimated bivariate copula density". 467 | colorbartitle (str, optional): Colorbar title. Defaults to "Density". 468 | **imshow_kwargs: Additional keyword arguments for imshow. 469 | Returns: 470 | tuple[plt.Figure, plt.Axes]: The figure and axes objects. 471 | """ 472 | if ax is None: 473 | fig, ax = plt.subplots() 474 | else: 475 | fig = ax.figure 476 | im = ax.imshow( 477 | X=self._pdf_grid.log().nan_to_num(posinf=0.0, neginf=-13.815510557964274).cpu() 478 | if is_log_pdf 479 | else self._pdf_grid.cpu(), 480 | extent=(0, 1, 0, 1), 481 | origin="lower", 482 | cmap=cmap, 483 | **imshow_kwargs, 484 | ) 485 | ax.set_xlabel(xlabel=xlabel) 486 | ax.set_ylabel(ylabel=ylabel) 487 | ax.set_title(label=title) 488 | ax.set_xlim(0, 1) 489 | ax.set_ylim(0, 1) 490 | plt.colorbar(im, ax=ax, label=colorbartitle) 491 | return fig, ax 492 | 493 | @torch.no_grad() 494 | def plot( 495 | self, 496 | plot_type: str = "surface", 497 | margin_type: str = "unif", 498 | xylim: Optional[tuple[float, float]] = None, 499 | grid_size: Optional[int] = None, 500 | ) -> tuple[plt.Figure, plt.Axes]: 501 | """Plot the bivariate copula density. 502 | 503 | Args: 504 | plot_type (str, optional): Type of plot, either "contour" or "surface". Defaults to "surface". 505 | margin_type (str, optional): Type of margin, either "unif" or "norm". Defaults to "unif". 506 | xylim (tuple[float, float], optional): Limits for x and y axes. Defaults to None. 507 | grid_size (int, optional): Size of the grid for the plot. Defaults to None. 508 | Returns: 509 | tuple[plt.Figure, plt.Axes]: The figure and axes objects. 510 | """ 511 | # * validate inputs 512 | if plot_type not in ["contour", "surface"]: 513 | raise ValueError("Unknown type") 514 | elif plot_type == "contour" and grid_size is None: 515 | grid_size = 100 516 | elif plot_type == "surface" and grid_size is None: 517 | grid_size = 40 518 | # * margin type and grid points 519 | if margin_type not in ["unif", "norm"]: 520 | raise ValueError("Unknown margin type") 521 | elif margin_type == "unif": 522 | if xylim is None: 523 | xylim = (1e-2, 1 - 1e-2) 524 | if plot_type == "contour": 525 | points = np.linspace(1e-5, 1 - 1e-5, grid_size) 526 | else: 527 | points = np.linspace(1, grid_size, grid_size) / (grid_size + 1) 528 | g = np.meshgrid(points, points) 529 | points = g[0][0] 530 | adj = 1 531 | levels = [0.2, 0.6, 1, 1.5, 2, 3, 5, 10, 20] 532 | xlabel, ylabel = "u1", "u2" 533 | elif margin_type == "norm": 534 | if xylim is None: 535 | xylim = (-3, 3) 536 | points = norm.cdf(np.linspace(xylim[0], xylim[1], grid_size)) 537 | g = np.meshgrid(points, points) 538 | points = norm.ppf(g[0][0]) 539 | adj = np.outer(norm.pdf(points), norm.pdf(points)) 540 | levels = [0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5] 541 | xlabel, ylabel = "z1", "z2" 542 | 543 | # * evaluate on grid 544 | g_tensor = torch.from_numpy(np.stack(g, axis=-1).reshape(-1, 2)).to( 545 | device=self.device, dtype=self.dtype 546 | ) 547 | vals = self.pdf(g_tensor).cpu().numpy() 548 | cop = np.reshape(vals, (grid_size, grid_size)) 549 | 550 | # * adjust for margins 551 | dens = cop * adj 552 | if len(np.unique(dens)) == 1: 553 | dens[0] = 1.000001 * dens[0] 554 | if margin_type == "unif": 555 | zlim = (0, max(3, 1.1 * max(dens.ravel()))) 556 | elif margin_type == "norm": 557 | zlim = (0, max(0.4, 1.1 * max(dens.ravel()))) 558 | 559 | # * create a colormap 560 | jet_colors = LinearSegmentedColormap.from_list( 561 | name="jet_colors", 562 | colors=[ 563 | "#00007F", 564 | "blue", 565 | "#007FFF", 566 | "cyan", 567 | "#7FFF7F", 568 | "yellow", 569 | "#FF7F00", 570 | "red", 571 | "#7F0000", 572 | ], 573 | N=100, 574 | ) 575 | 576 | # * plot 577 | if plot_type == "contour": 578 | fig, ax = plt.subplots() 579 | contour = ax.contour(points, points, dens, levels=levels, cmap="gray") 580 | ax.clabel(contour, inline=True, fontsize=8, fmt="%1.2f") 581 | ax.set_aspect("equal") 582 | ax.grid(True) 583 | elif plot_type == "surface": 584 | fig = plt.figure() 585 | ax = cast(Axes3D, fig.add_subplot(111, projection="3d")) 586 | ax.view_init(elev=30, azim=-110) 587 | X, Y = np.meshgrid(points, points) 588 | ax.plot_surface(X, Y, dens, cmap=jet_colors, edgecolor="none", shade=False) 589 | ax.set_zlim(zlim) 590 | ax.set_box_aspect([1, 1, 1]) 591 | ax.xaxis.pane.fill = False 592 | ax.yaxis.pane.fill = False 593 | ax.zaxis.pane.fill = False 594 | ax.grid(False) 595 | ax.set_xlabel(xlabel) 596 | ax.set_ylabel(ylabel) 597 | ax.set_xlim(xylim) 598 | ax.set_ylim(xylim) 599 | fig.tight_layout() 600 | plt.draw_if_interactive() 601 | return fig, ax 602 | --------------------------------------------------------------------------------