├── tta ├── __init__.py ├── models │ ├── linear.py │ ├── lenet.py │ ├── __init__.py │ └── resnet.py ├── common.py ├── utils.py ├── datasets │ ├── waterbirds.py │ ├── cxr │ │ ├── mimic.py │ │ ├── chexpert.py │ │ └── __init__.py │ ├── coco.py │ ├── __init__.py │ └── mnist.py ├── restore.py ├── visualize.py ├── train.py └── cli.py ├── pyrightconfig.json ├── requirements.txt ├── setup.py ├── Pipfile ├── docs └── experiments.md ├── scripts ├── manova.py ├── superpose.py ├── matching.py ├── andrew.py ├── baseline.py ├── freeze.py ├── tree.py └── merge.py ├── README.md ├── requirements-long.txt ├── .gitignore └── Makefile /tta/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "venvPath": ".", 3 | "venv": ".venv", 4 | } 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax>=0.3.15 2 | jaxlib 3 | flax 4 | matplotlib 5 | optax 6 | pandas 7 | scikit-learn 8 | torch 9 | 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('requirements.txt') as f: 4 | required = f.read().splitlines() 5 | 6 | setup( 7 | name="Test time label shift adaptation", 8 | version="0.1", 9 | packages=find_packages(), 10 | install_requires=required, 11 | ) 12 | -------------------------------------------------------------------------------- /tta/models/linear.py: -------------------------------------------------------------------------------- 1 | """Implementation of linear regression.""" 2 | 3 | import flax.linen as nn 4 | 5 | 6 | class Linear(nn.Module): 7 | num_outputs: int 8 | 9 | @nn.compact 10 | def __call__(self, x, train: bool): 11 | del train 12 | x = nn.Dense(features=self.num_outputs)(x) 13 | return x 14 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | numpy = "*" 8 | pandas = "*" 9 | scikit-learn = "*" 10 | jax = {extras = ["tpu"]} 11 | flax = "*" 12 | torch = "*" 13 | torchvision = "*" 14 | click = "*" 15 | matplotlib = "*" 16 | tensorflow-cpu = "*" 17 | statsmodels = "*" 18 | cxr-foundation = "*" 19 | 20 | [dev-packages] 21 | black = {extras = ["jupyter"]} 22 | 23 | [requires] 24 | python_version = "3.10" 25 | -------------------------------------------------------------------------------- /docs/experiments.md: -------------------------------------------------------------------------------- 1 | MNIST: 2 | [X] noise = 0.0 TRC-01 3 | [ ] noise = 0.5 TRC-02 4 | [ ] noise = 1.0 TRC-03 5 | 6 | CheXpert: 7 | [X] Embedding TRC-01 8 | [ ] Pixel TRC-01 9 | [ ] Domain 1/2/4 TRC-01 10 | [ ] Domain 10 11 | [ ] Tau = 0 TRC-04 12 | [ ] Tau = 1 TRC-05 13 | 14 | MIMIC: 15 | [ ] Embedding 16 | [ ] Tau = 0 TRC-04 17 | [ ] Tau = 1 TRC-05 18 | -------------------------------------------------------------------------------- /tta/common.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Union, Literal 2 | 3 | import jax.numpy as jnp 4 | 5 | 6 | AdaptationNull = Tuple[Literal["Null"]] 7 | AdaptationOracle = Tuple[Literal["Oracle"]] 8 | AdaptationGMTL = Tuple[Literal["GMTL"], float] 9 | AdaptationEM = Tuple[Literal["EM"], float, bool, bool] 10 | Adaptation = Union[AdaptationNull, AdaptationOracle, AdaptationGMTL, AdaptationEM] 11 | 12 | Curves = Dict[ 13 | Tuple[Adaptation, bool, int], 14 | jnp.ndarray, 15 | ] 16 | 17 | Sweeps = Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray] 18 | -------------------------------------------------------------------------------- /tta/models/lenet.py: -------------------------------------------------------------------------------- 1 | """Implementation of LeNet.""" 2 | 3 | import flax.linen as nn 4 | 5 | 6 | class LeNet(nn.Module): 7 | num_outputs: int 8 | 9 | @nn.compact 10 | def __call__(self, x, train: bool): 11 | del train 12 | x = nn.Conv(features=6, kernel_size=(5, 5), padding=((2, 2), (2, 2)))(x) 13 | x = nn.sigmoid(x) 14 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding=((0, 0), (0, 0))) 15 | x = nn.Conv(features=16, kernel_size=(5, 5), padding=[(0, 0), (0, 0)])(x) 16 | x = nn.sigmoid(x) 17 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding=((0, 0), (0, 0))) 18 | x = x.reshape((x.shape[0], -1)) 19 | x = nn.Dense(features=120)(x) 20 | x = nn.sigmoid(x) 21 | x = nn.Dense(features=84)(x) 22 | x = nn.sigmoid(x) 23 | x = nn.Dense(features=self.num_outputs)(x) 24 | return x 25 | -------------------------------------------------------------------------------- /scripts/manova.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from statsmodels.multivariate.manova import MANOVA 6 | 7 | 8 | def fit(): 9 | data_matrix = np.load("data/CheXpert/data_matrix.npz", allow_pickle=True) 10 | Y = data_matrix["features"] 11 | X = data_matrix["attributes"] 12 | X = pd.DataFrame(X, columns=data_matrix["columns"]) 13 | Y_names = [f"Y{i}" for i in range(Y.shape[1])] 14 | Y = pd.DataFrame(Y, columns=Y_names) 15 | 16 | X = X.drop(columns=["split"]) 17 | cutoff = np.median(X["AGE_AT_CXR"]) 18 | X["AGE_AT_CXR"] = (X["AGE_AT_CXR"] > cutoff).astype(int) 19 | mask = (X["GENDER"] == 0) | (X["GENDER"] == 1) 20 | mask &= X["PRIMARY_RACE"] >= 0 21 | mask &= X["ETHNICITY"] >= 0 22 | X = X.loc[mask] 23 | Y = Y.loc[mask] 24 | 25 | model = MANOVA(Y, X) 26 | results = model.mv_test() 27 | print(results) 28 | 29 | 30 | if __name__ == "__main__": 31 | sys.setrecursionlimit(15000) 32 | fit() 33 | -------------------------------------------------------------------------------- /scripts/superpose.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import click 5 | 6 | 7 | @click.command() 8 | @click.option("--source", type=click.Path(path_type=Path), required=True) 9 | @click.option("--target", type=click.Path(path_type=Path), required=True) 10 | def cmd(source, target): 11 | all_source_sweeps = np.load(source, allow_pickle=True) 12 | all_target_sweeps = dict(**np.load(target, allow_pickle=True)) 13 | 14 | for sweep_type, (source_sweeps, source_ylabel) in all_source_sweeps.items(): 15 | target_sweeps, target_ylabel = all_target_sweeps[sweep_type] 16 | assert source_ylabel == target_ylabel 17 | for ((algo, *_), argmax_joint, batch_size), sweep in source_sweeps.items(): 18 | if algo == "Null": 19 | key = ("Null-unconfounded",), argmax_joint, batch_size 20 | target_sweeps[key] = sweep 21 | all_target_sweeps[sweep_type] = (target_sweeps, target_ylabel) 22 | 23 | np.savez(target, **all_target_sweeps) 24 | 25 | 26 | if __name__ == "__main__": 27 | cmd() 28 | -------------------------------------------------------------------------------- /tta/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import sys 3 | 4 | import torch 5 | from torch.utils.data import Dataset, random_split 6 | 7 | 8 | class Dataset(Dataset): 9 | def __len__(self): 10 | raise NotImplementedError 11 | 12 | 13 | class Tee: 14 | def __init__(self, fname, mode="w"): 15 | self.stdout = sys.stdout 16 | self.file = open(fname, mode) 17 | 18 | def write(self, message): 19 | self.stdout.write(message) 20 | self.file.write(message) 21 | self.flush() 22 | 23 | def flush(self): 24 | self.stdout.flush() 25 | self.file.flush() 26 | 27 | 28 | def split_dataset(dataset: Dataset, n: int) -> Tuple[Dataset, Dataset]: 29 | """ 30 | Return a pair of datasets corresponding to a random split of the given 31 | dataset, with n datapoints in the first dataset and the rest in the last, 32 | using the given random seed 33 | """ 34 | assert 0 <= n <= len(dataset) 35 | 36 | generator = torch.Generator().manual_seed(2022) 37 | first, second = random_split(dataset, (n, len(dataset) - n), generator) 38 | 39 | return first, second 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Beyond Invariance: Test-Time Label-Shift Adaptation for Distributions with "Spurious" Correlations 2 | 3 | Available: https://arxiv.org/abs/2211.15646 4 | 5 | To cite our paper 6 | 7 | ``` 8 | @misc{sun2022invariance, 9 | title={Beyond Invariance: Test-Time Label-Shift Adaptation for Distributions with "Spurious" Correlations}, 10 | author={Qingyao Sun and Kevin Murphy and Sayna Ebrahimi and Alexander D'Amour}, 11 | year={2022}, 12 | eprint={2211.15646}, 13 | archivePrefix={arXiv}, 14 | primaryClass={stat.ML} 15 | } 16 | ``` 17 | 18 | ## PyTorch Implementation (WIP) 19 | 20 | We have a work-in-progress PyTorch implementation available at https://github.com/nalzok/BalancingGroups . 21 | 22 | ## How to Reproduce 23 | 24 | 1. Install [Pipenv](https://pipenv.pypa.io/en/latest/) and [Pyenv](https://github.com/pyenv/pyenv#installation). 25 | 2. Install dependence with `PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html pipenv install --deploy`. 26 | 3. Run experiments with `make paper-mnist`, `make paper-chexpert-embedding`, `make paper-chexpert-pixel`, and `make tree`. 27 | 4. Aggregate experimental results and generate figures with `make merge`. 28 | -------------------------------------------------------------------------------- /scripts/matching.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | def match(labels: pd.DataFrame, datastore: Dict[str, np.ndarray], output: Path): 9 | labels = labels.drop(columns=["Unnamed: 0", "patient_id"]) 10 | 11 | uniques = {} 12 | for col in ("split", "GENDER", "PRIMARY_RACE", "ETHNICITY"): 13 | labels[col], uniques[col] = pd.factorize(labels[col], sort=True) 14 | 15 | N = len(labels.index) 16 | features = np.empty((N, 1376), dtype=float) 17 | attributes = np.empty((N, len(labels.columns)), dtype=int) 18 | for i, (image_id, *image_attributes) in enumerate(labels.itertuples()): 19 | features[i] = datastore[image_id] 20 | attributes[i] = image_attributes 21 | 22 | np.savez( 23 | output, 24 | features=features, 25 | attributes=attributes, 26 | columns=labels.columns, 27 | uniques=uniques, 28 | ) 29 | 30 | 31 | if __name__ == "__main__": 32 | root = Path("data/CheXpert") 33 | labels = pd.read_csv(root / "labels.csv", index_col="image_id") 34 | datastore = np.load(root / "embeddings.npz") 35 | output = Path("data/CheXpert/data_matrix.npz") 36 | match(labels, datastore, output) 37 | -------------------------------------------------------------------------------- /requirements-long.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.3.0 2 | astunparse==1.6.3 3 | cachetools==5.2.0 4 | certifi==2022.9.24 5 | charset-normalizer==2.1.1 6 | chex==0.1.5 7 | click==8.1.3 8 | commonmark==0.9.1 9 | contourpy==1.0.6 10 | cycler==0.11.0 11 | dm-tree==0.1.7 12 | flatbuffers==22.10.26 13 | flax==0.6.1 14 | fonttools==4.38.0 15 | gast==0.4.0 16 | google-auth==2.14.1 17 | google-auth-oauthlib==0.4.6 18 | google-pasta==0.2.0 19 | grpcio==1.50.0 20 | h5py==3.7.0 21 | idna==3.4 22 | importlib-metadata==5.0.0 23 | jax[tpu]==0.3.24 24 | jaxlib==0.3.24 25 | joblib==1.2.0 26 | keras==2.10.0 27 | keras-preprocessing==1.1.2 28 | kiwisolver==1.4.4 29 | libclang==14.0.6 30 | libtpu-nightly==0.1.dev20221103 31 | markdown==3.4.1 32 | markupsafe==2.1.1 33 | matplotlib==3.6.2 34 | msgpack==1.0.4 35 | numpy==1.23.4 36 | nvidia-cublas-cu11==11.10.3.66 37 | nvidia-cuda-nvrtc-cu11==11.7.99 38 | nvidia-cuda-runtime-cu11==11.7.99 39 | nvidia-cudnn-cu11==8.5.0.96 40 | oauthlib==3.2.2 41 | opt-einsum==3.3.0 42 | optax==0.1.3 43 | packaging==21.3 44 | pandas==1.5.1 45 | pillow==9.3.0 46 | protobuf==3.19.6 47 | pyasn1==0.4.8 48 | pyasn1-modules==0.2.8 49 | pygments==2.13.0 50 | pyparsing==3.0.9 51 | python-dateutil==2.8.2 52 | pytz==2022.6 53 | pyyaml==6.0 54 | requests==2.28.1 55 | requests-oauthlib==1.3.1 56 | rich==12.6.0 57 | rsa==4.9 58 | scikit-learn==1.1.3 59 | scipy==1.9.3 60 | setuptools==65.5.1 61 | six==1.16.0 62 | tensorboard==2.10.1 63 | tensorboard-data-server==0.6.1 64 | tensorboard-plugin-wit==1.8.1 65 | tensorflow-cpu==2.10.0 66 | tensorflow-estimator==2.10.0 67 | tensorflow-io-gcs-filesystem==0.27.0 68 | termcolor==2.1.0 69 | threadpoolctl==3.1.0 70 | toolz==0.12.0 71 | torch==1.13.0 72 | torchvision==0.14.0 73 | typing-extensions==4.4.0 74 | urllib3==1.26.12 75 | werkzeug==2.2.2 76 | wheel==0.38.3 77 | wrapt==1.14.1 78 | zipp==3.10.0 79 | -------------------------------------------------------------------------------- /scripts/andrew.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import tensorflow as tf 7 | from sklearn.model_selection import train_test_split 8 | from cxr_foundation import constants, train_lib 9 | 10 | 11 | embeddings = { 12 | "all": np.load("data/CheXpert/embeddings.npz"), 13 | } 14 | 15 | for split, npz in embeddings.items(): 16 | dirname = os.path.join("./data", "inputs", "chexpert", split) 17 | if not os.path.exists(dirname): 18 | os.makedirs(dirname) 19 | for i, item in enumerate(npz.items()): 20 | image_id, embedding = item 21 | fname = os.path.join(dirname, f"{i:08d}.tfrecord") 22 | example = tf.train.Example() 23 | example.features.feature[constants.IMAGE_ID_KEY].bytes_list.value[:] = [ 24 | image_id.encode("utf-8") 25 | ] 26 | example.features.feature[constants.EMBEDDING_KEY].float_list.value[ 27 | : 28 | ] = embedding 29 | with tf.io.TFRecordWriter(fname) as writer: 30 | writer.write(example.SerializeToString()) 31 | 32 | 33 | labels_df = pd.read_csv("data/CheXpert/labels.csv") 34 | labels_df.drop(columns=["Unnamed: 0"], inplace=True) 35 | 36 | 37 | labels_df["SEX"] = labels_df["GENDER"].apply(lambda x: int(x == "Male")) 38 | 39 | print(labels_df.SEX.value_counts()) 40 | 41 | df = labels_df 42 | 43 | df = df[~df["SEX"].isna()] 44 | model = train_lib.create_model(["SEX"], hidden_layer_sizes=[]) 45 | training_df, tune_df = train_test_split(df, test_size=0.2) 46 | training_labels = dict(zip(training_df["image_id"], training_df["SEX"].astype(int))) 47 | filenames = glob.glob(os.path.join("./data/inputs/chexpert/*/", "*.tfrecord")) 48 | training_data = train_lib.get_dataset(filenames, labels=training_labels) 49 | tune_labels = dict(zip(tune_df["image_id"], tune_df["SEX"].astype(int))) 50 | tune_data = train_lib.get_dataset(filenames, labels=tune_labels).batch(1).cache() 51 | model.fit( 52 | x=training_data.batch(512).prefetch(tf.data.AUTOTUNE).cache(), 53 | validation_data=tune_data, 54 | epochs=100, 55 | ) 56 | -------------------------------------------------------------------------------- /tta/models/__init__.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import linen as nn 4 | 5 | from tta.models.linear import Linear 6 | from tta.models.lenet import LeNet 7 | from tta.models.resnet import ResNet 8 | 9 | 10 | class AdaptiveNN(nn.Module): 11 | C: int 12 | K: int 13 | model: str 14 | 15 | def setup(self): 16 | self.M = self.C * self.K 17 | 18 | if self.model == 'Linear': 19 | self.net = Linear(num_outputs=self.M) 20 | elif self.model == 'LeNet': 21 | self.net = LeNet(num_outputs=self.M) 22 | elif self.model.startswith('ResNet'): 23 | self.num_layers = int(self.model[6:]) 24 | self.net = ResNet(num_outputs=self.M, num_layers=self.num_layers) 25 | else: 26 | raise ValueError(f'Unknown network architecture {self.model}') 27 | 28 | self.b = self.param('b', jax.nn.initializers.zeros, (self.M,)) 29 | self.T = self.param('T', jax.nn.initializers.ones, ()) 30 | self.source_prior = self.variable('prior', 'source', 31 | jax.nn.initializers.constant(1/self.M,), 32 | None, 33 | (self.M,)) 34 | self.target_prior = self.variable('prior', 'target', 35 | jax.nn.initializers.constant(1/self.M,), 36 | None, 37 | (self.M,)) 38 | 39 | def raw_logit(self, x, train: bool): 40 | logit = self.net(x, train) 41 | 42 | return logit 43 | 44 | def calibrated_logit(self, x, train: bool): 45 | logit = self.raw_logit(x, train) 46 | logit = jax.lax.stop_gradient(logit) 47 | 48 | # bias corrected temperature scaling 49 | logit = logit/self.T + self.b 50 | 51 | return logit 52 | 53 | def adapted_prob(self, x, train: bool): 54 | logit = self.calibrated_logit(x, train) 55 | 56 | # adaptation 57 | w = self.target_prior.value / self.source_prior.value 58 | logit_max = jnp.max(logit, axis=-1, keepdims=True) 59 | unnormalized = w * jnp.exp(logit - jax.lax.stop_gradient(logit_max)) 60 | prob = unnormalized / jnp.sum(unnormalized, axis=-1, keepdims=True) 61 | 62 | prob = prob.reshape((-1, self.C, self.K)) 63 | 64 | return prob 65 | -------------------------------------------------------------------------------- /tta/datasets/waterbirds.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import numpy as np 3 | import torch 4 | from torchvision import transforms as T 5 | 6 | from tta.datasets import MultipleDomainDataset 7 | 8 | 9 | class MultipleDomainWaterbirds(MultipleDomainDataset): 10 | 11 | domain_names = ['train', 'val', 'test'] 12 | 13 | def __init__(self, root, generator): 14 | input_shape = (1, 224, 224, 3) 15 | C = 2 16 | K = 2 17 | confounder_strength = np.array([0, 1, 2]) 18 | super().__init__(input_shape, C, K, confounder_strength) 19 | 20 | if root is None: 21 | raise ValueError('Data directory not specified!') 22 | 23 | self.generator = generator 24 | 25 | from wilds.datasets.waterbirds_dataset import WaterbirdsDataset 26 | self.waterbirds = WaterbirdsDataset(root_dir=root) 27 | 28 | # make Z compliant in shape 29 | self.waterbirds._metadata_array = self.waterbirds._metadata_array[:, 0] 30 | 31 | # P(Z|Y) 32 | conditionals = [ 33 | np.array([[0.95, 0.05], [0.05, 0.95]]), 34 | np.array([[0.95, 0.05], [0.05, 0.95]]), 35 | np.array([[0.5, 0.5], [0.5, 0.5]]), 36 | ] 37 | 38 | # ImageNet augmentation 39 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], 40 | std=[0.229, 0.224, 0.225]) 41 | permute = T.Lambda(lambda x: x.permute(1, 2, 0)) 42 | random_transform = T.Compose([ 43 | T.RandomResizedCrop(224), 44 | T.RandomHorizontalFlip(), 45 | T.ToTensor(), 46 | normalize, 47 | permute, 48 | ]) 49 | deterministic_transform = T.Compose([ 50 | T.Resize(256), 51 | T.CenterCrop(224), 52 | T.ToTensor(), 53 | normalize, 54 | permute, 55 | ]) 56 | 57 | for env in self.confounder_strength: 58 | conditional = torch.from_numpy(conditionals[env]) 59 | domain_name = self.domain_names[env] 60 | transform = random_transform if domain_name == 'train' else deterministic_transform 61 | domain = self.waterbirds.get_subset(domain_name, transform=transform) 62 | 63 | counter = Counter(int(label) for _, label, _ in domain) 64 | y_count = torch.zeros(C) 65 | for label in counter: 66 | y_count[label] += counter[label] 67 | y_freq = y_count / len(domain) 68 | joint_M = y_freq[:, np.newaxis] * conditional 69 | 70 | self.domains.append((domain, joint_M)) 71 | -------------------------------------------------------------------------------- /scripts/baseline.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | import jax 5 | import jax.numpy as jnp 6 | from flax.jax_utils import replicate, unreplicate 7 | import flax.linen as nn 8 | import optax 9 | from sklearn.model_selection import train_test_split 10 | from sklearn.metrics import roc_auc_score 11 | 12 | 13 | def load_data(data_matrix, column): 14 | X = data_matrix["features"] 15 | 16 | attributes = data_matrix["attributes"] 17 | columns = data_matrix["columns"] 18 | (index_Y,) = np.flatnonzero(columns == column) 19 | Y = attributes[:, index_Y] 20 | 21 | mask = np.ones_like(Y, dtype=bool) 22 | if column == "split" or column == "GENDER": 23 | mask = (Y == 0) | (Y == 1) 24 | elif column == "PRIMARY_RACE": 25 | mask = Y >= 0 26 | Y = (Y == 19).astype(int) # uniques["PRIMARY_RACE"][19] == 'WHITE' 27 | elif column == "ETHNICITY": 28 | mask = Y >= 0 29 | Y = (Y == 2).astype(int) # uniques["ETHNICITY"][2] == 'Non-Hispanic/Non-Latino' 30 | elif column == "AGE_AT_CXR": 31 | cutoff = np.median(Y) 32 | Y = (Y > cutoff).astype(int) 33 | else: 34 | # 0 = no mention 35 | # 1 = positive 36 | # 2 = uncertain 37 | # 3 = negative 38 | mask = (Y == 1) | (Y == 3) 39 | Y //= 2 40 | 41 | X = X[mask] 42 | Y = Y[mask] 43 | 44 | return X, Y 45 | 46 | 47 | @partial( 48 | jax.pmap, 49 | axis_name="batch", 50 | static_broadcasted_argnums=(2, 3), 51 | donate_argnums=(0, 1), 52 | ) 53 | def train_step(params, opt_state, model, tx, X, Y): 54 | @partial(jax.value_and_grad, has_aux=True) 55 | def loss_grad_fn(params, x, y): 56 | logit = model.apply(params, x) 57 | loss = optax.softmax_cross_entropy_with_integer_labels(logit, y) 58 | loss = jax.lax.pmean(jnp.mean(loss), axis_name="batch") 59 | prob = jax.nn.softmax(logit) 60 | score = prob[:, 1] 61 | return loss, score 62 | 63 | (loss, score), grads = loss_grad_fn(params, X, Y) 64 | updates, opt_state = tx.update(grads, opt_state, params) 65 | params = optax.apply_updates(params, updates) 66 | 67 | return params, opt_state, loss, score 68 | 69 | 70 | @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(1,)) 71 | def test_step(params, model, X, Y): 72 | logit = model.apply(params, X) 73 | loss = optax.softmax_cross_entropy_with_integer_labels(logit, Y) 74 | loss = jax.lax.pmean(jnp.mean(loss), axis_name="batch") 75 | prob = jax.nn.softmax(logit) 76 | score = prob[:, 1] 77 | 78 | return loss, score 79 | 80 | 81 | def baseline(data_matrix, column): 82 | device_count = jax.local_device_count() 83 | X, Y = load_data(data_matrix, column) 84 | 85 | X, X_test, Y, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42) 86 | 87 | X = jnp.array(X[X.shape[0] % device_count :]).reshape( 88 | device_count, -1, *X.shape[1:] 89 | ) 90 | Y = jnp.array(Y[Y.shape[0] % device_count :]).reshape( 91 | device_count, -1, *Y.shape[1:] 92 | ) 93 | X_test = jnp.array(X_test[X_test.shape[0] % device_count :]).reshape( 94 | device_count, -1, *X_test.shape[1:] 95 | ) 96 | Y_test = jnp.array(Y_test[Y_test.shape[0] % device_count :]).reshape( 97 | device_count, -1, *Y_test.shape[1:] 98 | ) 99 | 100 | model = nn.Dense(features=2) 101 | key = jax.random.PRNGKey(42) 102 | dummy = jnp.empty((1, 1376)) 103 | params = model.init(key, dummy) 104 | 105 | learning_rate = 1e-3 106 | tx = optax.adam(learning_rate=learning_rate) 107 | opt_state = tx.init(params) 108 | 109 | params = replicate(params) 110 | opt_state = replicate(opt_state) 111 | loss = float("inf") 112 | for _ in range(1001): 113 | params, opt_state, loss, _ = train_step(params, opt_state, model, tx, X, Y) 114 | 115 | _, score = test_step(params, model, X_test, Y_test) 116 | auc = roc_auc_score(Y_test.reshape(-1), score.reshape(-1)) 117 | print(rf"{column.replace('_', chr(92)+'_')} & {auc:.3f} & {np.mean(Y_test):.3f} & {unreplicate(loss):.3f} \\") 118 | 119 | 120 | if __name__ == "__main__": 121 | data_matrix = np.load("data/CheXpert/data_matrix.npz", allow_pickle=True) 122 | columns = data_matrix["columns"] 123 | for column in columns: 124 | baseline(data_matrix, column) 125 | -------------------------------------------------------------------------------- /scripts/freeze.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Set 2 | from pathlib import Path 3 | 4 | import click 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | 9 | from tta.utils import Dataset 10 | from tta.datasets.mnist import MultipleDomainMNIST 11 | from tta.datasets.cxr.chexpert import MultipleDomainCheXpert 12 | from tta.datasets import MultipleDomainDataset, split 13 | 14 | 15 | @click.command() 16 | @click.option("--seed", type=int, required=True) 17 | def main(seed: int): 18 | for getter in [get_mnist, get_chexpert]: 19 | name, train_domains_set, dataset = getter(seed) 20 | 21 | train_fraction = 0.9 22 | train_calibration_fraction = 0.1 23 | calibration_domains_set = set() 24 | calibration_fraction = 0.0 25 | 26 | (train, _), (calibration, _), test_splits = split( 27 | dataset, 28 | train_domains_set, 29 | train_fraction, 30 | train_calibration_fraction, 31 | calibration_domains_set, 32 | calibration_fraction, 33 | ) 34 | 35 | splits = [(0, train), (1, calibration)] 36 | for i, (test, _) in enumerate(test_splits): 37 | splits.append((i + 2, test)) 38 | 39 | metadata_records = [] 40 | for split_id, ds in splits: 41 | X, Y, _, Z = dataset2np(ds) 42 | for embedding, y, z in zip(X, Y, Z): 43 | filename = f"{len(metadata_records)}.npy" 44 | path = Path("frozen") / name / filename 45 | path.parent.mkdir(parents=True, exist_ok=True) 46 | np.save(path, embedding) 47 | 48 | metadata_records.append({ 49 | "filename": filename, 50 | "split": split_id, 51 | "y": y, 52 | "a": z, 53 | }) 54 | 55 | metadata = pd.DataFrame.from_records(metadata_records) 56 | metadata.index.name = "id" 57 | 58 | metadata_path = Path("frozen") / f"{name}.csv" 59 | metadata_path.parent.mkdir(parents=True, exist_ok=True) 60 | metadata.to_csv(metadata_path) 61 | 62 | 63 | 64 | def get_mnist(seed) -> Tuple[str, Set[int], MultipleDomainDataset]: 65 | generator = torch.Generator().manual_seed(seed) 66 | 67 | train_domain = 1 68 | train_domains_set = {train_domain} 69 | dataset_apply_rotation = False 70 | dataset_feature_noise = 0 71 | dataset_label_noise = 0 72 | 73 | root = Path("data/mnist") 74 | dataset = MultipleDomainMNIST( 75 | root, 76 | train_domains_set, 77 | generator, 78 | dataset_apply_rotation, 79 | dataset_feature_noise, 80 | dataset_label_noise, 81 | ) 82 | 83 | name = f"mnist_rot{dataset_apply_rotation}_noise{dataset_label_noise}_domain{train_domain}_seed{seed}" 84 | 85 | return name, train_domains_set, dataset 86 | 87 | 88 | def get_chexpert(seed) -> Tuple[str, Set[int], MultipleDomainDataset]: 89 | generator = torch.Generator().manual_seed(seed) 90 | 91 | train_domain = 1 92 | train_domains_set = {train_domain} 93 | dataset_y_column = "EFFUSION" 94 | dataset_z_column = "GENDER" 95 | dataset_target_domain_count = 512 96 | dataset_source_domain_count = 65536 97 | dataset_use_embedding = True 98 | 99 | root = Path("data/CheXpert") 100 | dataset = MultipleDomainCheXpert( 101 | root, 102 | train_domains_set, 103 | generator, 104 | dataset_y_column, 105 | dataset_z_column, 106 | dataset_use_embedding, 107 | dataset_target_domain_count, 108 | dataset_source_domain_count, 109 | ) 110 | 111 | name = f"chexpert-embedding_{dataset_y_column}_{dataset_z_column}_domain{train_domain}_size{dataset_source_domain_count}_seed{seed}" 112 | 113 | return name, train_domains_set, dataset 114 | 115 | 116 | def dataset2np(dataset: Dataset) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 117 | X, Y, Y_tilde, Z = [], [], [], [] 118 | for x, y_tilde, y, z_flattened in dataset: 119 | X.append(x) 120 | Y.append(y) 121 | Y_tilde.append(y_tilde) 122 | Z.append(z_flattened) 123 | 124 | X = np.stack(X) 125 | Y = np.stack(Y) 126 | Y_tilde = np.stack(Y_tilde) 127 | Z = np.stack(Z) 128 | 129 | return X, Y, Y_tilde, Z 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /tta/datasets/cxr/mimic.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from hashlib import sha256 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from pandas.api.types import CategoricalDtype 7 | import torch 8 | 9 | from tta.datasets.cxr import MultipleDomainCXR 10 | 11 | 12 | class MultipleDomainMIMIC(MultipleDomainCXR): 13 | 14 | def __init__(self, root, train_domains, generator, Y_col: str, Z_col: str, use_embedding: bool, 15 | target_domain_count: int, source_domain_count: Optional[int]): 16 | if len(train_domains) != 1: 17 | raise NotImplementedError( 18 | "Training on multiple source distributions is not supported yet." 19 | ) 20 | if not use_embedding: 21 | raise NotImplementedError( 22 | "Using raw images for MIMIC is not supported yet." 23 | ) 24 | train_domain = next(iter(train_domains)) 25 | patient_col = "subject_id" 26 | 27 | input_shape = (1, 1376) 28 | C = 2 29 | K = 2 30 | confounder_strength = np.linspace(0, 1, 21) 31 | 32 | m = sha256() 33 | m.update(self.__class__.__name__.encode()) 34 | m.update(str(sorted(train_domains)).encode()) 35 | m.update(generator.get_state().numpy().data.hex().encode()) 36 | m.update(Y_col.encode()) 37 | m.update(Z_col.encode()) 38 | m.update(str(target_domain_count).encode()) 39 | m.update(str(source_domain_count).encode()) 40 | 41 | m.update(str(input_shape).encode()) 42 | m.update(str(C).encode()) 43 | m.update(str(K).encode()) 44 | m.update(confounder_strength.data.hex().encode()) 45 | m.update(str(train_domain).encode()) 46 | hexdigest = m.hexdigest() 47 | 48 | super().__init__(input_shape, C, K, confounder_strength, train_domain, hexdigest) 49 | 50 | cache_key = f'{train_domain}_{Y_col}_{Z_col}_{target_domain_count}_{source_domain_count}_{hexdigest}' 51 | cache_file = root / 'cached' / f'{cache_key}.pt' 52 | if cache_file.is_file(): 53 | # NOTE: The torch.Generator state won't be the same if we load from cache 54 | print(f'Loading cached datasets from {cache_file}') 55 | self.domains = torch.load(cache_file) 56 | return 57 | 58 | print('Building datasets... (this may take a while)') 59 | if root is None: 60 | raise ValueError('Data directory not specified!') 61 | 62 | self.generator = generator 63 | self.train_domains = train_domains 64 | 65 | labels_raw: pd.DataFrame = pd.read_csv(root / "mimic_labels_raw.csv", index_col="dicom_id") 66 | mimic_attributes: pd.DataFrame = pd.read_csv(root / "mimic_attributes.csv", index_col="dicom_id") 67 | labels = labels_raw.join(mimic_attributes, rsuffix="_attr") 68 | datastore = np.load(root / "mimic.npz") 69 | 70 | # Pneumonia 71 | # 0 = negative - 24303 72 | # -1 = uncertain - 19441 73 | # 1 = positive - 17222 74 | # 75 | # Pleural Effusion 76 | # 0 = negative - 27645 77 | # -1 = uncertain - 6202 78 | # 1 = positive - 57721 79 | # 80 | # Edema 81 | # 0 = negative - 25991 82 | # -1 = uncertain - 14244 83 | # 1 = positive - 29331 84 | # 85 | # gender 86 | # M - 130468 87 | # F - 112001 88 | relevant_columns = {Y_col, Z_col} 89 | pathology_dtype = CategoricalDtype(categories=(0.0, 1.0)) 90 | gender_dtype = CategoricalDtype(categories=("F", "M")) 91 | for column in ("Pneumonia", "Pleural Effusion", "Edema", "gender"): 92 | if column not in relevant_columns: 93 | continue 94 | 95 | if column in {"Pneumonia", "Pleural Effusion", "Edema"}: 96 | labels[column] = labels[column].astype(pathology_dtype) 97 | elif column == "gender": 98 | labels[column] = labels[column].astype(gender_dtype) 99 | 100 | nlevels = len(labels[column].dtype.categories) 101 | if nlevels != 2: 102 | raise NotImplementedError(f"Column {column} has {nlevels} != 2 levels.") 103 | 104 | labels = labels.loc[~labels[column].isna()] 105 | labels[column] = labels[column].cat.codes 106 | 107 | self.domains = self.build(generator, datastore, labels, Y_col, Z_col, patient_col, target_domain_count, source_domain_count) 108 | 109 | if use_embedding: 110 | cache_file.parent.mkdir(parents=True, exist_ok=True) 111 | print(f'Saving cached datasets to {cache_file}') 112 | torch.save(self.domains, cache_file) 113 | -------------------------------------------------------------------------------- /tta/datasets/cxr/chexpert.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from hashlib import sha256 3 | import re 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from pandas.api.types import CategoricalDtype 8 | import torch 9 | from PIL import Image 10 | import torchvision.transforms as T 11 | 12 | from tta.datasets.cxr import MultipleDomainCXR 13 | 14 | 15 | class MultipleDomainCheXpert(MultipleDomainCXR): 16 | 17 | def __init__(self, root, train_domains, generator, Y_col: str, Z_col: str, use_embedding: bool, 18 | target_domain_count: int, source_domain_count: Optional[int]): 19 | if len(train_domains) != 1: 20 | raise NotImplementedError( 21 | "Training on multiple source distributions is not supported yet." 22 | ) 23 | train_domain = next(iter(train_domains)) 24 | patient_col = "patient_id" 25 | 26 | if use_embedding: 27 | input_shape = (1, 1376) 28 | else: 29 | input_shape = (1, 224, 224, 3) 30 | C = 2 31 | K = 2 32 | confounder_strength = np.linspace(0, 1, 21) 33 | 34 | m = sha256() 35 | m.update(self.__class__.__name__.encode()) 36 | m.update(str(sorted(train_domains)).encode()) 37 | m.update(generator.get_state().numpy().data.hex().encode()) 38 | m.update(Y_col.encode()) 39 | m.update(Z_col.encode()) 40 | m.update(str(use_embedding).encode()) 41 | m.update(str(target_domain_count).encode()) 42 | m.update(str(source_domain_count).encode()) 43 | 44 | m.update(str(input_shape).encode()) 45 | m.update(str(C).encode()) 46 | m.update(str(K).encode()) 47 | m.update(confounder_strength.data.hex().encode()) 48 | m.update(str(train_domain).encode()) 49 | hexdigest = m.hexdigest() 50 | 51 | super().__init__(input_shape, C, K, confounder_strength, train_domain, hexdigest) 52 | 53 | cache_key = f'{train_domain}_{Y_col}_{Z_col}_{use_embedding}_{target_domain_count}_{source_domain_count}_{hexdigest}' 54 | cache_file = root / 'cached' / f'{cache_key}.pt' 55 | if cache_file.is_file(): 56 | # NOTE: The torch.Generator state won't be the same if we load from cache 57 | print(f'Loading cached datasets from {cache_file}') 58 | self.domains = torch.load(cache_file) 59 | return 60 | 61 | print('Building datasets... (this may take a while)') 62 | if root is None: 63 | raise ValueError('Data directory not specified!') 64 | 65 | self.generator = generator 66 | self.use_embedding = use_embedding 67 | self.train_domains = train_domains 68 | 69 | labels: pd.DataFrame = pd.read_csv(root / "labels.csv", index_col="image_id") 70 | if use_embedding: 71 | datastore = np.load(root / "embeddings.npz") 72 | else: 73 | datastore = CheXpertImages(root) 74 | 75 | # PNEUMONIA 76 | # 0 = no mention - 15933 77 | # 1 = positive - 4657 78 | # 2 = uncertain - 2054 79 | # 3 = negative - 167855 80 | # 81 | # EFFUSION 82 | # 0 = no mention - 9527 83 | # 1 = positive - 76726 84 | # 2 = uncertain - 25371 85 | # 3 = negative - 78875 86 | relevant_columns = {Y_col, Z_col} 87 | pathology_dtype = CategoricalDtype(categories=(3, 1)) 88 | gender_dtype = CategoricalDtype(categories=("Female", "Male")) 89 | for column in ("PNEUMONIA", "EFFUSION", "GENDER"): 90 | if column not in relevant_columns: 91 | continue 92 | 93 | if column in {"PNEUMONIA", "EFFUSION"}: 94 | labels[column] = labels[column].astype(pathology_dtype) 95 | elif column in "GENDER": 96 | labels[column] = labels[column].astype(gender_dtype) 97 | 98 | nlevels = len(labels[column].dtype.categories) 99 | if nlevels != 2: 100 | raise NotImplementedError(f"Column {column} has {nlevels} != 2 levels.") 101 | 102 | labels = labels.loc[~labels[column].isna()] 103 | labels[column] = labels[column].cat.codes 104 | 105 | self.domains = self.build(generator, datastore, labels, Y_col, Z_col, patient_col, target_domain_count, source_domain_count) 106 | 107 | if use_embedding: 108 | cache_file.parent.mkdir(parents=True, exist_ok=True) 109 | print(f'Saving cached datasets to {cache_file}') 110 | torch.save(self.domains, cache_file) 111 | 112 | 113 | class CheXpertImages: 114 | def __init__(self, root): 115 | self.root = root 116 | self.transform = T.Compose([ 117 | T.Resize((224, 224)), 118 | T.ToTensor(), 119 | T.Lambda(lambda x: x.permute(1, 2, 0)), # (C, H, W) -> (H, W, C) 120 | ]) 121 | self.pattern = re.compile("^CheXpert-v1.0/") 122 | 123 | def __getitem__(self, key): 124 | key = re.sub(self.pattern, "CheXpert-v1.0-small/", key) 125 | image = Image.open(self.root / key) 126 | return self.transform(image) 127 | -------------------------------------------------------------------------------- /tta/datasets/coco.py: -------------------------------------------------------------------------------- 1 | # Forked from https://github.com/facebookresearch/DomainBed/blob/main/domainbed/datasets.py 2 | from pathlib import Path 3 | from hashlib import sha256 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import TensorDataset 8 | from torchvision.transforms import ToTensor 9 | from PIL import Image 10 | 11 | from tta.datasets import MultipleDomainDataset 12 | 13 | 14 | class ColoredCOCO(MultipleDomainDataset): 15 | def __init__(self, root: Path, annFile: Path, generator: torch.Generator): 16 | self.categories = [ 17 | 'boat', 18 | 'airplane', 19 | 'truck', 20 | 'dog', 21 | 'zebra', 22 | 'horse', 23 | 'bird', 24 | 'train', 25 | 'bus', 26 | ] 27 | 28 | self.backgrounds = [ 29 | ( 0, 100, 0), 30 | (188, 143, 143), 31 | (255, 0, 0), 32 | (255, 215, 0), 33 | ( 0, 255, 0), 34 | ( 65, 105, 225), 35 | ( 0, 225, 225), 36 | ( 0, 0, 255), 37 | (255, 20, 147), 38 | ] 39 | 40 | input_shape = (1, 64, 64, 3) 41 | C = len(self.categories) 42 | K = len(self.backgrounds) 43 | confounder_strength = np.array([0.9, 0.8, 0.1]) 44 | super().__init__(input_shape, C, K, confounder_strength) 45 | 46 | m = sha256() 47 | m.update(str(annFile).encode()) 48 | cache_key = m.hexdigest() 49 | cache_file = root / 'cached' / f'{cache_key}.pt' 50 | if cache_file.is_file(): 51 | # NOTE: The torch.Generator state won't be the same if we load from cache 52 | print(f'Loading cached datasets from {cache_file}') 53 | self.domains = torch.load(cache_file) 54 | return 55 | 56 | if root is None: 57 | raise ValueError('Data directory not specified!') 58 | 59 | from pycocotools.coco import COCO 60 | self.root = root 61 | self.coco = COCO(annFile) 62 | 63 | self.cat_ids = self.coco.getCatIds(catNms=self.categories) 64 | self.image_ids_set = set() 65 | for cat_id in self.cat_ids: 66 | self.image_ids_set.update(self.coco.getImgIds(catIds=cat_id)) 67 | self.image_ids = list(self.image_ids_set) 68 | 69 | self.generator = generator 70 | 71 | shuffle = torch.randperm(len(self.image_ids), generator=self.generator) 72 | 73 | independent = np.ones((C, K)) * 1/K 74 | confounding1 = np.eye(C, K) 75 | confounding1 = 0.75 * confounding1 + 0.25 * independent 76 | confounding2 = np.roll(confounding1, shift=1, axis=1) 77 | 78 | for i, strength in enumerate(self.confounder_strength): 79 | indices = shuffle[i::len(self.confounder_strength)] 80 | prob = torch.from_numpy(strength * confounding1 + (1-strength) * confounding2) 81 | domain = self.dataset_transform(indices, prob) 82 | self.domains.append((domain, prob)) # FIXME: prob should be joint 83 | 84 | cache_file.parent.mkdir(parents=True, exist_ok=True) 85 | print(f'Saving cached datasets to {cache_file}') 86 | torch.save(self.domains, cache_file) 87 | 88 | def dataset_transform(self, indices: torch.Tensor, prob: torch.Tensor) -> TensorDataset: 89 | X, Y, Z = [], [], [] 90 | p = torch.cumsum(prob, dim=1) 91 | to_tensor = ToTensor() 92 | 93 | for sample_idx in indices: 94 | image_id = self.image_ids[sample_idx] 95 | image_json, = self.coco.loadImgs(image_id) 96 | image = Image.open(self.root / image_json['file_name']).convert('RGB') 97 | anns = self.coco.loadAnns(self.coco.getAnnIds( 98 | imgIds=image_id, 99 | catIds=self.cat_ids, 100 | areaRng=(10000, float('inf')) 101 | )) 102 | 103 | max_area = 0 104 | ann = None 105 | for candidate in anns: 106 | if max_area < candidate['area']: 107 | max_area = candidate['area'] 108 | ann = candidate 109 | 110 | if ann is None: 111 | continue 112 | 113 | cat_idx = self.cat_ids.index(ann['category_id']) 114 | background_idx = torch.searchsorted(p[cat_idx], torch.rand(1, generator=self.generator)) 115 | background_color = self.backgrounds[background_idx] 116 | 117 | mask = 255 * self.coco.annToMask(ann) 118 | mask = Image.fromarray(mask) 119 | 120 | background = Image.new('RGB', image.size, background_color) 121 | image = Image.composite(image, background, mask) 122 | image = image.resize((64, 64)) 123 | # (C, H, W) -> (H, W, C) 124 | image = to_tensor(image).permute(1, 2, 0) 125 | 126 | X.append(image) 127 | Y.append(cat_idx) 128 | Z.append(background_idx) 129 | 130 | X = torch.stack(X) 131 | Y = torch.Tensor(Y).long() 132 | Z = torch.cat(Z) 133 | 134 | return TensorDataset(X, Y, Y, Z) 135 | -------------------------------------------------------------------------------- /tta/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Forked from https://github.com/facebookresearch/DomainBed/blob/main/domainbed/datasets.py 2 | 3 | from typing import Set, Tuple, List 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import ConcatDataset, Subset 8 | 9 | from tta.utils import Dataset, split_dataset 10 | 11 | 12 | class MultipleDomainDataset: 13 | def __init__( 14 | self, input_shape, C, K, confounder_strength, train_domain, hexdigest 15 | ) -> None: 16 | super().__init__() 17 | 18 | self.input_shape: Tuple[int] = input_shape 19 | self.C: int = C 20 | self.K: int = K 21 | self.confounder_strength: np.ndarray = confounder_strength 22 | self.train_domain: int = train_domain 23 | self.hexdigest: str = hexdigest 24 | self.domains: List[Tuple[Dataset, torch.Tensor]] = [] 25 | 26 | 27 | def split( 28 | dataset: MultipleDomainDataset, 29 | train_domains: Set[int], 30 | train_fraction: float, 31 | train_calibration_fraction: float, 32 | calibration_domains: Set[int], 33 | calibration_fraction: float, 34 | ) -> Tuple[ 35 | Tuple[Dataset, torch.Tensor], 36 | Tuple[Dataset, torch.Tensor], 37 | List[Tuple[Dataset, torch.Tensor]], 38 | ]: 39 | train_splits = [] 40 | calibration_splits = [] 41 | test_splits = [] 42 | 43 | for i, (domain, joint_M) in enumerate(dataset.domains): 44 | if i in train_domains: 45 | # For source domains, we split it into train + calibration + test 46 | train, test = split_dataset(domain, int(len(domain) * train_fraction)) 47 | calibration, train = split_dataset( 48 | train, int(len(domain) * train_calibration_fraction) 49 | ) 50 | 51 | train_splits.append(train) 52 | calibration_splits.append(calibration) 53 | test_splits.append((test, joint_M)) 54 | elif i in calibration_domains: 55 | # For calibration domains, we split it into calibration + test 56 | calibration, test = split_dataset( 57 | domain, int(len(domain) * calibration_fraction) 58 | ) 59 | 60 | calibration_splits.append(calibration) 61 | test_splits.append((test, joint_M)) 62 | else: 63 | # For target domains, all samples are used as test 64 | test_splits.append((domain, joint_M)) 65 | 66 | joint_shape = dataset.domains[0][1].shape 67 | if joint_shape != (2, 2): 68 | raise NotImplementedError(f"(C, K) = {joint_shape} != (2, 2)") 69 | 70 | train = ConcatDataset(train_splits) 71 | joint_M_train = torch.zeros_like(dataset.domains[0][1]) 72 | for _, _, y, z in train: 73 | joint_M_train[y][z] += 1 74 | joint_M_train /= torch.sum(joint_M_train) 75 | 76 | calibration = ConcatDataset(calibration_splits) 77 | joint_M_calibration = torch.zeros_like(dataset.domains[0][1]) 78 | for _, _, y, z in calibration: 79 | joint_M_calibration[y][z] += 1 80 | joint_M_calibration /= torch.sum(joint_M_calibration) 81 | 82 | return (train, joint_M_train), (calibration, joint_M_calibration), test_splits 83 | 84 | 85 | def subsample( 86 | dataset: Dataset, 87 | joint_M: torch.Tensor, 88 | subsample_what: str, 89 | generator: torch.Generator, 90 | ) -> Tuple[Dataset, torch.Tensor]: 91 | joint_M_count = torch.zeros_like(joint_M, dtype=torch.long) 92 | M = [] 93 | for _, _, y, z in dataset: 94 | joint_M_count[y][z] += 1 95 | m = y * joint_M.shape[-1] + z 96 | M.append(m) 97 | M = torch.ByteTensor(M) 98 | 99 | count_per_group = torch.min(joint_M_count).item() 100 | Y = M // joint_M.shape[-1] 101 | joint_Y_count = torch.sum(joint_M_count, dim=1) 102 | count_per_class = torch.min(joint_Y_count).item() 103 | 104 | if subsample_what == "groups": 105 | indices_list = [] 106 | for m in range(np.prod(joint_M_count.shape)): 107 | weights = (M == m).float() 108 | indices_m = torch.multinomial( 109 | weights, count_per_group, replacement=False, generator=generator 110 | ) 111 | indices_list.extend(indices_m) 112 | 113 | elif subsample_what == "classes": 114 | indices_list = [] 115 | for y in range(np.prod(joint_Y_count.shape)): 116 | weights = (Y == y).float() 117 | indices_y = torch.multinomial( 118 | weights, count_per_class, replacement=False, generator=generator 119 | ) 120 | indices_list.extend(indices_y) 121 | 122 | else: 123 | raise ValueError(f"Unknown setting {subsample_what = }") 124 | 125 | subset = Subset(dataset, indices_list) 126 | 127 | joint_M_actual = torch.zeros_like(joint_M_count) 128 | for _, _, y, z in subset: 129 | joint_M_actual[y][z] += 1 130 | joint_Y_actual = torch.sum(joint_M_actual, dim=1) 131 | 132 | # Sanity check 133 | if subsample_what == "groups": 134 | joint_M_expected = count_per_group * torch.ones_like(joint_M_count) 135 | if not torch.allclose(joint_M_actual, joint_M_expected): 136 | raise ValueError(f"{joint_M_actual = }, {joint_M_expected = }") 137 | 138 | elif subsample_what == "classes": 139 | joint_Y_expected = count_per_class * torch.ones_like(joint_Y_count) 140 | if not torch.allclose(joint_Y_actual, joint_Y_expected): 141 | raise ValueError(f"{joint_Y_actual = }, {joint_Y_expected = }") 142 | 143 | else: 144 | raise ValueError(f"Unknown setting {subsample_what = }") 145 | 146 | joint_M_actual = joint_M_actual.float() / torch.sum(joint_M_actual) 147 | 148 | return subset, joint_M_actual 149 | -------------------------------------------------------------------------------- /tta/models/resnet.py: -------------------------------------------------------------------------------- 1 | """Implementation of ResNet.""" 2 | 3 | import functools 4 | from typing import Tuple, Callable, Any, Optional, Union, Dict 5 | 6 | import flax.linen as nn 7 | import jax 8 | import jax.numpy as jnp 9 | 10 | 11 | class IdentityLayer(nn.Module): 12 | """Identity layer, convenient for giving a name to an array.""" 13 | 14 | @nn.compact 15 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 16 | return x 17 | 18 | 19 | class ResidualBlock(nn.Module): 20 | """Bottleneck ResNet block.""" 21 | filters: int 22 | strides: Tuple[int, int] = (1, 1) 23 | dtype: jnp.dtype = jnp.float32 24 | bottleneck: bool = True 25 | 26 | @nn.compact 27 | def __call__(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray: 28 | needs_projection = x.shape[-1] != self.filters * 4 or self.strides != (1, 1) 29 | nout = self.filters * 4 if self.bottleneck else self.filters 30 | 31 | batch_norm = functools.partial( 32 | nn.BatchNorm, 33 | use_running_average=not train, 34 | momentum=0.9, 35 | epsilon=1e-5, 36 | dtype=self.dtype) 37 | conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) 38 | 39 | residual = x 40 | if needs_projection: 41 | residual = conv(nout, (1, 1), self.strides, name='proj_conv')(residual) 42 | residual = batch_norm(name='proj_bn')(residual) 43 | 44 | if self.bottleneck: 45 | x = conv(self.filters, (1, 1), name='conv1')(x) 46 | x = batch_norm(name='bn1')(x) 47 | x = IdentityLayer(name='relu1')(nn.relu(x)) 48 | 49 | y = conv( 50 | self.filters, (3, 3), 51 | self.strides, 52 | padding=[(1, 1), (1, 1)], 53 | name='conv2')(x) 54 | y = batch_norm(name='bn2')(y) 55 | y = IdentityLayer(name='relu2')(nn.relu(y)) 56 | 57 | if self.bottleneck: 58 | y = conv(nout, (1, 1), name='conv3')(y) 59 | else: 60 | y = conv(nout, (3, 3), padding=[(1, 1), (1, 1)], name='conv3')(y) 61 | y = batch_norm(name='bn3', scale_init=jax.nn.initializers.zeros)(y) 62 | y = IdentityLayer(name='relu3')(nn.relu(residual + y)) 63 | return y 64 | 65 | 66 | class ResNet(nn.Module): 67 | """ResNet architecture. 68 | 69 | Attributes: 70 | num_outputs: Num output classes. If None, a dict of intermediate feature 71 | maps is returned. 72 | num_filters: Num filters. 73 | num_layers: Num layers. 74 | kernel_init: Kernel initialization. 75 | bias_init: Bias initialization. 76 | dtype: Data type, e.g. jnp.float32. 77 | """ 78 | num_outputs: Optional[int] 79 | num_filters: int = 64 80 | num_layers: int = 50 81 | kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_normal() 82 | bias_init: Callable[..., Any] = jax.nn.initializers.zeros 83 | dtype: jnp.dtype = jnp.float32 84 | 85 | @nn.compact 86 | def __call__( 87 | self, 88 | x: jnp.ndarray, 89 | train: bool = False, 90 | debug: bool = False) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: 91 | """Applies ResNet model to the inputs. 92 | 93 | Args: 94 | x: Inputs to the model. 95 | train: Whether it is training or not. 96 | debug: Whether the debug mode is enabled. debug=True enables model 97 | specific logging/storing some values using jax.host_callback. 98 | 99 | Returns: 100 | Un-normalized logits. 101 | """ 102 | if self.num_layers not in BLOCK_SIZE_OPTIONS: 103 | raise ValueError('Please provide a valid number of layers') 104 | block_sizes, bottleneck = BLOCK_SIZE_OPTIONS[self.num_layers] 105 | x = nn.Conv( 106 | self.num_filters, 107 | kernel_size=(7, 7), 108 | strides=(2, 2), 109 | padding=[(3, 3), (3, 3)], 110 | use_bias=False, 111 | dtype=self.dtype, 112 | name='stem_conv')(x) 113 | x = nn.BatchNorm( 114 | use_running_average=not train, 115 | momentum=0.9, 116 | epsilon=1e-5, 117 | dtype=self.dtype, 118 | name='init_bn')(x) 119 | x = IdentityLayer(name='init_relu')(nn.relu(x)) 120 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)]) 121 | 122 | residual_block = functools.partial( 123 | ResidualBlock, dtype=self.dtype, bottleneck=bottleneck) 124 | representations = {'stem': x} 125 | for i, block_size in enumerate(block_sizes): 126 | for j in range(block_size): 127 | strides = (2, 2) if i > 0 and j == 0 else (1, 1) 128 | filters = self.num_filters * 2**i 129 | x = residual_block(filters=filters, strides=strides)(x, train) 130 | representations[f'stage_{i + 1}'] = x 131 | 132 | # Head. 133 | if self.num_outputs: 134 | x = jnp.mean(x, axis=(1, 2)) 135 | x = IdentityLayer(name='pre_logits')(x) 136 | x = nn.Dense( 137 | self.num_outputs, 138 | kernel_init=self.kernel_init, 139 | bias_init=self.bias_init, 140 | dtype=self.dtype, 141 | name='output_projection')(x) 142 | return x 143 | else: 144 | return representations 145 | 146 | 147 | # A dictionary mapping the number of layers in a resnet to the number of 148 | # blocks in each stage of the model. The second argument indicates whether we 149 | # use bottleneck layers or not. 150 | BLOCK_SIZE_OPTIONS = { 151 | 5: ([1], True), # Only strided blocks. Total stride 4. 152 | 8: ([1, 1], True), # Only strided blocks. Total stride 8. 153 | 11: ([1, 1, 1], True), # Only strided blocks. Total stride 16. 154 | 14: ([1, 1, 1, 1], True), # Only strided blocks. Total stride 32. 155 | 9: ([1, 1, 1, 1], False), # Only strided blocks. Total stride 32. 156 | 18: ([2, 2, 2, 2], False), 157 | 26: ([2, 2, 2, 2], True), 158 | 34: ([3, 4, 6, 3], False), 159 | 50: ([3, 4, 6, 3], True), 160 | 101: ([3, 4, 23, 3], True), 161 | 152: ([3, 8, 36, 3], True), 162 | 200: ([3, 24, 36, 3], True) 163 | } 164 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/linux,python,c,macos,vim 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,python,c,macos,vim 3 | 4 | ### C ### 5 | # Prerequisites 6 | *.d 7 | 8 | # Object files 9 | *.o 10 | *.ko 11 | *.obj 12 | *.elf 13 | 14 | # Linker output 15 | *.ilk 16 | *.map 17 | *.exp 18 | 19 | # Precompiled Headers 20 | *.gch 21 | *.pch 22 | 23 | # Libraries 24 | *.lib 25 | *.a 26 | *.la 27 | *.lo 28 | 29 | # Shared objects (inc. Windows DLLs) 30 | *.dll 31 | *.so 32 | *.so.* 33 | *.dylib 34 | 35 | # Executables 36 | *.exe 37 | *.out 38 | *.app 39 | *.i*86 40 | *.x86_64 41 | *.hex 42 | 43 | # Debug files 44 | *.dSYM/ 45 | *.su 46 | *.idb 47 | *.pdb 48 | 49 | # Kernel Module Compile Results 50 | *.mod* 51 | *.cmd 52 | .tmp_versions/ 53 | modules.order 54 | Module.symvers 55 | Mkfile.old 56 | dkms.conf 57 | 58 | ### Linux ### 59 | *~ 60 | 61 | # temporary files which can be created if a process still has a handle open of a deleted file 62 | .fuse_hidden* 63 | 64 | # KDE directory preferences 65 | .directory 66 | 67 | # Linux trash folder which might appear on any partition or disk 68 | .Trash-* 69 | 70 | # .nfs files are created when an open file is removed but is still being accessed 71 | .nfs* 72 | 73 | ### macOS ### 74 | # General 75 | .DS_Store 76 | .AppleDouble 77 | .LSOverride 78 | 79 | # Icon must end with two \r 80 | Icon 81 | 82 | 83 | # Thumbnails 84 | ._* 85 | 86 | # Files that might appear in the root of a volume 87 | .DocumentRevisions-V100 88 | .fseventsd 89 | .Spotlight-V100 90 | .TemporaryItems 91 | .Trashes 92 | .VolumeIcon.icns 93 | .com.apple.timemachine.donotpresent 94 | 95 | # Directories potentially created on remote AFP share 96 | .AppleDB 97 | .AppleDesktop 98 | Network Trash Folder 99 | Temporary Items 100 | .apdisk 101 | 102 | ### macOS Patch ### 103 | # iCloud generated files 104 | *.icloud 105 | 106 | ### Python ### 107 | # Byte-compiled / optimized / DLL files 108 | __pycache__/ 109 | *.py[cod] 110 | *$py.class 111 | 112 | # C extensions 113 | 114 | # Distribution / packaging 115 | .Python 116 | build/ 117 | develop-eggs/ 118 | dist/ 119 | downloads/ 120 | eggs/ 121 | .eggs/ 122 | lib/ 123 | lib64/ 124 | parts/ 125 | sdist/ 126 | var/ 127 | wheels/ 128 | share/python-wheels/ 129 | *.egg-info/ 130 | .installed.cfg 131 | *.egg 132 | MANIFEST 133 | 134 | # PyInstaller 135 | # Usually these files are written by a python script from a template 136 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 137 | *.manifest 138 | *.spec 139 | 140 | # Installer logs 141 | pip-log.txt 142 | pip-delete-this-directory.txt 143 | 144 | # Unit test / coverage reports 145 | htmlcov/ 146 | .tox/ 147 | .nox/ 148 | .coverage 149 | .coverage.* 150 | .cache 151 | nosetests.xml 152 | coverage.xml 153 | *.cover 154 | *.py,cover 155 | .hypothesis/ 156 | .pytest_cache/ 157 | cover/ 158 | 159 | # Translations 160 | *.mo 161 | *.pot 162 | 163 | # Django stuff: 164 | *.log 165 | local_settings.py 166 | db.sqlite3 167 | db.sqlite3-journal 168 | 169 | # Flask stuff: 170 | instance/ 171 | .webassets-cache 172 | 173 | # Scrapy stuff: 174 | .scrapy 175 | 176 | # Sphinx documentation 177 | docs/_build/ 178 | 179 | # PyBuilder 180 | .pybuilder/ 181 | target/ 182 | 183 | # Jupyter Notebook 184 | .ipynb_checkpoints 185 | 186 | # IPython 187 | profile_default/ 188 | ipython_config.py 189 | 190 | # pyenv 191 | # For a library or package, you might want to ignore these files since the code is 192 | # intended to run in multiple environments; otherwise, check them in: 193 | # .python-version 194 | 195 | # pipenv 196 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 197 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 198 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 199 | # install all needed dependencies. 200 | #Pipfile.lock 201 | 202 | # poetry 203 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 204 | # This is especially recommended for binary packages to ensure reproducibility, and is more 205 | # commonly ignored for libraries. 206 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 207 | #poetry.lock 208 | 209 | # pdm 210 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 211 | #pdm.lock 212 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 213 | # in version control. 214 | # https://pdm.fming.dev/#use-with-ide 215 | .pdm.toml 216 | 217 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 218 | __pypackages__/ 219 | 220 | # Celery stuff 221 | celerybeat-schedule 222 | celerybeat.pid 223 | 224 | # SageMath parsed files 225 | *.sage.py 226 | 227 | # Environments 228 | .env 229 | .venv 230 | env/ 231 | venv/ 232 | ENV/ 233 | env.bak/ 234 | venv.bak/ 235 | 236 | # Spyder project settings 237 | .spyderproject 238 | .spyproject 239 | 240 | # Rope project settings 241 | .ropeproject 242 | 243 | # mkdocs documentation 244 | /site 245 | 246 | # mypy 247 | .mypy_cache/ 248 | .dmypy.json 249 | dmypy.json 250 | 251 | # Pyre type checker 252 | .pyre/ 253 | 254 | # pytype static type analyzer 255 | .pytype/ 256 | 257 | # Cython debug symbols 258 | cython_debug/ 259 | 260 | # PyCharm 261 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 262 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 263 | # and can be added to the global gitignore or merged into this file. For a more nuclear 264 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 265 | #.idea/ 266 | 267 | ### Vim ### 268 | # Swap 269 | [._]*.s[a-v][a-z] 270 | !*.svg # comment out if you don't need vector files 271 | [._]*.sw[a-p] 272 | [._]s[a-rt-v][a-z] 273 | [._]ss[a-gi-z] 274 | [._]sw[a-p] 275 | 276 | # Session 277 | Session.vim 278 | Sessionx.vim 279 | 280 | # Temporary 281 | .netrwhist 282 | # Auto-generated tag files 283 | tags 284 | # Persistent undo 285 | [._]*.un~ 286 | 287 | # End of https://www.toptal.com/developers/gitignore/api/linux,python,c,macos,vim 288 | 289 | data/** 290 | pretrained/** 291 | jit_cache/** 292 | joblog*.txt 293 | logs/**.txt 294 | plots/**.png 295 | plots/**.pdf 296 | merged/**.png 297 | merged/**.pdf 298 | npz/**.npz 299 | checkpoints/** 300 | frozen/** 301 | -------------------------------------------------------------------------------- /tta/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # Forked from https://github.com/facebookresearch/DomainBed/blob/main/domainbed/datasets.py 2 | from collections import Counter 3 | from hashlib import sha256 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import TensorDataset 8 | from torchvision.datasets import MNIST 9 | from torchvision import transforms as T 10 | from PIL import Image 11 | 12 | from tta.datasets import MultipleDomainDataset 13 | 14 | 15 | class MultipleDomainMNIST(MultipleDomainDataset): 16 | 17 | def __init__(self, root, train_domains, generator, apply_rotation: bool, feature_noise: float, label_noise: float): 18 | if len(train_domains) != 1: 19 | raise NotImplementedError( 20 | "Training on multiple source distributions is not supported yet." 21 | ) 22 | train_domain = next(iter(train_domains)) 23 | 24 | self.colors = torch.ByteTensor([ 25 | (1, 0, 0), 26 | (0, 1, 0), 27 | ]) 28 | if apply_rotation: 29 | self.angles = torch.ShortTensor([0, 15]) 30 | else: 31 | self.angles = torch.ShortTensor([0]) 32 | self.Z = torch.LongTensor([(c_idx, r_idx) for c_idx in range(len(self.colors)) for r_idx in range(len(self.angles))]) 33 | 34 | input_shape = (1, 28, 28, 3) 35 | C = 2 36 | K = len(self.Z) 37 | confounder_strength = np.linspace(0, 1, 21) 38 | 39 | m = sha256() 40 | m.update(self.__class__.__name__.encode()) 41 | m.update(str(sorted(train_domains)).encode()) 42 | m.update(generator.get_state().numpy().data.hex().encode()) 43 | m.update(str(apply_rotation).encode()) 44 | m.update(str(feature_noise).encode()) 45 | m.update(str(label_noise).encode()) 46 | m.update(self.colors.numpy().data.hex().encode()) 47 | m.update(self.angles.numpy().data.hex().encode()) 48 | 49 | m.update(str(input_shape).encode()) 50 | m.update(str(C).encode()) 51 | m.update(str(K).encode()) 52 | m.update(confounder_strength.data.hex().encode()) 53 | m.update(str(train_domain).encode()) 54 | hexdigest = m.hexdigest() 55 | 56 | super().__init__(input_shape, C, K, confounder_strength, train_domain, hexdigest) 57 | 58 | cache_key = f'{train_domain}_{apply_rotation}_{feature_noise}_{label_noise}_{hexdigest}' 59 | cache_file = root / 'cached' / f'{cache_key}.pt' 60 | if cache_file.is_file(): 61 | # NOTE: The torch.Generator state won't be the same if we load from cache 62 | print(f'Loading cached datasets from {cache_file}') 63 | self.domains = torch.load(cache_file) 64 | return 65 | 66 | print('Building datasets... (this may take a while)') 67 | if root is None: 68 | raise ValueError('Data directory not specified!') 69 | 70 | self.generator = generator 71 | self.train_domains = train_domains 72 | self.feature_noise = feature_noise 73 | self.label_noise = label_noise 74 | 75 | original_dataset_tr = MNIST(root, train=True, download=True) 76 | original_dataset_te = MNIST(root, train=False, download=True) 77 | 78 | original_images = torch.cat((original_dataset_tr.data, 79 | original_dataset_te.data)) 80 | 81 | original_labels = torch.cat((original_dataset_tr.targets, 82 | original_dataset_te.targets)) 83 | original_labels = (original_labels < 5).long() 84 | 85 | shuffle = torch.randperm(len(original_images), generator=self.generator) 86 | 87 | original_images = original_images[shuffle] 88 | original_labels = original_labels[shuffle] 89 | 90 | # P(Z|Y) 91 | if apply_rotation: 92 | anchor1 = np.array([[0.5, 0.5, 0.0, 0.0], [0.0, 0.0, 0.5, 0.5]]) 93 | anchor2 = np.array([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.0, 0.0]]) 94 | else: 95 | anchor1 = np.array([[1.0, 0.0], [0.0, 1.0]]) 96 | anchor2 = np.array([[0.0, 1.0], [1.0, 0.0]]) 97 | 98 | for i, strength in enumerate(self.confounder_strength): 99 | offset = 0 if i in train_domains else 1 100 | images = original_images[offset::2] 101 | labels = original_labels[offset::2] 102 | conditional = torch.from_numpy(strength * anchor1 + (1-strength) * anchor2) 103 | domain = self.shift(images, labels, conditional) 104 | 105 | counter = Counter(labels.numpy()) 106 | y_count = torch.zeros(C) 107 | for label in counter: 108 | y_count[label] += counter[label] 109 | y_freq = y_count / len(labels) 110 | joint_M = y_freq[:, np.newaxis] * conditional 111 | 112 | self.domains.append((domain, joint_M)) 113 | 114 | cache_file.parent.mkdir(parents=True, exist_ok=True) 115 | print(f'Saving cached datasets to {cache_file}') 116 | torch.save(self.domains, cache_file) 117 | 118 | 119 | def shift(self, images, y_tilde, conditional): 120 | lookup_table = torch.cumsum(conditional, dim=1) 121 | to_tensor = T.ToTensor() 122 | N = y_tilde.size(0) 123 | 124 | # inject noise to Y 125 | if self.label_noise > 0: 126 | weights = torch.ones((N, self.C)) 127 | weights[torch.arange(N), y_tilde] += 1/self.label_noise - 2 128 | else: 129 | weights = torch.zeros((N, self.C)) 130 | weights[torch.arange(N), y_tilde] = 1 131 | y = torch.multinomial(weights, 1, generator=self.generator).squeeze(dim=-1) 132 | 133 | # generate Z condition on Y 134 | values = torch.rand((N, 1), generator=self.generator) 135 | z_idx = torch.searchsorted(lookup_table[y], values).squeeze(dim=-1) 136 | z = self.Z[z_idx] 137 | z_flattened = len(self.angles) * z[:, 0] + z[:, 1] 138 | 139 | # transform X based on Z 140 | x = torch.empty((N, *self.input_shape[1:])) 141 | for i, (image, (color_idx, angle_idx)) in enumerate(zip(images, z)): 142 | color = self.colors[color_idx] 143 | angle = self.angles[angle_idx] 144 | 145 | image = color * image.unsqueeze(-1) 146 | image = Image.fromarray(image.numpy()) 147 | image = image.rotate(angle.item(), resample=Image.BILINEAR) 148 | image = to_tensor(image) 149 | image = image.permute(1, 2, 0) 150 | 151 | noise = self.feature_noise * torch.randn(image.size(), generator=self.generator) 152 | image = torch.clamp(image + noise, 0, 1) 153 | 154 | x[i] = image 155 | 156 | return TensorDataset(x, y_tilde, y, z_flattened) 157 | -------------------------------------------------------------------------------- /tta/datasets/cxr/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import softmax 3 | import torch 4 | from torch.utils.data import TensorDataset 5 | 6 | from tta.datasets import MultipleDomainDataset 7 | 8 | 9 | class MultipleDomainCXR(MultipleDomainDataset): 10 | 11 | def build(self, generator, datastore, labels, Y_col, Z_col, patient_col, target_domain_count, source_domain_count): 12 | # Pathology: 0 = Negative, 1 = Positive 13 | # GENDER: 0 = Female, 1 = Male 14 | labels["M"] = 2 * labels[Y_col] + labels[Z_col] 15 | print(f"histogram({Y_col}, {Z_col}) =", labels["M"].value_counts().sort_index().values) 16 | 17 | marginal_Y = labels[Y_col].value_counts(normalize=True).sort_index().values 18 | marginal_Z = labels[Z_col].value_counts(normalize=True).sort_index().values 19 | print("marginal_Y", marginal_Y) 20 | print("marginal_Z", marginal_Z) 21 | 22 | # joint distribution of Y and Z 23 | p_11_min = max(0, marginal_Y[1] + marginal_Z[1] - 1) 24 | p_11_max = min(marginal_Y[1], marginal_Z[1]) 25 | anchor1 = np.array([ 26 | [1 - marginal_Y[1] - marginal_Z[1] + p_11_min, marginal_Z[1]-p_11_min ], 27 | [marginal_Y[1] - p_11_min, p_11_min ] 28 | ]) 29 | anchor2 = np.array([ 30 | [1 - marginal_Y[1] - marginal_Z[1] + p_11_max, marginal_Z[1]-p_11_max ], 31 | [marginal_Y[1] - p_11_max, p_11_max ] 32 | ]) 33 | print("anchor1", anchor1) 34 | print("anchor2", anchor2) 35 | 36 | mask = np.ones(len(labels.index), dtype=bool) 37 | domains = [None for _ in self.confounder_strength] 38 | 39 | # Sample source domains 40 | for i, strength in enumerate(self.confounder_strength): 41 | if i != self.train_domain: 42 | continue 43 | 44 | quota = labels["M"].loc[mask].value_counts().sort_index().values - target_domain_count 45 | quota = torch.from_numpy(quota) 46 | joint_M = torch.from_numpy(strength * anchor1 + (1-strength) * anchor2) 47 | 48 | source_domain_count_max = torch.floor(torch.min(quota/joint_M.flatten())).item() 49 | if source_domain_count is None: 50 | source_domain_count = source_domain_count_max 51 | elif source_domain_count > source_domain_count_max: 52 | raise ValueError(f"Insufficient samples for the source domain: {source_domain_count} > {source_domain_count_max}") 53 | 54 | count = torch.round(source_domain_count * joint_M).long() 55 | count = self.fix_count(count, source_domain_count) 56 | count_flatten = torch.flatten(count) 57 | assert torch.all(count_flatten <= quota), f"Insufficient samples for the source domain: {count_flatten} > {quota}" 58 | 59 | joint_M = count / torch.sum(count) 60 | 61 | print(f"histogram(M) = {count.flatten()}") 62 | reservation = np.ceil(target_domain_count * np.maximum(anchor1, anchor2).flatten()) 63 | domain, in_sample_patients = self.sample(generator, datastore, labels, Y_col, Z_col, patient_col, mask, count, reservation) 64 | mask &= ~labels[patient_col].isin(in_sample_patients) 65 | domains[i] = (domain, joint_M) 66 | 67 | remainder = np.sum(mask) 68 | if remainder < target_domain_count: 69 | raise ValueError(f"Not enough data for target domains: {remainder} < {target_domain_count}") 70 | 71 | # Sample target domains 72 | for i, strength in enumerate(self.confounder_strength): 73 | if i == self.train_domain: 74 | continue 75 | 76 | joint_M = torch.from_numpy(strength * anchor1 + (1-strength) * anchor2) 77 | count = torch.round(target_domain_count * joint_M).long() 78 | count = self.fix_count(count, target_domain_count) 79 | joint_M = count / torch.sum(count) 80 | 81 | print(f"histogram(M) = {count.flatten()}") 82 | domain, _ = self.sample(generator, datastore, labels, Y_col, Z_col, patient_col, mask, count, None) 83 | domains[i] = (domain, joint_M) 84 | 85 | return domains 86 | 87 | 88 | def fix_count(self, count, domain_count): 89 | count = torch.flatten(count) 90 | 91 | l1, l2, l3 = torch.topk(count, 3).indices 92 | if torch.sum(count) > domain_count: 93 | count[l1] -= 1 94 | if torch.sum(count) > domain_count: 95 | count[l2] -= 1 96 | if torch.sum(count) > domain_count: 97 | count[l3] -= 1 98 | 99 | s1, s2, s3 = torch.topk(count, 3, largest=False).indices 100 | if torch.sum(count) < domain_count: 101 | count[s1] += 1 102 | if torch.sum(count) < domain_count: 103 | count[s2] += 1 104 | if torch.sum(count) < domain_count: 105 | count[s3] += 1 106 | 107 | total_count = torch.sum(count) 108 | if total_count != domain_count: 109 | raise ValueError(f"Incorrect total count: {total_count} != {domain_count}") 110 | 111 | count = count.reshape((2, 2)) 112 | return count 113 | 114 | 115 | def sample(self, generator, datastore, labels, Y_col, Z_col, patient_col, mask, count, reservation): 116 | random_state = 0 117 | while True: 118 | in_sample = set() 119 | for Y in range(2): 120 | for Z in range(2): 121 | masked = labels.loc[mask & (labels["M"] == 2 * Y + Z)] 122 | image_per_patient = masked.groupby(patient_col).size() 123 | weights = image_per_patient.loc[masked[patient_col]].values 124 | indices = masked.sample(int(count[Y, Z]), weights=weights, random_state=random_state) 125 | in_sample.update(indices.index) 126 | 127 | class_name = self.__class__.__name__ 128 | if class_name == "MultipleDomainCheXpert": 129 | in_sample_patients = { fname.split("/")[2] for fname in in_sample } 130 | elif class_name == "MultipleDomainMIMIC": 131 | subject_id = labels["subject_id"] 132 | in_sample_patients = { subject_id.at[dicom_id].item() for dicom_id in in_sample } 133 | else: 134 | raise NotImplementedError(f"Unknown dataset {class_name}") 135 | 136 | remainder = np.bincount(labels["M"], weights=mask & ~labels[patient_col].isin(in_sample_patients)) 137 | if reservation is None or np.all(remainder >= reservation): 138 | print(f" remainder = {remainder} >= {reservation} = target_domain_count") 139 | break 140 | 141 | random_state += 1 142 | print(f" remainder = {remainder} < {reservation} = target_domain_count") 143 | 144 | N = int(torch.sum(count)) 145 | assert len(in_sample) == N, f"Incorrect number of elements: {len(in_sample)} != {N}" 146 | 147 | x = torch.empty((N, *self.input_shape[1:])) 148 | y_tilde = torch.empty(N, dtype=torch.long) 149 | y = torch.empty(N, dtype=torch.long) 150 | z_flattened = torch.empty(N, dtype=torch.long) 151 | 152 | perm = torch.randperm(N, generator=generator) 153 | for i, key in enumerate(in_sample): 154 | x[perm[i]] = torch.Tensor(datastore[key]) 155 | row = labels.loc[key] 156 | y[perm[i]] = y_tilde[perm[i]] = row[Y_col] 157 | z_flattened[perm[i]] = row[Z_col] 158 | 159 | return TensorDataset(x, y_tilde, y, z_flattened), in_sample_patients 160 | -------------------------------------------------------------------------------- /scripts/tree.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from itertools import count 3 | 4 | import click 5 | import numpy as np 6 | import torch 7 | from sklearn.ensemble import HistGradientBoostingClassifier 8 | from sklearn.metrics import roc_auc_score 9 | 10 | from tta.datasets.mnist import MultipleDomainMNIST 11 | from tta.datasets.cxr.chexpert import MultipleDomainCheXpert 12 | from tta.datasets import split 13 | from tta.visualize import latexify, plot 14 | 15 | 16 | @click.command() 17 | @click.option("--seed", type=int, required=True) 18 | def main(seed: int): 19 | jobs = [] 20 | 21 | for train_domain in (1,): 22 | generator = torch.Generator().manual_seed(seed) 23 | 24 | train_domains_set = {train_domain} 25 | dataset_apply_rotation = False 26 | dataset_feature_noise = 0 27 | dataset_label_noise = 0 28 | 29 | root = Path("data/mnist") 30 | dataset = MultipleDomainMNIST( 31 | root, 32 | train_domains_set, 33 | generator, 34 | dataset_apply_rotation, 35 | dataset_feature_noise, 36 | dataset_label_noise, 37 | ) 38 | 39 | prior_strength = 1 40 | config_name = f"tree_mnist_rot{dataset_apply_rotation}_noise{dataset_label_noise}_domain{train_domain}_prior{prior_strength}_seed{seed}" 41 | jobs.append((dataset, train_domains_set, dataset_label_noise, prior_strength, config_name)) 42 | 43 | for train_domain in (1,): 44 | generator = torch.Generator().manual_seed(seed) 45 | 46 | train_domains_set = {train_domain} 47 | dataset_y_column = "EFFUSION" 48 | dataset_z_column = "GENDER" 49 | dataset_target_domain_count = 512 50 | dataset_source_domain_count = 65536 51 | dataset_use_embedding = True 52 | dataset_label_noise = 0 53 | 54 | root = Path("data/CheXpert") 55 | dataset = MultipleDomainCheXpert( 56 | root, 57 | train_domains_set, 58 | generator, 59 | dataset_y_column, 60 | dataset_z_column, 61 | dataset_use_embedding, 62 | dataset_target_domain_count, 63 | dataset_source_domain_count, 64 | ) 65 | 66 | prior_strength = 1 67 | config_name = f"tree_chexpert-embedding_{dataset_y_column}_{dataset_z_column}_domain{train_domain}_size{dataset_source_domain_count}_prior{prior_strength}_seed{seed}" 68 | jobs.append((dataset, train_domains_set, dataset_label_noise, prior_strength, config_name)) 69 | 70 | for dataset, train_domains_set, dataset_label_noise, prior_strength, config_name in jobs: 71 | auc_sweeps = make_auc_sweeps(dataset, train_domains_set, prior_strength) 72 | 73 | npz_path = Path(f"npz/{config_name}.npz") 74 | all_sweeps = { 75 | "auc": (auc_sweeps, "AUC"), 76 | } 77 | np.savez(npz_path, **all_sweeps) 78 | 79 | plot_root = Path("plots/") 80 | y_lim = (0.6, 1) 81 | plot(npz_path, dataset.confounder_strength, train_domains_set, dataset_label_noise, "", plot_root, config_name, y_lim) 82 | 83 | 84 | def make_auc_sweeps(dataset, train_domains_set, prior_strength): 85 | train_fraction = 0.9 86 | train_calibration_fraction = 0.1 87 | calibration_domains_set = set() 88 | calibration_fraction = 0.0 89 | 90 | (train, _), (calibration, _), test_splits = split( 91 | dataset, 92 | train_domains_set, 93 | train_fraction, 94 | train_calibration_fraction, 95 | calibration_domains_set, 96 | calibration_fraction, 97 | ) 98 | 99 | # Training 100 | X, Y, _, Z = dataset2np(train) 101 | M = Y * 2 + Z 102 | clf = HistGradientBoostingClassifier(random_state=0) 103 | clf = clf.fit(X, M) 104 | 105 | induced_prob = clf.predict_proba(X) 106 | source = np.mean(induced_prob, axis=0) 107 | 108 | # Calibration 109 | X, Y, _, Z = dataset2np(calibration) 110 | M = Y * 2 + Z 111 | # TODO: do calibration with gradient descent or whatever 112 | 113 | # Testing 114 | auc_erm = [] 115 | auc_oracle = [] 116 | auc_gmtl_05 = [] 117 | auc_gmtl_10 = [] 118 | auc_gmtl_20 = [] 119 | auc_em = [] 120 | 121 | print("AUC") 122 | auc_sweeps = {} 123 | 124 | alpha = prior_strength * 4 * source 125 | for i, (test, target_oracle) in enumerate(test_splits): 126 | target_oracle = target_oracle.numpy().flatten() 127 | X, Y, _, Z = dataset2np(test) 128 | M = Y * 2 + Z 129 | prob = clf.predict_proba(X) 130 | 131 | prob_oracle = target_oracle * prob / source 132 | prob_oracle /= np.sum(prob_oracle, axis=-1, keepdims=True) 133 | 134 | prob_gmtl_05 = source**(1-0.5) * prob / source 135 | prob_gmtl_05 /= np.sum(prob_gmtl_05, axis=-1, keepdims=True) 136 | prob_gmtl_10 = source**(1-1.0) * prob / source 137 | prob_gmtl_10 /= np.sum(prob_gmtl_10, axis=-1, keepdims=True) 138 | prob_gmtl_20 = source**(1-2.0) * prob / source 139 | prob_gmtl_20 /= np.sum(prob_gmtl_20, axis=-1, keepdims=True) 140 | 141 | target = np.copy(source) 142 | for j in count(): 143 | old = target 144 | 145 | # E step 146 | prob_em = target * prob / source 147 | normalizer = np.sum(prob_em, axis=-1, keepdims=True) 148 | prob_em = prob_em / normalizer 149 | 150 | # M step 151 | prob_em_count = np.sum(prob_em, axis=0) + (alpha - 1) 152 | target = prob_em_count / np.sum(prob_em_count) 153 | 154 | if np.allclose(target, old) or j > 10000: 155 | break 156 | 157 | erm = evaluate(prob, Y) 158 | oracle = evaluate(prob_oracle, Y) 159 | gmtl_05 = evaluate(prob_gmtl_05, Y) 160 | gmtl_10 = evaluate(prob_gmtl_10, Y) 161 | gmtl_20 = evaluate(prob_gmtl_20, Y) 162 | em = evaluate(prob_em, Y) 163 | 164 | auc_erm.append(erm) 165 | auc_oracle.append(oracle) 166 | auc_gmtl_05.append(gmtl_05) 167 | auc_gmtl_10.append(gmtl_10) 168 | auc_gmtl_20.append(gmtl_20) 169 | auc_em.append(em) 170 | print("*" if i in train_domains_set else " ", f"domain #{i:<2}, {erm = :.4f}", f"{em = :.4f}", f"{oracle = :.4f}") 171 | 172 | # Dummy 173 | auc_erm.append(None) 174 | auc_oracle.append(None) 175 | auc_gmtl_05.append(None) 176 | auc_gmtl_10.append(None) 177 | auc_gmtl_20.append(None) 178 | auc_em.append(None) 179 | 180 | batch_size = len(test_splits[0][0]) 181 | auc_sweeps[("Null",), False, batch_size] = auc_erm 182 | auc_sweeps[("Oracle",), False, batch_size] = auc_oracle 183 | auc_sweeps[("GMTL", 0.5), False, batch_size] = auc_gmtl_05 184 | auc_sweeps[("GMTL", 1.0), False, batch_size] = auc_gmtl_10 185 | auc_sweeps[("GMTL", 2.0), False, batch_size] = auc_gmtl_20 186 | auc_sweeps[("EM", prior_strength, False, False), False, batch_size] = auc_em 187 | 188 | return auc_sweeps 189 | 190 | 191 | def dataset2np(dataset): 192 | X, Y, Y_tilde, Z = [], [], [], [] 193 | for x, y_tilde, y, z_flattened in dataset: 194 | X.append(x) 195 | Y.append(y) 196 | Y_tilde.append(y_tilde) 197 | Z.append(z_flattened) 198 | 199 | X = np.stack(X) 200 | X = X.reshape(-1, np.prod(X.shape[1:])) 201 | Y = np.stack(Y) 202 | Y_tilde = np.stack(Y_tilde) 203 | Z = np.stack(Z) 204 | 205 | return X, Y, Y_tilde, Z 206 | 207 | 208 | def evaluate(prob_M, Y): 209 | prob_M = prob_M.reshape((-1, 2, 2)) 210 | prob_Y = np.sum(prob_M, axis=-1) 211 | score_Y = prob_Y[:, 1] # assumes binary label 212 | auc = roc_auc_score(Y, score_Y) 213 | return auc 214 | 215 | 216 | if __name__ == "__main__": 217 | latexify(width_scale_factor=2, fig_height=2) 218 | main() 219 | -------------------------------------------------------------------------------- /tta/restore.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Mapping, Any, Tuple 2 | from pathlib import Path 3 | from collections.abc import MutableMapping 4 | import logging 5 | 6 | import flax 7 | from flax.training.checkpoints import restore_checkpoint, convert_pre_linen 8 | 9 | from tta.train import TrainState 10 | 11 | 12 | # JAX team is working on type annotation for pytree: 13 | # https://github.com/google/jax/issues/1555 14 | PyTree = Union[Mapping[str, Mapping], Any] 15 | 16 | 17 | def restore_train_state(state: TrainState, checkpoint_path: Path) -> TrainState: 18 | restored_params, restored_batch_stats = load_pretrained_checkpoint( 19 | state.params["net"], state.batch_stats["net"], checkpoint_path 20 | ) 21 | 22 | model_params = _replace_dict(state.params["net"], restored_params) 23 | params = flax.core.unfreeze(state.params) 24 | params["net"] = model_params 25 | state = state.replace(params=flax.core.freeze(params)) 26 | 27 | model_batch_stats = _replace_dict(state.batch_stats["net"], restored_batch_stats) 28 | batch_stats = flax.core.unfreeze(state.batch_stats) 29 | batch_stats["net"] = model_batch_stats 30 | state = state.replace(batch_stats=flax.core.freeze(batch_stats)) 31 | 32 | return state 33 | 34 | 35 | def load_pretrained_checkpoint( 36 | params: PyTree, batch_stats: PyTree, checkpoint_path: Path 37 | ) -> Tuple[PyTree, PyTree]: 38 | restored_train_state = restore_checkpoint(checkpoint_path, None) 39 | if restored_train_state is None: 40 | raise ValueError( 41 | f"No checkpoint for the pretrained model is found in: {checkpoint_path}" 42 | ) 43 | 44 | if "params" in restored_train_state: 45 | # restored_train_state was trained using optax 46 | restored_params = restored_train_state["params"] 47 | else: 48 | # restored_train_state was trained using flax.optim. Note that this does 49 | # not convert the naming of pre-Linen checkpoints. 50 | restored_params = restored_train_state["optimizer"]["target"] 51 | if "params" in restored_params: # Backward compatibility. 52 | restored_params = restored_params["params"] 53 | restored_params = dict(convert_pre_linen(restored_params)) 54 | 55 | del restored_params["output_projection"] # Remove classification head 56 | del restored_params["pre_logits"] # Not sure why it's there 57 | restored_params = flax.core.freeze(restored_params) 58 | 59 | # Inspect and compare the parameters of the model with the init-model. 60 | restored_params = inspect_params( 61 | expected_params=params, 62 | restored_params=restored_params, 63 | fail_if_extra=True, 64 | fail_if_missing=False, 65 | fail_if_shapes_mismatch=True, 66 | ) 67 | 68 | restored_batch_stats = restored_train_state["model_state"] 69 | restored_batch_stats = {k[1:]: v for k, v in restored_batch_stats.items()} 70 | restored_batch_stats = flax.traverse_util.flatten_dict( 71 | restored_batch_stats, sep="/" 72 | ) 73 | restored_batch_stats = flax.traverse_util.unflatten_dict( 74 | restored_batch_stats, sep="/" 75 | ) 76 | restored_batch_stats = flax.core.freeze(restored_batch_stats) 77 | 78 | restored_batch_stats = inspect_params( 79 | expected_params=batch_stats, 80 | restored_params=restored_batch_stats, 81 | fail_if_extra=True, 82 | fail_if_missing=True, 83 | fail_if_shapes_mismatch=True, 84 | ) 85 | 86 | return restored_params, restored_batch_stats 87 | 88 | 89 | def _replace_dict(model: PyTree, restored: PyTree) -> PyTree: 90 | """Replaces values in model dictionary with restored ones from checkpoint.""" 91 | model = flax.core.unfreeze(model) # pytype: disable=wrong-arg-types 92 | restored = flax.core.unfreeze(restored) # pytype: disable=wrong-arg-types 93 | 94 | # Flatten nested parameters to a dict of str -> tensor. Keys are tuples 95 | # from the path in the nested dictionary to the specific tensor. E.g., 96 | # {'a1': {'b1': t1, 'b2': t2}, 'a2': t3} 97 | # -> {('a1', 'b1'): t1, ('a1', 'b2'): t2, ('a2',): t3}. 98 | restored_flat = flax.traverse_util.flatten_dict( 99 | dict(restored), keep_empty_nodes=True 100 | ) 101 | model_flat = flax.traverse_util.flatten_dict(dict(model), keep_empty_nodes=True) 102 | 103 | for m_key, m_params in restored_flat.items(): 104 | # pytype: enable=attribute-error 105 | m_key_str = "/".join(m_key) 106 | if m_key not in model_flat: 107 | raise ValueError("%s in checkpoint doesn't exist in model.", m_key_str) 108 | logging.info("Loading %s from checkpoint into model", m_key_str) 109 | model_flat[m_key] = m_params 110 | 111 | return flax.core.freeze(flax.traverse_util.unflatten_dict(model_flat)) 112 | 113 | 114 | def _flatten_params(d, parent_key="", sep="/"): 115 | """Flattens a dictionary, keeping empty leaves.""" 116 | items = [] 117 | for k, v in d.items(): 118 | path = parent_key + sep + k if parent_key else k 119 | if isinstance(v, MutableMapping): 120 | items.extend(_flatten_params(v, path, sep=sep).items()) 121 | else: 122 | items.append((path, v)) 123 | # Keeps the empty dict if it was set explicitly. 124 | if parent_key and not d: 125 | items.append((parent_key, {})) 126 | return dict(items) 127 | 128 | 129 | def inspect_params( 130 | *, 131 | expected_params: PyTree, 132 | restored_params: PyTree, 133 | fail_if_extra: bool = True, 134 | fail_if_missing: bool = True, 135 | fail_if_shapes_mismatch: bool = False, 136 | ) -> PyTree: 137 | """Inspects whether the params are consistent with the expected keys.""" 138 | 139 | expected_flat = _flatten_params(flax.core.unfreeze(expected_params)) 140 | restored_flat = _flatten_params(flax.core.unfreeze(restored_params)) 141 | missing_keys = expected_flat.keys() - restored_flat.keys() 142 | extra_keys = restored_flat.keys() - expected_flat.keys() 143 | 144 | is_shape_mismatch = False 145 | for key in restored_flat: 146 | if key in expected_flat: 147 | restored_shape = None 148 | expected_shape = None 149 | # Handle empty nodes (without trainable params) 150 | if not isinstance(restored_flat[key], dict): 151 | restored_shape = restored_flat[key].shape 152 | if not isinstance(expected_flat[key], dict): 153 | expected_shape = expected_flat[key].shape 154 | 155 | if restored_shape != expected_shape: 156 | is_shape_mismatch = True 157 | logging.warning( 158 | "Key: %s. Expected shape: %s. Restored shape: %s", 159 | key, 160 | expected_flat[key].shape, 161 | restored_flat[key].shape, 162 | ) 163 | 164 | # Adds back empty dict explicitly, to support layers without weights. 165 | # Context: FLAX ignores empty dict during serialization. 166 | empty_keys = set() 167 | for k in missing_keys: 168 | if isinstance(expected_flat[k], dict) and not expected_flat[k]: 169 | restored_params[k] = {} # pytype: disable=unsupported-operands 170 | empty_keys.add(k) 171 | missing_keys -= empty_keys 172 | 173 | if empty_keys: 174 | logging.warning("Inspect recovered empty keys:\n%s", empty_keys) 175 | 176 | logging.info("Inspect missing keys:\n%s", missing_keys) 177 | logging.info("Inspect extra keys:\n%s", extra_keys) 178 | 179 | if fail_if_shapes_mismatch and is_shape_mismatch: 180 | raise ValueError("Shape mismatch between restored and target model") 181 | 182 | if (missing_keys and fail_if_missing) or (extra_keys and fail_if_extra): 183 | raise ValueError( 184 | f"Missing params from checkpoint: {missing_keys}.\n" 185 | f"Extra params in checkpoint: {extra_keys}." 186 | ) 187 | return restored_params 188 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: paper paper-chexpert paper-mnist paper-chexpert-embedding paper-chexpert-pixel baseline manova tree merge 2 | 3 | 4 | paper: paper-mnist paper-chexpert 5 | 6 | 7 | paper-chexpert: paper-chexpert-embedding paper-chexpert-pixel 8 | 9 | 10 | paper-mnist: 11 | for seed in $$(seq 2022 2025); do \ 12 | for rot in False; do \ 13 | for noise in 0; do \ 14 | for domain in 1; do \ 15 | for sub in none groups; do \ 16 | for tau in 0 1; do \ 17 | for train in 5000; do \ 18 | for cali in 0 1000; do \ 19 | for prior in 1; do \ 20 | pipenv run python3 \ 21 | -m tta.cli \ 22 | --config_name mnist_rot$${rot}_noise$${noise}_domain$${domain}_sub$${sub}_tau$${tau}_train$${train}_cali$${cali}_prior$${prior}_seed$${seed} \ 23 | --dataset_name MNIST \ 24 | --dataset_apply_rotation $${rot} \ 25 | --dataset_subsample_what $${sub} \ 26 | --dataset_feature_noise $${noise} \ 27 | --dataset_label_noise 0 \ 28 | --train_fit_joint True \ 29 | --train_model LeNet \ 30 | --train_domains $${domain} \ 31 | --train_fraction 0.9 \ 32 | --train_calibration_fraction 0.1 \ 33 | --train_batch_size 64 \ 34 | --train_epochs $${train} \ 35 | --train_decay 0.1 \ 36 | --train_patience 5 \ 37 | --train_tau $${tau} \ 38 | --train_lr 1e-3 \ 39 | --calibration_batch_size 64 \ 40 | --calibration_epochs $${cali} \ 41 | --calibration_decay 0.1 \ 42 | --calibration_patience 5 \ 43 | --calibration_tau $${tau} \ 44 | --calibration_lr 1e-3 \ 45 | --adapt_gmtl_alpha 1 \ 46 | --adapt_prior_strength $${prior} \ 47 | --adapt_symmetric_dirichlet False \ 48 | --adapt_fix_marginal False \ 49 | --test_argmax_joint False \ 50 | --test_batch_size 64 \ 51 | --test_batch_size 512 \ 52 | --seed $${seed} \ 53 | --num_workers 48 \ 54 | --plot_title "" \ 55 | --plot_only False; \ 56 | done \ 57 | done \ 58 | done \ 59 | done \ 60 | done \ 61 | done \ 62 | done \ 63 | done \ 64 | done 65 | 66 | 67 | paper-chexpert-embedding: 68 | for seed in $$(seq 2022 2025); do \ 69 | for Y_column in EFFUSION; do \ 70 | for Z_column in GENDER; do \ 71 | for domain in 1; do \ 72 | for size in 65536; do \ 73 | for sub in none groups; do \ 74 | for tau in 0 1; do \ 75 | for train in 5000; do \ 76 | for cali in 0 1000; do \ 77 | for prior in 1; do \ 78 | pipenv run python3 \ 79 | -m tta.cli \ 80 | --config_name chexpert-embedding_$${Y_column}_$${Z_column}_domain$${domain}_size$${size}_sub$${sub}_tau$${tau}_train$${train}_cali$${cali}_prior$${prior}_seed$${seed} \ 81 | --dataset_name CheXpert \ 82 | --dataset_Y_column $${Y_column} \ 83 | --dataset_Z_column $${Z_column} \ 84 | --dataset_target_domain_count 512 \ 85 | --dataset_source_domain_count $${size} \ 86 | --dataset_subsample_what $${sub} \ 87 | --dataset_use_embedding True \ 88 | --dataset_feature_noise 0 \ 89 | --dataset_label_noise 0 \ 90 | --train_fit_joint True \ 91 | --train_model Linear \ 92 | --train_domains $${domain} \ 93 | --train_fraction 0.9 \ 94 | --train_calibration_fraction 0.1 \ 95 | --train_batch_size 64 \ 96 | --train_epochs $${train} \ 97 | --train_decay 0.1 \ 98 | --train_patience 5 \ 99 | --train_tau $${tau} \ 100 | --train_lr 1e-3 \ 101 | --calibration_batch_size 64 \ 102 | --calibration_epochs $${cali} \ 103 | --calibration_decay 0.1 \ 104 | --calibration_patience 5 \ 105 | --calibration_tau $${tau} \ 106 | --calibration_lr 1e-3 \ 107 | --adapt_gmtl_alpha 1 \ 108 | --adapt_prior_strength $${prior} \ 109 | --adapt_symmetric_dirichlet False \ 110 | --adapt_fix_marginal False \ 111 | --test_argmax_joint False \ 112 | --test_batch_size 64 \ 113 | --test_batch_size 512 \ 114 | --seed $${seed} \ 115 | --num_workers 48 \ 116 | --plot_title "" \ 117 | --plot_only False; \ 118 | done \ 119 | done \ 120 | done \ 121 | done \ 122 | done \ 123 | done \ 124 | done \ 125 | done \ 126 | done \ 127 | done 128 | 129 | 130 | paper-chexpert-pixel: 131 | for seed in $$(seq 2022 2025); do \ 132 | for Y_column in EFFUSION; do \ 133 | for Z_column in GENDER; do \ 134 | for domain in 1; do \ 135 | for size in 65536; do \ 136 | for sub in none groups; do \ 137 | for tau in 0 1; do \ 138 | for train in 5000; do \ 139 | for cali in 0 1000; do \ 140 | for prior in 1; do \ 141 | pipenv run python3 \ 142 | -m tta.cli \ 143 | --config_name chexpert-pixel_$${Y_column}_$${Z_column}_domain$${domain}_size$${size}_sub$${sub}_tau$${tau}_train$${train}_cali$${cali}_prior$${prior}_seed$${seed} \ 144 | --dataset_name CheXpert \ 145 | --dataset_Y_column $${Y_column} \ 146 | --dataset_Z_column $${Z_column} \ 147 | --dataset_target_domain_count 512 \ 148 | --dataset_source_domain_count $${size} \ 149 | --dataset_subsample_what $${sub} \ 150 | --dataset_use_embedding False \ 151 | --dataset_feature_noise 0 \ 152 | --dataset_label_noise 0 \ 153 | --train_fit_joint True \ 154 | --train_model ResNet50 \ 155 | --train_pretrained_path pretrained/ResNet50_ImageNet1k \ 156 | --train_domains $${domain} \ 157 | --train_fraction 0.9 \ 158 | --train_calibration_fraction 0.1 \ 159 | --train_batch_size 64 \ 160 | --train_epochs $${train} \ 161 | --train_decay 0.1 \ 162 | --train_patience 5 \ 163 | --train_tau $${tau} \ 164 | --train_lr 1e-3 \ 165 | --calibration_batch_size 64 \ 166 | --calibration_epochs $${cali} \ 167 | --calibration_decay 0.1 \ 168 | --calibration_patience 5 \ 169 | --calibration_tau $${tau} \ 170 | --calibration_lr 1e-3 \ 171 | --adapt_gmtl_alpha 1 \ 172 | --adapt_prior_strength $${prior} \ 173 | --adapt_symmetric_dirichlet False \ 174 | --adapt_fix_marginal False \ 175 | --test_argmax_joint False \ 176 | --test_batch_size 64 \ 177 | --test_batch_size 512 \ 178 | --seed $${seed} \ 179 | --num_workers 48 \ 180 | --plot_title "" \ 181 | --plot_only False; \ 182 | done \ 183 | done \ 184 | done \ 185 | done \ 186 | done \ 187 | done \ 188 | done \ 189 | done \ 190 | done \ 191 | done 192 | 193 | 194 | data/CheXpert/data_matrix.npz: 195 | pipenv run python3 -m scripts.matching 196 | 197 | 198 | baseline: data/CheXpert/data_matrix.npz 199 | pipenv run python3 -m scripts.baseline 200 | 201 | 202 | manova: 203 | pipenv run python3 -m scripts.manova 204 | 205 | 206 | tree: 207 | for seed in $$(seq 2022 2025); do \ 208 | env JAX_PLATFORMS="cpu" pipenv run python3 -m scripts.tree --seed $${seed}; \ 209 | done 210 | 211 | 212 | merge: 213 | env JAX_PLATFORMS="cpu" \ 214 | pipenv run python3 \ 215 | -m scripts.merge \ 216 | --npz_pattern "tree_mnist_rotFalse_noise0_domain1_prior1_seed????.npz" \ 217 | --merged_title "" \ 218 | --merged_name "tree_mnist-domain1-noise0" 219 | env JAX_PLATFORMS="cpu" \ 220 | pipenv run python3 \ 221 | -m scripts.merge \ 222 | --npz_pattern "tree_chexpert-embedding_EFFUSION_GENDER_domain1_size65536_prior1_seed????.npz" \ 223 | --merged_title "" \ 224 | --merged_name "tree_chexpert-embedding-domain1" 225 | for noise in 0; do \ 226 | for domain in 1; do \ 227 | for cali in 0 1000; do \ 228 | env JAX_PLATFORMS="cpu" \ 229 | pipenv run python3 \ 230 | -m scripts.merge \ 231 | --npz_pattern "mnist_rotFalse_noise$${noise}_domain$${domain}_sub*_tau*_train5000_cali$${cali}_prior1_seed????.npz" \ 232 | --merged_title "" \ 233 | --merged_name "mnist-domain$${domain}-noise$${noise}-cali$${cali}"; \ 234 | done \ 235 | done \ 236 | done 237 | for domain in 1; do \ 238 | for cali in 0 1000; do \ 239 | env JAX_PLATFORMS="cpu" \ 240 | pipenv run python3 \ 241 | -m scripts.merge \ 242 | --npz_pattern "chexpert-embedding_EFFUSION_GENDER_domain$${domain}_size65536_sub*_tau*_train5000_cali$${cali}_prior1_seed????.npz" \ 243 | --merged_title "" \ 244 | --merged_name "chexpert-embedding-domain$${domain}-cali$${cali}"; \ 245 | done \ 246 | done 247 | for domain in 1; do \ 248 | for cali in 0 1000; do \ 249 | env JAX_PLATFORMS="cpu" \ 250 | pipenv run python3 \ 251 | -m scripts.merge \ 252 | --npz_pattern "chexpert-pixel_EFFUSION_GENDER_domain$${domain}_size65536_sub*_tau*_train5000_cali$${cali}_prior1_seed????.npz" \ 253 | --merged_title "" \ 254 | --merged_name "chexpert-pixel-domain$${domain}-cali$${cali}"; \ 255 | done \ 256 | done 257 | 258 | 259 | freeze: 260 | for seed in $$(seq 2023 2023); do \ 261 | pipenv run python3 -m scripts.freeze --seed $${seed}; \ 262 | done 263 | -------------------------------------------------------------------------------- /tta/visualize.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Tuple, Union, Optional 2 | 3 | from pathlib import Path 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def plot( 9 | npz_path: Path, 10 | confounder_strength: np.ndarray, 11 | train_domains_set: Set[int], 12 | dataset_label_noise: float, 13 | plot_title: str, 14 | plot_root: Path, 15 | config_name: str, 16 | y_lim: Optional[Tuple], 17 | ): 18 | print(f"Reading from {npz_path}") 19 | all_sweeps = np.load(npz_path, allow_pickle=True) 20 | 21 | for sweep_type, (sweeps, ylabel) in all_sweeps.items(): 22 | ylabel = ylabel.replace("Average AUC", "AUC") 23 | fig, ax = plt.subplots(figsize=(12, 6)) 24 | 25 | if sweep_type == "accuracy": 26 | if dataset_label_noise > 0: 27 | upper_bound = bayes_accuracy(dataset_label_noise, confounder_strength) 28 | ax.plot( 29 | confounder_strength, 30 | upper_bound, 31 | color="grey", 32 | linestyle="dashdot", 33 | label="Upper bound", 34 | ) 35 | 36 | alpha_min, alpha_max = float('inf'), -float('inf') 37 | prior_str_min, prior_str_max = float('inf'), -float('inf') 38 | batch_size_min, batch_size_max = float('inf'), -float('inf') 39 | for ((algo, *param), argmax_joint, batch_size), sweep in sweeps.items(): 40 | if algo == "GMTL": 41 | alpha, = param 42 | alpha_min = min(alpha_min, alpha) 43 | alpha_max = max(alpha_max, alpha) 44 | elif algo == "EM": 45 | prior_str, _, _ = param 46 | prior_str_min = min(prior_str_min, prior_str) 47 | prior_str_max = max(prior_str_max, prior_str) 48 | batch_size_min = min(batch_size_min, batch_size) 49 | batch_size_max = max(batch_size_max, batch_size) 50 | 51 | baseline_curves_labels = [], [] 52 | gmtl_curves_labels = [], [] 53 | em_curves_labels = [], [] 54 | tab20c = plt.get_cmap("tab20c") 55 | for ((algo, *param), argmax_joint, batch_size), sweep in sorted(sweeps.items()): 56 | del argmax_joint 57 | if algo == "Null": 58 | linestyle = "dotted" 59 | marker = "." 60 | markerfacecolor = color = tab20c.colors[19] 61 | scaler = 2 62 | label = "[Unadapted]" 63 | curves, labels = baseline_curves_labels 64 | elif algo == "Null-unconfounded": 65 | linestyle = "dotted" 66 | marker = "." 67 | markerfacecolor = color = tab20c.colors[8] 68 | scaler = 2 69 | label = "[Invariant]" 70 | curves, labels = baseline_curves_labels 71 | elif algo == "Oracle": 72 | linestyle = "dotted" 73 | marker = "." 74 | markerfacecolor = color = tab20c.colors[16] 75 | scaler = 2 76 | label = "[Oracle]" 77 | curves, labels = baseline_curves_labels 78 | elif algo == "GMTL": 79 | alpha, = param 80 | 81 | color_min = np.array(tab20c.colors[3]) 82 | color_max = np.array(tab20c.colors[0]) 83 | if alpha_max == alpha_min: 84 | markerfacecolor = color = color_max 85 | else: 86 | multiplier = (alpha - alpha_min)/(alpha_max - alpha_min) 87 | markerfacecolor = color = color_min + multiplier * (color_max - color_min) 88 | 89 | linestyle = "dashed" 90 | marker = "^" 91 | scaler = 1 92 | label = f"[GMTL] {alpha = }" 93 | curves, labels = gmtl_curves_labels 94 | elif algo == "EM": 95 | prior_str, symmetric_dirichlet, fix_marginal = param 96 | del symmetric_dirichlet, fix_marginal 97 | 98 | color_min = np.array(tab20c.colors[7]) 99 | color_max = np.array(tab20c.colors[4]) 100 | if batch_size_max == batch_size_min: 101 | color = color_max 102 | else: 103 | multiplier = (np.log(batch_size) - np.log(batch_size_min))/(np.log(batch_size_max) - np.log(batch_size_min)) 104 | color = color_min + multiplier * (color_max - color_min) 105 | 106 | markerfacecolor_min = np.array(tab20c.colors[4]) 107 | markerfacecolor_max = np.array(tab20c.colors[7]) 108 | if prior_str_max == prior_str_min: 109 | markerfacecolor = color 110 | else: 111 | multiplier = (prior_str - prior_str_min)/(prior_str_max - prior_str_min) 112 | markerfacecolor = markerfacecolor_min + multiplier * (markerfacecolor_max - markerfacecolor_min) 113 | 114 | linestyle = "solid" 115 | marker = "s" 116 | scaler = 1 117 | label = f"[TTLSA] N = {batch_size}" 118 | curves, labels = em_curves_labels 119 | else: 120 | raise ValueError(f"Unknown adaptation algorithm {algo}") 121 | 122 | curve, = ax.plot(confounder_strength, sweep[:-1], linestyle=linestyle, marker=marker, 123 | color=color, markerfacecolor=markerfacecolor, linewidth=2*scaler, markersize=8*scaler) 124 | curves.append(curve) 125 | labels.append(label) 126 | 127 | for i in train_domains_set: 128 | ax.axvline(confounder_strength[i], color="black", 129 | linestyle="dotted", linewidth=3) 130 | 131 | # plt.ylim((0, 1)) 132 | if sweep_type in {"mean", "l1", "norm"}: 133 | plt.ylim((0, 1)) 134 | elif y_lim is not None: 135 | plt.ylim(y_lim) 136 | else: 137 | plt.ylim((0.5, 1)) 138 | 139 | plt.xlabel("Shift parameter") 140 | plt.ylabel(ylabel) 141 | plt.title(plot_title) 142 | plt.grid(True) 143 | legend1 = plt.legend(*baseline_curves_labels, loc="upper left", bbox_to_anchor=(0, -0.15), ncol=1, frameon=False) 144 | legend2 = plt.legend(*gmtl_curves_labels, loc="upper left", bbox_to_anchor=(1/3, -0.15), ncol=1, frameon=False) 145 | plt.legend(*em_curves_labels, loc="upper left", bbox_to_anchor=(2/3, -0.15), ncol=1, frameon=False) 146 | plt.gca().add_artist(legend1) 147 | plt.gca().add_artist(legend2) 148 | fig.tight_layout() 149 | 150 | format_axes(ax) 151 | for suffix in ("png", "pdf"): 152 | plt.savefig(plot_root / f"{config_name}_{sweep_type}.{suffix}", bbox_inches='tight', dpi=300) 153 | 154 | plt.close(fig) 155 | 156 | 157 | def bayes_accuracy( 158 | dataset_label_noise: float, confounder_strength: Union[float, np.ndarray] 159 | ) -> np.ndarray: 160 | upper_bound = np.maximum( 161 | np.maximum(1 - confounder_strength, confounder_strength), 162 | (1 - dataset_label_noise) * np.ones_like(confounder_strength), 163 | ) 164 | return upper_bound 165 | 166 | 167 | DEFAULT_WIDTH = 6.0 168 | DEFAULT_HEIGHT = 1.5 169 | 170 | # Font sizes 171 | SIZE_SMALL = 10 172 | SIZE_MEDIUM = 12 173 | SIZE_LARGE = 16 174 | 175 | SPINE_COLOR = 'gray' 176 | 177 | 178 | def latexify( 179 | width_scale_factor=1, 180 | height_scale_factor=1, 181 | fig_width=None, 182 | fig_height=None, 183 | ): 184 | f""" 185 | width_scale_factor: float, DEFAULT_WIDTH will be divided by this number, DEFAULT_WIDTH is page width: {DEFAULT_WIDTH} inches. 186 | height_scale_factor: float, DEFAULT_HEIGHT will be divided by this number, DEFAULT_HEIGHT is {DEFAULT_HEIGHT} inches. 187 | fig_width: float, width of the figure in inches (if this is specified, width_scale_factor is ignored) 188 | fig_height: float, height of the figure in inches (if this is specified, height_scale_factor is ignored) 189 | """ 190 | if fig_width is None: 191 | fig_width = DEFAULT_WIDTH / width_scale_factor 192 | if fig_height is None: 193 | fig_height = DEFAULT_HEIGHT / height_scale_factor 194 | 195 | # use TrueType fonts so they are embedded 196 | # https://stackoverflow.com/questions/9054884/how-to-embed-fonts-in-pdfs-produced-by-matplotlib 197 | # https://jdhao.github.io/2018/01/18/mpl-plotting-notes-201801/ 198 | plt.rcParams["pdf.fonttype"] = 42 199 | 200 | # https://stackoverflow.com/a/39566040 201 | plt.rc("font", size=SIZE_MEDIUM) # controls default text sizes 202 | plt.rc("axes", titlesize=SIZE_LARGE) # fontsize of the axes title 203 | plt.rc("axes", labelsize=SIZE_LARGE) # fontsize of the x and y labels 204 | plt.rc("xtick", labelsize=SIZE_LARGE) # fontsize of the tick labels 205 | plt.rc("ytick", labelsize=SIZE_LARGE) # fontsize of the tick labels 206 | plt.rc("legend", fontsize=SIZE_LARGE) # legend fontsize 207 | plt.rc("figure", titlesize=SIZE_LARGE) # fontsize of the figure title 208 | 209 | # latexify: https://nipunbatra.github.io/blog/posts/2014-06-02-latexify.html 210 | plt.rcParams["backend"] = "ps" 211 | plt.rc("text", usetex=True) 212 | plt.rc("font", family="serif") 213 | plt.rc("figure", figsize=(fig_width, fig_height)) 214 | 215 | 216 | def format_axes(ax): 217 | for spine in ['top', 'right']: 218 | ax.spines[spine].set_visible(False) 219 | 220 | for spine in ['left', 'bottom']: 221 | ax.spines[spine].set_color(SPINE_COLOR) 222 | ax.spines[spine].set_linewidth(0.5) 223 | 224 | ax.xaxis.set_ticks_position('bottom') 225 | ax.yaxis.set_ticks_position('left') 226 | 227 | for axis in [ax.xaxis, ax.yaxis]: 228 | axis.set_tick_params(direction='out', color=SPINE_COLOR) 229 | 230 | return ax 231 | -------------------------------------------------------------------------------- /tta/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Tuple 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import flax 7 | from flax.training import train_state 8 | from flax.struct import field 9 | import optax 10 | 11 | from tta.models import AdaptiveNN 12 | 13 | 14 | class TrainState(train_state.TrainState): 15 | raw_fn: Callable = field(pytree_node=False) 16 | calibrated_fn: Callable = field(pytree_node=False) 17 | batch_stats: flax.core.FrozenDict[str, jnp.ndarray] 18 | prior: flax.core.FrozenDict[str, jnp.ndarray] 19 | 20 | 21 | def create_train_state(key: Any, C: int, K: int, model: str, 22 | learning_rate: float, specimen: jnp.ndarray, device_count: int) -> TrainState: 23 | net = AdaptiveNN(C=C, K=K, model=model) 24 | 25 | variables = net.init(key, specimen, True, method=net.adapted_prob) 26 | variables, params = variables.pop('params') 27 | if 'batch_stats' in variables: 28 | variables, batch_stats = variables.pop('batch_stats') 29 | else: 30 | batch_stats = {'dummy': jnp.empty(device_count)} 31 | variables, prior = variables.pop('prior') 32 | assert not variables 33 | 34 | tx = optax.adamw(learning_rate) 35 | state = TrainState.create( 36 | apply_fn=partial(net.apply, method=net.adapted_prob), 37 | params=params, 38 | tx=tx, 39 | raw_fn=partial(net.apply, method=net.raw_logit), 40 | calibrated_fn=partial(net.apply, method=net.calibrated_logit), 41 | batch_stats=batch_stats, 42 | prior=prior, 43 | ) 44 | 45 | return state 46 | 47 | 48 | @partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=(3, 4, 5), donate_argnums=(0,)) 49 | def train_step(state: TrainState, X: jnp.ndarray, M: jnp.ndarray, K: int, 50 | train_fit_joint: bool, tau: float, joint: jnp.ndarray) \ 51 | -> Tuple[TrainState, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]: 52 | @partial(jax.value_and_grad, has_aux=True) 53 | def loss_fn(params): 54 | variables = { 55 | 'params': params, 56 | 'batch_stats': state.batch_stats, 57 | 'prior': state.prior 58 | } 59 | logit, new_model_state = state.raw_fn( 60 | variables, X, True, mutable=['batch_stats'] 61 | ) 62 | logit = logit + tau * jnp.log(joint) 63 | 64 | if train_fit_joint: 65 | loss = optax.softmax_cross_entropy_with_integer_labels(logit, M) 66 | else: 67 | logit_YZ = logit.reshape((-1, logit.shape[-1] // K, K)) 68 | logit_Y = jax.nn.logsumexp(logit_YZ, axis=-1) 69 | logit_Z = jax.nn.logsumexp(logit_YZ, axis=-2) 70 | loss_Y = optax.softmax_cross_entropy_with_integer_labels(logit_Y, M // K) 71 | loss_Z = optax.softmax_cross_entropy_with_integer_labels(logit_Z, M % K) 72 | loss = loss_Y + loss_Z 73 | mask = jnp.arange(logit.shape[-1])[..., jnp.newaxis] == M 74 | hit = jnp.sum(mask * (jnp.argmax(logit, -1) == M), axis=-1) 75 | total = jnp.sum(mask, axis=-1) 76 | 77 | return loss.sum(), (new_model_state, hit, total) 78 | 79 | (loss, (new_model_state, hit, total)), grads = loss_fn(state.params) 80 | loss = jax.lax.psum(loss, axis_name='batch') 81 | hit = jax.lax.psum(hit, axis_name='batch') 82 | total = jax.lax.psum(total, axis_name='batch') 83 | grads = jax.lax.psum(grads, axis_name='batch') 84 | 85 | state = state.apply_gradients( 86 | grads=grads, 87 | batch_stats=new_model_state['batch_stats'], 88 | ) 89 | 90 | return state, (loss, hit, total) 91 | 92 | 93 | @partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=(3, 4, 5)) 94 | def validation_step(state: TrainState, X: jnp.ndarray, M: jnp.ndarray, K: int, 95 | train_fit_joint: bool, tau: float, joint: jnp.ndarray) -> jnp.ndarray: 96 | variables = { 97 | 'params': state.params, 98 | 'batch_stats': state.batch_stats, 99 | 'prior': state.prior 100 | } 101 | logit = state.raw_fn(variables, X, False) 102 | logit = logit + tau * jnp.log(joint) 103 | 104 | if train_fit_joint: 105 | loss = optax.softmax_cross_entropy_with_integer_labels(logit, M) 106 | else: 107 | logit_YZ = logit.reshape((-1, logit.shape[-1] // K, K)) 108 | logit_Y = jax.nn.logsumexp(logit_YZ, axis=-1) 109 | logit_Z = jax.nn.logsumexp(logit_YZ, axis=-2) 110 | loss_Y = optax.softmax_cross_entropy_with_integer_labels(logit_Y, M // K) 111 | loss_Z = optax.softmax_cross_entropy_with_integer_labels(logit_Z, M % K) 112 | loss = loss_Y + loss_Z 113 | 114 | loss = loss.sum() 115 | loss = jax.lax.psum(loss, axis_name='batch') 116 | 117 | return loss 118 | 119 | 120 | @partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=(3, 4, 5, 6), donate_argnums=(0,)) 121 | def calibration_step(state: TrainState, X: jnp.ndarray, M: jnp.ndarray, K: int, 122 | train_fit_joint: bool, tau: float, learning_rate: float, joint: jnp.ndarray) \ 123 | -> Tuple[TrainState, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]: 124 | @partial(jax.value_and_grad, has_aux=True) 125 | def loss_fn(params): 126 | variables = { 127 | 'params': params, 128 | 'batch_stats': state.batch_stats, 129 | 'prior': state.prior 130 | } 131 | logit, new_model_state = state.calibrated_fn( 132 | variables, X, True, mutable=['batch_stats'] 133 | ) 134 | logit = logit + tau * jnp.log(joint) 135 | 136 | if train_fit_joint: 137 | loss = optax.softmax_cross_entropy_with_integer_labels(logit, M) 138 | else: 139 | logit_YZ = logit.reshape((-1, logit.shape[-1] // K, K)) 140 | logit_Y = jax.nn.logsumexp(logit_YZ, axis=-1) 141 | logit_Z = jax.nn.logsumexp(logit_YZ, axis=-2) 142 | loss_Y = optax.softmax_cross_entropy_with_integer_labels(logit_Y, M // K) 143 | loss_Z = optax.softmax_cross_entropy_with_integer_labels(logit_Z, M % K) 144 | loss = loss_Y + loss_Z 145 | mask = jnp.arange(logit.shape[-1])[..., jnp.newaxis] == M 146 | hit = jnp.sum(mask * (jnp.argmax(logit, -1) == M), axis=-1) 147 | total = jnp.sum(mask, axis=-1) 148 | 149 | return loss.sum(), (new_model_state, hit, total) 150 | 151 | (loss, (new_model_state, hit, total)), grads = loss_fn(state.params) 152 | loss = jax.lax.psum(loss, axis_name='batch') 153 | hit = jax.lax.psum(hit, axis_name='batch') 154 | total = jax.lax.psum(total, axis_name='batch') 155 | grads = jax.lax.psum(grads, axis_name='batch') 156 | 157 | new_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, state.params, grads) 158 | state = state.replace( 159 | params=new_params, 160 | batch_stats=new_model_state['batch_stats'], 161 | ) 162 | 163 | return state, (loss, hit, total) 164 | 165 | 166 | cross_replica_mean: Callable = jax.pmap(lambda x: jax.lax.pmean(x, 'batch'), 'batch') 167 | 168 | 169 | @partial(jax.pmap, axis_name='batch') 170 | def induce_step(state: TrainState, X: jnp.ndarray) -> jnp.ndarray: 171 | variables = { 172 | 'params': state.params, 173 | 'batch_stats': state.batch_stats, 174 | 'prior': state.prior 175 | } 176 | logit = state.calibrated_fn(variables, X, False) 177 | prob = jax.nn.softmax(logit) 178 | prob_sum = jax.lax.psum(jnp.sum(prob, axis=0), axis_name='batch') 179 | 180 | return prob_sum 181 | 182 | 183 | # @partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=(3, 4, 5, 6), donate_argnums=(0,)) 184 | @partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=(3, 4, 5, 6)) 185 | def adapt_step(state: TrainState, X: jnp.ndarray, prior_strength: float, 186 | symmetric_dirichlet: bool, fix_marginal: bool, C: int, K: int) -> TrainState: 187 | M = C * K 188 | source_prior = state.prior['source'] 189 | if symmetric_dirichlet: 190 | alpha = jnp.ones(M) 191 | else: 192 | alpha = jax.tree_util.tree_map(lambda x: x * M, source_prior) 193 | alpha = prior_strength * alpha 194 | 195 | variables = { 196 | 'params': state.params, 197 | 'batch_stats': state.batch_stats, 198 | 'prior': state.prior 199 | } 200 | logit = state.calibrated_fn(variables, X, False) 201 | prob = jax.nn.softmax(logit) 202 | 203 | init_target_prior = source_prior 204 | init_objective = jnp.sum((alpha - 1) * jnp.log(source_prior)) 205 | init_val = init_target_prior, init_objective, init_objective - 1 206 | 207 | def cond_fun(val): 208 | _, objective, prev_objective = val 209 | return objective > prev_objective 210 | 211 | def body_fun(val): 212 | target_prior, prev_objective, _ = val 213 | 214 | # E step 215 | target_prob = target_prior * prob / source_prior 216 | normalizer = jnp.sum(target_prob, axis=-1, keepdims=True) 217 | target_prob = target_prob / normalizer 218 | 219 | # M step 220 | target_prob_count = jax.lax.psum(jnp.sum(target_prob, axis=0), axis_name='batch') 221 | target_prior_count = target_prob_count + (alpha - 1) # add pseudocount 222 | target_prior = target_prior_count / jnp.sum(target_prior_count) 223 | 224 | # Objective 225 | log_w = jnp.log(target_prior) - jnp.log(source_prior) 226 | mle_objective_i = jax.nn.logsumexp(log_w, axis=-1, b=prob) 227 | mle_objective = jax.lax.psum(jnp.sum(mle_objective_i), axis_name='batch') 228 | regularizer = jnp.sum((alpha - 1) * jnp.log(target_prior)) 229 | objective = mle_objective + regularizer 230 | 231 | return target_prior, objective, prev_objective 232 | 233 | target_prior, _, _ = jax.lax.while_loop(cond_fun, body_fun, init_val) 234 | 235 | if fix_marginal: 236 | # Make sure the marginal distribution of Y does not change 237 | source_prior = source_prior.reshape((C, K)) 238 | target_prior = target_prior.reshape((C, K)) 239 | source_marginal = jnp.sum(source_prior, axis=-1, keepdims=True) 240 | target_marginal = jnp.sum(target_prior, axis=-1, keepdims=True) 241 | target_prior = target_prior / target_marginal * source_marginal 242 | target_prior = target_prior.flatten() 243 | 244 | prior = state.prior.unfreeze() 245 | prior['target'] = target_prior 246 | state = state.replace(prior=flax.core.frozen_dict.freeze(prior)) 247 | 248 | return state 249 | 250 | 251 | @partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=(4,)) 252 | def test_step(state: TrainState, image: jnp.ndarray, Y: jnp.ndarray, Z: jnp.ndarray, argmax_joint: bool) \ 253 | -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]: 254 | variables = { 255 | 'params': state.params, 256 | 'batch_stats': state.batch_stats, 257 | 'prior': state.prior 258 | } 259 | 260 | prob_joint = state.apply_fn(variables, image, False) 261 | _, C, K = prob_joint.shape 262 | if (C, K) != (2, 2): 263 | raise NotImplementedError(f"(C, K) = {(C, K)} != (2, 2)") 264 | 265 | prob_Y = jnp.sum(prob_joint, axis=-1) 266 | prob_Z = jnp.sum(prob_joint, axis=-2) 267 | 268 | if argmax_joint: 269 | prediction_M = jnp.argmax(prob_joint) 270 | prediction_Y = prediction_M // 2 271 | prediction_Z = prediction_M % 2 272 | else: 273 | prediction_Y = jnp.argmax(prob_Y, axis=-1) 274 | prediction_Z = jnp.argmax(prob_Z, axis=-1) 275 | 276 | score_Y = prob_Y[:, 1] # assumes binary label 277 | score_Z = prob_Z[:, 1] # assumes binary label 278 | 279 | hit_Y = jax.lax.psum(jnp.sum(prediction_Y == Y), axis_name='batch') 280 | hit_Z = jax.lax.psum(jnp.sum(prediction_Z == Z), axis_name='batch') 281 | 282 | return (score_Y, hit_Y), (score_Z, hit_Z) 283 | -------------------------------------------------------------------------------- /scripts/merge.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Dict, Tuple, Union, List 2 | from pathlib import Path 3 | from collections import defaultdict 4 | import re 5 | 6 | import click 7 | import numpy as np 8 | import jax.numpy as jnp 9 | import matplotlib.pyplot as plt 10 | 11 | from tta.common import Adaptation, Curves 12 | from tta.visualize import latexify, format_axes 13 | 14 | 15 | Dataset = Union[Literal["mnist"], Literal["chexpert"], Literal["mimic"]] 16 | ConfigKey = Tuple[bool, Dataset, int, str, float, int] 17 | AdaptKey = Tuple[Adaptation, bool, int] 18 | 19 | 20 | @click.command() 21 | @click.option("--npz_pattern", type=str, required=True) 22 | @click.option("--merged_title", type=str, required=True) 23 | @click.option("--merged_name", type=str, required=True) 24 | def merge( 25 | npz_pattern: str, 26 | merged_title: str, 27 | merged_name: str, 28 | ) -> None: 29 | npz_root = Path("npz/") 30 | merged_root = Path("merged/") 31 | 32 | npz_root.mkdir(parents=True, exist_ok=True) 33 | merged_root.mkdir(parents=True, exist_ok=True) 34 | 35 | npz_dict = {} 36 | for npz_path in sorted(npz_root.glob(npz_pattern), key=key): 37 | print(f"Reading from {npz_path}") 38 | npz = np.load(npz_path, allow_pickle=True) 39 | npz_dict[npz_path.stem] = npz 40 | 41 | ylabels, type2config2adapt2sweeps = collect(npz_dict) 42 | confounder_strength = np.linspace(0, 1, 21) 43 | 44 | plot( 45 | ylabels, 46 | type2config2adapt2sweeps, 47 | confounder_strength, 48 | merged_title, 49 | merged_root, 50 | merged_name, 51 | ) 52 | 53 | 54 | def key(path: Path) -> Tuple[bool, Dataset, int, int, float, int]: 55 | is_tree, dataset, domain, sub, tau, cali = parse(path.stem) 56 | mapping = { 57 | "none": 0, 58 | "classes": 1, 59 | "groups": 2, 60 | } 61 | return is_tree, dataset, domain, mapping[sub], tau, cali 62 | 63 | 64 | def collect(npz_dict: Dict[str, Dict[str, Tuple[Curves, str]]]) -> Tuple[ 65 | Dict[str, str], Dict[str, Dict[ConfigKey, Dict[AdaptKey, List[jnp.ndarray]]]], 66 | ]: 67 | example = next(iter(npz_dict.values())) 68 | ylabels = {k: v.replace("Average AUC", "AUC") for k, (_, v) in example.items()} 69 | 70 | type2config2adapt2sweeps: Dict[str, Dict[ConfigKey, Dict[AdaptKey, List[jnp.ndarray]]]] \ 71 | = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) 72 | for sweep_type in ylabels.keys(): 73 | for config, type2adapt2sweeps in npz_dict.items(): 74 | config_key = parse(config) 75 | sweeps, _ = type2adapt2sweeps[sweep_type] 76 | for adapt_key, sweep in sweeps.items(): 77 | type2config2adapt2sweeps[sweep_type][config_key][adapt_key].append(sweep) 78 | 79 | return ylabels, type2config2adapt2sweeps 80 | 81 | 82 | def parse(config: str) -> ConfigKey: 83 | if config.startswith("mnist_"): 84 | pattern = re.compile(r"^mnist_rot(True|False)_noise(\d*\.?\d*)_domain(\d+)_sub(none|classes|groups)_tau(\d*\.?\d*)_train(\d+)_cali(\d+)_prior(\d*\.?\d*)_seed(\d*\.?\d*)$") 85 | matching = pattern.fullmatch(config) 86 | assert matching is not None 87 | 88 | is_tree = False 89 | dataset = "mnist" 90 | rot = bool(matching.group(1)) 91 | noise = float(matching.group(2)) 92 | domain = int(matching.group(3)) 93 | sub = matching.group(4) 94 | tau = float(matching.group(5)) 95 | train = int(matching.group(6)) 96 | cali = int(matching.group(7)) 97 | prior = float(matching.group(8)) 98 | seed = int(matching.group(9)) 99 | 100 | elif config.startswith("tree_mnist"): 101 | pattern = re.compile(r"^tree_mnist_rot(True|False)_noise(\d*\.?\d*)_domain(\d+)_prior(\d*\.?\d*)_seed(\d*\.?\d*)$") 102 | matching = pattern.fullmatch(config) 103 | assert matching is not None 104 | 105 | is_tree = True 106 | dataset = "mnist" 107 | rot = bool(matching.group(1)) 108 | noise = float(matching.group(2)) 109 | domain = int(matching.group(3)) 110 | sub = "none" 111 | tau = 0 112 | cali = 0 113 | prior = float(matching.group(4)) 114 | seed = int(matching.group(5)) 115 | 116 | elif config.startswith("chexpert-") or config.startswith("mimic-"): 117 | pattern = re.compile(r"^(chexpert|mimic)-(embedding|pixel)_([a-zA-Z]+)_([a-zA-Z]+)_domain(\d+)_size(\d+)_sub(none|classes|groups)_tau(\d*\.?\d*)_train(\d+)_cali(\d+)_prior(\d*\.?\d*)_seed(\d*\.?\d*)$") 118 | matching = pattern.fullmatch(config) 119 | assert matching is not None 120 | 121 | is_tree = False 122 | dataset = matching.group(1) 123 | modality = matching.group(2) 124 | Y_column = matching.group(3) 125 | Z_column = matching.group(4) 126 | domain = int(matching.group(5)) 127 | size = int(matching.group(6)) 128 | sub = matching.group(7) 129 | tau = float(matching.group(8)) 130 | train = int(matching.group(9)) 131 | cali = int(matching.group(10)) 132 | prior = float(matching.group(11)) 133 | seed = int(matching.group(12)) 134 | 135 | elif config.startswith("tree_chexpert-") or config.startswith("tree_mimic-"): 136 | pattern = re.compile(r"^tree_(chexpert|mimic)-(embedding|pixel)_([a-zA-Z]+)_([a-zA-Z]+)_domain(\d+)_size(\d+)_prior(\d*\.?\d*)_seed(\d*\.?\d*)$") 137 | matching = pattern.fullmatch(config) 138 | assert matching is not None 139 | 140 | is_tree = True 141 | dataset = matching.group(1) 142 | modality = matching.group(2) 143 | Y_column = matching.group(3) 144 | Z_column = matching.group(4) 145 | domain = int(matching.group(5)) 146 | size = int(matching.group(6)) 147 | sub = "none" 148 | tau = 0 149 | cali = 0 150 | prior = float(matching.group(7)) 151 | seed = int(matching.group(8)) 152 | 153 | else: 154 | raise ValueError(f"Unknown config {config}") 155 | 156 | return is_tree, dataset, domain, sub, tau, cali 157 | 158 | 159 | def plot( 160 | ylabels: Dict[str, str], 161 | type2config2adapt2sweeps: Dict[str, Dict[ConfigKey, Dict[AdaptKey, List[jnp.ndarray]]]], 162 | confounder_strength: np.ndarray, 163 | merged_title: str, 164 | merged_root: Path, 165 | merged_name: str, 166 | ): 167 | tab20 = plt.get_cmap("tab20").colors 168 | meta_styles = { 169 | "major": { 170 | ("none", 0.0): ("solid", "o", "ERM"), 171 | ("none", 1.0): ("solid", "o", "Logit Adjustment"), 172 | ("groups", 0.0): ("dashed", "^", "SUBG"), 173 | }, 174 | } 175 | for meta_type in meta_styles.keys(): 176 | styles = meta_styles[meta_type] 177 | for sweep_type, ylabel in ylabels.items(): 178 | fig, ax = plt.subplots(figsize=(12, 6)) 179 | 180 | erm_curves_labels = [], [] 181 | invariance_curves_labels = [], [] 182 | adaptation_curves_labels = [], [] 183 | 184 | config2adapt2sweeps = type2config2adapt2sweeps[sweep_type] 185 | for (is_tree, dataset, domain, sub, tau, cali), adapt2sweeps in config2adapt2sweeps.items(): 186 | ax.axvline(confounder_strength[domain], color="black", linestyle="dotted", linewidth=3) 187 | 188 | style = styles.get((sub, tau)) 189 | if style is None: 190 | continue 191 | 192 | linestyle, marker, base_label = style 193 | 194 | for ((algo, *param), argmax_joint, batch_size), sweeps in adapt2sweeps.items(): 195 | assert not argmax_joint 196 | 197 | adapt_on = "ERM" if is_tree else "Logit Adjustment" 198 | if meta_type == "major": 199 | if algo == "Null" and base_label == "ERM": 200 | label = base_label 201 | curves, labels = erm_curves_labels 202 | markerfacecolor = color = "black" 203 | elif algo == "Null" and base_label != "ERM": 204 | label = base_label 205 | curves, labels = invariance_curves_labels 206 | markerfacecolor = color = tab20[6] if base_label == "Logit Adjustment" else tab20[2] 207 | elif algo == "EM" and base_label == adapt_on and batch_size >= 64: 208 | label = f"TTLSA (batch size {batch_size})" 209 | curves, labels = adaptation_curves_labels 210 | if batch_size == 64: 211 | markerfacecolor = color = tab20[19] 212 | elif batch_size >= 512: 213 | markerfacecolor = color = tab20[18] 214 | else: 215 | raise ValueError(f"Unknown batch size {batch_size}") 216 | elif algo == "Oracle" and base_label == adapt_on: 217 | label = "TTLSA (oracle)" 218 | curves, labels = adaptation_curves_labels 219 | markerfacecolor = color = tab20[0] 220 | else: 221 | continue 222 | else: 223 | raise ValueError(f"Unknown meta type {meta_type}") 224 | 225 | if algo == "TTLSA (oracle)": 226 | linewidth = 1.5 227 | markersize = 6 228 | else: 229 | linewidth = 1 230 | markersize = 4 231 | 232 | jitter = { 233 | "Null": 0, 234 | "EM": -0.005, 235 | "Oracle": 0, 236 | } 237 | confounder_strength_jitted = confounder_strength + jitter[algo] 238 | 239 | sweep_mean, sweep_std = mean_std(sweeps) 240 | curve = ax.errorbar(confounder_strength_jitted, sweep_mean, sweep_std, 241 | linestyle=linestyle, marker=marker, 242 | linewidth=linewidth, markersize=markersize, 243 | color=color, markerfacecolor=markerfacecolor, alpha=1.0) 244 | curves.append(curve) 245 | labels.append(label) 246 | 247 | if sweep_type in {"mean", "l1", "norm"}: 248 | plt.ylim((0, 1)) 249 | elif is_tree: 250 | auc_limit = 0.9 if dataset == "mnist" else 0.7 251 | plt.ylim((auc_limit, 1)) 252 | else: 253 | auc_limit = 0.98 if dataset == "mnist" else 0.7 254 | plt.ylim((auc_limit, 1)) 255 | 256 | plt.xlabel("Shift parameter") 257 | if cali == 0: 258 | plt.ylabel(f"{ylabel} (without calibration)") 259 | else: 260 | plt.ylabel(ylabel) 261 | 262 | plt.title(merged_title) 263 | plt.grid(True, alpha=0.5) 264 | 265 | if meta_type == "major": 266 | # HACK 267 | invariance_curves, invariance_labels = invariance_curves_labels 268 | if len(invariance_curves) >= 3 and len(invariance_labels) >= 3: 269 | invariance_curves[0], invariance_curves[1] = invariance_curves[1], invariance_curves[0] 270 | invariance_labels[0], invariance_labels[1] = invariance_labels[1], invariance_labels[0] 271 | 272 | adaptation_curves, adaptation_labels = adaptation_curves_labels 273 | if len(adaptation_curves) >= 3 and len(adaptation_labels) >= 3: 274 | adaptation_curves[-1], adaptation_curves[-2] = adaptation_curves[-2], adaptation_curves[-1] 275 | adaptation_labels[-1], adaptation_labels[-2] = adaptation_labels[-2], adaptation_labels[-1] 276 | 277 | legend1 = plt.legend(*adaptation_curves_labels, loc="upper left", bbox_to_anchor=(2/3, -0.15), ncol=1, frameon=False) 278 | legend2 = plt.legend(*invariance_curves_labels, loc="upper left", bbox_to_anchor=(1/3, -0.15), ncol=1, frameon=False) 279 | plt.legend(*erm_curves_labels, loc="upper left", bbox_to_anchor=(0, -0.15), ncol=1, frameon=False) 280 | plt.gca().add_artist(legend1) 281 | plt.gca().add_artist(legend2) 282 | else: 283 | raise ValueError(f"Unknown meta type {meta_type}") 284 | 285 | fig.tight_layout() 286 | 287 | format_axes(ax) 288 | for suffix in ("png", "pdf"): 289 | plt.savefig(merged_root / f"{merged_name}_{sweep_type}_{meta_type}.{suffix}", bbox_inches="tight", dpi=300) 290 | 291 | plt.close(fig) 292 | 293 | 294 | def mean_std(sweeps: List[jnp.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 295 | sweeps_array = np.empty((len(sweeps), len(sweeps[0]) - 1)) 296 | for i, sweep in enumerate(sweeps): 297 | sweeps_array[i, :] = sweep[:-1] 298 | 299 | mean = np.mean(sweeps_array, axis=0) 300 | std = np.std(sweeps_array, axis=0) / np.sqrt(len(sweeps)) 301 | return mean, std 302 | 303 | 304 | if __name__ == "__main__": 305 | latexify(width_scale_factor=2, fig_height=2) 306 | np.random.seed(42) 307 | merge() 308 | -------------------------------------------------------------------------------- /tta/cli.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from typing import Any, Sequence, List, Tuple, Set, Optional, Dict 3 | from pathlib import Path 4 | from hashlib import sha256 5 | import sys 6 | import random 7 | from itertools import product 8 | from pprint import pprint 9 | 10 | import jax 11 | import jax.numpy as jnp 12 | from jax.experimental.compilation_cache.compilation_cache import initialize_cache 13 | import flax 14 | from flax.training.checkpoints import save_checkpoint, restore_checkpoint 15 | from flax.jax_utils import replicate, unreplicate 16 | import numpy as np 17 | import torch 18 | from torch.utils.data import Dataset, ConcatDataset, DataLoader 19 | import click 20 | from sklearn.metrics import roc_auc_score 21 | 22 | from tta.common import Adaptation, Curves, Sweeps 23 | from tta.utils import Tee 24 | from tta.datasets import MultipleDomainDataset, split, subsample 25 | from tta.datasets.mnist import MultipleDomainMNIST 26 | from tta.datasets.coco import ColoredCOCO 27 | from tta.datasets.waterbirds import MultipleDomainWaterbirds 28 | from tta.datasets.cxr.chexpert import MultipleDomainCheXpert 29 | from tta.datasets.cxr.mimic import MultipleDomainMIMIC 30 | from tta.train import ( 31 | TrainState, 32 | create_train_state, 33 | train_step, 34 | validation_step, 35 | calibration_step, 36 | cross_replica_mean, 37 | induce_step, 38 | adapt_step, 39 | test_step, 40 | ) 41 | from tta.restore import restore_train_state 42 | from tta.visualize import latexify, plot 43 | 44 | 45 | @click.command() 46 | @click.option("--config_name", type=str, required=True) 47 | @click.option( 48 | "--dataset_name", 49 | type=click.Choice(["MNIST", "COCO", "Waterbirds", "CheXpert", "MIMIC"]), 50 | required=True, 51 | ) 52 | @click.option("--dataset_Y_column", type=str, required=False) 53 | @click.option("--dataset_Z_column", type=str, required=False) 54 | @click.option("--dataset_target_domain_count", type=int, required=False) 55 | @click.option("--dataset_source_domain_count", type=int, required=False) 56 | @click.option("--dataset_subsample_what", type=str, required=True) 57 | @click.option("--dataset_use_embedding", type=bool, required=False) 58 | @click.option("--dataset_apply_rotation", type=bool, required=False) 59 | @click.option("--dataset_feature_noise", type=float, required=True) 60 | @click.option("--dataset_label_noise", type=float, required=True) 61 | @click.option("--train_fit_joint", type=bool, required=True) 62 | @click.option("--train_model", type=str, required=True) 63 | @click.option( 64 | "--train_pretrained_path", type=click.Path(path_type=Path), required=False 65 | ) 66 | @click.option("--train_domains", type=str, required=True) 67 | @click.option("--train_fraction", type=float, required=True) 68 | @click.option("--train_calibration_fraction", type=float, required=True) 69 | @click.option("--train_batch_size", type=int, required=True) 70 | @click.option("--train_epochs", type=int, required=True) 71 | @click.option("--train_decay", type=float, required=True) 72 | @click.option("--train_patience", type=int, required=True) 73 | @click.option("--train_tau", type=float, required=True) 74 | @click.option("--train_lr", type=float, required=True) 75 | @click.option("--calibration_domains", type=str, required=False) 76 | @click.option("--calibration_fraction", type=float, required=False) 77 | @click.option("--calibration_batch_size", type=int, required=True) 78 | @click.option("--calibration_epochs", type=int, required=True) 79 | @click.option("--calibration_decay", type=float, required=True) 80 | @click.option("--calibration_patience", type=int, required=True) 81 | @click.option("--calibration_tau", type=float, required=True) 82 | @click.option("--calibration_lr", type=float, required=True) 83 | @click.option("--adapt_skip_null_oracle", is_flag=True) 84 | @click.option("--adapt_gmtl_alpha", type=float, required=False, multiple=True) 85 | @click.option("--adapt_prior_strength", type=float, required=False, multiple=True) 86 | @click.option("--adapt_symmetric_dirichlet", type=bool, required=False, multiple=True) 87 | @click.option("--adapt_fix_marginal", type=bool, required=False, multiple=True) 88 | @click.option("--test_argmax_joint", type=bool, required=True, multiple=True) 89 | @click.option("--test_batch_size", type=int, required=True, multiple=True) 90 | @click.option("--seed", type=int, required=True) 91 | @click.option("--num_workers", type=int, required=True) 92 | @click.option( 93 | "--plot_title", type=str, required=False, default="Performance on Each Domain" 94 | ) 95 | @click.option("--plot_only", type=bool, required=True) 96 | def cli( 97 | config_name: str, 98 | dataset_name: str, 99 | dataset_y_column: Optional[str], 100 | dataset_z_column: Optional[str], 101 | dataset_target_domain_count: Optional[int], 102 | dataset_source_domain_count: Optional[int], 103 | dataset_subsample_what: str, 104 | dataset_use_embedding: Optional[bool], 105 | dataset_apply_rotation: Optional[bool], 106 | dataset_feature_noise: float, 107 | dataset_label_noise: float, 108 | train_fit_joint: bool, 109 | train_model: str, 110 | train_pretrained_path: Optional[Path], 111 | train_domains: str, 112 | train_fraction: float, 113 | train_calibration_fraction: float, 114 | train_batch_size: int, 115 | train_epochs: int, 116 | train_decay: float, 117 | train_patience: int, 118 | train_tau: float, 119 | train_lr: float, 120 | calibration_domains: Optional[str], 121 | calibration_fraction: Optional[float], 122 | calibration_batch_size: int, 123 | calibration_epochs: int, 124 | calibration_decay: float, 125 | calibration_patience: int, 126 | calibration_tau: float, 127 | calibration_lr: float, 128 | adapt_skip_null_oracle: bool, 129 | adapt_gmtl_alpha: Sequence[float], 130 | adapt_prior_strength: Sequence[float], 131 | adapt_symmetric_dirichlet: Sequence[bool], 132 | adapt_fix_marginal: Sequence[bool], 133 | test_argmax_joint: Sequence[bool], 134 | test_batch_size: Sequence[int], 135 | seed: int, 136 | num_workers: int, 137 | plot_title: str, 138 | plot_only: bool, 139 | ) -> None: 140 | log_root = Path("logs/") 141 | npz_root = Path("npz/") 142 | plot_root = Path("plots/") 143 | 144 | log_root.mkdir(parents=True, exist_ok=True) 145 | npz_root.mkdir(parents=True, exist_ok=True) 146 | plot_root.mkdir(parents=True, exist_ok=True) 147 | 148 | log_path = log_root / f"{config_name}.txt" 149 | if not plot_only: 150 | sys.stdout = Tee(log_path) 151 | npz_path = npz_root / f"{config_name}.npz" 152 | 153 | random.seed(seed) 154 | np.random.seed(seed) 155 | torch.manual_seed(seed) 156 | key = jax.random.PRNGKey(seed) 157 | generator = torch.Generator().manual_seed(seed) 158 | 159 | 160 | train_domains_set = set(int(env) for env in train_domains.split(",")) 161 | if len(train_domains_set) != 1: 162 | raise NotImplementedError( 163 | "Training on multiple source distributions is not supported yet." 164 | ) 165 | 166 | if calibration_domains is None: 167 | calibration_domains_set = set() 168 | else: 169 | calibration_domains_set = set( 170 | int(env) for env in calibration_domains.split(",") 171 | ) 172 | 173 | if calibration_fraction is None: 174 | calibration_fraction = 1.0 175 | 176 | 177 | if plot_only: 178 | # Ugly hack 179 | confounder_strength = np.linspace(0, 1, 21) 180 | dataset = SimpleNamespace(confounder_strength=confounder_strength) 181 | else: 182 | ( 183 | dataset, 184 | (train, joint_train), 185 | (calibration, joint_calibration), 186 | eval_splits, 187 | ) = prepare_dataset( 188 | dataset_name, 189 | dataset_y_column, 190 | dataset_z_column, 191 | dataset_target_domain_count, 192 | dataset_source_domain_count, 193 | dataset_subsample_what, 194 | dataset_use_embedding, 195 | dataset_apply_rotation, 196 | dataset_feature_noise, 197 | dataset_label_noise, 198 | train_domains_set, 199 | train_fraction, 200 | train_calibration_fraction, 201 | calibration_domains_set, 202 | calibration_fraction, 203 | generator, 204 | ) 205 | 206 | main( 207 | npz_path, 208 | dataset, 209 | train, 210 | joint_train, 211 | calibration, 212 | joint_calibration, 213 | eval_splits, 214 | train_domains_set, 215 | calibration_domains_set, 216 | dataset_label_noise, 217 | train_fit_joint, 218 | train_model, 219 | train_pretrained_path, 220 | train_batch_size, 221 | train_epochs, 222 | train_decay, 223 | train_patience, 224 | train_tau, 225 | train_lr, 226 | calibration_batch_size, 227 | calibration_epochs, 228 | calibration_decay, 229 | calibration_patience, 230 | calibration_tau, 231 | calibration_lr, 232 | adapt_skip_null_oracle, 233 | adapt_gmtl_alpha, 234 | adapt_prior_strength, 235 | adapt_symmetric_dirichlet, 236 | adapt_fix_marginal, 237 | test_argmax_joint, 238 | test_batch_size, 239 | key, 240 | generator, 241 | num_workers, 242 | ) 243 | 244 | plot( 245 | npz_path, 246 | dataset.confounder_strength, 247 | train_domains_set, 248 | dataset_label_noise, 249 | plot_title, 250 | plot_root, 251 | config_name, 252 | ) 253 | 254 | 255 | def prepare_dataset( 256 | dataset_name: str, 257 | dataset_y_column: Optional[str], 258 | dataset_z_column: Optional[str], 259 | dataset_target_domain_count: Optional[int], 260 | dataset_source_domain_count: Optional[int], 261 | dataset_subsample_what: str, 262 | dataset_use_embedding: Optional[bool], 263 | dataset_apply_rotation: Optional[bool], 264 | dataset_feature_noise: float, 265 | dataset_label_noise: float, 266 | train_domains_set: Set[int], 267 | train_fraction: float, 268 | train_calibration_fraction: float, 269 | calibration_domains_set: Set[int], 270 | calibration_fraction: float, 271 | generator: torch.Generator, 272 | ) -> Tuple[ 273 | MultipleDomainDataset, 274 | Tuple[Dataset, torch.Tensor], 275 | Tuple[Dataset, torch.Tensor], 276 | List[Tuple[Dataset, torch.Tensor]], 277 | ]: 278 | if dataset_name == "MNIST": 279 | assert dataset_y_column is None 280 | assert dataset_z_column is None 281 | assert dataset_target_domain_count is None 282 | assert dataset_source_domain_count is None 283 | assert dataset_use_embedding is None 284 | assert dataset_apply_rotation is not None 285 | 286 | root = Path("data/mnist") 287 | dataset = MultipleDomainMNIST( 288 | root, 289 | train_domains_set, 290 | generator, 291 | dataset_apply_rotation, 292 | dataset_feature_noise, 293 | dataset_label_noise, 294 | ) 295 | elif dataset_name == "COCO": 296 | assert dataset_y_column is None 297 | assert dataset_z_column is None 298 | assert dataset_target_domain_count is None 299 | assert dataset_source_domain_count is None 300 | assert dataset_use_embedding is None 301 | assert dataset_apply_rotation is None 302 | assert dataset_feature_noise == 0 303 | assert dataset_label_noise == 0 304 | 305 | root = Path("data/COCO/train2017") 306 | annFile = Path("data/COCO/annotations/instances_train2017.json") 307 | dataset = ColoredCOCO(root, annFile, generator) 308 | elif dataset_name == "Waterbirds": 309 | assert dataset_y_column is None 310 | assert dataset_z_column is None 311 | assert dataset_target_domain_count is None 312 | assert dataset_source_domain_count is None 313 | assert dataset_use_embedding is None 314 | assert dataset_apply_rotation is None 315 | assert dataset_feature_noise == 0 316 | assert dataset_label_noise == 0 317 | 318 | root = Path("data/") 319 | dataset = MultipleDomainWaterbirds(root, generator) 320 | elif dataset_name == "CheXpert": 321 | assert dataset_y_column is not None 322 | assert dataset_z_column is not None 323 | assert dataset_target_domain_count is not None 324 | assert dataset_use_embedding is not None 325 | assert dataset_apply_rotation is None 326 | assert dataset_feature_noise == 0 327 | assert dataset_label_noise == 0 328 | 329 | root = Path("data/CheXpert") 330 | dataset = MultipleDomainCheXpert( 331 | root, 332 | train_domains_set, 333 | generator, 334 | dataset_y_column, 335 | dataset_z_column, 336 | dataset_use_embedding, 337 | dataset_target_domain_count, 338 | dataset_source_domain_count, 339 | ) 340 | elif dataset_name == "MIMIC": 341 | assert dataset_y_column is not None 342 | assert dataset_z_column is not None 343 | assert dataset_target_domain_count is not None 344 | assert dataset_use_embedding is True 345 | assert dataset_apply_rotation is None 346 | assert dataset_feature_noise == 0 347 | assert dataset_label_noise == 0 348 | 349 | root = Path("data/MIMIC") 350 | dataset = MultipleDomainMIMIC( 351 | root, 352 | train_domains_set, 353 | generator, 354 | dataset_y_column, 355 | dataset_z_column, 356 | dataset_use_embedding, 357 | dataset_target_domain_count, 358 | dataset_source_domain_count, 359 | ) 360 | else: 361 | raise ValueError(f"Unknown dataset {dataset_name}") 362 | 363 | C, K = dataset.C, dataset.K 364 | if C != 2 or K != 2: 365 | raise NotImplementedError("Multi-label classification is not supported yet.") 366 | 367 | m = sha256() 368 | m.update(dataset.hexdigest.encode()) 369 | m.update(dataset_subsample_what.encode()) 370 | dataset.hexdigest = m.hexdigest() 371 | 372 | print("domains:", [len(domain) for domain, _ in dataset.domains]) 373 | (train, joint_train), (calibration, joint_calibration), test_splits = split( 374 | dataset, 375 | train_domains_set, 376 | train_fraction, 377 | train_calibration_fraction, 378 | calibration_domains_set, 379 | calibration_fraction, 380 | ) 381 | print("train (before subsampling):", len(train)) 382 | print(joint_train) 383 | print("calibration (before subsampling):", len(calibration)) 384 | print(joint_calibration) 385 | 386 | if dataset_subsample_what != "none": 387 | train, joint_train = subsample(train, joint_train, dataset_subsample_what, generator) 388 | calibration, joint_calibration = subsample(calibration, joint_calibration, dataset_subsample_what, generator) 389 | print("train (after subsampling):", len(train)) 390 | print(joint_train) 391 | print("calibration (after subsampling):", len(calibration)) 392 | print(joint_calibration) 393 | 394 | (train_domain,) = train_domains_set 395 | test_split_train, joint_M_train = test_splits[train_domain] 396 | print(f"test_split_train:", len(test_split_train)) 397 | print(joint_M_train) 398 | 399 | eval_splits = test_splits.copy() 400 | eval_splits.append((train, joint_train)) 401 | 402 | return ( 403 | dataset, 404 | (train, joint_train), 405 | (calibration, joint_calibration), 406 | eval_splits, 407 | ) 408 | 409 | 410 | def main( 411 | npz_path: Path, 412 | dataset: MultipleDomainDataset, 413 | train: ConcatDataset, 414 | joint_train: torch.Tensor, 415 | calibration: ConcatDataset, 416 | joint_calibration: torch.Tensor, 417 | eval_splits: List[Tuple[Dataset, torch.Tensor]], 418 | train_domains_set: Set[int], 419 | calibration_domains_set: Set[int], 420 | dataset_label_noise: float, 421 | train_fit_joint: bool, 422 | train_model: str, 423 | train_pretrained_path: Optional[Path], 424 | train_batch_size: int, 425 | train_epochs: int, 426 | train_decay: float, 427 | train_patience: int, 428 | train_tau: float, 429 | train_lr: float, 430 | calibration_batch_size: int, 431 | calibration_epochs: int, 432 | calibration_decay: float, 433 | calibration_patience: int, 434 | calibration_tau: float, 435 | calibration_lr: float, 436 | adapt_skip_null_oracle: bool, 437 | adapt_gmtl_alpha: Sequence[float], 438 | adapt_prior_strength: Sequence[float], 439 | adapt_symmetric_dirichlet: Sequence[bool], 440 | adapt_fix_marginal: Sequence[bool], 441 | test_argmax_joint: Sequence[bool], 442 | test_batch_size: Sequence[int], 443 | key: Any, 444 | generator: torch.Generator, 445 | num_workers: int, 446 | ) -> Dict[str, Curves]: 447 | device_count = jax.local_device_count() 448 | assert ( 449 | train_batch_size % device_count == 0 450 | ), f"train_batch_size should be divisible by {device_count}" 451 | assert ( 452 | calibration_batch_size % device_count == 0 453 | ), f"calibration_batch_size should be divisible by {device_count}" 454 | for batch_size in test_batch_size: 455 | assert ( 456 | batch_size % device_count == 0 457 | ), f"test_batch_size should be divisible by {device_count}" 458 | 459 | state = train_fn( 460 | dataset, 461 | train, 462 | joint_train, 463 | calibration, 464 | joint_calibration, 465 | train_fit_joint, 466 | train_model, 467 | train_pretrained_path, 468 | train_batch_size, 469 | train_epochs, 470 | train_decay, 471 | train_patience, 472 | train_tau, 473 | train_lr, 474 | calibration_batch_size, 475 | calibration_epochs, 476 | calibration_decay, 477 | calibration_patience, 478 | calibration_tau, 479 | calibration_lr, 480 | key, 481 | generator, 482 | device_count, 483 | num_workers, 484 | ) 485 | 486 | mean_sweeps, l1_sweeps, auc_sweeps, auc_Z_sweeps, accuracy_sweeps, accuracy_Z_sweeps, norm_sweeps = baseline_fn( 487 | state, 488 | dataset, 489 | eval_splits, 490 | dataset_label_noise, 491 | train_domains_set, 492 | train_batch_size, 493 | calibration_domains_set, 494 | adapt_skip_null_oracle, 495 | adapt_gmtl_alpha, 496 | generator, 497 | device_count, 498 | num_workers, 499 | ) 500 | 501 | for ( 502 | prior_strength, 503 | symmetric_dirichlet, 504 | fix_marginal, 505 | argmax_joint, 506 | batch_size, 507 | ) in product( 508 | adapt_prior_strength, 509 | adapt_symmetric_dirichlet, 510 | adapt_fix_marginal, 511 | test_argmax_joint, 512 | test_batch_size, 513 | ): 514 | adaptation = ("EM", prior_strength, symmetric_dirichlet, fix_marginal) 515 | state, ( 516 | mean_sweep, 517 | l1_sweep, 518 | auc_sweep, 519 | auc_Z_sweep, 520 | accuracy_sweep, 521 | accuracy_Z_sweep, 522 | norm_sweep, 523 | ) = adapt_fn( 524 | state, 525 | dataset.C, 526 | dataset.K, 527 | dataset_label_noise, 528 | train_domains_set, 529 | calibration_domains_set, 530 | eval_splits, 531 | adaptation, 532 | argmax_joint, 533 | batch_size, 534 | device_count, 535 | generator, 536 | num_workers, 537 | ) 538 | k = adaptation, argmax_joint, batch_size 539 | mean_sweeps[k] = mean_sweep 540 | l1_sweeps[k] = l1_sweep 541 | auc_sweeps[k] = auc_sweep 542 | auc_Z_sweeps[k] = auc_Z_sweep 543 | accuracy_sweeps[k] = accuracy_sweep 544 | accuracy_Z_sweeps[k] = accuracy_Z_sweep 545 | norm_sweeps[k] = norm_sweep 546 | 547 | all_sweeps = { 548 | "mean": (mean_sweeps, "Average probability of class 1"), 549 | "l1": (l1_sweeps, "Average L1 error of class 1"), 550 | "auc": (auc_sweeps, "AUC"), 551 | "auc_Z": (auc_Z_sweeps, "AUC (Z)"), 552 | "accuracy": (accuracy_sweeps, "Accuracy"), 553 | "accuracy_Z": (accuracy_Z_sweeps, "Accuracy (Z)"), 554 | "norm": (norm_sweeps, "Euclidean distance"), 555 | } 556 | pprint(all_sweeps) 557 | 558 | if npz_path.exists(): 559 | all_existing_sweeps = dict(**np.load(npz_path, allow_pickle=True)) 560 | for sweep_type in all_sweeps.keys(): 561 | sweeps, ylabel = all_sweeps[sweep_type] 562 | existing_sweeps, existing_ylabel = all_existing_sweeps[sweep_type] 563 | assert ylabel == existing_ylabel 564 | 565 | existing_sweeps.update(sweeps) 566 | all_existing_sweeps[sweep_type] = existing_sweeps, existing_ylabel 567 | 568 | jnp.savez(npz_path, **all_existing_sweeps) 569 | 570 | else: 571 | jnp.savez(npz_path, **all_sweeps) 572 | 573 | return all_sweeps 574 | 575 | 576 | def train_fn( 577 | dataset: MultipleDomainDataset, 578 | train: ConcatDataset, 579 | joint_train: torch.Tensor, 580 | calibration: ConcatDataset, 581 | joint_calibration: torch.Tensor, 582 | train_fit_joint: bool, 583 | train_model: str, 584 | train_pretrained_path: Optional[Path], 585 | train_batch_size: int, 586 | train_epochs: int, 587 | train_decay: float, 588 | train_patience: int, 589 | train_tau: float, 590 | train_lr: float, 591 | calibration_batch_size: int, 592 | calibration_epochs: int, 593 | calibration_decay: float, 594 | calibration_patience: int, 595 | calibration_tau: float, 596 | calibration_lr: float, 597 | key: Any, 598 | generator: torch.Generator, 599 | device_count: int, 600 | num_workers: int, 601 | ) -> TrainState: 602 | if len(calibration) == 0 and calibration_epochs > 0: 603 | raise ValueError("Calibration set may not be empty") 604 | 605 | C, K = dataset.C, dataset.K 606 | key_init, key = jax.random.split(key) 607 | specimen = jnp.empty(dataset.input_shape) 608 | state = create_train_state( 609 | key_init, 610 | C, 611 | K, 612 | train_model, 613 | train_lr, 614 | specimen, 615 | device_count, 616 | ) 617 | if train_pretrained_path is not None: 618 | state = restore_train_state(state, train_pretrained_path) 619 | 620 | m = sha256() 621 | m.update(dataset.hexdigest.encode()) 622 | m.update(str(train_fit_joint).encode()) 623 | m.update(train_model.encode()) 624 | m.update(str(train_pretrained_path).encode()) 625 | train_key = (train_batch_size, train_epochs, train_decay, train_patience, train_tau, train_lr) 626 | m.update(str(train_key).encode()) 627 | calibration_key = (calibration_batch_size, calibration_epochs, calibration_decay, calibration_patience, calibration_tau, calibration_lr) 628 | m.update(str(calibration_key).encode()) 629 | m.update(str(key).encode()) 630 | hexdigest = m.hexdigest() 631 | 632 | prefix = f"{dataset.__class__.__name__}_{dataset.train_domain}_{train_model}_{train_tau}_{calibration_tau}_{hexdigest}_" 633 | restored = restore_checkpoint("checkpoints/", state, prefix=prefix) 634 | if restored is not state: 635 | print(f"Restoring checkpoint with {prefix = }") 636 | 637 | # HACK: backward compatibility for legacy checkpoints 638 | # prior = restored.prior.unfreeze() 639 | # print('prior["source"]', prior["source"]) 640 | # prior["source"] = jnp.ones_like(prior["source"]) 641 | # prior["source"] = prior["source"] / jnp.sum(prior["source"]) 642 | # print('prior["source"]', prior["source"]) 643 | # restored = restored.replace(prior=flax.core.frozen_dict.freeze(prior)) 644 | 645 | return replicate(restored) 646 | else: 647 | print(f"Cannot find checkpoint with {prefix = }") 648 | 649 | state: TrainState = replicate(state) 650 | 651 | train_loader = DataLoader( 652 | train, 653 | train_batch_size, 654 | shuffle=True, 655 | num_workers=num_workers, 656 | generator=generator, 657 | ) 658 | if len(calibration) or calibration_epochs: 659 | calibration_loader = DataLoader( 660 | calibration, 661 | calibration_batch_size, 662 | shuffle=True, 663 | num_workers=num_workers, 664 | generator=generator, 665 | ) 666 | else: 667 | calibration_loader = None 668 | 669 | print("===> Training") 670 | joint_train_jnp = replicate(jnp.asarray(joint_train.flatten().numpy())) 671 | epoch_loss_valid_ema = None 672 | min_epoch_loss_valid_ema = float('inf') 673 | wait = 0 674 | for epoch in range(train_epochs): 675 | epoch_loss = 0 676 | epoch_hit = jnp.zeros(C * K, dtype=int) 677 | epoch_total = jnp.zeros(C * K, dtype=int) 678 | for X, _, Y, Z in train_loader: 679 | if X.shape[0] < device_count: 680 | continue 681 | 682 | remainder = X.shape[0] % device_count 683 | X = X[remainder:] 684 | Y = Y[remainder:] 685 | Z = Z[remainder:] 686 | 687 | X = jnp.array(X).reshape(device_count, -1, *X.shape[1:]) 688 | Y = jnp.array(Y).reshape(device_count, -1, *Y.shape[1:]) 689 | Z = jnp.array(Z).reshape(device_count, -1, *Z.shape[1:]) 690 | M = Y * K + Z 691 | 692 | state, (loss, hit, total) = train_step(state, X, M, K, train_fit_joint, train_tau, joint_train_jnp) 693 | epoch_loss += unreplicate(loss) 694 | epoch_hit += unreplicate(hit) 695 | epoch_total += unreplicate(total) 696 | 697 | if calibration_loader is None: 698 | with jnp.printoptions(precision=3): 699 | print( 700 | f"Train epoch {epoch + 1}, loss: {epoch_loss}, hit: {epoch_hit}, total: {epoch_total}" 701 | ) 702 | 703 | continue 704 | 705 | epoch_loss_valid = 0 706 | for X, _, Y, Z in calibration_loader: 707 | if X.shape[0] < device_count: 708 | continue 709 | 710 | remainder = X.shape[0] % device_count 711 | X = X[remainder:] 712 | Y = Y[remainder:] 713 | Z = Z[remainder:] 714 | 715 | X = jnp.array(X).reshape(device_count, -1, *X.shape[1:]) 716 | Y = jnp.array(Y).reshape(device_count, -1, *Y.shape[1:]) 717 | Z = jnp.array(Z).reshape(device_count, -1, *Z.shape[1:]) 718 | M = Y * K + Z 719 | 720 | loss_valid = validation_step(state, X, M, K, train_fit_joint, train_tau, joint_train_jnp) 721 | epoch_loss_valid += unreplicate(loss_valid) 722 | 723 | if epoch_loss_valid_ema is None: 724 | epoch_loss_valid_ema = epoch_loss_valid 725 | else: 726 | epoch_loss_valid_ema = (1 - train_decay) * epoch_loss_valid_ema + train_decay * epoch_loss_valid 727 | 728 | with jnp.printoptions(precision=3): 729 | print( 730 | f"Train epoch {epoch + 1}, loss: {epoch_loss:.2f} (val: {epoch_loss_valid:.2f}, ema: {epoch_loss_valid_ema:.2f}), hit: {epoch_hit}, total: {epoch_total}" 731 | ) 732 | 733 | if epoch_loss_valid >= min_epoch_loss_valid_ema: 734 | wait += 1 735 | else: 736 | wait = 0 737 | min_epoch_loss_valid_ema = epoch_loss_valid_ema 738 | 739 | if wait > train_patience: 740 | print(f"Early stopping! {train_decay = }, {train_patience = }, {min_epoch_loss_valid_ema = }") 741 | break 742 | 743 | # Sync the batch statistics across replicas so that evaluation is deterministic. 744 | state = state.replace(batch_stats=cross_replica_mean(state.batch_stats)) 745 | 746 | print("===> Calibrating") 747 | joint_calibration_jnp = replicate(jnp.asarray(joint_calibration.flatten().numpy())) 748 | epoch_loss_ema = None 749 | min_epoch_loss_ema = float('inf') 750 | wait = 0 751 | if calibration_loader is not None: 752 | for epoch in range(calibration_epochs): 753 | epoch_loss = 0 754 | epoch_hit = jnp.zeros(C * K, dtype=int) 755 | epoch_total = jnp.zeros(C * K, dtype=int) 756 | for X, _, Y, Z in calibration_loader: 757 | if X.shape[0] < device_count: 758 | continue 759 | 760 | remainder = X.shape[0] % device_count 761 | X = X[remainder:] 762 | Y = Y[remainder:] 763 | Z = Z[remainder:] 764 | 765 | X = jnp.array(X).reshape(device_count, -1, *X.shape[1:]) 766 | Y = jnp.array(Y).reshape(device_count, -1, *Y.shape[1:]) 767 | Z = jnp.array(Z).reshape(device_count, -1, *Z.shape[1:]) 768 | M = Y * K + Z 769 | 770 | state, (loss, hit, total) = calibration_step( 771 | state, X, M, K, train_fit_joint, calibration_tau, calibration_lr, joint_calibration_jnp 772 | ) 773 | epoch_loss += unreplicate(loss) 774 | epoch_hit += unreplicate(hit) 775 | epoch_total += unreplicate(total) 776 | 777 | if epoch_loss_ema is None: 778 | epoch_loss_ema = epoch_loss 779 | else: 780 | epoch_loss_ema = (1 - calibration_decay) * epoch_loss_ema + calibration_decay * epoch_loss 781 | 782 | with jnp.printoptions(precision=3): 783 | print( 784 | f"Calibration epoch {epoch + 1}, loss: {epoch_loss:.2f} (ema: {epoch_loss_ema:.2f}), hit: {epoch_hit}, total: {epoch_total}" 785 | ) 786 | 787 | if epoch_loss_ema >= min_epoch_loss_ema: 788 | wait += 1 789 | else: 790 | wait = 0 791 | min_epoch_loss_ema = epoch_loss_ema 792 | 793 | if wait > calibration_patience: 794 | print(f"Early stopping! {calibration_decay = }, {calibration_patience = }, {min_epoch_loss_ema = }") 795 | break 796 | 797 | # Sync the batch statistics across replicas so that evaluation is deterministic. 798 | state = state.replace(batch_stats=cross_replica_mean(state.batch_stats)) 799 | 800 | print("---> Temperature =", unreplicate(state.params["T"])) 801 | print("---> Bias =", unreplicate(state.params["b"])) 802 | 803 | if train_tau == 0 or calibration_tau == 0: 804 | # When doing logit adjustment, the source label distribution should be 805 | # uniform, as we effectively trained on an invariant domain. Since 806 | # "source" defaults to a uniform distribution, we only need to update 807 | # it when tau == 0. 808 | print("===> Estimating Source Label Prior") 809 | source_prior_induced = estimate_source_prior( 810 | calibration, 811 | calibration_batch_size, 812 | num_workers, 813 | generator, 814 | C, 815 | K, 816 | device_count, 817 | state, 818 | "induce", 819 | ) 820 | source_prior_empirical = estimate_source_prior( 821 | train, 822 | train_batch_size, 823 | num_workers, 824 | generator, 825 | C, 826 | K, 827 | device_count, 828 | state, 829 | "count", 830 | ) 831 | 832 | print("---> Induced source label prior =", source_prior_induced) 833 | print("---> Empirical source label prior =", source_prior_empirical) 834 | tvd = jnp.sum(jnp.abs(source_prior_induced - source_prior_empirical)) / 2 835 | print("---> Total variation distance =", tvd) 836 | 837 | prior = state.prior.unfreeze() 838 | prior["source"] = replicate(source_prior_induced) 839 | state = state.replace(prior=flax.core.frozen_dict.freeze(prior)) 840 | 841 | save_checkpoint("checkpoints/", unreplicate(state), -tvd, prefix) 842 | else: 843 | save_checkpoint("checkpoints/", unreplicate(state), 0, prefix) 844 | 845 | return state 846 | 847 | 848 | def estimate_source_prior( 849 | dataset: Dataset, 850 | batch_size: int, 851 | num_workers: int, 852 | generator: torch.Generator, 853 | C: int, 854 | K: int, 855 | device_count: int, 856 | state: TrainState, 857 | method: str, 858 | ) -> jnp.ndarray: 859 | loader = DataLoader( 860 | dataset, 861 | batch_size, 862 | shuffle=False, 863 | num_workers=num_workers, 864 | generator=generator, 865 | ) 866 | if method == "count": 867 | source_prior = np.zeros((C * K)) 868 | I = np.identity(C * K) 869 | for _, _, Y, Z in loader: 870 | M = Y * K + Z 871 | source_prior += np.sum(I[M], axis=0) 872 | 873 | source_prior = jnp.array(source_prior / np.sum(source_prior)) 874 | 875 | elif method == "induce": 876 | N = 0 877 | source_prior = jnp.zeros(C * K) 878 | for X, _, _, _ in loader: 879 | remainder = X.shape[0] % device_count 880 | X = X[remainder:] 881 | 882 | N += X.shape[0] 883 | X = jnp.array(X).reshape(device_count, -1, *X.shape[1:]) 884 | source_prior = source_prior + unreplicate(induce_step(state, X)) 885 | 886 | source_prior = source_prior / N 887 | 888 | else: 889 | raise ValueError(f"Unknown source label prior estimation method {method}") 890 | 891 | return source_prior 892 | 893 | 894 | def baseline_fn( 895 | state: TrainState, 896 | dataset: MultipleDomainDataset, 897 | eval_splits: List[Tuple[Dataset, torch.Tensor]], 898 | dataset_label_noise: float, 899 | train_domains_set: Set[int], 900 | train_batch_size: int, 901 | calibration_domains_set: Set[int], 902 | adapt_skip_null_oracle: bool, 903 | adapt_gmtl_alpha: Sequence[float], 904 | generator: torch.Generator, 905 | device_count: int, 906 | num_workers: int, 907 | ): 908 | print("===> Adapting & Evaluating") 909 | 910 | mean_sweeps = {} 911 | l1_sweeps = {} 912 | auc_sweeps = {} 913 | auc_Z_sweeps = {} 914 | accuracy_sweeps = {} 915 | accuracy_Z_sweeps = {} 916 | norm_sweeps = {} 917 | 918 | adaptations: List[Adaptation] 919 | if adapt_skip_null_oracle: 920 | adaptations = [] 921 | else: 922 | adaptations = [("Null",), ("Oracle",)] 923 | 924 | adaptations.extend(("GMTL", alpha) for alpha in adapt_gmtl_alpha) 925 | for adaptation in adaptations: 926 | argmax_joint = False 927 | batch_size = train_batch_size # batch size does not matter since we are not adapting on data 928 | state, (mean, l1, auc, auc_Z, accuracy, accuracy_Z, norm) = adapt_fn( 929 | state, 930 | dataset.C, 931 | dataset.K, 932 | dataset_label_noise, 933 | train_domains_set, 934 | calibration_domains_set, 935 | eval_splits, 936 | adaptation, 937 | argmax_joint, 938 | batch_size, 939 | device_count, 940 | generator, 941 | num_workers, 942 | ) 943 | mean_sweeps[adaptation, argmax_joint, batch_size] = mean 944 | l1_sweeps[adaptation, argmax_joint, batch_size] = l1 945 | auc_sweeps[adaptation, argmax_joint, batch_size] = auc 946 | auc_Z_sweeps[adaptation, argmax_joint, batch_size] = auc_Z 947 | accuracy_sweeps[adaptation, argmax_joint, batch_size] = accuracy 948 | accuracy_Z_sweeps[adaptation, argmax_joint, batch_size] = accuracy_Z 949 | norm_sweeps[adaptation, argmax_joint, batch_size] = norm 950 | 951 | return mean_sweeps, l1_sweeps, auc_sweeps, auc_Z_sweeps, accuracy_sweeps, accuracy_Z_sweeps, norm_sweeps 952 | 953 | 954 | def adapt_fn( 955 | state: TrainState, 956 | C: int, 957 | K: int, 958 | dataset_label_noise: float, 959 | train_domains_set: Set[int], 960 | calibration_domains_set: Set[int], 961 | eval_splits: Sequence[Tuple[Dataset, torch.Tensor]], 962 | adaptation: Adaptation, 963 | argmax_joint: bool, 964 | batch_size: int, 965 | device_count: int, 966 | generator: torch.Generator, 967 | num_workers: int, 968 | ) -> Tuple[TrainState, Sweeps]: 969 | label = f"{adaptation = }, {argmax_joint = }, {batch_size = }" 970 | print(f"---> {label}") 971 | 972 | mean_sweep = jnp.empty(len(eval_splits)) 973 | l1_sweep = jnp.empty(len(eval_splits)) 974 | auc_sweep = jnp.empty(len(eval_splits)) 975 | auc_Z_sweep = jnp.empty(len(eval_splits)) 976 | accuracy_sweep = jnp.empty(len(eval_splits)) 977 | accuracy_Z_sweep = jnp.empty(len(eval_splits)) 978 | norm_sweep = jnp.empty(len(eval_splits)) 979 | for i, (eval_, joint_M) in enumerate(eval_splits): 980 | # happens on the source domain when train_fraction = 1.0 981 | if len(eval_) == 0: 982 | mean_sweep = mean_sweep.at[i].set(jnp.nan) 983 | l1_sweep = l1_sweep.at[i].set(jnp.nan) 984 | auc_sweep = auc_sweep.at[i].set(jnp.nan) 985 | auc_Z_sweep = auc_Z_sweep.at[i].set(jnp.nan) 986 | accuracy_sweep = accuracy_sweep.at[i].set(jnp.nan) 987 | accuracy_Z_sweep = accuracy_Z_sweep.at[i].set(jnp.nan) 988 | norm_sweep = norm_sweep.at[i].set(jnp.nan) 989 | continue 990 | 991 | seen = ( 992 | " (seen)" 993 | if i in train_domains_set.union(calibration_domains_set) 994 | else " (train)" 995 | if i == len(eval_splits) - 1 996 | else "(unseen)" 997 | ) 998 | 999 | joint_M = jnp.array(joint_M) 1000 | flip_prob = jnp.array( 1001 | [ 1002 | [1 - dataset_label_noise, dataset_label_noise], 1003 | [dataset_label_noise, 1 - dataset_label_noise], 1004 | ] 1005 | ) 1006 | joint = flip_prob[:, :, jnp.newaxis] * joint_M # P(Y_tilde, Y, Z) 1007 | prob = joint / jnp.sum(joint, axis=1, keepdims=True) 1008 | prob = prob[:, 1, :] # P(Y=1|Y_tilde, Z) 1009 | 1010 | # using shuffle=True so that Y contains multiple classes, otherwise AUC is not defined 1011 | mean = l1 = hits = hits_Z = norm = 0 1012 | epoch_Y = jnp.empty(len(eval_) // device_count * device_count) 1013 | epoch_score = jnp.empty(len(eval_) // device_count * device_count) 1014 | epoch_Z = jnp.empty(len(eval_) // device_count * device_count) 1015 | epoch_score_Z = jnp.empty(len(eval_) // device_count * device_count) 1016 | offset = 0 1017 | 1018 | eval_loader = DataLoader( 1019 | eval_, 1020 | batch_size, 1021 | shuffle=True, 1022 | num_workers=num_workers, 1023 | generator=generator, 1024 | ) 1025 | for X, Y_tilde, Y, Z in eval_loader: 1026 | if X.shape[0] < device_count: 1027 | continue 1028 | 1029 | remainder = X.shape[0] % device_count 1030 | X = X[remainder:] 1031 | Y_tilde = Y_tilde[remainder:] 1032 | Y = Y[remainder:] 1033 | Z = Z[remainder:] 1034 | 1035 | N = X.shape[0] 1036 | X = jnp.array(X).reshape(device_count, -1, *X.shape[1:]) 1037 | Y_tilde = jnp.array(Y_tilde) 1038 | Y = jnp.array(Y).reshape(device_count, -1, *Y.shape[1:]) 1039 | Z = jnp.array(Z).reshape(device_count, -1, *Z.shape[1:]) 1040 | 1041 | epoch_Y = epoch_Y.at[offset : offset + N].set(Y.flatten()) 1042 | epoch_Z = epoch_Z.at[offset : offset + N].set(Z.flatten()) 1043 | 1044 | if adaptation[0] == "Null": 1045 | prior = state.prior.unfreeze() 1046 | prior["target"] = prior["source"] 1047 | state = state.replace(prior=flax.core.frozen_dict.freeze(prior)) 1048 | elif adaptation[0] == "Oracle": 1049 | prior = state.prior.unfreeze() 1050 | prior["target"] = replicate(joint_M.flatten()) 1051 | state = state.replace(prior=flax.core.frozen_dict.freeze(prior)) 1052 | elif adaptation[0] == "GMTL": 1053 | _, alpha = adaptation 1054 | prior = state.prior.unfreeze() 1055 | target = prior["source"]**(1-alpha) 1056 | target = target / jnp.sum(-1, keepdims=True) 1057 | prior["target"] = target 1058 | state = state.replace(prior=flax.core.frozen_dict.freeze(prior)) 1059 | elif adaptation[0] == "EM": 1060 | _, prior_strength, symmetric_dirichlet, fix_marginal = adaptation 1061 | state = adapt_step( 1062 | state, 1063 | X, 1064 | replicate(prior_strength), 1065 | symmetric_dirichlet, 1066 | fix_marginal, 1067 | C, 1068 | K, 1069 | ) 1070 | prior = unreplicate(state.prior["target"]).reshape((C, K)) 1071 | else: 1072 | raise ValueError(f"Unknown adaptation scheme {adaptation}") 1073 | 1074 | (score, hit), (score_Z, hit_Z) = test_step(state, X, Y, Z, argmax_joint) 1075 | 1076 | mean += jnp.sum(score) 1077 | l1 += jnp.sum(jnp.abs(score.flatten() - prob[Y_tilde, Z.flatten()])) 1078 | epoch_score = epoch_score.at[offset : offset + N].set(score.flatten()) 1079 | epoch_score_Z = epoch_score_Z.at[offset : offset + N].set(score_Z.flatten()) 1080 | hits += unreplicate(hit) 1081 | hits_Z += unreplicate(hit_Z) 1082 | prior = unreplicate(state.prior["target"]).reshape((C, K)) 1083 | norm += N * jnp.linalg.norm(prior - joint_M) 1084 | 1085 | offset += N 1086 | 1087 | mean = mean / len(eval_) 1088 | l1 = l1 / len(eval_) 1089 | auc = roc_auc_score(epoch_Y, epoch_score) 1090 | auc_Z = roc_auc_score(epoch_Z, epoch_score_Z) 1091 | accuracy = hits / len(eval_) 1092 | accuracy_Z = hits_Z / len(eval_) 1093 | norm = norm / len(eval_) 1094 | 1095 | with jnp.printoptions(precision=4): 1096 | print( 1097 | f"[{label}] Environment {i:>2} {seen} mean {mean}, L1 {l1}, AUC {auc} ({auc_Z}), Accuracy {accuracy} ({accuracy_Z}), Norm {norm}" 1098 | ) 1099 | 1100 | # note that foo_sweep.at[-1] is the training foo 1101 | mean_sweep = mean_sweep.at[i].set(mean) 1102 | l1_sweep = l1_sweep.at[i].set(l1) 1103 | auc_sweep = auc_sweep.at[i].set(auc) 1104 | auc_Z_sweep = auc_Z_sweep.at[i].set(auc_Z) 1105 | accuracy_sweep = accuracy_sweep.at[i].set(accuracy) 1106 | accuracy_Z_sweep = accuracy_Z_sweep.at[i].set(accuracy_Z) 1107 | norm_sweep = norm_sweep.at[i].set(norm) 1108 | 1109 | print( 1110 | f"[{label}] Average response {jnp.nanmean(mean_sweep[:-1])}, " 1111 | f"Average L1 {jnp.nanmean(l1_sweep[:-1])}, " 1112 | f"Average AUC {jnp.nanmean(auc_sweep[:-1])} ({jnp.nanmean(auc_Z_sweep[:-1])}), " 1113 | f"Accuracy {jnp.nanmean(accuracy_sweep[:-1])} ({jnp.nanmean(accuracy_Z_sweep[:-1])}), " 1114 | f"Norm {jnp.nanmean(norm_sweep[:-1])}" 1115 | ) 1116 | 1117 | return state, ( 1118 | mean_sweep, 1119 | l1_sweep, 1120 | auc_sweep, 1121 | auc_Z_sweep, 1122 | accuracy_sweep, 1123 | accuracy_Z_sweep, 1124 | norm_sweep, 1125 | ) 1126 | 1127 | 1128 | if __name__ == "__main__": 1129 | initialize_cache("jit_cache") 1130 | latexify(width_scale_factor=2, fig_height=2) 1131 | cli() 1132 | --------------------------------------------------------------------------------