├── .gitignore ├── image ├── acc_0.png ├── acc_1.png ├── model.png ├── manual_1.png ├── z_pca_0.png ├── z_pca_1.png ├── z_tsne_0.png ├── z_tsne_1.png ├── gause_I0k1.png ├── vae_loss_1.png └── variable_define.png ├── requirements.txt ├── README.md ├── main.py ├── gmm_module.py ├── vae_module.py └── tool.py /.gitignore: -------------------------------------------------------------------------------- 1 | pth/ 2 | model/ 3 | __pycache__/ -------------------------------------------------------------------------------- /image/acc_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/acc_0.png -------------------------------------------------------------------------------- /image/acc_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/acc_1.png -------------------------------------------------------------------------------- /image/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/model.png -------------------------------------------------------------------------------- /image/manual_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/manual_1.png -------------------------------------------------------------------------------- /image/z_pca_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/z_pca_0.png -------------------------------------------------------------------------------- /image/z_pca_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/z_pca_1.png -------------------------------------------------------------------------------- /image/z_tsne_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/z_tsne_0.png -------------------------------------------------------------------------------- /image/z_tsne_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/z_tsne_1.png -------------------------------------------------------------------------------- /image/gause_I0k1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/gause_I0k1.png -------------------------------------------------------------------------------- /image/vae_loss_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/vae_loss_1.png -------------------------------------------------------------------------------- /image/variable_define.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/variable_define.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyTorch VAE-GMM Requirements 2 | # Core deep learning framework 3 | torch>=1.5.1 4 | torchvision>=0.6.1 5 | 6 | # Numerical computing 7 | numpy>=1.19.0 8 | 9 | # Statistical computing 10 | scipy>=1.5.0 11 | 12 | # Machine learning utilities 13 | scikit-learn>=0.23.0 14 | 15 | # Plotting and visualization 16 | matplotlib>=3.3.0 17 | 18 | # Optional: For better performance with numerical operations 19 | # Uncomment if needed 20 | # mkl>=2021.1.1 21 | # mkl-service>=2.3.0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Auto-Encoder(VAE)+Gaussian mixture model(GMM) 2 | 3 | Implementation of mutual learning model between VAE and GMM. 4 | This idea of integrating probability models is based on this paper: [Neuro-SERKET: Development of Integrative Cognitive System through the Composition of Deep Probabilistic Generative Models](https://arxiv.org/abs/1910.08918). 5 | Symbol Emergence in Robotics tool KIT(SERKET) is a framework that allows integration and partitioning of probabilistic generative models. 6 | 7 | This is a Graphical Model of VAE+GMM model: 8 | 9 |
10 | 11 |
12 | 13 | VAE and GMM share the latent variable x. 14 | x is a variable that follows a multivariate normal distribution and is estimated by VAE. 15 | 16 | The training will be conducted in the following sequence. 17 | 18 | 1. VAE estimates latent variable(x) and sends latent variables(x) to GMM. 19 | 2. GMM clusters latent variables(x) sent from VAE and sends mean and variance parameters of the Gaussian distribution to VAE. 20 | 3. Return to 1 again. 21 | 22 | What this repo contains: 23 | 24 | - `main.py`: Main code for training model. 25 | - `vae_module.py`: A training program for VAE, running in main.py. 26 | - `gmm_module.py`: A training program for GMM, running in main.py. 27 | - `tool.py`: Various functions handled in the program. 28 | 29 | # How to run 30 | 31 | Install the required libraries using the following command. 32 | ※ Install PyTorch first (XXX should match your CUDA version). 33 | ※ My environment is the following **Pytorch==2.8.0+cu129, CUDA==12.9** 34 | 35 | ```bash 36 | $ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cuXXX 37 | $ pip install -r requirements.txt 38 | ``` 39 | 40 | You can train the VAE+GMM model by running `main.py`. 41 | 42 | ```bash 43 | $ python main.py 44 | ``` 45 | 46 | - `train_model()` can be made to train VAE+GMM. 47 | - `vae_module.decode()` makes image reconstruction from parameters of posterior distribution estimated by GMM. 48 | 49 | ```python:main.py 50 | def main() -> None: 51 | """Main function to orchestrate the VAE-GMM training process.""" 52 | # Parse arguments and setup configuration 53 | config = parse_arguments() 54 | 55 | # Setup environment 56 | setup_directories(config) 57 | device = setup_device_and_seed(config) 58 | 59 | # Create data loaders 60 | train_loader, all_loader, train_size = create_data_loaders(config) 61 | 62 | # Train the VAE-GMM model 63 | train_model(config, train_loader, all_loader, device) 64 | 65 | # Reconstruct images from trained model 66 | print("\nGenerating reconstructed images...") 67 | vae_module.decode( 68 | iteration=1, # Use model from iteration 1 69 | decode_k=1, # Use cluster 1 for reconstruction 70 | sample_num=16, # Generate 16 samples 71 | model_dir=config.debug_dir, 72 | device=device, 73 | ) 74 | ``` 75 | 76 | # Changes with and without mutual learning (for MNIST) 77 | 78 | ## Latent space on VAE 79 | 80 | Left : without mutual learning・Right : with mutual learning 81 | Plot using TSNE 82 | 83 |
84 | 85 |
86 | Plot using PCA 87 |
88 | 89 |
90 | 91 | ## ELBO of VAE 92 | 93 | Red line is ELBO before mutual learning, Blue line is ELBO after mutual learning 94 | Vertical axis is training iteration of VAE, Horizontal axis is ELBO of VAE 95 | (In general, the higher the ELBO, the better) 96 | 97 |
98 | 99 |
100 | 101 | ## Clustering performance (in GMM) 102 | 103 | Results of clustering performance by accuracy(Addresses clustering performance in GMM within VAE+GMM) 104 | Left : without mutual learning・Right : with mutual learning 105 | Vertical axis is training iteration of GMM, Horizontal axis is accuracy 106 | 107 |
108 | 109 |
110 | 111 | # Image reconstruction from Gaussian distribution parameters estimated by GMM using VAE decoder 112 | 113 | GMM performs clustering on latent variables of VAE. 114 | By sampling random variables from posterior distribution estimated by GMM and using them as input to VAE decoder, the image can be reconstructed. 115 | 116 | "x" represents the mean parameter of the normal distribution for each cluster. 117 | In this example, a random variable is sampled from a Gaussian distribution with K=1. 118 | 119 |
120 | 121 |
122 | Reconstructed image of the sampled random variable input to the VAE decoder: 123 |
124 | 125 |
126 | 127 | # Special Thanks 128 | 129 | The implementation of GMM is based on 130 | [【Python】4.4.2:ガウス混合モデルにおける推論:ギブスサンプリング【緑ベイズ入門のノート】](https://www.anarchive-beta.com/entry/2020/11/28/210948) 131 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch implementation of mutual learning between VAE and GMM models. 3 | 4 | This module orchestrates training, handles data loading and directory setup 5 | for the VAE-GMM mutual learning system on MNIST dataset. 6 | """ 7 | 8 | import argparse 9 | import os 10 | from dataclasses import dataclass 11 | from typing import Tuple 12 | 13 | import torch 14 | from torch.utils.data import DataLoader, Subset 15 | from torchvision import datasets, transforms 16 | 17 | import gmm_module 18 | import vae_module 19 | from tool import visualize_gmm 20 | 21 | 22 | @dataclass 23 | class Config: 24 | """Configuration class for VAE-GMM training parameters.""" 25 | 26 | # Training parameters 27 | vae_iter: int = 100 28 | gmm_iter: int = 100 29 | mutual_iterations: int = 2 30 | 31 | # Hardware settings 32 | use_cuda: bool = True 33 | seed: int = 1 34 | 35 | # Data parameters 36 | data_dir: str = "./../data" 37 | train_fraction: float = 1 / 6 # Use 1/6 of MNIST training set (10,000 samples) 38 | batch_size: int = 10 # Small batch size for GMM training stability 39 | 40 | # Model parameters 41 | num_clusters: int = 10 # K=10 for MNIST digit classes 42 | 43 | # Directory paths 44 | model_dir: str = "./model" 45 | debug_dir: str = "./model/debug" 46 | 47 | @property 48 | def graph_dir(self) -> str: 49 | """Directory for saving graphs and visualizations.""" 50 | return os.path.join(self.debug_dir, "graph") 51 | 52 | @property 53 | def pth_dir(self) -> str: 54 | """Directory for saving VAE model states (.pth files).""" 55 | return os.path.join(self.debug_dir, "pth") 56 | 57 | @property 58 | def npy_dir(self) -> str: 59 | """Directory for saving GMM parameters (.npy files).""" 60 | return os.path.join(self.debug_dir, "npy") 61 | 62 | @property 63 | def recon_dir(self) -> str: 64 | """Directory for saving reconstructed images.""" 65 | return os.path.join(self.debug_dir, "recon") 66 | 67 | 68 | def setup_directories(config: Config) -> None: 69 | """Create necessary directories for model outputs.""" 70 | directories = [ 71 | config.model_dir, 72 | config.debug_dir, 73 | config.graph_dir, 74 | config.pth_dir, 75 | config.npy_dir, 76 | config.recon_dir, 77 | ] 78 | 79 | for directory in directories: 80 | os.makedirs(directory, exist_ok=True) 81 | 82 | 83 | def setup_device_and_seed(config: Config) -> torch.device: 84 | """Setup device and random seed for reproducible results.""" 85 | torch.manual_seed(config.seed) 86 | 87 | if config.use_cuda and torch.cuda.is_available(): 88 | device = torch.device("cuda") 89 | print(f"Using CUDA device: {torch.cuda.get_device_name()}") 90 | else: 91 | device = torch.device("cpu") 92 | print("Using CPU device") 93 | 94 | return device 95 | 96 | 97 | def create_data_loaders(config: Config) -> Tuple[DataLoader, DataLoader, int]: 98 | """ 99 | Create training and full dataset loaders for MNIST. 100 | 101 | Returns: 102 | Tuple of (train_loader, all_loader, train_size) 103 | """ 104 | # Load MNIST dataset 105 | trainval_dataset = datasets.MNIST( 106 | config.data_dir, train=True, transform=transforms.ToTensor(), download=True 107 | ) 108 | 109 | # Calculate dataset sizes 110 | n_samples = len(trainval_dataset) 111 | train_size = int(n_samples * config.train_fraction) 112 | 113 | # Create subsets 114 | train_indices = list(range(train_size)) 115 | train_dataset = Subset(trainval_dataset, train_indices) 116 | 117 | # Data loader settings 118 | kwargs = {"num_workers": 1, "pin_memory": True} if config.use_cuda else {} 119 | 120 | # Create data loaders 121 | train_loader = DataLoader( 122 | train_dataset, batch_size=config.batch_size, shuffle=False, **kwargs 123 | ) 124 | 125 | # All data loader for sending latent variables to GMM 126 | all_loader = DataLoader( 127 | train_dataset, batch_size=train_size, shuffle=False, **kwargs 128 | ) 129 | 130 | print( 131 | f"Dataset info: {train_size} samples, " 132 | f"VAE iterations: {config.vae_iter}, " 133 | f"GMM iterations: {config.gmm_iter}" 134 | ) 135 | 136 | return train_loader, all_loader, train_size 137 | 138 | 139 | def train_model( 140 | config: Config, 141 | train_loader: DataLoader, 142 | all_loader: DataLoader, 143 | device: torch.device, 144 | ) -> None: 145 | """ 146 | Main training loop for VAE-GMM mutual learning. 147 | 148 | Args: 149 | config: Configuration object with training parameters 150 | train_loader: DataLoader for training batches 151 | all_loader: DataLoader for all data (used for GMM training) 152 | """ 153 | print("Starting VAE-GMM mutual learning training...") 154 | 155 | # Initialize GMM parameters for first iteration 156 | gmm_mu = None 157 | gmm_var = None 158 | 159 | for iteration in range(config.mutual_iterations): 160 | print(f"---------- Mutual Learning Iteration: {iteration + 1} ----------") 161 | 162 | # VAE Training Phase 163 | print("Training VAE...") 164 | x_d, label, loss_list = vae_module.train( 165 | iteration=iteration, 166 | gmm_mu=gmm_mu, # GMM mean parameters (None for first iteration) 167 | gmm_var=gmm_var, # GMM variance parameters (None for first iteration) 168 | epoch=config.vae_iter, 169 | train_loader=train_loader, 170 | all_loader=all_loader, 171 | model_dir=config.debug_dir, 172 | device=device, 173 | ) 174 | 175 | # GMM Training Phase 176 | print("Training GMM...") 177 | gmm_mu, gmm_var, max_acc = gmm_module.train( 178 | iteration=iteration, 179 | x_d=x_d, # Latent variables from VAE 180 | model_dir=config.debug_dir, 181 | label=label, # MNIST labels for accuracy calculation 182 | K=config.num_clusters, 183 | epoch=config.gmm_iter, 184 | ) 185 | 186 | # Visualize latent space 187 | print("Plotting latent space...") 188 | vae_module.plot_latent( 189 | iteration=iteration, 190 | all_loader=all_loader, 191 | model_dir=config.debug_dir, 192 | device=device, 193 | ) 194 | 195 | print(f"Iteration {iteration + 1} completed. Max accuracy: {max_acc:.4f}") 196 | 197 | 198 | def plot_dist(iteration: int, decode_k: int, sample_num: int, model_dir: str) -> None: 199 | """ 200 | Visualize Gaussian distributions from GMM parameters. 201 | 202 | Args: 203 | iteration: Which iteration's model to load 204 | decode_k: Cluster number of Gaussian distribution 205 | sample_num: Number of samples for the random variable 206 | model_dir: Directory containing the model files 207 | """ 208 | visualize_gmm( 209 | iteration=iteration, 210 | decode_k=decode_k, 211 | sample_num=sample_num, 212 | model_dir=model_dir, 213 | ) 214 | 215 | 216 | def parse_arguments() -> Config: 217 | """Parse command line arguments and return configuration.""" 218 | parser = argparse.ArgumentParser(description="VAE-GMM MNIST Mutual Learning") 219 | parser.add_argument( 220 | "--vae-iter", 221 | type=int, 222 | default=100, 223 | help="Number of VAE training iterations (default: 100)", 224 | ) 225 | parser.add_argument( 226 | "--gmm-iter", 227 | type=int, 228 | default=100, 229 | help="Number of GMM training iterations (default: 100)", 230 | ) 231 | parser.add_argument( 232 | "--no-cuda", action="store_true", default=False, help="Disable CUDA training" 233 | ) 234 | parser.add_argument( 235 | "--seed", 236 | type=int, 237 | default=1, 238 | help="Random seed for reproducibility (default: 1)", 239 | ) 240 | 241 | args = parser.parse_args() 242 | 243 | # Create config from arguments 244 | config = Config( 245 | vae_iter=args.vae_iter, 246 | gmm_iter=args.gmm_iter, 247 | use_cuda=not args.no_cuda, 248 | seed=args.seed, 249 | ) 250 | 251 | return config 252 | 253 | 254 | def main() -> None: 255 | """Main function to orchestrate the VAE-GMM training process.""" 256 | # Parse arguments and setup configuration 257 | config = parse_arguments() 258 | 259 | # Setup environment 260 | setup_directories(config) 261 | device = setup_device_and_seed(config) 262 | 263 | # Create data loaders 264 | train_loader, all_loader, train_size = create_data_loaders(config) 265 | 266 | # Train the VAE-GMM model 267 | train_model(config, train_loader, all_loader, device) 268 | 269 | # Reconstruct images from trained model 270 | print("\nGenerating reconstructed images...") 271 | vae_module.decode( 272 | iteration=1, # Use model from iteration 1 273 | decode_k=1, # Use cluster 1 for reconstruction 274 | sample_num=16, # Generate 16 samples 275 | model_dir=config.debug_dir, 276 | device=device, 277 | ) 278 | 279 | print("Training and reconstruction completed!") 280 | 281 | # Optional: Visualize distributions (uncomment to use) 282 | # print("Visualizing Gaussian distributions...") 283 | # plot_dist( 284 | # iteration=0, 285 | # decode_k=1, 286 | # sample_num=16, 287 | # model_dir=config.debug_dir 288 | # ) 289 | 290 | 291 | if __name__ == "__main__": 292 | main() 293 | -------------------------------------------------------------------------------- /gmm_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gaussian Mixture Model (GMM) module for the VAE-GMM mutual learning system. 3 | 4 | This module provides GMM implementation with Gibbs sampling for clustering 5 | VAE latent variables and estimating parameters for the VAE prior distribution. 6 | """ 7 | 8 | from dataclasses import dataclass 9 | from typing import Tuple, List 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch 14 | from scipy.stats import dirichlet, wishart 15 | 16 | from tool import calc_acc 17 | 18 | 19 | @dataclass 20 | class GMMConfig: 21 | """Configuration class for GMM hyperparameters.""" 22 | 23 | # Prior hyperparameters 24 | beta: float = 1.0 # Precision parameter for mean prior 25 | nu_factor: float = 1.0 # Degrees of freedom factor (nu = dim * nu_factor) 26 | w_scale: float = 0.55 # Scale factor for precision matrix prior 27 | alpha: float = 0.3 # Dirichlet concentration parameter 28 | 29 | # Training parameters 30 | log_interval: int = 20 # Logging interval during training 31 | 32 | # Numerical stability 33 | eps: float = 1e-7 # Small epsilon for numerical stability 34 | 35 | 36 | def initialize_gmm_parameters( 37 | num_data: int, 38 | data_dim: int, 39 | num_clusters: int, 40 | config: GMMConfig 41 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 42 | """ 43 | Initialize GMM parameters by sampling from priors. 44 | 45 | Args: 46 | num_data: Number of data points 47 | data_dim: Dimensionality of data 48 | num_clusters: Number of clusters (K) 49 | config: GMM configuration 50 | 51 | Returns: 52 | Tuple of (mu, lambda, pi, prior_parameters) 53 | """ 54 | # Prior parameters 55 | m_d = np.zeros(data_dim) # Prior mean 56 | w_dd = np.identity(data_dim) * config.w_scale # Prior precision matrix scale 57 | nu = int(data_dim * config.nu_factor) # Degrees of freedom 58 | alpha_k = np.full(num_clusters, config.alpha) # Dirichlet parameters 59 | 60 | # Initialize cluster parameters by sampling from priors 61 | mu_kd = np.empty((num_clusters, data_dim)) 62 | lambda_kdd = np.empty((num_clusters, data_dim, data_dim)) 63 | 64 | for k in range(num_clusters): 65 | # Sample precision matrix from Wishart distribution 66 | lambda_kdd[k] = wishart.rvs(df=nu, scale=w_dd, size=1) 67 | 68 | # Sample mean from Normal-Wishart prior 69 | cov = np.linalg.inv(config.beta * lambda_kdd[k]) 70 | mu_kd[k] = np.random.multivariate_normal(mean=m_d, cov=cov).flatten() 71 | 72 | # Sample mixing coefficients from Dirichlet distribution 73 | pi_k = dirichlet.rvs(alpha=alpha_k, size=1).flatten() 74 | 75 | # Store prior parameters for later use 76 | prior_params = { 77 | 'm_d': m_d, 78 | 'w_dd': w_dd, 79 | 'nu': nu, 80 | 'alpha_k': alpha_k 81 | } 82 | 83 | return mu_kd, lambda_kdd, pi_k, prior_params 84 | 85 | 86 | def compute_responsibilities( 87 | x_d: np.ndarray, 88 | mu_kd: np.ndarray, 89 | lambda_kdd: np.ndarray, 90 | pi_k: np.ndarray, 91 | config: GMMConfig 92 | ) -> np.ndarray: 93 | """ 94 | Compute responsibilities (posterior probabilities) for each data point. 95 | 96 | Args: 97 | x_d: Data points [num_data, data_dim] 98 | mu_kd: Cluster means [num_clusters, data_dim] 99 | lambda_kdd: Cluster precision matrices [num_clusters, data_dim, data_dim] 100 | pi_k: Mixing coefficients [num_clusters] 101 | config: GMM configuration 102 | 103 | Returns: 104 | Responsibilities [num_data, num_clusters] 105 | """ 106 | num_data, num_clusters = len(x_d), len(mu_kd) 107 | eta_dk = np.zeros((num_data, num_clusters)) 108 | 109 | for k in range(num_clusters): 110 | # Compute quadratic form for each data point 111 | diff = x_d - mu_kd[k] # [num_data, data_dim] 112 | quadratic_form = np.sum(diff @ lambda_kdd[k] * diff, axis=1) 113 | 114 | # Log probability components 115 | log_prob = -0.5 * quadratic_form 116 | log_prob += 0.5 * np.log(np.linalg.det(lambda_kdd[k]) + config.eps) 117 | log_prob += np.log(pi_k[k] + config.eps) 118 | 119 | eta_dk[:, k] = np.exp(log_prob) 120 | 121 | # Normalize to get responsibilities 122 | eta_dk /= np.sum(eta_dk, axis=1, keepdims=True) 123 | return eta_dk 124 | 125 | 126 | def sample_cluster_assignments(eta_dk: np.ndarray) -> Tuple[np.ndarray, List[int]]: 127 | """ 128 | Sample cluster assignments from responsibilities. 129 | 130 | Args: 131 | eta_dk: Responsibilities [num_data, num_clusters] 132 | 133 | Returns: 134 | Tuple of (assignment_matrix, predicted_labels) 135 | """ 136 | num_data, num_clusters = eta_dk.shape 137 | z_dk = np.zeros((num_data, num_clusters)) 138 | pred_labels = [] 139 | 140 | for d in range(num_data): 141 | # Sample from multinomial distribution 142 | assignment = np.random.multinomial(n=1, pvals=eta_dk[d], size=1).flatten() 143 | z_dk[d] = assignment 144 | pred_labels.append(np.argmax(assignment)) 145 | 146 | return z_dk, pred_labels 147 | 148 | 149 | def update_gmm_parameters( 150 | x_d: np.ndarray, 151 | z_dk: np.ndarray, 152 | prior_params: dict, 153 | config: GMMConfig 154 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 155 | """ 156 | Update GMM parameters using Gibbs sampling. 157 | 158 | Args: 159 | x_d: Data points [num_data, data_dim] 160 | z_dk: Cluster assignments [num_data, num_clusters] 161 | prior_params: Dictionary containing prior parameters 162 | config: GMM configuration 163 | 164 | Returns: 165 | Tuple of (updated_mu, updated_lambda, updated_pi) 166 | """ 167 | num_data, data_dim = x_d.shape 168 | num_clusters = z_dk.shape[1] 169 | 170 | # Extract prior parameters 171 | m_d = prior_params['m_d'] 172 | w_dd = prior_params['w_dd'] 173 | nu = prior_params['nu'] 174 | alpha_k = prior_params['alpha_k'] 175 | 176 | # Initialize parameter arrays 177 | mu_kd = np.empty((num_clusters, data_dim)) 178 | lambda_kdd = np.empty((num_clusters, data_dim, data_dim)) 179 | 180 | for k in range(num_clusters): 181 | # Posterior parameters for Normal-Wishart 182 | cluster_size = np.sum(z_dk[:, k]) 183 | beta_hat = cluster_size + config.beta 184 | 185 | # Posterior mean 186 | weighted_sum = np.sum(z_dk[:, k][:, np.newaxis] * x_d, axis=0) 187 | m_hat = (weighted_sum + config.beta * m_d) / beta_hat 188 | 189 | # Posterior scale matrix for Wishart 190 | centered_data = x_d - m_hat # [num_data, data_dim] 191 | weighted_scatter = (z_dk[:, k][:, np.newaxis, np.newaxis] * 192 | centered_data[:, :, np.newaxis] @ 193 | centered_data[:, np.newaxis, :]).sum(axis=0) 194 | 195 | prior_term = config.beta * np.outer(m_d - m_hat, m_d - m_hat) 196 | w_hat = np.linalg.inv(np.linalg.inv(w_dd) + weighted_scatter + prior_term) 197 | 198 | nu_hat = cluster_size + nu 199 | 200 | # Sample precision matrix from Wishart 201 | lambda_kdd[k] = wishart.rvs(df=nu_hat, scale=w_hat, size=1) 202 | 203 | # Sample mean from multivariate normal 204 | cov = np.linalg.inv(beta_hat * lambda_kdd[k]) 205 | mu_kd[k] = np.random.multivariate_normal(mean=m_hat, cov=cov).flatten() 206 | 207 | # Update mixing coefficients 208 | alpha_hat = np.sum(z_dk, axis=0) + alpha_k 209 | pi_k = dirichlet.rvs(alpha=alpha_hat, size=1).flatten() 210 | 211 | return mu_kd, lambda_kdd, pi_k 212 | 213 | 214 | def compute_data_parameters( 215 | mu_kd: np.ndarray, 216 | lambda_kdd: np.ndarray, 217 | pred_labels: List[int] 218 | ) -> Tuple[np.ndarray, np.ndarray]: 219 | """ 220 | Compute mean and variance for each data point based on cluster assignments. 221 | 222 | Args: 223 | mu_kd: Cluster means [num_clusters, data_dim] 224 | lambda_kdd: Cluster precision matrices [num_clusters, data_dim, data_dim] 225 | pred_labels: Predicted cluster labels for each data point 226 | 227 | Returns: 228 | Tuple of (data_means, data_variances) 229 | """ 230 | num_data = len(pred_labels) 231 | data_dim = mu_kd.shape[1] 232 | 233 | mu_d = np.zeros((num_data, data_dim)) 234 | var_d = np.zeros((num_data, data_dim)) 235 | 236 | for d in range(num_data): 237 | cluster_idx = pred_labels[d] 238 | mu_d[d] = mu_kd[cluster_idx] 239 | # Extract diagonal of inverse precision matrix (variances) 240 | var_d[d] = np.diag(np.linalg.inv(lambda_kdd[cluster_idx])) 241 | 242 | return mu_d, var_d 243 | 244 | 245 | def save_gmm_parameters( 246 | mu_kd: np.ndarray, 247 | lambda_kdd: np.ndarray, 248 | pi_k: np.ndarray, 249 | iteration: int, 250 | model_dir: str 251 | ) -> None: 252 | """Save GMM parameters to files.""" 253 | np.save(f"{model_dir}/npy/mu_{iteration}.npy", mu_kd) 254 | np.save(f"{model_dir}/npy/lambda_{iteration}.npy", lambda_kdd) 255 | np.save(f"{model_dir}/npy/pi_{iteration}.npy", pi_k) 256 | 257 | 258 | def plot_accuracy_curve( 259 | accuracy_history: np.ndarray, 260 | iteration: int, 261 | num_clusters: int, 262 | model_dir: str 263 | ) -> None: 264 | """ 265 | Plot and save accuracy curve during GMM training. 266 | 267 | Args: 268 | accuracy_history: Accuracy values over epochs 269 | iteration: Current iteration number 270 | num_clusters: Number of clusters 271 | model_dir: Directory to save plot 272 | """ 273 | plt.figure(figsize=(10, 6)) 274 | plt.plot(range(len(accuracy_history)), accuracy_history, 275 | marker='o', markersize=3, linewidth=2, color='blue') 276 | plt.xlabel("Epoch", fontsize=12) 277 | plt.ylabel("Clustering Accuracy", fontsize=12) 278 | plt.title(f"GMM Training Progress - Iteration {iteration} (K={num_clusters})", 279 | fontsize=14) 280 | plt.grid(True, alpha=0.3) 281 | plt.ylim(0, 1) 282 | 283 | # Add max accuracy annotation 284 | max_acc = np.max(accuracy_history) 285 | max_epoch = np.argmax(accuracy_history) 286 | plt.annotate(f'Max: {max_acc:.3f}', 287 | xy=(max_epoch, max_acc), 288 | xytext=(max_epoch + len(accuracy_history) * 0.1, max_acc + 0.05), 289 | arrowprops=dict(arrowstyle='->', color='red', alpha=0.7), 290 | fontsize=10, color='red') 291 | 292 | plt.tight_layout() 293 | plt.savefig(f"{model_dir}/graph/acc_{iteration}.png", dpi=150, bbox_inches='tight') 294 | plt.close() 295 | 296 | 297 | def train( 298 | iteration: int, 299 | x_d: np.ndarray, 300 | label: np.ndarray, 301 | K: int, 302 | epoch: int = 100, 303 | model_dir: str = "vae_gmm" 304 | ) -> Tuple[torch.Tensor, torch.Tensor, float]: 305 | """ 306 | Train GMM using Gibbs sampling for clustering VAE latent variables. 307 | 308 | Args: 309 | iteration: Current mutual learning iteration 310 | x_d: Latent variables from VAE [num_data, latent_dim] 311 | label: Ground truth labels for accuracy calculation 312 | K: Number of clusters 313 | epoch: Number of training epochs 314 | model_dir: Directory to save results 315 | 316 | Returns: 317 | Tuple of (data_means, data_variances, max_accuracy) 318 | """ 319 | print(f"GMM Training Start - Iteration {iteration}") 320 | 321 | # Data dimensions 322 | num_data, data_dim = x_d.shape 323 | config = GMMConfig() 324 | 325 | # Initialize parameters 326 | mu_kd, lambda_kdd, pi_k, prior_params = initialize_gmm_parameters( 327 | num_data, data_dim, K, config 328 | ) 329 | 330 | # Training loop 331 | accuracy_history = np.zeros(epoch) 332 | max_accuracy = 0.0 333 | pred_labels = [] 334 | 335 | for epoch_idx in range(epoch): 336 | # E-step: Compute responsibilities 337 | eta_dk = compute_responsibilities(x_d, mu_kd, lambda_kdd, pi_k, config) 338 | 339 | # Sample cluster assignments 340 | z_dk, pred_labels = sample_cluster_assignments(eta_dk) 341 | 342 | # M-step: Update parameters using Gibbs sampling 343 | mu_kd, lambda_kdd, pi_k = update_gmm_parameters( 344 | x_d, z_dk, prior_params, config 345 | ) 346 | 347 | # Calculate and store accuracy 348 | accuracy = calc_acc(pred_labels, label)[0] 349 | accuracy_history[epoch_idx] = np.round(accuracy, 3) 350 | max_accuracy = max(max_accuracy, accuracy_history[epoch_idx]) 351 | 352 | # Log progress 353 | if (epoch_idx == 0 or (epoch_idx + 1) % config.log_interval == 0 or 354 | epoch_idx == epoch - 1): 355 | print(f"====> Epoch: {epoch_idx + 1}/{epoch}, " 356 | f"Accuracy: {accuracy_history[epoch_idx]:.3f}, " 357 | f"Max Accuracy: {max_accuracy:.3f}") 358 | 359 | # Compute final data parameters 360 | mu_d, var_d = compute_data_parameters(mu_kd, lambda_kdd, pred_labels) 361 | 362 | # Save results 363 | save_gmm_parameters(mu_kd, lambda_kdd, pi_k, iteration, model_dir) 364 | plot_accuracy_curve(accuracy_history, iteration, K, model_dir) 365 | 366 | print(f"GMM training completed. Final accuracy: {accuracy_history[-1]:.3f}") 367 | 368 | return torch.from_numpy(mu_d), torch.from_numpy(var_d), max_accuracy 369 | -------------------------------------------------------------------------------- /vae_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Variational Autoencoder (VAE) module for the VAE-GMM mutual learning system. 3 | 4 | This module provides the VAE implementation with encoder-decoder architecture, 5 | training functionality, and utilities for latent space visualization and image reconstruction. 6 | """ 7 | 8 | from dataclasses import dataclass 9 | from typing import Optional, Tuple 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch 14 | import torch.utils.data 15 | from torch import nn, optim 16 | from torch.nn import functional as F 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import save_image 19 | 20 | from tool import get_param, sample, visualize_ls 21 | 22 | 23 | @dataclass 24 | class VAEConfig: 25 | """Configuration class for VAE parameters.""" 26 | 27 | # Architecture parameters 28 | latent_dim: int = 12 # Dimension of latent variable 29 | input_dim: int = 784 # MNIST image dimension (28x28) 30 | hidden_dim: int = 256 # Hidden layer dimension 31 | 32 | # Training parameters 33 | learning_rate: float = 1e-3 34 | reconstruction_loss: str = "bce" # Binary cross-entropy for MNIST 35 | 36 | # Visualization parameters 37 | log_interval: int = 50 # Logging interval during training 38 | 39 | 40 | class VAE(nn.Module): 41 | """ 42 | Variational Autoencoder with encoder-decoder architecture for MNIST. 43 | 44 | The VAE includes: 45 | - Encoder: Maps input images to latent space (mean and log-variance) 46 | - Decoder: Reconstructs images from latent representations 47 | - Reparameterization: Enables backpropagation through stochastic sampling 48 | """ 49 | 50 | def __init__(self, config: VAEConfig) -> None: 51 | """ 52 | Initialize VAE network layers. 53 | 54 | Args: 55 | config: VAE configuration containing architecture parameters 56 | """ 57 | super(VAE, self).__init__() 58 | self.config = config 59 | 60 | # Encoder layers 61 | self.encoder_hidden = nn.Linear(config.input_dim, config.hidden_dim) 62 | self.encoder_mu = nn.Linear(config.hidden_dim, config.latent_dim) 63 | self.encoder_logvar = nn.Linear(config.hidden_dim, config.latent_dim) 64 | 65 | # Decoder layers 66 | self.decoder_hidden = nn.Linear(config.latent_dim, config.hidden_dim) 67 | self.decoder_output = nn.Linear(config.hidden_dim, config.input_dim) 68 | 69 | def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 70 | """ 71 | Encode input to latent space parameters. 72 | 73 | Args: 74 | x: Input tensor [batch_size, input_dim] 75 | 76 | Returns: 77 | Tuple of (mean, log_variance) tensors [batch_size, latent_dim] 78 | """ 79 | hidden = F.relu(self.encoder_hidden(x)) 80 | mu = self.encoder_mu(hidden) 81 | logvar = self.encoder_logvar(hidden) 82 | return mu, logvar 83 | 84 | def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: 85 | """ 86 | Reparameterization trick for backpropagation through stochastic sampling. 87 | 88 | Args: 89 | mu: Mean tensor [batch_size, latent_dim] 90 | logvar: Log-variance tensor [batch_size, latent_dim] 91 | 92 | Returns: 93 | Sampled latent vector [batch_size, latent_dim] 94 | """ 95 | std = torch.exp(0.5 * logvar) 96 | eps = torch.randn_like(std) 97 | return mu + eps * std 98 | 99 | def decode(self, z: torch.Tensor) -> torch.Tensor: 100 | """ 101 | Decode latent representation to output space. 102 | 103 | Args: 104 | z: Latent tensor [batch_size, latent_dim] 105 | 106 | Returns: 107 | Reconstructed output [batch_size, input_dim] 108 | """ 109 | hidden = F.relu(self.decoder_hidden(z)) 110 | return torch.sigmoid(self.decoder_output(hidden)) 111 | 112 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 113 | """ 114 | Forward pass through VAE. 115 | 116 | Args: 117 | x: Input tensor [batch_size, ...] 118 | 119 | Returns: 120 | Tuple of (reconstruction, mu, logvar, latent_sample) 121 | """ 122 | x_flat = x.view(-1, self.config.input_dim) 123 | mu, logvar = self.encode(x_flat) 124 | z = self.reparameterize(mu, logvar) 125 | reconstruction = self.decode(z) 126 | return reconstruction, mu, logvar, z 127 | 128 | def compute_loss( 129 | self, 130 | reconstruction: torch.Tensor, 131 | target: torch.Tensor, 132 | mu: torch.Tensor, 133 | logvar: torch.Tensor, 134 | gmm_mu: Optional[torch.Tensor] = None, 135 | gmm_var: Optional[torch.Tensor] = None, 136 | device: torch.device = torch.device("cpu") 137 | ) -> torch.Tensor: 138 | """ 139 | Compute VAE loss (reconstruction + KL divergence). 140 | 141 | Reference: Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 142 | https://arxiv.org/abs/1312.6114 143 | 144 | Args: 145 | reconstruction: Reconstructed output [batch_size, input_dim] 146 | target: Original input [batch_size, input_dim] 147 | mu: Encoder mean [batch_size, latent_dim] 148 | logvar: Encoder log-variance [batch_size, latent_dim] 149 | gmm_mu: GMM mean parameters (None for standard normal prior) 150 | gmm_var: GMM variance parameters (None for standard normal prior) 151 | device: Device for computation 152 | 153 | Returns: 154 | Total loss tensor 155 | """ 156 | # Reconstruction loss (Binary Cross-Entropy) 157 | target_flat = target.view(-1, self.config.input_dim) 158 | reconstruction_loss = F.binary_cross_entropy( 159 | reconstruction, target_flat, reduction="sum" 160 | ) 161 | 162 | # KL divergence 163 | if gmm_mu is not None and gmm_var is not None: 164 | # Use GMM parameters as prior (mutual learning iterations > 0) 165 | kl_divergence = self._compute_gmm_kl_divergence( 166 | mu, logvar, gmm_mu, gmm_var, device 167 | ) 168 | else: 169 | # Use standard normal prior (first iteration) 170 | kl_divergence = self._compute_standard_kl_divergence(mu, logvar) 171 | 172 | return reconstruction_loss + kl_divergence 173 | 174 | def _compute_standard_kl_divergence( 175 | self, mu: torch.Tensor, logvar: torch.Tensor 176 | ) -> torch.Tensor: 177 | """Compute KL divergence with standard normal prior N(0,I).""" 178 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 179 | 180 | def _compute_gmm_kl_divergence( 181 | self, 182 | mu: torch.Tensor, 183 | logvar: torch.Tensor, 184 | gmm_mu: torch.Tensor, 185 | gmm_var: torch.Tensor, 186 | device: torch.device 187 | ) -> torch.Tensor: 188 | """Compute KL divergence with GMM prior.""" 189 | # Convert GMM parameters to appropriate format 190 | prior_mu = gmm_mu.to(device).expand_as(mu) 191 | prior_var = gmm_var.to(device).expand_as(logvar) 192 | prior_logvar = prior_var.log() 193 | 194 | # KL divergence components 195 | var_division = logvar.exp() / prior_var # σ²_q / σ²_p 196 | diff = mu - prior_mu # μ_q - μ_p 197 | diff_term = diff * diff / prior_var # (μ_q - μ_p)² / σ²_p 198 | logvar_division = prior_logvar - logvar # log(σ²_p) - log(σ²_q) 199 | 200 | kl_per_sample = 0.5 * ( 201 | (var_division + diff_term + logvar_division).sum(1) - self.config.latent_dim 202 | ) 203 | return kl_per_sample.sum() 204 | 205 | 206 | def save_model_and_losses( 207 | model: VAE, loss_list: np.ndarray, iteration: int, model_dir: str 208 | ) -> None: 209 | """Save trained model and loss history.""" 210 | # Save model state 211 | model_path = f"{model_dir}/pth/vae_{iteration}.pth" 212 | torch.save(model.state_dict(), model_path) 213 | 214 | # Save loss history 215 | loss_path = f"{model_dir}/npy/loss_{iteration}.npy" 216 | np.save(loss_path, loss_list) 217 | 218 | 219 | def plot_training_losses( 220 | loss_list: np.ndarray, iteration: int, model_dir: str 221 | ) -> None: 222 | """Plot and save training loss curves.""" 223 | plt.figure() 224 | plt.plot(range(len(loss_list)), loss_list, color="blue", label="ELBO") 225 | 226 | # Compare with first iteration if available 227 | if iteration != 0: 228 | try: 229 | loss_0 = np.load(f"{model_dir}/npy/loss_0.npy") 230 | plt.plot(range(len(loss_0)), loss_0, color="red", label="ELBO_I0") 231 | except FileNotFoundError: 232 | pass # Skip comparison if first iteration data not available 233 | 234 | plt.xlabel("Epoch") 235 | plt.ylabel("ELBO") 236 | plt.legend(loc="lower right") 237 | plt.title(f"VAE Training Loss - Iteration {iteration}") 238 | 239 | # Save plot 240 | plot_path = f"{model_dir}/graph/vae_loss_{iteration}.png" 241 | plt.savefig(plot_path) 242 | plt.close() 243 | 244 | 245 | def train_single_epoch( 246 | model: VAE, 247 | train_loader: DataLoader, 248 | optimizer: optim.Optimizer, 249 | iteration: int, 250 | gmm_mu: Optional[torch.Tensor], 251 | gmm_var: Optional[torch.Tensor], 252 | device: torch.device 253 | ) -> float: 254 | """Train VAE for one epoch.""" 255 | model.train() 256 | total_loss = 0.0 257 | 258 | for batch_idx, (data, _) in enumerate(train_loader): 259 | data = data.to(device) 260 | optimizer.zero_grad() 261 | 262 | # Forward pass 263 | reconstruction, mu, logvar, latent = model(data) 264 | 265 | # Compute loss 266 | if iteration == 0: 267 | # First iteration: use standard normal prior 268 | loss = model.compute_loss( 269 | reconstruction, data, mu, logvar, device=device 270 | ) 271 | else: 272 | # Subsequent iterations: use GMM prior 273 | batch_gmm_mu = gmm_mu[batch_idx] if gmm_mu is not None else None 274 | batch_gmm_var = gmm_var[batch_idx] if gmm_var is not None else None 275 | loss = model.compute_loss( 276 | reconstruction, data, mu, logvar, 277 | batch_gmm_mu, batch_gmm_var, device 278 | ) 279 | 280 | # Backward pass 281 | loss = loss.mean() 282 | loss.backward() 283 | optimizer.step() 284 | 285 | total_loss += loss.item() 286 | 287 | return total_loss 288 | 289 | 290 | def train( 291 | iteration: int, 292 | gmm_mu: Optional[torch.Tensor], 293 | gmm_var: Optional[torch.Tensor], 294 | epoch: int, 295 | train_loader: DataLoader, 296 | all_loader: DataLoader, 297 | model_dir: str = "./vae_gmm", 298 | device: torch.device = torch.device("cpu") 299 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 300 | """ 301 | Train VAE model for specified number of epochs. 302 | 303 | Args: 304 | iteration: Current mutual learning iteration 305 | gmm_mu: GMM mean parameters (None for first iteration) 306 | gmm_var: GMM variance parameters (None for first iteration) 307 | epoch: Number of training epochs 308 | train_loader: DataLoader for training batches 309 | all_loader: DataLoader for all data (used for latent variable extraction) 310 | model_dir: Directory to save model and results 311 | device: Device for computation 312 | 313 | Returns: 314 | Tuple of (latent_variables, labels, loss_history) 315 | """ 316 | print(f"VAE Training Start - Iteration {iteration}") 317 | 318 | # Initialize model and optimizer 319 | config = VAEConfig() 320 | model = VAE(config).to(device) 321 | optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) 322 | 323 | # Track losses 324 | loss_history = np.zeros(epoch) 325 | 326 | # Training loop 327 | for epoch_idx in range(epoch): 328 | epoch_loss = train_single_epoch( 329 | model, train_loader, optimizer, iteration, 330 | gmm_mu, gmm_var, device 331 | ) 332 | 333 | # Store negative loss (ELBO) 334 | avg_loss = epoch_loss / len(train_loader.dataset) 335 | loss_history[epoch_idx] = -avg_loss 336 | 337 | # Log progress 338 | if (epoch_idx == 0 or (epoch_idx + 1) % config.log_interval == 0 or 339 | epoch_idx == epoch - 1): 340 | print(f"====> Epoch: {epoch_idx + 1}/{epoch}, Average loss: {avg_loss:.4f}") 341 | 342 | # Save model and results 343 | save_model_and_losses(model, loss_history, iteration, model_dir) 344 | plot_training_losses(loss_history, iteration, model_dir) 345 | 346 | # Extract latent variables for all data 347 | latent_vars, labels = extract_latent_variables( 348 | model, all_loader, device 349 | ) 350 | 351 | return latent_vars, labels, loss_history 352 | 353 | 354 | def load_trained_model( 355 | iteration: int, model_dir: str, device: torch.device 356 | ) -> VAE: 357 | """Load a trained VAE model from checkpoint.""" 358 | config = VAEConfig() 359 | model = VAE(config).to(device) 360 | model_path = f"{model_dir}/pth/vae_{iteration}.pth" 361 | model.load_state_dict(torch.load(model_path, map_location=device)) 362 | model.eval() 363 | return model 364 | 365 | 366 | def extract_latent_variables( 367 | model: VAE, data_loader: DataLoader, device: torch.device 368 | ) -> Tuple[np.ndarray, np.ndarray]: 369 | """ 370 | Extract latent variables and labels from all data using trained VAE. 371 | 372 | Args: 373 | model: Trained VAE model 374 | data_loader: DataLoader containing all data 375 | device: Device for computation 376 | 377 | Returns: 378 | Tuple of (latent_variables, labels) as numpy arrays 379 | """ 380 | model.eval() 381 | with torch.no_grad(): 382 | for batch_idx, (data, labels) in enumerate(data_loader): 383 | data = data.to(device) 384 | _, _, _, latent_vars = model(data) 385 | latent_vars = latent_vars.cpu() 386 | labels = labels.cpu() 387 | # Only process the first (and typically only) batch for all_loader 388 | break 389 | 390 | return latent_vars.detach().numpy(), labels.detach().numpy() 391 | 392 | 393 | def decode( 394 | iteration: int, 395 | decode_k: int, 396 | sample_num: int, 397 | model_dir: str = "./vae_gmm", 398 | device: torch.device = torch.device("cpu") 399 | ) -> None: 400 | """ 401 | Reconstruct images from GMM parameters using trained VAE decoder. 402 | 403 | Args: 404 | iteration: Which iteration's model to load 405 | decode_k: Cluster number of Gaussian distribution for sampling 406 | sample_num: Number of samples to generate 407 | model_dir: Directory containing model files 408 | device: Device for computation 409 | """ 410 | print(f"Reconstructing {sample_num} images from GMM cluster {decode_k}") 411 | 412 | # Load trained model 413 | model = load_trained_model(iteration, model_dir, device) 414 | 415 | # Get GMM parameters and generate samples 416 | config = VAEConfig() 417 | mu_gmm, lambda_gmm, pi_gmm = get_param(iteration, model_dir=model_dir) 418 | manual_sample, _ = sample( 419 | iteration=iteration, 420 | x_dim=config.latent_dim, 421 | mu_gmm=mu_gmm, 422 | lambda_gmm=lambda_gmm, 423 | sample_num=sample_num, 424 | sample_k=decode_k, 425 | model_dir=model_dir, 426 | ) 427 | 428 | # Convert to tensor and generate reconstructions 429 | samples = torch.from_numpy(manual_sample.astype(np.float32)).to(device) 430 | 431 | with torch.no_grad(): 432 | reconstructions = model.decode(samples).cpu() 433 | # Save reconstructed images 434 | output_path = f"{model_dir}/recon/manual_{decode_k}.png" 435 | save_image( 436 | reconstructions.view(sample_num, 1, 28, 28), 437 | output_path 438 | ) 439 | print(f"Saved reconstructed images to: {output_path}") 440 | 441 | 442 | def plot_latent( 443 | iteration: int, 444 | all_loader: DataLoader, 445 | model_dir: str = "./vae_gmm", 446 | device: torch.device = torch.device("cpu") 447 | ) -> None: 448 | """ 449 | Visualize VAE latent space using all data points. 450 | 451 | Args: 452 | iteration: Current iteration number 453 | all_loader: DataLoader containing all data 454 | model_dir: Directory for saving visualizations 455 | device: Device for computation 456 | """ 457 | print("Plotting latent space visualization") 458 | 459 | # Load trained model 460 | model = load_trained_model(iteration, model_dir, device) 461 | 462 | # Extract latent variables for visualization 463 | latent_vars, labels = extract_latent_variables(model, all_loader, device) 464 | 465 | # Create visualization using tool function 466 | visualize_ls(iteration, latent_vars, labels, model_dir) 467 | -------------------------------------------------------------------------------- /tool.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualization and utility tools for the VAE-GMM mutual learning system. 3 | 4 | This module provides comprehensive visualization capabilities, sampling functions, 5 | parameter loading utilities, and accuracy calculation for evaluating the performance 6 | of the VAE-GMM mutual learning system. 7 | """ 8 | 9 | from dataclasses import dataclass 10 | from typing import List, Tuple, Optional 11 | 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | from scipy.stats import multivariate_normal 15 | from sklearn.manifold import TSNE 16 | 17 | 18 | @dataclass 19 | class VisualizationConfig: 20 | """Configuration class for visualization parameters.""" 21 | 22 | # Figure dimensions 23 | figure_size_2d: Tuple[int, int] = (12, 9) 24 | figure_size_latent: Tuple[int, int] = (10, 10) 25 | 26 | # Grid resolution for contour plots 27 | grid_resolution: int = 900 28 | 29 | # Visualization parameters 30 | marker_size: int = 100 31 | scatter_size: int = 100 32 | alpha: float = 0.5 33 | font_size_title: int = 20 34 | font_size_labels: int = 17 35 | 36 | # Colors for different clusters/classes 37 | colors: List[str] = None 38 | 39 | # Dimensionality reduction 40 | use_tsne: bool = True # If False, could use PCA 41 | tsne_perplexity: float = 30.0 42 | random_state: int = 0 43 | 44 | # File format 45 | dpi: int = 150 46 | file_format: str = "png" 47 | 48 | def __post_init__(self): 49 | if self.colors is None: 50 | self.colors = [ 51 | "red", "green", "blue", "orange", "purple", 52 | "yellow", "black", "cyan", "#a65628", "#f781bf" 53 | ] 54 | 55 | 56 | @dataclass 57 | class SamplingConfig: 58 | """Configuration class for sampling parameters.""" 59 | 60 | # Manual sampling parameters 61 | manual_sigma: float = 0.1 # Standard deviation for manual sampling 62 | 63 | # Default latent dimension 64 | default_latent_dim: int = 12 65 | 66 | # Default number of clusters 67 | default_num_clusters: int = 10 68 | 69 | 70 | def load_gmm_parameters(iteration: int, model_dir: str = "./vae_gmm") -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 71 | """ 72 | Load GMM parameters from saved files. 73 | 74 | Args: 75 | iteration: Iteration number to load 76 | model_dir: Directory containing saved parameters 77 | 78 | Returns: 79 | Tuple of (means, precision_matrices, mixing_coefficients) 80 | """ 81 | mu_path = f"{model_dir}/npy/mu_{iteration}.npy" 82 | lambda_path = f"{model_dir}/npy/lambda_{iteration}.npy" 83 | pi_path = f"{model_dir}/npy/pi_{iteration}.npy" 84 | 85 | mu_gmm = np.load(mu_path) 86 | lambda_gmm = np.load(lambda_path) 87 | pi_gmm = np.load(pi_path) 88 | 89 | return mu_gmm, lambda_gmm, pi_gmm 90 | 91 | 92 | def sample_from_gmm_cluster( 93 | cluster_mean: np.ndarray, 94 | cluster_precision: np.ndarray, 95 | sample_num: int, 96 | config: SamplingConfig 97 | ) -> Tuple[np.ndarray, np.ndarray]: 98 | """ 99 | Sample from a specific GMM cluster using two different methods. 100 | 101 | Args: 102 | cluster_mean: Mean vector of the cluster 103 | cluster_precision: Precision matrix of the cluster 104 | sample_num: Number of samples to generate 105 | config: Sampling configuration 106 | 107 | Returns: 108 | Tuple of (manual_samples, gmm_samples) 109 | - manual_samples: Samples using fixed covariance 110 | - gmm_samples: Samples using estimated covariance from GMM 111 | """ 112 | latent_dim = len(cluster_mean) 113 | 114 | # Manual sampling with fixed small covariance 115 | manual_cov = config.manual_sigma * np.identity(latent_dim, dtype=float) 116 | manual_samples = np.random.multivariate_normal( 117 | mean=cluster_mean, 118 | cov=manual_cov, 119 | size=sample_num 120 | ) 121 | 122 | # GMM-based sampling using estimated covariance 123 | gmm_cov = np.linalg.inv(cluster_precision) 124 | gmm_samples = np.random.multivariate_normal( 125 | mean=cluster_mean, 126 | cov=gmm_cov, 127 | size=sample_num 128 | ) 129 | 130 | return manual_samples, gmm_samples 131 | 132 | 133 | def sample( 134 | iteration: int, 135 | x_dim: int, 136 | mu_gmm: np.ndarray, 137 | lambda_gmm: np.ndarray, 138 | sample_num: int, 139 | sample_k: int, 140 | model_dir: str = "./vae_gmm" 141 | ) -> Tuple[np.ndarray, np.ndarray]: 142 | """ 143 | Sample random variables for VAE decoder input from GMM posterior. 144 | 145 | Args: 146 | iteration: Current iteration (for compatibility) 147 | x_dim: Latent space dimensionality 148 | mu_gmm: GMM means [num_clusters, latent_dim] 149 | lambda_gmm: GMM precision matrices [num_clusters, latent_dim, latent_dim] 150 | sample_num: Number of samples to generate 151 | sample_k: Cluster index to sample from 152 | model_dir: Model directory (for compatibility) 153 | 154 | Returns: 155 | Tuple of (manual_samples, gmm_samples) 156 | """ 157 | config = SamplingConfig() 158 | 159 | return sample_from_gmm_cluster( 160 | cluster_mean=mu_gmm[sample_k], 161 | cluster_precision=lambda_gmm[sample_k], 162 | sample_num=sample_num, 163 | config=config 164 | ) 165 | 166 | 167 | def create_2d_grid( 168 | means_2d: np.ndarray, 169 | precisions_2d: np.ndarray, 170 | config: VisualizationConfig 171 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 172 | """ 173 | Create a 2D grid for contour plotting. 174 | 175 | Args: 176 | means_2d: 2D cluster means [num_clusters, 2] 177 | precisions_2d: 2D precision matrices [num_clusters, 2, 2] 178 | config: Visualization configuration 179 | 180 | Returns: 181 | Tuple of (x1_grid, x2_grid, grid_points) 182 | """ 183 | # Calculate grid bounds based on cluster parameters 184 | std_factor = 3.0 # Number of standard deviations to include 185 | 186 | x1_std = np.sqrt(1.0 / precisions_2d[:, 0, 0]) 187 | x2_std = np.sqrt(1.0 / precisions_2d[:, 1, 1]) 188 | 189 | x1_min = np.min(means_2d[:, 0] - std_factor * x1_std) 190 | x1_max = np.max(means_2d[:, 0] + std_factor * x1_std) 191 | x2_min = np.min(means_2d[:, 1] - std_factor * x2_std) 192 | x2_max = np.max(means_2d[:, 1] + std_factor * x2_std) 193 | 194 | # Create grid 195 | x1_line = np.linspace(x1_min, x1_max, config.grid_resolution) 196 | x2_line = np.linspace(x2_min, x2_max, config.grid_resolution) 197 | x1_grid, x2_grid = np.meshgrid(x1_line, x2_line) 198 | 199 | # Grid points for density evaluation 200 | grid_points = np.stack([x1_grid.flatten(), x2_grid.flatten()], axis=1) 201 | 202 | return x1_grid, x2_grid, grid_points 203 | 204 | 205 | def compute_2d_density( 206 | grid_points: np.ndarray, 207 | cluster_mean_2d: np.ndarray, 208 | cluster_cov_2d: np.ndarray, 209 | mixing_coeff: float 210 | ) -> np.ndarray: 211 | """ 212 | Compute 2D density for a single cluster. 213 | 214 | Args: 215 | grid_points: Grid points for evaluation [num_points, 2] 216 | cluster_mean_2d: 2D cluster mean [2] 217 | cluster_cov_2d: 2D cluster covariance [2, 2] 218 | mixing_coeff: Mixing coefficient for this cluster 219 | 220 | Returns: 221 | Density values at grid points 222 | """ 223 | density = multivariate_normal.pdf( 224 | x=grid_points, 225 | mean=cluster_mean_2d, 226 | cov=cluster_cov_2d 227 | ) 228 | return density * mixing_coeff 229 | 230 | 231 | def visualize_gmm( 232 | iteration: int, 233 | decode_k: int, 234 | sample_num: int, 235 | model_dir: str = "./vae_gmm" 236 | ) -> None: 237 | """ 238 | Visualize GMM clusters and samples in 2D space. 239 | 240 | Args: 241 | iteration: Current iteration number 242 | decode_k: Cluster index to visualize 243 | sample_num: Number of samples to generate 244 | model_dir: Directory containing model files 245 | """ 246 | config = VisualizationConfig() 247 | sampling_config = SamplingConfig() 248 | 249 | # Load GMM parameters 250 | mu_gmm, lambda_gmm, pi_gmm = load_gmm_parameters(iteration, model_dir) 251 | 252 | # Generate samples from specified cluster 253 | manual_samples, gmm_samples = sample( 254 | iteration=iteration, 255 | x_dim=sampling_config.default_latent_dim, 256 | mu_gmm=mu_gmm, 257 | lambda_gmm=lambda_gmm, 258 | sample_num=sample_num, 259 | sample_k=decode_k, 260 | model_dir=model_dir 261 | ) 262 | 263 | # Extract 2D projections for visualization 264 | num_clusters = len(mu_gmm) 265 | means_2d = mu_gmm[:, :2] # First 2 dimensions 266 | precisions_2d = lambda_gmm[:, :2, :2] # First 2x2 block 267 | 268 | # Create visualization grid 269 | x1_grid, x2_grid, grid_points = create_2d_grid(means_2d, precisions_2d, config) 270 | 271 | # Compute density for the selected cluster 272 | cluster_cov_2d = np.linalg.inv(precisions_2d[decode_k]) 273 | density = compute_2d_density( 274 | grid_points, means_2d[decode_k], cluster_cov_2d, pi_gmm[decode_k] 275 | ) 276 | 277 | # Create the plot 278 | plt.figure(figsize=config.figure_size_2d) 279 | 280 | # Plot samples 281 | plt.scatter( 282 | manual_samples[:, 0], manual_samples[:, 1], 283 | label=f"Cluster {decode_k + 1} samples", 284 | s=config.scatter_size, alpha=0.7 285 | ) 286 | 287 | # Plot cluster centers 288 | plt.scatter( 289 | means_2d[:, 0], means_2d[:, 1], 290 | color="red", s=config.marker_size, marker="x", 291 | label="Cluster centers", linewidths=3 292 | ) 293 | 294 | # Plot density contours 295 | density_reshaped = density.reshape(x1_grid.shape) 296 | contours = plt.contour( 297 | x1_grid, x2_grid, density_reshaped, 298 | alpha=config.alpha, linestyles="dashed" 299 | ) 300 | 301 | # Formatting 302 | plt.suptitle("Gaussian Mixture Model Visualization", fontsize=config.font_size_title) 303 | plt.title(f"Samples: {sample_num}, Cluster: {decode_k + 1}", fontsize=14) 304 | plt.xlabel("Latent Dimension 1", fontsize=12) 305 | plt.ylabel("Latent Dimension 2", fontsize=12) 306 | plt.legend() 307 | plt.grid(True, alpha=0.3) 308 | plt.colorbar(contours, label="Density") 309 | 310 | # Save the plot 311 | output_path = f"{model_dir}/graph/gaussian_I{iteration}_k{decode_k}.{config.file_format}" 312 | plt.savefig(output_path, dpi=config.dpi, bbox_inches='tight') 313 | plt.close() 314 | 315 | print(f"GMM visualization saved to: {output_path}") 316 | 317 | 318 | def visualize_latent_space( 319 | iteration: int, 320 | latent_vars: np.ndarray, 321 | labels: np.ndarray, 322 | save_dir: str, 323 | config: Optional[VisualizationConfig] = None 324 | ) -> None: 325 | """ 326 | Visualize VAE latent space using dimensionality reduction. 327 | 328 | Args: 329 | iteration: Current iteration number 330 | latent_vars: Latent variables [num_samples, latent_dim] 331 | labels: Ground truth labels [num_samples] 332 | save_dir: Directory to save visualization 333 | config: Visualization configuration 334 | """ 335 | if config is None: 336 | config = VisualizationConfig() 337 | 338 | # Apply dimensionality reduction 339 | if config.use_tsne: 340 | reducer = TSNE( 341 | n_components=2, 342 | random_state=config.random_state, 343 | perplexity=min(config.tsne_perplexity, len(latent_vars) - 1) 344 | ) 345 | points_2d = reducer.fit_transform(latent_vars) 346 | else: 347 | # Could implement PCA here if needed 348 | raise NotImplementedError("PCA option not implemented yet") 349 | 350 | # Create the plot 351 | plt.figure(figsize=config.figure_size_latent) 352 | 353 | # Plot each class with different colors and number markers 354 | unique_labels = np.unique(labels) 355 | for label in unique_labels: 356 | mask = labels == label 357 | plt.scatter( 358 | points_2d[mask, 0], points_2d[mask, 1], 359 | c=config.colors[label % len(config.colors)], 360 | s=config.scatter_size, 361 | marker=f"${label}$", 362 | label=f"Class {label}", 363 | alpha=0.7 364 | ) 365 | 366 | # Formatting 367 | plt.title("VAE Latent Space Visualization", fontsize=config.font_size_title) 368 | plt.xlabel("t-SNE Component 1", fontsize=12) 369 | plt.ylabel("t-SNE Component 2", fontsize=12) 370 | plt.tick_params(labelsize=12) 371 | plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') 372 | plt.grid(True, alpha=0.3) 373 | 374 | # Save the plot 375 | output_path = f"{save_dir}/graph/latent_space_{iteration}.{config.file_format}" 376 | plt.savefig(output_path, dpi=config.dpi, bbox_inches='tight') 377 | plt.close() 378 | 379 | print(f"Latent space visualization saved to: {output_path}") 380 | 381 | 382 | def visualize_ls(iteration: int, z: np.ndarray, labels: np.ndarray, save_dir: str) -> None: 383 | """ 384 | Legacy function for latent space visualization (maintained for compatibility). 385 | 386 | Args: 387 | iteration: Current iteration number 388 | z: Latent variables 389 | labels: Ground truth labels 390 | save_dir: Directory to save visualization 391 | """ 392 | visualize_latent_space(iteration, z, labels, save_dir) 393 | 394 | 395 | def get_param(iteration: int, model_dir: str = "./vae_gmm") -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 396 | """ 397 | Legacy function for loading GMM parameters (maintained for compatibility). 398 | 399 | Args: 400 | iteration: Iteration number to load 401 | model_dir: Directory containing saved parameters 402 | 403 | Returns: 404 | Tuple of (means, precision_matrices, mixing_coefficients) 405 | """ 406 | return load_gmm_parameters(iteration, model_dir) 407 | 408 | 409 | def compute_clustering_accuracy_hungarian( 410 | predicted_labels: np.ndarray, 411 | true_labels: np.ndarray 412 | ) -> Tuple[float, np.ndarray]: 413 | """ 414 | Compute clustering accuracy using Hungarian algorithm approximation. 415 | 416 | This function finds the best permutation of cluster labels to maximize 417 | accuracy against ground truth labels. 418 | 419 | Args: 420 | predicted_labels: Predicted cluster assignments 421 | true_labels: Ground truth labels 422 | 423 | Returns: 424 | Tuple of (best_accuracy, best_permuted_labels) 425 | """ 426 | num_clusters = int(np.max(predicted_labels)) + 1 427 | num_samples = len(predicted_labels) 428 | 429 | best_accuracy = 0.0 430 | best_labels = predicted_labels.copy() 431 | 432 | # Try all possible permutations (brute force for small K) 433 | # For larger K, this should use the Hungarian algorithm 434 | if num_clusters <= 10: # Brute force feasible for small number of clusters 435 | from itertools import permutations 436 | 437 | for perm in permutations(range(num_clusters)): 438 | # Create permutation mapping 439 | permuted_labels = np.zeros_like(predicted_labels) 440 | for original_label, new_label in enumerate(perm): 441 | mask = predicted_labels == original_label 442 | permuted_labels[mask] = new_label 443 | 444 | # Calculate accuracy 445 | accuracy = np.sum(permuted_labels == true_labels) / num_samples 446 | 447 | if accuracy > best_accuracy: 448 | best_accuracy = accuracy 449 | best_labels = permuted_labels.copy() 450 | else: 451 | # For larger number of clusters, use the original iterative approach 452 | best_accuracy, best_labels = calc_acc(predicted_labels, true_labels) 453 | 454 | return best_accuracy, best_labels 455 | 456 | 457 | def calc_acc(results: np.ndarray, correct: np.ndarray) -> Tuple[float, np.ndarray]: 458 | """ 459 | Calculate clustering accuracy by finding optimal label permutation. 460 | 461 | This function iteratively tries swapping pairs of cluster labels to find 462 | the permutation that maximizes accuracy against ground truth. 463 | 464 | Args: 465 | results: Predicted cluster labels 466 | correct: Ground truth labels 467 | 468 | Returns: 469 | Tuple of (maximum_accuracy, optimal_labels) 470 | """ 471 | num_clusters = int(np.max(results)) + 1 472 | num_samples = len(results) 473 | max_accuracy = 0.0 474 | current_labels = results.copy() 475 | 476 | # Iteratively improve by swapping labels 477 | improved = True 478 | max_iterations = 100 # Prevent infinite loops 479 | iteration_count = 0 480 | 481 | while improved and iteration_count < max_iterations: 482 | improved = False 483 | iteration_count += 1 484 | 485 | # Try all pairs of cluster labels 486 | for i in range(num_clusters): 487 | for j in range(i + 1, num_clusters): 488 | # Create temporary result with swapped labels 489 | temp_result = current_labels.copy() 490 | 491 | # Swap labels i and j 492 | mask_i = current_labels == i 493 | mask_j = current_labels == j 494 | temp_result[mask_i] = j 495 | temp_result[mask_j] = i 496 | 497 | # Calculate accuracy 498 | accuracy = np.sum(temp_result == correct) / num_samples 499 | 500 | # Update if improved 501 | if accuracy > max_accuracy: 502 | max_accuracy = accuracy 503 | current_labels = temp_result.copy() 504 | improved = True 505 | 506 | return max_accuracy, current_labels --------------------------------------------------------------------------------