├── .gitignore ├── PACE ├── src │ ├── eval_PACE.sh │ ├── train_PACE.sh │ ├── train_ViT.sh │ ├── generate_data.py │ ├── config.py │ ├── evaluate.py │ ├── augment.py │ ├── main.py │ ├── run_quantitative_theta.py │ ├── model.py │ └── utils.py └── environment_PACE.yml └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | PACE/dataset 2 | PACE/results 3 | PACE/src/__pycache__ 4 | PACE/ckpt -------------------------------------------------------------------------------- /PACE/src/eval_PACE.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_quantitative_theta.py --task Color --name ViT-PACE --num_epochs 1 & 2 | CUDA_VISIBLE_DEVICES=1 python run_quantitative_theta.py --task flower102 --name ViT-PACE --num_epochs 1 & 3 | CUDA_VISIBLE_DEVICES=2 python run_quantitative_theta.py --task cub2011 --name ViT-PACE --num_epochs 1 & 4 | CUDA_VISIBLE_DEVICES=3 python run_quantitative_theta.py --task cars --name ViT-PACE --num_epochs 1 -------------------------------------------------------------------------------- /PACE/src/train_PACE.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py --train --task Color --name ViT-PACE --num_epochs 1 --pretrain_epoch 5 & 2 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task flower102 --name ViT-PACE --num_epochs 1 --pretrain_epoch 10 & 3 | CUDA_VISIBLE_DEVICES=0 python main.py --train --task cub2011 --name ViT-PACE --num_epochs 1 --pretrain_epoch 10 & 4 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task cars --name ViT-PACE --num_epochs 1 --pretrain_epoch 20 & -------------------------------------------------------------------------------- /PACE/src/train_ViT.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task Color --name ViT-base --num_epochs 5 --lr 1e-3 --require_grad 2 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task flower102 --name ViT-base --num_epochs 10 --lr 1e-3 --require_grad 3 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task cub2011 --name ViT-base --num_epochs 10 --lr 1e-3 --require_grad 4 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task cars --name ViT-base --num_epochs 20 --weight_decay 0.01 --lr 5e-5 --require_grad --seed 2025 --train_batch_size 32 --eval_batch_size 64 5 | 6 | -------------------------------------------------------------------------------- /PACE/src/generate_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | from skimage.transform import resize 5 | from skimage.util import random_noise 6 | import random 7 | 8 | # Create directories for saving the images 9 | if not os.path.exists('../dataset/Color/class0'): 10 | os.makedirs('../dataset/Color/class0') 11 | 12 | if not os.path.exists('../dataset/Color/class1'): 13 | os.makedirs('../dataset/Color/class1') 14 | 15 | # Image dimensions 16 | original_size = (2, 2) 17 | target_size = (224, 224) 18 | 19 | # Number of images to generate for each class 20 | num_images = 1000 21 | 22 | # Color codes for red, yellow, blue, green, and black in RGB 23 | color_dict = {'red': [1, 0, 0], 'yellow': [1, 1, 0], 'blue': [0, 0, 1], 'green': [0, 1, 0], 'black': [0, 0, 0]} 24 | 25 | # Function to create an image 26 | def create_image(colors, num_black): 27 | # Start with a completely colored image 28 | image = np.zeros((*original_size, 3)) 29 | colored_locs = random.sample([(i, j) for i in range(2) for j in range(2)], 4-num_black) 30 | for idx, loc in enumerate(colored_locs): 31 | image[loc] = colors[idx%2] 32 | 33 | return image 34 | 35 | # Generate images 36 | for i in range(num_images): 37 | # Class 0: red+yellow 38 | num_black = random.randint(0, 2) 39 | image0 = create_image([color_dict['red'], color_dict['yellow']], num_black) 40 | image0 = random_noise(image0, mode='gaussian') 41 | image0_resized = resize(image0, target_size) 42 | plt.imsave(f'../dataset/Color/class0/{i}.png', image0_resized) 43 | 44 | # Class 1: blue+green 45 | num_black = random.randint(0, 2) 46 | image1 = create_image([color_dict['blue'], color_dict['green']], num_black) 47 | image1 = random_noise(image1, mode='gaussian') 48 | image1_resized = resize(image1, target_size) 49 | plt.imsave(f'../dataset/Color/class1/{i}.png', image1_resized) -------------------------------------------------------------------------------- /PACE/src/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser(description='PACE') 3 | 4 | # data args 5 | parser.add_argument('--data_path', type=str, help='path of dataset', 6 | default='../dataset') 7 | parser.add_argument('--task',type=str,help='task name of dataset',default='toy') 8 | parser.add_argument('--save_path', type=str, help='path to save', 9 | default='../ckpt') 10 | parser.add_argument('--load_path', type=str, help='path to load', 11 | default='../ckpt') 12 | 13 | # model args 14 | 15 | ## for PACE 16 | parser.add_argument('--c_dim', type=int, help='dimension of PACE',default=25) 17 | parser.add_argument('--K', type=int, help='number of centers of PACE',default=100) 18 | parser.add_argument('--D', type=int, help='number of images',default=10000) 19 | parser.add_argument('--N', type=int, help='max length of images',default=197) 20 | parser.add_argument('--alpha', type=float, help='alpha prior of PACE',default=2) 21 | parser.add_argument('--eta', type=float, help='weight of PACE',default=1) 22 | parser.add_argument('--frac', type=str,help='type of fractional model', default='fix') 23 | parser.add_argument('--layer', type=int, help='layer of PACE',default=-2) 24 | parser.add_argument('--version', type=str, help='running version',default='v0') 25 | 26 | ## for ViT 27 | parser.add_argument('--lm', type=str, help='which language model', default='ViT') 28 | parser.add_argument('--b_dim', type=int, help='dimension of ViT',default=768) 29 | parser.add_argument('--out_dim', type=int, help='dimension of output',default=5) 30 | parser.add_argument('--name', type=str, help='model name', default='ViT-PACE') 31 | parser.add_argument('--seed',type=int, default=2021) 32 | parser.add_argument('--lr', type=float, help='learning rate', default=3e-5) 33 | parser.add_argument('--weight_decay', type=float, help='weight decay', default=0.05) 34 | parser.add_argument('--train', action='store_true', default=False) 35 | parser.add_argument('--require_grad', action='store_true', default=False) 36 | parser.add_argument('--pretrain_epoch', type=str, help='loading pretrain epoch', default='5') 37 | 38 | # optimization args 39 | parser.add_argument('--num_epochs', type=int, help='number of epoches',default=10) 40 | parser.add_argument('--train_batch_size', type=int, help='training sz',default=16) 41 | parser.add_argument('--eval_batch_size', type=int, help='eval sz',default=64) 42 | parser.add_argument('--metric', type=str, help='eval metric',default='eval_accuracy') 43 | 44 | 45 | # for quantitative theta 46 | parser.add_argument('--load_concepts', action='store_true', default=False, 47 | help='load concepts and probabilities from saved path instead of inferring them') 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /PACE/src/evaluate.py: -------------------------------------------------------------------------------- 1 | # import metric packages 2 | import numpy as np 3 | import pandas as pd 4 | from numpy.linalg import * 5 | from scipy.linalg import sqrtm 6 | from scipy.special import gammaln, psi 7 | import torch 8 | from sklearn.decomposition import PCA 9 | import matplotlib.pyplot as plt 10 | from datasets import load_metric 11 | from sklearn.linear_model import LogisticRegression, LinearRegression 12 | from utils import * 13 | from sklearn.pipeline import make_pipeline 14 | from sklearn import preprocessing 15 | from scipy.stats import entropy 16 | from sklearn.neural_network import MLPClassifier 17 | from sklearn.preprocessing import StandardScaler 18 | from sklearn.preprocessing import Normalizer 19 | 20 | 21 | def to_numpy(x): 22 | if isinstance(x, torch.Tensor): 23 | return x.detach().cpu().numpy() 24 | else: 25 | return x 26 | 27 | 28 | def faithfulness(concept_train, pred_train, concept_test, pred_test, hard=False, prob_test=None): 29 | # faithfulness with MLP classifier 30 | concept_train = to_numpy(concept_train) 31 | pred_train = to_numpy(pred_train) 32 | concept_test = to_numpy(concept_test) 33 | pred_test = to_numpy(pred_test) 34 | concept_train = concept_test 35 | pred_train = pred_test 36 | pipe = make_pipeline( 37 | StandardScaler(), 38 | MLPClassifier(hidden_layer_sizes=(100,), activation='relu', solver='adam', alpha=0.0001, batch_size='auto', learning_rate='constant', learning_rate_init=0.001, power_t=0.5, max_iter=200, shuffle=True, random_state=None, tol=0.0001, verbose=False, warm_start=False, momentum=0.9, nesterovs_momentum=True, early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08) 39 | ) 40 | 41 | clf = pipe.fit(concept_train, pred_train) 42 | score = clf.score(concept_test, pred_test) 43 | 44 | return score, -1 45 | 46 | 47 | def faithfulness_linear(concept_train, pred_train, concept_test, pred_test, hard=False, prob_test=None): 48 | # faithfulness with linear classifier 49 | concept_train = to_numpy(concept_train) 50 | pred_train = to_numpy(pred_train) 51 | concept_test = to_numpy(concept_test) 52 | pred_test = to_numpy(pred_test) 53 | ct = concept_train 54 | total = np.prod(ct.shape) 55 | mx = np.max(ct) 56 | mn = np.min(ct) 57 | span = mx - mn 58 | pipe = make_pipeline(preprocessing.StandardScaler(), LogisticRegression(random_state=0, max_iter=2500)) 59 | clf = pipe.fit(concept_train, pred_train) 60 | hard_score = clf.score(concept_test, pred_test) 61 | prob = clf.predict_proba(concept_test) 62 | soft_score = np.mean([entropy(prob[i], prob_test[i]) for i in range(len(prob))]) 63 | 64 | return hard_score, soft_score 65 | 66 | 67 | def stability(concept_orig, concept_aug, compute=True): 68 | assert len(concept_orig.shape) == 2 69 | assert concept_orig.shape == concept_aug.shape 70 | delta = np.linalg.norm(concept_orig - concept_aug, axis=1) / np.linalg.norm(concept_orig, axis=1) 71 | 72 | return np.mean(delta) 73 | 74 | def sparsity(concept): 75 | assert len(concept.shape) == 2 76 | eps = 0.1 / concept.shape[1] 77 | return np.mean(concept < eps) 78 | 79 | def parsimony(concept): 80 | assert len(concept.shape) == 2 81 | return concept.shape[1] 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Probabilistic Conceptual Explainers
for Foundation Models 2 | This repo contains the code and data for our PACE (ICML 2024 paper): 3 | 4 | **Probabilistic Conceptual Explainers: Trustworthy Conceptual Explanations for Vision Foundation Models**
5 | Hengyi Wang*, Shiwei Tan*, Hao Wang
6 | [[Paper](http://www.wanghao.in/paper/ICML24_PACE.pdf)] [[ICML Website](https://icml.cc/virtual/2024/poster/34650)] 7 | 8 | and our VALC (EMNLP 2024 Findings paper): 9 | 10 | **Variational Language Concepts for Interpreting Foundation Language Models**
11 | Hengyi Wang, Shiwei Tan, Zhiqing Hong, Desheng Zhang, Hao Wang
12 | [[Paper](http://www.wanghao.in/paper/EMNLP24_VALC.pdf)] [[ACL Website](https://aclanthology.org/2024.findings-emnlp.505/)] 13 | 14 | ## Brief Introduction for PACE 15 | We propose five desiderata for explaining vision foundation models like ViTs - faithfulness, stability, sparsity, multi-level structure, and parsimony - and demonstrate the inadequacy of current methods in meeting these criteria comprehensively. Rather than using sparse autoencoders (SAEs), we introduce a variational Bayesian explanation framework, dubbed ProbAbilistic Concept Explainers (PACE), which models the distributions of patch embeddings to provide trustworthy post-hoc conceptual explanations. Our PACE can provide dataset-, image-, and patch-level explanations for ViTs and achieves all five desiderata (faithfulness, stability, sparsity, multi-level structure, and parsimony) in a unified framework. 16 | 17 | ## Probabilistic Conceptual Explainers (PACE) for Vision Transformers (ViTs) 18 | PACE is compatible with *arbitrary* vision transformers. 19 | 20 | Below are some sample concepts automatically discovered by our PACE, *without the need for concept annotation during training*. 21 | 22 | ![More_Random_Samples_Color](https://github.com/user-attachments/assets/f39aa0c6-3427-428e-ada9-aa9880d0ca09) 23 | 24 | **Figure 1.** Above are some sample concepts discovered by PACE in the *COLOR* dataset. See Figure 3 of [our paper](http://wanghao.in/paper/ICML24_PACE.pdf) for details on the *COLOR* dataset. 25 | 26 | ![More_Random_Samples_Flower](https://github.com/user-attachments/assets/80bd9dcf-2514-49ca-a659-6b101d423044) 27 | 28 | **Figure 2.** Above are some sample concepts discovered by PACE in the *Oxford Flower* dataset. 29 | 30 | 31 | 32 | ### Installation 33 | ```bash 34 | conda env create -f environment_PACE.yml 35 | conda activate PACE 36 | cd src 37 | ``` 38 | 39 | ### Generate the *Color* Dataset 40 | ```bash 41 | python generate_data.py 42 | ``` 43 | ### Finetune ViT for Color Dataset and Real-World Datasets 44 | ``` 45 | bash ./train_ViT.sh 46 | ``` 47 | 48 | ### Train PACE for Each Dataset 49 | ```bash 50 | bash ./train_PACE.sh 51 | ``` 52 | 53 | ### Test PACE for Each Dataset 54 | ```bash 55 | bash ./eval_PACE.sh 56 | ``` 57 | 58 | 59 | ## Probabilistic Conceptual Explainers (VALC) for Pretrained Language Models 60 | 61 | Coming Soon! 62 | 63 | ## Reference 64 | 65 | ```bib 66 | @inproceedings{PACE, 67 | title={Probabilistic Conceptual Explainers: Trustworthy Conceptual Explanations for Vision Foundation Models}, 68 | author={Hengyi Wang and 69 | Shiwei Tan and 70 | Hao Wang}, 71 | booktitle={International Conference on Machine Learning}, 72 | year={2024} 73 | } 74 | 75 | @inproceedings{VALC, 76 | title={Variational Language Concepts for Interpreting Foundation Language Models}, 77 | author={Hengyi Wang and 78 | Shiwei Tan and 79 | Zhiqing Hong and 80 | Desheng Zhang and 81 | Hao Wang}, 82 | booktitle={Findings of the Association for Computational Linguistics: EMNLP 2024}, 83 | year={2024} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /PACE/src/augment.py: -------------------------------------------------------------------------------- 1 | from tkinter import image_types 2 | import numpy as np 3 | import pandas as pd 4 | from numpy.linalg import * 5 | from scipy.linalg import sqrtm 6 | from scipy.special import gammaln, psi 7 | import torch 8 | from sklearn.decomposition import PCA 9 | import matplotlib.pyplot as plt 10 | from datasets import load_metric 11 | from transformers import ViTFeatureExtractor 12 | from torchvision.transforms import (CenterCrop, 13 | Compose, 14 | Normalize, 15 | RandomHorizontalFlip, 16 | RandomResizedCrop, 17 | Resize, 18 | ToTensor) 19 | from torch import FloatTensor, div 20 | from torch.utils.data import DataLoader, Dataset 21 | from torchvision import transforms 22 | from torchvision.transforms.functional import InterpolationMode 23 | 24 | import torch.nn.functional as F 25 | import os 26 | import pandas as pd 27 | from torchvision.datasets.folder import default_loader 28 | from torchvision.datasets.utils import download_url 29 | from torch.utils.data import Dataset 30 | import pickle 31 | from sklearn.cluster import KMeans 32 | from sklearn.linear_model import LogisticRegression 33 | from utils import * 34 | 35 | 36 | 37 | 38 | def cov_z(phi): 39 | ret = np.empty((phi.shape[0], phi.shape[0])) 40 | for i in range(phi.shape[0]): 41 | for j in range(phi.shape[0]): 42 | if i == j: 43 | ret[i,j] = phi[i]*(1-phi[i]) 44 | else: 45 | ret[i,j] = - phi[i]*phi[j] 46 | return ret 47 | 48 | 49 | 50 | 51 | 52 | def contrastive_learning(e, e_prime, temperature=0.1): 53 | """ 54 | Args: 55 | e (torch.Tensor): A tensor of shape (B, d) representing the original embeddings. 56 | e_prime (torch.Tensor): A tensor of shape (B, d) representing the transformed embeddings. 57 | temperature (float): A temperature hyperparameter to scale the logits. 58 | 59 | Returns: 60 | loss (torch.Tensor): The contrastive loss value. 61 | """ 62 | # Normalize the embeddings and transformations to have unit length 63 | e = e.float() 64 | e_prime = e_prime.float() 65 | e.requires_grad_(True) 66 | 67 | e = F.normalize(e, dim=-1) 68 | e_prime = F.normalize(e_prime, dim=-1) 69 | 70 | # Compute dot product similarity between e and e' (positive pairs) 71 | positive_similarity = torch.sum(e * e_prime, dim=-1) 72 | 73 | # Compute dot product similarity between e and all other e' (negative pairs) 74 | negative_similarity = torch.mm(e, e_prime.t()) 75 | 76 | # Remove the similarity of positive pairs from negative_similarity 77 | diagonal_indices = torch.arange(e.shape[0]) 78 | negative_similarity[diagonal_indices, diagonal_indices] = float('-inf') 79 | 80 | # Compute the logits for positive and negative pairs 81 | logits = torch.cat([positive_similarity.unsqueeze(1), negative_similarity], dim=1) 82 | 83 | # Apply temperature scaling to the logits 84 | logits /= temperature 85 | 86 | # Create labels for the positive pairs (the first column in logits) 87 | labels = torch.zeros(e.shape[0], dtype=torch.long, device=e.device) 88 | 89 | # gradient of e, not use autograd 90 | grad_e = torch.zeros_like(e) 91 | grad_e_prime = torch.zeros_like(e_prime) 92 | # Compute cross entropy loss for each sample 93 | loss = F.cross_entropy(logits, labels) 94 | # Compute gradients for e and e_prime 95 | grad_e = torch.autograd.grad(loss, e, retain_graph=True)[0] 96 | 97 | return grad_e, loss 98 | 99 | def image_augment(image): 100 | b, c, h, w = image.shape 101 | contrast_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), 102 | transforms.RandomResizedCrop(size=w), 103 | transforms.RandomApply([ 104 | transforms.ColorJitter(brightness=0.5, 105 | contrast=0.5, 106 | saturation=0.5, 107 | hue=0.1) 108 | ], p=0.8), 109 | transforms.RandomGrayscale(p=0.2), 110 | transforms.GaussianBlur(kernel_size=9), 111 | #transforms.ToTensor() 112 | transforms.Normalize((0.5,), (0.5,)) 113 | ]) 114 | image_trans = contrast_transforms(image).cuda() 115 | return image_trans 116 | 117 | def contrative_transform(image, model, topic_model, temperature=0.1): 118 | # image: input image batch (B, C, H, W) 119 | # model: ViT model 120 | contrast_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), 121 | transforms.RandomResizedCrop(size=96), 122 | transforms.RandomApply([ 123 | transforms.ColorJitter(brightness=0.5, 124 | contrast=0.5, 125 | saturation=0.5, 126 | hue=0.1) 127 | ], p=0.8), 128 | transforms.RandomGrayscale(p=0.2), 129 | transforms.GaussianBlur(kernel_size=9), 130 | transforms.ToTensor(), 131 | transforms.Normalize((0.5,), (0.5,)) 132 | ]) 133 | image_trans = contrast_transforms(image) 134 | embed = model(image) 135 | embed_trans = model(image_trans) 136 | 137 | return embed, embed_trans 138 | -------------------------------------------------------------------------------- /PACE/environment_PACE.yml: -------------------------------------------------------------------------------- 1 | name: PACE 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - _tflow_select=2.3.0=mkl 8 | - absl-py=1.3.0=py37h06a4308_0 9 | - astor=0.8.1=py37h06a4308_0 10 | - astunparse=1.6.3=py_0 11 | - asynctest=0.13.0=py_0 12 | - backcall=0.2.0=pyhd3eb1b0_0 13 | - blas=1.0=mkl 14 | - blinker=1.4=py37h06a4308_0 15 | - brotli=1.0.9=h5eee18b_7 16 | - brotli-bin=1.0.9=h5eee18b_7 17 | - brotlipy=0.7.0=py37h27cfd23_1003 18 | - c-ares=1.19.1=h5eee18b_0 19 | - ca-certificates=2023.08.22=h06a4308_0 20 | - cachetools=4.2.2=pyhd3eb1b0_0 21 | - certifi=2022.12.7=py37h06a4308_0 22 | - cffi=1.15.0=py37h7f8727e_0 23 | - cryptography=39.0.1=py37h9ce1e76_0 24 | - cudatoolkit=10.1.243=h6bb024c_0 25 | - cycler=0.11.0=pyhd3eb1b0_0 26 | - dbus=1.13.18=hb2f20db_0 27 | - decorator=5.1.1=pyhd3eb1b0_0 28 | - entrypoints=0.4=py37h06a4308_0 29 | - expat=2.5.0=h6a678d5_0 30 | - fftw=3.3.9=h27cfd23_1 31 | - fontconfig=2.14.1=h52c9d5c_1 32 | - fonttools=4.25.0=pyhd3eb1b0_0 33 | - freetype=2.12.1=h4a9f257_0 34 | - frozenlist=1.3.3=py37h5eee18b_0 35 | - gast=0.4.0=pyhd3eb1b0_0 36 | - giflib=5.2.1=h5eee18b_3 37 | - glib=2.63.1=h5a9c865_0 38 | - google-auth=2.6.0=pyhd3eb1b0_0 39 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 40 | - google-pasta=0.2.0=pyhd3eb1b0_0 41 | - grpcio=1.42.0=py37hce63b2e_0 42 | - gst-plugins-base=1.14.0=hbbd80ab_1 43 | - gstreamer=1.14.0=hb453b48_1 44 | - h5py=2.10.0=py37hd6299e0_1 45 | - hdf5=1.10.6=h3ffc7dd_1 46 | - icu=58.2=he6710b0_3 47 | - idna=3.4=py37h06a4308_0 48 | - importlib_metadata=4.11.3=hd3eb1b0_0 49 | - intel-openmp=2021.4.0=h06a4308_3561 50 | - joblib=1.1.1=py37h06a4308_0 51 | - jpeg=9e=h5eee18b_1 52 | - jupyter_client=7.4.9=py37h06a4308_0 53 | - jupyter_core=4.11.2=py37h06a4308_0 54 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 55 | - kiwisolver=1.4.4=py37h6a678d5_0 56 | - lcms2=2.12=h3be6417_0 57 | - lerc=3.0=h295c915_0 58 | - libbrotlicommon=1.0.9=h5eee18b_7 59 | - libbrotlidec=1.0.9=h5eee18b_7 60 | - libbrotlienc=1.0.9=h5eee18b_7 61 | - libdeflate=1.17=h5eee18b_0 62 | - libedit=3.1.20221030=h5eee18b_0 63 | - libffi=3.2.1=hf484d3e_1007 64 | - libgcc-ng=11.2.0=h1234567_1 65 | - libgfortran-ng=11.2.0=h00389a5_1 66 | - libgfortran5=11.2.0=h1234567_1 67 | - libgomp=11.2.0=h1234567_1 68 | - libpng=1.6.39=h5eee18b_0 69 | - libprotobuf=3.20.3=he621ea3_0 70 | - libsodium=1.0.18=h7b6447c_0 71 | - libstdcxx-ng=11.2.0=h1234567_1 72 | - libtiff=4.5.1=h6a678d5_0 73 | - libuuid=1.41.5=h5eee18b_0 74 | - libwebp=1.2.4=h11a3e52_1 75 | - libwebp-base=1.2.4=h5eee18b_1 76 | - libxcb=1.15=h7f8727e_0 77 | - libxml2=2.9.14=h74e7548_0 78 | - lz4-c=1.9.4=h6a678d5_0 79 | - markdown=3.4.1=py37h06a4308_0 80 | - markupsafe=2.1.1=py37h7f8727e_0 81 | - matplotlib=3.5.3=py37h06a4308_0 82 | - matplotlib-base=3.5.3=py37hf590b9c_0 83 | - matplotlib-inline=0.1.6=py37h06a4308_0 84 | - mkl=2021.4.0=h06a4308_640 85 | - mkl-service=2.4.0=py37h7f8727e_0 86 | - mkl_fft=1.3.1=py37hd3c417c_0 87 | - mkl_random=1.2.2=py37h51133e4_0 88 | - munkres=1.1.4=py_0 89 | - ncurses=6.4=h6a678d5_0 90 | - numpy-base=1.21.5=py37ha15fc14_3 91 | - oauthlib=3.2.1=py37h06a4308_0 92 | - openssl=1.1.1w=h7f8727e_0 93 | - opt_einsum=3.3.0=pyhd3eb1b0_1 94 | - parso=0.8.3=pyhd3eb1b0_0 95 | - pcre=8.45=h295c915_0 96 | - pexpect=4.8.0=pyhd3eb1b0_3 97 | - pickleshare=0.7.5=pyhd3eb1b0_1003 98 | - pip=22.3.1=py37h06a4308_0 99 | - ptyprocess=0.7.0=pyhd3eb1b0_2 100 | - pyasn1=0.4.8=pyhd3eb1b0_0 101 | - pyasn1-modules=0.2.8=py_0 102 | - pycparser=2.21=pyhd3eb1b0_0 103 | - pyjwt=2.4.0=py37h06a4308_0 104 | - pyopenssl=23.0.0=py37h06a4308_0 105 | - pyparsing=3.0.9=py37h06a4308_0 106 | - pyqt=5.9.2=py37h05f1152_2 107 | - pysocks=1.7.1=py37_1 108 | - python=3.7.2=h0371630_0 109 | - python-dateutil=2.8.2=pyhd3eb1b0_0 110 | - python-flatbuffers=2.0=pyhd3eb1b0_0 111 | - qt=5.9.7=h5867ecd_1 112 | - readline=7.0=h7b6447c_5 113 | - requests-oauthlib=1.3.0=py_0 114 | - rsa=4.7.2=pyhd3eb1b0_1 115 | - scikit-learn=1.0.2=py37h51133e4_1 116 | - scipy=1.7.3=py37h6c91a56_2 117 | - setuptools=65.6.3=py37h06a4308_0 118 | - sip=4.19.8=py37hf484d3e_0 119 | - six=1.16.0=pyhd3eb1b0_1 120 | - sqlite=3.33.0=h62c20be_0 121 | - tensorboard=2.10.0=py37h06a4308_0 122 | - tensorboard-data-server=0.6.1=py37h52d8a92_0 123 | - tensorboard-plugin-wit=1.8.1=py37h06a4308_0 124 | - tensorflow=2.4.1=mkl_py37h2d14ff2_0 125 | - tensorflow-base=2.4.1=mkl_py37h43e0292_0 126 | - tensorflow-estimator=2.6.0=pyh7b7c402_0 127 | - termcolor=2.1.0=py37h06a4308_0 128 | - threadpoolctl=2.2.0=pyh0d69192_0 129 | - tk=8.6.12=h1ccaba5_0 130 | - tornado=6.2=py37h5eee18b_0 131 | - tqdm=4.64.1=py37h06a4308_0 132 | - typing_extensions=4.1.1=pyh06a4308_0 133 | - werkzeug=2.2.2=py37h06a4308_0 134 | - wheel=0.38.4=py37h06a4308_0 135 | - wrapt=1.14.1=py37h5eee18b_0 136 | - xz=5.4.2=h5eee18b_0 137 | - zeromq=4.3.4=h2531618_0 138 | - zlib=1.2.13=h5eee18b_0 139 | - zstd=1.5.5=hc292b87_0 140 | - pip: 141 | - accelerate==0.20.3 142 | - aiohttp==3.8.5 143 | - aiosignal==1.3.1 144 | - appdirs==1.4.4 145 | - async-timeout==4.0.3 146 | - attrs==23.1.0 147 | - chardet==5.2.0 148 | - charset-normalizer==2.1.1 149 | - click==8.1.7 150 | - craft-xai==0.0.3 151 | - datasets==2.13.1 152 | - debugpy==1.6.7.post1 153 | - dill==0.3.6 154 | - docker-pycreds==0.4.0 155 | - filelock==3.12.2 156 | - fsspec==2023.1.0 157 | - gitdb==4.0.10 158 | - gitpython==3.1.34 159 | - huggingface-hub==0.16.4 160 | - imageio==2.31.2 161 | - importlib-metadata==6.7.0 162 | - ipykernel==6.16.2 163 | - ipython==7.34.0 164 | - jedi==0.19.0 165 | - jupyter-core==4.12.0 166 | - loguru==0.7.2 167 | - mingpt==0.0.1 168 | - multidict==6.0.4 169 | - multiprocess==0.70.14 170 | - nest-asyncio==1.5.7 171 | - networkx==2.6.3 172 | - numpy==1.21.6 173 | - nvidia-cublas-cu11==11.10.3.66 174 | - nvidia-cuda-nvrtc-cu11==11.7.99 175 | - nvidia-cuda-runtime-cu11==11.7.99 176 | - nvidia-cudnn-cu11==8.5.0.96 177 | - opencv-python==4.8.0.76 178 | - packaging==23.1 179 | - pandas==1.3.5 180 | - pathtools==0.1.2 181 | - pillow==9.5.0 182 | - prompt-toolkit==3.0.39 183 | - protobuf==4.24.2 184 | - psutil==5.9.5 185 | - pyarrow==12.0.1 186 | - pygments==2.16.1 187 | - python-graphviz==0.20.1 188 | - pytz==2023.3 189 | - pywavelets==1.3.0 190 | - pyyaml==6.0.1 191 | - pyzmq==25.1.1 192 | - regex==2023.8.8 193 | - requests==2.31.0 194 | - safetensors==0.3.3 195 | - scikit-image==0.19.3 196 | - sentry-sdk==1.30.0 197 | - setproctitle==1.3.2 198 | - smmap==5.0.0 199 | - tifffile==2021.11.2 200 | - timm==0.9.7 201 | - tokenizers==0.13.3 202 | - torch==1.13.1 203 | - torchvision==0.14.1 204 | - torchviz==0.0.2 205 | - traitlets==5.9.0 206 | - transformers==4.30.2 207 | - typing-extensions==4.7.1 208 | - urllib3==2.0.4 209 | - wandb==0.15.9 210 | - wcwidth==0.2.6 211 | - xxhash==3.3.0 212 | - yarl==1.9.2 213 | - zipp==3.15.0 214 | -------------------------------------------------------------------------------- /PACE/src/main.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | from transformers import Trainer, TrainingArguments 3 | from transformers import EarlyStoppingCallback, TrainerCallback 4 | from transformers import ViTFeatureExtractor, ViTForImageClassification 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | from datasets import load_metric,load_dataset 10 | import pickle 11 | import os 12 | from math import pi 13 | import numpy as np 14 | from sklearn.mixture import GaussianMixture 15 | import pickle 16 | from numpy import random 17 | import scipy.sparse as sp 18 | from scipy.special import gammaln 19 | from tqdm import tqdm 20 | from sklearn.decomposition import PCA 21 | from sklearn import manifold 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | import pickle 25 | import sys, re, time, string 26 | from scipy.special import gammaln, psi 27 | from numpy.linalg import * 28 | import math 29 | import pandas as pd 30 | from config import parser 31 | import torchvision.transforms as transforms 32 | import torchvision 33 | from utils import accuracy_score, dirichlet_expectation, read_tsv_file, compute_metrics, Adam, posterior_mu_sigma 34 | from utils import Cub2011, StanfordCars, MyImageDatasetFromStanfordCars, build_transform, load_dataset_by_task 35 | from model import PACE, ViTClassify 36 | from torchviz import make_dot 37 | from utils import dirichlet_expectation 38 | from torchvision.transforms.functional import InterpolationMode 39 | from augment import image_augment 40 | from transformers import TrainingArguments, Trainer, AutoFeatureExtractor 41 | from utils import MyImageDataset 42 | from datasets import load_dataset 43 | import os 44 | from PIL import Image 45 | import torch 46 | import torchvision.transforms as transforms 47 | from torch.utils.data import random_split 48 | 49 | args = parser.parse_args() 50 | 51 | #print("args", args) 52 | 53 | args.save_path = os.path.normpath(os.path.join(args.save_path, args.name)) # adjusts to windows or linux 54 | 55 | # create the directory if it doesn't exist 56 | if not os.path.exists(args.save_path): 57 | print(f"❌ Directory does not exist: {args.save_path}") 58 | os.makedirs(args.save_path, exist_ok=True) 59 | print(f"✅ Created directory: {args.save_path}") 60 | 61 | # Windows: ..\ckpt\ViT-base 62 | # Linux/macOS: ../ckpt/ViT-base 63 | 64 | np.random.seed(args.seed) 65 | torch.manual_seed(args.seed) 66 | random.seed(args.seed) 67 | 68 | class MyEarlyStoppingCallback(EarlyStoppingCallback): 69 | 70 | def __init__(self, early_stopping_patience=1, early_stopping_threshold=0.0, args=None): 71 | super(MyEarlyStoppingCallback,self).__init__(early_stopping_patience, early_stopping_threshold) 72 | self.epochs = 0 73 | self.ub = 0 74 | self.flag = True 75 | self.args = args 76 | 77 | def on_evaluate(self, args, state, control, metrics, **kwargs): 78 | metric_to_check = args.metric_for_best_model 79 | if not metric_to_check.startswith("eval_"): 80 | metric_to_check = f"eval_{metric_to_check}" 81 | metric_value = metrics.get(metric_to_check) 82 | 83 | self.check_metric_value(args, state, control, metric_value) 84 | if self.early_stopping_patience_counter >= self.early_stopping_patience: 85 | control.should_training_stop = True 86 | self.epochs += 1 87 | self.flag = True 88 | 89 | if PACE is None: 90 | return 91 | 92 | # save model 93 | args = self.args 94 | torch.save(model.state_dict(), args.save_path +'/' + args.task + '_epoch'+str(self.epochs)+'.pt') 95 | np.save(args.save_path+'/' + args.task + '_mus-epoch'+str(self.epochs) +'.npy',PACE._mus) 96 | np.save(args.save_path+'/' + args.task + '_sigmas-epoch'+str(self.epochs)+'.npy',PACE._sigmas) 97 | np.save(args.save_path+'/' + args.task + '_eta-epoch'+str(self.epochs)+'.npy',PACE._eta) 98 | 99 | 100 | class PACETrainer(Trainer): 101 | 102 | def compute_loss(self,model,inputs,return_outputs=False, **kwargs): # **args... 103 | #output = model(inputs['encodings']) # get predict outputs and last word embeddings 104 | logits, states, att = model(inputs['encodings']) 105 | image_trans = image_augment(inputs['encodings']) 106 | logits_trans, states_trans, att_trans = model(image_trans) 107 | 108 | loss = torch.nn.CrossEntropyLoss() 109 | pred_prob = np.exp(logits.detach().cpu().numpy())/np.exp(logits.detach().cpu().numpy()).sum(axis=1)[:,None] 110 | pred_ids = torch.argmax(logits,-1) 111 | ViT_loss = loss(logits, inputs['labels']).sum() 112 | if mycallback.flag is True: 113 | mycallback.flag = False 114 | 115 | if PACE is not None and return_outputs is False: 116 | 117 | gamma_trans, phi_trans = PACE.do_e_step(states_trans, att_trans[args.layer + 1]) 118 | PACE._phi_trans = PACE._phi 119 | gamma, phi = PACE.do_em_step(states, att[args.layer + 1], cl=True,y=pred_prob) 120 | PACE.update_eta(logits) 121 | pred_prob = torch.argmax(logits, dim=-1) 122 | PACE_loss = PACE._delta 123 | custum_loss = PACE_loss 124 | else: 125 | custum_loss = ViT_loss 126 | 127 | 128 | pred = {'label_ids':inputs['labels'], 'predictions':pred_ids} 129 | return (custum_loss, pred) if return_outputs else custum_loss 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | img_size = 224 140 | randaug_magnitude = 0 141 | 142 | 143 | 144 | img_size = (224,224) 145 | transform = transforms.Compose([ 146 | transforms.ToTensor(), 147 | transforms.Resize(img_size, interpolation=InterpolationMode.BICUBIC), 148 | #transforms.RandAugment(num_ops=2,magnitude=randaug_magnitude), 149 | transforms.Normalize( 150 | mean=(0.485, 0.456, 0.406), 151 | std=(0.229, 0.224, 0.225) 152 | ), 153 | ]) 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | # Dataset loading using load_dataset_by_task function 162 | train_dataset, test_dataset, args.out_dim = load_dataset_by_task(args.task, args.data_path) 163 | val_dataset = test_dataset 164 | 165 | # Create data loaders for easier batch processing 166 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 167 | test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) 168 | 169 | model = ViTClassify(in_dim = args.b_dim, out_dim=args.out_dim,hid_dim=args.c_dim, layer=args.layer) 170 | model = model.cuda() 171 | 172 | if 'PACE' in args.name: 173 | PACE = PACE(d=args.c_dim,K=args.K,D=args.D,N=args.N,alpha=args.alpha,C = args.out_dim) 174 | else: 175 | PACE = None 176 | 177 | training_args = TrainingArguments( 178 | output_dir='../results', # output directory 179 | num_train_epochs=args.num_epochs, # total number of training epochs 180 | per_device_train_batch_size=args.train_batch_size, # batch size per device during training 181 | per_device_eval_batch_size=args.eval_batch_size, # batch size for evaluation 182 | warmup_steps=500, # number of warmup steps for learning rate scheduler 183 | weight_decay=args.weight_decay, # strength of weight decay 184 | logging_dir='./logs', # directory for storing logs 185 | logging_steps=10, 186 | seed = args.seed, 187 | load_best_model_at_end=True, 188 | metric_for_best_model=args.metric, # 'eval_matthews_correlation' for cola, etc. 189 | eval_strategy='epoch', 190 | save_strategy='epoch', 191 | learning_rate = args.lr, 192 | lr_scheduler_type='cosine', # cosine learning rate scheduler 193 | save_total_limit=3, # limit the number of saved checkpoints 194 | gradient_accumulation_steps=1, # gradient accumulation steps 195 | fp16=True, # use mixed precision training to accelerate 196 | ) 197 | 198 | mycallback = MyEarlyStoppingCallback(early_stopping_patience=10, args=args) 199 | 200 | trainer = PACETrainer( 201 | model=model, # the instantiated 🤗 Transformers model to be trained 202 | args=training_args, # training arguments, defined above 203 | train_dataset=train_dataset, # training dataset 204 | eval_dataset=val_dataset, # evaluation dataset 205 | compute_metrics=compute_metrics, 206 | callbacks=[mycallback], 207 | 208 | ) 209 | 210 | test_set = DataLoader(val_dataset,batch_size=args.eval_batch_size,shuffle=False) 211 | 212 | print('train size', len(train_dataset)) 213 | print('eval size', len(val_dataset)) 214 | 215 | if args.train: 216 | print('training...') 217 | if not args.require_grad: # Train PACE, otherwise train ViT 218 | model.load_state_dict(torch.load('../ckpt/ViT-base' +'/' + args.task + '_epoch' + args.pretrain_epoch +'.pt')) 219 | trainer.train() 220 | torch.save(model.state_dict(), args.save_path +'/' + args.task + '_epoch'+str(args.num_epochs)+'.pt') 221 | else: 222 | print('evaluating...') 223 | model.load_state_dict(torch.load(args.save_path +'/' + args.task + '_epoch'+str(args.num_epochs)+'.pt')) 224 | 225 | if PACE is not None: 226 | if args.train: 227 | np.save(args.save_path+'/' + args.task + '_mus-epoch'+str(args.num_epochs)+'.npy',PACE._mus) 228 | np.save(args.save_path+'/' + args.task + '_sigmas-epoch'+str(args.num_epochs)+'.npy',PACE._sigmas) 229 | np.save(args.save_path+'/' + args.task + '_eta-epoch'+str(args.num_epochs)+'.npy',PACE._eta) 230 | else: 231 | PACE._mus = np.load(args.save_path+'/' + args.task + '_mus-epoch'+str(args.num_epochs)+'.npy') 232 | PACE._sigmas = np.load(args.save_path+'/' + args.task + '_sigmas-epoch'+str(args.num_epochs)+'.npy') 233 | PACE._eta = np.load(args.save_path+'/' + args.task + '_eta-epoch'+str(args.num_epochs)+'.npy') 234 | 235 | # explain ViT 236 | print('PACE is explaining ViT...') 237 | for i, inputs in enumerate(test_set): 238 | test_encodings = inputs['encodings'].cuda() 239 | test_labels = inputs['labels'].cuda() 240 | logits, states, att = model(test_encodings) 241 | 242 | # infer phi and gamma 243 | gamma, phi = PACE.do_e_step(states, att[args.layer + 1]) 244 | # infer E[log(theta)], which is the expectation of log(theta) 245 | E_log_theta = dirichlet_expectation(gamma) 246 | 247 | # Normalize gamma so each row sums to 1 248 | gamma = gamma / gamma.sum(axis=1, keepdims=True) 249 | 250 | # Collect images and gamma and then plot them 251 | num_to_plot = 10 252 | if i == 0: 253 | images_to_plot = [] 254 | gammas_to_plot = [] 255 | plotted = False 256 | 257 | # Get the original images (before transform) for this batch 258 | start_idx = i * args.eval_batch_size 259 | end_idx = start_idx + test_encodings.shape[0] 260 | # from dataset to original images 261 | test_images = dataset['train']['image'] 262 | original_images = test_images[start_idx:end_idx] 263 | 264 | # Move gamma to CPU and convert to numpy if it's a tensor 265 | if isinstance(gamma, torch.Tensor): 266 | gamma_np = gamma.detach().cpu().numpy() 267 | else: 268 | gamma_np = gamma 269 | 270 | for img, g in zip(original_images, gamma_np): 271 | if len(images_to_plot) < num_to_plot: 272 | images_to_plot.append(img) 273 | gammas_to_plot.append(g) 274 | else: 275 | break 276 | 277 | # After the loop, plot the first num_to_plot images and their gamma 278 | if len(images_to_plot) >= num_to_plot and not plotted: 279 | print(f'plotting the first {num_to_plot} images and their normalized gamma, i.e., the image-level explanation theta...') 280 | fig, axes = plt.subplots(num_to_plot, 2, figsize=(8, 4 * num_to_plot)) 281 | for idx in range(num_to_plot): 282 | # Show image 283 | axes[idx, 0].imshow(images_to_plot[idx]) 284 | axes[idx, 0].axis('off') 285 | axes[idx, 0].set_title(f"Test Image {idx+1}") 286 | # Show gamma as bar plot 287 | axes[idx, 1].bar(np.arange(len(gammas_to_plot[idx])), gammas_to_plot[idx]) 288 | axes[idx, 1].set_title(f"Gamma {idx+1}") 289 | plt.tight_layout() 290 | plt.savefig(args.save_path + '/' + args.task + '_epoch' + str(args.num_epochs) + 'gamma_test_images.pdf') 291 | print('plotting done') 292 | plotted = True 293 | # plt.show() -------------------------------------------------------------------------------- /PACE/src/run_quantitative_theta.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | from transformers import pipeline 3 | from transformers import BertTokenizer, BertModel 4 | from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments 5 | from transformers import EarlyStoppingCallback, TrainerCallback 6 | from transformers import ViTFeatureExtractor, ViTForImageClassification 7 | import torch 8 | from torch import nn 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader 11 | from datasets import load_metric,load_dataset 12 | import pickle 13 | import os 14 | from math import pi 15 | import numpy as np 16 | from sklearn.mixture import GaussianMixture 17 | import pickle 18 | from numpy import random 19 | import scipy.sparse as sp 20 | from scipy.special import gammaln 21 | from tqdm import tqdm 22 | from sklearn.decomposition import PCA 23 | from sklearn import manifold 24 | import matplotlib.pyplot as plt 25 | import numpy as np 26 | import pickle 27 | import sys, re, time, string 28 | from scipy.special import gammaln, psi 29 | from numpy.linalg import * 30 | import math 31 | import pandas as pd 32 | import numpy as np 33 | from config import parser 34 | import torchvision.transforms as transforms 35 | import torchvision 36 | from utils import accuracy_score, dirichlet_expectation, read_tsv_file, compute_metrics, Adam, posterior_mu, posterior_mu_sigma, vis, kmeans_init 37 | from utils import run_kmeans, plot_topics 38 | from model import PACE, ViTClassify 39 | from torchviz import make_dot 40 | from utils import load_train_data, load_val_data, softmax, dirichlet_expectation 41 | from torchvision.transforms.functional import InterpolationMode 42 | #from captum.attr import Lime, LimeBase 43 | from augment import contrastive_learning, contrative_transform, image_augment 44 | import wandb 45 | from transformers import TrainingArguments, Trainer, AutoFeatureExtractor 46 | from utils import MyImageDataset 47 | from datasets import load_dataset 48 | from evaluate import stability, faithfulness, sparsity, parsimony, faithfulness_linear 49 | #import shap, lime 50 | from utils import attention_norm, topic_vis, load_dataset_by_task 51 | from PIL import Image 52 | import pdb 53 | from evaluate import to_numpy 54 | 55 | args = parser.parse_args() 56 | args.save_path = os.path.join(args.save_path, args.name) 57 | sample_path = os.path.join('../sample', args.name) 58 | 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | random.seed(args.seed) 62 | 63 | # Dataset loading using load_dataset_by_task function 64 | train_dataset, test_dataset, args.out_dim = load_dataset_by_task(args.task, args.data_path) 65 | val_dataset = test_dataset 66 | 67 | model = ViTClassify(in_dim = args.b_dim, out_dim=args.out_dim,hid_dim=args.c_dim, layer=args.layer) 68 | #print(model) 69 | 70 | model = model.cuda() 71 | 72 | if 'PACE' in args.name: 73 | PACE = PACE(d=args.c_dim,K=args.K,D=args.D,N=args.N,alpha=args.alpha,C = args.out_dim) 74 | else: 75 | PACE = None 76 | 77 | training_args = TrainingArguments( 78 | output_dir='./results', # output directory 79 | num_train_epochs=args.num_epochs, # total number of training epochs 80 | per_device_train_batch_size=args.train_batch_size, # batch size per device during training 81 | per_device_eval_batch_size=args.eval_batch_size, # batch size for evaluation 82 | warmup_steps=0, # number of warmup steps for learning rate scheduler change steps from 100 to 0 83 | weight_decay=args.weight_decay, # strength of weight decay 84 | logging_dir='./logs', # directory for storing logs 85 | logging_steps=10, 86 | seed = args.seed, 87 | load_best_model_at_end=True, 88 | metric_for_best_model=args.metric, # 'eval_matthews_correlation' for cola, etc. 89 | eval_strategy='epoch', 90 | save_strategy='epoch', 91 | learning_rate = args.lr, 92 | report_to="wandb", 93 | #resume_from_checkpoint=True, 94 | # eval_steps=100, 95 | ) 96 | 97 | 98 | 99 | test_set = DataLoader(val_dataset,batch_size=args.eval_batch_size,shuffle=True) 100 | train_set = DataLoader(train_dataset,batch_size=args.train_batch_size,shuffle=True) 101 | 102 | print('train size', len(train_dataset)) 103 | print('eval size', len(val_dataset)) 104 | 105 | #X0 = np.load(os.path.join(args.save_path, 'X-L-2.npy')) 106 | #PACE._mus = PACE._mu0 = run_kmeans(X0, args.K) 107 | 108 | print('evaluating') 109 | 110 | model.load_state_dict(torch.load(args.save_path +'/' + args.task + '_epoch'+str(args.num_epochs)+'.pt')) 111 | 112 | #args.version = 'kmeans' 113 | if PACE is not None: 114 | PACE._mus = np.load(args.save_path+'/' + args.task + '_mus-epoch'+str(args.num_epochs)+'.npy') 115 | PACE._sigmas = np.load(args.save_path+'/' + args.task + '_sigmas-epoch'+str(args.num_epochs)+'.npy') 116 | PACE._eta = np.load(args.save_path+'/' + args.task + '_eta-epoch'+str(args.num_epochs)+'.npy') 117 | 118 | 119 | 120 | 121 | # temperarilly CPU bounded, instead of GPU-bounded, needs multi-thread if multiple run at the same time 122 | # numpy matrix manipulation test 123 | 124 | x = None 125 | pos = [] 126 | topic = [] 127 | name = [] 128 | patch_img = [] 129 | full_img = [] 130 | word_embed = {} 131 | word_cnt = {} 132 | top_words = [{} for _ in range(args.K)] # maintain a priority queue of prob for tokens in each topic 133 | pred_label = [] 134 | tok = [] 135 | font = [] 136 | topic_cnt = {} 137 | 138 | 139 | # for idx in range(args.K): # args.K 140 | # if PACE is None: 141 | # continue 142 | # name.append(0) 143 | # topic.append('T_'+str(idx)) 144 | # #font.append(1) # np.exp(det(PACE._sigmas[idx])) 145 | # if x is None: 146 | # x = PACE._mus[idx].reshape(-1,args.c_dim) 147 | # else: 148 | # x = np.concatenate([x,PACE._mus[idx].reshape(-1,args.c_dim)],axis=0) 149 | # patch_img.append(np.ones((224//16,224//16,3))) # patch_img[-1].shape 150 | # full_img.append(np.ones((224,224,3))) # full_img[-1].shape 151 | # pos.append((-1,-1)) 152 | 153 | # x = torch.Tensor(x).cuda() 154 | 155 | topic_cnt = dict(sorted(topic_cnt.items(), key=lambda item: item[1],reverse=True)) 156 | print(topic_cnt) 157 | concepts = [[] for _ in range(args.K)] 158 | #top_topics = list(topic_cnt)[1:6] 159 | top_topics = {} 160 | tt_cp = [x for x in top_topics] 161 | tw = {} 162 | 163 | 164 | #top_topics = list(topic_cnt)[5:10] 165 | 166 | 167 | topic_se = 24 168 | class_1 = 10 169 | class_2 = 20 170 | 171 | sample_num = 50000#5000 172 | avg_corr = 0 173 | batch_cnt = 0 174 | 175 | # test metrics for LIME model 176 | # ref https://captum.ai/api/lime.html 177 | 178 | 179 | # interprete classifier from embedding inputs 180 | 181 | concept_all = [] 182 | label_all = [] 183 | 184 | model.eval() 185 | concept_test = [] 186 | concept_aug_test = [] 187 | concept_train = [] 188 | concept_aug_train = [] 189 | prob_train = [] 190 | prob_test = [] 191 | 192 | pred_train = [] 193 | pred_test = [] 194 | embeds = [] 195 | corpus = [] 196 | patches = [] 197 | attentions = [] 198 | concept_images = [] 199 | masked_concept_images = [] 200 | concept_labels = [] 201 | left_up_att = [] 202 | right_up_att = [] 203 | left_down_att = [] 204 | right_down_att = [] 205 | 206 | model_map = {'Sedan', 'SUV', 'Convertible', 'Minivan', 'Coupe', 'Wagon', 'Hatchback', 'Van', 'Truck', 'Pickup'} 207 | make_map = {'Audi', 'BMW', 'Chevrolet', 'Dodge', 'Ford', 'Honda', 'Hyundai', 'Jeep', 'Lexus', 'Mercedes-Benz', 'Nissan', 'Porsche', 'Subaru', 'Colorota', 'Volkswagen'} 208 | 209 | dataset_model_cnt = {} 210 | dataset_make_cnt = {} 211 | 212 | # only perform inference if not loading concepts from saved path 213 | if not args.load_concepts: 214 | print("Performing inference on train and test datasets...") 215 | 216 | #train datset 217 | with torch.no_grad(): 218 | cnt = 0 219 | for id, inputs in enumerate(train_set): 220 | print('train batch', id) 221 | train_encodings = inputs['encodings'].cuda() 222 | train_labels = inputs['labels'].cuda() 223 | #test_path = inputs['path'] 224 | #test_mask = inputs['attention_mask'].cuda() 225 | #print(test_encodings.size()) 226 | logits, states, att = model(train_encodings) 227 | 228 | # get augmented outputs 229 | image_trans = image_augment(inputs['encodings']) 230 | logits_trans, states_trans, att_trans = model(image_trans) 231 | 232 | preds = logits.argmax(-1) 233 | #print('preds', preds) 234 | #logits = logits.detach().cpu().numpy() 235 | 236 | for pp in preds: 237 | pred_train.append(pp) 238 | for i in range(len(logits)): 239 | prob_train.append((torch.softmax(logits[i], dim=0)).detach().cpu().numpy()) 240 | #print('prob', prob_train[-1]) 241 | if PACE is None: 242 | continue 243 | 244 | A_o = att[args.layer] 245 | #A_e = model.effective_attention(att[args.layer+1]) 246 | #A_e = attention_norm(A_e) 247 | A_o = attention_norm(A_o) 248 | gamma, phi = PACE.do_e_step(states, A_o) # inference w/o learning, so e step instead of em step. 249 | #gamma_trans, phi_trans = PACE.do_e_step(states_trans, att_trans[args.layer + 1]) 250 | E_log_theta = dirichlet_expectation(gamma) 251 | concept_train.append(np.exp(E_log_theta)) 252 | #concept_train.append(phi.mean(1)) 253 | #print(phi.mean(1)[0]) 254 | # A_o_trans = att_trans[args.layer] 255 | # A_o_trans = attention_norm(A_o_trans) 256 | # gamma_trans, phi_trans = PACE.do_e_step(states_trans, A_o) 257 | # concept_aug_train.append(phi_trans.mean(1)) 258 | 259 | # test dataset 260 | with torch.no_grad(): 261 | cnt = 0 262 | for id, inputs in enumerate(test_set): 263 | print('test batch', id) 264 | test_encodings = inputs['encodings'].cuda() 265 | test_labels = inputs['labels'].cuda() 266 | #test_path = inputs['path'] 267 | #test_mask = inputs['attention_mask'].cuda() 268 | #print(test_encodings.size()) 269 | logits, states, att = model(test_encodings) 270 | 271 | # get augmented outputs 272 | image_trans = image_augment(inputs['encodings']) 273 | logits_trans, states_trans, att_trans = model(image_trans) 274 | 275 | preds = logits.argmax(-1) 276 | #logits = logits.detach().cpu().numpy() 277 | 278 | for pp in preds: 279 | pred_test.append(pp) 280 | for i in range(len(logits)): 281 | prob_test.append((torch.softmax(logits[i], dim=0)).detach().cpu().numpy()) 282 | #print('prob', prob_test[-1]) 283 | if PACE is None: 284 | continue 285 | 286 | A_o = att[args.layer] 287 | #A_e = model.effective_attention(att[args.layer+1]) 288 | #A_e = attention_norm(A_e) 289 | A_o = attention_norm(A_o) 290 | gamma, phi = PACE.do_e_step(states, A_o) # inference w/o learning, so e step instead of em step. 291 | #gamma_trans, phi_trans = PACE.do_e_step(states_trans, att_trans[args.layer + 1]) 292 | E_log_theta = dirichlet_expectation(gamma) 293 | concept_test.append(np.exp(E_log_theta)) 294 | #concept_test.append(phi.mean(1)) 295 | #print(phi.mean(1)[0]) 296 | A_o_trans = att_trans[args.layer] 297 | A_o_trans = attention_norm(A_o_trans) 298 | gamma_trans, phi_trans = PACE.do_e_step(states_trans, A_o) 299 | E_log_theta = dirichlet_expectation(gamma_trans) 300 | concept_aug_test.append(np.exp(E_log_theta)) 301 | #concept_aug_test.append(phi_trans.mean(1)) 302 | 303 | # process inferred data 304 | concept_train = np.concatenate(concept_train, axis=0) 305 | concept_test = np.concatenate(concept_test, axis=0) 306 | concept_aug_test = np.concatenate(concept_aug_test, axis=0) 307 | # concept_aug_train = np.concatenate(concept_aug_train, axis=0) 308 | pred_train = to_numpy(torch.stack(pred_train)) 309 | pred_test = to_numpy(torch.stack(pred_test)) 310 | prob_train = np.array(prob_train) 311 | prob_test = np.array(prob_test) 312 | print('prob_test', prob_test.shape) 313 | else: 314 | print("Skipping inference - will load concepts and probabilities from saved path...") 315 | 316 | # conditionally load concepts and probabilities from saved path or use inferred ones 317 | if args.load_concepts: 318 | print("Loading concepts and probabilities from saved path...") 319 | concept_train = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_train.npy')) 320 | concept_test = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_test.npy')) 321 | # concept_aug_train = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_aug_train.npy')) 322 | concept_aug_test = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_aug_test.npy')) 323 | pred_train = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-pred_train.npy')) 324 | pred_test = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-pred_test.npy')) 325 | prob_train = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-prob_train.npy')) 326 | prob_test = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-prob_test.npy')) 327 | else: 328 | print("Using inferred concepts and probabilities from model...") 329 | 330 | 331 | # normalize concept_train and concept_test 332 | concept_train = concept_train / concept_train.sum(axis=1)[:,None] 333 | concept_test = concept_test / concept_test.sum(axis=1)[:,None] 334 | # concept_aug_train = concept_aug_train / concept_aug_train.sum(axis=1)[:,None] 335 | concept_aug_test = concept_aug_test / concept_aug_test.sum(axis=1)[:,None] 336 | 337 | 338 | # only save concepts and probabilities if we inferred them (not loaded from saved path) 339 | if not args.load_concepts: 340 | print("Saving inferred concepts and probabilities...") 341 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_train.npy'),concept_train) 342 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_test.npy'),concept_test) 343 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_aug_train.npy'),concept_aug_train) 344 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_aug_test.npy'),concept_aug_test) 345 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-pred_train.npy'),pred_train) 346 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-pred_test.npy'),pred_test) 347 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-prob_train.npy'),prob_train) 348 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-prob_test.npy'),prob_test) 349 | else: 350 | print("Skipping save since concepts and probabilities were loaded from saved path...") 351 | 352 | #pdb.set_trace() 353 | 354 | stability_score = stability(concept_test, concept_aug_test) 355 | 356 | print('stability', stability_score) 357 | 358 | fhard, fsoft = faithfulness_linear(concept_train, pred_train, concept_test, pred_test, prob_test=prob_test) 359 | #faithfulness = faithfulness(concept_test, pred_test, concept_test, pred_test) 360 | 361 | print('faithfulness', fhard, fsoft) 362 | 363 | sparsity = sparsity(concept_test) 364 | 365 | print('sparsity', sparsity) 366 | 367 | parsimony = parsimony(concept_test) 368 | 369 | print('parsimony', parsimony) 370 | 371 | # log txt file 372 | with open(args.save_path + '/' + args.task + '_epoch' + str(args.num_epochs) + '.txt', 'w') as f: 373 | f.write('stability: ' + str(stability_score) + '\n') 374 | f.write('faithfulness: ' + str(fhard) + ' ' + str(fsoft) + '\n') 375 | f.write('sparsity: ' + str(sparsity) + '\n') 376 | f.write('parsimony: ' + str(parsimony) + '\n') 377 | 378 | 379 | -------------------------------------------------------------------------------- /PACE/src/model.py: -------------------------------------------------------------------------------- 1 | from transformers import pipeline 2 | from transformers import ViTConfig, ViTForImageClassification 3 | from transformers import ViTForImageClassification, ViTConfig 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | import pickle 9 | import os 10 | from math import pi 11 | import numpy as np 12 | from sklearn.mixture import GaussianMixture 13 | import pickle 14 | from numpy import random 15 | import scipy.sparse as sp 16 | from scipy.special import gammaln 17 | from tqdm import tqdm 18 | from sklearn.decomposition import PCA 19 | from sklearn import manifold 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | import pickle 23 | import sys, re, time, string 24 | from scipy.special import gammaln, psi 25 | from numpy.linalg import * 26 | import math 27 | import pandas as pd 28 | from config import parser 29 | #from config_parse_args import parser_args 30 | from utils import accuracy_score, dirichlet_expectation, read_tsv_file, compute_metrics, Adam, posterior_mu_sigma 31 | import time 32 | from torch.nn import functional as F 33 | from scipy.special import softmax 34 | from augment import contrastive_learning, contrative_transform 35 | from utils import Adam 36 | 37 | args = parser.parse_args() 38 | args.save_path = os.path.join(args.save_path, args.name) 39 | 40 | 41 | 42 | class ViTClassify(nn.Module): 43 | def __init__(self, in_dim, out_dim, hid_dim, layer): 44 | super(ViTClassify, self).__init__() 45 | self.in_dim = in_dim 46 | self.out_dim = out_dim 47 | self.hid_dim = hid_dim 48 | self.layer = layer 49 | 50 | 51 | ViT_config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k", output_hidden_states=True, output_attentions=True, num_labels=out_dim) 52 | self.ViT = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', config=ViT_config) 53 | 54 | if not args.require_grad: # vit be none trainable, train PACE instead 55 | for param in self.ViT.parameters(): 56 | param.requires_grad = False 57 | self.linear = nn.Linear(in_dim, hid_dim) 58 | self.embedding = None 59 | 60 | def forward(self, encodings, labels=None): 61 | ViT_output = self.ViT(encodings) 62 | logits = ViT_output['logits'] 63 | all_states = ViT_output['hidden_states'] 64 | attention = ViT_output['attentions'] 65 | states = all_states[self.layer] 66 | hidden = self.linear(states) 67 | self.embedding = all_states[args.layer] 68 | 69 | return logits, hidden, attention 70 | class PACE: 71 | 72 | def __init__(self, d, K, D, N, alpha, C): 73 | ''' 74 | Arguments: 75 | K: number of topics 76 | d: dimension of embedding space 77 | D: number of images 78 | alpha: prior on theta 79 | eta: prior on mu and Sigma 80 | ''' 81 | self._d = d 82 | self._K = K 83 | self._D = D 84 | self._N = N 85 | self._C = C 86 | self._alpha = alpha 87 | self._updatect = 0 88 | self._gamma = None 89 | self._phi = None 90 | self._mu0 = np.random.randn(K,d) * 10 91 | self._mus = np.random.randn(K,d) * 10 92 | self._sigmas = np.array([np.eye(d) for _ in range(K)]) 93 | self._eta = np.random.randn(C,K) # (num_class, num_topic) 94 | self._updatect = 0 95 | self._eps = 1e-50 96 | self._converge = False 97 | self._m_mu = 0 98 | self._m_sigma = 0 99 | self._cnt = 0 100 | self._snum = 0 101 | self._sigma0 = np.array([np.eye(self._d) for _ in range(self._K)]) 102 | self._lr = 1e-2 103 | self._embeds = None 104 | self._delta = None 105 | self.faitful_loss = 0 106 | self.stability_loss = 0 107 | self._adam_delta = Adam(1e-4,0.9,0.9) 108 | self._reg = 1e-3 109 | self._w = None 110 | self.linear = nn.Linear(args.b_dim, args.c_dim).cuda() 111 | self._lrm = torch.empty(args.b_dim, args.c_dim).cuda() 112 | torch.nn.init.normal_(self._lrm) 113 | self._phi_trans = None 114 | self.optimizer = None 115 | 116 | def log_p_y(self, logits): 117 | 118 | logits = logits.cpu().detach().numpy() 119 | log_p = 0 120 | B, _ = logits.shape 121 | for b in range(B): 122 | phi_mean = self._phi[b].mean(0) 123 | pred_y = softmax(logits[b]) 124 | for i in range(self._C): 125 | log_p += pred_y[i] * np.log(np.exp(self._eta[i].dot(phi_mean))/np.sum(np.exp(self._eta.dot(phi_mean)))) 126 | 127 | return log_p 128 | 129 | 130 | def update_eta(self, logits): 131 | logits = logits.cpu().detach().numpy() 132 | pred_y = softmax(logits, axis=1) 133 | phi_mean = self._phi.mean((0,1)) 134 | for i in range(self._C): 135 | self._eta[i] = self._eta[i] - self._lr * (pred_y[:,i].mean() - np.exp(self._eta[i].dot(phi_mean))/np.sum(np.exp(self._eta.dot(phi_mean)))) * phi_mean 136 | 137 | 138 | 139 | def do_e_step(self, embeds, frac=None): 140 | batchD = len(embeds) 141 | 142 | 143 | if args.frac == 'equal': 144 | self._w = torch.Tensor(1).cuda() 145 | elif args.frac=='fix': # fixed length attention 146 | self._w = frac.mean(1)[:,0,:] 147 | 148 | phi = random.gamma(self._K, 1./self._K, (batchD, embeds.size()[1], self._K)) 149 | gamma = phi.sum(1) 150 | 151 | 152 | it = 0 153 | meanchange = 0 154 | sigma_invs = [] 155 | sigma_dets = [] 156 | for i in range(self._K): 157 | sigma_inv = inv(self._sigmas[i] + self._eps * np.eye(self._d)) 158 | sigma_invs.append(sigma_inv) 159 | sigma_det = det(self._sigmas[i]) 160 | sigma_dets.append(sigma_det) 161 | 162 | sigma_invs = torch.Tensor(np.array(sigma_invs)) 163 | sigma_invs = Variable(sigma_invs, requires_grad=True).cuda() 164 | sigma_dets = torch.Tensor(np.array(sigma_dets)) 165 | sigma_dets = Variable(sigma_dets, requires_grad=True).cuda() 166 | self._delta = 0 167 | self._delta = Variable(torch.Tensor(self._delta), requires_grad=True) 168 | 169 | tensor_mus = Variable(torch.Tensor(np.array(self._mus)),requires_grad=False).cuda() 170 | tensor_sigmas = Variable(torch.Tensor(np.array(self._sigmas)),requires_grad=False).cuda() 171 | it = 0 172 | meanchange = 0 173 | sigma_invs = [] 174 | sigma_dets = [] 175 | for i in range(self._K): 176 | sigma_inv = inv(self._sigmas[i] + self._eps * np.eye(self._d)) 177 | sigma_invs.append(sigma_inv) 178 | sigma_det = det(self._sigmas[i]) 179 | sigma_dets.append(sigma_det) 180 | sigma_invs = torch.Tensor(np.array(sigma_invs)) 181 | sigma_invs = Variable(sigma_invs, requires_grad=True).cuda() 182 | sigma_dets = torch.Tensor(np.array(sigma_dets)) 183 | sigma_dets = Variable(sigma_dets, requires_grad=True).cuda() 184 | self._delta = 0 185 | self._delta = Variable(torch.Tensor(self._delta), requires_grad=True) 186 | 187 | # Iterate between gamma and phi until convergence 188 | tensor_mus = Variable(torch.Tensor(np.array(self._mus)),requires_grad=False).cuda() 189 | tensor_sigmas = Variable(torch.Tensor(np.array(self._sigmas)),requires_grad=False).cuda() 190 | for it in range(0,10): # train to converge. 191 | 192 | mus = tensor_mus.unsqueeze(-1) 193 | embeds = embeds.view(-1,self._d,1) 194 | 195 | 196 | dir_exp = dirichlet_expectation(gamma) 197 | dir_exp = Variable(torch.Tensor(dir_exp),requires_grad=False).view(-1,self._K).cuda() 198 | spd_tensor = None 199 | phi_tensor = Variable(torch.Tensor(phi),requires_grad=False).cuda() 200 | for i in range(self._K): 201 | mul = torch.matmul((embeds-mus[i,:]).transpose(1,2), sigma_invs[i,:,:]) 202 | spd = 0.5 * torch.matmul(mul,embeds-mus[i,:]).view(-1,self._N) 203 | 204 | if spd_tensor is None: 205 | spd_tensor = spd.unsqueeze(-1) 206 | else: 207 | spd_tensor = torch.cat([spd_tensor,spd.unsqueeze(-1)],dim=2) 208 | 209 | phi_tmp = torch.exp(dir_exp[:,i].view(-1,1)-spd/100 )/torch.sqrt(abs(sigma_dets[i])+self._eps) 210 | 211 | phi_tmp =phi_tmp * self._w 212 | phi[:,:,i] = phi_tmp.cpu().detach().numpy() 213 | 214 | phi = (phi+self._eps)/(phi.sum(-1, keepdims=True) + self._eps * self._K) # along K axis 215 | 216 | self._delta = phi_tensor * spd_tensor 217 | 218 | self._delta = self._delta.mean() * batchD 219 | gamma = self._alpha + phi.sum(1) # along n axis 220 | del dir_exp, phi_tensor, spd_tensor 221 | del sigma_invs, sigma_dets, tensor_mus, tensor_sigmas 222 | self._phi = phi 223 | self._gamma = gamma 224 | 225 | return (gamma, phi) 226 | 227 | def do_cl_e_step(self, embeds, frac=None, y = None): 228 | ''' 229 | frac: attention of last hidden layer, size [B,num_heads,N,N] 230 | ''' 231 | batchD = len(embeds) 232 | f_optimizer = Adam(1e-4,0.9,0.9) 233 | s_optimizer = Adam(1e-4,0.9,0.9) 234 | if args.frac == 'equal': 235 | self._w = torch.Tensor(1).cuda() 236 | elif args.frac=='fix': 237 | self._w = frac.mean(1)[:,0,:] 238 | 239 | 240 | phi = random.gamma(self._K, 1./self._K, (batchD, embeds.size()[1], self._K)) 241 | gamma = phi.sum(1) 242 | 243 | 244 | it = 0 245 | meanchange = 0 246 | sigma_invs = [] 247 | sigma_dets = [] 248 | 249 | for i in range(self._K): 250 | sigma_inv = inv(self._sigmas[i] + self._eps * np.eye(self._d)) 251 | sigma_invs.append(sigma_inv) 252 | sigma_det = det(self._sigmas[i]) 253 | sigma_dets.append(sigma_det) 254 | sigma_invs = torch.Tensor(np.array(sigma_invs)) 255 | sigma_invs = Variable(sigma_invs, requires_grad=True).cuda() 256 | sigma_dets = torch.Tensor(np.array(sigma_dets)) 257 | sigma_dets = Variable(sigma_dets, requires_grad=True).cuda() 258 | self._delta = 0 259 | self._delta = Variable(torch.Tensor(self._delta), requires_grad=True) 260 | 261 | tensor_mus = Variable(torch.Tensor(np.array(self._mus)),requires_grad=False).cuda() 262 | tensor_sigmas = Variable(torch.Tensor(np.array(self._sigmas)),requires_grad=False).cuda() 263 | it = 0 264 | meanchange = 0 265 | sigma_invs = [] 266 | sigma_dets = [] 267 | for i in range(self._K): 268 | sigma_inv = inv(self._sigmas[i] + self._eps * np.eye(self._d)) 269 | sigma_invs.append(sigma_inv) 270 | sigma_det = det(self._sigmas[i]) 271 | sigma_dets.append(sigma_det) 272 | 273 | sigma_invs = torch.Tensor(np.array(sigma_invs)) 274 | sigma_invs = Variable(sigma_invs, requires_grad=True).cuda() 275 | sigma_dets = torch.Tensor(np.array(sigma_dets)) 276 | sigma_dets = Variable(sigma_dets, requires_grad=True).cuda() 277 | self._delta = 0 278 | self._delta = Variable(torch.Tensor(self._delta), requires_grad=True) 279 | 280 | # Iterate between gamma and phi until convergence 281 | tensor_mus = Variable(torch.Tensor(np.array(self._mus)),requires_grad=False).cuda() 282 | tensor_sigmas = Variable(torch.Tensor(np.array(self._sigmas)),requires_grad=False).cuda() 283 | 284 | for it in range(0,10): # train to converge. 285 | mus = tensor_mus.unsqueeze(-1) 286 | embeds = embeds.view(-1,self._d,1) 287 | dir_exp = dirichlet_expectation(gamma) 288 | dir_exp = Variable(torch.Tensor(dir_exp),requires_grad=False).view(-1,self._K).cuda() 289 | spd_tensor = None 290 | phi_tensor = Variable(torch.Tensor(phi),requires_grad=False).cuda() 291 | phi_mean = torch.mean(phi_tensor, dim=1).detach().cpu().numpy() 292 | phi_fair_delta = 1/self._N * (np.einsum('bi,ij->bj',y, self._eta)- np.einsum('bi,ij->bj',np.exp(np.einsum('ij,bj->bi',self._eta, phi_mean)), self._eta)/np.sum(np.exp(np.einsum('ij,bj->bi',self._eta, phi_mean)), axis=1, keepdims=True) ) 293 | phi_fair_delta = torch.from_numpy(phi_fair_delta).cuda() 294 | phi_trans_mean = self._phi_trans.mean(1) 295 | phi_stable_delta, _ = contrastive_learning(torch.from_numpy(phi_mean).cuda(), torch.from_numpy(phi_trans_mean).cuda()) 296 | phi_stable_delta *= -1/self._N 297 | eps = 1e-10 298 | for i in range(self._K): 299 | mul = torch.matmul((embeds-mus[i,:]).transpose(1,2), sigma_invs[i,:,:]) 300 | spd = 0.5 * torch.matmul(mul,embeds-mus[i,:]).view(-1,self._N) 301 | 302 | if spd_tensor is None: 303 | spd_tensor = spd.unsqueeze(-1) 304 | else: 305 | spd_tensor = torch.cat([spd_tensor,spd.unsqueeze(-1)],dim=2) 306 | phi_tmp = torch.exp(dir_exp[:,i].view(-1,1)-spd+self._eps) 307 | phi_tmp =phi_tmp * self._w 308 | phi[:,:,i] = phi_tmp.cpu().detach().numpy() 309 | phi = (phi+self._eps)/(phi.sum(-1, keepdims=True) + self._eps * self._K) # along K axis 310 | 311 | self._delta = phi_tensor * spd_tensor 312 | 313 | self._delta = self._delta.mean() * batchD 314 | relevance_scale = 1 315 | stable_scale = 1 316 | phi = np.clip(phi, 0, 1e10) 317 | phi = (phi+self._eps)/(phi.sum(-1).reshape(batchD,self._N,1) + self._eps * self._K) 318 | # additional grad with second-order optimization 319 | f_grad = f_optimizer.update(phi_fair_delta.unsqueeze(1).detach().cpu().numpy()) 320 | s_grad = s_optimizer.update(phi_stable_delta.unsqueeze(1).detach().cpu().numpy()) 321 | phi += relevance_scale * f_grad + stable_scale * s_grad 322 | phi = (phi+self._eps)/(phi.sum(-1, keepdims=True) + self._eps * self._K) 323 | 324 | gamma = self._alpha + phi.sum(1) # along n axis 325 | del dir_exp, phi_tensor, spd_tensor 326 | del sigma_invs, sigma_dets, tensor_mus, tensor_sigmas 327 | self._phi = phi 328 | self._gamma = gamma 329 | 330 | return (gamma, phi) 331 | 332 | 333 | 334 | def update_lambda(self, embeds): 335 | ''' 336 | update variational parameter lambda for beta 337 | ''' 338 | def do_em_step(self, embeds,frac=None, cl=False, y=None): 339 | ''' 340 | first E step, 341 | then M step 342 | 343 | gamma (B,K) 344 | phi (B,N,K) 345 | embeds (B,N,d) 346 | cl: using contrastive learning 347 | ''' 348 | 349 | if cl and y is not None: 350 | gamma, phi = self.do_cl_e_step(embeds=embeds, frac=frac, y=y) 351 | else: 352 | gamma, phi = self.do_e_step(embeds,frac) 353 | embeds = embeds.cpu().detach().numpy() 354 | batchD = len(embeds) 355 | last_mus = self._mus.copy() 356 | last_sigmas = self._sigmas.copy() 357 | 358 | mu_init = 0 359 | sigma_init = 0 360 | norm = 0 361 | gamma = gamma.reshape(-1, self._K) 362 | phi = phi.reshape(-1,self._K) 363 | embeds = embeds.reshape(-1,self._d) 364 | mu_init = phi.reshape(-1,self._K,1) * embeds.reshape(-1,1,self._d) 365 | mu_init = mu_init * self._w.cpu().detach().numpy().reshape(-1,1,1) 366 | mu_init = mu_init.sum(0) 367 | delta = embeds.reshape(-1,1,self._d)-self._mus.reshape(1,-1,self._d) 368 | square = delta.reshape(-1,self._K,1,self._d) * delta.reshape(-1,self._K,self._d, 1) # (B*N,K,d,d) 369 | sigma_init = phi.reshape(-1,self._K,1,1) * square 370 | sigma_init = sigma_init * self._w.cpu().detach().numpy().reshape(-1,1,1,1) 371 | sigma_init = sigma_init.sum(0) 372 | norm = (phi*self._w.cpu().detach().numpy().reshape(-1,1)).sum(0) 373 | 374 | rhot = 0.75 375 | 376 | self._m_mu = self._mus * rhot * self._cnt + mu_init/(norm.reshape(-1,1)+self._eps) * (1-rhot) * batchD 377 | self._m_sigma = self._sigmas * rhot * self._cnt + sigma_init/(norm.reshape(-1,1,1)+self._eps) * (1-rhot) * batchD 378 | self._cnt = self._cnt + batchD * (1-rhot) 379 | self._mus = self._m_mu / self._cnt 380 | self._sigmas = self._m_sigma / self._cnt 381 | 382 | self._snum += 1 383 | for i in range(self._K): 384 | self._mus[i], self._sigmas[i] = posterior_mu_sigma(self._mus[i], self._sigmas[i], self._snum, self._mu0[i]) 385 | 386 | if abs(last_mus - self._mus).max() < 1e-5: 387 | self._converge = True 388 | 389 | phi = phi.reshape(-1,self._N,self._K) 390 | 391 | return gamma, phi 392 | 393 | -------------------------------------------------------------------------------- /PACE/src/utils.py: -------------------------------------------------------------------------------- 1 | from tkinter import image_types 2 | import numpy as np 3 | import pandas as pd 4 | from numpy.linalg import * 5 | from scipy.linalg import sqrtm 6 | from scipy.special import gammaln, psi 7 | import torch 8 | from sklearn.decomposition import PCA 9 | import matplotlib.pyplot as plt 10 | from datasets import load_metric 11 | from config import parser 12 | from transformers import ViTFeatureExtractor, ViTImageProcessor 13 | from datasets import load_dataset 14 | from torch.utils.data import random_split 15 | from torchvision.transforms import (CenterCrop, 16 | Compose, 17 | Normalize, 18 | RandomHorizontalFlip, 19 | RandomResizedCrop, 20 | Resize, 21 | ToTensor) 22 | from torch import FloatTensor, div 23 | from torch.utils.data import DataLoader, Dataset 24 | from torchvision import transforms 25 | from torchvision.transforms.functional import InterpolationMode 26 | 27 | import logging 28 | import os 29 | import random 30 | import pandas as pd 31 | from torchvision.datasets.folder import default_loader 32 | from torchvision.datasets.utils import download_url 33 | from torch.utils.data import Dataset 34 | import pickle 35 | from sklearn.cluster import KMeans 36 | from torchvision.datasets import VisionDataset 37 | import os 38 | from PIL import Image 39 | 40 | 41 | from PIL import Image, ImageOps 42 | 43 | 44 | def attention_norm(attention, a=1): 45 | # attention = attention.mean(1) 46 | # max(attention, 0) 47 | #attention = torch.max(attention, torch.zeros_like(attention)) 48 | attention = attention - attention.min() 49 | 50 | #attention = attention ** a 51 | attention = attention / attention.sum(-1, keepdims=True) 52 | 53 | return attention 54 | 55 | 56 | class StanfordCars_simplified(torch.utils.data.Dataset): 57 | def __init__(self, root, transform = None): 58 | self.images = [os.path.join(root, file) for file in os.listdir(root)] 59 | self.transform = transform 60 | 61 | def __len__(self): 62 | return len(self.images) 63 | 64 | def __getitem__(self, index): 65 | image_file = self.images[index] 66 | image = Image.open(image_file).convert("RGB") 67 | if self.transform: 68 | image = self.transform(image) 69 | return image[None] 70 | 71 | 72 | class StanfordCars(VisionDataset): 73 | """`Stanford Cars `_ Dataset 74 | 75 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is 76 | split into 8,144 training images and 8,041 testing images, where each class 77 | has been split roughly in a 50-50 split 78 | 79 | .. note:: 80 | 81 | This class needs `scipy `_ to load target files from `.mat` format. 82 | 83 | Args: 84 | root (string): Root directory of dataset 85 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. 86 | transform (callable, optional): A function/transform that takes in an PIL image 87 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 88 | target_transform (callable, optional): A function/transform that takes in the 89 | target and transforms it. 90 | download (bool, optional): If True, downloads the dataset from the internet and 91 | puts it in root directory. If dataset is already downloaded, it is not 92 | downloaded again.""" 93 | 94 | def __init__( 95 | self, 96 | root: str, 97 | split = "train", 98 | transform = None, 99 | target_transform = None, 100 | download: bool = False, 101 | ): 102 | 103 | try: 104 | import scipy.io as sio 105 | except ImportError: 106 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") 107 | 108 | super().__init__(root, transform=transform, target_transform=target_transform) 109 | 110 | #self._split = verify_str_arg(split, "split", ("train", "test")) 111 | #self._base_folder = pathlib.Path(root) / "stanford_cars" 112 | self._split = split 113 | self._base_folder = root 114 | devkit = os.path.join(self._base_folder, "devkit") 115 | 116 | 117 | if self._split == "train": 118 | self._annotations_mat_path = os.path.join(devkit , "cars_train_annos.mat") 119 | self._images_base_path = os.path.join(self._base_folder , "cars_train") 120 | else: 121 | self._annotations_mat_path = os.path.join(self._base_folder , "cars_test_annos_withlabels.mat") 122 | self._images_base_path = os.path.join(self._base_folder , "cars_test") 123 | 124 | if download: 125 | self.download() 126 | 127 | self._samples = [ 128 | ( 129 | str(os.path.join(self._images_base_path , annotation["fname"])), 130 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1 131 | ) 132 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] 133 | ] 134 | 135 | self.classes = sio.loadmat(str(os.path.join(devkit , "cars_meta.mat")), squeeze_me=True)["class_names"].tolist() 136 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} 137 | 138 | def __len__(self): 139 | return len(self._samples) 140 | 141 | def __getitem__(self, idx: int): 142 | """Returns pil_image and class_id for given index""" 143 | image_path, target = self._samples[idx] 144 | pil_image = Image.open(image_path).convert("RGB") 145 | 146 | if self.transform is not None: 147 | pil_image = self.transform(pil_image) 148 | if self.target_transform is not None: 149 | target = self.target_transform(target) 150 | return pil_image, target 151 | 152 | def download(self): 153 | if self._check_exists(): 154 | return 155 | def _check_exists(self): 156 | if not (self._base_folder / "devkit").is_dir(): 157 | return False 158 | 159 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir() 160 | 161 | 162 | 163 | 164 | def train_transforms(examples): 165 | examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']] 166 | return examples 167 | 168 | def val_transforms(examples): 169 | examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']] 170 | return examples 171 | 172 | 173 | 174 | args = parser.parse_args() 175 | 176 | def accuracy_score(labels, preds): 177 | acc = (preds==labels).astype(np.float).mean() 178 | return acc 179 | 180 | def compute_metrics_acc(pred): 181 | #labels = pred.label_ids 182 | labels, preds = pred.predictions 183 | 184 | #print('labels',labels) 185 | #print('preds',preds) 186 | #precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary') 187 | acc = accuracy_score(labels, preds) 188 | return { 189 | 'accuracy': acc, 190 | #'f1': f1, 191 | #'precision': precision, 192 | #'recall': recall 193 | } 194 | 195 | def compute_metrics(pred): 196 | #labels = pred.label_ids 197 | labels, preds = pred.predictions 198 | #metric = load_metric('glue', args.task) 199 | metric =load_metric('accuracy') 200 | return metric.compute(predictions=preds, references=labels) 201 | 202 | def dirichlet_expectation(alpha): 203 | ''' 204 | E[log(theta)|alpha], where theta ~ Dir(alpha). 205 | from blei/online LDA 206 | ''' 207 | if len(alpha.shape) == 1: # 1D version 208 | return psi(alpha) - psi(np.sum(alpha)) 209 | return psi(alpha) - psi(np.sum(alpha,1))[:, np.newaxis] 210 | 211 | 212 | def read_tsv_file(file_path): 213 | df = pd.read_csv(file_path,sep='\t') 214 | seq = df['sentence'] 215 | return seq 216 | 217 | 218 | 219 | class DynamicCrop(object): 220 | def __init__(self, is_train=True): 221 | self.is_train = is_train 222 | 223 | def __call__(self, img): 224 | w, h = img.size 225 | crop_size = min(w, h) 226 | 227 | left_margin = (w - crop_size) / 2 228 | top_margin = (h - crop_size) / 2 229 | 230 | # Random crop for training 231 | if self.is_train: 232 | left_margin = random.randint(0, w - crop_size) 233 | top_margin = random.randint(0, h - crop_size) 234 | 235 | img = img.crop((left_margin, top_margin, left_margin + crop_size, top_margin + crop_size)) 236 | return img 237 | 238 | def build_transform(output_size, is_train=True): 239 | """ 240 | Get the appropriate image transformation based on the training/testing phase. 241 | 242 | Parameters: 243 | - output_size (int or tuple): Size for resizing the cropped image. 244 | - is_train (bool): If True, random crop and resize are performed. Otherwise, center crop and resize. 245 | 246 | Returns: 247 | - torchvision.transforms.Compose: A composition of transformations. 248 | """ 249 | return transforms.Compose([ 250 | DynamicCrop(is_train=is_train), 251 | transforms.Resize(output_size), 252 | transforms.ToTensor() 253 | ]) 254 | 255 | 256 | def build_transform_prev(crop_size, output_size, is_train=True): 257 | """ 258 | Get the appropriate image transformation based on the training/testing phase. 259 | 260 | Parameters: 261 | - crop_size (int or tuple): Size for cropping. If int, a square crop is made. 262 | - output_size (int or tuple): Size for resizing the cropped image. 263 | - train (bool): If True, random crop and resize are performed. Otherwise, center crop and resize. 264 | 265 | Returns: 266 | - torchvision.transforms.Compose: A composition of transformations. 267 | """ 268 | if is_train: 269 | return transforms.Compose([ 270 | transforms.RandomCrop(crop_size), 271 | transforms.Resize(output_size), 272 | transforms.ToTensor() 273 | ]) 274 | else: 275 | return transforms.Compose([ 276 | transforms.CenterCrop(crop_size), 277 | transforms.Resize(output_size), 278 | transforms.ToTensor() 279 | ]) 280 | 281 | 282 | class Cub2011(Dataset): 283 | base_folder = 'CUB_200_2011/images' 284 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 285 | filename = 'CUB_200_2011.tgz' 286 | tgz_md5 = '97eceeb196236b17998738112f37df78' 287 | 288 | def __init__(self, root, train=True, transform=None, loader=default_loader, download=True): 289 | self.root = os.path.expanduser(root) 290 | self.transform = transform 291 | self.loader = default_loader 292 | self.train = train 293 | 294 | if download: 295 | self._download() 296 | 297 | if not self._check_integrity(): 298 | raise RuntimeError('Dataset not found or corrupted.' + 299 | ' You can use download=True to download it') 300 | 301 | def _load_metadata(self): 302 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', 303 | names=['img_id', 'filepath']) 304 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), 305 | sep=' ', names=['img_id', 'target']) 306 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), 307 | sep=' ', names=['img_id', 'is_training_img']) 308 | 309 | data = images.merge(image_class_labels, on='img_id') 310 | self.data = data.merge(train_test_split, on='img_id') 311 | 312 | if self.train: 313 | self.data = self.data[self.data.is_training_img == 1] 314 | else: 315 | self.data = self.data[self.data.is_training_img == 0] 316 | 317 | def _check_integrity(self): 318 | try: 319 | self._load_metadata() 320 | except Exception: 321 | return False 322 | 323 | for index, row in self.data.iterrows(): 324 | filepath = os.path.join(self.root, self.base_folder, row.filepath) 325 | if not os.path.isfile(filepath): 326 | print(filepath) 327 | return False 328 | return True 329 | 330 | def _download(self): 331 | import tarfile 332 | 333 | if self._check_integrity(): 334 | print('Files already downloaded and verified') 335 | return 336 | 337 | download_url(self.url, self.root, self.filename, self.tgz_md5) 338 | 339 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 340 | tar.extractall(path=self.root) 341 | 342 | def __len__(self): 343 | return len(self.data) 344 | 345 | def __getitem__(self, idx): 346 | sample = self.data.iloc[idx] 347 | path = os.path.join(self.root, self.base_folder, sample.filepath) 348 | #print('sample', sample) 349 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0 350 | img = self.loader(path) 351 | #print(img) 352 | #print(target) 353 | if self.transform is not None: 354 | img = self.transform(img)['pixel_values'][0] 355 | return {'encodings':img, 'labels':target, 'path': sample.filepath} 356 | 357 | class MyImageDataset(Dataset): 358 | """Dataset class for Image""" 359 | def __init__(self, dataset, labels, transform=None, normalize=None): 360 | super(MyImageDataset, self).__init__() 361 | assert(len(dataset) == len(labels)) 362 | self.dataset = dataset 363 | self.labels = labels 364 | self.transform = transform 365 | self.normalize = normalize 366 | 367 | #print(self.labels) 368 | 369 | def __len__(self): 370 | return len(self.dataset) 371 | 372 | def __getitem__(self, idx): 373 | data = self.dataset[idx] 374 | 375 | if self.transform: 376 | data = self.transform(data) 377 | img_to_tensor = transforms.ToTensor() 378 | # if data is not tensor 379 | if not isinstance(data, torch.Tensor): 380 | data = img_to_tensor(data) 381 | if self.normalize: 382 | data = self.normalize(data) 383 | 384 | return {'encodings':data, 'labels':self.labels[idx]} 385 | 386 | 387 | class ImageNetDataset(Dataset): 388 | """Dataset class for ImageNet""" 389 | def __init__(self, dataset, labels, transform=None, normalize=None): 390 | super(ImageNetDataset, self).__init__() 391 | assert(len(dataset) == len(labels)) 392 | self.dataset = dataset 393 | self.labels = labels 394 | self.transform = transform 395 | self.normalize = normalize 396 | #print(self.labels) 397 | 398 | def __len__(self): 399 | return len(self.dataset) 400 | 401 | def __getitem__(self, idx): 402 | data = self.dataset[idx] 403 | 404 | if self.transform: 405 | data = self.transform(data) 406 | img_to_tensor = transforms.ToTensor() 407 | data = img_to_tensor(data) 408 | 409 | data = div(data, 255) 410 | # why would there by (1,224,224) samples? 411 | if data.size()[0] == 1: 412 | data = data.repeat(3,1,1) 413 | if self.normalize: 414 | data = self.normalize(data) 415 | 416 | return {'encodings':data, 'labels':self.labels[idx]} 417 | 418 | class MyImageDatasetFromStanfordCars(Dataset): 419 | def __init__(self, stanford_cars_dataset, transform=None, normalize=None): 420 | super(MyImageDatasetFromStanfordCars, self).__init__() 421 | self.stanford_cars_dataset = stanford_cars_dataset 422 | self.transform = transform 423 | self.normalize = normalize 424 | 425 | def __len__(self): 426 | return len(self.stanford_cars_dataset) 427 | 428 | def __getitem__(self, idx): 429 | data, label = self.stanford_cars_dataset[idx] 430 | 431 | if self.transform: 432 | data = self.transform(data) 433 | img_to_tensor = transforms.ToTensor() 434 | # if data is not tensor 435 | if not isinstance(data, torch.Tensor): 436 | data = img_to_tensor(data) 437 | if self.normalize: 438 | data = self.normalize(data) 439 | 440 | return {'encodings': data, 'labels': label} 441 | 442 | 443 | def load_train_data(data_dir, dataset, img_size, magnitude, batch_size): 444 | with open(data_dir, 'rb') as f: 445 | ds = pickle.load(f) 446 | train_data = ds['image'] 447 | train_labels = ds['label'] 448 | transform = transforms.Compose([ 449 | transforms.Resize(img_size, interpolation=InterpolationMode.BICUBIC), 450 | transforms.RandAugment(num_ops=2,magnitude=magnitude), 451 | ]) 452 | train_dataset = dataset(train_data, train_labels, transform, 453 | normalize=transforms.Compose([ 454 | transforms.Normalize( 455 | mean=(0.485, 0.456, 0.406), 456 | std=(0.229, 0.224, 0.225) 457 | ) 458 | ]), 459 | ) 460 | train_loader = DataLoader( 461 | train_dataset, 462 | shuffle=True, 463 | batch_size=batch_size, 464 | num_workers=8, 465 | pin_memory=True, 466 | drop_last=True, 467 | ) 468 | f.close() 469 | return train_dataset, train_loader 470 | 471 | def load_val_data(data_dir, dataset, img_size, batch_size): 472 | with open(data_dir, 'rb') as f: 473 | ds = pickle.load(f) 474 | val_data = ds['image'] 475 | val_labels = ds['label'] 476 | transform = transforms.Compose([ 477 | transforms.Resize(img_size, interpolation=InterpolationMode.BICUBIC), 478 | ]) 479 | val_dataset = dataset(val_data, val_labels, transform, 480 | normalize=transforms.Compose([ 481 | transforms.Normalize( 482 | mean=(0.485, 0.456, 0.406), 483 | std=(0.229, 0.224, 0.225) 484 | ), 485 | ]) 486 | ) 487 | val_loader = DataLoader( 488 | val_dataset, 489 | batch_size=batch_size, 490 | shuffle=True, 491 | num_workers=4, 492 | pin_memory=True 493 | ) 494 | f.close() 495 | return val_dataset, val_loader 496 | 497 | class Adam: 498 | def __init__(self, alpha=1e-4, beta1=0.9, beta2=0.9): 499 | self.alpha = alpha 500 | self.beta1 = beta1 501 | self.beta2 = beta2 502 | self.m = 0 503 | self.v = 0 504 | self.t = 0 505 | self.eps = 1e-5 506 | def update(self, g): 507 | self.t += 1 508 | self.m = self.beta1 * self.m + (1-self.beta1) * g 509 | self.v = self.beta2 * self.v + (1-self.beta2) * g**2 510 | self.m = self.m / (1-self.beta1**self.t) 511 | self.v = self.v / (1-self.beta2**self.t) 512 | return self.alpha * self.m / (np.sqrt(self.v+self.eps)+self.eps) 513 | 514 | def posterior_mu(mu, mu0, n, n0=1): 515 | res = (n0*mu0+n*mu)/(n+n0) 516 | return res 517 | 518 | 519 | # initializing should adapt to the data (mean, variance, norm...) 520 | def posterior_mu_sigma(mu, sigma, n, mu0, n0=500, lamda0=None): 521 | #lamda = inv(np.matmul(sigma, sigma.T)) 522 | #print('size', sigma.size) 523 | # fix mu, update sigma 524 | if lamda0 is None: 525 | lamda0 = np.eye(sigma.shape[-1]) 526 | 527 | post_mu = (n0*mu0+n*mu)/(n+n0) 528 | #post_sigma = (n0* lamda0 + n* sigma) / (n0+n) 529 | #print('prod', np.matmul((mu-mu0).T,mu-mu0)) 530 | lamda = (n0* lamda0 + n* inv(sigma)) / (n+n0) #+ n*n0/(n0+n) * np.matmul((mu-mu0).T,mu-mu0) 531 | post_sigma = (n0*lamda0 + n*inv(lamda) )/ (n+n0) 532 | #print('post_sigma', det(post_sigma)) 533 | return post_mu, post_sigma 534 | 535 | 536 | def get_tf_idf_mask(corpus_words): 537 | #corpus_words = [[w for w in seq if w is not 0] for seq in corpus_words] 538 | corpus_bow = [list(Counter(seq).items()) for seq in corpus_words] 539 | 540 | tfidf = TfidfModel(corpus_bow, normalize=False) 541 | #filter low value words 542 | low_value = 4 543 | filtered_ids = [] 544 | tfidf_mask = [] 545 | for i in range(0, len(corpus_bow)): 546 | bow = corpus_bow[i] 547 | low_value_words = [] #reinitialize to be safe. You can skip this. 548 | low_value_words = [id for id, value in tfidf[bow] if value < low_value] 549 | new_mask = [] 550 | for idx, w in enumerate(corpus_words[i]): 551 | if w in low_value_words or w==0 or idx==0: # consider padding, 'CLS' token 552 | new_mask.append(0) 553 | else: 554 | new_mask.append(1) 555 | #reassign 556 | tfidf_mask.append(new_mask) 557 | #print(new_mask) 558 | return tfidf_mask 559 | 560 | 561 | def topic_vis(mus, coherence, topic): 562 | pca = PCA(n_components=2) 563 | pc = pca.fit_transform(mus) 564 | ax = plt.subplot() 565 | for i in topic: 566 | coh = coherence[i] 567 | coh = round(coh, 1) 568 | ax.annotate( str(i), (pc[i,0], pc[i,1]), size=10) # + '-' + str(coh) 569 | 570 | ax.scatter(pc[i,0], pc[i, 1]) 571 | plt.savefig('topic_scatter_vis.jpg') 572 | return ax 573 | 574 | def vis(x,name,topic,patch=None, pos=None, full_img=None, sigmas = None, div= None, coh=None, top_topics = None, indiv_coh=None, attentions=None): 575 | ''' 576 | x: features 577 | K: topic numbers 578 | topic: concept of patch/word 579 | patch: image of patch, numpy array (optional) 580 | ''' 581 | #if pos is not None: 582 | # print(pos) 583 | #print('topic', topic) 584 | # Get the Matplotlib logger 585 | mpl_logger = logging.getLogger('matplotlib') 586 | 587 | # Set the level to WARNING or higher 588 | mpl_logger.setLevel(logging.WARNING) 589 | 590 | mean = x.mean(axis=0) 591 | var = x.var(axis=0) 592 | #print(x.shape) 593 | #print(mean.shape) 594 | #print(var.shape) 595 | ''' 596 | outer_indices = [i for i in range(x.shape[0]) if (abs(x[i]-mean) > 3*np.sqrt(var)).any()] 597 | print('outer', outer_indices) 598 | x = np.delete(x, outer_indices, axis=0) 599 | name = np.delete(name, outer_indices, axis=0) 600 | topic = np.delete(topic, outer_indices, axis=0) 601 | if patch is not None: 602 | patch = np.delete(patch, outer_indices, axis=0) 603 | if pos is not None: 604 | pos = np.delete(pos, outer_indices, axis=0) 605 | if full_img is not None: 606 | full_img = np.delete(full_img, outer_indices, axis=0) 607 | #if sigmas is not None: 608 | # sigmas = np.delete(sigmas, outer_indices, axis=0) 609 | ''' 610 | pca = PCA(n_components=2) 611 | pc = pca.fit_transform(x) 612 | colors = [ 'blue', 'red','purple', 'green','lime', 'cyan', 'orange', 'pink','black', 'grey'] * 20 613 | # for i, txt in enumerate(name): 614 | # ax.annotate(txt, (pc[i,0], pc[i,1]), 615 | # color=colors[topic[i]], size=10) 616 | # ax.scatter(pc[:,0], pc[:, 1]) 617 | ''' 618 | fig = plt.figure(figsize=(10,10)) 619 | axs = fig.add_subplot(1,1,1) 620 | axs.set_xlabel('pc 1', fontsize=15) 621 | axs.set_ylabel('pc 2', fontsize=15) 622 | axs.set_title('2D PCA', fontsize=20) 623 | print(topic) 624 | for i, txt in enumerate(topic): 625 | axs.annotate(txt, (pc[i,0], pc[i,1]), 626 | size=10,color=colors[name[i]]) 627 | axs.scatter(pc[:,0], pc[:, 1]) 628 | plt.savefig('topic_vis.jpg') 629 | plt.savefig('topic_vis.pdf') 630 | 631 | scale = np.linalg.norm(pc) / np.linalg.norm(x) 632 | ''' 633 | #for i in range(args.K): 634 | # print('concept ', i, sigmas[i] * scale) 635 | 636 | # f, axarr = plt.subplots(2,2) 637 | # axarr[0,0].imshow(image_datas[0]) 638 | # Title for the whole figure 639 | 640 | 641 | grid_size = 20 642 | scale = np.sqrt(var).mean() 643 | # Define center point and range 644 | #x0, y0 = 5, 5 # replace with your desired center point 645 | pc = pc * 20 646 | tt = 24 647 | x0, y0 = int(pc[tt,0]), int(pc[tt,1]) 648 | #x0/=100 649 | #y0/=100 650 | x0, y0 = 0,0 651 | print('center', x0, y0) 652 | print('pc x', pc[100:150,0]) 653 | print('pc y', pc[100:150,1]) 654 | # norm pc[0,:], pc[1,:] to be in range [-10,10] 655 | # pc[0,:] = pc[0,:] - pc[0,:].mean() 656 | # pc[1,:] = pc[1,:] - pc[1,:].mean() 657 | # pc[0,:] = pc[0,:] / pc[0,:].max() * 10 658 | # pc[1,:] = pc[1,:] / pc[1,:].max() * 10 659 | 660 | r = 10 # replace with your desired range 661 | 662 | gflag = np.zeros((grid_size,grid_size)) 663 | max_att = np.zeros((grid_size,grid_size)) 664 | f, arr = plt.subplots(grid_size,grid_size,figsize=(grid_size,grid_size)) 665 | 666 | 667 | f.suptitle(f'div {div}, coh {coh}', fontsize=20) # Add your title here 668 | #print(pc) 669 | for i in range(grid_size): 670 | for j in range(grid_size): 671 | arr[i,j].axis('off') 672 | #print('patch', patch) 673 | #print nearest distance from i (i<100) to j (j>=100) 674 | 675 | 676 | for i in range(100): 677 | min_dist = 100000 678 | min_coord = (0,0) 679 | print('orig', pc[i,0], pc[i,1]) 680 | for j in range(100,pc.shape[0]): 681 | dist = np.linalg.norm(pc[i]-pc[j]) 682 | if dist < min_dist: 683 | min_dist = dist 684 | min_coord = (pc[j,0],pc[j,1]) 685 | print('i=',i,', min_dist', min_dist, min_coord) 686 | 687 | patch = np.array(patch) 688 | mean_coord = patch.mean(axis=0) 689 | for i, _ in enumerate(patch): 690 | #if topic[i] not in top_topics: 691 | # continue 692 | ax = int(pc[i,0]) - int(x0) + grid_size//2 693 | ay = int(pc[i,1]) - int(y0) + grid_size//2 694 | #print(ax, ay, grid_size) 695 | #if topic[i] == tt: 696 | # print('topic ',tt, ax, ay) 697 | if abs(ax+1/2-grid_size/2)>=grid_size/2 or abs(ay+1/2-grid_size/2)>=grid_size/2: 698 | continue 699 | if int(gflag[ax,ay]): 700 | ''' 701 | if attentions is None: 702 | continue 703 | elif attentions[i] < max_att[ax,ay]: 704 | continue 705 | else: 706 | max_att[ax,ay] = attentions[i] 707 | ''' 708 | if np.linalg.norm(patch[i]-mean_coord) < max_att[ax,ay]: 709 | continue 710 | max_att[ax,ay] = np.linalg.norm(patch[i]-mean_coord) 711 | gflag[ax,ay] = 1 712 | 713 | 714 | img = arr[ax,ay].imshow(patch[i], interpolation='nearest') 715 | img.set_cmap('hot') 716 | plt.axis('off') 717 | 718 | 719 | # Removes axis numbers 720 | 721 | #arr[ax,ay].text(-0.1, 1.1, pos[i], transform=arr[ax,ay].transAxes, 722 | # size=10, weight='bold',color=colors[name[i]]) 723 | arr[ax,ay].text(-0.1, 1.1, topic[i], transform=arr[ax,ay].transAxes, 724 | size=20, weight='bold',color=colors[name[i]]) 725 | 726 | # we want to center at (5, 5). So, we set the limits to go from 0 to 10. 727 | #arr.set_xlim(0, 10) 728 | #arr.set_ylim(0, 10) 729 | 730 | # Setting x-ticks and y-ticks to clearly see the center of image at (5,5) 731 | #arr.set_xticks(np.arange(0,11,1)) 732 | #arr.set_yticks(np.arange(0,11,1)) 733 | plt.savefig('dataset_vis.jpg') 734 | plt.savefig('dataset_vis.pdf') 735 | for ii in top_topics: 736 | print(indiv_coh[ii]) 737 | 738 | 739 | grid_size = 20 740 | gflag = np.zeros((grid_size,grid_size)) 741 | max_att = np.zeros((grid_size,grid_size)) 742 | f, arr = plt.subplots(grid_size,grid_size,figsize=(20,20)) 743 | #print(pc) 744 | for i in range(grid_size): 745 | for j in range(grid_size): 746 | arr[i,j].axis('off') 747 | for i, _ in enumerate(full_img): 748 | #if topic[i] not in top_topics: 749 | # continue 750 | ax = int(pc[i,0]) - int(x0) + grid_size//2 751 | ay = int(pc[i,1]) - int(y0) + grid_size//2 752 | #arr[ax,ay].axis('off') 753 | 754 | if abs(ax+1/2-grid_size/2)>=grid_size/2 or abs(ay+1/2-grid_size/2)>=grid_size/2: 755 | continue 756 | if int(gflag[ax,ay]): 757 | if np.linalg.norm(patch[i]-mean_coord) < max_att[ax,ay]: 758 | continue 759 | max_att[ax,ay] = np.linalg.norm(patch[i]-mean_coord) 760 | gflag[ax,ay] = 1 761 | img = arr[ax,ay].imshow(full_img[i], interpolation='nearest') 762 | img.set_cmap('hot') 763 | plt.axis('off') 764 | # Removes axis numbers 765 | 766 | arr[ax,ay].text(-0.1, 1.1, pos[i], transform=arr[ax,ay].transAxes, 767 | size=10, weight='bold',color=colors[name[i]]) 768 | #arr[ax,ay].text(-0.1, 1.1, topic[i], transform=arr[ax,ay].transAxes, 769 | # size=20, weight='bold',color=colors[name[i]]) 770 | plt.savefig('dataset_image_vis.jpg') 771 | plt.savefig('dataset_image_vis.pdf') 772 | 773 | 774 | grid_size = 20 775 | gflag = np.zeros((grid_size,grid_size)) 776 | max_att = np.zeros((grid_size,grid_size)) 777 | f, arr = plt.subplots(grid_size,grid_size,figsize=(20,20)) 778 | #print(pc) 779 | for i in range(grid_size): 780 | for j in range(grid_size): 781 | arr[i,j].axis('off') 782 | for i, _ in enumerate(full_img): 783 | #if topic[i] not in top_topics: 784 | # continue 785 | if pos[i] == (-1,-1): 786 | continue 787 | ax = int(pc[i,0]) - int(x0) + grid_size//2 788 | ay = int(pc[i,1]) - int(y0) + grid_size//2 789 | #arr[ax,ay].axis('off') 790 | # set the pos in image i to be stripes of black and white 791 | for ii in range(pos[i][0]*16, pos[i][0]*16+16): 792 | for jj in range(pos[i][1]*16, pos[i][1]*16+16): 793 | if (ii+jj)%2 == 0: 794 | full_img[i][ii,jj] = 0 795 | else: 796 | full_img[i][ii,jj] = 1 797 | 798 | if abs(ax+1/2-grid_size/2)>=grid_size/2 or abs(ay+1/2-grid_size/2)>=grid_size/2: 799 | continue 800 | if int(gflag[ax,ay]): 801 | if np.linalg.norm(patch[i]-mean_coord) < max_att[ax,ay]: 802 | continue 803 | max_att[ax,ay] = np.linalg.norm(patch[i]-mean_coord) 804 | gflag[ax,ay] = 1 805 | img = arr[ax,ay].imshow(full_img[i], interpolation='nearest') 806 | img.set_cmap('hot') 807 | plt.axis('off') 808 | # Removes axis numbers 809 | 810 | arr[ax,ay].text(-0.1, 1.1, pos[i], transform=arr[ax,ay].transAxes, 811 | size=10, weight='bold',color=colors[name[i]]) 812 | #arr[ax,ay].text(-0.1, 1.1, topic[i], transform=arr[ax,ay].transAxes, 813 | # size=20, weight='bold',color=colors[name[i]]) 814 | plt.savefig('dataset_image_masked_vis.jpg') 815 | plt.savefig('dataset_image_masked_vis.pdf') 816 | 817 | def plot_topics(topic_patches, topic_attention, topic_image, topic_masked_image, topic_labels): 818 | #num_topic, num_patch = len(topic_patches), len(topic_patches[0]) 819 | #matched_class = [17, 0, -1, 10,18, 4,8, 12,6, 7] 820 | order = [0,6,2,8,1,5,9,3,4,7] 821 | oo = [0,1,2,3,4,5,6,7,8,9] 822 | mp = dict(zip(order,oo)) 823 | matched_class = [0,0,0,1,0,1,0,0,0,1] 824 | num_topic, num_patch = 10, 40 825 | width_per_subplot = 0.5 826 | height_per_subplot = 0.5 827 | custom_figsize=(num_patch * width_per_subplot, num_topic * height_per_subplot) 828 | f, arr = plt.subplots(num_topic, num_patch, figsize=custom_figsize) 829 | 830 | for i in range(num_topic): 831 | if matched_class[i] == -1: 832 | continue 833 | selected_idx = np.where(topic_labels[i] == matched_class[i])[0] 834 | while len(selected_idx) < num_patch: 835 | ll = selected_idx[-1].item() 836 | if ll == 199: 837 | ll -=50 838 | selected_idx = np.concatenate((selected_idx, np.array([ll+1]))) 839 | #import pdb; pdb.set_trace() 840 | topic_patches[i][:len(selected_idx)] = topic_patches[i][selected_idx] 841 | #topic_attention[i] = topic_attention[i][selected_idx] 842 | topic_image[i] = [topic_image[i][j] for j in selected_idx] 843 | topic_masked_image[i] = [topic_masked_image[i][j] for j in selected_idx] 844 | topic_labels[i][:len(selected_idx)] = topic_labels[i][selected_idx] 845 | 846 | 847 | #f, arr = plt.subplots(num_topic,num_patch,figsize=(num_topic,num_patch)) 848 | for i in range(num_topic): 849 | for j in range(num_patch): 850 | arr[i,j].axis('off') 851 | if j >= num_patch: 852 | continue 853 | #import pdb 854 | #pdb.set_trace() 855 | topic_attention[i][j] = int(1000*topic_attention[i][j])/1000 856 | #arr[i,j].text(-0.1, 1.1, topic_attention[i][j], transform=arr[i,j].transAxes, 857 | #size=10, weight='bold') 858 | img = arr[mp[i],j].imshow(topic_patches[i][j], interpolation='nearest') 859 | img.set_cmap('hot') 860 | #f.subplots_adjust(wspace=0.1, hspace=0) # Adjust the spacing between subplots here 861 | plt.savefig('topic_patch_vis.jpg') 862 | plt.savefig('topic_patch_vis.pdf') 863 | 864 | f, arr = plt.subplots(num_topic,num_patch,figsize=custom_figsize) #(num_topic,num_patch) 865 | for i in range(num_topic): 866 | for j in range(num_patch): 867 | arr[i,j].axis('off') 868 | if j >= num_patch: 869 | continue 870 | 871 | topic_attention[i][j] = int(1000*topic_attention[i][j])/1000 872 | #arr[i,j].text(-0.1, 1.1, topic_attention[i][j], transform=arr[i,j].transAxes, 873 | #size=10, weight='bold') 874 | img = arr[mp[i],j].imshow(topic_image[i][j], interpolation='nearest') 875 | img.set_cmap('hot') 876 | 877 | plt.savefig('topic_image_vis.jpg') 878 | plt.savefig('topic_image_vis.pdf') 879 | 880 | 881 | f, arr = plt.subplots(num_topic,num_patch,figsize=custom_figsize) 882 | for i in range(num_topic): 883 | for j in range(num_patch): 884 | arr[i,j].axis('off') 885 | if j >= num_patch: 886 | continue 887 | 888 | topic_attention[i][j] = int(1000*topic_attention[i][j])/1000 889 | arr[i,j].text(-0.1, 1.1, topic_labels[i][j], transform=arr[i,j].transAxes, 890 | size=5, weight='bold') 891 | img = arr[mp[i],j].imshow(topic_masked_image[i][j], interpolation='nearest') 892 | img.set_cmap('hot') 893 | 894 | plt.savefig('topic_masked_image_vis.jpg') 895 | plt.savefig('topic_masked_image_vis.pdf') 896 | return f, arr 897 | 898 | 899 | def run_kmeans(X,K): 900 | # initialize centers, return mean and variance 901 | #centers = kmeans_init(X,K) 902 | # run kmeans algorithm for 1/2 iterations and return the mean and variance 903 | kmeans = KMeans(n_clusters=args.K,random_state=0, n_init=args.K, max_iter=10) 904 | ret = kmeans.fit(X) 905 | mean = ret.cluster_centers_ 906 | variance = 0 907 | print(mean.mean(0)) 908 | print(X.mean(axis=0)) 909 | return mean#, variance 910 | 911 | 912 | 913 | def kmeans_init_prev(X,K, init_idx=None): 914 | centers = [] 915 | index = [] 916 | avg = 0 917 | if init_idx is None: 918 | init_idx = np.random.randint(0,len(X)) 919 | avg = X[init_idx] 920 | centers.append(avg) 921 | 922 | else: 923 | for idx, j in enumerate(init_idx): 924 | avg = (avg*idx + X[j])/(idx+1) 925 | centers.append(X[j]) 926 | index.append(j) 927 | #init_len = len(centers) 928 | for idx in range(K): 929 | if len(centers) == K: 930 | break 931 | dist = 0 932 | new_center = None 933 | new_index = None 934 | for j in range(len(X)): 935 | if np.linalg.norm(X[j]-avg) > dist: 936 | dist = np.linalg.norm(X[j]-avg) 937 | new_center = X[j] 938 | new_index = j 939 | print(idx, dist) 940 | avg = (avg*len(centers) + new_center)/(len(centers)+1) 941 | centers.append(new_center) 942 | index.append(new_index) 943 | 944 | 945 | 946 | return centers, index 947 | 948 | def kmeans_init(X, K, init_idx=None): 949 | centers = [] 950 | index = [] 951 | 952 | # If initial indices are not provided, select one data point randomly 953 | if init_idx is None: 954 | init_idx = [np.random.randint(0, len(X))] 955 | 956 | # Initialize the centers and indices using provided or random indices 957 | for idx in init_idx: 958 | centers.append(X[idx]) 959 | index.append(idx) 960 | 961 | # Fill in the remaining centers 962 | while len(centers) < K: 963 | avg = np.mean(centers, axis=0) # Calculate the average of current centers 964 | dist = 0 965 | new_center = None 966 | new_index = None 967 | 968 | for j, x in enumerate(X): 969 | # Check if point is not already a center and is farther from the average 970 | if j not in index and np.linalg.norm(x - avg) > dist: 971 | dist = np.linalg.norm(x - avg) 972 | new_center = x 973 | new_index = j 974 | 975 | if new_center is not None: 976 | centers.append(new_center) 977 | index.append(new_index) 978 | print(len(centers), dist) 979 | return centers, index 980 | 981 | def softmax(x): # 2D 982 | """Compute softmax values for each sets of scores in x.""" 983 | # x (B,d) 984 | e_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) 985 | return e_x / e_x.sum(axis=-1, keepdims=True) 986 | 987 | 988 | def ensure_rgb_image(image): 989 | """Ensure image is in RGB format""" 990 | if hasattr(image, 'mode'): 991 | if image.mode != 'RGB': 992 | image = image.convert('RGB') 993 | return image 994 | 995 | 996 | def load_dataset_by_task(task, data_path): 997 | """ 998 | Load dataset based on task name 999 | Args: 1000 | task: task name ('flower102', 'cub2011', 'cars', 'Color') 1001 | data_path: path to dataset 1002 | Returns: 1003 | train_dataset, test_dataset, out_dim 1004 | """ 1005 | img_size = (224, 224) 1006 | 1007 | if task == 'flower102': 1008 | dataset_name = "nelorth/oxford-flowers" 1009 | dataset = load_dataset(dataset_name) 1010 | 1011 | # ensure the image is in RGB format 1012 | train_images = [ensure_rgb_image(img) for img in dataset['train']['image']] 1013 | test_images = [ensure_rgb_image(img) for img in dataset['test']['image']] 1014 | 1015 | processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") 1016 | train_inputs = processor(train_images, return_tensors="pt") 1017 | test_inputs = processor(test_images, return_tensors="pt") 1018 | train_dataset = MyImageDataset(train_inputs['pixel_values'], dataset['train']['label']) 1019 | test_dataset = MyImageDataset(test_inputs['pixel_values'], dataset['test']['label']) 1020 | out_dim = 102 1021 | 1022 | elif task == 'cub2011': 1023 | dataset_name = "Donghyun99/CUB-200-2011" 1024 | dataset = load_dataset(dataset_name) 1025 | 1026 | # ensure the image is in RGB format 1027 | train_images = [ensure_rgb_image(img) for img in dataset['train']['image']] 1028 | test_images = [ensure_rgb_image(img) for img in dataset['test']['image']] 1029 | 1030 | # use the same processor to ensure consistency 1031 | processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") 1032 | train_inputs = processor(train_images, return_tensors="pt") 1033 | test_inputs = processor(test_images, return_tensors="pt") 1034 | train_dataset = MyImageDataset(train_inputs['pixel_values'], dataset['train']['label']) 1035 | test_dataset = MyImageDataset(test_inputs['pixel_values'], dataset['test']['label']) 1036 | out_dim = 200 1037 | 1038 | elif task == 'cars': 1039 | dataset_name = "tanganke/stanford_cars" 1040 | dataset = load_dataset(dataset_name) 1041 | 1042 | # ensure the image is in RGB format 1043 | train_images = [ensure_rgb_image(img) for img in dataset['train']['image']] 1044 | test_images = [ensure_rgb_image(img) for img in dataset['test']['image']] 1045 | 1046 | # use the same processor to ensure consistency 1047 | processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") 1048 | train_inputs = processor(train_images, return_tensors="pt") 1049 | test_inputs = processor(test_images, return_tensors="pt") 1050 | train_dataset = MyImageDataset(train_inputs['pixel_values'], dataset['train']['label']) 1051 | test_dataset = MyImageDataset(test_inputs['pixel_values'], dataset['test']['label']) 1052 | out_dim = 196 1053 | 1054 | elif task == 'Color': 1055 | # Define a transformation to convert the images to PyTorch tensors 1056 | transform = transforms.Compose([ 1057 | transforms.ToTensor(), 1058 | transforms.Lambda(lambda x: x[:3, ...]), 1059 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to range [-1,1] 1060 | ]) 1061 | 1062 | # Load the images and labels 1063 | dataset = [] 1064 | labels = [] 1065 | for class_dir in [os.path.join(data_path, 'Color/class0'), 1066 | os.path.join(data_path, 'Color/class1')]: 1067 | if os.path.exists(class_dir): 1068 | for image_name in os.listdir(class_dir): 1069 | # Read image 1070 | image = Image.open(os.path.join(class_dir, image_name)) 1071 | # Add to the lists 1072 | dataset.append(image) 1073 | labels.append(int(class_dir[-1])) # class ID from the directory name 1074 | 1075 | # Convert lists to tensors 1076 | labels = torch.tensor(labels) 1077 | 1078 | # Split into train and test sets 1079 | # Pair up the data and labels 1080 | paired_data = list(zip(dataset, labels)) 1081 | 1082 | # Perform the split on the paired data 1083 | train_size = int(0.8 * len(paired_data)) # 80% for training 1084 | test_size = len(paired_data) - train_size 1085 | train_data, test_data = random_split(paired_data, [train_size, test_size]) 1086 | 1087 | train_images, train_labels = zip(*train_data) 1088 | test_images, test_labels = zip(*test_data) 1089 | 1090 | # Convert the zipped data back to lists or tensors as needed 1091 | train_images = list(train_images) 1092 | train_labels = list(train_labels) 1093 | test_images = list(test_images) 1094 | test_labels = list(test_labels) 1095 | 1096 | # Create MyImageDataset instances 1097 | train_dataset = MyImageDataset(train_images, train_labels, transform=transform) 1098 | test_dataset = MyImageDataset(test_images, test_labels, transform=transform) 1099 | out_dim = 2 1100 | 1101 | else: 1102 | raise ValueError(f"Unsupported task: {task}") 1103 | 1104 | return train_dataset, test_dataset, out_dim 1105 | 1106 | 1107 | ''' 1108 | def row_norms(X, squared=False): 1109 | """Row-wise (squared) Euclidean norm of X. 1110 | Equivalent to np.sqrt((X * X).sum(axis=1)), but also supports sparse 1111 | matrices and does not create an X.shape-sized temporary. 1112 | Performs no input validation. 1113 | Parameters 1114 | ---------- 1115 | X : array-like 1116 | The input array. 1117 | squared : bool, default=False 1118 | If True, return squared norms. 1119 | Returns 1120 | ------- 1121 | array-like 1122 | The row-wise (squared) Euclidean norm of X. 1123 | """ 1124 | if sparse.issparse(X): 1125 | if not isinstance(X, sparse.csr_matrix): 1126 | X = sparse.csr_matrix(X) 1127 | norms = csr_row_norms(X) 1128 | else: 1129 | norms = np.einsum("ij,ij->i", X, X) 1130 | 1131 | if not squared: 1132 | np.sqrt(norms, norms) 1133 | return norms 1134 | 1135 | def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared=False): 1136 | """Computational part of euclidean_distances 1137 | Assumes inputs are already checked. 1138 | If norms are passed as float32, they are unused. If arrays are passed as 1139 | float32, norms needs to be recomputed on upcast chunks. 1140 | TODO: use a float64 accumulator in row_norms to avoid the latter. 1141 | """ 1142 | if X_norm_squared is not None: 1143 | if X_norm_squared.dtype == np.float32: 1144 | XX = None 1145 | else: 1146 | XX = X_norm_squared.reshape(-1, 1) 1147 | elif X.dtype == np.float32: 1148 | XX = None 1149 | else: 1150 | XX = row_norms(X, squared=True)[:, np.newaxis] 1151 | 1152 | if Y is X: 1153 | YY = None if XX is None else XX.T 1154 | else: 1155 | if Y_norm_squared is not None: 1156 | if Y_norm_squared.dtype == np.float32: 1157 | YY = None 1158 | else: 1159 | YY = Y_norm_squared.reshape(1, -1) 1160 | elif Y.dtype == np.float32: 1161 | YY = None 1162 | else: 1163 | YY = row_norms(Y, squared=True)[np.newaxis, :] 1164 | 1165 | if X.dtype == np.float32: 1166 | # To minimize precision issues with float32, we compute the distance 1167 | # matrix on chunks of X and Y upcast to float64 1168 | distances = _euclidean_distances_upcast(X, XX, Y, YY) 1169 | else: 1170 | # if dtype is already float64, no need to chunk and upcast 1171 | distances = -2 * safe_sparse_dot(X, Y.T, dense_output=True) 1172 | distances += XX 1173 | distances += YY 1174 | np.maximum(distances, 0, out=distances) 1175 | 1176 | # Ensure that distances between vectors and themselves are set to 0.0. 1177 | # This may not be the case due to floating point rounding errors. 1178 | if X is Y: 1179 | np.fill_diagonal(distances, 0) 1180 | 1181 | return distances if squared else np.sqrt(distances, out=distances) 1182 | 1183 | 1184 | # ref: https://github.com/scikit-learn/scikit-learn/blob/baf0ea25d/sklearn/cluster/_kmeans.py#L154 1185 | 1186 | def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): 1187 | """Computational component for initialization of n_clusters by 1188 | k-means++. Prior validation of data is assumed. 1189 | Parameters 1190 | ---------- 1191 | X : {ndarray, sparse matrix} of shape (n_samples, n_features) 1192 | The data to pick seeds for. 1193 | n_clusters : int 1194 | The number of seeds to choose. 1195 | x_squared_norms : ndarray of shape (n_samples,) 1196 | Squared Euclidean norm of each data point. 1197 | random_state : RandomState instance 1198 | The generator used to initialize the centers. 1199 | See :term:`Glossary `. 1200 | n_local_trials : int, default=None 1201 | The number of seeding trials for each center (except the first), 1202 | of which the one reducing inertia the most is greedily chosen. 1203 | Set to None to make the number of trials depend logarithmically 1204 | on the number of seeds (2+log(k)); this is the default. 1205 | Returns 1206 | ------- 1207 | centers : ndarray of shape (n_clusters, n_features) 1208 | The initial centers for k-means. 1209 | indices : ndarray of shape (n_clusters,) 1210 | The index location of the chosen centers in the data array X. For a 1211 | given index and center, X[index] = center. 1212 | """ 1213 | n_samples, n_features = X.shape 1214 | 1215 | centers = np.empty((n_clusters, n_features), dtype=X.dtype) 1216 | 1217 | # Set the number of local seeding trials if none is given 1218 | if n_local_trials is None: 1219 | # This is what Arthur/Vassilvitskii tried, but did not report 1220 | # specific results for other than mentioning in the conclusion 1221 | # that it helped. 1222 | n_local_trials = 2 + int(np.log(n_clusters)) 1223 | 1224 | # Pick first center randomly and track index of point 1225 | center_id = random_state.randint(n_samples) 1226 | indices = np.full(n_clusters, -1, dtype=int) 1227 | if sp.issparse(X): 1228 | centers[0] = X[center_id].toarray() 1229 | else: 1230 | centers[0] = X[center_id] 1231 | indices[0] = center_id 1232 | 1233 | # Initialize list of closest distances and calculate current potential 1234 | closest_dist_sq = _euclidean_distances( 1235 | centers[0, np.newaxis], X, Y_norm_squared=x_squared_norms, squared=True 1236 | ) 1237 | current_pot = closest_dist_sq.sum() 1238 | 1239 | # Pick the remaining n_clusters-1 points 1240 | for c in range(1, n_clusters): 1241 | # Choose center candidates by sampling with probability proportional 1242 | # to the squared distance to the closest existing center 1243 | rand_vals = random_state.uniform(size=n_local_trials) * current_pot 1244 | candidate_ids = np.searchsorted(stable_cumsum(closest_dist_sq), rand_vals) 1245 | # XXX: numerical imprecision can result in a candidate_id out of range 1246 | np.clip(candidate_ids, None, closest_dist_sq.size - 1, out=candidate_ids) 1247 | 1248 | # Compute distances to center candidates 1249 | distance_to_candidates = _euclidean_distances( 1250 | X[candidate_ids], X, Y_norm_squared=x_squared_norms, squared=True 1251 | ) 1252 | 1253 | # update closest distances squared and potential for each candidate 1254 | np.minimum(closest_dist_sq, distance_to_candidates, out=distance_to_candidates) 1255 | candidates_pot = distance_to_candidates.sum(axis=1) 1256 | 1257 | # Decide which candidate is the best 1258 | best_candidate = np.argmin(candidates_pot) 1259 | current_pot = candidates_pot[best_candidate] 1260 | closest_dist_sq = distance_to_candidates[best_candidate] 1261 | best_candidate = candidate_ids[best_candidate] 1262 | 1263 | # Permanently add best center candidate found in local tries 1264 | if sp.issparse(X): 1265 | centers[c] = X[best_candidate].toarray() 1266 | else: 1267 | centers[c] = X[best_candidate] 1268 | indices[c] = best_candidate 1269 | 1270 | return centers, indices 1271 | 1272 | 1273 | 1274 | ''' --------------------------------------------------------------------------------