├── requirements.txt ├── .gitignore ├── img └── madgan.png ├── madgan ├── __init__.py ├── constants.py ├── __main__.py ├── data.py ├── anomaly.py ├── train.py ├── models.py └── engine.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .mypy_cache/ 2 | /data/ 3 | .vscode/ 4 | models/ 5 | __pycache__/ -------------------------------------------------------------------------------- /img/madgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Guillem96/madgan-pytorch/HEAD/img/madgan.png -------------------------------------------------------------------------------- /madgan/__init__.py: -------------------------------------------------------------------------------- 1 | from madgan import data 2 | from madgan import engine 3 | from madgan import models -------------------------------------------------------------------------------- /madgan/constants.py: -------------------------------------------------------------------------------- 1 | REAL_LABEL = 0 2 | FAKE_LABEL = 1 3 | WINDOW_SIZE = 256 4 | WINDOW_STRIDE = 32 5 | LATENT_SPACE_DIM = 32 -------------------------------------------------------------------------------- /madgan/__main__.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from madgan.train import train 3 | 4 | _madgan_cli = typer.Typer(name="MAD-GAN CLI") 5 | _madgan_cli.command(name="train")(train) 6 | 7 | 8 | @_madgan_cli.callback() 9 | def main() -> None: 10 | """MAD-GAN Command Line Interface.""" 11 | 12 | if __name__ == "__main__": 13 | _madgan_cli() 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAD-GAN PyTorch 🧠🎨 2 | 3 | Multivariate Anomaly Detection with GAN (MAD-GAN) PyTorch modern implementation. 4 | 5 | This implementation is based on the model described in the MAD-GAN paper (https://arxiv.org/pdf/1901.04997.pdf). 6 | 7 | ## Model picture 🖼 8 | 9 | ![](img/madgan.png) 10 | 11 | ## Train the MAD-GAN ⛹️‍♀️ 12 | 13 | To train the MAD-GAN neural network you need a preprocessed dataset in CSV format 14 | (more formats will come soon). 15 | 16 | The CSV should look like this: 17 | 18 | ``` 19 | feature1,feature2,feature3,featureN 20 | 0.1,-0.2,-0.7,0.8 21 | 0.1,-0.2,-0.7,0.8 22 | 0.1,-0.2,-0.7,0.8 23 | ... 24 | ``` 25 | 26 | > Note that for now time based windows are not supported (support will drop soon) 27 | 28 | Then to train the model you just run the following CLI command. 29 | 30 | ``` 31 | $ python -m madgan train \ 32 | data/dataset.csv 33 | --batch-size 32 \ 34 | --epochs 8 \ 35 | --model-dir models/madgan # Training checkpoints will be stored here 36 | ``` 37 | 38 | ## Use a trained model 39 | 40 | TBD 41 | 42 | ## References 📖 43 | 44 | [1] [MAD-GAN: Multivariate Anomaly Detection for Time Series Data with Generative Adversarial Networks](https://arxiv.org/pdf/1901.04997.pdf) 45 | -------------------------------------------------------------------------------- /madgan/data.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Iterator, Optional, Sequence 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch.utils.data.distributed import DistributedSampler 5 | 6 | import pandas as pd 7 | import numpy as np 8 | 9 | 10 | class WindowDataset(Dataset): 11 | """Dataset to iterate using sliding windows over a pandas DataFrame. 12 | 13 | Args: 14 | df (pd.DataFrame): Sorted DataFrame by time. 15 | window_size (int): Number of elements per window. 16 | window_slide (int): Step size between each window. 17 | 18 | """ 19 | 20 | def __init__(self, df: pd.DataFrame, window_size: int, 21 | window_slide: int) -> None: 22 | self.windows = _window_array(df.values, window_size, window_slide) 23 | 24 | def __getitem__(self, index: int) -> torch.Tensor: 25 | return torch.as_tensor(self.windows[index].copy()) 26 | 27 | def __len__(self) -> int: 28 | return self.windows.shape[0] 29 | 30 | 31 | class LatentSpaceIterator(object): 32 | """Iterator that generates random sliding windows.""" 33 | 34 | def __init__(self, 35 | noise_shape: Sequence[int], 36 | noise_type: str = "uniform") -> None: 37 | self.noise_fn: Callable[[Any], torch.Tensor] 38 | self.noise_shape = noise_shape 39 | if noise_type == "uniform": 40 | self.noise_fn = torch.rand 41 | elif noise_type == "normal": 42 | self.noise_fn = torch.randn 43 | else: 44 | raise ValueError(f"Unexpected noise type {noise_type}") 45 | 46 | def __iter__(self) -> Iterator[torch.Tensor]: 47 | return self 48 | 49 | def __next__(self) -> torch.Tensor: 50 | return self.noise_fn(*self.noise_shape) 51 | 52 | 53 | def prepare_dataloader(ds: Dataset, 54 | batch_size: int, 55 | is_distributed: bool = False, 56 | **kwargs: Any) -> DataLoader: 57 | """Creates a dataloader for training. 58 | 59 | Args: 60 | ds (Dataset): Training dataset. 61 | batch_size (int): DataLoader batch size. 62 | is_distributed (bool): Is the training distributed over multiple nodes? 63 | Defaults to False. 64 | 65 | Returns: 66 | DataLoader: Data iterator ready to use. 67 | """ 68 | sampler: Optional[DistributedSampler] = (DistributedSampler(ds) 69 | if is_distributed else None) 70 | return DataLoader(ds, batch_size=batch_size, sampler=sampler, **kwargs) 71 | 72 | 73 | def _window_array(array: np.ndarray, window_size: int, 74 | window_slide: int) -> np.ndarray: 75 | return np.vstack([ 76 | array[i:i + window_size] 77 | for i in range(0, array.shape[0], window_slide) 78 | ]) 79 | -------------------------------------------------------------------------------- /madgan/anomaly.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AnomalyDetector(object): 7 | 8 | def __init__(self, 9 | *, 10 | discriminator: nn.Module, 11 | generator: nn.Module, 12 | latent_space_dim: int, 13 | res_weight: float = .2, 14 | anomaly_threshold: float = 1.0) -> None: 15 | self.discriminator = discriminator 16 | self.generator = generator 17 | self.threshold = anomaly_threshold 18 | self.latent_space_dim = latent_space_dim 19 | self.res_weight = res_weight 20 | 21 | def predict(self, tensor: torch.Tensor) -> torch.Tensor: 22 | return (self.predict_proba(tensor) > self.anomaly_threshold).int() 23 | 24 | def predict_proba(self, tensor: torch.Tensor) -> torch.Tensor: 25 | discriminator_score = self.compute_anomaly_score(tensor) 26 | discriminator_score *= 1 - self.res_weight 27 | reconstruction_loss = self.compute_reconstruction_loss(tensor) 28 | reconstruction_loss *= self.res_weight 29 | return reconstruction_loss + discriminator_score 30 | 31 | def compute_anomaly_score(self, tensor: torch.Tensor) -> torch.Tensor: 32 | with torch.no_grad(): 33 | discriminator_score = self.discriminator(tensor) 34 | return discriminator_score 35 | 36 | def compute_reconstruction_loss(self, 37 | tensor: torch.Tensor) -> torch.Tensor: 38 | best_reconstruct = self._generate_best_reconstruction(tensor) 39 | return (best_reconstruct - tensor).abs().sum(dim=(1, 2)) 40 | 41 | def _generate_reconstruction(self, tensor: torch.Tensor) -> None: 42 | # The goal of this function is to find the corresponding latent space for the given 43 | # input and then generate the best possible reconstruction. 44 | max_iters = 10 45 | 46 | Z = torch.empty( 47 | (tensor.size(0), tensor.size(1), self.latent_space_dim), 48 | requires_grad=True) 49 | nn.init.normal_(Z, std=0.05) 50 | 51 | optimizer = torch.optim.RMSprop(params=[Z], lr=0.1) 52 | loss_fn = nn.MSELoss(reduction="none") 53 | normalized_target = F.normalize(tensor, dim=1, p=2) 54 | 55 | for _ in range(max_iters): 56 | optimizer.zero_grad() 57 | generated_samples = self.generator(Z) 58 | normalized_input = F.normalize(generated_samples, dim=1, p=2) 59 | reconstruction_error = loss_fn(normalized_input, 60 | normalized_target).sum(dim=(1, 2)) 61 | reconstruction_error.backward() 62 | optimizer.step() 63 | 64 | with torch.no_grad(): 65 | best_reconstruct = self.generator(Z) 66 | return best_reconstruct 67 | -------------------------------------------------------------------------------- /madgan/train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Iterator, Tuple 3 | 4 | import pandas as pd 5 | import torch 6 | 7 | import madgan 8 | from madgan import constants 9 | 10 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | def train( 14 | input_data: str, 15 | batch_size: int = 32, 16 | epochs: int = 8, 17 | lr: float = 1e-4, 18 | hidden_dim: int = 512, 19 | window_size: int = constants.WINDOW_SIZE, 20 | window_stride: int = constants.WINDOW_STRIDE, 21 | random_seed: int = 0, 22 | model_dir: Path = Path("models/madgan"), 23 | ) -> None: 24 | 25 | madgan.engine.set_seed(random_seed) 26 | 27 | model_dir.mkdir(parents=True, exist_ok=True) 28 | 29 | df = pd.read_csv(input_data) 30 | train_dl, test_dl = _prepare_data(df=df, 31 | batch_size=batch_size, 32 | window_size=window_size, 33 | window_stride=window_stride) 34 | latent_space = madgan.data.LatentSpaceIterator(noise_shape=[ 35 | batch_size, 36 | window_size, 37 | df.shape[-1], 38 | ]) 39 | 40 | generator = madgan.models.Generator( 41 | latent_space_dim=constants.LATENT_SPACE_DIM, 42 | hidden_units=hidden_dim, 43 | output_dim=df.shape[-1]) 44 | generator.to(DEVICE) 45 | 46 | discriminator = madgan.models.Discriminator(input_dim=df.shape[-1], 47 | hidden_units=hidden_dim, 48 | add_batch_mean=True) 49 | generator.to(DEVICE) 50 | 51 | discriminator_optim = torch.optim.Adam(discriminator.parameters(), lr=lr) 52 | generator_optim = torch.optim.Adam(generator.parameters(), lr=lr) 53 | 54 | criterion_fn = torch.nn.BCELoss() 55 | 56 | for epoch in range(epochs): 57 | madgan.engine.train_one_epoch( 58 | generator=generator, 59 | discriminator=discriminator, 60 | loss_fn=criterion_fn, 61 | real_dataloader=train_dl, 62 | latent_dataloader=latent_space, 63 | discriminator_optimizer=discriminator_optim, 64 | generator_optimizer=generator_optim, 65 | normal_label=constants.REAL_LABEL, 66 | anomaly_label=constants.FAKE_LABEL, 67 | epoch=epoch) 68 | 69 | madgan.engine.evaluate(generator=generator, 70 | discriminator=discriminator, 71 | real_dataloader=test_dl, 72 | latent_dataloader=latent_space, 73 | loss_fn=criterion_fn, 74 | normal_label=constants.REAL_LABEL, 75 | anomaly_label=constants.FAKE_LABEL) 76 | 77 | generator.save(model_dir / f"generator_{epoch}.pt") 78 | discriminator.save(model_dir / f"discriminator_{epoch}.pt") 79 | 80 | 81 | def _prepare_data( 82 | df: pd.DataFrame, 83 | batch_size: int, 84 | window_size: int, 85 | window_stride: int, 86 | ) -> Tuple[Iterator[torch.Tensor], Iterator[torch.Tensor]]: 87 | dataset = madgan.data.WindowDataset(df, 88 | window_size=window_size, 89 | window_slide=window_stride) 90 | 91 | indices = torch.randperm(len(dataset)) 92 | train_len = int(len(dataset) * .8) 93 | train_dataset = torch.utils.data.Subset(dataset, 94 | indices[:train_len].tolist()) 95 | test_dataset = torch.utils.data.Subset(dataset, 96 | indices[train_len:].tolist()) 97 | 98 | train_dl = madgan.data.prepare_dataloader(train_dataset, 99 | batch_size=batch_size) 100 | test_dl = madgan.data.prepare_dataloader(test_dataset, 101 | batch_size=batch_size) 102 | return train_dl, test_dl 103 | -------------------------------------------------------------------------------- /madgan/models.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Protocol, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class SerializableModule(Protocol): 9 | 10 | def save(self, fpath: Union[str, Path]) -> None: 11 | ... 12 | 13 | @classmethod 14 | def from_pretrained( 15 | cls, fpath: Union[str, Path], 16 | map_location: Optional[torch.device]) -> "SerializableModule": 17 | ... 18 | 19 | 20 | class Generator(nn.Module): 21 | 22 | def __init__(self, 23 | latent_space_dim: int, 24 | hidden_units: int, 25 | output_dim: int, 26 | n_lstm_layers: int = 2) -> None: 27 | super().__init__() 28 | self.latent_space_dim = latent_space_dim 29 | self.hidden_units = hidden_units 30 | self.n_lstm_layers = n_lstm_layers 31 | self.output_dim = output_dim 32 | 33 | self.lstm = nn.LSTM(input_size=self.latent_space_dim, 34 | hidden_size=self.hidden_units, 35 | num_layers=self.n_lstm_layers, 36 | batch_first=True, 37 | dropout=.1) 38 | 39 | self.linear = nn.Linear(in_features=self.hidden_units, 40 | out_features=self.output_dim) 41 | nn.init.trunc_normal_(self.linear.bias) 42 | nn.init.trunc_normal_(self.linear.weight) 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | rnn_output, _ = self.lstm(x) 46 | return self.linear(rnn_output) 47 | 48 | def save(self, fpath: Union[Path, str]) -> None: 49 | chkp = { 50 | "config": { 51 | "latent_space_dim": self.latent_space_dim, 52 | "hidden_units": self.hidden_units, 53 | "n_lstm_layers": self.n_lstm_layers, 54 | "output_dim": self.output_dim 55 | }, 56 | "weights": self.state_dict(), 57 | } 58 | torch.save(chkp, fpath) 59 | 60 | @classmethod 61 | def from_pretrained( 62 | cls, 63 | fpath: Union[Path, str], 64 | map_location: Optional[torch.device] = None) -> "Generator": 65 | chkp = torch.load(fpath, map_location=map_location) 66 | model = cls(**chkp.pop("config")) 67 | model.load_state_dict(chkp.pop("weights")) 68 | model.eval() 69 | return model 70 | 71 | 72 | class Discriminator(nn.Module): 73 | 74 | def __init__(self, 75 | input_dim: int, 76 | hidden_units: int, 77 | n_lstm_layers: int = 2, 78 | add_batch_mean: bool = False) -> None: 79 | super().__init__() 80 | self.add_batch_mean = add_batch_mean 81 | self.hidden_units = hidden_units 82 | self.input_dim = input_dim 83 | self.n_lstm_layers = n_lstm_layers 84 | 85 | extra_features = self.hidden_units if self.add_batch_mean else 0 86 | self.lstm = nn.LSTM(input_size=self.input_dim, 87 | hidden_size=self.hidden_units + extra_features, 88 | num_layers=self.n_lstm_layers, 89 | batch_first=True, 90 | dropout=.1) 91 | 92 | self.linear = nn.Linear(in_features=self.hidden_units + extra_features, 93 | out_features=1) 94 | nn.init.trunc_normal_(self.linear.bias) 95 | nn.init.trunc_normal_(self.linear.weight) 96 | 97 | self.activation = nn.Sigmoid() 98 | 99 | def forward(self, x: torch.Tensor) -> torch.Tensor: 100 | if self.add_batch_mean: 101 | bs = x.size(0) 102 | batch_mean = x.mean(0, keepdim=True).repeat(bs, 1, 1) 103 | x = torch.cat([x, batch_mean], dim=-1) 104 | 105 | rnn_output, _ = self.lstm(x) 106 | return self.activation(self.linear(rnn_output)) 107 | 108 | def save(self, fpath: Union[Path, str]) -> None: 109 | chkp = { 110 | "config": { 111 | "add_batch_mean": self.add_batch_mean, 112 | "hidden_units": self.hidden_units, 113 | "input_dim": self.input_dim, 114 | "n_lstm_layers": self.n_lstm_layers 115 | }, 116 | "weights": self.state_dict(), 117 | } 118 | torch.save(chkp, fpath) 119 | 120 | @classmethod 121 | def from_pretrained( 122 | cls, 123 | fpath: Union[Path, str], 124 | map_location: Optional[torch.device] = None) -> "Discriminator": 125 | chkp = torch.load(fpath, map_location=map_location) 126 | model = cls(**chkp.pop("config")) 127 | model.load_state_dict(chkp.pop("weights")) 128 | model.eval() 129 | return model 130 | -------------------------------------------------------------------------------- /madgan/engine.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Callable, Dict, Iterator 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] 9 | 10 | 11 | def set_seed(seed: int = 0) -> None: 12 | torch.manual_seed(seed) 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | 16 | 17 | def train_one_epoch(generator: nn.Module, 18 | discriminator: nn.Module, 19 | loss_fn: LossFn, 20 | real_dataloader: Iterator[torch.Tensor], 21 | latent_dataloader: Iterator[torch.Tensor], 22 | discriminator_optimizer: torch.optim.Optimizer, 23 | generator_optimizer: torch.optim.Optimizer, 24 | normal_label: int = 0, 25 | anomaly_label: int = 1, 26 | epoch: int = 0, 27 | log_every: int = 30) -> None: 28 | """Trains a GAN for a single epoch. 29 | 30 | Args: 31 | generator (nn.Module): Torch module implementing the GAN generator. 32 | discriminator (nn.Module): Torch module implementing the GAN 33 | discriminator. 34 | loss_fn (LossFn): Loss function, should return a reduced value. 35 | real_dataloader (Iterator[torch.Tensor]): Iterator to go over real data 36 | samples. 37 | latent_dataloader (Iterator[torch.Tensor]): Iterator to go through 38 | generated samples from the latent space. 39 | discriminator_optimizer (torch.optim.Optimizer): Optimizer for the 40 | discrimninator. 41 | generator_optimizer (torch.optim.Optimizer): Oprimizer for the generator 42 | module. 43 | normal_label (int): Label for samples with normal behaviour 44 | (real or non-anomaly). Defaults to 0. 45 | anomaly_label (int): Label that identifies generate samples 46 | (anomalies when running inference). Defaults to 1. 47 | epoch (int, optional): Current epoch (just for logging purposes). 48 | Defaults to 0. 49 | log_every (int, optional): Log the training progess every n steps. 50 | Defaults to 30. 51 | """ 52 | generator.train() 53 | discriminator.train() 54 | 55 | for i, (real, z) in enumerate(zip(real_dataloader, latent_dataloader)): 56 | bs = real.size(0) 57 | real_labels = torch.full((bs, ), normal_label).float().to(real.device) 58 | fake_labels = torch.full((bs, ), anomaly_label).float().to(real.device) 59 | all_labels = torch.cat([real_labels, fake_labels]) 60 | 61 | # Generate fake samples with the generator 62 | fake = generator(z) 63 | 64 | # Update discriminator 65 | discriminator_optimizer.zero_grad() 66 | discriminator.train() 67 | real_logits = discriminator(real) 68 | fake_logits = discriminator(fake.detach()) 69 | d_logits = torch.cat([real_logits, fake_logits]) 70 | 71 | # Discriminator tries to identify the true nature of each sample 72 | # (anomaly or normal) 73 | d_real_loss = loss_fn(real_logits.view(-1), real_labels) 74 | d_fake_loss = loss_fn(fake_logits.view(-1), fake_labels) 75 | d_loss = d_real_loss + d_fake_loss 76 | d_loss.backward() 77 | 78 | discriminator_optimizer.step() 79 | 80 | # Update generator 81 | generator.zero_grad() 82 | discriminator.eval() 83 | 84 | g_logits = discriminator(fake) 85 | # Generator will improve so it can cheat the discriminator 86 | cheat_loss = loss_fn(g_logits, real_labels) 87 | cheat_loss.backward() 88 | generator_optimizer.step() 89 | 90 | if (i + 1) % log_every == 0: 91 | discriminator_acc = ((d_logits.detach() > 92 | .5) == all_labels).float() 93 | discriminator_acc = discriminator_acc.sum().div(bs) 94 | 95 | generator_acc = (g_logits.detach() > .5 == real_labels).float() 96 | generator_acc = generator_acc.sum().div(bs) 97 | 98 | log = { 99 | "generator_loss": cheat_loss.item(), 100 | "discriminator_loss": d_loss.item(), 101 | "discriminator_acc": discriminator_acc.item(), 102 | "generator_acc": generator_acc.item(), 103 | } 104 | 105 | header = f"Epoch [{epoch}] Step [{i}]" 106 | log_metrics = " ".join( 107 | f"{metric_name}:{metric_value:.4f}" 108 | for metric_name, metric_value in log.items()) 109 | print(header, log_metrics) 110 | 111 | discriminator.zero_grad() 112 | generator.zero_grad() 113 | 114 | 115 | @torch.no_grad() 116 | def evaluate(generator: nn.Module, 117 | discriminator: nn.Module, 118 | loss_fn: LossFn, 119 | real_dataloader: Iterator[torch.Tensor], 120 | latent_dataloader: Iterator[torch.Tensor], 121 | normal_label: int = 0, 122 | anomaly_label: int = 1) -> Dict[str, float]: 123 | """Evaluates a trained GAN. 124 | 125 | Reports the real and fake losses for the discriminator, as well as the 126 | accuracies. 127 | 128 | Args: 129 | generator (nn.Module): Torch module implementing the GAN generator. 130 | discriminator (nn.Module): Torch module implementing the GAN 131 | discriminator. 132 | loss_fn (LossFn): Loss function, should return a reduced value. 133 | real_dataloader (Iterator[torch.Tensor]): Iterator to go over real data 134 | samples. 135 | latent_dataloader (Iterator[torch.Tensor]): Iterator to go through 136 | generated samples from the latent space. 137 | normal_label (int): Label for samples with normal behaviour 138 | (real or non-anomaly). Defaults to 0. 139 | anomaly_label (int): Label that identifies generate samples 140 | (anomalies when running inference). Defaults to 1. 141 | 142 | Returns: 143 | Dict[str, float]: Aggregated metrics. 144 | """ 145 | generator.eval() 146 | discriminator.eval() 147 | 148 | agg_metrics: Dict[str, float] = {} 149 | for real, z in zip(real_dataloader, latent_dataloader): 150 | bs = real.size(0) 151 | real_labels = torch.full((bs, ), normal_label).float().to(real.device) 152 | fake_labels = torch.full((bs, ), anomaly_label).float().to(real.device) 153 | all_labels = torch.cat([real_labels, fake_labels]) 154 | 155 | # Generate fake samples with the generator 156 | fake = generator(z) 157 | 158 | # Try to classify the real and generated samples 159 | real_logits = discriminator(real) 160 | fake_logits = discriminator(fake.detach()) 161 | d_logits = torch.cat([real_logits, fake_logits]) 162 | 163 | # Discriminator tries to identify the true nature of each sample 164 | # (anomaly or normal) 165 | d_real_loss = loss_fn(real_logits.view(-1), real_labels) 166 | d_fake_loss = loss_fn(fake_logits.view(-1), fake_labels) 167 | d_loss = d_real_loss + d_fake_loss 168 | 169 | discriminator_acc = ((d_logits > .5) == all_labels).float() 170 | discriminator_acc = discriminator_acc.sum().div(bs) 171 | 172 | generator_acc = (fake_logits > .5 == real_labels).float() 173 | generator_acc = generator_acc.sum().div(bs) 174 | 175 | log = { 176 | "discriminator_real_loss": d_real_loss.item(), 177 | "discriminator_fake_loss": d_fake_loss.item(), 178 | "discriminator_loss": d_loss.item(), 179 | "discriminator_acc": discriminator_acc.item(), 180 | "generator_acc": generator_acc.item(), 181 | } 182 | 183 | if not agg_metrics: 184 | agg_metrics = log 185 | else: 186 | agg_metrics = { 187 | metric_name: (agg_metrics[metric_name] + metric_value) / 2. 188 | for metric_name, metric_value in log.items() 189 | } 190 | 191 | log_metrics = " ".join( 192 | f"{metric_name}:{metric_value:.4f}" 193 | for metric_name, metric_value in agg_metrics.items()) 194 | print("Evaluation metrics:", log_metrics) 195 | return agg_metrics 196 | --------------------------------------------------------------------------------