├── .gitignore
├── image
├── acc_0.png
├── acc_1.png
├── model.png
├── manual_1.png
├── z_pca_0.png
├── z_pca_1.png
├── z_tsne_0.png
├── z_tsne_1.png
├── gause_I0k1.png
├── vae_loss_1.png
└── variable_define.png
├── requirements.txt
├── README.md
├── main.py
├── gmm_module.py
├── vae_module.py
└── tool.py
/.gitignore:
--------------------------------------------------------------------------------
1 | pth/
2 | model/
3 | __pycache__/
--------------------------------------------------------------------------------
/image/acc_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/acc_0.png
--------------------------------------------------------------------------------
/image/acc_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/acc_1.png
--------------------------------------------------------------------------------
/image/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/model.png
--------------------------------------------------------------------------------
/image/manual_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/manual_1.png
--------------------------------------------------------------------------------
/image/z_pca_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/z_pca_0.png
--------------------------------------------------------------------------------
/image/z_pca_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/z_pca_1.png
--------------------------------------------------------------------------------
/image/z_tsne_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/z_tsne_0.png
--------------------------------------------------------------------------------
/image/z_tsne_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/z_tsne_1.png
--------------------------------------------------------------------------------
/image/gause_I0k1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/gause_I0k1.png
--------------------------------------------------------------------------------
/image/vae_loss_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/vae_loss_1.png
--------------------------------------------------------------------------------
/image/variable_define.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/is0383kk/Pytorch_VAE-GMM/HEAD/image/variable_define.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # PyTorch VAE-GMM Requirements
2 | # Core deep learning framework
3 | torch>=1.5.1
4 | torchvision>=0.6.1
5 |
6 | # Numerical computing
7 | numpy>=1.19.0
8 |
9 | # Statistical computing
10 | scipy>=1.5.0
11 |
12 | # Machine learning utilities
13 | scikit-learn>=0.23.0
14 |
15 | # Plotting and visualization
16 | matplotlib>=3.3.0
17 |
18 | # Optional: For better performance with numerical operations
19 | # Uncomment if needed
20 | # mkl>=2021.1.1
21 | # mkl-service>=2.3.0
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Variational Auto-Encoder(VAE)+Gaussian mixture model(GMM)
2 |
3 | Implementation of mutual learning model between VAE and GMM.
4 | This idea of integrating probability models is based on this paper: [Neuro-SERKET: Development of Integrative Cognitive System through the Composition of Deep Probabilistic Generative Models](https://arxiv.org/abs/1910.08918).
5 | Symbol Emergence in Robotics tool KIT(SERKET) is a framework that allows integration and partitioning of probabilistic generative models.
6 |
7 | This is a Graphical Model of VAE+GMM model:
8 |
9 |
10 |


11 |
12 |
13 | VAE and GMM share the latent variable x.
14 | x is a variable that follows a multivariate normal distribution and is estimated by VAE.
15 |
16 | The training will be conducted in the following sequence.
17 |
18 | 1. VAE estimates latent variable(x) and sends latent variables(x) to GMM.
19 | 2. GMM clusters latent variables(x) sent from VAE and sends mean and variance parameters of the Gaussian distribution to VAE.
20 | 3. Return to 1 again.
21 |
22 | What this repo contains:
23 |
24 | - `main.py`: Main code for training model.
25 | - `vae_module.py`: A training program for VAE, running in main.py.
26 | - `gmm_module.py`: A training program for GMM, running in main.py.
27 | - `tool.py`: Various functions handled in the program.
28 |
29 | # How to run
30 |
31 | Install the required libraries using the following command.
32 | ※ Install PyTorch first (XXX should match your CUDA version).
33 | ※ My environment is the following **Pytorch==2.8.0+cu129, CUDA==12.9**
34 |
35 | ```bash
36 | $ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cuXXX
37 | $ pip install -r requirements.txt
38 | ```
39 |
40 | You can train the VAE+GMM model by running `main.py`.
41 |
42 | ```bash
43 | $ python main.py
44 | ```
45 |
46 | - `train_model()` can be made to train VAE+GMM.
47 | - `vae_module.decode()` makes image reconstruction from parameters of posterior distribution estimated by GMM.
48 |
49 | ```python:main.py
50 | def main() -> None:
51 | """Main function to orchestrate the VAE-GMM training process."""
52 | # Parse arguments and setup configuration
53 | config = parse_arguments()
54 |
55 | # Setup environment
56 | setup_directories(config)
57 | device = setup_device_and_seed(config)
58 |
59 | # Create data loaders
60 | train_loader, all_loader, train_size = create_data_loaders(config)
61 |
62 | # Train the VAE-GMM model
63 | train_model(config, train_loader, all_loader, device)
64 |
65 | # Reconstruct images from trained model
66 | print("\nGenerating reconstructed images...")
67 | vae_module.decode(
68 | iteration=1, # Use model from iteration 1
69 | decode_k=1, # Use cluster 1 for reconstruction
70 | sample_num=16, # Generate 16 samples
71 | model_dir=config.debug_dir,
72 | device=device,
73 | )
74 | ```
75 |
76 | # Changes with and without mutual learning (for MNIST)
77 |
78 | ## Latent space on VAE
79 |
80 | Left : without mutual learning・Right : with mutual learning
81 | Plot using TSNE
82 |
83 |
84 |


85 |
86 | Plot using PCA
87 |
88 |


89 |
90 |
91 | ## ELBO of VAE
92 |
93 | Red line is ELBO before mutual learning, Blue line is ELBO after mutual learning
94 | Vertical axis is training iteration of VAE, Horizontal axis is ELBO of VAE
95 | (In general, the higher the ELBO, the better)
96 |
97 |
98 |

99 |
100 |
101 | ## Clustering performance (in GMM)
102 |
103 | Results of clustering performance by accuracy(Addresses clustering performance in GMM within VAE+GMM)
104 | Left : without mutual learning・Right : with mutual learning
105 | Vertical axis is training iteration of GMM, Horizontal axis is accuracy
106 |
107 |
108 |


109 |
110 |
111 | # Image reconstruction from Gaussian distribution parameters estimated by GMM using VAE decoder
112 |
113 | GMM performs clustering on latent variables of VAE.
114 | By sampling random variables from posterior distribution estimated by GMM and using them as input to VAE decoder, the image can be reconstructed.
115 |
116 | "x" represents the mean parameter of the normal distribution for each cluster.
117 | In this example, a random variable is sampled from a Gaussian distribution with K=1.
118 |
119 |
120 |

121 |
122 | Reconstructed image of the sampled random variable input to the VAE decoder:
123 |
124 |

125 |
126 |
127 | # Special Thanks
128 |
129 | The implementation of GMM is based on
130 | [【Python】4.4.2:ガウス混合モデルにおける推論:ギブスサンプリング【緑ベイズ入門のノート】](https://www.anarchive-beta.com/entry/2020/11/28/210948)
131 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | PyTorch implementation of mutual learning between VAE and GMM models.
3 |
4 | This module orchestrates training, handles data loading and directory setup
5 | for the VAE-GMM mutual learning system on MNIST dataset.
6 | """
7 |
8 | import argparse
9 | import os
10 | from dataclasses import dataclass
11 | from typing import Tuple
12 |
13 | import torch
14 | from torch.utils.data import DataLoader, Subset
15 | from torchvision import datasets, transforms
16 |
17 | import gmm_module
18 | import vae_module
19 | from tool import visualize_gmm
20 |
21 |
22 | @dataclass
23 | class Config:
24 | """Configuration class for VAE-GMM training parameters."""
25 |
26 | # Training parameters
27 | vae_iter: int = 100
28 | gmm_iter: int = 100
29 | mutual_iterations: int = 2
30 |
31 | # Hardware settings
32 | use_cuda: bool = True
33 | seed: int = 1
34 |
35 | # Data parameters
36 | data_dir: str = "./../data"
37 | train_fraction: float = 1 / 6 # Use 1/6 of MNIST training set (10,000 samples)
38 | batch_size: int = 10 # Small batch size for GMM training stability
39 |
40 | # Model parameters
41 | num_clusters: int = 10 # K=10 for MNIST digit classes
42 |
43 | # Directory paths
44 | model_dir: str = "./model"
45 | debug_dir: str = "./model/debug"
46 |
47 | @property
48 | def graph_dir(self) -> str:
49 | """Directory for saving graphs and visualizations."""
50 | return os.path.join(self.debug_dir, "graph")
51 |
52 | @property
53 | def pth_dir(self) -> str:
54 | """Directory for saving VAE model states (.pth files)."""
55 | return os.path.join(self.debug_dir, "pth")
56 |
57 | @property
58 | def npy_dir(self) -> str:
59 | """Directory for saving GMM parameters (.npy files)."""
60 | return os.path.join(self.debug_dir, "npy")
61 |
62 | @property
63 | def recon_dir(self) -> str:
64 | """Directory for saving reconstructed images."""
65 | return os.path.join(self.debug_dir, "recon")
66 |
67 |
68 | def setup_directories(config: Config) -> None:
69 | """Create necessary directories for model outputs."""
70 | directories = [
71 | config.model_dir,
72 | config.debug_dir,
73 | config.graph_dir,
74 | config.pth_dir,
75 | config.npy_dir,
76 | config.recon_dir,
77 | ]
78 |
79 | for directory in directories:
80 | os.makedirs(directory, exist_ok=True)
81 |
82 |
83 | def setup_device_and_seed(config: Config) -> torch.device:
84 | """Setup device and random seed for reproducible results."""
85 | torch.manual_seed(config.seed)
86 |
87 | if config.use_cuda and torch.cuda.is_available():
88 | device = torch.device("cuda")
89 | print(f"Using CUDA device: {torch.cuda.get_device_name()}")
90 | else:
91 | device = torch.device("cpu")
92 | print("Using CPU device")
93 |
94 | return device
95 |
96 |
97 | def create_data_loaders(config: Config) -> Tuple[DataLoader, DataLoader, int]:
98 | """
99 | Create training and full dataset loaders for MNIST.
100 |
101 | Returns:
102 | Tuple of (train_loader, all_loader, train_size)
103 | """
104 | # Load MNIST dataset
105 | trainval_dataset = datasets.MNIST(
106 | config.data_dir, train=True, transform=transforms.ToTensor(), download=True
107 | )
108 |
109 | # Calculate dataset sizes
110 | n_samples = len(trainval_dataset)
111 | train_size = int(n_samples * config.train_fraction)
112 |
113 | # Create subsets
114 | train_indices = list(range(train_size))
115 | train_dataset = Subset(trainval_dataset, train_indices)
116 |
117 | # Data loader settings
118 | kwargs = {"num_workers": 1, "pin_memory": True} if config.use_cuda else {}
119 |
120 | # Create data loaders
121 | train_loader = DataLoader(
122 | train_dataset, batch_size=config.batch_size, shuffle=False, **kwargs
123 | )
124 |
125 | # All data loader for sending latent variables to GMM
126 | all_loader = DataLoader(
127 | train_dataset, batch_size=train_size, shuffle=False, **kwargs
128 | )
129 |
130 | print(
131 | f"Dataset info: {train_size} samples, "
132 | f"VAE iterations: {config.vae_iter}, "
133 | f"GMM iterations: {config.gmm_iter}"
134 | )
135 |
136 | return train_loader, all_loader, train_size
137 |
138 |
139 | def train_model(
140 | config: Config,
141 | train_loader: DataLoader,
142 | all_loader: DataLoader,
143 | device: torch.device,
144 | ) -> None:
145 | """
146 | Main training loop for VAE-GMM mutual learning.
147 |
148 | Args:
149 | config: Configuration object with training parameters
150 | train_loader: DataLoader for training batches
151 | all_loader: DataLoader for all data (used for GMM training)
152 | """
153 | print("Starting VAE-GMM mutual learning training...")
154 |
155 | # Initialize GMM parameters for first iteration
156 | gmm_mu = None
157 | gmm_var = None
158 |
159 | for iteration in range(config.mutual_iterations):
160 | print(f"---------- Mutual Learning Iteration: {iteration + 1} ----------")
161 |
162 | # VAE Training Phase
163 | print("Training VAE...")
164 | x_d, label, loss_list = vae_module.train(
165 | iteration=iteration,
166 | gmm_mu=gmm_mu, # GMM mean parameters (None for first iteration)
167 | gmm_var=gmm_var, # GMM variance parameters (None for first iteration)
168 | epoch=config.vae_iter,
169 | train_loader=train_loader,
170 | all_loader=all_loader,
171 | model_dir=config.debug_dir,
172 | device=device,
173 | )
174 |
175 | # GMM Training Phase
176 | print("Training GMM...")
177 | gmm_mu, gmm_var, max_acc = gmm_module.train(
178 | iteration=iteration,
179 | x_d=x_d, # Latent variables from VAE
180 | model_dir=config.debug_dir,
181 | label=label, # MNIST labels for accuracy calculation
182 | K=config.num_clusters,
183 | epoch=config.gmm_iter,
184 | )
185 |
186 | # Visualize latent space
187 | print("Plotting latent space...")
188 | vae_module.plot_latent(
189 | iteration=iteration,
190 | all_loader=all_loader,
191 | model_dir=config.debug_dir,
192 | device=device,
193 | )
194 |
195 | print(f"Iteration {iteration + 1} completed. Max accuracy: {max_acc:.4f}")
196 |
197 |
198 | def plot_dist(iteration: int, decode_k: int, sample_num: int, model_dir: str) -> None:
199 | """
200 | Visualize Gaussian distributions from GMM parameters.
201 |
202 | Args:
203 | iteration: Which iteration's model to load
204 | decode_k: Cluster number of Gaussian distribution
205 | sample_num: Number of samples for the random variable
206 | model_dir: Directory containing the model files
207 | """
208 | visualize_gmm(
209 | iteration=iteration,
210 | decode_k=decode_k,
211 | sample_num=sample_num,
212 | model_dir=model_dir,
213 | )
214 |
215 |
216 | def parse_arguments() -> Config:
217 | """Parse command line arguments and return configuration."""
218 | parser = argparse.ArgumentParser(description="VAE-GMM MNIST Mutual Learning")
219 | parser.add_argument(
220 | "--vae-iter",
221 | type=int,
222 | default=100,
223 | help="Number of VAE training iterations (default: 100)",
224 | )
225 | parser.add_argument(
226 | "--gmm-iter",
227 | type=int,
228 | default=100,
229 | help="Number of GMM training iterations (default: 100)",
230 | )
231 | parser.add_argument(
232 | "--no-cuda", action="store_true", default=False, help="Disable CUDA training"
233 | )
234 | parser.add_argument(
235 | "--seed",
236 | type=int,
237 | default=1,
238 | help="Random seed for reproducibility (default: 1)",
239 | )
240 |
241 | args = parser.parse_args()
242 |
243 | # Create config from arguments
244 | config = Config(
245 | vae_iter=args.vae_iter,
246 | gmm_iter=args.gmm_iter,
247 | use_cuda=not args.no_cuda,
248 | seed=args.seed,
249 | )
250 |
251 | return config
252 |
253 |
254 | def main() -> None:
255 | """Main function to orchestrate the VAE-GMM training process."""
256 | # Parse arguments and setup configuration
257 | config = parse_arguments()
258 |
259 | # Setup environment
260 | setup_directories(config)
261 | device = setup_device_and_seed(config)
262 |
263 | # Create data loaders
264 | train_loader, all_loader, train_size = create_data_loaders(config)
265 |
266 | # Train the VAE-GMM model
267 | train_model(config, train_loader, all_loader, device)
268 |
269 | # Reconstruct images from trained model
270 | print("\nGenerating reconstructed images...")
271 | vae_module.decode(
272 | iteration=1, # Use model from iteration 1
273 | decode_k=1, # Use cluster 1 for reconstruction
274 | sample_num=16, # Generate 16 samples
275 | model_dir=config.debug_dir,
276 | device=device,
277 | )
278 |
279 | print("Training and reconstruction completed!")
280 |
281 | # Optional: Visualize distributions (uncomment to use)
282 | # print("Visualizing Gaussian distributions...")
283 | # plot_dist(
284 | # iteration=0,
285 | # decode_k=1,
286 | # sample_num=16,
287 | # model_dir=config.debug_dir
288 | # )
289 |
290 |
291 | if __name__ == "__main__":
292 | main()
293 |
--------------------------------------------------------------------------------
/gmm_module.py:
--------------------------------------------------------------------------------
1 | """
2 | Gaussian Mixture Model (GMM) module for the VAE-GMM mutual learning system.
3 |
4 | This module provides GMM implementation with Gibbs sampling for clustering
5 | VAE latent variables and estimating parameters for the VAE prior distribution.
6 | """
7 |
8 | from dataclasses import dataclass
9 | from typing import Tuple, List
10 |
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import torch
14 | from scipy.stats import dirichlet, wishart
15 |
16 | from tool import calc_acc
17 |
18 |
19 | @dataclass
20 | class GMMConfig:
21 | """Configuration class for GMM hyperparameters."""
22 |
23 | # Prior hyperparameters
24 | beta: float = 1.0 # Precision parameter for mean prior
25 | nu_factor: float = 1.0 # Degrees of freedom factor (nu = dim * nu_factor)
26 | w_scale: float = 0.55 # Scale factor for precision matrix prior
27 | alpha: float = 0.3 # Dirichlet concentration parameter
28 |
29 | # Training parameters
30 | log_interval: int = 20 # Logging interval during training
31 |
32 | # Numerical stability
33 | eps: float = 1e-7 # Small epsilon for numerical stability
34 |
35 |
36 | def initialize_gmm_parameters(
37 | num_data: int,
38 | data_dim: int,
39 | num_clusters: int,
40 | config: GMMConfig
41 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
42 | """
43 | Initialize GMM parameters by sampling from priors.
44 |
45 | Args:
46 | num_data: Number of data points
47 | data_dim: Dimensionality of data
48 | num_clusters: Number of clusters (K)
49 | config: GMM configuration
50 |
51 | Returns:
52 | Tuple of (mu, lambda, pi, prior_parameters)
53 | """
54 | # Prior parameters
55 | m_d = np.zeros(data_dim) # Prior mean
56 | w_dd = np.identity(data_dim) * config.w_scale # Prior precision matrix scale
57 | nu = int(data_dim * config.nu_factor) # Degrees of freedom
58 | alpha_k = np.full(num_clusters, config.alpha) # Dirichlet parameters
59 |
60 | # Initialize cluster parameters by sampling from priors
61 | mu_kd = np.empty((num_clusters, data_dim))
62 | lambda_kdd = np.empty((num_clusters, data_dim, data_dim))
63 |
64 | for k in range(num_clusters):
65 | # Sample precision matrix from Wishart distribution
66 | lambda_kdd[k] = wishart.rvs(df=nu, scale=w_dd, size=1)
67 |
68 | # Sample mean from Normal-Wishart prior
69 | cov = np.linalg.inv(config.beta * lambda_kdd[k])
70 | mu_kd[k] = np.random.multivariate_normal(mean=m_d, cov=cov).flatten()
71 |
72 | # Sample mixing coefficients from Dirichlet distribution
73 | pi_k = dirichlet.rvs(alpha=alpha_k, size=1).flatten()
74 |
75 | # Store prior parameters for later use
76 | prior_params = {
77 | 'm_d': m_d,
78 | 'w_dd': w_dd,
79 | 'nu': nu,
80 | 'alpha_k': alpha_k
81 | }
82 |
83 | return mu_kd, lambda_kdd, pi_k, prior_params
84 |
85 |
86 | def compute_responsibilities(
87 | x_d: np.ndarray,
88 | mu_kd: np.ndarray,
89 | lambda_kdd: np.ndarray,
90 | pi_k: np.ndarray,
91 | config: GMMConfig
92 | ) -> np.ndarray:
93 | """
94 | Compute responsibilities (posterior probabilities) for each data point.
95 |
96 | Args:
97 | x_d: Data points [num_data, data_dim]
98 | mu_kd: Cluster means [num_clusters, data_dim]
99 | lambda_kdd: Cluster precision matrices [num_clusters, data_dim, data_dim]
100 | pi_k: Mixing coefficients [num_clusters]
101 | config: GMM configuration
102 |
103 | Returns:
104 | Responsibilities [num_data, num_clusters]
105 | """
106 | num_data, num_clusters = len(x_d), len(mu_kd)
107 | eta_dk = np.zeros((num_data, num_clusters))
108 |
109 | for k in range(num_clusters):
110 | # Compute quadratic form for each data point
111 | diff = x_d - mu_kd[k] # [num_data, data_dim]
112 | quadratic_form = np.sum(diff @ lambda_kdd[k] * diff, axis=1)
113 |
114 | # Log probability components
115 | log_prob = -0.5 * quadratic_form
116 | log_prob += 0.5 * np.log(np.linalg.det(lambda_kdd[k]) + config.eps)
117 | log_prob += np.log(pi_k[k] + config.eps)
118 |
119 | eta_dk[:, k] = np.exp(log_prob)
120 |
121 | # Normalize to get responsibilities
122 | eta_dk /= np.sum(eta_dk, axis=1, keepdims=True)
123 | return eta_dk
124 |
125 |
126 | def sample_cluster_assignments(eta_dk: np.ndarray) -> Tuple[np.ndarray, List[int]]:
127 | """
128 | Sample cluster assignments from responsibilities.
129 |
130 | Args:
131 | eta_dk: Responsibilities [num_data, num_clusters]
132 |
133 | Returns:
134 | Tuple of (assignment_matrix, predicted_labels)
135 | """
136 | num_data, num_clusters = eta_dk.shape
137 | z_dk = np.zeros((num_data, num_clusters))
138 | pred_labels = []
139 |
140 | for d in range(num_data):
141 | # Sample from multinomial distribution
142 | assignment = np.random.multinomial(n=1, pvals=eta_dk[d], size=1).flatten()
143 | z_dk[d] = assignment
144 | pred_labels.append(np.argmax(assignment))
145 |
146 | return z_dk, pred_labels
147 |
148 |
149 | def update_gmm_parameters(
150 | x_d: np.ndarray,
151 | z_dk: np.ndarray,
152 | prior_params: dict,
153 | config: GMMConfig
154 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
155 | """
156 | Update GMM parameters using Gibbs sampling.
157 |
158 | Args:
159 | x_d: Data points [num_data, data_dim]
160 | z_dk: Cluster assignments [num_data, num_clusters]
161 | prior_params: Dictionary containing prior parameters
162 | config: GMM configuration
163 |
164 | Returns:
165 | Tuple of (updated_mu, updated_lambda, updated_pi)
166 | """
167 | num_data, data_dim = x_d.shape
168 | num_clusters = z_dk.shape[1]
169 |
170 | # Extract prior parameters
171 | m_d = prior_params['m_d']
172 | w_dd = prior_params['w_dd']
173 | nu = prior_params['nu']
174 | alpha_k = prior_params['alpha_k']
175 |
176 | # Initialize parameter arrays
177 | mu_kd = np.empty((num_clusters, data_dim))
178 | lambda_kdd = np.empty((num_clusters, data_dim, data_dim))
179 |
180 | for k in range(num_clusters):
181 | # Posterior parameters for Normal-Wishart
182 | cluster_size = np.sum(z_dk[:, k])
183 | beta_hat = cluster_size + config.beta
184 |
185 | # Posterior mean
186 | weighted_sum = np.sum(z_dk[:, k][:, np.newaxis] * x_d, axis=0)
187 | m_hat = (weighted_sum + config.beta * m_d) / beta_hat
188 |
189 | # Posterior scale matrix for Wishart
190 | centered_data = x_d - m_hat # [num_data, data_dim]
191 | weighted_scatter = (z_dk[:, k][:, np.newaxis, np.newaxis] *
192 | centered_data[:, :, np.newaxis] @
193 | centered_data[:, np.newaxis, :]).sum(axis=0)
194 |
195 | prior_term = config.beta * np.outer(m_d - m_hat, m_d - m_hat)
196 | w_hat = np.linalg.inv(np.linalg.inv(w_dd) + weighted_scatter + prior_term)
197 |
198 | nu_hat = cluster_size + nu
199 |
200 | # Sample precision matrix from Wishart
201 | lambda_kdd[k] = wishart.rvs(df=nu_hat, scale=w_hat, size=1)
202 |
203 | # Sample mean from multivariate normal
204 | cov = np.linalg.inv(beta_hat * lambda_kdd[k])
205 | mu_kd[k] = np.random.multivariate_normal(mean=m_hat, cov=cov).flatten()
206 |
207 | # Update mixing coefficients
208 | alpha_hat = np.sum(z_dk, axis=0) + alpha_k
209 | pi_k = dirichlet.rvs(alpha=alpha_hat, size=1).flatten()
210 |
211 | return mu_kd, lambda_kdd, pi_k
212 |
213 |
214 | def compute_data_parameters(
215 | mu_kd: np.ndarray,
216 | lambda_kdd: np.ndarray,
217 | pred_labels: List[int]
218 | ) -> Tuple[np.ndarray, np.ndarray]:
219 | """
220 | Compute mean and variance for each data point based on cluster assignments.
221 |
222 | Args:
223 | mu_kd: Cluster means [num_clusters, data_dim]
224 | lambda_kdd: Cluster precision matrices [num_clusters, data_dim, data_dim]
225 | pred_labels: Predicted cluster labels for each data point
226 |
227 | Returns:
228 | Tuple of (data_means, data_variances)
229 | """
230 | num_data = len(pred_labels)
231 | data_dim = mu_kd.shape[1]
232 |
233 | mu_d = np.zeros((num_data, data_dim))
234 | var_d = np.zeros((num_data, data_dim))
235 |
236 | for d in range(num_data):
237 | cluster_idx = pred_labels[d]
238 | mu_d[d] = mu_kd[cluster_idx]
239 | # Extract diagonal of inverse precision matrix (variances)
240 | var_d[d] = np.diag(np.linalg.inv(lambda_kdd[cluster_idx]))
241 |
242 | return mu_d, var_d
243 |
244 |
245 | def save_gmm_parameters(
246 | mu_kd: np.ndarray,
247 | lambda_kdd: np.ndarray,
248 | pi_k: np.ndarray,
249 | iteration: int,
250 | model_dir: str
251 | ) -> None:
252 | """Save GMM parameters to files."""
253 | np.save(f"{model_dir}/npy/mu_{iteration}.npy", mu_kd)
254 | np.save(f"{model_dir}/npy/lambda_{iteration}.npy", lambda_kdd)
255 | np.save(f"{model_dir}/npy/pi_{iteration}.npy", pi_k)
256 |
257 |
258 | def plot_accuracy_curve(
259 | accuracy_history: np.ndarray,
260 | iteration: int,
261 | num_clusters: int,
262 | model_dir: str
263 | ) -> None:
264 | """
265 | Plot and save accuracy curve during GMM training.
266 |
267 | Args:
268 | accuracy_history: Accuracy values over epochs
269 | iteration: Current iteration number
270 | num_clusters: Number of clusters
271 | model_dir: Directory to save plot
272 | """
273 | plt.figure(figsize=(10, 6))
274 | plt.plot(range(len(accuracy_history)), accuracy_history,
275 | marker='o', markersize=3, linewidth=2, color='blue')
276 | plt.xlabel("Epoch", fontsize=12)
277 | plt.ylabel("Clustering Accuracy", fontsize=12)
278 | plt.title(f"GMM Training Progress - Iteration {iteration} (K={num_clusters})",
279 | fontsize=14)
280 | plt.grid(True, alpha=0.3)
281 | plt.ylim(0, 1)
282 |
283 | # Add max accuracy annotation
284 | max_acc = np.max(accuracy_history)
285 | max_epoch = np.argmax(accuracy_history)
286 | plt.annotate(f'Max: {max_acc:.3f}',
287 | xy=(max_epoch, max_acc),
288 | xytext=(max_epoch + len(accuracy_history) * 0.1, max_acc + 0.05),
289 | arrowprops=dict(arrowstyle='->', color='red', alpha=0.7),
290 | fontsize=10, color='red')
291 |
292 | plt.tight_layout()
293 | plt.savefig(f"{model_dir}/graph/acc_{iteration}.png", dpi=150, bbox_inches='tight')
294 | plt.close()
295 |
296 |
297 | def train(
298 | iteration: int,
299 | x_d: np.ndarray,
300 | label: np.ndarray,
301 | K: int,
302 | epoch: int = 100,
303 | model_dir: str = "vae_gmm"
304 | ) -> Tuple[torch.Tensor, torch.Tensor, float]:
305 | """
306 | Train GMM using Gibbs sampling for clustering VAE latent variables.
307 |
308 | Args:
309 | iteration: Current mutual learning iteration
310 | x_d: Latent variables from VAE [num_data, latent_dim]
311 | label: Ground truth labels for accuracy calculation
312 | K: Number of clusters
313 | epoch: Number of training epochs
314 | model_dir: Directory to save results
315 |
316 | Returns:
317 | Tuple of (data_means, data_variances, max_accuracy)
318 | """
319 | print(f"GMM Training Start - Iteration {iteration}")
320 |
321 | # Data dimensions
322 | num_data, data_dim = x_d.shape
323 | config = GMMConfig()
324 |
325 | # Initialize parameters
326 | mu_kd, lambda_kdd, pi_k, prior_params = initialize_gmm_parameters(
327 | num_data, data_dim, K, config
328 | )
329 |
330 | # Training loop
331 | accuracy_history = np.zeros(epoch)
332 | max_accuracy = 0.0
333 | pred_labels = []
334 |
335 | for epoch_idx in range(epoch):
336 | # E-step: Compute responsibilities
337 | eta_dk = compute_responsibilities(x_d, mu_kd, lambda_kdd, pi_k, config)
338 |
339 | # Sample cluster assignments
340 | z_dk, pred_labels = sample_cluster_assignments(eta_dk)
341 |
342 | # M-step: Update parameters using Gibbs sampling
343 | mu_kd, lambda_kdd, pi_k = update_gmm_parameters(
344 | x_d, z_dk, prior_params, config
345 | )
346 |
347 | # Calculate and store accuracy
348 | accuracy = calc_acc(pred_labels, label)[0]
349 | accuracy_history[epoch_idx] = np.round(accuracy, 3)
350 | max_accuracy = max(max_accuracy, accuracy_history[epoch_idx])
351 |
352 | # Log progress
353 | if (epoch_idx == 0 or (epoch_idx + 1) % config.log_interval == 0 or
354 | epoch_idx == epoch - 1):
355 | print(f"====> Epoch: {epoch_idx + 1}/{epoch}, "
356 | f"Accuracy: {accuracy_history[epoch_idx]:.3f}, "
357 | f"Max Accuracy: {max_accuracy:.3f}")
358 |
359 | # Compute final data parameters
360 | mu_d, var_d = compute_data_parameters(mu_kd, lambda_kdd, pred_labels)
361 |
362 | # Save results
363 | save_gmm_parameters(mu_kd, lambda_kdd, pi_k, iteration, model_dir)
364 | plot_accuracy_curve(accuracy_history, iteration, K, model_dir)
365 |
366 | print(f"GMM training completed. Final accuracy: {accuracy_history[-1]:.3f}")
367 |
368 | return torch.from_numpy(mu_d), torch.from_numpy(var_d), max_accuracy
369 |
--------------------------------------------------------------------------------
/vae_module.py:
--------------------------------------------------------------------------------
1 | """
2 | Variational Autoencoder (VAE) module for the VAE-GMM mutual learning system.
3 |
4 | This module provides the VAE implementation with encoder-decoder architecture,
5 | training functionality, and utilities for latent space visualization and image reconstruction.
6 | """
7 |
8 | from dataclasses import dataclass
9 | from typing import Optional, Tuple
10 |
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import torch
14 | import torch.utils.data
15 | from torch import nn, optim
16 | from torch.nn import functional as F
17 | from torch.utils.data import DataLoader
18 | from torchvision.utils import save_image
19 |
20 | from tool import get_param, sample, visualize_ls
21 |
22 |
23 | @dataclass
24 | class VAEConfig:
25 | """Configuration class for VAE parameters."""
26 |
27 | # Architecture parameters
28 | latent_dim: int = 12 # Dimension of latent variable
29 | input_dim: int = 784 # MNIST image dimension (28x28)
30 | hidden_dim: int = 256 # Hidden layer dimension
31 |
32 | # Training parameters
33 | learning_rate: float = 1e-3
34 | reconstruction_loss: str = "bce" # Binary cross-entropy for MNIST
35 |
36 | # Visualization parameters
37 | log_interval: int = 50 # Logging interval during training
38 |
39 |
40 | class VAE(nn.Module):
41 | """
42 | Variational Autoencoder with encoder-decoder architecture for MNIST.
43 |
44 | The VAE includes:
45 | - Encoder: Maps input images to latent space (mean and log-variance)
46 | - Decoder: Reconstructs images from latent representations
47 | - Reparameterization: Enables backpropagation through stochastic sampling
48 | """
49 |
50 | def __init__(self, config: VAEConfig) -> None:
51 | """
52 | Initialize VAE network layers.
53 |
54 | Args:
55 | config: VAE configuration containing architecture parameters
56 | """
57 | super(VAE, self).__init__()
58 | self.config = config
59 |
60 | # Encoder layers
61 | self.encoder_hidden = nn.Linear(config.input_dim, config.hidden_dim)
62 | self.encoder_mu = nn.Linear(config.hidden_dim, config.latent_dim)
63 | self.encoder_logvar = nn.Linear(config.hidden_dim, config.latent_dim)
64 |
65 | # Decoder layers
66 | self.decoder_hidden = nn.Linear(config.latent_dim, config.hidden_dim)
67 | self.decoder_output = nn.Linear(config.hidden_dim, config.input_dim)
68 |
69 | def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
70 | """
71 | Encode input to latent space parameters.
72 |
73 | Args:
74 | x: Input tensor [batch_size, input_dim]
75 |
76 | Returns:
77 | Tuple of (mean, log_variance) tensors [batch_size, latent_dim]
78 | """
79 | hidden = F.relu(self.encoder_hidden(x))
80 | mu = self.encoder_mu(hidden)
81 | logvar = self.encoder_logvar(hidden)
82 | return mu, logvar
83 |
84 | def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
85 | """
86 | Reparameterization trick for backpropagation through stochastic sampling.
87 |
88 | Args:
89 | mu: Mean tensor [batch_size, latent_dim]
90 | logvar: Log-variance tensor [batch_size, latent_dim]
91 |
92 | Returns:
93 | Sampled latent vector [batch_size, latent_dim]
94 | """
95 | std = torch.exp(0.5 * logvar)
96 | eps = torch.randn_like(std)
97 | return mu + eps * std
98 |
99 | def decode(self, z: torch.Tensor) -> torch.Tensor:
100 | """
101 | Decode latent representation to output space.
102 |
103 | Args:
104 | z: Latent tensor [batch_size, latent_dim]
105 |
106 | Returns:
107 | Reconstructed output [batch_size, input_dim]
108 | """
109 | hidden = F.relu(self.decoder_hidden(z))
110 | return torch.sigmoid(self.decoder_output(hidden))
111 |
112 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
113 | """
114 | Forward pass through VAE.
115 |
116 | Args:
117 | x: Input tensor [batch_size, ...]
118 |
119 | Returns:
120 | Tuple of (reconstruction, mu, logvar, latent_sample)
121 | """
122 | x_flat = x.view(-1, self.config.input_dim)
123 | mu, logvar = self.encode(x_flat)
124 | z = self.reparameterize(mu, logvar)
125 | reconstruction = self.decode(z)
126 | return reconstruction, mu, logvar, z
127 |
128 | def compute_loss(
129 | self,
130 | reconstruction: torch.Tensor,
131 | target: torch.Tensor,
132 | mu: torch.Tensor,
133 | logvar: torch.Tensor,
134 | gmm_mu: Optional[torch.Tensor] = None,
135 | gmm_var: Optional[torch.Tensor] = None,
136 | device: torch.device = torch.device("cpu")
137 | ) -> torch.Tensor:
138 | """
139 | Compute VAE loss (reconstruction + KL divergence).
140 |
141 | Reference: Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
142 | https://arxiv.org/abs/1312.6114
143 |
144 | Args:
145 | reconstruction: Reconstructed output [batch_size, input_dim]
146 | target: Original input [batch_size, input_dim]
147 | mu: Encoder mean [batch_size, latent_dim]
148 | logvar: Encoder log-variance [batch_size, latent_dim]
149 | gmm_mu: GMM mean parameters (None for standard normal prior)
150 | gmm_var: GMM variance parameters (None for standard normal prior)
151 | device: Device for computation
152 |
153 | Returns:
154 | Total loss tensor
155 | """
156 | # Reconstruction loss (Binary Cross-Entropy)
157 | target_flat = target.view(-1, self.config.input_dim)
158 | reconstruction_loss = F.binary_cross_entropy(
159 | reconstruction, target_flat, reduction="sum"
160 | )
161 |
162 | # KL divergence
163 | if gmm_mu is not None and gmm_var is not None:
164 | # Use GMM parameters as prior (mutual learning iterations > 0)
165 | kl_divergence = self._compute_gmm_kl_divergence(
166 | mu, logvar, gmm_mu, gmm_var, device
167 | )
168 | else:
169 | # Use standard normal prior (first iteration)
170 | kl_divergence = self._compute_standard_kl_divergence(mu, logvar)
171 |
172 | return reconstruction_loss + kl_divergence
173 |
174 | def _compute_standard_kl_divergence(
175 | self, mu: torch.Tensor, logvar: torch.Tensor
176 | ) -> torch.Tensor:
177 | """Compute KL divergence with standard normal prior N(0,I)."""
178 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
179 |
180 | def _compute_gmm_kl_divergence(
181 | self,
182 | mu: torch.Tensor,
183 | logvar: torch.Tensor,
184 | gmm_mu: torch.Tensor,
185 | gmm_var: torch.Tensor,
186 | device: torch.device
187 | ) -> torch.Tensor:
188 | """Compute KL divergence with GMM prior."""
189 | # Convert GMM parameters to appropriate format
190 | prior_mu = gmm_mu.to(device).expand_as(mu)
191 | prior_var = gmm_var.to(device).expand_as(logvar)
192 | prior_logvar = prior_var.log()
193 |
194 | # KL divergence components
195 | var_division = logvar.exp() / prior_var # σ²_q / σ²_p
196 | diff = mu - prior_mu # μ_q - μ_p
197 | diff_term = diff * diff / prior_var # (μ_q - μ_p)² / σ²_p
198 | logvar_division = prior_logvar - logvar # log(σ²_p) - log(σ²_q)
199 |
200 | kl_per_sample = 0.5 * (
201 | (var_division + diff_term + logvar_division).sum(1) - self.config.latent_dim
202 | )
203 | return kl_per_sample.sum()
204 |
205 |
206 | def save_model_and_losses(
207 | model: VAE, loss_list: np.ndarray, iteration: int, model_dir: str
208 | ) -> None:
209 | """Save trained model and loss history."""
210 | # Save model state
211 | model_path = f"{model_dir}/pth/vae_{iteration}.pth"
212 | torch.save(model.state_dict(), model_path)
213 |
214 | # Save loss history
215 | loss_path = f"{model_dir}/npy/loss_{iteration}.npy"
216 | np.save(loss_path, loss_list)
217 |
218 |
219 | def plot_training_losses(
220 | loss_list: np.ndarray, iteration: int, model_dir: str
221 | ) -> None:
222 | """Plot and save training loss curves."""
223 | plt.figure()
224 | plt.plot(range(len(loss_list)), loss_list, color="blue", label="ELBO")
225 |
226 | # Compare with first iteration if available
227 | if iteration != 0:
228 | try:
229 | loss_0 = np.load(f"{model_dir}/npy/loss_0.npy")
230 | plt.plot(range(len(loss_0)), loss_0, color="red", label="ELBO_I0")
231 | except FileNotFoundError:
232 | pass # Skip comparison if first iteration data not available
233 |
234 | plt.xlabel("Epoch")
235 | plt.ylabel("ELBO")
236 | plt.legend(loc="lower right")
237 | plt.title(f"VAE Training Loss - Iteration {iteration}")
238 |
239 | # Save plot
240 | plot_path = f"{model_dir}/graph/vae_loss_{iteration}.png"
241 | plt.savefig(plot_path)
242 | plt.close()
243 |
244 |
245 | def train_single_epoch(
246 | model: VAE,
247 | train_loader: DataLoader,
248 | optimizer: optim.Optimizer,
249 | iteration: int,
250 | gmm_mu: Optional[torch.Tensor],
251 | gmm_var: Optional[torch.Tensor],
252 | device: torch.device
253 | ) -> float:
254 | """Train VAE for one epoch."""
255 | model.train()
256 | total_loss = 0.0
257 |
258 | for batch_idx, (data, _) in enumerate(train_loader):
259 | data = data.to(device)
260 | optimizer.zero_grad()
261 |
262 | # Forward pass
263 | reconstruction, mu, logvar, latent = model(data)
264 |
265 | # Compute loss
266 | if iteration == 0:
267 | # First iteration: use standard normal prior
268 | loss = model.compute_loss(
269 | reconstruction, data, mu, logvar, device=device
270 | )
271 | else:
272 | # Subsequent iterations: use GMM prior
273 | batch_gmm_mu = gmm_mu[batch_idx] if gmm_mu is not None else None
274 | batch_gmm_var = gmm_var[batch_idx] if gmm_var is not None else None
275 | loss = model.compute_loss(
276 | reconstruction, data, mu, logvar,
277 | batch_gmm_mu, batch_gmm_var, device
278 | )
279 |
280 | # Backward pass
281 | loss = loss.mean()
282 | loss.backward()
283 | optimizer.step()
284 |
285 | total_loss += loss.item()
286 |
287 | return total_loss
288 |
289 |
290 | def train(
291 | iteration: int,
292 | gmm_mu: Optional[torch.Tensor],
293 | gmm_var: Optional[torch.Tensor],
294 | epoch: int,
295 | train_loader: DataLoader,
296 | all_loader: DataLoader,
297 | model_dir: str = "./vae_gmm",
298 | device: torch.device = torch.device("cpu")
299 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
300 | """
301 | Train VAE model for specified number of epochs.
302 |
303 | Args:
304 | iteration: Current mutual learning iteration
305 | gmm_mu: GMM mean parameters (None for first iteration)
306 | gmm_var: GMM variance parameters (None for first iteration)
307 | epoch: Number of training epochs
308 | train_loader: DataLoader for training batches
309 | all_loader: DataLoader for all data (used for latent variable extraction)
310 | model_dir: Directory to save model and results
311 | device: Device for computation
312 |
313 | Returns:
314 | Tuple of (latent_variables, labels, loss_history)
315 | """
316 | print(f"VAE Training Start - Iteration {iteration}")
317 |
318 | # Initialize model and optimizer
319 | config = VAEConfig()
320 | model = VAE(config).to(device)
321 | optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
322 |
323 | # Track losses
324 | loss_history = np.zeros(epoch)
325 |
326 | # Training loop
327 | for epoch_idx in range(epoch):
328 | epoch_loss = train_single_epoch(
329 | model, train_loader, optimizer, iteration,
330 | gmm_mu, gmm_var, device
331 | )
332 |
333 | # Store negative loss (ELBO)
334 | avg_loss = epoch_loss / len(train_loader.dataset)
335 | loss_history[epoch_idx] = -avg_loss
336 |
337 | # Log progress
338 | if (epoch_idx == 0 or (epoch_idx + 1) % config.log_interval == 0 or
339 | epoch_idx == epoch - 1):
340 | print(f"====> Epoch: {epoch_idx + 1}/{epoch}, Average loss: {avg_loss:.4f}")
341 |
342 | # Save model and results
343 | save_model_and_losses(model, loss_history, iteration, model_dir)
344 | plot_training_losses(loss_history, iteration, model_dir)
345 |
346 | # Extract latent variables for all data
347 | latent_vars, labels = extract_latent_variables(
348 | model, all_loader, device
349 | )
350 |
351 | return latent_vars, labels, loss_history
352 |
353 |
354 | def load_trained_model(
355 | iteration: int, model_dir: str, device: torch.device
356 | ) -> VAE:
357 | """Load a trained VAE model from checkpoint."""
358 | config = VAEConfig()
359 | model = VAE(config).to(device)
360 | model_path = f"{model_dir}/pth/vae_{iteration}.pth"
361 | model.load_state_dict(torch.load(model_path, map_location=device))
362 | model.eval()
363 | return model
364 |
365 |
366 | def extract_latent_variables(
367 | model: VAE, data_loader: DataLoader, device: torch.device
368 | ) -> Tuple[np.ndarray, np.ndarray]:
369 | """
370 | Extract latent variables and labels from all data using trained VAE.
371 |
372 | Args:
373 | model: Trained VAE model
374 | data_loader: DataLoader containing all data
375 | device: Device for computation
376 |
377 | Returns:
378 | Tuple of (latent_variables, labels) as numpy arrays
379 | """
380 | model.eval()
381 | with torch.no_grad():
382 | for batch_idx, (data, labels) in enumerate(data_loader):
383 | data = data.to(device)
384 | _, _, _, latent_vars = model(data)
385 | latent_vars = latent_vars.cpu()
386 | labels = labels.cpu()
387 | # Only process the first (and typically only) batch for all_loader
388 | break
389 |
390 | return latent_vars.detach().numpy(), labels.detach().numpy()
391 |
392 |
393 | def decode(
394 | iteration: int,
395 | decode_k: int,
396 | sample_num: int,
397 | model_dir: str = "./vae_gmm",
398 | device: torch.device = torch.device("cpu")
399 | ) -> None:
400 | """
401 | Reconstruct images from GMM parameters using trained VAE decoder.
402 |
403 | Args:
404 | iteration: Which iteration's model to load
405 | decode_k: Cluster number of Gaussian distribution for sampling
406 | sample_num: Number of samples to generate
407 | model_dir: Directory containing model files
408 | device: Device for computation
409 | """
410 | print(f"Reconstructing {sample_num} images from GMM cluster {decode_k}")
411 |
412 | # Load trained model
413 | model = load_trained_model(iteration, model_dir, device)
414 |
415 | # Get GMM parameters and generate samples
416 | config = VAEConfig()
417 | mu_gmm, lambda_gmm, pi_gmm = get_param(iteration, model_dir=model_dir)
418 | manual_sample, _ = sample(
419 | iteration=iteration,
420 | x_dim=config.latent_dim,
421 | mu_gmm=mu_gmm,
422 | lambda_gmm=lambda_gmm,
423 | sample_num=sample_num,
424 | sample_k=decode_k,
425 | model_dir=model_dir,
426 | )
427 |
428 | # Convert to tensor and generate reconstructions
429 | samples = torch.from_numpy(manual_sample.astype(np.float32)).to(device)
430 |
431 | with torch.no_grad():
432 | reconstructions = model.decode(samples).cpu()
433 | # Save reconstructed images
434 | output_path = f"{model_dir}/recon/manual_{decode_k}.png"
435 | save_image(
436 | reconstructions.view(sample_num, 1, 28, 28),
437 | output_path
438 | )
439 | print(f"Saved reconstructed images to: {output_path}")
440 |
441 |
442 | def plot_latent(
443 | iteration: int,
444 | all_loader: DataLoader,
445 | model_dir: str = "./vae_gmm",
446 | device: torch.device = torch.device("cpu")
447 | ) -> None:
448 | """
449 | Visualize VAE latent space using all data points.
450 |
451 | Args:
452 | iteration: Current iteration number
453 | all_loader: DataLoader containing all data
454 | model_dir: Directory for saving visualizations
455 | device: Device for computation
456 | """
457 | print("Plotting latent space visualization")
458 |
459 | # Load trained model
460 | model = load_trained_model(iteration, model_dir, device)
461 |
462 | # Extract latent variables for visualization
463 | latent_vars, labels = extract_latent_variables(model, all_loader, device)
464 |
465 | # Create visualization using tool function
466 | visualize_ls(iteration, latent_vars, labels, model_dir)
467 |
--------------------------------------------------------------------------------
/tool.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualization and utility tools for the VAE-GMM mutual learning system.
3 |
4 | This module provides comprehensive visualization capabilities, sampling functions,
5 | parameter loading utilities, and accuracy calculation for evaluating the performance
6 | of the VAE-GMM mutual learning system.
7 | """
8 |
9 | from dataclasses import dataclass
10 | from typing import List, Tuple, Optional
11 |
12 | import matplotlib.pyplot as plt
13 | import numpy as np
14 | from scipy.stats import multivariate_normal
15 | from sklearn.manifold import TSNE
16 |
17 |
18 | @dataclass
19 | class VisualizationConfig:
20 | """Configuration class for visualization parameters."""
21 |
22 | # Figure dimensions
23 | figure_size_2d: Tuple[int, int] = (12, 9)
24 | figure_size_latent: Tuple[int, int] = (10, 10)
25 |
26 | # Grid resolution for contour plots
27 | grid_resolution: int = 900
28 |
29 | # Visualization parameters
30 | marker_size: int = 100
31 | scatter_size: int = 100
32 | alpha: float = 0.5
33 | font_size_title: int = 20
34 | font_size_labels: int = 17
35 |
36 | # Colors for different clusters/classes
37 | colors: List[str] = None
38 |
39 | # Dimensionality reduction
40 | use_tsne: bool = True # If False, could use PCA
41 | tsne_perplexity: float = 30.0
42 | random_state: int = 0
43 |
44 | # File format
45 | dpi: int = 150
46 | file_format: str = "png"
47 |
48 | def __post_init__(self):
49 | if self.colors is None:
50 | self.colors = [
51 | "red", "green", "blue", "orange", "purple",
52 | "yellow", "black", "cyan", "#a65628", "#f781bf"
53 | ]
54 |
55 |
56 | @dataclass
57 | class SamplingConfig:
58 | """Configuration class for sampling parameters."""
59 |
60 | # Manual sampling parameters
61 | manual_sigma: float = 0.1 # Standard deviation for manual sampling
62 |
63 | # Default latent dimension
64 | default_latent_dim: int = 12
65 |
66 | # Default number of clusters
67 | default_num_clusters: int = 10
68 |
69 |
70 | def load_gmm_parameters(iteration: int, model_dir: str = "./vae_gmm") -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
71 | """
72 | Load GMM parameters from saved files.
73 |
74 | Args:
75 | iteration: Iteration number to load
76 | model_dir: Directory containing saved parameters
77 |
78 | Returns:
79 | Tuple of (means, precision_matrices, mixing_coefficients)
80 | """
81 | mu_path = f"{model_dir}/npy/mu_{iteration}.npy"
82 | lambda_path = f"{model_dir}/npy/lambda_{iteration}.npy"
83 | pi_path = f"{model_dir}/npy/pi_{iteration}.npy"
84 |
85 | mu_gmm = np.load(mu_path)
86 | lambda_gmm = np.load(lambda_path)
87 | pi_gmm = np.load(pi_path)
88 |
89 | return mu_gmm, lambda_gmm, pi_gmm
90 |
91 |
92 | def sample_from_gmm_cluster(
93 | cluster_mean: np.ndarray,
94 | cluster_precision: np.ndarray,
95 | sample_num: int,
96 | config: SamplingConfig
97 | ) -> Tuple[np.ndarray, np.ndarray]:
98 | """
99 | Sample from a specific GMM cluster using two different methods.
100 |
101 | Args:
102 | cluster_mean: Mean vector of the cluster
103 | cluster_precision: Precision matrix of the cluster
104 | sample_num: Number of samples to generate
105 | config: Sampling configuration
106 |
107 | Returns:
108 | Tuple of (manual_samples, gmm_samples)
109 | - manual_samples: Samples using fixed covariance
110 | - gmm_samples: Samples using estimated covariance from GMM
111 | """
112 | latent_dim = len(cluster_mean)
113 |
114 | # Manual sampling with fixed small covariance
115 | manual_cov = config.manual_sigma * np.identity(latent_dim, dtype=float)
116 | manual_samples = np.random.multivariate_normal(
117 | mean=cluster_mean,
118 | cov=manual_cov,
119 | size=sample_num
120 | )
121 |
122 | # GMM-based sampling using estimated covariance
123 | gmm_cov = np.linalg.inv(cluster_precision)
124 | gmm_samples = np.random.multivariate_normal(
125 | mean=cluster_mean,
126 | cov=gmm_cov,
127 | size=sample_num
128 | )
129 |
130 | return manual_samples, gmm_samples
131 |
132 |
133 | def sample(
134 | iteration: int,
135 | x_dim: int,
136 | mu_gmm: np.ndarray,
137 | lambda_gmm: np.ndarray,
138 | sample_num: int,
139 | sample_k: int,
140 | model_dir: str = "./vae_gmm"
141 | ) -> Tuple[np.ndarray, np.ndarray]:
142 | """
143 | Sample random variables for VAE decoder input from GMM posterior.
144 |
145 | Args:
146 | iteration: Current iteration (for compatibility)
147 | x_dim: Latent space dimensionality
148 | mu_gmm: GMM means [num_clusters, latent_dim]
149 | lambda_gmm: GMM precision matrices [num_clusters, latent_dim, latent_dim]
150 | sample_num: Number of samples to generate
151 | sample_k: Cluster index to sample from
152 | model_dir: Model directory (for compatibility)
153 |
154 | Returns:
155 | Tuple of (manual_samples, gmm_samples)
156 | """
157 | config = SamplingConfig()
158 |
159 | return sample_from_gmm_cluster(
160 | cluster_mean=mu_gmm[sample_k],
161 | cluster_precision=lambda_gmm[sample_k],
162 | sample_num=sample_num,
163 | config=config
164 | )
165 |
166 |
167 | def create_2d_grid(
168 | means_2d: np.ndarray,
169 | precisions_2d: np.ndarray,
170 | config: VisualizationConfig
171 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
172 | """
173 | Create a 2D grid for contour plotting.
174 |
175 | Args:
176 | means_2d: 2D cluster means [num_clusters, 2]
177 | precisions_2d: 2D precision matrices [num_clusters, 2, 2]
178 | config: Visualization configuration
179 |
180 | Returns:
181 | Tuple of (x1_grid, x2_grid, grid_points)
182 | """
183 | # Calculate grid bounds based on cluster parameters
184 | std_factor = 3.0 # Number of standard deviations to include
185 |
186 | x1_std = np.sqrt(1.0 / precisions_2d[:, 0, 0])
187 | x2_std = np.sqrt(1.0 / precisions_2d[:, 1, 1])
188 |
189 | x1_min = np.min(means_2d[:, 0] - std_factor * x1_std)
190 | x1_max = np.max(means_2d[:, 0] + std_factor * x1_std)
191 | x2_min = np.min(means_2d[:, 1] - std_factor * x2_std)
192 | x2_max = np.max(means_2d[:, 1] + std_factor * x2_std)
193 |
194 | # Create grid
195 | x1_line = np.linspace(x1_min, x1_max, config.grid_resolution)
196 | x2_line = np.linspace(x2_min, x2_max, config.grid_resolution)
197 | x1_grid, x2_grid = np.meshgrid(x1_line, x2_line)
198 |
199 | # Grid points for density evaluation
200 | grid_points = np.stack([x1_grid.flatten(), x2_grid.flatten()], axis=1)
201 |
202 | return x1_grid, x2_grid, grid_points
203 |
204 |
205 | def compute_2d_density(
206 | grid_points: np.ndarray,
207 | cluster_mean_2d: np.ndarray,
208 | cluster_cov_2d: np.ndarray,
209 | mixing_coeff: float
210 | ) -> np.ndarray:
211 | """
212 | Compute 2D density for a single cluster.
213 |
214 | Args:
215 | grid_points: Grid points for evaluation [num_points, 2]
216 | cluster_mean_2d: 2D cluster mean [2]
217 | cluster_cov_2d: 2D cluster covariance [2, 2]
218 | mixing_coeff: Mixing coefficient for this cluster
219 |
220 | Returns:
221 | Density values at grid points
222 | """
223 | density = multivariate_normal.pdf(
224 | x=grid_points,
225 | mean=cluster_mean_2d,
226 | cov=cluster_cov_2d
227 | )
228 | return density * mixing_coeff
229 |
230 |
231 | def visualize_gmm(
232 | iteration: int,
233 | decode_k: int,
234 | sample_num: int,
235 | model_dir: str = "./vae_gmm"
236 | ) -> None:
237 | """
238 | Visualize GMM clusters and samples in 2D space.
239 |
240 | Args:
241 | iteration: Current iteration number
242 | decode_k: Cluster index to visualize
243 | sample_num: Number of samples to generate
244 | model_dir: Directory containing model files
245 | """
246 | config = VisualizationConfig()
247 | sampling_config = SamplingConfig()
248 |
249 | # Load GMM parameters
250 | mu_gmm, lambda_gmm, pi_gmm = load_gmm_parameters(iteration, model_dir)
251 |
252 | # Generate samples from specified cluster
253 | manual_samples, gmm_samples = sample(
254 | iteration=iteration,
255 | x_dim=sampling_config.default_latent_dim,
256 | mu_gmm=mu_gmm,
257 | lambda_gmm=lambda_gmm,
258 | sample_num=sample_num,
259 | sample_k=decode_k,
260 | model_dir=model_dir
261 | )
262 |
263 | # Extract 2D projections for visualization
264 | num_clusters = len(mu_gmm)
265 | means_2d = mu_gmm[:, :2] # First 2 dimensions
266 | precisions_2d = lambda_gmm[:, :2, :2] # First 2x2 block
267 |
268 | # Create visualization grid
269 | x1_grid, x2_grid, grid_points = create_2d_grid(means_2d, precisions_2d, config)
270 |
271 | # Compute density for the selected cluster
272 | cluster_cov_2d = np.linalg.inv(precisions_2d[decode_k])
273 | density = compute_2d_density(
274 | grid_points, means_2d[decode_k], cluster_cov_2d, pi_gmm[decode_k]
275 | )
276 |
277 | # Create the plot
278 | plt.figure(figsize=config.figure_size_2d)
279 |
280 | # Plot samples
281 | plt.scatter(
282 | manual_samples[:, 0], manual_samples[:, 1],
283 | label=f"Cluster {decode_k + 1} samples",
284 | s=config.scatter_size, alpha=0.7
285 | )
286 |
287 | # Plot cluster centers
288 | plt.scatter(
289 | means_2d[:, 0], means_2d[:, 1],
290 | color="red", s=config.marker_size, marker="x",
291 | label="Cluster centers", linewidths=3
292 | )
293 |
294 | # Plot density contours
295 | density_reshaped = density.reshape(x1_grid.shape)
296 | contours = plt.contour(
297 | x1_grid, x2_grid, density_reshaped,
298 | alpha=config.alpha, linestyles="dashed"
299 | )
300 |
301 | # Formatting
302 | plt.suptitle("Gaussian Mixture Model Visualization", fontsize=config.font_size_title)
303 | plt.title(f"Samples: {sample_num}, Cluster: {decode_k + 1}", fontsize=14)
304 | plt.xlabel("Latent Dimension 1", fontsize=12)
305 | plt.ylabel("Latent Dimension 2", fontsize=12)
306 | plt.legend()
307 | plt.grid(True, alpha=0.3)
308 | plt.colorbar(contours, label="Density")
309 |
310 | # Save the plot
311 | output_path = f"{model_dir}/graph/gaussian_I{iteration}_k{decode_k}.{config.file_format}"
312 | plt.savefig(output_path, dpi=config.dpi, bbox_inches='tight')
313 | plt.close()
314 |
315 | print(f"GMM visualization saved to: {output_path}")
316 |
317 |
318 | def visualize_latent_space(
319 | iteration: int,
320 | latent_vars: np.ndarray,
321 | labels: np.ndarray,
322 | save_dir: str,
323 | config: Optional[VisualizationConfig] = None
324 | ) -> None:
325 | """
326 | Visualize VAE latent space using dimensionality reduction.
327 |
328 | Args:
329 | iteration: Current iteration number
330 | latent_vars: Latent variables [num_samples, latent_dim]
331 | labels: Ground truth labels [num_samples]
332 | save_dir: Directory to save visualization
333 | config: Visualization configuration
334 | """
335 | if config is None:
336 | config = VisualizationConfig()
337 |
338 | # Apply dimensionality reduction
339 | if config.use_tsne:
340 | reducer = TSNE(
341 | n_components=2,
342 | random_state=config.random_state,
343 | perplexity=min(config.tsne_perplexity, len(latent_vars) - 1)
344 | )
345 | points_2d = reducer.fit_transform(latent_vars)
346 | else:
347 | # Could implement PCA here if needed
348 | raise NotImplementedError("PCA option not implemented yet")
349 |
350 | # Create the plot
351 | plt.figure(figsize=config.figure_size_latent)
352 |
353 | # Plot each class with different colors and number markers
354 | unique_labels = np.unique(labels)
355 | for label in unique_labels:
356 | mask = labels == label
357 | plt.scatter(
358 | points_2d[mask, 0], points_2d[mask, 1],
359 | c=config.colors[label % len(config.colors)],
360 | s=config.scatter_size,
361 | marker=f"${label}$",
362 | label=f"Class {label}",
363 | alpha=0.7
364 | )
365 |
366 | # Formatting
367 | plt.title("VAE Latent Space Visualization", fontsize=config.font_size_title)
368 | plt.xlabel("t-SNE Component 1", fontsize=12)
369 | plt.ylabel("t-SNE Component 2", fontsize=12)
370 | plt.tick_params(labelsize=12)
371 | plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
372 | plt.grid(True, alpha=0.3)
373 |
374 | # Save the plot
375 | output_path = f"{save_dir}/graph/latent_space_{iteration}.{config.file_format}"
376 | plt.savefig(output_path, dpi=config.dpi, bbox_inches='tight')
377 | plt.close()
378 |
379 | print(f"Latent space visualization saved to: {output_path}")
380 |
381 |
382 | def visualize_ls(iteration: int, z: np.ndarray, labels: np.ndarray, save_dir: str) -> None:
383 | """
384 | Legacy function for latent space visualization (maintained for compatibility).
385 |
386 | Args:
387 | iteration: Current iteration number
388 | z: Latent variables
389 | labels: Ground truth labels
390 | save_dir: Directory to save visualization
391 | """
392 | visualize_latent_space(iteration, z, labels, save_dir)
393 |
394 |
395 | def get_param(iteration: int, model_dir: str = "./vae_gmm") -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
396 | """
397 | Legacy function for loading GMM parameters (maintained for compatibility).
398 |
399 | Args:
400 | iteration: Iteration number to load
401 | model_dir: Directory containing saved parameters
402 |
403 | Returns:
404 | Tuple of (means, precision_matrices, mixing_coefficients)
405 | """
406 | return load_gmm_parameters(iteration, model_dir)
407 |
408 |
409 | def compute_clustering_accuracy_hungarian(
410 | predicted_labels: np.ndarray,
411 | true_labels: np.ndarray
412 | ) -> Tuple[float, np.ndarray]:
413 | """
414 | Compute clustering accuracy using Hungarian algorithm approximation.
415 |
416 | This function finds the best permutation of cluster labels to maximize
417 | accuracy against ground truth labels.
418 |
419 | Args:
420 | predicted_labels: Predicted cluster assignments
421 | true_labels: Ground truth labels
422 |
423 | Returns:
424 | Tuple of (best_accuracy, best_permuted_labels)
425 | """
426 | num_clusters = int(np.max(predicted_labels)) + 1
427 | num_samples = len(predicted_labels)
428 |
429 | best_accuracy = 0.0
430 | best_labels = predicted_labels.copy()
431 |
432 | # Try all possible permutations (brute force for small K)
433 | # For larger K, this should use the Hungarian algorithm
434 | if num_clusters <= 10: # Brute force feasible for small number of clusters
435 | from itertools import permutations
436 |
437 | for perm in permutations(range(num_clusters)):
438 | # Create permutation mapping
439 | permuted_labels = np.zeros_like(predicted_labels)
440 | for original_label, new_label in enumerate(perm):
441 | mask = predicted_labels == original_label
442 | permuted_labels[mask] = new_label
443 |
444 | # Calculate accuracy
445 | accuracy = np.sum(permuted_labels == true_labels) / num_samples
446 |
447 | if accuracy > best_accuracy:
448 | best_accuracy = accuracy
449 | best_labels = permuted_labels.copy()
450 | else:
451 | # For larger number of clusters, use the original iterative approach
452 | best_accuracy, best_labels = calc_acc(predicted_labels, true_labels)
453 |
454 | return best_accuracy, best_labels
455 |
456 |
457 | def calc_acc(results: np.ndarray, correct: np.ndarray) -> Tuple[float, np.ndarray]:
458 | """
459 | Calculate clustering accuracy by finding optimal label permutation.
460 |
461 | This function iteratively tries swapping pairs of cluster labels to find
462 | the permutation that maximizes accuracy against ground truth.
463 |
464 | Args:
465 | results: Predicted cluster labels
466 | correct: Ground truth labels
467 |
468 | Returns:
469 | Tuple of (maximum_accuracy, optimal_labels)
470 | """
471 | num_clusters = int(np.max(results)) + 1
472 | num_samples = len(results)
473 | max_accuracy = 0.0
474 | current_labels = results.copy()
475 |
476 | # Iteratively improve by swapping labels
477 | improved = True
478 | max_iterations = 100 # Prevent infinite loops
479 | iteration_count = 0
480 |
481 | while improved and iteration_count < max_iterations:
482 | improved = False
483 | iteration_count += 1
484 |
485 | # Try all pairs of cluster labels
486 | for i in range(num_clusters):
487 | for j in range(i + 1, num_clusters):
488 | # Create temporary result with swapped labels
489 | temp_result = current_labels.copy()
490 |
491 | # Swap labels i and j
492 | mask_i = current_labels == i
493 | mask_j = current_labels == j
494 | temp_result[mask_i] = j
495 | temp_result[mask_j] = i
496 |
497 | # Calculate accuracy
498 | accuracy = np.sum(temp_result == correct) / num_samples
499 |
500 | # Update if improved
501 | if accuracy > max_accuracy:
502 | max_accuracy = accuracy
503 | current_labels = temp_result.copy()
504 | improved = True
505 |
506 | return max_accuracy, current_labels
--------------------------------------------------------------------------------