├── .gitignore
├── PACE
├── src
│ ├── eval_PACE.sh
│ ├── train_PACE.sh
│ ├── train_ViT.sh
│ ├── generate_data.py
│ ├── config.py
│ ├── evaluate.py
│ ├── augment.py
│ ├── main.py
│ ├── run_quantitative_theta.py
│ ├── model.py
│ └── utils.py
└── environment_PACE.yml
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | PACE/dataset
2 | PACE/results
3 | PACE/src/__pycache__
4 | PACE/ckpt
--------------------------------------------------------------------------------
/PACE/src/eval_PACE.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python run_quantitative_theta.py --task Color --name ViT-PACE --num_epochs 1 &
2 | CUDA_VISIBLE_DEVICES=1 python run_quantitative_theta.py --task flower102 --name ViT-PACE --num_epochs 1 &
3 | CUDA_VISIBLE_DEVICES=2 python run_quantitative_theta.py --task cub2011 --name ViT-PACE --num_epochs 1 &
4 | CUDA_VISIBLE_DEVICES=3 python run_quantitative_theta.py --task cars --name ViT-PACE --num_epochs 1
--------------------------------------------------------------------------------
/PACE/src/train_PACE.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main.py --train --task Color --name ViT-PACE --num_epochs 1 --pretrain_epoch 5 &
2 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task flower102 --name ViT-PACE --num_epochs 1 --pretrain_epoch 10 &
3 | CUDA_VISIBLE_DEVICES=0 python main.py --train --task cub2011 --name ViT-PACE --num_epochs 1 --pretrain_epoch 10 &
4 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task cars --name ViT-PACE --num_epochs 1 --pretrain_epoch 20 &
--------------------------------------------------------------------------------
/PACE/src/train_ViT.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task Color --name ViT-base --num_epochs 5 --lr 1e-3 --require_grad
2 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task flower102 --name ViT-base --num_epochs 10 --lr 1e-3 --require_grad
3 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task cub2011 --name ViT-base --num_epochs 10 --lr 1e-3 --require_grad
4 | CUDA_VISIBLE_DEVICES=1 python main.py --train --task cars --name ViT-base --num_epochs 20 --weight_decay 0.01 --lr 5e-5 --require_grad --seed 2025 --train_batch_size 32 --eval_batch_size 64
5 |
6 |
--------------------------------------------------------------------------------
/PACE/src/generate_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import os
4 | from skimage.transform import resize
5 | from skimage.util import random_noise
6 | import random
7 |
8 | # Create directories for saving the images
9 | if not os.path.exists('../dataset/Color/class0'):
10 | os.makedirs('../dataset/Color/class0')
11 |
12 | if not os.path.exists('../dataset/Color/class1'):
13 | os.makedirs('../dataset/Color/class1')
14 |
15 | # Image dimensions
16 | original_size = (2, 2)
17 | target_size = (224, 224)
18 |
19 | # Number of images to generate for each class
20 | num_images = 1000
21 |
22 | # Color codes for red, yellow, blue, green, and black in RGB
23 | color_dict = {'red': [1, 0, 0], 'yellow': [1, 1, 0], 'blue': [0, 0, 1], 'green': [0, 1, 0], 'black': [0, 0, 0]}
24 |
25 | # Function to create an image
26 | def create_image(colors, num_black):
27 | # Start with a completely colored image
28 | image = np.zeros((*original_size, 3))
29 | colored_locs = random.sample([(i, j) for i in range(2) for j in range(2)], 4-num_black)
30 | for idx, loc in enumerate(colored_locs):
31 | image[loc] = colors[idx%2]
32 |
33 | return image
34 |
35 | # Generate images
36 | for i in range(num_images):
37 | # Class 0: red+yellow
38 | num_black = random.randint(0, 2)
39 | image0 = create_image([color_dict['red'], color_dict['yellow']], num_black)
40 | image0 = random_noise(image0, mode='gaussian')
41 | image0_resized = resize(image0, target_size)
42 | plt.imsave(f'../dataset/Color/class0/{i}.png', image0_resized)
43 |
44 | # Class 1: blue+green
45 | num_black = random.randint(0, 2)
46 | image1 = create_image([color_dict['blue'], color_dict['green']], num_black)
47 | image1 = random_noise(image1, mode='gaussian')
48 | image1_resized = resize(image1, target_size)
49 | plt.imsave(f'../dataset/Color/class1/{i}.png', image1_resized)
--------------------------------------------------------------------------------
/PACE/src/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser(description='PACE')
3 |
4 | # data args
5 | parser.add_argument('--data_path', type=str, help='path of dataset',
6 | default='../dataset')
7 | parser.add_argument('--task',type=str,help='task name of dataset',default='toy')
8 | parser.add_argument('--save_path', type=str, help='path to save',
9 | default='../ckpt')
10 | parser.add_argument('--load_path', type=str, help='path to load',
11 | default='../ckpt')
12 |
13 | # model args
14 |
15 | ## for PACE
16 | parser.add_argument('--c_dim', type=int, help='dimension of PACE',default=25)
17 | parser.add_argument('--K', type=int, help='number of centers of PACE',default=100)
18 | parser.add_argument('--D', type=int, help='number of images',default=10000)
19 | parser.add_argument('--N', type=int, help='max length of images',default=197)
20 | parser.add_argument('--alpha', type=float, help='alpha prior of PACE',default=2)
21 | parser.add_argument('--eta', type=float, help='weight of PACE',default=1)
22 | parser.add_argument('--frac', type=str,help='type of fractional model', default='fix')
23 | parser.add_argument('--layer', type=int, help='layer of PACE',default=-2)
24 | parser.add_argument('--version', type=str, help='running version',default='v0')
25 |
26 | ## for ViT
27 | parser.add_argument('--lm', type=str, help='which language model', default='ViT')
28 | parser.add_argument('--b_dim', type=int, help='dimension of ViT',default=768)
29 | parser.add_argument('--out_dim', type=int, help='dimension of output',default=5)
30 | parser.add_argument('--name', type=str, help='model name', default='ViT-PACE')
31 | parser.add_argument('--seed',type=int, default=2021)
32 | parser.add_argument('--lr', type=float, help='learning rate', default=3e-5)
33 | parser.add_argument('--weight_decay', type=float, help='weight decay', default=0.05)
34 | parser.add_argument('--train', action='store_true', default=False)
35 | parser.add_argument('--require_grad', action='store_true', default=False)
36 | parser.add_argument('--pretrain_epoch', type=str, help='loading pretrain epoch', default='5')
37 |
38 | # optimization args
39 | parser.add_argument('--num_epochs', type=int, help='number of epoches',default=10)
40 | parser.add_argument('--train_batch_size', type=int, help='training sz',default=16)
41 | parser.add_argument('--eval_batch_size', type=int, help='eval sz',default=64)
42 | parser.add_argument('--metric', type=str, help='eval metric',default='eval_accuracy')
43 |
44 |
45 | # for quantitative theta
46 | parser.add_argument('--load_concepts', action='store_true', default=False,
47 | help='load concepts and probabilities from saved path instead of inferring them')
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/PACE/src/evaluate.py:
--------------------------------------------------------------------------------
1 | # import metric packages
2 | import numpy as np
3 | import pandas as pd
4 | from numpy.linalg import *
5 | from scipy.linalg import sqrtm
6 | from scipy.special import gammaln, psi
7 | import torch
8 | from sklearn.decomposition import PCA
9 | import matplotlib.pyplot as plt
10 | from datasets import load_metric
11 | from sklearn.linear_model import LogisticRegression, LinearRegression
12 | from utils import *
13 | from sklearn.pipeline import make_pipeline
14 | from sklearn import preprocessing
15 | from scipy.stats import entropy
16 | from sklearn.neural_network import MLPClassifier
17 | from sklearn.preprocessing import StandardScaler
18 | from sklearn.preprocessing import Normalizer
19 |
20 |
21 | def to_numpy(x):
22 | if isinstance(x, torch.Tensor):
23 | return x.detach().cpu().numpy()
24 | else:
25 | return x
26 |
27 |
28 | def faithfulness(concept_train, pred_train, concept_test, pred_test, hard=False, prob_test=None):
29 | # faithfulness with MLP classifier
30 | concept_train = to_numpy(concept_train)
31 | pred_train = to_numpy(pred_train)
32 | concept_test = to_numpy(concept_test)
33 | pred_test = to_numpy(pred_test)
34 | concept_train = concept_test
35 | pred_train = pred_test
36 | pipe = make_pipeline(
37 | StandardScaler(),
38 | MLPClassifier(hidden_layer_sizes=(100,), activation='relu', solver='adam', alpha=0.0001, batch_size='auto', learning_rate='constant', learning_rate_init=0.001, power_t=0.5, max_iter=200, shuffle=True, random_state=None, tol=0.0001, verbose=False, warm_start=False, momentum=0.9, nesterovs_momentum=True, early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
39 | )
40 |
41 | clf = pipe.fit(concept_train, pred_train)
42 | score = clf.score(concept_test, pred_test)
43 |
44 | return score, -1
45 |
46 |
47 | def faithfulness_linear(concept_train, pred_train, concept_test, pred_test, hard=False, prob_test=None):
48 | # faithfulness with linear classifier
49 | concept_train = to_numpy(concept_train)
50 | pred_train = to_numpy(pred_train)
51 | concept_test = to_numpy(concept_test)
52 | pred_test = to_numpy(pred_test)
53 | ct = concept_train
54 | total = np.prod(ct.shape)
55 | mx = np.max(ct)
56 | mn = np.min(ct)
57 | span = mx - mn
58 | pipe = make_pipeline(preprocessing.StandardScaler(), LogisticRegression(random_state=0, max_iter=2500))
59 | clf = pipe.fit(concept_train, pred_train)
60 | hard_score = clf.score(concept_test, pred_test)
61 | prob = clf.predict_proba(concept_test)
62 | soft_score = np.mean([entropy(prob[i], prob_test[i]) for i in range(len(prob))])
63 |
64 | return hard_score, soft_score
65 |
66 |
67 | def stability(concept_orig, concept_aug, compute=True):
68 | assert len(concept_orig.shape) == 2
69 | assert concept_orig.shape == concept_aug.shape
70 | delta = np.linalg.norm(concept_orig - concept_aug, axis=1) / np.linalg.norm(concept_orig, axis=1)
71 |
72 | return np.mean(delta)
73 |
74 | def sparsity(concept):
75 | assert len(concept.shape) == 2
76 | eps = 0.1 / concept.shape[1]
77 | return np.mean(concept < eps)
78 |
79 | def parsimony(concept):
80 | assert len(concept.shape) == 2
81 | return concept.shape[1]
82 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Probabilistic Conceptual Explainers
for Foundation Models
2 | This repo contains the code and data for our PACE (ICML 2024 paper):
3 |
4 | **Probabilistic Conceptual Explainers: Trustworthy Conceptual Explanations for Vision Foundation Models**
5 | Hengyi Wang*, Shiwei Tan*, Hao Wang
6 | [[Paper](http://www.wanghao.in/paper/ICML24_PACE.pdf)] [[ICML Website](https://icml.cc/virtual/2024/poster/34650)]
7 |
8 | and our VALC (EMNLP 2024 Findings paper):
9 |
10 | **Variational Language Concepts for Interpreting Foundation Language Models**
11 | Hengyi Wang, Shiwei Tan, Zhiqing Hong, Desheng Zhang, Hao Wang
12 | [[Paper](http://www.wanghao.in/paper/EMNLP24_VALC.pdf)] [[ACL Website](https://aclanthology.org/2024.findings-emnlp.505/)]
13 |
14 | ## Brief Introduction for PACE
15 | We propose five desiderata for explaining vision foundation models like ViTs - faithfulness, stability, sparsity, multi-level structure, and parsimony - and demonstrate the inadequacy of current methods in meeting these criteria comprehensively. Rather than using sparse autoencoders (SAEs), we introduce a variational Bayesian explanation framework, dubbed ProbAbilistic Concept Explainers (PACE), which models the distributions of patch embeddings to provide trustworthy post-hoc conceptual explanations. Our PACE can provide dataset-, image-, and patch-level explanations for ViTs and achieves all five desiderata (faithfulness, stability, sparsity, multi-level structure, and parsimony) in a unified framework.
16 |
17 | ## Probabilistic Conceptual Explainers (PACE) for Vision Transformers (ViTs)
18 | PACE is compatible with *arbitrary* vision transformers.
19 |
20 | Below are some sample concepts automatically discovered by our PACE, *without the need for concept annotation during training*.
21 |
22 | 
23 |
24 | **Figure 1.** Above are some sample concepts discovered by PACE in the *COLOR* dataset. See Figure 3 of [our paper](http://wanghao.in/paper/ICML24_PACE.pdf) for details on the *COLOR* dataset.
25 |
26 | 
27 |
28 | **Figure 2.** Above are some sample concepts discovered by PACE in the *Oxford Flower* dataset.
29 |
30 |
31 |
32 | ### Installation
33 | ```bash
34 | conda env create -f environment_PACE.yml
35 | conda activate PACE
36 | cd src
37 | ```
38 |
39 | ### Generate the *Color* Dataset
40 | ```bash
41 | python generate_data.py
42 | ```
43 | ### Finetune ViT for Color Dataset and Real-World Datasets
44 | ```
45 | bash ./train_ViT.sh
46 | ```
47 |
48 | ### Train PACE for Each Dataset
49 | ```bash
50 | bash ./train_PACE.sh
51 | ```
52 |
53 | ### Test PACE for Each Dataset
54 | ```bash
55 | bash ./eval_PACE.sh
56 | ```
57 |
58 |
59 | ## Probabilistic Conceptual Explainers (VALC) for Pretrained Language Models
60 |
61 | Coming Soon!
62 |
63 | ## Reference
64 |
65 | ```bib
66 | @inproceedings{PACE,
67 | title={Probabilistic Conceptual Explainers: Trustworthy Conceptual Explanations for Vision Foundation Models},
68 | author={Hengyi Wang and
69 | Shiwei Tan and
70 | Hao Wang},
71 | booktitle={International Conference on Machine Learning},
72 | year={2024}
73 | }
74 |
75 | @inproceedings{VALC,
76 | title={Variational Language Concepts for Interpreting Foundation Language Models},
77 | author={Hengyi Wang and
78 | Shiwei Tan and
79 | Zhiqing Hong and
80 | Desheng Zhang and
81 | Hao Wang},
82 | booktitle={Findings of the Association for Computational Linguistics: EMNLP 2024},
83 | year={2024}
84 | }
85 | ```
86 |
--------------------------------------------------------------------------------
/PACE/src/augment.py:
--------------------------------------------------------------------------------
1 | from tkinter import image_types
2 | import numpy as np
3 | import pandas as pd
4 | from numpy.linalg import *
5 | from scipy.linalg import sqrtm
6 | from scipy.special import gammaln, psi
7 | import torch
8 | from sklearn.decomposition import PCA
9 | import matplotlib.pyplot as plt
10 | from datasets import load_metric
11 | from transformers import ViTFeatureExtractor
12 | from torchvision.transforms import (CenterCrop,
13 | Compose,
14 | Normalize,
15 | RandomHorizontalFlip,
16 | RandomResizedCrop,
17 | Resize,
18 | ToTensor)
19 | from torch import FloatTensor, div
20 | from torch.utils.data import DataLoader, Dataset
21 | from torchvision import transforms
22 | from torchvision.transforms.functional import InterpolationMode
23 |
24 | import torch.nn.functional as F
25 | import os
26 | import pandas as pd
27 | from torchvision.datasets.folder import default_loader
28 | from torchvision.datasets.utils import download_url
29 | from torch.utils.data import Dataset
30 | import pickle
31 | from sklearn.cluster import KMeans
32 | from sklearn.linear_model import LogisticRegression
33 | from utils import *
34 |
35 |
36 |
37 |
38 | def cov_z(phi):
39 | ret = np.empty((phi.shape[0], phi.shape[0]))
40 | for i in range(phi.shape[0]):
41 | for j in range(phi.shape[0]):
42 | if i == j:
43 | ret[i,j] = phi[i]*(1-phi[i])
44 | else:
45 | ret[i,j] = - phi[i]*phi[j]
46 | return ret
47 |
48 |
49 |
50 |
51 |
52 | def contrastive_learning(e, e_prime, temperature=0.1):
53 | """
54 | Args:
55 | e (torch.Tensor): A tensor of shape (B, d) representing the original embeddings.
56 | e_prime (torch.Tensor): A tensor of shape (B, d) representing the transformed embeddings.
57 | temperature (float): A temperature hyperparameter to scale the logits.
58 |
59 | Returns:
60 | loss (torch.Tensor): The contrastive loss value.
61 | """
62 | # Normalize the embeddings and transformations to have unit length
63 | e = e.float()
64 | e_prime = e_prime.float()
65 | e.requires_grad_(True)
66 |
67 | e = F.normalize(e, dim=-1)
68 | e_prime = F.normalize(e_prime, dim=-1)
69 |
70 | # Compute dot product similarity between e and e' (positive pairs)
71 | positive_similarity = torch.sum(e * e_prime, dim=-1)
72 |
73 | # Compute dot product similarity between e and all other e' (negative pairs)
74 | negative_similarity = torch.mm(e, e_prime.t())
75 |
76 | # Remove the similarity of positive pairs from negative_similarity
77 | diagonal_indices = torch.arange(e.shape[0])
78 | negative_similarity[diagonal_indices, diagonal_indices] = float('-inf')
79 |
80 | # Compute the logits for positive and negative pairs
81 | logits = torch.cat([positive_similarity.unsqueeze(1), negative_similarity], dim=1)
82 |
83 | # Apply temperature scaling to the logits
84 | logits /= temperature
85 |
86 | # Create labels for the positive pairs (the first column in logits)
87 | labels = torch.zeros(e.shape[0], dtype=torch.long, device=e.device)
88 |
89 | # gradient of e, not use autograd
90 | grad_e = torch.zeros_like(e)
91 | grad_e_prime = torch.zeros_like(e_prime)
92 | # Compute cross entropy loss for each sample
93 | loss = F.cross_entropy(logits, labels)
94 | # Compute gradients for e and e_prime
95 | grad_e = torch.autograd.grad(loss, e, retain_graph=True)[0]
96 |
97 | return grad_e, loss
98 |
99 | def image_augment(image):
100 | b, c, h, w = image.shape
101 | contrast_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
102 | transforms.RandomResizedCrop(size=w),
103 | transforms.RandomApply([
104 | transforms.ColorJitter(brightness=0.5,
105 | contrast=0.5,
106 | saturation=0.5,
107 | hue=0.1)
108 | ], p=0.8),
109 | transforms.RandomGrayscale(p=0.2),
110 | transforms.GaussianBlur(kernel_size=9),
111 | #transforms.ToTensor()
112 | transforms.Normalize((0.5,), (0.5,))
113 | ])
114 | image_trans = contrast_transforms(image).cuda()
115 | return image_trans
116 |
117 | def contrative_transform(image, model, topic_model, temperature=0.1):
118 | # image: input image batch (B, C, H, W)
119 | # model: ViT model
120 | contrast_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
121 | transforms.RandomResizedCrop(size=96),
122 | transforms.RandomApply([
123 | transforms.ColorJitter(brightness=0.5,
124 | contrast=0.5,
125 | saturation=0.5,
126 | hue=0.1)
127 | ], p=0.8),
128 | transforms.RandomGrayscale(p=0.2),
129 | transforms.GaussianBlur(kernel_size=9),
130 | transforms.ToTensor(),
131 | transforms.Normalize((0.5,), (0.5,))
132 | ])
133 | image_trans = contrast_transforms(image)
134 | embed = model(image)
135 | embed_trans = model(image_trans)
136 |
137 | return embed, embed_trans
138 |
--------------------------------------------------------------------------------
/PACE/environment_PACE.yml:
--------------------------------------------------------------------------------
1 | name: PACE
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - _tflow_select=2.3.0=mkl
8 | - absl-py=1.3.0=py37h06a4308_0
9 | - astor=0.8.1=py37h06a4308_0
10 | - astunparse=1.6.3=py_0
11 | - asynctest=0.13.0=py_0
12 | - backcall=0.2.0=pyhd3eb1b0_0
13 | - blas=1.0=mkl
14 | - blinker=1.4=py37h06a4308_0
15 | - brotli=1.0.9=h5eee18b_7
16 | - brotli-bin=1.0.9=h5eee18b_7
17 | - brotlipy=0.7.0=py37h27cfd23_1003
18 | - c-ares=1.19.1=h5eee18b_0
19 | - ca-certificates=2023.08.22=h06a4308_0
20 | - cachetools=4.2.2=pyhd3eb1b0_0
21 | - certifi=2022.12.7=py37h06a4308_0
22 | - cffi=1.15.0=py37h7f8727e_0
23 | - cryptography=39.0.1=py37h9ce1e76_0
24 | - cudatoolkit=10.1.243=h6bb024c_0
25 | - cycler=0.11.0=pyhd3eb1b0_0
26 | - dbus=1.13.18=hb2f20db_0
27 | - decorator=5.1.1=pyhd3eb1b0_0
28 | - entrypoints=0.4=py37h06a4308_0
29 | - expat=2.5.0=h6a678d5_0
30 | - fftw=3.3.9=h27cfd23_1
31 | - fontconfig=2.14.1=h52c9d5c_1
32 | - fonttools=4.25.0=pyhd3eb1b0_0
33 | - freetype=2.12.1=h4a9f257_0
34 | - frozenlist=1.3.3=py37h5eee18b_0
35 | - gast=0.4.0=pyhd3eb1b0_0
36 | - giflib=5.2.1=h5eee18b_3
37 | - glib=2.63.1=h5a9c865_0
38 | - google-auth=2.6.0=pyhd3eb1b0_0
39 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
40 | - google-pasta=0.2.0=pyhd3eb1b0_0
41 | - grpcio=1.42.0=py37hce63b2e_0
42 | - gst-plugins-base=1.14.0=hbbd80ab_1
43 | - gstreamer=1.14.0=hb453b48_1
44 | - h5py=2.10.0=py37hd6299e0_1
45 | - hdf5=1.10.6=h3ffc7dd_1
46 | - icu=58.2=he6710b0_3
47 | - idna=3.4=py37h06a4308_0
48 | - importlib_metadata=4.11.3=hd3eb1b0_0
49 | - intel-openmp=2021.4.0=h06a4308_3561
50 | - joblib=1.1.1=py37h06a4308_0
51 | - jpeg=9e=h5eee18b_1
52 | - jupyter_client=7.4.9=py37h06a4308_0
53 | - jupyter_core=4.11.2=py37h06a4308_0
54 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0
55 | - kiwisolver=1.4.4=py37h6a678d5_0
56 | - lcms2=2.12=h3be6417_0
57 | - lerc=3.0=h295c915_0
58 | - libbrotlicommon=1.0.9=h5eee18b_7
59 | - libbrotlidec=1.0.9=h5eee18b_7
60 | - libbrotlienc=1.0.9=h5eee18b_7
61 | - libdeflate=1.17=h5eee18b_0
62 | - libedit=3.1.20221030=h5eee18b_0
63 | - libffi=3.2.1=hf484d3e_1007
64 | - libgcc-ng=11.2.0=h1234567_1
65 | - libgfortran-ng=11.2.0=h00389a5_1
66 | - libgfortran5=11.2.0=h1234567_1
67 | - libgomp=11.2.0=h1234567_1
68 | - libpng=1.6.39=h5eee18b_0
69 | - libprotobuf=3.20.3=he621ea3_0
70 | - libsodium=1.0.18=h7b6447c_0
71 | - libstdcxx-ng=11.2.0=h1234567_1
72 | - libtiff=4.5.1=h6a678d5_0
73 | - libuuid=1.41.5=h5eee18b_0
74 | - libwebp=1.2.4=h11a3e52_1
75 | - libwebp-base=1.2.4=h5eee18b_1
76 | - libxcb=1.15=h7f8727e_0
77 | - libxml2=2.9.14=h74e7548_0
78 | - lz4-c=1.9.4=h6a678d5_0
79 | - markdown=3.4.1=py37h06a4308_0
80 | - markupsafe=2.1.1=py37h7f8727e_0
81 | - matplotlib=3.5.3=py37h06a4308_0
82 | - matplotlib-base=3.5.3=py37hf590b9c_0
83 | - matplotlib-inline=0.1.6=py37h06a4308_0
84 | - mkl=2021.4.0=h06a4308_640
85 | - mkl-service=2.4.0=py37h7f8727e_0
86 | - mkl_fft=1.3.1=py37hd3c417c_0
87 | - mkl_random=1.2.2=py37h51133e4_0
88 | - munkres=1.1.4=py_0
89 | - ncurses=6.4=h6a678d5_0
90 | - numpy-base=1.21.5=py37ha15fc14_3
91 | - oauthlib=3.2.1=py37h06a4308_0
92 | - openssl=1.1.1w=h7f8727e_0
93 | - opt_einsum=3.3.0=pyhd3eb1b0_1
94 | - parso=0.8.3=pyhd3eb1b0_0
95 | - pcre=8.45=h295c915_0
96 | - pexpect=4.8.0=pyhd3eb1b0_3
97 | - pickleshare=0.7.5=pyhd3eb1b0_1003
98 | - pip=22.3.1=py37h06a4308_0
99 | - ptyprocess=0.7.0=pyhd3eb1b0_2
100 | - pyasn1=0.4.8=pyhd3eb1b0_0
101 | - pyasn1-modules=0.2.8=py_0
102 | - pycparser=2.21=pyhd3eb1b0_0
103 | - pyjwt=2.4.0=py37h06a4308_0
104 | - pyopenssl=23.0.0=py37h06a4308_0
105 | - pyparsing=3.0.9=py37h06a4308_0
106 | - pyqt=5.9.2=py37h05f1152_2
107 | - pysocks=1.7.1=py37_1
108 | - python=3.7.2=h0371630_0
109 | - python-dateutil=2.8.2=pyhd3eb1b0_0
110 | - python-flatbuffers=2.0=pyhd3eb1b0_0
111 | - qt=5.9.7=h5867ecd_1
112 | - readline=7.0=h7b6447c_5
113 | - requests-oauthlib=1.3.0=py_0
114 | - rsa=4.7.2=pyhd3eb1b0_1
115 | - scikit-learn=1.0.2=py37h51133e4_1
116 | - scipy=1.7.3=py37h6c91a56_2
117 | - setuptools=65.6.3=py37h06a4308_0
118 | - sip=4.19.8=py37hf484d3e_0
119 | - six=1.16.0=pyhd3eb1b0_1
120 | - sqlite=3.33.0=h62c20be_0
121 | - tensorboard=2.10.0=py37h06a4308_0
122 | - tensorboard-data-server=0.6.1=py37h52d8a92_0
123 | - tensorboard-plugin-wit=1.8.1=py37h06a4308_0
124 | - tensorflow=2.4.1=mkl_py37h2d14ff2_0
125 | - tensorflow-base=2.4.1=mkl_py37h43e0292_0
126 | - tensorflow-estimator=2.6.0=pyh7b7c402_0
127 | - termcolor=2.1.0=py37h06a4308_0
128 | - threadpoolctl=2.2.0=pyh0d69192_0
129 | - tk=8.6.12=h1ccaba5_0
130 | - tornado=6.2=py37h5eee18b_0
131 | - tqdm=4.64.1=py37h06a4308_0
132 | - typing_extensions=4.1.1=pyh06a4308_0
133 | - werkzeug=2.2.2=py37h06a4308_0
134 | - wheel=0.38.4=py37h06a4308_0
135 | - wrapt=1.14.1=py37h5eee18b_0
136 | - xz=5.4.2=h5eee18b_0
137 | - zeromq=4.3.4=h2531618_0
138 | - zlib=1.2.13=h5eee18b_0
139 | - zstd=1.5.5=hc292b87_0
140 | - pip:
141 | - accelerate==0.20.3
142 | - aiohttp==3.8.5
143 | - aiosignal==1.3.1
144 | - appdirs==1.4.4
145 | - async-timeout==4.0.3
146 | - attrs==23.1.0
147 | - chardet==5.2.0
148 | - charset-normalizer==2.1.1
149 | - click==8.1.7
150 | - craft-xai==0.0.3
151 | - datasets==2.13.1
152 | - debugpy==1.6.7.post1
153 | - dill==0.3.6
154 | - docker-pycreds==0.4.0
155 | - filelock==3.12.2
156 | - fsspec==2023.1.0
157 | - gitdb==4.0.10
158 | - gitpython==3.1.34
159 | - huggingface-hub==0.16.4
160 | - imageio==2.31.2
161 | - importlib-metadata==6.7.0
162 | - ipykernel==6.16.2
163 | - ipython==7.34.0
164 | - jedi==0.19.0
165 | - jupyter-core==4.12.0
166 | - loguru==0.7.2
167 | - mingpt==0.0.1
168 | - multidict==6.0.4
169 | - multiprocess==0.70.14
170 | - nest-asyncio==1.5.7
171 | - networkx==2.6.3
172 | - numpy==1.21.6
173 | - nvidia-cublas-cu11==11.10.3.66
174 | - nvidia-cuda-nvrtc-cu11==11.7.99
175 | - nvidia-cuda-runtime-cu11==11.7.99
176 | - nvidia-cudnn-cu11==8.5.0.96
177 | - opencv-python==4.8.0.76
178 | - packaging==23.1
179 | - pandas==1.3.5
180 | - pathtools==0.1.2
181 | - pillow==9.5.0
182 | - prompt-toolkit==3.0.39
183 | - protobuf==4.24.2
184 | - psutil==5.9.5
185 | - pyarrow==12.0.1
186 | - pygments==2.16.1
187 | - python-graphviz==0.20.1
188 | - pytz==2023.3
189 | - pywavelets==1.3.0
190 | - pyyaml==6.0.1
191 | - pyzmq==25.1.1
192 | - regex==2023.8.8
193 | - requests==2.31.0
194 | - safetensors==0.3.3
195 | - scikit-image==0.19.3
196 | - sentry-sdk==1.30.0
197 | - setproctitle==1.3.2
198 | - smmap==5.0.0
199 | - tifffile==2021.11.2
200 | - timm==0.9.7
201 | - tokenizers==0.13.3
202 | - torch==1.13.1
203 | - torchvision==0.14.1
204 | - torchviz==0.0.2
205 | - traitlets==5.9.0
206 | - transformers==4.30.2
207 | - typing-extensions==4.7.1
208 | - urllib3==2.0.4
209 | - wandb==0.15.9
210 | - wcwidth==0.2.6
211 | - xxhash==3.3.0
212 | - yarl==1.9.2
213 | - zipp==3.15.0
214 |
--------------------------------------------------------------------------------
/PACE/src/main.py:
--------------------------------------------------------------------------------
1 | from cProfile import label
2 | from transformers import Trainer, TrainingArguments
3 | from transformers import EarlyStoppingCallback, TrainerCallback
4 | from transformers import ViTFeatureExtractor, ViTForImageClassification
5 | import torch
6 | from torch import nn
7 | from torch.autograd import Variable
8 | from torch.utils.data import DataLoader
9 | from datasets import load_metric,load_dataset
10 | import pickle
11 | import os
12 | from math import pi
13 | import numpy as np
14 | from sklearn.mixture import GaussianMixture
15 | import pickle
16 | from numpy import random
17 | import scipy.sparse as sp
18 | from scipy.special import gammaln
19 | from tqdm import tqdm
20 | from sklearn.decomposition import PCA
21 | from sklearn import manifold
22 | import matplotlib.pyplot as plt
23 | import numpy as np
24 | import pickle
25 | import sys, re, time, string
26 | from scipy.special import gammaln, psi
27 | from numpy.linalg import *
28 | import math
29 | import pandas as pd
30 | from config import parser
31 | import torchvision.transforms as transforms
32 | import torchvision
33 | from utils import accuracy_score, dirichlet_expectation, read_tsv_file, compute_metrics, Adam, posterior_mu_sigma
34 | from utils import Cub2011, StanfordCars, MyImageDatasetFromStanfordCars, build_transform, load_dataset_by_task
35 | from model import PACE, ViTClassify
36 | from torchviz import make_dot
37 | from utils import dirichlet_expectation
38 | from torchvision.transforms.functional import InterpolationMode
39 | from augment import image_augment
40 | from transformers import TrainingArguments, Trainer, AutoFeatureExtractor
41 | from utils import MyImageDataset
42 | from datasets import load_dataset
43 | import os
44 | from PIL import Image
45 | import torch
46 | import torchvision.transforms as transforms
47 | from torch.utils.data import random_split
48 |
49 | args = parser.parse_args()
50 |
51 | #print("args", args)
52 |
53 | args.save_path = os.path.normpath(os.path.join(args.save_path, args.name)) # adjusts to windows or linux
54 |
55 | # create the directory if it doesn't exist
56 | if not os.path.exists(args.save_path):
57 | print(f"❌ Directory does not exist: {args.save_path}")
58 | os.makedirs(args.save_path, exist_ok=True)
59 | print(f"✅ Created directory: {args.save_path}")
60 |
61 | # Windows: ..\ckpt\ViT-base
62 | # Linux/macOS: ../ckpt/ViT-base
63 |
64 | np.random.seed(args.seed)
65 | torch.manual_seed(args.seed)
66 | random.seed(args.seed)
67 |
68 | class MyEarlyStoppingCallback(EarlyStoppingCallback):
69 |
70 | def __init__(self, early_stopping_patience=1, early_stopping_threshold=0.0, args=None):
71 | super(MyEarlyStoppingCallback,self).__init__(early_stopping_patience, early_stopping_threshold)
72 | self.epochs = 0
73 | self.ub = 0
74 | self.flag = True
75 | self.args = args
76 |
77 | def on_evaluate(self, args, state, control, metrics, **kwargs):
78 | metric_to_check = args.metric_for_best_model
79 | if not metric_to_check.startswith("eval_"):
80 | metric_to_check = f"eval_{metric_to_check}"
81 | metric_value = metrics.get(metric_to_check)
82 |
83 | self.check_metric_value(args, state, control, metric_value)
84 | if self.early_stopping_patience_counter >= self.early_stopping_patience:
85 | control.should_training_stop = True
86 | self.epochs += 1
87 | self.flag = True
88 |
89 | if PACE is None:
90 | return
91 |
92 | # save model
93 | args = self.args
94 | torch.save(model.state_dict(), args.save_path +'/' + args.task + '_epoch'+str(self.epochs)+'.pt')
95 | np.save(args.save_path+'/' + args.task + '_mus-epoch'+str(self.epochs) +'.npy',PACE._mus)
96 | np.save(args.save_path+'/' + args.task + '_sigmas-epoch'+str(self.epochs)+'.npy',PACE._sigmas)
97 | np.save(args.save_path+'/' + args.task + '_eta-epoch'+str(self.epochs)+'.npy',PACE._eta)
98 |
99 |
100 | class PACETrainer(Trainer):
101 |
102 | def compute_loss(self,model,inputs,return_outputs=False, **kwargs): # **args...
103 | #output = model(inputs['encodings']) # get predict outputs and last word embeddings
104 | logits, states, att = model(inputs['encodings'])
105 | image_trans = image_augment(inputs['encodings'])
106 | logits_trans, states_trans, att_trans = model(image_trans)
107 |
108 | loss = torch.nn.CrossEntropyLoss()
109 | pred_prob = np.exp(logits.detach().cpu().numpy())/np.exp(logits.detach().cpu().numpy()).sum(axis=1)[:,None]
110 | pred_ids = torch.argmax(logits,-1)
111 | ViT_loss = loss(logits, inputs['labels']).sum()
112 | if mycallback.flag is True:
113 | mycallback.flag = False
114 |
115 | if PACE is not None and return_outputs is False:
116 |
117 | gamma_trans, phi_trans = PACE.do_e_step(states_trans, att_trans[args.layer + 1])
118 | PACE._phi_trans = PACE._phi
119 | gamma, phi = PACE.do_em_step(states, att[args.layer + 1], cl=True,y=pred_prob)
120 | PACE.update_eta(logits)
121 | pred_prob = torch.argmax(logits, dim=-1)
122 | PACE_loss = PACE._delta
123 | custum_loss = PACE_loss
124 | else:
125 | custum_loss = ViT_loss
126 |
127 |
128 | pred = {'label_ids':inputs['labels'], 'predictions':pred_ids}
129 | return (custum_loss, pred) if return_outputs else custum_loss
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 | img_size = 224
140 | randaug_magnitude = 0
141 |
142 |
143 |
144 | img_size = (224,224)
145 | transform = transforms.Compose([
146 | transforms.ToTensor(),
147 | transforms.Resize(img_size, interpolation=InterpolationMode.BICUBIC),
148 | #transforms.RandAugment(num_ops=2,magnitude=randaug_magnitude),
149 | transforms.Normalize(
150 | mean=(0.485, 0.456, 0.406),
151 | std=(0.229, 0.224, 0.225)
152 | ),
153 | ])
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 | # Dataset loading using load_dataset_by_task function
162 | train_dataset, test_dataset, args.out_dim = load_dataset_by_task(args.task, args.data_path)
163 | val_dataset = test_dataset
164 |
165 | # Create data loaders for easier batch processing
166 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
167 | test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
168 |
169 | model = ViTClassify(in_dim = args.b_dim, out_dim=args.out_dim,hid_dim=args.c_dim, layer=args.layer)
170 | model = model.cuda()
171 |
172 | if 'PACE' in args.name:
173 | PACE = PACE(d=args.c_dim,K=args.K,D=args.D,N=args.N,alpha=args.alpha,C = args.out_dim)
174 | else:
175 | PACE = None
176 |
177 | training_args = TrainingArguments(
178 | output_dir='../results', # output directory
179 | num_train_epochs=args.num_epochs, # total number of training epochs
180 | per_device_train_batch_size=args.train_batch_size, # batch size per device during training
181 | per_device_eval_batch_size=args.eval_batch_size, # batch size for evaluation
182 | warmup_steps=500, # number of warmup steps for learning rate scheduler
183 | weight_decay=args.weight_decay, # strength of weight decay
184 | logging_dir='./logs', # directory for storing logs
185 | logging_steps=10,
186 | seed = args.seed,
187 | load_best_model_at_end=True,
188 | metric_for_best_model=args.metric, # 'eval_matthews_correlation' for cola, etc.
189 | eval_strategy='epoch',
190 | save_strategy='epoch',
191 | learning_rate = args.lr,
192 | lr_scheduler_type='cosine', # cosine learning rate scheduler
193 | save_total_limit=3, # limit the number of saved checkpoints
194 | gradient_accumulation_steps=1, # gradient accumulation steps
195 | fp16=True, # use mixed precision training to accelerate
196 | )
197 |
198 | mycallback = MyEarlyStoppingCallback(early_stopping_patience=10, args=args)
199 |
200 | trainer = PACETrainer(
201 | model=model, # the instantiated 🤗 Transformers model to be trained
202 | args=training_args, # training arguments, defined above
203 | train_dataset=train_dataset, # training dataset
204 | eval_dataset=val_dataset, # evaluation dataset
205 | compute_metrics=compute_metrics,
206 | callbacks=[mycallback],
207 |
208 | )
209 |
210 | test_set = DataLoader(val_dataset,batch_size=args.eval_batch_size,shuffle=False)
211 |
212 | print('train size', len(train_dataset))
213 | print('eval size', len(val_dataset))
214 |
215 | if args.train:
216 | print('training...')
217 | if not args.require_grad: # Train PACE, otherwise train ViT
218 | model.load_state_dict(torch.load('../ckpt/ViT-base' +'/' + args.task + '_epoch' + args.pretrain_epoch +'.pt'))
219 | trainer.train()
220 | torch.save(model.state_dict(), args.save_path +'/' + args.task + '_epoch'+str(args.num_epochs)+'.pt')
221 | else:
222 | print('evaluating...')
223 | model.load_state_dict(torch.load(args.save_path +'/' + args.task + '_epoch'+str(args.num_epochs)+'.pt'))
224 |
225 | if PACE is not None:
226 | if args.train:
227 | np.save(args.save_path+'/' + args.task + '_mus-epoch'+str(args.num_epochs)+'.npy',PACE._mus)
228 | np.save(args.save_path+'/' + args.task + '_sigmas-epoch'+str(args.num_epochs)+'.npy',PACE._sigmas)
229 | np.save(args.save_path+'/' + args.task + '_eta-epoch'+str(args.num_epochs)+'.npy',PACE._eta)
230 | else:
231 | PACE._mus = np.load(args.save_path+'/' + args.task + '_mus-epoch'+str(args.num_epochs)+'.npy')
232 | PACE._sigmas = np.load(args.save_path+'/' + args.task + '_sigmas-epoch'+str(args.num_epochs)+'.npy')
233 | PACE._eta = np.load(args.save_path+'/' + args.task + '_eta-epoch'+str(args.num_epochs)+'.npy')
234 |
235 | # explain ViT
236 | print('PACE is explaining ViT...')
237 | for i, inputs in enumerate(test_set):
238 | test_encodings = inputs['encodings'].cuda()
239 | test_labels = inputs['labels'].cuda()
240 | logits, states, att = model(test_encodings)
241 |
242 | # infer phi and gamma
243 | gamma, phi = PACE.do_e_step(states, att[args.layer + 1])
244 | # infer E[log(theta)], which is the expectation of log(theta)
245 | E_log_theta = dirichlet_expectation(gamma)
246 |
247 | # Normalize gamma so each row sums to 1
248 | gamma = gamma / gamma.sum(axis=1, keepdims=True)
249 |
250 | # Collect images and gamma and then plot them
251 | num_to_plot = 10
252 | if i == 0:
253 | images_to_plot = []
254 | gammas_to_plot = []
255 | plotted = False
256 |
257 | # Get the original images (before transform) for this batch
258 | start_idx = i * args.eval_batch_size
259 | end_idx = start_idx + test_encodings.shape[0]
260 | # from dataset to original images
261 | test_images = dataset['train']['image']
262 | original_images = test_images[start_idx:end_idx]
263 |
264 | # Move gamma to CPU and convert to numpy if it's a tensor
265 | if isinstance(gamma, torch.Tensor):
266 | gamma_np = gamma.detach().cpu().numpy()
267 | else:
268 | gamma_np = gamma
269 |
270 | for img, g in zip(original_images, gamma_np):
271 | if len(images_to_plot) < num_to_plot:
272 | images_to_plot.append(img)
273 | gammas_to_plot.append(g)
274 | else:
275 | break
276 |
277 | # After the loop, plot the first num_to_plot images and their gamma
278 | if len(images_to_plot) >= num_to_plot and not plotted:
279 | print(f'plotting the first {num_to_plot} images and their normalized gamma, i.e., the image-level explanation theta...')
280 | fig, axes = plt.subplots(num_to_plot, 2, figsize=(8, 4 * num_to_plot))
281 | for idx in range(num_to_plot):
282 | # Show image
283 | axes[idx, 0].imshow(images_to_plot[idx])
284 | axes[idx, 0].axis('off')
285 | axes[idx, 0].set_title(f"Test Image {idx+1}")
286 | # Show gamma as bar plot
287 | axes[idx, 1].bar(np.arange(len(gammas_to_plot[idx])), gammas_to_plot[idx])
288 | axes[idx, 1].set_title(f"Gamma {idx+1}")
289 | plt.tight_layout()
290 | plt.savefig(args.save_path + '/' + args.task + '_epoch' + str(args.num_epochs) + 'gamma_test_images.pdf')
291 | print('plotting done')
292 | plotted = True
293 | # plt.show()
--------------------------------------------------------------------------------
/PACE/src/run_quantitative_theta.py:
--------------------------------------------------------------------------------
1 | from cProfile import label
2 | from transformers import pipeline
3 | from transformers import BertTokenizer, BertModel
4 | from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
5 | from transformers import EarlyStoppingCallback, TrainerCallback
6 | from transformers import ViTFeatureExtractor, ViTForImageClassification
7 | import torch
8 | from torch import nn
9 | from torch.autograd import Variable
10 | from torch.utils.data import DataLoader
11 | from datasets import load_metric,load_dataset
12 | import pickle
13 | import os
14 | from math import pi
15 | import numpy as np
16 | from sklearn.mixture import GaussianMixture
17 | import pickle
18 | from numpy import random
19 | import scipy.sparse as sp
20 | from scipy.special import gammaln
21 | from tqdm import tqdm
22 | from sklearn.decomposition import PCA
23 | from sklearn import manifold
24 | import matplotlib.pyplot as plt
25 | import numpy as np
26 | import pickle
27 | import sys, re, time, string
28 | from scipy.special import gammaln, psi
29 | from numpy.linalg import *
30 | import math
31 | import pandas as pd
32 | import numpy as np
33 | from config import parser
34 | import torchvision.transforms as transforms
35 | import torchvision
36 | from utils import accuracy_score, dirichlet_expectation, read_tsv_file, compute_metrics, Adam, posterior_mu, posterior_mu_sigma, vis, kmeans_init
37 | from utils import run_kmeans, plot_topics
38 | from model import PACE, ViTClassify
39 | from torchviz import make_dot
40 | from utils import load_train_data, load_val_data, softmax, dirichlet_expectation
41 | from torchvision.transforms.functional import InterpolationMode
42 | #from captum.attr import Lime, LimeBase
43 | from augment import contrastive_learning, contrative_transform, image_augment
44 | import wandb
45 | from transformers import TrainingArguments, Trainer, AutoFeatureExtractor
46 | from utils import MyImageDataset
47 | from datasets import load_dataset
48 | from evaluate import stability, faithfulness, sparsity, parsimony, faithfulness_linear
49 | #import shap, lime
50 | from utils import attention_norm, topic_vis, load_dataset_by_task
51 | from PIL import Image
52 | import pdb
53 | from evaluate import to_numpy
54 |
55 | args = parser.parse_args()
56 | args.save_path = os.path.join(args.save_path, args.name)
57 | sample_path = os.path.join('../sample', args.name)
58 |
59 | np.random.seed(args.seed)
60 | torch.manual_seed(args.seed)
61 | random.seed(args.seed)
62 |
63 | # Dataset loading using load_dataset_by_task function
64 | train_dataset, test_dataset, args.out_dim = load_dataset_by_task(args.task, args.data_path)
65 | val_dataset = test_dataset
66 |
67 | model = ViTClassify(in_dim = args.b_dim, out_dim=args.out_dim,hid_dim=args.c_dim, layer=args.layer)
68 | #print(model)
69 |
70 | model = model.cuda()
71 |
72 | if 'PACE' in args.name:
73 | PACE = PACE(d=args.c_dim,K=args.K,D=args.D,N=args.N,alpha=args.alpha,C = args.out_dim)
74 | else:
75 | PACE = None
76 |
77 | training_args = TrainingArguments(
78 | output_dir='./results', # output directory
79 | num_train_epochs=args.num_epochs, # total number of training epochs
80 | per_device_train_batch_size=args.train_batch_size, # batch size per device during training
81 | per_device_eval_batch_size=args.eval_batch_size, # batch size for evaluation
82 | warmup_steps=0, # number of warmup steps for learning rate scheduler change steps from 100 to 0
83 | weight_decay=args.weight_decay, # strength of weight decay
84 | logging_dir='./logs', # directory for storing logs
85 | logging_steps=10,
86 | seed = args.seed,
87 | load_best_model_at_end=True,
88 | metric_for_best_model=args.metric, # 'eval_matthews_correlation' for cola, etc.
89 | eval_strategy='epoch',
90 | save_strategy='epoch',
91 | learning_rate = args.lr,
92 | report_to="wandb",
93 | #resume_from_checkpoint=True,
94 | # eval_steps=100,
95 | )
96 |
97 |
98 |
99 | test_set = DataLoader(val_dataset,batch_size=args.eval_batch_size,shuffle=True)
100 | train_set = DataLoader(train_dataset,batch_size=args.train_batch_size,shuffle=True)
101 |
102 | print('train size', len(train_dataset))
103 | print('eval size', len(val_dataset))
104 |
105 | #X0 = np.load(os.path.join(args.save_path, 'X-L-2.npy'))
106 | #PACE._mus = PACE._mu0 = run_kmeans(X0, args.K)
107 |
108 | print('evaluating')
109 |
110 | model.load_state_dict(torch.load(args.save_path +'/' + args.task + '_epoch'+str(args.num_epochs)+'.pt'))
111 |
112 | #args.version = 'kmeans'
113 | if PACE is not None:
114 | PACE._mus = np.load(args.save_path+'/' + args.task + '_mus-epoch'+str(args.num_epochs)+'.npy')
115 | PACE._sigmas = np.load(args.save_path+'/' + args.task + '_sigmas-epoch'+str(args.num_epochs)+'.npy')
116 | PACE._eta = np.load(args.save_path+'/' + args.task + '_eta-epoch'+str(args.num_epochs)+'.npy')
117 |
118 |
119 |
120 |
121 | # temperarilly CPU bounded, instead of GPU-bounded, needs multi-thread if multiple run at the same time
122 | # numpy matrix manipulation test
123 |
124 | x = None
125 | pos = []
126 | topic = []
127 | name = []
128 | patch_img = []
129 | full_img = []
130 | word_embed = {}
131 | word_cnt = {}
132 | top_words = [{} for _ in range(args.K)] # maintain a priority queue of prob for tokens in each topic
133 | pred_label = []
134 | tok = []
135 | font = []
136 | topic_cnt = {}
137 |
138 |
139 | # for idx in range(args.K): # args.K
140 | # if PACE is None:
141 | # continue
142 | # name.append(0)
143 | # topic.append('T_'+str(idx))
144 | # #font.append(1) # np.exp(det(PACE._sigmas[idx]))
145 | # if x is None:
146 | # x = PACE._mus[idx].reshape(-1,args.c_dim)
147 | # else:
148 | # x = np.concatenate([x,PACE._mus[idx].reshape(-1,args.c_dim)],axis=0)
149 | # patch_img.append(np.ones((224//16,224//16,3))) # patch_img[-1].shape
150 | # full_img.append(np.ones((224,224,3))) # full_img[-1].shape
151 | # pos.append((-1,-1))
152 |
153 | # x = torch.Tensor(x).cuda()
154 |
155 | topic_cnt = dict(sorted(topic_cnt.items(), key=lambda item: item[1],reverse=True))
156 | print(topic_cnt)
157 | concepts = [[] for _ in range(args.K)]
158 | #top_topics = list(topic_cnt)[1:6]
159 | top_topics = {}
160 | tt_cp = [x for x in top_topics]
161 | tw = {}
162 |
163 |
164 | #top_topics = list(topic_cnt)[5:10]
165 |
166 |
167 | topic_se = 24
168 | class_1 = 10
169 | class_2 = 20
170 |
171 | sample_num = 50000#5000
172 | avg_corr = 0
173 | batch_cnt = 0
174 |
175 | # test metrics for LIME model
176 | # ref https://captum.ai/api/lime.html
177 |
178 |
179 | # interprete classifier from embedding inputs
180 |
181 | concept_all = []
182 | label_all = []
183 |
184 | model.eval()
185 | concept_test = []
186 | concept_aug_test = []
187 | concept_train = []
188 | concept_aug_train = []
189 | prob_train = []
190 | prob_test = []
191 |
192 | pred_train = []
193 | pred_test = []
194 | embeds = []
195 | corpus = []
196 | patches = []
197 | attentions = []
198 | concept_images = []
199 | masked_concept_images = []
200 | concept_labels = []
201 | left_up_att = []
202 | right_up_att = []
203 | left_down_att = []
204 | right_down_att = []
205 |
206 | model_map = {'Sedan', 'SUV', 'Convertible', 'Minivan', 'Coupe', 'Wagon', 'Hatchback', 'Van', 'Truck', 'Pickup'}
207 | make_map = {'Audi', 'BMW', 'Chevrolet', 'Dodge', 'Ford', 'Honda', 'Hyundai', 'Jeep', 'Lexus', 'Mercedes-Benz', 'Nissan', 'Porsche', 'Subaru', 'Colorota', 'Volkswagen'}
208 |
209 | dataset_model_cnt = {}
210 | dataset_make_cnt = {}
211 |
212 | # only perform inference if not loading concepts from saved path
213 | if not args.load_concepts:
214 | print("Performing inference on train and test datasets...")
215 |
216 | #train datset
217 | with torch.no_grad():
218 | cnt = 0
219 | for id, inputs in enumerate(train_set):
220 | print('train batch', id)
221 | train_encodings = inputs['encodings'].cuda()
222 | train_labels = inputs['labels'].cuda()
223 | #test_path = inputs['path']
224 | #test_mask = inputs['attention_mask'].cuda()
225 | #print(test_encodings.size())
226 | logits, states, att = model(train_encodings)
227 |
228 | # get augmented outputs
229 | image_trans = image_augment(inputs['encodings'])
230 | logits_trans, states_trans, att_trans = model(image_trans)
231 |
232 | preds = logits.argmax(-1)
233 | #print('preds', preds)
234 | #logits = logits.detach().cpu().numpy()
235 |
236 | for pp in preds:
237 | pred_train.append(pp)
238 | for i in range(len(logits)):
239 | prob_train.append((torch.softmax(logits[i], dim=0)).detach().cpu().numpy())
240 | #print('prob', prob_train[-1])
241 | if PACE is None:
242 | continue
243 |
244 | A_o = att[args.layer]
245 | #A_e = model.effective_attention(att[args.layer+1])
246 | #A_e = attention_norm(A_e)
247 | A_o = attention_norm(A_o)
248 | gamma, phi = PACE.do_e_step(states, A_o) # inference w/o learning, so e step instead of em step.
249 | #gamma_trans, phi_trans = PACE.do_e_step(states_trans, att_trans[args.layer + 1])
250 | E_log_theta = dirichlet_expectation(gamma)
251 | concept_train.append(np.exp(E_log_theta))
252 | #concept_train.append(phi.mean(1))
253 | #print(phi.mean(1)[0])
254 | # A_o_trans = att_trans[args.layer]
255 | # A_o_trans = attention_norm(A_o_trans)
256 | # gamma_trans, phi_trans = PACE.do_e_step(states_trans, A_o)
257 | # concept_aug_train.append(phi_trans.mean(1))
258 |
259 | # test dataset
260 | with torch.no_grad():
261 | cnt = 0
262 | for id, inputs in enumerate(test_set):
263 | print('test batch', id)
264 | test_encodings = inputs['encodings'].cuda()
265 | test_labels = inputs['labels'].cuda()
266 | #test_path = inputs['path']
267 | #test_mask = inputs['attention_mask'].cuda()
268 | #print(test_encodings.size())
269 | logits, states, att = model(test_encodings)
270 |
271 | # get augmented outputs
272 | image_trans = image_augment(inputs['encodings'])
273 | logits_trans, states_trans, att_trans = model(image_trans)
274 |
275 | preds = logits.argmax(-1)
276 | #logits = logits.detach().cpu().numpy()
277 |
278 | for pp in preds:
279 | pred_test.append(pp)
280 | for i in range(len(logits)):
281 | prob_test.append((torch.softmax(logits[i], dim=0)).detach().cpu().numpy())
282 | #print('prob', prob_test[-1])
283 | if PACE is None:
284 | continue
285 |
286 | A_o = att[args.layer]
287 | #A_e = model.effective_attention(att[args.layer+1])
288 | #A_e = attention_norm(A_e)
289 | A_o = attention_norm(A_o)
290 | gamma, phi = PACE.do_e_step(states, A_o) # inference w/o learning, so e step instead of em step.
291 | #gamma_trans, phi_trans = PACE.do_e_step(states_trans, att_trans[args.layer + 1])
292 | E_log_theta = dirichlet_expectation(gamma)
293 | concept_test.append(np.exp(E_log_theta))
294 | #concept_test.append(phi.mean(1))
295 | #print(phi.mean(1)[0])
296 | A_o_trans = att_trans[args.layer]
297 | A_o_trans = attention_norm(A_o_trans)
298 | gamma_trans, phi_trans = PACE.do_e_step(states_trans, A_o)
299 | E_log_theta = dirichlet_expectation(gamma_trans)
300 | concept_aug_test.append(np.exp(E_log_theta))
301 | #concept_aug_test.append(phi_trans.mean(1))
302 |
303 | # process inferred data
304 | concept_train = np.concatenate(concept_train, axis=0)
305 | concept_test = np.concatenate(concept_test, axis=0)
306 | concept_aug_test = np.concatenate(concept_aug_test, axis=0)
307 | # concept_aug_train = np.concatenate(concept_aug_train, axis=0)
308 | pred_train = to_numpy(torch.stack(pred_train))
309 | pred_test = to_numpy(torch.stack(pred_test))
310 | prob_train = np.array(prob_train)
311 | prob_test = np.array(prob_test)
312 | print('prob_test', prob_test.shape)
313 | else:
314 | print("Skipping inference - will load concepts and probabilities from saved path...")
315 |
316 | # conditionally load concepts and probabilities from saved path or use inferred ones
317 | if args.load_concepts:
318 | print("Loading concepts and probabilities from saved path...")
319 | concept_train = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_train.npy'))
320 | concept_test = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_test.npy'))
321 | # concept_aug_train = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_aug_train.npy'))
322 | concept_aug_test = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_aug_test.npy'))
323 | pred_train = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-pred_train.npy'))
324 | pred_test = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-pred_test.npy'))
325 | prob_train = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-prob_train.npy'))
326 | prob_test = np.load(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-prob_test.npy'))
327 | else:
328 | print("Using inferred concepts and probabilities from model...")
329 |
330 |
331 | # normalize concept_train and concept_test
332 | concept_train = concept_train / concept_train.sum(axis=1)[:,None]
333 | concept_test = concept_test / concept_test.sum(axis=1)[:,None]
334 | # concept_aug_train = concept_aug_train / concept_aug_train.sum(axis=1)[:,None]
335 | concept_aug_test = concept_aug_test / concept_aug_test.sum(axis=1)[:,None]
336 |
337 |
338 | # only save concepts and probabilities if we inferred them (not loaded from saved path)
339 | if not args.load_concepts:
340 | print("Saving inferred concepts and probabilities...")
341 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_train.npy'),concept_train)
342 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_test.npy'),concept_test)
343 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_aug_train.npy'),concept_aug_train)
344 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-concept_aug_test.npy'),concept_aug_test)
345 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-pred_train.npy'),pred_train)
346 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-pred_test.npy'),pred_test)
347 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-prob_train.npy'),prob_train)
348 | np.save(os.path.join(args.save_path, str(args.task) +'_epoch'+str(args.num_epochs) + '-prob_test.npy'),prob_test)
349 | else:
350 | print("Skipping save since concepts and probabilities were loaded from saved path...")
351 |
352 | #pdb.set_trace()
353 |
354 | stability_score = stability(concept_test, concept_aug_test)
355 |
356 | print('stability', stability_score)
357 |
358 | fhard, fsoft = faithfulness_linear(concept_train, pred_train, concept_test, pred_test, prob_test=prob_test)
359 | #faithfulness = faithfulness(concept_test, pred_test, concept_test, pred_test)
360 |
361 | print('faithfulness', fhard, fsoft)
362 |
363 | sparsity = sparsity(concept_test)
364 |
365 | print('sparsity', sparsity)
366 |
367 | parsimony = parsimony(concept_test)
368 |
369 | print('parsimony', parsimony)
370 |
371 | # log txt file
372 | with open(args.save_path + '/' + args.task + '_epoch' + str(args.num_epochs) + '.txt', 'w') as f:
373 | f.write('stability: ' + str(stability_score) + '\n')
374 | f.write('faithfulness: ' + str(fhard) + ' ' + str(fsoft) + '\n')
375 | f.write('sparsity: ' + str(sparsity) + '\n')
376 | f.write('parsimony: ' + str(parsimony) + '\n')
377 |
378 |
379 |
--------------------------------------------------------------------------------
/PACE/src/model.py:
--------------------------------------------------------------------------------
1 | from transformers import pipeline
2 | from transformers import ViTConfig, ViTForImageClassification
3 | from transformers import ViTForImageClassification, ViTConfig
4 | import torch
5 | from torch import nn
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | import pickle
9 | import os
10 | from math import pi
11 | import numpy as np
12 | from sklearn.mixture import GaussianMixture
13 | import pickle
14 | from numpy import random
15 | import scipy.sparse as sp
16 | from scipy.special import gammaln
17 | from tqdm import tqdm
18 | from sklearn.decomposition import PCA
19 | from sklearn import manifold
20 | import matplotlib.pyplot as plt
21 | import numpy as np
22 | import pickle
23 | import sys, re, time, string
24 | from scipy.special import gammaln, psi
25 | from numpy.linalg import *
26 | import math
27 | import pandas as pd
28 | from config import parser
29 | #from config_parse_args import parser_args
30 | from utils import accuracy_score, dirichlet_expectation, read_tsv_file, compute_metrics, Adam, posterior_mu_sigma
31 | import time
32 | from torch.nn import functional as F
33 | from scipy.special import softmax
34 | from augment import contrastive_learning, contrative_transform
35 | from utils import Adam
36 |
37 | args = parser.parse_args()
38 | args.save_path = os.path.join(args.save_path, args.name)
39 |
40 |
41 |
42 | class ViTClassify(nn.Module):
43 | def __init__(self, in_dim, out_dim, hid_dim, layer):
44 | super(ViTClassify, self).__init__()
45 | self.in_dim = in_dim
46 | self.out_dim = out_dim
47 | self.hid_dim = hid_dim
48 | self.layer = layer
49 |
50 |
51 | ViT_config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k", output_hidden_states=True, output_attentions=True, num_labels=out_dim)
52 | self.ViT = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', config=ViT_config)
53 |
54 | if not args.require_grad: # vit be none trainable, train PACE instead
55 | for param in self.ViT.parameters():
56 | param.requires_grad = False
57 | self.linear = nn.Linear(in_dim, hid_dim)
58 | self.embedding = None
59 |
60 | def forward(self, encodings, labels=None):
61 | ViT_output = self.ViT(encodings)
62 | logits = ViT_output['logits']
63 | all_states = ViT_output['hidden_states']
64 | attention = ViT_output['attentions']
65 | states = all_states[self.layer]
66 | hidden = self.linear(states)
67 | self.embedding = all_states[args.layer]
68 |
69 | return logits, hidden, attention
70 | class PACE:
71 |
72 | def __init__(self, d, K, D, N, alpha, C):
73 | '''
74 | Arguments:
75 | K: number of topics
76 | d: dimension of embedding space
77 | D: number of images
78 | alpha: prior on theta
79 | eta: prior on mu and Sigma
80 | '''
81 | self._d = d
82 | self._K = K
83 | self._D = D
84 | self._N = N
85 | self._C = C
86 | self._alpha = alpha
87 | self._updatect = 0
88 | self._gamma = None
89 | self._phi = None
90 | self._mu0 = np.random.randn(K,d) * 10
91 | self._mus = np.random.randn(K,d) * 10
92 | self._sigmas = np.array([np.eye(d) for _ in range(K)])
93 | self._eta = np.random.randn(C,K) # (num_class, num_topic)
94 | self._updatect = 0
95 | self._eps = 1e-50
96 | self._converge = False
97 | self._m_mu = 0
98 | self._m_sigma = 0
99 | self._cnt = 0
100 | self._snum = 0
101 | self._sigma0 = np.array([np.eye(self._d) for _ in range(self._K)])
102 | self._lr = 1e-2
103 | self._embeds = None
104 | self._delta = None
105 | self.faitful_loss = 0
106 | self.stability_loss = 0
107 | self._adam_delta = Adam(1e-4,0.9,0.9)
108 | self._reg = 1e-3
109 | self._w = None
110 | self.linear = nn.Linear(args.b_dim, args.c_dim).cuda()
111 | self._lrm = torch.empty(args.b_dim, args.c_dim).cuda()
112 | torch.nn.init.normal_(self._lrm)
113 | self._phi_trans = None
114 | self.optimizer = None
115 |
116 | def log_p_y(self, logits):
117 |
118 | logits = logits.cpu().detach().numpy()
119 | log_p = 0
120 | B, _ = logits.shape
121 | for b in range(B):
122 | phi_mean = self._phi[b].mean(0)
123 | pred_y = softmax(logits[b])
124 | for i in range(self._C):
125 | log_p += pred_y[i] * np.log(np.exp(self._eta[i].dot(phi_mean))/np.sum(np.exp(self._eta.dot(phi_mean))))
126 |
127 | return log_p
128 |
129 |
130 | def update_eta(self, logits):
131 | logits = logits.cpu().detach().numpy()
132 | pred_y = softmax(logits, axis=1)
133 | phi_mean = self._phi.mean((0,1))
134 | for i in range(self._C):
135 | self._eta[i] = self._eta[i] - self._lr * (pred_y[:,i].mean() - np.exp(self._eta[i].dot(phi_mean))/np.sum(np.exp(self._eta.dot(phi_mean)))) * phi_mean
136 |
137 |
138 |
139 | def do_e_step(self, embeds, frac=None):
140 | batchD = len(embeds)
141 |
142 |
143 | if args.frac == 'equal':
144 | self._w = torch.Tensor(1).cuda()
145 | elif args.frac=='fix': # fixed length attention
146 | self._w = frac.mean(1)[:,0,:]
147 |
148 | phi = random.gamma(self._K, 1./self._K, (batchD, embeds.size()[1], self._K))
149 | gamma = phi.sum(1)
150 |
151 |
152 | it = 0
153 | meanchange = 0
154 | sigma_invs = []
155 | sigma_dets = []
156 | for i in range(self._K):
157 | sigma_inv = inv(self._sigmas[i] + self._eps * np.eye(self._d))
158 | sigma_invs.append(sigma_inv)
159 | sigma_det = det(self._sigmas[i])
160 | sigma_dets.append(sigma_det)
161 |
162 | sigma_invs = torch.Tensor(np.array(sigma_invs))
163 | sigma_invs = Variable(sigma_invs, requires_grad=True).cuda()
164 | sigma_dets = torch.Tensor(np.array(sigma_dets))
165 | sigma_dets = Variable(sigma_dets, requires_grad=True).cuda()
166 | self._delta = 0
167 | self._delta = Variable(torch.Tensor(self._delta), requires_grad=True)
168 |
169 | tensor_mus = Variable(torch.Tensor(np.array(self._mus)),requires_grad=False).cuda()
170 | tensor_sigmas = Variable(torch.Tensor(np.array(self._sigmas)),requires_grad=False).cuda()
171 | it = 0
172 | meanchange = 0
173 | sigma_invs = []
174 | sigma_dets = []
175 | for i in range(self._K):
176 | sigma_inv = inv(self._sigmas[i] + self._eps * np.eye(self._d))
177 | sigma_invs.append(sigma_inv)
178 | sigma_det = det(self._sigmas[i])
179 | sigma_dets.append(sigma_det)
180 | sigma_invs = torch.Tensor(np.array(sigma_invs))
181 | sigma_invs = Variable(sigma_invs, requires_grad=True).cuda()
182 | sigma_dets = torch.Tensor(np.array(sigma_dets))
183 | sigma_dets = Variable(sigma_dets, requires_grad=True).cuda()
184 | self._delta = 0
185 | self._delta = Variable(torch.Tensor(self._delta), requires_grad=True)
186 |
187 | # Iterate between gamma and phi until convergence
188 | tensor_mus = Variable(torch.Tensor(np.array(self._mus)),requires_grad=False).cuda()
189 | tensor_sigmas = Variable(torch.Tensor(np.array(self._sigmas)),requires_grad=False).cuda()
190 | for it in range(0,10): # train to converge.
191 |
192 | mus = tensor_mus.unsqueeze(-1)
193 | embeds = embeds.view(-1,self._d,1)
194 |
195 |
196 | dir_exp = dirichlet_expectation(gamma)
197 | dir_exp = Variable(torch.Tensor(dir_exp),requires_grad=False).view(-1,self._K).cuda()
198 | spd_tensor = None
199 | phi_tensor = Variable(torch.Tensor(phi),requires_grad=False).cuda()
200 | for i in range(self._K):
201 | mul = torch.matmul((embeds-mus[i,:]).transpose(1,2), sigma_invs[i,:,:])
202 | spd = 0.5 * torch.matmul(mul,embeds-mus[i,:]).view(-1,self._N)
203 |
204 | if spd_tensor is None:
205 | spd_tensor = spd.unsqueeze(-1)
206 | else:
207 | spd_tensor = torch.cat([spd_tensor,spd.unsqueeze(-1)],dim=2)
208 |
209 | phi_tmp = torch.exp(dir_exp[:,i].view(-1,1)-spd/100 )/torch.sqrt(abs(sigma_dets[i])+self._eps)
210 |
211 | phi_tmp =phi_tmp * self._w
212 | phi[:,:,i] = phi_tmp.cpu().detach().numpy()
213 |
214 | phi = (phi+self._eps)/(phi.sum(-1, keepdims=True) + self._eps * self._K) # along K axis
215 |
216 | self._delta = phi_tensor * spd_tensor
217 |
218 | self._delta = self._delta.mean() * batchD
219 | gamma = self._alpha + phi.sum(1) # along n axis
220 | del dir_exp, phi_tensor, spd_tensor
221 | del sigma_invs, sigma_dets, tensor_mus, tensor_sigmas
222 | self._phi = phi
223 | self._gamma = gamma
224 |
225 | return (gamma, phi)
226 |
227 | def do_cl_e_step(self, embeds, frac=None, y = None):
228 | '''
229 | frac: attention of last hidden layer, size [B,num_heads,N,N]
230 | '''
231 | batchD = len(embeds)
232 | f_optimizer = Adam(1e-4,0.9,0.9)
233 | s_optimizer = Adam(1e-4,0.9,0.9)
234 | if args.frac == 'equal':
235 | self._w = torch.Tensor(1).cuda()
236 | elif args.frac=='fix':
237 | self._w = frac.mean(1)[:,0,:]
238 |
239 |
240 | phi = random.gamma(self._K, 1./self._K, (batchD, embeds.size()[1], self._K))
241 | gamma = phi.sum(1)
242 |
243 |
244 | it = 0
245 | meanchange = 0
246 | sigma_invs = []
247 | sigma_dets = []
248 |
249 | for i in range(self._K):
250 | sigma_inv = inv(self._sigmas[i] + self._eps * np.eye(self._d))
251 | sigma_invs.append(sigma_inv)
252 | sigma_det = det(self._sigmas[i])
253 | sigma_dets.append(sigma_det)
254 | sigma_invs = torch.Tensor(np.array(sigma_invs))
255 | sigma_invs = Variable(sigma_invs, requires_grad=True).cuda()
256 | sigma_dets = torch.Tensor(np.array(sigma_dets))
257 | sigma_dets = Variable(sigma_dets, requires_grad=True).cuda()
258 | self._delta = 0
259 | self._delta = Variable(torch.Tensor(self._delta), requires_grad=True)
260 |
261 | tensor_mus = Variable(torch.Tensor(np.array(self._mus)),requires_grad=False).cuda()
262 | tensor_sigmas = Variable(torch.Tensor(np.array(self._sigmas)),requires_grad=False).cuda()
263 | it = 0
264 | meanchange = 0
265 | sigma_invs = []
266 | sigma_dets = []
267 | for i in range(self._K):
268 | sigma_inv = inv(self._sigmas[i] + self._eps * np.eye(self._d))
269 | sigma_invs.append(sigma_inv)
270 | sigma_det = det(self._sigmas[i])
271 | sigma_dets.append(sigma_det)
272 |
273 | sigma_invs = torch.Tensor(np.array(sigma_invs))
274 | sigma_invs = Variable(sigma_invs, requires_grad=True).cuda()
275 | sigma_dets = torch.Tensor(np.array(sigma_dets))
276 | sigma_dets = Variable(sigma_dets, requires_grad=True).cuda()
277 | self._delta = 0
278 | self._delta = Variable(torch.Tensor(self._delta), requires_grad=True)
279 |
280 | # Iterate between gamma and phi until convergence
281 | tensor_mus = Variable(torch.Tensor(np.array(self._mus)),requires_grad=False).cuda()
282 | tensor_sigmas = Variable(torch.Tensor(np.array(self._sigmas)),requires_grad=False).cuda()
283 |
284 | for it in range(0,10): # train to converge.
285 | mus = tensor_mus.unsqueeze(-1)
286 | embeds = embeds.view(-1,self._d,1)
287 | dir_exp = dirichlet_expectation(gamma)
288 | dir_exp = Variable(torch.Tensor(dir_exp),requires_grad=False).view(-1,self._K).cuda()
289 | spd_tensor = None
290 | phi_tensor = Variable(torch.Tensor(phi),requires_grad=False).cuda()
291 | phi_mean = torch.mean(phi_tensor, dim=1).detach().cpu().numpy()
292 | phi_fair_delta = 1/self._N * (np.einsum('bi,ij->bj',y, self._eta)- np.einsum('bi,ij->bj',np.exp(np.einsum('ij,bj->bi',self._eta, phi_mean)), self._eta)/np.sum(np.exp(np.einsum('ij,bj->bi',self._eta, phi_mean)), axis=1, keepdims=True) )
293 | phi_fair_delta = torch.from_numpy(phi_fair_delta).cuda()
294 | phi_trans_mean = self._phi_trans.mean(1)
295 | phi_stable_delta, _ = contrastive_learning(torch.from_numpy(phi_mean).cuda(), torch.from_numpy(phi_trans_mean).cuda())
296 | phi_stable_delta *= -1/self._N
297 | eps = 1e-10
298 | for i in range(self._K):
299 | mul = torch.matmul((embeds-mus[i,:]).transpose(1,2), sigma_invs[i,:,:])
300 | spd = 0.5 * torch.matmul(mul,embeds-mus[i,:]).view(-1,self._N)
301 |
302 | if spd_tensor is None:
303 | spd_tensor = spd.unsqueeze(-1)
304 | else:
305 | spd_tensor = torch.cat([spd_tensor,spd.unsqueeze(-1)],dim=2)
306 | phi_tmp = torch.exp(dir_exp[:,i].view(-1,1)-spd+self._eps)
307 | phi_tmp =phi_tmp * self._w
308 | phi[:,:,i] = phi_tmp.cpu().detach().numpy()
309 | phi = (phi+self._eps)/(phi.sum(-1, keepdims=True) + self._eps * self._K) # along K axis
310 |
311 | self._delta = phi_tensor * spd_tensor
312 |
313 | self._delta = self._delta.mean() * batchD
314 | relevance_scale = 1
315 | stable_scale = 1
316 | phi = np.clip(phi, 0, 1e10)
317 | phi = (phi+self._eps)/(phi.sum(-1).reshape(batchD,self._N,1) + self._eps * self._K)
318 | # additional grad with second-order optimization
319 | f_grad = f_optimizer.update(phi_fair_delta.unsqueeze(1).detach().cpu().numpy())
320 | s_grad = s_optimizer.update(phi_stable_delta.unsqueeze(1).detach().cpu().numpy())
321 | phi += relevance_scale * f_grad + stable_scale * s_grad
322 | phi = (phi+self._eps)/(phi.sum(-1, keepdims=True) + self._eps * self._K)
323 |
324 | gamma = self._alpha + phi.sum(1) # along n axis
325 | del dir_exp, phi_tensor, spd_tensor
326 | del sigma_invs, sigma_dets, tensor_mus, tensor_sigmas
327 | self._phi = phi
328 | self._gamma = gamma
329 |
330 | return (gamma, phi)
331 |
332 |
333 |
334 | def update_lambda(self, embeds):
335 | '''
336 | update variational parameter lambda for beta
337 | '''
338 | def do_em_step(self, embeds,frac=None, cl=False, y=None):
339 | '''
340 | first E step,
341 | then M step
342 |
343 | gamma (B,K)
344 | phi (B,N,K)
345 | embeds (B,N,d)
346 | cl: using contrastive learning
347 | '''
348 |
349 | if cl and y is not None:
350 | gamma, phi = self.do_cl_e_step(embeds=embeds, frac=frac, y=y)
351 | else:
352 | gamma, phi = self.do_e_step(embeds,frac)
353 | embeds = embeds.cpu().detach().numpy()
354 | batchD = len(embeds)
355 | last_mus = self._mus.copy()
356 | last_sigmas = self._sigmas.copy()
357 |
358 | mu_init = 0
359 | sigma_init = 0
360 | norm = 0
361 | gamma = gamma.reshape(-1, self._K)
362 | phi = phi.reshape(-1,self._K)
363 | embeds = embeds.reshape(-1,self._d)
364 | mu_init = phi.reshape(-1,self._K,1) * embeds.reshape(-1,1,self._d)
365 | mu_init = mu_init * self._w.cpu().detach().numpy().reshape(-1,1,1)
366 | mu_init = mu_init.sum(0)
367 | delta = embeds.reshape(-1,1,self._d)-self._mus.reshape(1,-1,self._d)
368 | square = delta.reshape(-1,self._K,1,self._d) * delta.reshape(-1,self._K,self._d, 1) # (B*N,K,d,d)
369 | sigma_init = phi.reshape(-1,self._K,1,1) * square
370 | sigma_init = sigma_init * self._w.cpu().detach().numpy().reshape(-1,1,1,1)
371 | sigma_init = sigma_init.sum(0)
372 | norm = (phi*self._w.cpu().detach().numpy().reshape(-1,1)).sum(0)
373 |
374 | rhot = 0.75
375 |
376 | self._m_mu = self._mus * rhot * self._cnt + mu_init/(norm.reshape(-1,1)+self._eps) * (1-rhot) * batchD
377 | self._m_sigma = self._sigmas * rhot * self._cnt + sigma_init/(norm.reshape(-1,1,1)+self._eps) * (1-rhot) * batchD
378 | self._cnt = self._cnt + batchD * (1-rhot)
379 | self._mus = self._m_mu / self._cnt
380 | self._sigmas = self._m_sigma / self._cnt
381 |
382 | self._snum += 1
383 | for i in range(self._K):
384 | self._mus[i], self._sigmas[i] = posterior_mu_sigma(self._mus[i], self._sigmas[i], self._snum, self._mu0[i])
385 |
386 | if abs(last_mus - self._mus).max() < 1e-5:
387 | self._converge = True
388 |
389 | phi = phi.reshape(-1,self._N,self._K)
390 |
391 | return gamma, phi
392 |
393 |
--------------------------------------------------------------------------------
/PACE/src/utils.py:
--------------------------------------------------------------------------------
1 | from tkinter import image_types
2 | import numpy as np
3 | import pandas as pd
4 | from numpy.linalg import *
5 | from scipy.linalg import sqrtm
6 | from scipy.special import gammaln, psi
7 | import torch
8 | from sklearn.decomposition import PCA
9 | import matplotlib.pyplot as plt
10 | from datasets import load_metric
11 | from config import parser
12 | from transformers import ViTFeatureExtractor, ViTImageProcessor
13 | from datasets import load_dataset
14 | from torch.utils.data import random_split
15 | from torchvision.transforms import (CenterCrop,
16 | Compose,
17 | Normalize,
18 | RandomHorizontalFlip,
19 | RandomResizedCrop,
20 | Resize,
21 | ToTensor)
22 | from torch import FloatTensor, div
23 | from torch.utils.data import DataLoader, Dataset
24 | from torchvision import transforms
25 | from torchvision.transforms.functional import InterpolationMode
26 |
27 | import logging
28 | import os
29 | import random
30 | import pandas as pd
31 | from torchvision.datasets.folder import default_loader
32 | from torchvision.datasets.utils import download_url
33 | from torch.utils.data import Dataset
34 | import pickle
35 | from sklearn.cluster import KMeans
36 | from torchvision.datasets import VisionDataset
37 | import os
38 | from PIL import Image
39 |
40 |
41 | from PIL import Image, ImageOps
42 |
43 |
44 | def attention_norm(attention, a=1):
45 | # attention = attention.mean(1)
46 | # max(attention, 0)
47 | #attention = torch.max(attention, torch.zeros_like(attention))
48 | attention = attention - attention.min()
49 |
50 | #attention = attention ** a
51 | attention = attention / attention.sum(-1, keepdims=True)
52 |
53 | return attention
54 |
55 |
56 | class StanfordCars_simplified(torch.utils.data.Dataset):
57 | def __init__(self, root, transform = None):
58 | self.images = [os.path.join(root, file) for file in os.listdir(root)]
59 | self.transform = transform
60 |
61 | def __len__(self):
62 | return len(self.images)
63 |
64 | def __getitem__(self, index):
65 | image_file = self.images[index]
66 | image = Image.open(image_file).convert("RGB")
67 | if self.transform:
68 | image = self.transform(image)
69 | return image[None]
70 |
71 |
72 | class StanfordCars(VisionDataset):
73 | """`Stanford Cars `_ Dataset
74 |
75 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is
76 | split into 8,144 training images and 8,041 testing images, where each class
77 | has been split roughly in a 50-50 split
78 |
79 | .. note::
80 |
81 | This class needs `scipy `_ to load target files from `.mat` format.
82 |
83 | Args:
84 | root (string): Root directory of dataset
85 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
86 | transform (callable, optional): A function/transform that takes in an PIL image
87 | and returns a transformed version. E.g, ``transforms.RandomCrop``
88 | target_transform (callable, optional): A function/transform that takes in the
89 | target and transforms it.
90 | download (bool, optional): If True, downloads the dataset from the internet and
91 | puts it in root directory. If dataset is already downloaded, it is not
92 | downloaded again."""
93 |
94 | def __init__(
95 | self,
96 | root: str,
97 | split = "train",
98 | transform = None,
99 | target_transform = None,
100 | download: bool = False,
101 | ):
102 |
103 | try:
104 | import scipy.io as sio
105 | except ImportError:
106 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
107 |
108 | super().__init__(root, transform=transform, target_transform=target_transform)
109 |
110 | #self._split = verify_str_arg(split, "split", ("train", "test"))
111 | #self._base_folder = pathlib.Path(root) / "stanford_cars"
112 | self._split = split
113 | self._base_folder = root
114 | devkit = os.path.join(self._base_folder, "devkit")
115 |
116 |
117 | if self._split == "train":
118 | self._annotations_mat_path = os.path.join(devkit , "cars_train_annos.mat")
119 | self._images_base_path = os.path.join(self._base_folder , "cars_train")
120 | else:
121 | self._annotations_mat_path = os.path.join(self._base_folder , "cars_test_annos_withlabels.mat")
122 | self._images_base_path = os.path.join(self._base_folder , "cars_test")
123 |
124 | if download:
125 | self.download()
126 |
127 | self._samples = [
128 | (
129 | str(os.path.join(self._images_base_path , annotation["fname"])),
130 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1
131 | )
132 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
133 | ]
134 |
135 | self.classes = sio.loadmat(str(os.path.join(devkit , "cars_meta.mat")), squeeze_me=True)["class_names"].tolist()
136 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
137 |
138 | def __len__(self):
139 | return len(self._samples)
140 |
141 | def __getitem__(self, idx: int):
142 | """Returns pil_image and class_id for given index"""
143 | image_path, target = self._samples[idx]
144 | pil_image = Image.open(image_path).convert("RGB")
145 |
146 | if self.transform is not None:
147 | pil_image = self.transform(pil_image)
148 | if self.target_transform is not None:
149 | target = self.target_transform(target)
150 | return pil_image, target
151 |
152 | def download(self):
153 | if self._check_exists():
154 | return
155 | def _check_exists(self):
156 | if not (self._base_folder / "devkit").is_dir():
157 | return False
158 |
159 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
160 |
161 |
162 |
163 |
164 | def train_transforms(examples):
165 | examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']]
166 | return examples
167 |
168 | def val_transforms(examples):
169 | examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']]
170 | return examples
171 |
172 |
173 |
174 | args = parser.parse_args()
175 |
176 | def accuracy_score(labels, preds):
177 | acc = (preds==labels).astype(np.float).mean()
178 | return acc
179 |
180 | def compute_metrics_acc(pred):
181 | #labels = pred.label_ids
182 | labels, preds = pred.predictions
183 |
184 | #print('labels',labels)
185 | #print('preds',preds)
186 | #precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
187 | acc = accuracy_score(labels, preds)
188 | return {
189 | 'accuracy': acc,
190 | #'f1': f1,
191 | #'precision': precision,
192 | #'recall': recall
193 | }
194 |
195 | def compute_metrics(pred):
196 | #labels = pred.label_ids
197 | labels, preds = pred.predictions
198 | #metric = load_metric('glue', args.task)
199 | metric =load_metric('accuracy')
200 | return metric.compute(predictions=preds, references=labels)
201 |
202 | def dirichlet_expectation(alpha):
203 | '''
204 | E[log(theta)|alpha], where theta ~ Dir(alpha).
205 | from blei/online LDA
206 | '''
207 | if len(alpha.shape) == 1: # 1D version
208 | return psi(alpha) - psi(np.sum(alpha))
209 | return psi(alpha) - psi(np.sum(alpha,1))[:, np.newaxis]
210 |
211 |
212 | def read_tsv_file(file_path):
213 | df = pd.read_csv(file_path,sep='\t')
214 | seq = df['sentence']
215 | return seq
216 |
217 |
218 |
219 | class DynamicCrop(object):
220 | def __init__(self, is_train=True):
221 | self.is_train = is_train
222 |
223 | def __call__(self, img):
224 | w, h = img.size
225 | crop_size = min(w, h)
226 |
227 | left_margin = (w - crop_size) / 2
228 | top_margin = (h - crop_size) / 2
229 |
230 | # Random crop for training
231 | if self.is_train:
232 | left_margin = random.randint(0, w - crop_size)
233 | top_margin = random.randint(0, h - crop_size)
234 |
235 | img = img.crop((left_margin, top_margin, left_margin + crop_size, top_margin + crop_size))
236 | return img
237 |
238 | def build_transform(output_size, is_train=True):
239 | """
240 | Get the appropriate image transformation based on the training/testing phase.
241 |
242 | Parameters:
243 | - output_size (int or tuple): Size for resizing the cropped image.
244 | - is_train (bool): If True, random crop and resize are performed. Otherwise, center crop and resize.
245 |
246 | Returns:
247 | - torchvision.transforms.Compose: A composition of transformations.
248 | """
249 | return transforms.Compose([
250 | DynamicCrop(is_train=is_train),
251 | transforms.Resize(output_size),
252 | transforms.ToTensor()
253 | ])
254 |
255 |
256 | def build_transform_prev(crop_size, output_size, is_train=True):
257 | """
258 | Get the appropriate image transformation based on the training/testing phase.
259 |
260 | Parameters:
261 | - crop_size (int or tuple): Size for cropping. If int, a square crop is made.
262 | - output_size (int or tuple): Size for resizing the cropped image.
263 | - train (bool): If True, random crop and resize are performed. Otherwise, center crop and resize.
264 |
265 | Returns:
266 | - torchvision.transforms.Compose: A composition of transformations.
267 | """
268 | if is_train:
269 | return transforms.Compose([
270 | transforms.RandomCrop(crop_size),
271 | transforms.Resize(output_size),
272 | transforms.ToTensor()
273 | ])
274 | else:
275 | return transforms.Compose([
276 | transforms.CenterCrop(crop_size),
277 | transforms.Resize(output_size),
278 | transforms.ToTensor()
279 | ])
280 |
281 |
282 | class Cub2011(Dataset):
283 | base_folder = 'CUB_200_2011/images'
284 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
285 | filename = 'CUB_200_2011.tgz'
286 | tgz_md5 = '97eceeb196236b17998738112f37df78'
287 |
288 | def __init__(self, root, train=True, transform=None, loader=default_loader, download=True):
289 | self.root = os.path.expanduser(root)
290 | self.transform = transform
291 | self.loader = default_loader
292 | self.train = train
293 |
294 | if download:
295 | self._download()
296 |
297 | if not self._check_integrity():
298 | raise RuntimeError('Dataset not found or corrupted.' +
299 | ' You can use download=True to download it')
300 |
301 | def _load_metadata(self):
302 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
303 | names=['img_id', 'filepath'])
304 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
305 | sep=' ', names=['img_id', 'target'])
306 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
307 | sep=' ', names=['img_id', 'is_training_img'])
308 |
309 | data = images.merge(image_class_labels, on='img_id')
310 | self.data = data.merge(train_test_split, on='img_id')
311 |
312 | if self.train:
313 | self.data = self.data[self.data.is_training_img == 1]
314 | else:
315 | self.data = self.data[self.data.is_training_img == 0]
316 |
317 | def _check_integrity(self):
318 | try:
319 | self._load_metadata()
320 | except Exception:
321 | return False
322 |
323 | for index, row in self.data.iterrows():
324 | filepath = os.path.join(self.root, self.base_folder, row.filepath)
325 | if not os.path.isfile(filepath):
326 | print(filepath)
327 | return False
328 | return True
329 |
330 | def _download(self):
331 | import tarfile
332 |
333 | if self._check_integrity():
334 | print('Files already downloaded and verified')
335 | return
336 |
337 | download_url(self.url, self.root, self.filename, self.tgz_md5)
338 |
339 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
340 | tar.extractall(path=self.root)
341 |
342 | def __len__(self):
343 | return len(self.data)
344 |
345 | def __getitem__(self, idx):
346 | sample = self.data.iloc[idx]
347 | path = os.path.join(self.root, self.base_folder, sample.filepath)
348 | #print('sample', sample)
349 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0
350 | img = self.loader(path)
351 | #print(img)
352 | #print(target)
353 | if self.transform is not None:
354 | img = self.transform(img)['pixel_values'][0]
355 | return {'encodings':img, 'labels':target, 'path': sample.filepath}
356 |
357 | class MyImageDataset(Dataset):
358 | """Dataset class for Image"""
359 | def __init__(self, dataset, labels, transform=None, normalize=None):
360 | super(MyImageDataset, self).__init__()
361 | assert(len(dataset) == len(labels))
362 | self.dataset = dataset
363 | self.labels = labels
364 | self.transform = transform
365 | self.normalize = normalize
366 |
367 | #print(self.labels)
368 |
369 | def __len__(self):
370 | return len(self.dataset)
371 |
372 | def __getitem__(self, idx):
373 | data = self.dataset[idx]
374 |
375 | if self.transform:
376 | data = self.transform(data)
377 | img_to_tensor = transforms.ToTensor()
378 | # if data is not tensor
379 | if not isinstance(data, torch.Tensor):
380 | data = img_to_tensor(data)
381 | if self.normalize:
382 | data = self.normalize(data)
383 |
384 | return {'encodings':data, 'labels':self.labels[idx]}
385 |
386 |
387 | class ImageNetDataset(Dataset):
388 | """Dataset class for ImageNet"""
389 | def __init__(self, dataset, labels, transform=None, normalize=None):
390 | super(ImageNetDataset, self).__init__()
391 | assert(len(dataset) == len(labels))
392 | self.dataset = dataset
393 | self.labels = labels
394 | self.transform = transform
395 | self.normalize = normalize
396 | #print(self.labels)
397 |
398 | def __len__(self):
399 | return len(self.dataset)
400 |
401 | def __getitem__(self, idx):
402 | data = self.dataset[idx]
403 |
404 | if self.transform:
405 | data = self.transform(data)
406 | img_to_tensor = transforms.ToTensor()
407 | data = img_to_tensor(data)
408 |
409 | data = div(data, 255)
410 | # why would there by (1,224,224) samples?
411 | if data.size()[0] == 1:
412 | data = data.repeat(3,1,1)
413 | if self.normalize:
414 | data = self.normalize(data)
415 |
416 | return {'encodings':data, 'labels':self.labels[idx]}
417 |
418 | class MyImageDatasetFromStanfordCars(Dataset):
419 | def __init__(self, stanford_cars_dataset, transform=None, normalize=None):
420 | super(MyImageDatasetFromStanfordCars, self).__init__()
421 | self.stanford_cars_dataset = stanford_cars_dataset
422 | self.transform = transform
423 | self.normalize = normalize
424 |
425 | def __len__(self):
426 | return len(self.stanford_cars_dataset)
427 |
428 | def __getitem__(self, idx):
429 | data, label = self.stanford_cars_dataset[idx]
430 |
431 | if self.transform:
432 | data = self.transform(data)
433 | img_to_tensor = transforms.ToTensor()
434 | # if data is not tensor
435 | if not isinstance(data, torch.Tensor):
436 | data = img_to_tensor(data)
437 | if self.normalize:
438 | data = self.normalize(data)
439 |
440 | return {'encodings': data, 'labels': label}
441 |
442 |
443 | def load_train_data(data_dir, dataset, img_size, magnitude, batch_size):
444 | with open(data_dir, 'rb') as f:
445 | ds = pickle.load(f)
446 | train_data = ds['image']
447 | train_labels = ds['label']
448 | transform = transforms.Compose([
449 | transforms.Resize(img_size, interpolation=InterpolationMode.BICUBIC),
450 | transforms.RandAugment(num_ops=2,magnitude=magnitude),
451 | ])
452 | train_dataset = dataset(train_data, train_labels, transform,
453 | normalize=transforms.Compose([
454 | transforms.Normalize(
455 | mean=(0.485, 0.456, 0.406),
456 | std=(0.229, 0.224, 0.225)
457 | )
458 | ]),
459 | )
460 | train_loader = DataLoader(
461 | train_dataset,
462 | shuffle=True,
463 | batch_size=batch_size,
464 | num_workers=8,
465 | pin_memory=True,
466 | drop_last=True,
467 | )
468 | f.close()
469 | return train_dataset, train_loader
470 |
471 | def load_val_data(data_dir, dataset, img_size, batch_size):
472 | with open(data_dir, 'rb') as f:
473 | ds = pickle.load(f)
474 | val_data = ds['image']
475 | val_labels = ds['label']
476 | transform = transforms.Compose([
477 | transforms.Resize(img_size, interpolation=InterpolationMode.BICUBIC),
478 | ])
479 | val_dataset = dataset(val_data, val_labels, transform,
480 | normalize=transforms.Compose([
481 | transforms.Normalize(
482 | mean=(0.485, 0.456, 0.406),
483 | std=(0.229, 0.224, 0.225)
484 | ),
485 | ])
486 | )
487 | val_loader = DataLoader(
488 | val_dataset,
489 | batch_size=batch_size,
490 | shuffle=True,
491 | num_workers=4,
492 | pin_memory=True
493 | )
494 | f.close()
495 | return val_dataset, val_loader
496 |
497 | class Adam:
498 | def __init__(self, alpha=1e-4, beta1=0.9, beta2=0.9):
499 | self.alpha = alpha
500 | self.beta1 = beta1
501 | self.beta2 = beta2
502 | self.m = 0
503 | self.v = 0
504 | self.t = 0
505 | self.eps = 1e-5
506 | def update(self, g):
507 | self.t += 1
508 | self.m = self.beta1 * self.m + (1-self.beta1) * g
509 | self.v = self.beta2 * self.v + (1-self.beta2) * g**2
510 | self.m = self.m / (1-self.beta1**self.t)
511 | self.v = self.v / (1-self.beta2**self.t)
512 | return self.alpha * self.m / (np.sqrt(self.v+self.eps)+self.eps)
513 |
514 | def posterior_mu(mu, mu0, n, n0=1):
515 | res = (n0*mu0+n*mu)/(n+n0)
516 | return res
517 |
518 |
519 | # initializing should adapt to the data (mean, variance, norm...)
520 | def posterior_mu_sigma(mu, sigma, n, mu0, n0=500, lamda0=None):
521 | #lamda = inv(np.matmul(sigma, sigma.T))
522 | #print('size', sigma.size)
523 | # fix mu, update sigma
524 | if lamda0 is None:
525 | lamda0 = np.eye(sigma.shape[-1])
526 |
527 | post_mu = (n0*mu0+n*mu)/(n+n0)
528 | #post_sigma = (n0* lamda0 + n* sigma) / (n0+n)
529 | #print('prod', np.matmul((mu-mu0).T,mu-mu0))
530 | lamda = (n0* lamda0 + n* inv(sigma)) / (n+n0) #+ n*n0/(n0+n) * np.matmul((mu-mu0).T,mu-mu0)
531 | post_sigma = (n0*lamda0 + n*inv(lamda) )/ (n+n0)
532 | #print('post_sigma', det(post_sigma))
533 | return post_mu, post_sigma
534 |
535 |
536 | def get_tf_idf_mask(corpus_words):
537 | #corpus_words = [[w for w in seq if w is not 0] for seq in corpus_words]
538 | corpus_bow = [list(Counter(seq).items()) for seq in corpus_words]
539 |
540 | tfidf = TfidfModel(corpus_bow, normalize=False)
541 | #filter low value words
542 | low_value = 4
543 | filtered_ids = []
544 | tfidf_mask = []
545 | for i in range(0, len(corpus_bow)):
546 | bow = corpus_bow[i]
547 | low_value_words = [] #reinitialize to be safe. You can skip this.
548 | low_value_words = [id for id, value in tfidf[bow] if value < low_value]
549 | new_mask = []
550 | for idx, w in enumerate(corpus_words[i]):
551 | if w in low_value_words or w==0 or idx==0: # consider padding, 'CLS' token
552 | new_mask.append(0)
553 | else:
554 | new_mask.append(1)
555 | #reassign
556 | tfidf_mask.append(new_mask)
557 | #print(new_mask)
558 | return tfidf_mask
559 |
560 |
561 | def topic_vis(mus, coherence, topic):
562 | pca = PCA(n_components=2)
563 | pc = pca.fit_transform(mus)
564 | ax = plt.subplot()
565 | for i in topic:
566 | coh = coherence[i]
567 | coh = round(coh, 1)
568 | ax.annotate( str(i), (pc[i,0], pc[i,1]), size=10) # + '-' + str(coh)
569 |
570 | ax.scatter(pc[i,0], pc[i, 1])
571 | plt.savefig('topic_scatter_vis.jpg')
572 | return ax
573 |
574 | def vis(x,name,topic,patch=None, pos=None, full_img=None, sigmas = None, div= None, coh=None, top_topics = None, indiv_coh=None, attentions=None):
575 | '''
576 | x: features
577 | K: topic numbers
578 | topic: concept of patch/word
579 | patch: image of patch, numpy array (optional)
580 | '''
581 | #if pos is not None:
582 | # print(pos)
583 | #print('topic', topic)
584 | # Get the Matplotlib logger
585 | mpl_logger = logging.getLogger('matplotlib')
586 |
587 | # Set the level to WARNING or higher
588 | mpl_logger.setLevel(logging.WARNING)
589 |
590 | mean = x.mean(axis=0)
591 | var = x.var(axis=0)
592 | #print(x.shape)
593 | #print(mean.shape)
594 | #print(var.shape)
595 | '''
596 | outer_indices = [i for i in range(x.shape[0]) if (abs(x[i]-mean) > 3*np.sqrt(var)).any()]
597 | print('outer', outer_indices)
598 | x = np.delete(x, outer_indices, axis=0)
599 | name = np.delete(name, outer_indices, axis=0)
600 | topic = np.delete(topic, outer_indices, axis=0)
601 | if patch is not None:
602 | patch = np.delete(patch, outer_indices, axis=0)
603 | if pos is not None:
604 | pos = np.delete(pos, outer_indices, axis=0)
605 | if full_img is not None:
606 | full_img = np.delete(full_img, outer_indices, axis=0)
607 | #if sigmas is not None:
608 | # sigmas = np.delete(sigmas, outer_indices, axis=0)
609 | '''
610 | pca = PCA(n_components=2)
611 | pc = pca.fit_transform(x)
612 | colors = [ 'blue', 'red','purple', 'green','lime', 'cyan', 'orange', 'pink','black', 'grey'] * 20
613 | # for i, txt in enumerate(name):
614 | # ax.annotate(txt, (pc[i,0], pc[i,1]),
615 | # color=colors[topic[i]], size=10)
616 | # ax.scatter(pc[:,0], pc[:, 1])
617 | '''
618 | fig = plt.figure(figsize=(10,10))
619 | axs = fig.add_subplot(1,1,1)
620 | axs.set_xlabel('pc 1', fontsize=15)
621 | axs.set_ylabel('pc 2', fontsize=15)
622 | axs.set_title('2D PCA', fontsize=20)
623 | print(topic)
624 | for i, txt in enumerate(topic):
625 | axs.annotate(txt, (pc[i,0], pc[i,1]),
626 | size=10,color=colors[name[i]])
627 | axs.scatter(pc[:,0], pc[:, 1])
628 | plt.savefig('topic_vis.jpg')
629 | plt.savefig('topic_vis.pdf')
630 |
631 | scale = np.linalg.norm(pc) / np.linalg.norm(x)
632 | '''
633 | #for i in range(args.K):
634 | # print('concept ', i, sigmas[i] * scale)
635 |
636 | # f, axarr = plt.subplots(2,2)
637 | # axarr[0,0].imshow(image_datas[0])
638 | # Title for the whole figure
639 |
640 |
641 | grid_size = 20
642 | scale = np.sqrt(var).mean()
643 | # Define center point and range
644 | #x0, y0 = 5, 5 # replace with your desired center point
645 | pc = pc * 20
646 | tt = 24
647 | x0, y0 = int(pc[tt,0]), int(pc[tt,1])
648 | #x0/=100
649 | #y0/=100
650 | x0, y0 = 0,0
651 | print('center', x0, y0)
652 | print('pc x', pc[100:150,0])
653 | print('pc y', pc[100:150,1])
654 | # norm pc[0,:], pc[1,:] to be in range [-10,10]
655 | # pc[0,:] = pc[0,:] - pc[0,:].mean()
656 | # pc[1,:] = pc[1,:] - pc[1,:].mean()
657 | # pc[0,:] = pc[0,:] / pc[0,:].max() * 10
658 | # pc[1,:] = pc[1,:] / pc[1,:].max() * 10
659 |
660 | r = 10 # replace with your desired range
661 |
662 | gflag = np.zeros((grid_size,grid_size))
663 | max_att = np.zeros((grid_size,grid_size))
664 | f, arr = plt.subplots(grid_size,grid_size,figsize=(grid_size,grid_size))
665 |
666 |
667 | f.suptitle(f'div {div}, coh {coh}', fontsize=20) # Add your title here
668 | #print(pc)
669 | for i in range(grid_size):
670 | for j in range(grid_size):
671 | arr[i,j].axis('off')
672 | #print('patch', patch)
673 | #print nearest distance from i (i<100) to j (j>=100)
674 |
675 |
676 | for i in range(100):
677 | min_dist = 100000
678 | min_coord = (0,0)
679 | print('orig', pc[i,0], pc[i,1])
680 | for j in range(100,pc.shape[0]):
681 | dist = np.linalg.norm(pc[i]-pc[j])
682 | if dist < min_dist:
683 | min_dist = dist
684 | min_coord = (pc[j,0],pc[j,1])
685 | print('i=',i,', min_dist', min_dist, min_coord)
686 |
687 | patch = np.array(patch)
688 | mean_coord = patch.mean(axis=0)
689 | for i, _ in enumerate(patch):
690 | #if topic[i] not in top_topics:
691 | # continue
692 | ax = int(pc[i,0]) - int(x0) + grid_size//2
693 | ay = int(pc[i,1]) - int(y0) + grid_size//2
694 | #print(ax, ay, grid_size)
695 | #if topic[i] == tt:
696 | # print('topic ',tt, ax, ay)
697 | if abs(ax+1/2-grid_size/2)>=grid_size/2 or abs(ay+1/2-grid_size/2)>=grid_size/2:
698 | continue
699 | if int(gflag[ax,ay]):
700 | '''
701 | if attentions is None:
702 | continue
703 | elif attentions[i] < max_att[ax,ay]:
704 | continue
705 | else:
706 | max_att[ax,ay] = attentions[i]
707 | '''
708 | if np.linalg.norm(patch[i]-mean_coord) < max_att[ax,ay]:
709 | continue
710 | max_att[ax,ay] = np.linalg.norm(patch[i]-mean_coord)
711 | gflag[ax,ay] = 1
712 |
713 |
714 | img = arr[ax,ay].imshow(patch[i], interpolation='nearest')
715 | img.set_cmap('hot')
716 | plt.axis('off')
717 |
718 |
719 | # Removes axis numbers
720 |
721 | #arr[ax,ay].text(-0.1, 1.1, pos[i], transform=arr[ax,ay].transAxes,
722 | # size=10, weight='bold',color=colors[name[i]])
723 | arr[ax,ay].text(-0.1, 1.1, topic[i], transform=arr[ax,ay].transAxes,
724 | size=20, weight='bold',color=colors[name[i]])
725 |
726 | # we want to center at (5, 5). So, we set the limits to go from 0 to 10.
727 | #arr.set_xlim(0, 10)
728 | #arr.set_ylim(0, 10)
729 |
730 | # Setting x-ticks and y-ticks to clearly see the center of image at (5,5)
731 | #arr.set_xticks(np.arange(0,11,1))
732 | #arr.set_yticks(np.arange(0,11,1))
733 | plt.savefig('dataset_vis.jpg')
734 | plt.savefig('dataset_vis.pdf')
735 | for ii in top_topics:
736 | print(indiv_coh[ii])
737 |
738 |
739 | grid_size = 20
740 | gflag = np.zeros((grid_size,grid_size))
741 | max_att = np.zeros((grid_size,grid_size))
742 | f, arr = plt.subplots(grid_size,grid_size,figsize=(20,20))
743 | #print(pc)
744 | for i in range(grid_size):
745 | for j in range(grid_size):
746 | arr[i,j].axis('off')
747 | for i, _ in enumerate(full_img):
748 | #if topic[i] not in top_topics:
749 | # continue
750 | ax = int(pc[i,0]) - int(x0) + grid_size//2
751 | ay = int(pc[i,1]) - int(y0) + grid_size//2
752 | #arr[ax,ay].axis('off')
753 |
754 | if abs(ax+1/2-grid_size/2)>=grid_size/2 or abs(ay+1/2-grid_size/2)>=grid_size/2:
755 | continue
756 | if int(gflag[ax,ay]):
757 | if np.linalg.norm(patch[i]-mean_coord) < max_att[ax,ay]:
758 | continue
759 | max_att[ax,ay] = np.linalg.norm(patch[i]-mean_coord)
760 | gflag[ax,ay] = 1
761 | img = arr[ax,ay].imshow(full_img[i], interpolation='nearest')
762 | img.set_cmap('hot')
763 | plt.axis('off')
764 | # Removes axis numbers
765 |
766 | arr[ax,ay].text(-0.1, 1.1, pos[i], transform=arr[ax,ay].transAxes,
767 | size=10, weight='bold',color=colors[name[i]])
768 | #arr[ax,ay].text(-0.1, 1.1, topic[i], transform=arr[ax,ay].transAxes,
769 | # size=20, weight='bold',color=colors[name[i]])
770 | plt.savefig('dataset_image_vis.jpg')
771 | plt.savefig('dataset_image_vis.pdf')
772 |
773 |
774 | grid_size = 20
775 | gflag = np.zeros((grid_size,grid_size))
776 | max_att = np.zeros((grid_size,grid_size))
777 | f, arr = plt.subplots(grid_size,grid_size,figsize=(20,20))
778 | #print(pc)
779 | for i in range(grid_size):
780 | for j in range(grid_size):
781 | arr[i,j].axis('off')
782 | for i, _ in enumerate(full_img):
783 | #if topic[i] not in top_topics:
784 | # continue
785 | if pos[i] == (-1,-1):
786 | continue
787 | ax = int(pc[i,0]) - int(x0) + grid_size//2
788 | ay = int(pc[i,1]) - int(y0) + grid_size//2
789 | #arr[ax,ay].axis('off')
790 | # set the pos in image i to be stripes of black and white
791 | for ii in range(pos[i][0]*16, pos[i][0]*16+16):
792 | for jj in range(pos[i][1]*16, pos[i][1]*16+16):
793 | if (ii+jj)%2 == 0:
794 | full_img[i][ii,jj] = 0
795 | else:
796 | full_img[i][ii,jj] = 1
797 |
798 | if abs(ax+1/2-grid_size/2)>=grid_size/2 or abs(ay+1/2-grid_size/2)>=grid_size/2:
799 | continue
800 | if int(gflag[ax,ay]):
801 | if np.linalg.norm(patch[i]-mean_coord) < max_att[ax,ay]:
802 | continue
803 | max_att[ax,ay] = np.linalg.norm(patch[i]-mean_coord)
804 | gflag[ax,ay] = 1
805 | img = arr[ax,ay].imshow(full_img[i], interpolation='nearest')
806 | img.set_cmap('hot')
807 | plt.axis('off')
808 | # Removes axis numbers
809 |
810 | arr[ax,ay].text(-0.1, 1.1, pos[i], transform=arr[ax,ay].transAxes,
811 | size=10, weight='bold',color=colors[name[i]])
812 | #arr[ax,ay].text(-0.1, 1.1, topic[i], transform=arr[ax,ay].transAxes,
813 | # size=20, weight='bold',color=colors[name[i]])
814 | plt.savefig('dataset_image_masked_vis.jpg')
815 | plt.savefig('dataset_image_masked_vis.pdf')
816 |
817 | def plot_topics(topic_patches, topic_attention, topic_image, topic_masked_image, topic_labels):
818 | #num_topic, num_patch = len(topic_patches), len(topic_patches[0])
819 | #matched_class = [17, 0, -1, 10,18, 4,8, 12,6, 7]
820 | order = [0,6,2,8,1,5,9,3,4,7]
821 | oo = [0,1,2,3,4,5,6,7,8,9]
822 | mp = dict(zip(order,oo))
823 | matched_class = [0,0,0,1,0,1,0,0,0,1]
824 | num_topic, num_patch = 10, 40
825 | width_per_subplot = 0.5
826 | height_per_subplot = 0.5
827 | custom_figsize=(num_patch * width_per_subplot, num_topic * height_per_subplot)
828 | f, arr = plt.subplots(num_topic, num_patch, figsize=custom_figsize)
829 |
830 | for i in range(num_topic):
831 | if matched_class[i] == -1:
832 | continue
833 | selected_idx = np.where(topic_labels[i] == matched_class[i])[0]
834 | while len(selected_idx) < num_patch:
835 | ll = selected_idx[-1].item()
836 | if ll == 199:
837 | ll -=50
838 | selected_idx = np.concatenate((selected_idx, np.array([ll+1])))
839 | #import pdb; pdb.set_trace()
840 | topic_patches[i][:len(selected_idx)] = topic_patches[i][selected_idx]
841 | #topic_attention[i] = topic_attention[i][selected_idx]
842 | topic_image[i] = [topic_image[i][j] for j in selected_idx]
843 | topic_masked_image[i] = [topic_masked_image[i][j] for j in selected_idx]
844 | topic_labels[i][:len(selected_idx)] = topic_labels[i][selected_idx]
845 |
846 |
847 | #f, arr = plt.subplots(num_topic,num_patch,figsize=(num_topic,num_patch))
848 | for i in range(num_topic):
849 | for j in range(num_patch):
850 | arr[i,j].axis('off')
851 | if j >= num_patch:
852 | continue
853 | #import pdb
854 | #pdb.set_trace()
855 | topic_attention[i][j] = int(1000*topic_attention[i][j])/1000
856 | #arr[i,j].text(-0.1, 1.1, topic_attention[i][j], transform=arr[i,j].transAxes,
857 | #size=10, weight='bold')
858 | img = arr[mp[i],j].imshow(topic_patches[i][j], interpolation='nearest')
859 | img.set_cmap('hot')
860 | #f.subplots_adjust(wspace=0.1, hspace=0) # Adjust the spacing between subplots here
861 | plt.savefig('topic_patch_vis.jpg')
862 | plt.savefig('topic_patch_vis.pdf')
863 |
864 | f, arr = plt.subplots(num_topic,num_patch,figsize=custom_figsize) #(num_topic,num_patch)
865 | for i in range(num_topic):
866 | for j in range(num_patch):
867 | arr[i,j].axis('off')
868 | if j >= num_patch:
869 | continue
870 |
871 | topic_attention[i][j] = int(1000*topic_attention[i][j])/1000
872 | #arr[i,j].text(-0.1, 1.1, topic_attention[i][j], transform=arr[i,j].transAxes,
873 | #size=10, weight='bold')
874 | img = arr[mp[i],j].imshow(topic_image[i][j], interpolation='nearest')
875 | img.set_cmap('hot')
876 |
877 | plt.savefig('topic_image_vis.jpg')
878 | plt.savefig('topic_image_vis.pdf')
879 |
880 |
881 | f, arr = plt.subplots(num_topic,num_patch,figsize=custom_figsize)
882 | for i in range(num_topic):
883 | for j in range(num_patch):
884 | arr[i,j].axis('off')
885 | if j >= num_patch:
886 | continue
887 |
888 | topic_attention[i][j] = int(1000*topic_attention[i][j])/1000
889 | arr[i,j].text(-0.1, 1.1, topic_labels[i][j], transform=arr[i,j].transAxes,
890 | size=5, weight='bold')
891 | img = arr[mp[i],j].imshow(topic_masked_image[i][j], interpolation='nearest')
892 | img.set_cmap('hot')
893 |
894 | plt.savefig('topic_masked_image_vis.jpg')
895 | plt.savefig('topic_masked_image_vis.pdf')
896 | return f, arr
897 |
898 |
899 | def run_kmeans(X,K):
900 | # initialize centers, return mean and variance
901 | #centers = kmeans_init(X,K)
902 | # run kmeans algorithm for 1/2 iterations and return the mean and variance
903 | kmeans = KMeans(n_clusters=args.K,random_state=0, n_init=args.K, max_iter=10)
904 | ret = kmeans.fit(X)
905 | mean = ret.cluster_centers_
906 | variance = 0
907 | print(mean.mean(0))
908 | print(X.mean(axis=0))
909 | return mean#, variance
910 |
911 |
912 |
913 | def kmeans_init_prev(X,K, init_idx=None):
914 | centers = []
915 | index = []
916 | avg = 0
917 | if init_idx is None:
918 | init_idx = np.random.randint(0,len(X))
919 | avg = X[init_idx]
920 | centers.append(avg)
921 |
922 | else:
923 | for idx, j in enumerate(init_idx):
924 | avg = (avg*idx + X[j])/(idx+1)
925 | centers.append(X[j])
926 | index.append(j)
927 | #init_len = len(centers)
928 | for idx in range(K):
929 | if len(centers) == K:
930 | break
931 | dist = 0
932 | new_center = None
933 | new_index = None
934 | for j in range(len(X)):
935 | if np.linalg.norm(X[j]-avg) > dist:
936 | dist = np.linalg.norm(X[j]-avg)
937 | new_center = X[j]
938 | new_index = j
939 | print(idx, dist)
940 | avg = (avg*len(centers) + new_center)/(len(centers)+1)
941 | centers.append(new_center)
942 | index.append(new_index)
943 |
944 |
945 |
946 | return centers, index
947 |
948 | def kmeans_init(X, K, init_idx=None):
949 | centers = []
950 | index = []
951 |
952 | # If initial indices are not provided, select one data point randomly
953 | if init_idx is None:
954 | init_idx = [np.random.randint(0, len(X))]
955 |
956 | # Initialize the centers and indices using provided or random indices
957 | for idx in init_idx:
958 | centers.append(X[idx])
959 | index.append(idx)
960 |
961 | # Fill in the remaining centers
962 | while len(centers) < K:
963 | avg = np.mean(centers, axis=0) # Calculate the average of current centers
964 | dist = 0
965 | new_center = None
966 | new_index = None
967 |
968 | for j, x in enumerate(X):
969 | # Check if point is not already a center and is farther from the average
970 | if j not in index and np.linalg.norm(x - avg) > dist:
971 | dist = np.linalg.norm(x - avg)
972 | new_center = x
973 | new_index = j
974 |
975 | if new_center is not None:
976 | centers.append(new_center)
977 | index.append(new_index)
978 | print(len(centers), dist)
979 | return centers, index
980 |
981 | def softmax(x): # 2D
982 | """Compute softmax values for each sets of scores in x."""
983 | # x (B,d)
984 | e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
985 | return e_x / e_x.sum(axis=-1, keepdims=True)
986 |
987 |
988 | def ensure_rgb_image(image):
989 | """Ensure image is in RGB format"""
990 | if hasattr(image, 'mode'):
991 | if image.mode != 'RGB':
992 | image = image.convert('RGB')
993 | return image
994 |
995 |
996 | def load_dataset_by_task(task, data_path):
997 | """
998 | Load dataset based on task name
999 | Args:
1000 | task: task name ('flower102', 'cub2011', 'cars', 'Color')
1001 | data_path: path to dataset
1002 | Returns:
1003 | train_dataset, test_dataset, out_dim
1004 | """
1005 | img_size = (224, 224)
1006 |
1007 | if task == 'flower102':
1008 | dataset_name = "nelorth/oxford-flowers"
1009 | dataset = load_dataset(dataset_name)
1010 |
1011 | # ensure the image is in RGB format
1012 | train_images = [ensure_rgb_image(img) for img in dataset['train']['image']]
1013 | test_images = [ensure_rgb_image(img) for img in dataset['test']['image']]
1014 |
1015 | processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
1016 | train_inputs = processor(train_images, return_tensors="pt")
1017 | test_inputs = processor(test_images, return_tensors="pt")
1018 | train_dataset = MyImageDataset(train_inputs['pixel_values'], dataset['train']['label'])
1019 | test_dataset = MyImageDataset(test_inputs['pixel_values'], dataset['test']['label'])
1020 | out_dim = 102
1021 |
1022 | elif task == 'cub2011':
1023 | dataset_name = "Donghyun99/CUB-200-2011"
1024 | dataset = load_dataset(dataset_name)
1025 |
1026 | # ensure the image is in RGB format
1027 | train_images = [ensure_rgb_image(img) for img in dataset['train']['image']]
1028 | test_images = [ensure_rgb_image(img) for img in dataset['test']['image']]
1029 |
1030 | # use the same processor to ensure consistency
1031 | processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
1032 | train_inputs = processor(train_images, return_tensors="pt")
1033 | test_inputs = processor(test_images, return_tensors="pt")
1034 | train_dataset = MyImageDataset(train_inputs['pixel_values'], dataset['train']['label'])
1035 | test_dataset = MyImageDataset(test_inputs['pixel_values'], dataset['test']['label'])
1036 | out_dim = 200
1037 |
1038 | elif task == 'cars':
1039 | dataset_name = "tanganke/stanford_cars"
1040 | dataset = load_dataset(dataset_name)
1041 |
1042 | # ensure the image is in RGB format
1043 | train_images = [ensure_rgb_image(img) for img in dataset['train']['image']]
1044 | test_images = [ensure_rgb_image(img) for img in dataset['test']['image']]
1045 |
1046 | # use the same processor to ensure consistency
1047 | processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
1048 | train_inputs = processor(train_images, return_tensors="pt")
1049 | test_inputs = processor(test_images, return_tensors="pt")
1050 | train_dataset = MyImageDataset(train_inputs['pixel_values'], dataset['train']['label'])
1051 | test_dataset = MyImageDataset(test_inputs['pixel_values'], dataset['test']['label'])
1052 | out_dim = 196
1053 |
1054 | elif task == 'Color':
1055 | # Define a transformation to convert the images to PyTorch tensors
1056 | transform = transforms.Compose([
1057 | transforms.ToTensor(),
1058 | transforms.Lambda(lambda x: x[:3, ...]),
1059 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to range [-1,1]
1060 | ])
1061 |
1062 | # Load the images and labels
1063 | dataset = []
1064 | labels = []
1065 | for class_dir in [os.path.join(data_path, 'Color/class0'),
1066 | os.path.join(data_path, 'Color/class1')]:
1067 | if os.path.exists(class_dir):
1068 | for image_name in os.listdir(class_dir):
1069 | # Read image
1070 | image = Image.open(os.path.join(class_dir, image_name))
1071 | # Add to the lists
1072 | dataset.append(image)
1073 | labels.append(int(class_dir[-1])) # class ID from the directory name
1074 |
1075 | # Convert lists to tensors
1076 | labels = torch.tensor(labels)
1077 |
1078 | # Split into train and test sets
1079 | # Pair up the data and labels
1080 | paired_data = list(zip(dataset, labels))
1081 |
1082 | # Perform the split on the paired data
1083 | train_size = int(0.8 * len(paired_data)) # 80% for training
1084 | test_size = len(paired_data) - train_size
1085 | train_data, test_data = random_split(paired_data, [train_size, test_size])
1086 |
1087 | train_images, train_labels = zip(*train_data)
1088 | test_images, test_labels = zip(*test_data)
1089 |
1090 | # Convert the zipped data back to lists or tensors as needed
1091 | train_images = list(train_images)
1092 | train_labels = list(train_labels)
1093 | test_images = list(test_images)
1094 | test_labels = list(test_labels)
1095 |
1096 | # Create MyImageDataset instances
1097 | train_dataset = MyImageDataset(train_images, train_labels, transform=transform)
1098 | test_dataset = MyImageDataset(test_images, test_labels, transform=transform)
1099 | out_dim = 2
1100 |
1101 | else:
1102 | raise ValueError(f"Unsupported task: {task}")
1103 |
1104 | return train_dataset, test_dataset, out_dim
1105 |
1106 |
1107 | '''
1108 | def row_norms(X, squared=False):
1109 | """Row-wise (squared) Euclidean norm of X.
1110 | Equivalent to np.sqrt((X * X).sum(axis=1)), but also supports sparse
1111 | matrices and does not create an X.shape-sized temporary.
1112 | Performs no input validation.
1113 | Parameters
1114 | ----------
1115 | X : array-like
1116 | The input array.
1117 | squared : bool, default=False
1118 | If True, return squared norms.
1119 | Returns
1120 | -------
1121 | array-like
1122 | The row-wise (squared) Euclidean norm of X.
1123 | """
1124 | if sparse.issparse(X):
1125 | if not isinstance(X, sparse.csr_matrix):
1126 | X = sparse.csr_matrix(X)
1127 | norms = csr_row_norms(X)
1128 | else:
1129 | norms = np.einsum("ij,ij->i", X, X)
1130 |
1131 | if not squared:
1132 | np.sqrt(norms, norms)
1133 | return norms
1134 |
1135 | def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared=False):
1136 | """Computational part of euclidean_distances
1137 | Assumes inputs are already checked.
1138 | If norms are passed as float32, they are unused. If arrays are passed as
1139 | float32, norms needs to be recomputed on upcast chunks.
1140 | TODO: use a float64 accumulator in row_norms to avoid the latter.
1141 | """
1142 | if X_norm_squared is not None:
1143 | if X_norm_squared.dtype == np.float32:
1144 | XX = None
1145 | else:
1146 | XX = X_norm_squared.reshape(-1, 1)
1147 | elif X.dtype == np.float32:
1148 | XX = None
1149 | else:
1150 | XX = row_norms(X, squared=True)[:, np.newaxis]
1151 |
1152 | if Y is X:
1153 | YY = None if XX is None else XX.T
1154 | else:
1155 | if Y_norm_squared is not None:
1156 | if Y_norm_squared.dtype == np.float32:
1157 | YY = None
1158 | else:
1159 | YY = Y_norm_squared.reshape(1, -1)
1160 | elif Y.dtype == np.float32:
1161 | YY = None
1162 | else:
1163 | YY = row_norms(Y, squared=True)[np.newaxis, :]
1164 |
1165 | if X.dtype == np.float32:
1166 | # To minimize precision issues with float32, we compute the distance
1167 | # matrix on chunks of X and Y upcast to float64
1168 | distances = _euclidean_distances_upcast(X, XX, Y, YY)
1169 | else:
1170 | # if dtype is already float64, no need to chunk and upcast
1171 | distances = -2 * safe_sparse_dot(X, Y.T, dense_output=True)
1172 | distances += XX
1173 | distances += YY
1174 | np.maximum(distances, 0, out=distances)
1175 |
1176 | # Ensure that distances between vectors and themselves are set to 0.0.
1177 | # This may not be the case due to floating point rounding errors.
1178 | if X is Y:
1179 | np.fill_diagonal(distances, 0)
1180 |
1181 | return distances if squared else np.sqrt(distances, out=distances)
1182 |
1183 |
1184 | # ref: https://github.com/scikit-learn/scikit-learn/blob/baf0ea25d/sklearn/cluster/_kmeans.py#L154
1185 |
1186 | def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trials=None):
1187 | """Computational component for initialization of n_clusters by
1188 | k-means++. Prior validation of data is assumed.
1189 | Parameters
1190 | ----------
1191 | X : {ndarray, sparse matrix} of shape (n_samples, n_features)
1192 | The data to pick seeds for.
1193 | n_clusters : int
1194 | The number of seeds to choose.
1195 | x_squared_norms : ndarray of shape (n_samples,)
1196 | Squared Euclidean norm of each data point.
1197 | random_state : RandomState instance
1198 | The generator used to initialize the centers.
1199 | See :term:`Glossary `.
1200 | n_local_trials : int, default=None
1201 | The number of seeding trials for each center (except the first),
1202 | of which the one reducing inertia the most is greedily chosen.
1203 | Set to None to make the number of trials depend logarithmically
1204 | on the number of seeds (2+log(k)); this is the default.
1205 | Returns
1206 | -------
1207 | centers : ndarray of shape (n_clusters, n_features)
1208 | The initial centers for k-means.
1209 | indices : ndarray of shape (n_clusters,)
1210 | The index location of the chosen centers in the data array X. For a
1211 | given index and center, X[index] = center.
1212 | """
1213 | n_samples, n_features = X.shape
1214 |
1215 | centers = np.empty((n_clusters, n_features), dtype=X.dtype)
1216 |
1217 | # Set the number of local seeding trials if none is given
1218 | if n_local_trials is None:
1219 | # This is what Arthur/Vassilvitskii tried, but did not report
1220 | # specific results for other than mentioning in the conclusion
1221 | # that it helped.
1222 | n_local_trials = 2 + int(np.log(n_clusters))
1223 |
1224 | # Pick first center randomly and track index of point
1225 | center_id = random_state.randint(n_samples)
1226 | indices = np.full(n_clusters, -1, dtype=int)
1227 | if sp.issparse(X):
1228 | centers[0] = X[center_id].toarray()
1229 | else:
1230 | centers[0] = X[center_id]
1231 | indices[0] = center_id
1232 |
1233 | # Initialize list of closest distances and calculate current potential
1234 | closest_dist_sq = _euclidean_distances(
1235 | centers[0, np.newaxis], X, Y_norm_squared=x_squared_norms, squared=True
1236 | )
1237 | current_pot = closest_dist_sq.sum()
1238 |
1239 | # Pick the remaining n_clusters-1 points
1240 | for c in range(1, n_clusters):
1241 | # Choose center candidates by sampling with probability proportional
1242 | # to the squared distance to the closest existing center
1243 | rand_vals = random_state.uniform(size=n_local_trials) * current_pot
1244 | candidate_ids = np.searchsorted(stable_cumsum(closest_dist_sq), rand_vals)
1245 | # XXX: numerical imprecision can result in a candidate_id out of range
1246 | np.clip(candidate_ids, None, closest_dist_sq.size - 1, out=candidate_ids)
1247 |
1248 | # Compute distances to center candidates
1249 | distance_to_candidates = _euclidean_distances(
1250 | X[candidate_ids], X, Y_norm_squared=x_squared_norms, squared=True
1251 | )
1252 |
1253 | # update closest distances squared and potential for each candidate
1254 | np.minimum(closest_dist_sq, distance_to_candidates, out=distance_to_candidates)
1255 | candidates_pot = distance_to_candidates.sum(axis=1)
1256 |
1257 | # Decide which candidate is the best
1258 | best_candidate = np.argmin(candidates_pot)
1259 | current_pot = candidates_pot[best_candidate]
1260 | closest_dist_sq = distance_to_candidates[best_candidate]
1261 | best_candidate = candidate_ids[best_candidate]
1262 |
1263 | # Permanently add best center candidate found in local tries
1264 | if sp.issparse(X):
1265 | centers[c] = X[best_candidate].toarray()
1266 | else:
1267 | centers[c] = X[best_candidate]
1268 | indices[c] = best_candidate
1269 |
1270 | return centers, indices
1271 |
1272 |
1273 |
1274 | '''
--------------------------------------------------------------------------------