├── EAL-GAN ├── data │ ├── mnist.mat │ ├── musk.mat │ ├── shuttle.mat │ ├── vowels.mat │ ├── optdigits.mat │ ├── pendigits.mat │ ├── satellite.mat │ └── satimage-2.mat ├── models │ ├── sync_batchnorm │ │ ├── comm.pyc │ │ ├── __init__.py │ │ ├── unittest.py │ │ ├── batchnorm_reimpl.py │ │ ├── replicate.py │ │ ├── comm.py │ │ └── batchnorm.py │ ├── losses.py │ ├── pyod_utils.py │ ├── layers.py │ └── EAL_GAN.py ├── README.md └── Train_EAL_GAN.py ├── EAL-GAN-image ├── src │ ├── base │ │ ├── __init__.py │ │ ├── torchvision_dataset.py │ │ ├── base_net.py │ │ ├── base_dataset.py │ │ ├── base_trainer.py │ │ └── odds_dataset.py │ ├── datasets │ │ ├── __init__.py │ │ ├── odds.py │ │ ├── main.py │ │ ├── preprocessing.py │ │ ├── cifar10.py │ │ ├── mnist.py │ │ └── fmnist.py │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── unittest.py │ │ ├── batchnorm_reimpl.py │ │ ├── replicate.py │ │ ├── comm.py │ │ └── batchnorm.py │ ├── loss.py │ ├── my_utils.py │ ├── EAL_GAN.py │ ├── layers.py │ └── BigGAN.py └── Train_EAL_GAN.py └── README.md /EAL-GAN/data/mnist.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/mnist.mat -------------------------------------------------------------------------------- /EAL-GAN/data/musk.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/musk.mat -------------------------------------------------------------------------------- /EAL-GAN/data/shuttle.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/shuttle.mat -------------------------------------------------------------------------------- /EAL-GAN/data/vowels.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/vowels.mat -------------------------------------------------------------------------------- /EAL-GAN/data/optdigits.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/optdigits.mat -------------------------------------------------------------------------------- /EAL-GAN/data/pendigits.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/pendigits.mat -------------------------------------------------------------------------------- /EAL-GAN/data/satellite.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/satellite.mat -------------------------------------------------------------------------------- /EAL-GAN/data/satimage-2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/satimage-2.mat -------------------------------------------------------------------------------- /EAL-GAN/models/sync_batchnorm/comm.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/models/sync_batchnorm/comm.pyc -------------------------------------------------------------------------------- /EAL-GAN-image/src/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import * 2 | from .torchvision_dataset import * 3 | from .odds_dataset import * 4 | from .base_net import * 5 | from .base_trainer import * 6 | -------------------------------------------------------------------------------- /EAL-GAN/models/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 2 | from .replicate import DataParallelWithCallback, patch_replication_callback 3 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import load_dataset 2 | from .mnist import MNIST_Dataset 3 | from .fmnist import FashionMNIST_Dataset 4 | from .cifar10 import CIFAR10_Dataset 5 | from .odds import ODDSADDataset 6 | from .preprocessing import * 7 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /EAL-GAN/models/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/base/torchvision_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseADDataset 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | class TorchvisionDataset(BaseADDataset): 6 | """TorchvisionDataset class for datasets already implemented in torchvision.datasets.""" 7 | 8 | def __init__(self, root: str): 9 | super().__init__(root) 10 | 11 | def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> ( 12 | DataLoader, DataLoader): 13 | train_loader = DataLoader(dataset=self.train_set, batch_size=batch_size, shuffle=shuffle_train, 14 | num_workers=num_workers, drop_last=True) 15 | test_loader = DataLoader(dataset=self.test_set, batch_size=batch_size, shuffle=shuffle_test, 16 | num_workers=num_workers, drop_last=False) 17 | return train_loader, test_loader 18 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/base/base_net.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class BaseNet(nn.Module): 7 | """Base class for all neural networks.""" 8 | 9 | def __init__(self): 10 | super().__init__() 11 | self.logger = logging.getLogger(self.__class__.__name__) 12 | self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer 13 | 14 | def forward(self, *input): 15 | """ 16 | Forward pass logic 17 | :return: Network output 18 | """ 19 | raise NotImplementedError 20 | 21 | def summary(self): 22 | """Network summary.""" 23 | net_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in net_parameters]) 25 | self.logger.info('Trainable parameters: {}'.format(params)) 26 | self.logger.info(self) 27 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/base/base_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | class BaseADDataset(ABC): 6 | """Anomaly detection dataset base class.""" 7 | 8 | def __init__(self, root: str): 9 | super().__init__() 10 | self.root = root # root path to data 11 | 12 | self.n_classes = 2 # 0: normal, 1: outlier 13 | self.normal_classes = None # tuple with original class labels that define the normal class 14 | self.outlier_classes = None # tuple with original class labels that define the outlier class 15 | 16 | self.train_set = None # must be of type torch.utils.data.Dataset 17 | self.test_set = None # must be of type torch.utils.data.Dataset 18 | 19 | @abstractmethod 20 | def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> ( 21 | DataLoader, DataLoader): 22 | """Implement data loaders of type torch.utils.data.DataLoader for train_set and test_set.""" 23 | pass 24 | 25 | def __repr__(self): 26 | return self.__class__.__name__ 27 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/base/base_trainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from .base_dataset import BaseADDataset 3 | from .base_net import BaseNet 4 | 5 | 6 | class BaseTrainer(ABC): 7 | """Trainer base class.""" 8 | 9 | def __init__(self, optimizer_name: str, lr: float, n_epochs: int, lr_milestones: tuple, batch_size: int, 10 | weight_decay: float, device: str, n_jobs_dataloader: int): 11 | super().__init__() 12 | self.optimizer_name = optimizer_name 13 | self.lr = lr 14 | self.n_epochs = n_epochs 15 | self.lr_milestones = lr_milestones 16 | self.batch_size = batch_size 17 | self.weight_decay = weight_decay 18 | self.device = device 19 | self.n_jobs_dataloader = n_jobs_dataloader 20 | 21 | @abstractmethod 22 | def train(self, dataset: BaseADDataset, net: BaseNet) -> BaseNet: 23 | """ 24 | Implement train method that trains the given network using the train_set of dataset. 25 | :return: Trained net 26 | """ 27 | pass 28 | 29 | @abstractmethod 30 | def test(self, dataset: BaseADDataset, net: BaseNet): 31 | """ 32 | Implement test method that evaluates the test_set of dataset on the given network. 33 | """ 34 | pass 35 | -------------------------------------------------------------------------------- /EAL-GAN/README.md: -------------------------------------------------------------------------------- 1 | 2 | EAL-GAN: Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning 3 | == 4 | This is the official implementation of “Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning”. Our paper has been submitted to the IEEE for possible publication, and a Preprint version of the manuscript can be found in Arxiv. If you use the codes in this repo, please cite the paper as follows:
5 | 6 | > @misc{chen2021supervised,
7 |      title={Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning},
8 |     author={Zhi Chen and Jiang Duan and Li Kang and Guoping Qiu},
9 |     year={2021},
10 |     eprint={2104.11952},
11 |     archivePrefix={arXiv},
12 |     primaryClass={cs.LG}
13 | } 14 | 15 |     We have implemented two versions of the proposed EAL-GAN, including:(1) one for classical anomaly detection datastes, whose codes are in the "EAL-GAN" folder, and (2) one for the image datasets, whose codes are in the "EAL-GAN-image" folder. Please note the EAL-GAN-image is sensitive to the learning rate and initiation. We print the changes in the loss of discriminators and generator during training, and if the loss become NAN in the first epoch, you should re-run the code with smaller learning rate. If that doesn't happen, you can expect a promising result. 16 | 17 | 18 | Requirements 19 | === 20 | Pytorch >1.6
21 | Python 3.7
22 | 23 | Getting started 24 | === 25 | (1) You can run the script “train_EAL_GAN.py” to train the model proposed in our paper.
26 | (2) Some of the datasets in our paper are given in the folder “/data”.
27 | (3) Models/EAL-GAN.py is the proposed model.
28 | (4) Models/losses.py is the loss functions.
29 | 30 | 31 | Acknowledgments 32 | == 33 | some of our codes (e.g., Spectral Normalization) are extracted from the [PyTorch implementation of BigGAN]( https://github.com/ajbrock/BigGAN-PyTorch). 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | EAL-GAN: Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning 3 | == 4 | This is the official implementation of “Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning”. Our paper has been officially accpted by IEEE Transactions on Pattern Analysis and Machine Intelligence, and a Preprint version of the manuscript can be found in Arxiv. If you use the code in this repo, please cite the paper as follows:
5 | 6 | >@ARTICLE{chen202-EALGAN,
7 |      author={Chen, Zhi and Duan, Jiang and Kang, Li and Qiu, Guoping},
8 |      journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
9 |      title={Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning},
10 |      year={2023},
11 |      volume={45},
12 |      number={6},
13 |      pages={{7781-7798},
14 |      doi={10.1109/TPAMI.2022.3225476}}
15 | 16 |     We have implemented two versions of the proposed EAL-GAN, including:(1) one for classical anomaly detection datastes, whose codes are in the "EAL-GAN" folder, and (2) one for the image datasets, whose codes are in the "EAL-GAN-image" folder. Please note the EAL-GAN-image is sensitive to the learning rate and initiation. We print the changes in the loss of discriminators and generator during training, and if the loss become NAN in the first epoch, you should re-run the code with smaller learning rate. If that doesn't happen, you can expect a promising result. 17 | 18 | 19 | Requirements 20 | === 21 | Pytorch >1.6
22 | Python 3.7
23 | 24 | Getting started 25 | === 26 | (1) You can run the script “train_EAL_GAN.py” to train the model proposed in our paper.
27 | (2) Some of the datasets in our paper are given in the folder “/data”.
28 | (3) Models/EAL-GAN.py is the proposed model.
29 | (4) Models/losses.py is the loss functions.
30 | 31 | 32 | Acknowledgments 33 | == 34 | some of our codes (e.g., Spectral Normalization) are extracted from the [PyTorch implementation of BigGAN]( https://github.com/ajbrock/BigGAN-PyTorch). 35 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/datasets/odds.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Subset 2 | from src.base.base_dataset import BaseADDataset 3 | from src.base.odds_dataset import ODDSDataset 4 | from src.datasets.preprocessing import create_semisupervised_setting 5 | 6 | import torch 7 | 8 | 9 | class ODDSADDataset(BaseADDataset): 10 | 11 | def __init__(self, root: str, dataset_name: str, n_known_outlier_classes: int = 0, ratio_known_normal: float = 0.0, 12 | ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0, random_state=None): 13 | super().__init__(root) 14 | 15 | # Define normal and outlier classes 16 | self.n_classes = 2 # 0: normal, 1: outlier 17 | self.normal_classes = (0,) 18 | self.outlier_classes = (1,) 19 | 20 | if n_known_outlier_classes == 0: 21 | self.known_outlier_classes = () 22 | else: 23 | self.known_outlier_classes = (1,) 24 | 25 | # Get train set 26 | train_set = ODDSDataset(root=self.root, dataset_name=dataset_name, train=True, random_state=random_state, 27 | download=True) 28 | 29 | # Create semi-supervised setting 30 | idx, _, semi_targets = create_semisupervised_setting(train_set.targets.cpu().data.numpy(), self.normal_classes, 31 | self.outlier_classes, self.known_outlier_classes, 32 | ratio_known_normal, ratio_known_outlier, ratio_pollution) 33 | semi_targets = torch.tensor(semi_targets) 34 | if semi_targets.shape[0]>=train_set.semi_targets.shape[0]: 35 | semi_targets = semi_targets[0:train_set.semi_targets.shape[0]] 36 | train_set.semi_targets[idx] = semi_targets # set respective semi-supervised labels 37 | 38 | # Subset train_set to semi-supervised setup 39 | self.train_set = Subset(train_set, idx) 40 | 41 | # Get test set 42 | self.test_set = ODDSDataset(root=self.root, dataset_name=dataset_name, train=False, random_state=random_state) 43 | 44 | def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (DataLoader, DataLoader): 45 | train_loader = DataLoader(dataset=self.train_set, batch_size=batch_size, shuffle=shuffle_train, 46 | num_workers=num_workers, drop_last=True) 47 | test_loader = DataLoader(dataset=self.test_set, batch_size=batch_size, shuffle=shuffle_test, 48 | num_workers=num_workers, drop_last=False) 49 | return train_loader, test_loader 50 | -------------------------------------------------------------------------------- /EAL-GAN/models/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /EAL-GAN-image/Train_EAL_GAN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | from time import time 8 | 9 | sys.path.append( 10 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 11 | # supress warnings for clean output 12 | import warnings 13 | 14 | warnings.filterwarnings("ignore") 15 | 16 | import numpy as np 17 | import pandas as pd 18 | from sklearn.model_selection import train_test_split 19 | from scipy.io import loadmat 20 | 21 | import torch 22 | 23 | from src.my_utils import parse_args 24 | from src.EAL_GAN import EAL_GAN 25 | from src.datasets.main import load_dataset 26 | 27 | if __name__=='__main__': 28 | # Define data file and read X and y 29 | data_root = './data' 30 | save_dir = './results' 31 | n_ite = 10 32 | 33 | data_names = [#'mnist', 34 | 'fmnist', 35 | #'cifar10', 36 | ] 37 | data_resolution = {'mnist':32, 'fmnist':32, 'cifar10':32} 38 | # define the number of iterations 39 | df_columns = ['Data', 'auc_mean', 'auc_std', 'gmean', 'gmean_std'] 40 | 41 | # initialize the container for saving the results 42 | roc_df = pd.DataFrame(columns=df_columns) 43 | gmean_df = pd.DataFrame(columns=df_columns) 44 | 45 | args = parse_args() 46 | roc_df = pd.DataFrame(columns=df_columns) 47 | gmean_df = pd.DataFrame(columns=df_columns) 48 | outlier_class = 0 49 | 50 | 51 | 52 | for data_name in data_names: 53 | roc_list = [data_name] 54 | gmean_list = [data_name] 55 | 56 | roc_mat = np.zeros([1, 1]) 57 | gmean_mat = np.zeros([1, 1]) 58 | 59 | args.feat_dim = 3 if data_name=='cifar10' else 1 60 | #for i in range(0, n_ite): 61 | dataset = load_dataset(data_name, data_root, args.normal_class, outlier_class, args.n_known_outlier_classes, 62 | args.ratio_known_normal, args.ratio_known_outlier, args.ratio_pollution, 63 | random_state=args.seed, resolution=data_resolution[data_name]) 64 | 65 | 66 | #todo: add my method 67 | cb_gan = EAL_GAN(args, dataset) 68 | best_auc, best_gmean = cb_gan.fit() 69 | 70 | print('AUC:%.4f, Gmean:%.4f ' % (best_auc, best_gmean)) 71 | 72 | roc_mat[0, 0] = best_auc 73 | gmean_mat[0, 0] = best_gmean 74 | 75 | roc_list = roc_list + np.mean(roc_mat, axis=0).tolist() + np.std(roc_mat, axis=0).tolist() + np.mean(gmean_mat, axis=0).tolist()+np.std(gmean_mat, axis=0).tolist() 76 | temp_df = pd.DataFrame(roc_list).transpose() 77 | temp_df.columns = df_columns 78 | roc_df = pd.concat([roc_df, temp_df], axis=0) 79 | 80 | # Save the results for each run 81 | save_path1 = os.path.join(save_dir, "AUC_EAL_GAN.csv") 82 | roc_df.to_csv(save_path1, index=False, float_format='%.4f') 83 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pandas as pd 4 | import numpy as np 5 | import os 6 | import torch.nn.functional as F 7 | 8 | def loss_dis_real(dis_real, out_category, y, weights=None, cat_weight=None, gamma=2.0): 9 | #step 1: the loss for GAN 10 | logpt = F.softplus(-dis_real) 11 | pt = torch.exp(-logpt) 12 | if weights is None: 13 | weights = pt.detach() 14 | p = pt*1 15 | #print(dis_real.shape) 16 | #print(p.shape) 17 | p = p.view(dis_real.shape[0], 1) 18 | p = (1-p)**gamma 19 | loss = p.clone().detach() * logpt 20 | else: 21 | weights = torch.cat([weights, pt], 1) 22 | p = torch.mean(weights, 1) 23 | p = p.view(len(dis_real), 1) 24 | p = (1-p)**gamma 25 | loss = p.clone().detach()*logpt 26 | loss = torch.mean(loss) 27 | 28 | #step 2: loss for classifying 29 | target = y.view(y.size(0), 1) 30 | pt_cat = (1.-target.float())*(1-out_category) + target.float()*out_category 31 | logpt_cat = -torch.log(pt_cat) 32 | batch_size = target.size(0) 33 | 34 | if cat_weight is None: 35 | cat_weight = pt_cat.detach() 36 | p = pt_cat*1 37 | p = p.view(batch_size, 1) 38 | p = (1-p)**gamma 39 | else: 40 | cat_weight = torch.cat([cat_weight, pt_cat], 1) 41 | p = torch.mean(cat_weight, 1) 42 | p = p.view(batch_size, 1) 43 | p = (1-p)**gamma 44 | logpt_cat = p.clone().detach()*logpt_cat 45 | loss_cat = torch.mean(logpt_cat) 46 | return loss, loss_cat, weights, cat_weight 47 | 48 | 49 | def loss_dis_fake(dis_fake, out_category, y, weights=None, cat_weight=None, gamma=2.0): 50 | logpt = F.softplus(dis_fake) 51 | pt = torch.exp(-logpt) 52 | 53 | if weights is None: 54 | p = pt*1 55 | p = p.view(len(dis_fake), 1) 56 | p = (1-p)**gamma 57 | loss = p.clone().detach() * logpt 58 | weights = pt.detach() 59 | #loss = logpt 60 | else: 61 | weights = torch.cat([weights, pt], 1) 62 | p = torch.mean(weights, 1) 63 | p = p.view(len(dis_fake), 1) 64 | p = (1-p)**gamma 65 | loss = p.clone().detach()*logpt 66 | 67 | loss = torch.mean(loss) 68 | 69 | #step 2: loss for classifying 70 | target = y.view(y.size(0), 1) 71 | pt_cat = (1.-target.float())*(1-out_category) + target.float()*out_category 72 | logpt_cat = -torch.log(pt_cat) 73 | batch_size = target.size(0) 74 | 75 | if cat_weight is None: 76 | cat_weight = pt_cat.detach() 77 | p = pt_cat*1 78 | p = p.view(batch_size, 1) 79 | p = (1-p)**gamma 80 | else: 81 | cat_weight = torch.cat([cat_weight, pt_cat], 1) 82 | p = torch.mean(cat_weight, 1) 83 | p = p.view(batch_size, 1) 84 | p = (1-p)**gamma 85 | logpt_cat = p.clone().detach()*logpt_cat 86 | loss_cat = torch.mean(logpt_cat) 87 | return loss, loss_cat, weights, cat_weight 88 | 89 | -------------------------------------------------------------------------------- /EAL-GAN/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pandas as pd 4 | import numpy as np 5 | import os 6 | import torch.nn.functional as F 7 | 8 | def loss_dis_real(dis_real, out_category, y, weights=None, cat_weight=None, gamma=2.0): 9 | #step 1: the loss for GAN 10 | logpt = F.softplus(-dis_real) 11 | pt = torch.exp(-logpt) 12 | if weights is None: 13 | weights = pt.detach() 14 | p = pt*1 15 | #print(dis_real.shape) 16 | #print(p.shape) 17 | p = p.view(dis_real.shape[0], 1) 18 | p = (1-p)**gamma 19 | loss = p.clone().detach() * logpt 20 | else: 21 | weights = torch.cat([weights, pt], 1) 22 | p = torch.mean(weights, 1) 23 | p = p.view(len(dis_real), 1) 24 | p = (1-p)**gamma 25 | loss = p.clone().detach()*logpt 26 | loss = torch.mean(loss) 27 | 28 | #step 2: loss for classifying 29 | target = y.view(y.size(0), 1) 30 | pt_cat = (1.-target.float())*(1-out_category) + target.float()*out_category 31 | logpt_cat = -torch.log(pt_cat) 32 | batch_size = target.size(0) 33 | 34 | if cat_weight is None: 35 | cat_weight = pt_cat.detach() 36 | p = pt_cat*1 37 | p = p.view(batch_size, 1) 38 | p = (1-p)**gamma 39 | else: 40 | cat_weight = torch.cat([cat_weight, pt_cat], 1) 41 | p = torch.mean(cat_weight, 1) 42 | p = p.view(batch_size, 1) 43 | p = (1-p)**gamma 44 | logpt_cat = p.clone().detach()*logpt_cat 45 | loss_cat = torch.mean(logpt_cat) 46 | return loss, loss_cat, weights, cat_weight 47 | 48 | 49 | def loss_dis_fake(dis_fake, out_category, y, weights=None, cat_weight=None, gamma=2.0): 50 | logpt = F.softplus(dis_fake) 51 | pt = torch.exp(-logpt) 52 | 53 | if weights is None: 54 | p = pt*1 55 | p = p.view(len(dis_fake), 1) 56 | p = (1-p)**gamma 57 | loss = p.clone().detach() * logpt 58 | weights = pt.detach() 59 | #loss = logpt 60 | else: 61 | weights = torch.cat([weights, pt], 1) 62 | p = torch.mean(weights, 1) 63 | p = p.view(len(dis_fake), 1) 64 | p = (1-p)**gamma 65 | loss = p.clone().detach()*logpt 66 | 67 | loss = torch.mean(loss) 68 | 69 | #step 2: loss for classifying 70 | target = y.view(y.size(0), 1) 71 | pt_cat = (1.-target.float())*(1-out_category) + target.float()*out_category 72 | logpt_cat = -torch.log(pt_cat) 73 | batch_size = target.size(0) 74 | 75 | if cat_weight is None: 76 | cat_weight = pt_cat.detach() 77 | p = pt_cat*1 78 | p = p.view(batch_size, 1) 79 | p = (1-p)**gamma 80 | else: 81 | cat_weight = torch.cat([cat_weight, pt_cat], 1) 82 | p = torch.mean(cat_weight, 1) 83 | p = p.view(batch_size, 1) 84 | p = (1-p)**gamma 85 | logpt_cat = p.clone().detach()*logpt_cat 86 | loss_cat = torch.mean(logpt_cat) 87 | return loss, loss_cat, weights, cat_weight 88 | 89 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/datasets/main.py: -------------------------------------------------------------------------------- 1 | from .mnist import MNIST_Dataset 2 | from .fmnist import FashionMNIST_Dataset 3 | from .cifar10 import CIFAR10_Dataset 4 | from .odds import ODDSADDataset 5 | 6 | 7 | def load_dataset(dataset_name, data_path, normal_class, known_outlier_class, n_known_outlier_classes: int = 0, 8 | ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0, 9 | random_state=None, resolution=32): 10 | """Loads the dataset.""" 11 | 12 | #implemented_datasets = ('mnist', 'fmnist', 'cifar10', 13 | # 'arrhythmia', 'cardio', 'satellite', 'satimage-2', 'shuttle', 'thyroid') 14 | #assert dataset_name in implemented_datasets 15 | 16 | dataset = None 17 | 18 | 19 | if dataset_name == 'mnist': 20 | dataset = MNIST_Dataset(root=data_path, 21 | normal_class=normal_class, 22 | known_outlier_class=known_outlier_class, 23 | n_known_outlier_classes=n_known_outlier_classes, 24 | ratio_known_normal=ratio_known_normal, 25 | ratio_known_outlier=ratio_known_outlier, 26 | ratio_pollution=ratio_pollution, 27 | resolution=resolution) 28 | 29 | elif dataset_name == 'fmnist': 30 | dataset = FashionMNIST_Dataset(root=data_path, 31 | normal_class=normal_class, 32 | known_outlier_class=known_outlier_class, 33 | n_known_outlier_classes=n_known_outlier_classes, 34 | ratio_known_normal=ratio_known_normal, 35 | ratio_known_outlier=ratio_known_outlier, 36 | ratio_pollution=ratio_pollution, 37 | resolution=resolution) 38 | 39 | elif dataset_name == 'cifar10': 40 | dataset = CIFAR10_Dataset(root=data_path, 41 | normal_class=normal_class, 42 | known_outlier_class=known_outlier_class, 43 | n_known_outlier_classes=n_known_outlier_classes, 44 | ratio_known_normal=ratio_known_normal, 45 | ratio_known_outlier=ratio_known_outlier, 46 | ratio_pollution=ratio_pollution, 47 | resolution=resolution) 48 | 49 | else: 50 | dataset = ODDSADDataset(root=data_path, 51 | dataset_name=dataset_name, 52 | n_known_outlier_classes=n_known_outlier_classes, 53 | ratio_known_normal=ratio_known_normal, 54 | ratio_known_outlier=ratio_known_outlier, 55 | ratio_pollution=ratio_pollution, 56 | random_state=random_state) 57 | 58 | return dataset 59 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /EAL-GAN/models/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def create_semisupervised_setting(labels, normal_classes, outlier_classes, known_outlier_classes, 6 | ratio_known_normal, ratio_known_outlier, ratio_pollution): 7 | """ 8 | Create a semi-supervised data setting. 9 | :param labels: np.array with labels of all dataset samples 10 | :param normal_classes: tuple with normal class labels 11 | :param outlier_classes: tuple with anomaly class labels 12 | :param known_outlier_classes: tuple with known (labeled) anomaly class labels 13 | :param ratio_known_normal: the desired ratio of known (labeled) normal samples 14 | :param ratio_known_outlier: the desired ratio of known (labeled) anomalous samples 15 | :param ratio_pollution: the desired pollution ratio of the unlabeled data with unknown (unlabeled) anomalies. 16 | :return: tuple with list of sample indices, list of original labels, and list of semi-supervised labels 17 | """ 18 | idx_normal = np.argwhere(np.isin(labels, normal_classes)).flatten() 19 | idx_outlier = np.argwhere(np.isin(labels, outlier_classes)).flatten() 20 | idx_known_outlier_candidates = np.argwhere(np.isin(labels, known_outlier_classes)).flatten() 21 | 22 | n_normal = len(idx_normal) 23 | 24 | # Solve system of linear equations to obtain respective number of samples 25 | a = np.array([[1, 1, 0, 0], 26 | [(1-ratio_known_normal), -ratio_known_normal, -ratio_known_normal, -ratio_known_normal], 27 | [-ratio_known_outlier, -ratio_known_outlier, -ratio_known_outlier, (1-ratio_known_outlier)], 28 | [0, -ratio_pollution, (1-ratio_pollution), 0]]) 29 | b = np.array([n_normal, 0, 0, 0]) 30 | x = np.linalg.solve(a, b) 31 | 32 | # Get number of samples 33 | n_known_normal = int(x[0]) 34 | n_unlabeled_normal = int(x[1]) 35 | n_unlabeled_outlier = int(x[2]) 36 | n_known_outlier = int(x[3]) 37 | 38 | # Sample indices 39 | perm_normal = np.random.permutation(n_normal) 40 | perm_outlier = np.random.permutation(len(idx_outlier)) 41 | perm_known_outlier = np.random.permutation(len(idx_known_outlier_candidates)) 42 | 43 | idx_known_normal = idx_normal[perm_normal[:n_known_normal]].tolist() 44 | idx_unlabeled_normal = idx_normal[perm_normal[n_known_normal:n_known_normal+n_unlabeled_normal]].tolist() 45 | idx_unlabeled_outlier = idx_outlier[perm_outlier[:n_unlabeled_outlier]].tolist() 46 | idx_known_outlier = idx_known_outlier_candidates[perm_known_outlier[:n_known_outlier]].tolist() 47 | 48 | # Get original class labels 49 | labels_known_normal = labels[idx_known_normal].tolist() 50 | labels_unlabeled_normal = labels[idx_unlabeled_normal].tolist() 51 | labels_unlabeled_outlier = labels[idx_unlabeled_outlier].tolist() 52 | labels_known_outlier = labels[idx_known_outlier].tolist() 53 | 54 | # Get semi-supervised setting labels 55 | semi_labels_known_normal = np.ones(n_known_normal).astype(np.int32).tolist() 56 | semi_labels_unlabeled_normal = np.zeros(n_unlabeled_normal).astype(np.int32).tolist() 57 | semi_labels_unlabeled_outlier = np.zeros(n_unlabeled_outlier).astype(np.int32).tolist() 58 | semi_labels_known_outlier = (-np.ones(n_known_outlier).astype(np.int32)).tolist() 59 | 60 | # Create final lists 61 | list_idx = idx_known_normal + idx_unlabeled_normal + idx_unlabeled_outlier + idx_known_outlier 62 | list_labels = labels_known_normal + labels_unlabeled_normal + labels_unlabeled_outlier + labels_known_outlier 63 | list_semi_labels = (semi_labels_known_normal + semi_labels_unlabeled_normal + semi_labels_unlabeled_outlier 64 | + semi_labels_known_outlier) 65 | 66 | return list_idx, list_labels, list_semi_labels 67 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | from PIL import Image 3 | from torchvision.datasets import CIFAR10 4 | from src.base.torchvision_dataset import TorchvisionDataset 5 | from src.datasets.preprocessing import create_semisupervised_setting 6 | 7 | import torch 8 | import torchvision.transforms as transforms 9 | import random 10 | import numpy as np 11 | 12 | 13 | class CIFAR10_Dataset(TorchvisionDataset): 14 | 15 | def __init__(self, root: str, normal_class: int = 5, known_outlier_class: int = 3, n_known_outlier_classes: int = 0, 16 | ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0, resolution=32): 17 | super().__init__(root) 18 | 19 | # Define normal and outlier classes 20 | self.n_classes = 2 # 0: normal, 1: outlier 21 | self.normal_classes = tuple([normal_class]) 22 | #self.outlier_classes = list(range(0, 10)) 23 | #self.outlier_classes.remove(normal_class) 24 | self.outlier_classes = tuple([known_outlier_class]) 25 | 26 | if n_known_outlier_classes == 0: 27 | self.known_outlier_classes = () 28 | elif n_known_outlier_classes == 1: 29 | self.known_outlier_classes = tuple([known_outlier_class]) 30 | else: 31 | self.known_outlier_classes = tuple(random.sample(self.outlier_classes, n_known_outlier_classes)) 32 | 33 | # CIFAR-10 preprocessing: feature scaling to [0, 1] 34 | #transform = transforms.ToTensor() 35 | transform = transforms.Compose([ 36 | transforms.Resize(resolution+2), 37 | transforms.CenterCrop(resolution), 38 | transforms.ToTensor() 39 | ]) 40 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes)) 41 | 42 | # Get train set 43 | train_set = MyCIFAR10(root=self.root, train=True, transform=transform, target_transform=target_transform, 44 | download=True) 45 | 46 | ''' 47 | # Create semi-supervised setting 48 | idx, _, semi_targets = create_semisupervised_setting(np.array(train_set.targets), self.normal_classes, 49 | self.outlier_classes, self.known_outlier_classes, 50 | ratio_known_normal, ratio_known_outlier, ratio_pollution) 51 | train_set.semi_targets[idx] = torch.tensor(semi_targets) # set respective semi-supervised labels 52 | 53 | # Subset train_set to semi-supervised setup 54 | self.train_set = Subset(train_set, idx) 55 | ''' 56 | self.train_set = train_set 57 | # Get test set 58 | self.test_set = MyCIFAR10(root=self.root, train=False, transform=transform, target_transform=target_transform, 59 | download=True) 60 | 61 | 62 | class MyCIFAR10(CIFAR10): 63 | """ 64 | Torchvision CIFAR10 class with additional targets for the semi-supervised setting and patch of __getitem__ method 65 | to also return the semi-supervised target as well as the index of a data sample. 66 | """ 67 | 68 | def __init__(self, *args, **kwargs): 69 | super(MyCIFAR10, self).__init__(*args, **kwargs) 70 | 71 | self.semi_targets = torch.zeros(len(self.targets), dtype=torch.int64) 72 | 73 | def __getitem__(self, index): 74 | """Override the original method of the CIFAR10 class. 75 | Args: 76 | index (int): Index 77 | 78 | Returns: 79 | tuple: (image, target, semi_target, index) 80 | """ 81 | img, target, semi_target = self.data[index], self.targets[index], int(self.semi_targets[index]) 82 | 83 | # doing this so that it is consistent with all other datasets 84 | # to return a PIL Image 85 | img = Image.fromarray(img) 86 | 87 | if self.transform is not None: 88 | img = self.transform(img) 89 | 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | 93 | return img, target, semi_target, index 94 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | from PIL import Image 3 | from torchvision.datasets import MNIST 4 | from src.base.torchvision_dataset import TorchvisionDataset 5 | from src.datasets.preprocessing import create_semisupervised_setting 6 | 7 | import torch 8 | import torchvision.transforms as transforms 9 | import random 10 | 11 | 12 | class MNIST_Dataset(TorchvisionDataset): 13 | 14 | def __init__(self, root: str, normal_class: int = 0, known_outlier_class: int = 1, n_known_outlier_classes: int = 0, 15 | ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0, resolution=32): 16 | super().__init__(root) 17 | 18 | # Define normal and outlier classes 19 | self.n_classes = 2 # 0: normal, 1: outlier 20 | self.normal_classes = tuple([normal_class]) 21 | #self.outlier_classes = list(range(0, 10)) 22 | #self.outlier_classes.remove(normal_class) 23 | #self.outlier_classes = list(known_outlier_class) 24 | self.outlier_classes = tuple([known_outlier_class]) 25 | 26 | if n_known_outlier_classes == 0: 27 | self.known_outlier_classes = () 28 | elif n_known_outlier_classes == 1: 29 | self.known_outlier_classes = tuple([known_outlier_class]) 30 | else: 31 | self.known_outlier_classes = tuple(random.sample(self.outlier_classes, n_known_outlier_classes)) 32 | 33 | # MNIST preprocessing: feature scaling to [0, 1] 34 | transform = transforms.Compose([ 35 | transforms.Resize(resolution+2), 36 | transforms.CenterCrop(resolution), 37 | transforms.ToTensor() 38 | ]) 39 | 40 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes)) 41 | 42 | # Get train set 43 | train_set = MyMNIST(root=self.root, train=True, transform=transform, target_transform=target_transform, 44 | download=True) 45 | 46 | # Create semi-supervised setting 47 | ''' 48 | idx, _, semi_targets = create_semisupervised_setting(train_set.targets.cpu().data.numpy(), self.normal_classes, 49 | self.outlier_classes, self.known_outlier_classes, 50 | ratio_known_normal, ratio_known_outlier, ratio_pollution) 51 | train_set.semi_targets[idx] = torch.tensor(semi_targets) # set respective semi-supervised labels 52 | 53 | # Subset train_set to semi-supervised setup 54 | self.train_set = Subset(train_set, idx) 55 | ''' 56 | self.train_set = train_set 57 | 58 | # Get test set 59 | self.test_set = MyMNIST(root=self.root, train=False, transform=transform, target_transform=target_transform, 60 | download=True) 61 | 62 | 63 | class MyMNIST(MNIST): 64 | """ 65 | Torchvision MNIST class with additional targets for the semi-supervised setting and patch of __getitem__ method 66 | to also return the semi-supervised target as well as the index of a data sample. 67 | """ 68 | 69 | def __init__(self, *args, **kwargs): 70 | super(MyMNIST, self).__init__(*args, **kwargs) 71 | 72 | self.semi_targets = torch.zeros_like(self.targets) 73 | 74 | def __getitem__(self, index): 75 | """Override the original method of the MNIST class. 76 | Args: 77 | index (int): Index 78 | 79 | Returns: 80 | tuple: (image, target, semi_target, index) 81 | """ 82 | img, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index]) 83 | 84 | # doing this so that it is consistent with all other datasets 85 | # to return a PIL Image 86 | img = Image.fromarray(img.numpy(), mode='L') 87 | 88 | if self.transform is not None: 89 | img = self.transform(img) 90 | 91 | if self.target_transform is not None: 92 | target = self.target_transform(target) 93 | 94 | return img, target, semi_target, index 95 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/datasets/fmnist.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Subset 2 | from PIL import Image 3 | from torchvision.datasets import FashionMNIST 4 | from src.base.torchvision_dataset import TorchvisionDataset 5 | from src.datasets.preprocessing import create_semisupervised_setting 6 | 7 | import torch 8 | import torchvision.transforms as transforms 9 | import random 10 | 11 | 12 | class FashionMNIST_Dataset(TorchvisionDataset): 13 | 14 | def __init__(self, root: str, normal_class: int = 0, known_outlier_class: int = 1, n_known_outlier_classes: int = 0, 15 | ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0, resolution=32): 16 | super().__init__(root) 17 | 18 | # Define normal and outlier classes 19 | self.n_classes = 2 # 0: normal, 1: outlier 20 | self.normal_classes = tuple([normal_class]) 21 | #self.outlier_classes = list(range(0, 10)) 22 | #self.outlier_classes.remove(normal_class) 23 | #self.outlier_classes = list(known_outlier_class) 24 | self.outlier_classes = tuple([known_outlier_class]) 25 | 26 | if n_known_outlier_classes == 0: 27 | self.known_outlier_classes = () 28 | elif n_known_outlier_classes == 1: 29 | self.known_outlier_classes = tuple([known_outlier_class]) 30 | else: 31 | self.known_outlier_classes = tuple(random.sample(self.outlier_classes, n_known_outlier_classes)) 32 | 33 | # FashionMNIST preprocessing: feature scaling to [0, 1] 34 | transform = transforms.Compose([ 35 | transforms.Resize(resolution+2), 36 | transforms.CenterCrop(resolution), 37 | transforms.ToTensor() 38 | ]) 39 | #transform = transforms.ToTensor() 40 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes)) 41 | 42 | 43 | # Get train set 44 | train_set = MyFashionMNIST(root=self.root, train=True, transform=transform, target_transform=target_transform, 45 | download=True) 46 | ''' 47 | # Create semi-supervised setting 48 | idx, _, semi_targets = create_semisupervised_setting(train_set.targets.cpu().data.numpy(), self.normal_classes, 49 | self.outlier_classes, self.known_outlier_classes, 50 | ratio_known_normal, ratio_known_outlier, ratio_pollution) 51 | train_set.semi_targets[idx] = torch.tensor(semi_targets) # set respective semi-supervised labels 52 | 53 | # Subset train_set to semi-supervised setup 54 | self.train_set = Subset(train_set, idx) 55 | ''' 56 | self.train_set = train_set 57 | # Get test set 58 | self.test_set = MyFashionMNIST(root=self.root, train=False, transform=transform, 59 | target_transform=target_transform, download=True) 60 | 61 | 62 | class MyFashionMNIST(FashionMNIST): 63 | """ 64 | Torchvision FashionMNIST class with additional targets for the semi-supervised setting and patch of __getitem__ 65 | method to also return the semi-supervised target as well as the index of a data sample. 66 | """ 67 | 68 | def __init__(self, *args, **kwargs): 69 | super(MyFashionMNIST, self).__init__(*args, **kwargs) 70 | 71 | self.semi_targets = torch.zeros_like(self.targets) 72 | 73 | def __getitem__(self, index): 74 | """Override the original method of the MyFashionMNIST class. 75 | Args: 76 | index (int): Index 77 | 78 | Returns: 79 | tuple: (image, target, semi_target, index) 80 | """ 81 | img, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index]) 82 | 83 | # doing this so that it is consistent with all other datasets 84 | # to return a PIL Image 85 | img = Image.fromarray(img.numpy(), mode='L') 86 | 87 | if self.transform is not None: 88 | img = self.transform(img) 89 | 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | 93 | return img, target, semi_target, index 94 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /EAL-GAN/models/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/base/odds_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from torch.utils.data import Dataset 3 | from scipy.io import loadmat 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 6 | from torchvision.datasets.utils import download_url 7 | 8 | import pandas as pd 9 | 10 | import os 11 | import torch 12 | import numpy as np 13 | 14 | 15 | class ODDSDataset(Dataset): 16 | """ 17 | ODDSDataset class for datasets from Outlier Detection DataSets (ODDS): http://odds.cs.stonybrook.edu/ 18 | 19 | Dataset class with additional targets for the semi-supervised setting and modification of __getitem__ method 20 | to also return the semi-supervised target as well as the index of a data sample. 21 | """ 22 | 23 | urls = { 24 | 'arrhythmia': 'https://www.dropbox.com/s/lmlwuspn1sey48r/arrhythmia.mat?dl=1', 25 | 'cardio': 'https://www.dropbox.com/s/galg3ihvxklf0qi/cardio.mat?dl=1', 26 | 'satellite': 'https://www.dropbox.com/s/dpzxp8jyr9h93k5/satellite.mat?dl=1', 27 | 'satimage-2': 'https://www.dropbox.com/s/hckgvu9m6fs441p/satimage-2.mat?dl=1', 28 | 'shuttle': 'https://www.dropbox.com/s/mk8ozgisimfn3dw/shuttle.mat?dl=1', 29 | 'thyroid': 'https://www.dropbox.com/s/bih0e15a0fukftb/thyroid.mat?dl=1' 30 | } 31 | 32 | def __init__(self, root: str, dataset_name: str, train=True, random_state=None, download=False): 33 | super(Dataset, self).__init__() 34 | 35 | self.classes = [0, 1] 36 | 37 | if isinstance(root, torch._six.string_classes): 38 | root = os.path.expanduser(root) 39 | self.root = Path(root) 40 | self.dataset_name = dataset_name 41 | self.train = train # training set or test set 42 | self.file_name = self.dataset_name + '.mat' 43 | self.data_file = self.root / self.file_name 44 | 45 | if download: 46 | self.download() 47 | 48 | mat = loadmat(self.data_file) 49 | X = mat['X'] 50 | y = mat['y'].ravel() 51 | idx_norm = y == 0 52 | idx_out = y == 1 53 | 54 | # 60% data for training and 40% for testing; keep outlier ratio 55 | ''' 56 | X_train_norm, X_test_norm, y_train_norm, y_test_norm = train_test_split(X[idx_norm], y[idx_norm], 57 | test_size=0.4, 58 | random_state=random_state) 59 | X_train_out, X_test_out, y_train_out, y_test_out = train_test_split(X[idx_out], y[idx_out], 60 | test_size=0.4, 61 | random_state=random_state) 62 | X_train = np.concatenate((X_train_norm, X_train_out)) 63 | X_test = np.concatenate((X_test_norm, X_test_out)) 64 | y_train = np.concatenate((y_train_norm, y_train_out)) 65 | y_test = np.concatenate((y_test_norm, y_test_out)) 66 | ''' 67 | X_train, X_test, y_train, y_test = \ 68 | train_test_split(X, y, test_size=0.4, random_state=random_state) 69 | 70 | # Standardize data (per feature Z-normalization, i.e. zero-mean and unit variance) 71 | scaler = StandardScaler().fit(X_train) 72 | X_train_scaled = scaler.transform(X_train) 73 | X_test_scaled = scaler.transform(X_test) 74 | 75 | # Scale to range [0,1] 76 | #minmax_scaler = MinMaxScaler().fit(X_train_stand) 77 | #X_train_scaled = minmax_scaler.transform(X_train_stand) 78 | #X_test_scaled = minmax_scaler.transform(X_test_stand) 79 | 80 | X_train_pandas = pd.DataFrame(X_train_scaled) 81 | X_test_pandas = pd.DataFrame(X_test_scaled) 82 | X_train_pandas.fillna(X_train_pandas.mean(), inplace=True) 83 | X_test_pandas.fillna(X_train_pandas.mean(), inplace=True) 84 | X_train_scaled = X_train_pandas.values 85 | X_test_scaled = X_test_pandas.values 86 | 87 | if self.train: 88 | self.data = torch.tensor(X_train_scaled, dtype=torch.float32) 89 | self.targets = torch.tensor(y_train, dtype=torch.int64) 90 | else: 91 | self.data = torch.tensor(X_test_scaled, dtype=torch.float32) 92 | self.targets = torch.tensor(y_test, dtype=torch.int64) 93 | 94 | #print("fuck=", torch.sum(self.targets==1)) 95 | 96 | self.semi_targets = torch.zeros_like(self.targets) 97 | 98 | def __getitem__(self, index): 99 | """ 100 | Args: 101 | index (int): Index 102 | 103 | Returns: 104 | tuple: (sample, target, semi_target, index) 105 | """ 106 | sample, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index]) 107 | 108 | return sample, target, semi_target, index 109 | 110 | def __len__(self): 111 | return len(self.data) 112 | 113 | def _check_exists(self): 114 | return os.path.exists(self.data_file) 115 | 116 | def download(self): 117 | """Download the ODDS dataset if it doesn't exist in root already.""" 118 | 119 | if self._check_exists(): 120 | return 121 | 122 | # download file 123 | download_url(self.urls[self.dataset_name], self.root, self.file_name) 124 | 125 | print('Done!') 126 | -------------------------------------------------------------------------------- /EAL-GAN/Train_EAL_GAN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | from time import time 8 | 9 | sys.path.append( 10 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 11 | # supress warnings for clean output 12 | import warnings 13 | 14 | warnings.filterwarnings("ignore") 15 | 16 | import numpy as np 17 | import pandas as pd 18 | from sklearn.model_selection import train_test_split 19 | from scipy.io import loadmat 20 | 21 | import torch 22 | 23 | from models.pyod_utils import standardizer, AUC_and_Gmean 24 | from models.pyod_utils import precision_n_scores, gmean_scores 25 | from models.utils import parse_args 26 | from sklearn.metrics import roc_auc_score 27 | from models.EAL_GAN import EAL_GAN 28 | 29 | if __name__=='__main__': 30 | # Define data file and read X and y 31 | data_root = './data' 32 | save_dir = './results' 33 | 34 | 35 | data_names = [#'arrhythmia.mat', 36 | # 'cardio.mat', 37 | # 'glass.mat', 38 | # 'ionosphere.mat', 39 | # 'letter.mat', 40 | # 'lympho.mat', 41 | # 'mnist.mat', 42 | # 'musk.mat', 43 | 'optdigits.mat', 44 | # 'pendigits.mat', 45 | # 'pima.mat', 46 | # 'satellite.mat', 47 | # 'satimage-2.mat', 48 | # 'shuttle.mat', 49 | # 'vowels.mat', 50 | # 'annthyroid.mat', 51 | # 'campaign.mat', 52 | # 'celeba.mat', 53 | 'fraud.mat', 54 | 'donors.mat' 55 | ] 56 | # define the number of iterations 57 | n_ite = 10 58 | n_classifiers = 1 59 | 60 | df_columns = ['Data', '#Samples', '# Dimensions', 'Outlier Perc', 61 | 'AUC_Mean', 'AUC_Std', 'Gmean', 'Gmean_Std'] 62 | 63 | # initialize the container for saving the results 64 | roc_df = pd.DataFrame(columns=df_columns) 65 | gmean_df = pd.DataFrame(columns=df_columns) 66 | anomaly_ratio_df = pd.DataFrame(columns=df_columns) 67 | overall_ratio_df = pd.DataFrame(columns=df_columns) 68 | time_df = pd.DataFrame(columns=df_columns) 69 | args = parse_args() 70 | 71 | for data_name in data_names: 72 | mat = loadmat(os.path.join(data_root, data_name)) 73 | 74 | X = mat['X'] 75 | y = mat['y'].ravel() 76 | y = y.astype(np.long) 77 | idx_norm = y == 0 78 | idx_out = y == 1 79 | 80 | outliers_fraction = np.count_nonzero(y) / len(y) 81 | print(outliers_fraction) 82 | outliers_percentage = round(outliers_fraction * 100, ndigits=4) 83 | 84 | # construct containers for saving results 85 | roc_list = [data_name[:-4], X.shape[0], X.shape[1], outliers_percentage] 86 | gmean_list = [data_name[:-4], X.shape[0], X.shape[1], outliers_percentage] 87 | anomaly_ratio_list = [data_name[:-4], X.shape[0], X.shape[1], outliers_percentage] 88 | overall_ratio_list = [data_name[:-4], X.shape[0], X.shape[1], outliers_percentage] 89 | 90 | roc_mat = np.zeros([n_ite, n_classifiers]) 91 | gmean_mat = np.zeros([n_ite, n_classifiers]) 92 | anomaly_ratio_mat = np.zeros([n_ite, n_classifiers]) 93 | overall_ratio_mat = np.zeros([n_ite, n_classifiers]) 94 | 95 | for i in range(n_ite): 96 | print("\n... Processing", data_name[:-4], '...', 'Iteration', i + 1) 97 | random_state = np.random.RandomState(i) 98 | 99 | # 60% data for training and 40% for testing; keep outlier ratio 100 | 101 | X_train_norm, X_test_norm, y_train_norm, y_test_norm = train_test_split(X[idx_norm], y[idx_norm], 102 | test_size=0.4, 103 | random_state=random_state) 104 | X_train_out, X_test_out, y_train_out, y_test_out = train_test_split(X[idx_out], y[idx_out], 105 | test_size=0.4, 106 | random_state=random_state) 107 | X_train = np.concatenate((X_train_norm, X_train_out)) 108 | X_test = np.concatenate((X_test_norm, X_test_out)) 109 | y_train = np.concatenate((y_train_norm, y_train_out)) 110 | y_test = np.concatenate((y_test_norm, y_test_out)) 111 | 112 | X_train_norm, X_test_norm = standardizer(X_train, X_test) 113 | 114 | X_train_pandas = pd.DataFrame(X_train_norm) 115 | X_test_pandas = pd.DataFrame(X_test_norm) 116 | X_train_pandas.fillna(X_train_pandas.mean(), inplace=True) 117 | X_test_pandas.fillna(X_train_pandas.mean(), inplace=True) 118 | X_train_norm = X_train_pandas.values 119 | X_test_norm = X_test_pandas.values 120 | 121 | #X_train_norm, X_test_norm = X_train, X_test 122 | data_x = torch.from_numpy(X_train_norm).float() 123 | data_y = torch.from_numpy(y_train).long() 124 | test_x = torch.from_numpy(X_test_norm).float() 125 | test_y = torch.from_numpy(y_test).long() 126 | #print(data_x) 127 | 128 | t0 = time() 129 | #todo: add my method 130 | eal_gan = EAL_GAN(args, data_x, data_y, test_x, test_y) 131 | best_auc, best_gmean = eal_gan.fit() 132 | 133 | t1 = time() 134 | duration = round(t1 - t0, ndigits=4) 135 | 136 | 137 | print('AUC:%.4f, Gmean:%.4f execution time: %.4f s' % (best_auc, best_gmean, duration)) 138 | 139 | roc_mat[i, 0] = best_auc 140 | gmean_mat[i, 0] = best_gmean 141 | 142 | roc_list = roc_list + np.mean(roc_mat, axis=0).tolist() + np.std(roc_mat, axis=0).tolist() + np.mean(gmean_mat, axis=0).tolist() + np.std(gmean_mat, axis=0).tolist() 143 | temp_df = pd.DataFrame(roc_list).transpose() 144 | temp_df.columns = df_columns 145 | roc_df = pd.concat([roc_df, temp_df], axis=0) 146 | 147 | 148 | # Save the results for each run 149 | save_path1 = os.path.join(save_dir, "AUC_EAL_GAN.csv") 150 | 151 | roc_df.to_csv(save_path1, index=False, float_format='%.4f') 152 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/my_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import roc_auc_score 4 | 5 | import argparse 6 | 7 | def active_sampling(args, real_x, real_y, NetD_Ensemble, need_sample=True): 8 | if need_sample: 9 | pt = None 10 | for i in range(args.ensemble_num): 11 | netD = NetD_Ensemble[i] 12 | pt_i = netD(real_x, mode=1) #get the confidence on real data 13 | if i==0: 14 | pt = pt_i.detach() 15 | else: 16 | pt += pt_i.detach() 17 | pt /= args.ensemble_num 18 | pt = pt.view(pt.shape[0],) 19 | batch_size_selected = int(real_x.shape[0]*args.active_rate) 20 | batch_size_selected = max(1, batch_size_selected) 21 | pt = torch.abs(pt-0.5) # select the instance with low margin value 22 | _, idx = torch.sort(pt, descending=False) 23 | X = real_x[idx[0:batch_size_selected]].detach() 24 | Y = real_y[idx[0:batch_size_selected]].detach() 25 | X_unlabeled = real_x[idx[batch_size_selected:]].detach() 26 | Y = Y.view(Y.shape[0],) 27 | #X = X.view(X.shape[0], -1) 28 | return X, Y, X_unlabeled, idx[0:batch_size_selected] 29 | else: 30 | batch_size_selected = int(real_x.shape[0]*args.active_rate) 31 | batch_size_selected = max(1, batch_size_selected) 32 | X = real_x[0:batch_size_selected].detach() 33 | Y = real_y[0:batch_size_selected].detach() 34 | X_unlabeled = None 35 | idx = torch.tensor(np.arange(batch_size_selected)).view(-1,).long() 36 | return X, Y, X_unlabeled, idx 37 | 38 | def get_gmean(y, y_pred, threshold=0.5): 39 | """Utility function to calculate precision @ rank n. 40 | 41 | Parameters 42 | ---------- 43 | y : list or numpy array of shape (n_samples,) 44 | The ground truth. Binary (0: inliers, 1: outliers). 45 | 46 | y_pred : list or numpy array of shape (n_samples,) 47 | The raw outlier scores as returned by a fitted model. 48 | 49 | 50 | Returns 51 | ------- 52 | Gmean: float 53 | """ 54 | #y_pred = get_label_n(y, y_pred) 55 | y = y.reshape(-1, ) 56 | y_pred = y_pred.reshape(-1, ) 57 | y_pred = (y_pred >= threshold).astype('int') 58 | ones_all = (y==1).sum() 59 | ones_correct = ((y==1) & (y_pred==1)).sum() 60 | zeros_all = (y==0).sum() 61 | zeros_correct = ((y==0) & (y_pred==0)).sum() 62 | Gmean = np.sqrt((1.0*ones_correct/ones_all) * (1.0*zeros_correct/zeros_all)) 63 | #Gmean *= np.sqrt 64 | 65 | return Gmean 66 | 67 | def AUC_and_Gmean(y_test, y_scores): 68 | #print(y_test) 69 | #print(y_scores) 70 | 71 | auc = round(roc_auc_score(y_test, y_scores), ndigits=4) 72 | gmean = round(get_gmean(y_test, y_scores, 0.5), ndigits=4) 73 | 74 | return auc, gmean 75 | 76 | def parse_args(): 77 | ################################################################################ 78 | # Settings 79 | ################################################################################ 80 | parser = argparse.ArgumentParser() 81 | 82 | parser.add_argument('--dataset_name', default='cardio', type=str) 83 | parser.add_argument('--net_name', default='cardio_mlp', type=str) 84 | parser.add_argument('--xp_path', default='./log', type=str) 85 | parser.add_argument('--data_path', default='./data', type=str) 86 | parser.add_argument('--load_config', default=None, 87 | help='Config JSON-file path (default: None).') 88 | parser.add_argument('--load_model', default=None, 89 | help='Model file path (default: None).') 90 | parser.add_argument('--eta', type=float, default=1.0, help='Deep SAD hyperparameter eta (must be 0 < eta).') 91 | parser.add_argument('--ratio_known_normal', type=float, default=0.9, 92 | help='Ratio of known (labeled) normal training examples.') 93 | parser.add_argument('--ratio_known_outlier', type=float, default=0.5, 94 | help='Ratio of known (labeled) anomalous training examples.') 95 | parser.add_argument('--ratio_pollution', type=float, default=0.0, 96 | help='Pollution ratio of unlabeled training data with unknown (unlabeled) anomalies.') 97 | parser.add_argument('--device', type=str, default='cuda', help='Computation device to use ("cpu", "cuda", "cuda:2", etc.).') 98 | parser.add_argument('--seed', type=int, default=0, help='Set seed. If -1, use randomization.') 99 | parser.add_argument('--weight_decay', type=float, default=1e-6, 100 | help='Weight decay (L2 penalty) hyperparameter for Deep SAD objective.') 101 | 102 | parser.add_argument('--n_jobs_dataloader', type=int, default=0, 103 | help='Number of workers for data loading. 0 means that the data will be loaded in the main process.') 104 | parser.add_argument('--normal_class', type=int, default=0, 105 | help='Specify the normal class of the dataset (all other classes are considered anomalous).') 106 | parser.add_argument('--known_outlier_class', type=int, default=1, 107 | help='Specify the known outlier class of the dataset for semi-supervised anomaly detection.') 108 | parser.add_argument('--n_known_outlier_classes', type=int, default=1, 109 | help='Number of known outlier classes.' 110 | 'If 0, no anomalies are known.' 111 | 'If 1, outlier class as specified in --known_outlier_class option.' 112 | 'If > 1, the specified number of outlier classes will be sampled at random.') 113 | 114 | 115 | #parameter for GAN 116 | parser.add_argument('--channel', type=int, default=64, 117 | help='capacity of the generator and discriminator') 118 | parser.add_argument('--resolution', type=int, default=32, 119 | help='resolution of the images') 120 | parser.add_argument('--gamma', type=float, default=0.1, 121 | help='gamma for the scheduler') 122 | parser.add_argument('--step_size', type=int, default=10, 123 | help='step_size for the scheduler to adjust learning rate') 124 | parser.add_argument('--max_epochs', type=int, default=20, 125 | help='Stop training generator after stop_epochs.') 126 | parser.add_argument('--lr_g', type=float, default=0.0001, 127 | help='Learning rate of generator.') 128 | parser.add_argument('--lr_d', type=float, default=0.0001, 129 | help='Learning rate of discriminator.') 130 | parser.add_argument('--active_rate', type=float, default=0.05, 131 | help='the proportion of instances that need to be labeled.') 132 | parser.add_argument('--batch_size', type=int, default=200, 133 | help='batch size.') 134 | parser.add_argument('--dim_z', type=int, default=128, 135 | help='dim for latent noise.') 136 | parser.add_argument('--ensemble_num', type=int, default=5, 137 | help='the number of dis in ensemble.') 138 | parser.add_argument('--cuda', type=bool, default=True, 139 | help='if GPU used') 140 | parser.add_argument('--feat_dim', type=int, default=3, 141 | help='channel number of images') 142 | 143 | return parser.parse_args() 144 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/EAL_GAN.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import functools 3 | from numpy.random import gamma 4 | import torch 5 | import torch.nn as nn 6 | from torch.optim import Adam 7 | #import torch.optim.lr_scheduler.StepLR as StepLR 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from tqdm import tqdm 12 | 13 | 14 | from src.BigGANdeep import Generator, Discriminator 15 | from src.my_utils import active_sampling, AUC_and_Gmean 16 | from src.loss import loss_dis_fake, loss_dis_real 17 | import src.utils as utils 18 | 19 | class EAL_GAN(nn.Module): 20 | def __init__(self, args, dataset): 21 | super().__init__() 22 | self.args = args 23 | self.dataset = dataset 24 | self.device = 'cuda' if args.cuda else 'cpu' 25 | 26 | self.generator = Generator(G_ch=args.channel, dim_z=args.dim_z, resolution=args.resolution, n_classes=2, output_channel=args.feat_dim) 27 | self.optimizer_g = Adam(self.generator.parameters(), lr=args.lr_g, betas=(0.00, 0.99)) 28 | self.scheduler_g = torch.optim.lr_scheduler.StepLR(self.optimizer_g, step_size=args.step_size, gamma=args.gamma) 29 | if args.cuda: 30 | self.generator = nn.DataParallel(self.generator).cuda() 31 | 32 | self.NetD_Ensemble = [] 33 | self.Opti_Ensemble = [] 34 | self.schD_Ensemble = [] 35 | lr_ds = np.random.rand(args.ensemble_num)*(args.lr_d*5-args.lr_d)+args.lr_d #learning rate 36 | for index in range(args.ensemble_num): 37 | netD = Discriminator(D_ch=args.channel, resolution=args.resolution, n_classes=2, input_channel=args.feat_dim) 38 | optimizerD = Adam(netD.parameters(), lr=args.lr_d, betas=(0.00, 0.99)) 39 | schedule_D = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=args.step_size, gamma=args.gamma) 40 | if args.cuda: 41 | netD = nn.DataParallel(netD).cuda() 42 | 43 | self.NetD_Ensemble += [netD] 44 | self.Opti_Ensemble += [optimizerD] 45 | self.schD_Ensemble += [schedule_D] 46 | 47 | def fit(self): 48 | self.iterations = 0 49 | 50 | self.z, self.y = utils.prepare_z_y(self.args.batch_size, dim_z=self.args.dim_z, nclasses=2, device=self.device) 51 | z_for_sample, y_for_sample = utils.prepare_z_y(self.args.batch_size, dim_z=self.args.dim_z, nclasses=2, device=self.device) 52 | self.sample = functools.partial(utils.sample, G=self.generator, z_=z_for_sample, y_=y_for_sample) 53 | self.train_history = defaultdict(list) 54 | 55 | Best_Measure_Recorded = 0 56 | for epoch in range(self.args.max_epochs): 57 | 58 | auc_train, gmean_train, auc_test, gmean_test = self.train_one_epoch(epoch) 59 | if auc_train*gmean_train > Best_Measure_Recorded: 60 | Best_Measure_Recorded = auc_train*gmean_train 61 | Best_AUC = auc_test 62 | Best_Gmean = gmean_test 63 | 64 | states = { 65 | 'epoch':epoch, 66 | 'gen_dict':self.generator.state_dict(), 67 | 'auc_train':auc_train, 68 | 'auc_test': auc_test 69 | } 70 | for i in range(self.args.ensemble_num): 71 | netD = self.NetD_Ensemble[i] 72 | states['dis_dict'+str(i)] = netD.state_dict() 73 | 74 | torch.save(states, './logs/checkpoint_best.pth') 75 | #print(train_AUC, test_AUC, epoch) 76 | print('Training for epoch %d: Train_AUC=%.4f train_Gmean=%.4f Test_AUC=%.4f Test_Gmean=%.4f' % (epoch + 1, auc_train, gmean_train, auc_test, gmean_test)) 77 | 78 | 79 | ''' 80 | #step 1: load the best models 81 | self.Best_Ensemble = [] 82 | states = torch.load('./logs/checkpoint_best.pth') 83 | self.generator.load_state_dict(states['gen_dict']) 84 | for i in range(self.args.ensemble_num): 85 | netD = self.NetD_Ensemble[i] 86 | netD.load_state_dict(states['dis_dict'+str(i)]) 87 | self.Best_Ensemble += [netD] 88 | ''' 89 | 90 | return Best_AUC, Best_Gmean 91 | 92 | 93 | def train_one_epoch(self, epoch=1): 94 | train_loader, test_loader = self.dataset.loaders(batch_size=self.args.batch_size, num_workers=0) 95 | for (inputs, targets, semi_argets,_) in tqdm(train_loader): 96 | #print("inputs_shape=", inputs.shape) 97 | self.iterations += 1 98 | #step 1: update the discriminators 99 | if self.args.cuda: 100 | inputs = inputs.cuda() 101 | targets = targets.cuda() 102 | 103 | self.z.sample_() 104 | self.y.sample_() 105 | generated_x = self.generator(self.z, self.y) 106 | generated_x = generated_x.detach() 107 | 108 | #print("generated_shape=", generated_x.shape) 109 | 110 | real_weights = None 111 | fake_weights = None 112 | real_cat_weights = None 113 | fake_cat_weights = None 114 | 115 | dis_loss = 0 116 | gen_loss = 0 117 | #select p% of the training data, label them 118 | real_x_selected, real_y_selected, _, index_selected = active_sampling(self.args, inputs, targets, self.NetD_Ensemble, need_sample=(self.iterations>1)) 119 | 120 | for i in range(self.args.ensemble_num): 121 | optimizer = self.Opti_Ensemble[i] 122 | netD = self.NetD_Ensemble[i] 123 | out_real_fake, out_real_categoy = netD(real_x_selected, real_y_selected) 124 | 125 | loss_adv_real, real_loss_cat, real_weights, real_cat_weights = loss_dis_real(out_real_fake, out_real_categoy, real_y_selected, real_weights, real_cat_weights) 126 | real_loss = loss_adv_real+real_loss_cat 127 | 128 | #train on fake data 129 | output, out_fake_category = netD(generated_x, self.y.detach()) 130 | loss_adv_fake, loss_cat_fake, fake_weights, fake_cat_weights = loss_dis_fake(output, out_fake_category, self.y, fake_weights, fake_cat_weights) 131 | fake_loss = loss_adv_fake + loss_cat_fake 132 | sum_loss = real_loss+fake_loss 133 | dis_loss += sum_loss 134 | 135 | self.train_history['discriminator_loss_'+str(i)].append(sum_loss) 136 | optimizer.zero_grad() 137 | sum_loss.backward() 138 | optimizer.step() 139 | #print('real_loss=', real_loss.data, " fake_loss=", fake_loss.data) 140 | self.train_history['discriminator_loss'].append(dis_loss) 141 | 142 | #train the generator 143 | self.z.sample_() 144 | self.y.sample_() 145 | generated_x = self.generator(self.z, self.y) 146 | 147 | gen_loss = 0 148 | gen_weights = None 149 | gen_cat_weights = None 150 | for i in range(self.args.ensemble_num): 151 | #optimizer = names['optimizerD_' + str(i)] 152 | netD = self.NetD_Ensemble[i] 153 | output, out_category = netD(generated_x, self.y) 154 | #out_real_fake, out_real_categoy = netD(real_x, real_y) 155 | loss, loss_cat, gen_weights, gen_cat_weights = loss_dis_real(output, out_category, self.y, gen_weights, gen_cat_weights) 156 | #loss, gen_weights = loss_gen(output, gen_weights) 157 | gen_loss += (loss+loss_cat) 158 | self.train_history['generator_loss'].append(gen_loss) 159 | self.optimizer_g.zero_grad() 160 | gen_loss.backward() 161 | self.optimizer_g.step() 162 | 163 | if self.iterations%10==0: 164 | print("dis_loss=%.4f gen_loss=%.4f" % (gen_loss.data, dis_loss.data )) 165 | 166 | torch.cuda.empty_cache() 167 | 168 | y_train, y_scores = self.predict(train_loader, self.NetD_Ensemble) 169 | #print("anomaly number=", np.sum(y_train)) 170 | y_scores_pandas = pd.DataFrame(y_scores) 171 | y_scores_pandas.fillna(0, inplace=True) 172 | y_scores = y_scores_pandas.values 173 | 174 | auc, gmean = AUC_and_Gmean(y_train, y_scores) 175 | self.train_history['train_auc'].append(auc) 176 | self.train_history['train_Gmean'].append(gmean) 177 | 178 | y_test, y_scores_test = self.predict(test_loader, self.NetD_Ensemble) 179 | y_scores_pandas = pd.DataFrame(y_scores_test) 180 | y_scores_pandas.fillna(0, inplace=True) 181 | y_scores_test = y_scores_pandas.values 182 | 183 | test_auc, test_gmean = AUC_and_Gmean(y_test, y_scores_test) 184 | 185 | return auc, gmean, test_auc, test_gmean 186 | 187 | def predict(self, data_loader, dis_Ensemble=None): 188 | p = [] 189 | y = [] 190 | for (real_x, real_y, _, index) in tqdm(data_loader): 191 | final_pt = 0 192 | for i in range(self.args.ensemble_num): 193 | pt = self.Best_Ensemble[i](real_x, mode=1) if dis_Ensemble is None else dis_Ensemble[i](real_x, mode=1) 194 | final_pt = pt.cpu().detach().numpy() if i==0 else final_pt+pt.cpu().detach().numpy() 195 | 196 | final_pt /= self.args.ensemble_num 197 | #final_pt = final_pt.view(-1,) 198 | 199 | p += [final_pt] 200 | y += [real_y.cpu().detach().numpy()] 201 | #p = torch.cat(p, 0).cpu().detach().numpy() 202 | #y = torch.cat(y, 0).cpu().detach().numpy() 203 | #print(p) 204 | p = np.concatenate(p, 0) 205 | y = np.concatenate(y, 0) 206 | return y, p 207 | 208 | 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /EAL-GAN/models/pyod_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """A set of utility functions to support outlier detection. 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import numpy as np 11 | from numpy import percentile 12 | import numbers 13 | import math 14 | 15 | import sklearn 16 | from sklearn.metrics import precision_score 17 | from sklearn.preprocessing import StandardScaler 18 | 19 | from sklearn.utils import column_or_1d 20 | from sklearn.utils import check_array 21 | from sklearn.utils import check_consistent_length 22 | 23 | from sklearn.utils import check_random_state 24 | from sklearn.utils.random import sample_without_replacement 25 | from sklearn.metrics import roc_auc_score 26 | 27 | MAX_INT = np.iinfo(np.int32).max 28 | MIN_INT = -1 * MAX_INT 29 | 30 | 31 | def check_parameter(param, low=MIN_INT, high=MAX_INT, param_name='', 32 | include_left=False, include_right=False): 33 | """Check if an input is within the defined range. 34 | 35 | Parameters 36 | ---------- 37 | param : int, float 38 | The input parameter to check. 39 | 40 | low : int, float 41 | The lower bound of the range. 42 | 43 | high : int, float 44 | The higher bound of the range. 45 | 46 | param_name : str, optional (default='') 47 | The name of the parameter. 48 | 49 | include_left : bool, optional (default=False) 50 | Whether includes the lower bound (lower bound <=). 51 | 52 | include_right : bool, optional (default=False) 53 | Whether includes the higher bound (<= higher bound). 54 | 55 | Returns 56 | ------- 57 | within_range : bool or raise errors 58 | Whether the parameter is within the range of (low, high) 59 | 60 | """ 61 | 62 | # param, low and high should all be numerical 63 | if not isinstance(param, (numbers.Integral, np.integer, np.float)): 64 | raise TypeError('{param_name} is set to {param} Not numerical'.format( 65 | param=param, param_name=param_name)) 66 | 67 | if not isinstance(low, (numbers.Integral, np.integer, np.float)): 68 | raise TypeError('low is set to {low}. Not numerical'.format(low=low)) 69 | 70 | if not isinstance(high, (numbers.Integral, np.integer, np.float)): 71 | raise TypeError('high is set to {high}. Not numerical'.format( 72 | high=high)) 73 | 74 | # at least one of the bounds should be specified 75 | if low is MIN_INT and high is MAX_INT: 76 | raise ValueError('Neither low nor high bounds is undefined') 77 | 78 | # if wrong bound values are used 79 | if low > high: 80 | raise ValueError( 81 | 'Lower bound > Higher bound') 82 | 83 | # value check under different bound conditions 84 | if (include_left and include_right) and (param < low or param > high): 85 | raise ValueError( 86 | '{param_name} is set to {param}. ' 87 | 'Not in the range of [{low}, {high}].'.format( 88 | param=param, low=low, high=high, param_name=param_name)) 89 | 90 | elif (include_left and not include_right) and ( 91 | param < low or param >= high): 92 | raise ValueError( 93 | '{param_name} is set to {param}. ' 94 | 'Not in the range of [{low}, {high}).'.format( 95 | param=param, low=low, high=high, param_name=param_name)) 96 | 97 | elif (not include_left and include_right) and ( 98 | param <= low or param > high): 99 | raise ValueError( 100 | '{param_name} is set to {param}. ' 101 | 'Not in the range of ({low}, {high}].'.format( 102 | param=param, low=low, high=high, param_name=param_name)) 103 | 104 | elif (not include_left and not include_right) and ( 105 | param <= low or param >= high): 106 | raise ValueError( 107 | '{param_name} is set to {param}. ' 108 | 'Not in the range of ({low}, {high}).'.format( 109 | param=param, low=low, high=high, param_name=param_name)) 110 | else: 111 | return True 112 | 113 | 114 | def check_detector(detector): 115 | """Checks if fit and decision_function methods exist for given detector 116 | 117 | Parameters 118 | ---------- 119 | detector : pyod.models 120 | Detector instance for which the check is performed. 121 | 122 | """ 123 | 124 | if not hasattr(detector, 'fit') or not hasattr(detector, 125 | 'decision_function'): 126 | raise AttributeError("%s is not a detector instance." % (detector)) 127 | 128 | 129 | def standardizer(X, X_t=None, keep_scalar=False): 130 | """Conduct Z-normalization on data to turn input samples become zero-mean 131 | and unit variance. 132 | 133 | Parameters 134 | ---------- 135 | X : numpy array of shape (n_samples, n_features) 136 | The training samples 137 | 138 | X_t : numpy array of shape (n_samples_new, n_features), optional (default=None) 139 | The data to be converted 140 | 141 | keep_scalar : bool, optional (default=False) 142 | The flag to indicate whether to return the scalar 143 | 144 | Returns 145 | ------- 146 | X_norm : numpy array of shape (n_samples, n_features) 147 | X after the Z-score normalization 148 | 149 | X_t_norm : numpy array of shape (n_samples, n_features) 150 | X_t after the Z-score normalization 151 | 152 | scalar : sklearn scalar object 153 | The scalar used in conversion 154 | 155 | """ 156 | X = check_array(X) 157 | scaler = StandardScaler().fit(X) 158 | 159 | if X_t is None: 160 | if keep_scalar: 161 | return scaler.transform(X), scaler 162 | else: 163 | return scaler.transform(X) 164 | else: 165 | X_t = check_array(X_t) 166 | if X.shape[1] != X_t.shape[1]: 167 | raise ValueError( 168 | "The number of input data feature should be consistent" 169 | "X has {0} features and X_t has {1} features.".format( 170 | X.shape[1], X_t.shape[1])) 171 | if keep_scalar: 172 | return scaler.transform(X), scaler.transform(X_t), scaler 173 | else: 174 | return scaler.transform(X), scaler.transform(X_t) 175 | 176 | 177 | def score_to_label(pred_scores, outliers_fraction=0.1): 178 | """Turn raw outlier outlier scores to binary labels (0 or 1). 179 | 180 | Parameters 181 | ---------- 182 | pred_scores : list or numpy array of shape (n_samples,) 183 | Raw outlier scores. Outliers are assumed have larger values. 184 | 185 | outliers_fraction : float in (0,1) 186 | Percentage of outliers. 187 | 188 | Returns 189 | ------- 190 | outlier_labels : numpy array of shape (n_samples,) 191 | For each observation, tells whether or not 192 | it should be considered as an outlier according to the 193 | fitted model. Return the outlier probability, ranging 194 | in [0,1]. 195 | """ 196 | # check input values 197 | pred_scores = column_or_1d(pred_scores) 198 | check_parameter(outliers_fraction, 0, 1) 199 | 200 | threshold = percentile(pred_scores, 100 * (1 - outliers_fraction)) 201 | pred_labels = (pred_scores > threshold).astype('int') 202 | return pred_labels 203 | 204 | 205 | def precision_n_scores(y, y_pred, n=None): 206 | """Utility function to calculate precision @ rank n. 207 | 208 | Parameters 209 | ---------- 210 | y : list or numpy array of shape (n_samples,) 211 | The ground truth. Binary (0: inliers, 1: outliers). 212 | 213 | y_pred : list or numpy array of shape (n_samples,) 214 | The raw outlier scores as returned by a fitted model. 215 | 216 | n : int, optional (default=None) 217 | The number of outliers. if not defined, infer using ground truth. 218 | 219 | Returns 220 | ------- 221 | precision_at_rank_n : float 222 | Precision at rank n score. 223 | 224 | """ 225 | 226 | # turn raw prediction decision scores into binary labels 227 | y_pred = get_label_n(y, y_pred, n) 228 | 229 | # enforce formats of y and labels_ 230 | y = column_or_1d(y) 231 | y_pred = column_or_1d(y_pred) 232 | 233 | return precision_score(y, y_pred) 234 | 235 | 236 | def get_label_n(y, y_pred, n=None): 237 | """Function to turn raw outlier scores into binary labels by assign 1 238 | to top n outlier scores. 239 | 240 | Parameters 241 | ---------- 242 | y : list or numpy array of shape (n_samples,) 243 | The ground truth. Binary (0: inliers, 1: outliers). 244 | 245 | y_pred : list or numpy array of shape (n_samples,) 246 | The raw outlier scores as returned by a fitted model. 247 | 248 | n : int, optional (default=None) 249 | The number of outliers. if not defined, infer using ground truth. 250 | 251 | Returns 252 | ------- 253 | labels : numpy array of shape (n_samples,) 254 | binary labels 0: normal points and 1: outliers 255 | 256 | Examples 257 | -------- 258 | >>> from pyod.utils.utility import get_label_n 259 | >>> y = [0, 1, 1, 0, 0] 260 | >>> y_pred = [0.1, 0.5, 0.3, 0.2, 0.7] 261 | >>> get_label_n(y, y_pred) 262 | array([0, 1, 0, 0, 1]) 263 | 264 | """ 265 | 266 | # enforce formats of inputs 267 | y = column_or_1d(y) 268 | y_pred = column_or_1d(y_pred) 269 | 270 | check_consistent_length(y, y_pred) 271 | y_len = len(y) # the length of targets 272 | 273 | # calculate the percentage of outliers 274 | if n is not None: 275 | outliers_fraction = n / y_len 276 | else: 277 | outliers_fraction = np.count_nonzero(y) / y_len 278 | 279 | threshold = percentile(y_pred, 100 * (1 - outliers_fraction)) 280 | y_pred = (y_pred > threshold).astype('int') 281 | 282 | return y_pred 283 | 284 | def gmean_scores(y, y_pred, threshold=0.5): 285 | """Utility function to calculate precision @ rank n. 286 | 287 | Parameters 288 | ---------- 289 | y : list or numpy array of shape (n_samples,) 290 | The ground truth. Binary (0: inliers, 1: outliers). 291 | 292 | y_pred : list or numpy array of shape (n_samples,) 293 | The raw outlier scores as returned by a fitted model. 294 | 295 | 296 | Returns 297 | ------- 298 | Gmean: float 299 | """ 300 | y_pred = get_label_n(y, y_pred) 301 | #y_pred = (y_pred > threshold).astype('int') 302 | ones_all = (y==1).sum() 303 | ones_correct = ((y==1) & (y_pred==1)).sum() 304 | zeros_all = (y==0).sum() 305 | zeros_correct = ((y==0) & (y_pred==0)).sum() 306 | Gmean = np.sqrt(1.0*ones_correct/ones_all) 307 | Gmean *= np.sqrt(1.0*zeros_correct/zeros_all) 308 | 309 | return Gmean 310 | 311 | def get_gmean(y, y_pred, threshold=0.5): 312 | """Utility function to calculate precision @ rank n. 313 | 314 | Parameters 315 | ---------- 316 | y : list or numpy array of shape (n_samples,) 317 | The ground truth. Binary (0: inliers, 1: outliers). 318 | 319 | y_pred : list or numpy array of shape (n_samples,) 320 | The raw outlier scores as returned by a fitted model. 321 | 322 | 323 | Returns 324 | ------- 325 | Gmean: float 326 | """ 327 | #y_pred = get_label_n(y, y_pred) 328 | y = y.reshape(-1,) 329 | y_pred = y_pred.reshape(-1,) 330 | y_pred = (y_pred >= threshold).astype('int') 331 | ones_all = (y==1).sum() 332 | ones_correct = ((y==1) & (y_pred==1)).sum() 333 | zeros_all = (y==0).sum() 334 | zeros_correct = ((y==0) & (y_pred==0)).sum() 335 | Gmean = np.sqrt((1.0*ones_correct/ones_all) * (1.0*zeros_correct/zeros_all)) 336 | #Gmean *= np.sqrt 337 | 338 | return Gmean 339 | 340 | def AUC_and_Gmean(y_test, y_scores): 341 | #print(y_test) 342 | #print(y_scores) 343 | auc = round(roc_auc_score(y_test, y_scores), ndigits=4) 344 | #prn = round(precision_n_scores(y_test, y_scores), ndigits=4) 345 | gmean = round(get_gmean(y_test, y_scores, 0.5), ndigits=4) 346 | 347 | return auc, gmean 348 | -------------------------------------------------------------------------------- /EAL-GAN/models/layers.py: -------------------------------------------------------------------------------- 1 | ''' Layers 2 | This file contains various layers for the BigGAN models. 3 | ''' 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torch.nn import Parameter as P 11 | 12 | from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d 13 | 14 | 15 | # Projection of x onto y 16 | def proj(x, y): 17 | return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) 18 | 19 | 20 | # Orthogonalize x wrt list of vectors ys 21 | def gram_schmidt(x, ys): 22 | for y in ys: 23 | x = x - proj(x, y) 24 | return x 25 | 26 | 27 | # Apply num_itrs steps of the power method to estimate top N singular values. 28 | def power_iteration(W, u_, update=True, eps=1e-12): 29 | # Lists holding singular vectors and values 30 | us, vs, svs = [], [], [] 31 | for i, u in enumerate(u_): 32 | # Run one step of the power iteration 33 | with torch.no_grad(): 34 | v = torch.matmul(u, W) 35 | # Run Gram-Schmidt to subtract components of all other singular vectors 36 | v = F.normalize(gram_schmidt(v, vs), eps=eps) 37 | # Add to the list 38 | vs += [v] 39 | # Update the other singular vector 40 | u = torch.matmul(v, W.t()) 41 | # Run Gram-Schmidt to subtract components of all other singular vectors 42 | u = F.normalize(gram_schmidt(u, us), eps=eps) 43 | # Add to the list 44 | us += [u] 45 | if update: 46 | u_[i][:] = u 47 | # Compute this singular value and add it to the list 48 | svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] 49 | #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] 50 | return svs, us, vs 51 | 52 | 53 | # Convenience passthrough function 54 | class identity(nn.Module): 55 | def forward(self, input): 56 | return input 57 | 58 | 59 | # Spectral normalization base class 60 | class SN(object): 61 | def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): 62 | # Number of power iterations per step 63 | self.num_itrs = num_itrs 64 | # Number of singular values 65 | self.num_svs = num_svs 66 | # Transposed? 67 | self.transpose = transpose 68 | # Epsilon value for avoiding divide-by-0 69 | self.eps = eps 70 | # Register a singular vector for each sv 71 | for i in range(self.num_svs): 72 | self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) 73 | self.register_buffer('sv%d' % i, torch.ones(1)) 74 | 75 | # Singular vectors (u side) 76 | @property 77 | def u(self): 78 | return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] 79 | 80 | # Singular values; 81 | # note that these buffers are just for logging and are not used in training. 82 | @property 83 | def sv(self): 84 | return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] 85 | 86 | # Compute the spectrally-normalized weight 87 | def W_(self): 88 | W_mat = self.weight.view(self.weight.size(0), -1) 89 | if self.transpose: 90 | W_mat = W_mat.t() 91 | # Apply num_itrs power iterations 92 | for _ in range(self.num_itrs): 93 | svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) 94 | # Update the svs 95 | if self.training: 96 | with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! 97 | for i, sv in enumerate(svs): 98 | self.sv[i][:] = sv 99 | return self.weight / svs[0] 100 | 101 | 102 | # 2D Conv layer with spectral norm 103 | class SNConv2d(nn.Conv2d, SN): 104 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 105 | padding=0, dilation=1, groups=1, bias=True, 106 | num_svs=1, num_itrs=1, eps=1e-12): 107 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, 108 | padding, dilation, groups, bias) 109 | SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) 110 | def forward(self, x): 111 | return F.conv2d(x, self.W_(), self.bias, self.stride, 112 | self.padding, self.dilation, self.groups) 113 | 114 | 115 | # Linear layer with spectral norm 116 | class SNLinear(nn.Linear, SN): 117 | def __init__(self, in_features, out_features, bias=True, 118 | num_svs=1, num_itrs=1, eps=1e-12): 119 | nn.Linear.__init__(self, in_features, out_features, bias) 120 | SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) 121 | def forward(self, x): 122 | return F.linear(x, self.W_(), self.bias) 123 | 124 | 125 | # Embedding layer with spectral norm 126 | # We use num_embeddings as the dim instead of embedding_dim here 127 | # for convenience sake 128 | class SNEmbedding(nn.Embedding, SN): 129 | def __init__(self, num_embeddings, embedding_dim, padding_idx=None, 130 | max_norm=None, norm_type=2, scale_grad_by_freq=False, 131 | sparse=False, _weight=None, 132 | num_svs=1, num_itrs=1, eps=1e-12): 133 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, 134 | max_norm, norm_type, scale_grad_by_freq, 135 | sparse, _weight) 136 | SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) 137 | def forward(self, x): 138 | return F.embedding(x, self.W_()) 139 | 140 | 141 | # A non-local block as used in SA-GAN 142 | # Note that the implementation as described in the paper is largely incorrect; 143 | # refer to the released code for the actual implementation. 144 | class Attention(nn.Module): 145 | def __init__(self, ch, which_conv=SNConv2d, name='attention'): 146 | super(Attention, self).__init__() 147 | # Channel multiplier 148 | self.ch = ch 149 | self.which_conv = which_conv 150 | self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) 151 | self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) 152 | self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) 153 | self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) 154 | # Learnable gain parameter 155 | self.gamma = P(torch.tensor(0.), requires_grad=True) 156 | def forward(self, x, y=None): 157 | # Apply convs 158 | theta = self.theta(x) 159 | phi = F.max_pool2d(self.phi(x), [2,2]) 160 | g = F.max_pool2d(self.g(x), [2,2]) 161 | # Perform reshapes 162 | theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3]) 163 | phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4) 164 | g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4) 165 | # Matmul and softmax to get attention maps 166 | beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) 167 | # Attention map times g path 168 | o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) 169 | return self.gamma * o + x 170 | 171 | 172 | # Fused batchnorm op 173 | def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): 174 | # Apply scale and shift--if gain and bias are provided, fuse them here 175 | # Prepare scale 176 | scale = torch.rsqrt(var + eps) 177 | # If a gain is provided, use it 178 | if gain is not None: 179 | scale = scale * gain 180 | # Prepare shift 181 | shift = mean * scale 182 | # If bias is provided, use it 183 | if bias is not None: 184 | shift = shift - bias 185 | return x * scale - shift 186 | #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. 187 | 188 | 189 | # Manual BN 190 | # Calculate means and variances using mean-of-squares minus mean-squared 191 | def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): 192 | # Cast x to float32 if necessary 193 | float_x = x.float() 194 | # Calculate expected value of x (m) and expected value of x**2 (m2) 195 | # Mean of x 196 | m = torch.mean(float_x, [0, 2, 3], keepdim=True) 197 | # Mean of x squared 198 | m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) 199 | # Calculate variance as mean of squared minus mean squared. 200 | var = (m2 - m **2) 201 | # Cast back to float 16 if necessary 202 | var = var.type(x.type()) 203 | m = m.type(x.type()) 204 | # Return mean and variance for updating stored mean/var if requested 205 | if return_mean_var: 206 | return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() 207 | else: 208 | return fused_bn(x, m, var, gain, bias, eps) 209 | 210 | 211 | # My batchnorm, supports standing stats 212 | class myBN(nn.Module): 213 | def __init__(self, num_channels, eps=1e-5, momentum=0.1): 214 | super(myBN, self).__init__() 215 | # momentum for updating running stats 216 | self.momentum = momentum 217 | # epsilon to avoid dividing by 0 218 | self.eps = eps 219 | # Momentum 220 | self.momentum = momentum 221 | # Register buffers 222 | self.register_buffer('stored_mean', torch.zeros(num_channels)) 223 | self.register_buffer('stored_var', torch.ones(num_channels)) 224 | self.register_buffer('accumulation_counter', torch.zeros(1)) 225 | # Accumulate running means and vars 226 | self.accumulate_standing = False 227 | 228 | # reset standing stats 229 | def reset_stats(self): 230 | self.stored_mean[:] = 0 231 | self.stored_var[:] = 0 232 | self.accumulation_counter[:] = 0 233 | 234 | def forward(self, x, gain, bias): 235 | if self.training: 236 | out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) 237 | # If accumulating standing stats, increment them 238 | if self.accumulate_standing: 239 | self.stored_mean[:] = self.stored_mean + mean.data 240 | self.stored_var[:] = self.stored_var + var.data 241 | self.accumulation_counter += 1.0 242 | # If not accumulating standing stats, take running averages 243 | else: 244 | self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum 245 | self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum 246 | return out 247 | # If not in training mode, use the stored statistics 248 | else: 249 | mean = self.stored_mean.view(1, -1, 1, 1) 250 | var = self.stored_var.view(1, -1, 1, 1) 251 | # If using standing stats, divide them by the accumulation counter 252 | if self.accumulate_standing: 253 | mean = mean / self.accumulation_counter 254 | var = var / self.accumulation_counter 255 | return fused_bn(x, mean, var, gain, bias, self.eps) 256 | 257 | 258 | # Simple function to handle groupnorm norm stylization 259 | def groupnorm(x, norm_style): 260 | # If number of channels specified in norm_style: 261 | if 'ch' in norm_style: 262 | ch = int(norm_style.split('_')[-1]) 263 | groups = max(int(x.shape[1]) // ch, 1) 264 | # If number of groups specified in norm style 265 | elif 'grp' in norm_style: 266 | groups = int(norm_style.split('_')[-1]) 267 | # If neither, default to groups = 16 268 | else: 269 | groups = 16 270 | return F.group_norm(x, groups) 271 | 272 | 273 | # Class-conditional bn 274 | # output size is the number of channels, input size is for the linear layers 275 | # Andy's Note: this class feels messy but I'm not really sure how to clean it up 276 | # Suggestions welcome! (By which I mean, refactor this and make a pull request 277 | # if you want to make this more readable/usable). 278 | class ccbn(nn.Module): 279 | def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, 280 | cross_replica=False, mybn=False, norm_style='bn',): 281 | super(ccbn, self).__init__() 282 | self.output_size, self.input_size = output_size, input_size 283 | # Prepare gain and bias layers 284 | self.gain = which_linear(input_size, output_size) 285 | self.bias = which_linear(input_size, output_size) 286 | # epsilon to avoid dividing by 0 287 | self.eps = eps 288 | # Momentum 289 | self.momentum = momentum 290 | # Use cross-replica batchnorm? 291 | self.cross_replica = cross_replica 292 | # Use my batchnorm? 293 | self.mybn = mybn 294 | # Norm style? 295 | self.norm_style = norm_style 296 | 297 | if self.cross_replica: 298 | self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) 299 | elif self.mybn: 300 | self.bn = myBN(output_size, self.eps, self.momentum) 301 | elif self.norm_style in ['bn', 'in']: 302 | self.register_buffer('stored_mean', torch.zeros(output_size)) 303 | self.register_buffer('stored_var', torch.ones(output_size)) 304 | 305 | 306 | def forward(self, x, y): 307 | # Calculate class-conditional gains and biases 308 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) 309 | bias = self.bias(y).view(y.size(0), -1, 1, 1) 310 | # If using my batchnorm 311 | if self.mybn or self.cross_replica: 312 | return self.bn(x, gain=gain, bias=bias) 313 | # else: 314 | else: 315 | if self.norm_style == 'bn': 316 | out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, 317 | self.training, 0.1, self.eps) 318 | elif self.norm_style == 'in': 319 | out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, 320 | self.training, 0.1, self.eps) 321 | elif self.norm_style == 'gn': 322 | out = groupnorm(x, self.normstyle) 323 | elif self.norm_style == 'nonorm': 324 | out = x 325 | return out * gain + bias 326 | def extra_repr(self): 327 | s = 'out: {output_size}, in: {input_size},' 328 | s +=' cross_replica={cross_replica}' 329 | return s.format(**self.__dict__) 330 | 331 | 332 | # Normal, non-class-conditional BN 333 | class bn(nn.Module): 334 | def __init__(self, output_size, eps=1e-5, momentum=0.1, 335 | cross_replica=False, mybn=False): 336 | super(bn, self).__init__() 337 | self.output_size= output_size 338 | # Prepare gain and bias layers 339 | self.gain = P(torch.ones(output_size), requires_grad=True) 340 | self.bias = P(torch.zeros(output_size), requires_grad=True) 341 | # epsilon to avoid dividing by 0 342 | self.eps = eps 343 | # Momentum 344 | self.momentum = momentum 345 | # Use cross-replica batchnorm? 346 | self.cross_replica = cross_replica 347 | # Use my batchnorm? 348 | self.mybn = mybn 349 | 350 | if self.cross_replica: 351 | self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) 352 | elif mybn: 353 | self.bn = myBN(output_size, self.eps, self.momentum) 354 | # Register buffers if neither of the above 355 | else: 356 | self.register_buffer('stored_mean', torch.zeros(output_size)) 357 | self.register_buffer('stored_var', torch.ones(output_size)) 358 | 359 | def forward(self, x, y=None): 360 | if self.cross_replica or self.mybn: 361 | gain = self.gain.view(1,-1,1,1) 362 | bias = self.bias.view(1,-1,1,1) 363 | return self.bn(x, gain=gain, bias=bias) 364 | else: 365 | return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, 366 | self.bias, self.training, self.momentum, self.eps) 367 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | # _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input, gain=None, bias=None): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | out = F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | if gain is not None: 55 | out = out + gain 56 | if bias is not None: 57 | out = out + bias 58 | return out 59 | 60 | # Resize the input to (B, C, -1). 61 | input_shape = input.size() 62 | # print(input_shape) 63 | input = input.view(input.size(0), input.size(1), -1) 64 | 65 | # Compute the sum and square-sum. 66 | sum_size = input.size(0) * input.size(2) 67 | input_sum = _sum_ft(input) 68 | input_ssum = _sum_ft(input ** 2) 69 | # Reduce-and-broadcast the statistics. 70 | # print('it begins') 71 | if self._parallel_id == 0: 72 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 73 | else: 74 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | # if self._parallel_id == 0: 76 | # # print('here') 77 | # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 78 | # else: 79 | # # print('there') 80 | # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 81 | 82 | # print('how2') 83 | # num = sum_size 84 | # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) 85 | # Fix the graph 86 | # sum = (sum.detach() - input_sum.detach()) + input_sum 87 | # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum 88 | 89 | # mean = sum / num 90 | # var = ssum / num - mean ** 2 91 | # # var = (ssum - mean * sum) / num 92 | # inv_std = torch.rsqrt(var + self.eps) 93 | 94 | # Compute the output. 95 | if gain is not None: 96 | # print('gaining') 97 | # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) 98 | # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) 99 | # output = input * scale - shift 100 | output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) 101 | elif self.affine: 102 | # MJY:: Fuse the multiplication for speed. 103 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 104 | else: 105 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 106 | 107 | # Reshape it. 108 | return output.view(input_shape) 109 | 110 | def __data_parallel_replicate__(self, ctx, copy_id): 111 | self._is_parallel = True 112 | self._parallel_id = copy_id 113 | 114 | # parallel_id == 0 means master device. 115 | if self._parallel_id == 0: 116 | ctx.sync_master = self._sync_master 117 | else: 118 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 119 | 120 | def _data_parallel_master(self, intermediates): 121 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 122 | 123 | # Always using same "device order" makes the ReduceAdd operation faster. 124 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 125 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 126 | 127 | to_reduce = [i[1][:2] for i in intermediates] 128 | to_reduce = [j for i in to_reduce for j in i] # flatten 129 | target_gpus = [i[1].sum.get_device() for i in intermediates] 130 | 131 | sum_size = sum([i[1].sum_size for i in intermediates]) 132 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 133 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 134 | 135 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 136 | # print('a') 137 | # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) 138 | # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) 139 | # print('b') 140 | outputs = [] 141 | for i, rec in enumerate(intermediates): 142 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 143 | # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) 144 | 145 | return outputs 146 | 147 | def _compute_mean_std(self, sum_, ssum, size): 148 | """Compute the mean and standard-deviation with sum and square-sum. This method 149 | also maintains the moving average on the master device.""" 150 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 151 | mean = sum_ / size 152 | sumvar = ssum - sum_ * mean 153 | unbias_var = sumvar / (size - 1) 154 | bias_var = sumvar / size 155 | 156 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 157 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 158 | return mean, torch.rsqrt(bias_var + self.eps) 159 | # return mean, bias_var.clamp(self.eps) ** -0.5 160 | 161 | 162 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 163 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 164 | mini-batch. 165 | 166 | .. math:: 167 | 168 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 169 | 170 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 171 | standard-deviation are reduced across all devices during training. 172 | 173 | For example, when one uses `nn.DataParallel` to wrap the network during 174 | training, PyTorch's implementation normalize the tensor on each device using 175 | the statistics only on that device, which accelerated the computation and 176 | is also easy to implement, but the statistics might be inaccurate. 177 | Instead, in this synchronized version, the statistics will be computed 178 | over all training samples distributed on multiple devices. 179 | 180 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 181 | as the built-in PyTorch implementation. 182 | 183 | The mean and standard-deviation are calculated per-dimension over 184 | the mini-batches and gamma and beta are learnable parameter vectors 185 | of size C (where C is the input size). 186 | 187 | During training, this layer keeps a running estimate of its computed mean 188 | and variance. The running sum is kept with a default momentum of 0.1. 189 | 190 | During evaluation, this running mean/variance is used for normalization. 191 | 192 | Because the BatchNorm is done over the `C` dimension, computing statistics 193 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 194 | 195 | Args: 196 | num_features: num_features from an expected input of size 197 | `batch_size x num_features [x width]` 198 | eps: a value added to the denominator for numerical stability. 199 | Default: 1e-5 200 | momentum: the value used for the running_mean and running_var 201 | computation. Default: 0.1 202 | affine: a boolean value that when set to ``True``, gives the layer learnable 203 | affine parameters. Default: ``True`` 204 | 205 | Shape: 206 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 207 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 208 | 209 | Examples: 210 | >>> # With Learnable Parameters 211 | >>> m = SynchronizedBatchNorm1d(100) 212 | >>> # Without Learnable Parameters 213 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 214 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 215 | >>> output = m(input) 216 | """ 217 | 218 | def _check_input_dim(self, input): 219 | if input.dim() != 2 and input.dim() != 3: 220 | raise ValueError('expected 2D or 3D input (got {}D input)' 221 | .format(input.dim())) 222 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 223 | 224 | 225 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 226 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 227 | of 3d inputs 228 | 229 | .. math:: 230 | 231 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 232 | 233 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 234 | standard-deviation are reduced across all devices during training. 235 | 236 | For example, when one uses `nn.DataParallel` to wrap the network during 237 | training, PyTorch's implementation normalize the tensor on each device using 238 | the statistics only on that device, which accelerated the computation and 239 | is also easy to implement, but the statistics might be inaccurate. 240 | Instead, in this synchronized version, the statistics will be computed 241 | over all training samples distributed on multiple devices. 242 | 243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 244 | as the built-in PyTorch implementation. 245 | 246 | The mean and standard-deviation are calculated per-dimension over 247 | the mini-batches and gamma and beta are learnable parameter vectors 248 | of size C (where C is the input size). 249 | 250 | During training, this layer keeps a running estimate of its computed mean 251 | and variance. The running sum is kept with a default momentum of 0.1. 252 | 253 | During evaluation, this running mean/variance is used for normalization. 254 | 255 | Because the BatchNorm is done over the `C` dimension, computing statistics 256 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 257 | 258 | Args: 259 | num_features: num_features from an expected input of 260 | size batch_size x num_features x height x width 261 | eps: a value added to the denominator for numerical stability. 262 | Default: 1e-5 263 | momentum: the value used for the running_mean and running_var 264 | computation. Default: 0.1 265 | affine: a boolean value that when set to ``True``, gives the layer learnable 266 | affine parameters. Default: ``True`` 267 | 268 | Shape: 269 | - Input: :math:`(N, C, H, W)` 270 | - Output: :math:`(N, C, H, W)` (same shape as input) 271 | 272 | Examples: 273 | >>> # With Learnable Parameters 274 | >>> m = SynchronizedBatchNorm2d(100) 275 | >>> # Without Learnable Parameters 276 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 277 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 278 | >>> output = m(input) 279 | """ 280 | 281 | def _check_input_dim(self, input): 282 | if input.dim() != 4: 283 | raise ValueError('expected 4D input (got {}D input)' 284 | .format(input.dim())) 285 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 286 | 287 | 288 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 289 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 290 | of 4d inputs 291 | 292 | .. math:: 293 | 294 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 295 | 296 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 297 | standard-deviation are reduced across all devices during training. 298 | 299 | For example, when one uses `nn.DataParallel` to wrap the network during 300 | training, PyTorch's implementation normalize the tensor on each device using 301 | the statistics only on that device, which accelerated the computation and 302 | is also easy to implement, but the statistics might be inaccurate. 303 | Instead, in this synchronized version, the statistics will be computed 304 | over all training samples distributed on multiple devices. 305 | 306 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 307 | as the built-in PyTorch implementation. 308 | 309 | The mean and standard-deviation are calculated per-dimension over 310 | the mini-batches and gamma and beta are learnable parameter vectors 311 | of size C (where C is the input size). 312 | 313 | During training, this layer keeps a running estimate of its computed mean 314 | and variance. The running sum is kept with a default momentum of 0.1. 315 | 316 | During evaluation, this running mean/variance is used for normalization. 317 | 318 | Because the BatchNorm is done over the `C` dimension, computing statistics 319 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 320 | or Spatio-temporal BatchNorm 321 | 322 | Args: 323 | num_features: num_features from an expected input of 324 | size batch_size x num_features x depth x height x width 325 | eps: a value added to the denominator for numerical stability. 326 | Default: 1e-5 327 | momentum: the value used for the running_mean and running_var 328 | computation. Default: 0.1 329 | affine: a boolean value that when set to ``True``, gives the layer learnable 330 | affine parameters. Default: ``True`` 331 | 332 | Shape: 333 | - Input: :math:`(N, C, D, H, W)` 334 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 335 | 336 | Examples: 337 | >>> # With Learnable Parameters 338 | >>> m = SynchronizedBatchNorm3d(100) 339 | >>> # Without Learnable Parameters 340 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 341 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 342 | >>> output = m(input) 343 | """ 344 | 345 | def _check_input_dim(self, input): 346 | if input.dim() != 5: 347 | raise ValueError('expected 5D input (got {}D input)' 348 | .format(input.dim())) 349 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /EAL-GAN/models/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | # _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input, gain=None, bias=None): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | out = F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | if gain is not None: 55 | out = out + gain 56 | if bias is not None: 57 | out = out + bias 58 | return out 59 | 60 | # Resize the input to (B, C, -1). 61 | input_shape = input.size() 62 | # print(input_shape) 63 | input = input.view(input.size(0), input.size(1), -1) 64 | 65 | # Compute the sum and square-sum. 66 | sum_size = input.size(0) * input.size(2) 67 | input_sum = _sum_ft(input) 68 | input_ssum = _sum_ft(input ** 2) 69 | # Reduce-and-broadcast the statistics. 70 | # print('it begins') 71 | if self._parallel_id == 0: 72 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 73 | else: 74 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | # if self._parallel_id == 0: 76 | # # print('here') 77 | # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 78 | # else: 79 | # # print('there') 80 | # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 81 | 82 | # print('how2') 83 | # num = sum_size 84 | # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) 85 | # Fix the graph 86 | # sum = (sum.detach() - input_sum.detach()) + input_sum 87 | # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum 88 | 89 | # mean = sum / num 90 | # var = ssum / num - mean ** 2 91 | # # var = (ssum - mean * sum) / num 92 | # inv_std = torch.rsqrt(var + self.eps) 93 | 94 | # Compute the output. 95 | if gain is not None: 96 | # print('gaining') 97 | # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) 98 | # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) 99 | # output = input * scale - shift 100 | output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) 101 | elif self.affine: 102 | # MJY:: Fuse the multiplication for speed. 103 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 104 | else: 105 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 106 | 107 | # Reshape it. 108 | return output.view(input_shape) 109 | 110 | def __data_parallel_replicate__(self, ctx, copy_id): 111 | self._is_parallel = True 112 | self._parallel_id = copy_id 113 | 114 | # parallel_id == 0 means master device. 115 | if self._parallel_id == 0: 116 | ctx.sync_master = self._sync_master 117 | else: 118 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 119 | 120 | def _data_parallel_master(self, intermediates): 121 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 122 | 123 | # Always using same "device order" makes the ReduceAdd operation faster. 124 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 125 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 126 | 127 | to_reduce = [i[1][:2] for i in intermediates] 128 | to_reduce = [j for i in to_reduce for j in i] # flatten 129 | target_gpus = [i[1].sum.get_device() for i in intermediates] 130 | 131 | sum_size = sum([i[1].sum_size for i in intermediates]) 132 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 133 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 134 | 135 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 136 | # print('a') 137 | # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) 138 | # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) 139 | # print('b') 140 | outputs = [] 141 | for i, rec in enumerate(intermediates): 142 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 143 | # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) 144 | 145 | return outputs 146 | 147 | def _compute_mean_std(self, sum_, ssum, size): 148 | """Compute the mean and standard-deviation with sum and square-sum. This method 149 | also maintains the moving average on the master device.""" 150 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 151 | mean = sum_ / size 152 | sumvar = ssum - sum_ * mean 153 | unbias_var = sumvar / (size - 1) 154 | bias_var = sumvar / size 155 | 156 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 157 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 158 | return mean, torch.rsqrt(bias_var + self.eps) 159 | # return mean, bias_var.clamp(self.eps) ** -0.5 160 | 161 | 162 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 163 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 164 | mini-batch. 165 | 166 | .. math:: 167 | 168 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 169 | 170 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 171 | standard-deviation are reduced across all devices during training. 172 | 173 | For example, when one uses `nn.DataParallel` to wrap the network during 174 | training, PyTorch's implementation normalize the tensor on each device using 175 | the statistics only on that device, which accelerated the computation and 176 | is also easy to implement, but the statistics might be inaccurate. 177 | Instead, in this synchronized version, the statistics will be computed 178 | over all training samples distributed on multiple devices. 179 | 180 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 181 | as the built-in PyTorch implementation. 182 | 183 | The mean and standard-deviation are calculated per-dimension over 184 | the mini-batches and gamma and beta are learnable parameter vectors 185 | of size C (where C is the input size). 186 | 187 | During training, this layer keeps a running estimate of its computed mean 188 | and variance. The running sum is kept with a default momentum of 0.1. 189 | 190 | During evaluation, this running mean/variance is used for normalization. 191 | 192 | Because the BatchNorm is done over the `C` dimension, computing statistics 193 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 194 | 195 | Args: 196 | num_features: num_features from an expected input of size 197 | `batch_size x num_features [x width]` 198 | eps: a value added to the denominator for numerical stability. 199 | Default: 1e-5 200 | momentum: the value used for the running_mean and running_var 201 | computation. Default: 0.1 202 | affine: a boolean value that when set to ``True``, gives the layer learnable 203 | affine parameters. Default: ``True`` 204 | 205 | Shape: 206 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 207 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 208 | 209 | Examples: 210 | >>> # With Learnable Parameters 211 | >>> m = SynchronizedBatchNorm1d(100) 212 | >>> # Without Learnable Parameters 213 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 214 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 215 | >>> output = m(input) 216 | """ 217 | 218 | def _check_input_dim(self, input): 219 | if input.dim() != 2 and input.dim() != 3: 220 | raise ValueError('expected 2D or 3D input (got {}D input)' 221 | .format(input.dim())) 222 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 223 | 224 | 225 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 226 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 227 | of 3d inputs 228 | 229 | .. math:: 230 | 231 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 232 | 233 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 234 | standard-deviation are reduced across all devices during training. 235 | 236 | For example, when one uses `nn.DataParallel` to wrap the network during 237 | training, PyTorch's implementation normalize the tensor on each device using 238 | the statistics only on that device, which accelerated the computation and 239 | is also easy to implement, but the statistics might be inaccurate. 240 | Instead, in this synchronized version, the statistics will be computed 241 | over all training samples distributed on multiple devices. 242 | 243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 244 | as the built-in PyTorch implementation. 245 | 246 | The mean and standard-deviation are calculated per-dimension over 247 | the mini-batches and gamma and beta are learnable parameter vectors 248 | of size C (where C is the input size). 249 | 250 | During training, this layer keeps a running estimate of its computed mean 251 | and variance. The running sum is kept with a default momentum of 0.1. 252 | 253 | During evaluation, this running mean/variance is used for normalization. 254 | 255 | Because the BatchNorm is done over the `C` dimension, computing statistics 256 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 257 | 258 | Args: 259 | num_features: num_features from an expected input of 260 | size batch_size x num_features x height x width 261 | eps: a value added to the denominator for numerical stability. 262 | Default: 1e-5 263 | momentum: the value used for the running_mean and running_var 264 | computation. Default: 0.1 265 | affine: a boolean value that when set to ``True``, gives the layer learnable 266 | affine parameters. Default: ``True`` 267 | 268 | Shape: 269 | - Input: :math:`(N, C, H, W)` 270 | - Output: :math:`(N, C, H, W)` (same shape as input) 271 | 272 | Examples: 273 | >>> # With Learnable Parameters 274 | >>> m = SynchronizedBatchNorm2d(100) 275 | >>> # Without Learnable Parameters 276 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 277 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 278 | >>> output = m(input) 279 | """ 280 | 281 | def _check_input_dim(self, input): 282 | if input.dim() != 4: 283 | raise ValueError('expected 4D input (got {}D input)' 284 | .format(input.dim())) 285 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 286 | 287 | 288 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 289 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 290 | of 4d inputs 291 | 292 | .. math:: 293 | 294 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 295 | 296 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 297 | standard-deviation are reduced across all devices during training. 298 | 299 | For example, when one uses `nn.DataParallel` to wrap the network during 300 | training, PyTorch's implementation normalize the tensor on each device using 301 | the statistics only on that device, which accelerated the computation and 302 | is also easy to implement, but the statistics might be inaccurate. 303 | Instead, in this synchronized version, the statistics will be computed 304 | over all training samples distributed on multiple devices. 305 | 306 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 307 | as the built-in PyTorch implementation. 308 | 309 | The mean and standard-deviation are calculated per-dimension over 310 | the mini-batches and gamma and beta are learnable parameter vectors 311 | of size C (where C is the input size). 312 | 313 | During training, this layer keeps a running estimate of its computed mean 314 | and variance. The running sum is kept with a default momentum of 0.1. 315 | 316 | During evaluation, this running mean/variance is used for normalization. 317 | 318 | Because the BatchNorm is done over the `C` dimension, computing statistics 319 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 320 | or Spatio-temporal BatchNorm 321 | 322 | Args: 323 | num_features: num_features from an expected input of 324 | size batch_size x num_features x depth x height x width 325 | eps: a value added to the denominator for numerical stability. 326 | Default: 1e-5 327 | momentum: the value used for the running_mean and running_var 328 | computation. Default: 0.1 329 | affine: a boolean value that when set to ``True``, gives the layer learnable 330 | affine parameters. Default: ``True`` 331 | 332 | Shape: 333 | - Input: :math:`(N, C, D, H, W)` 334 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 335 | 336 | Examples: 337 | >>> # With Learnable Parameters 338 | >>> m = SynchronizedBatchNorm3d(100) 339 | >>> # Without Learnable Parameters 340 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 341 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 342 | >>> output = m(input) 343 | """ 344 | 345 | def _check_input_dim(self, input): 346 | if input.dim() != 5: 347 | raise ValueError('expected 5D input (got {}D input)' 348 | .format(input.dim())) 349 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /EAL-GAN/models/EAL_GAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import functools 4 | import random 5 | import os 6 | from collections import defaultdict 7 | import matplotlib.pyplot as plt 8 | import matplotlib.font_manager 9 | 10 | from numpy import percentile 11 | import pandas as pd 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import init 16 | import torch.optim as optim 17 | import torch.nn.functional as F 18 | from torch.nn import Parameter as P 19 | 20 | from .layers import ccbn, identity, SNLinear, SNEmbedding 21 | 22 | from .utils import prepare_z_y, active_sampling_V1, sample_selector_V1 23 | from .losses import loss_dis_fake, loss_dis_real 24 | from .pyod_utils import AUC_and_Gmean 25 | 26 | 27 | class Generator(nn.Module): 28 | def __init__(self, dim_z=64, hidden_dim=128, output_dim=128, n_classes=2, hidden_number=1, 29 | init='ortho', SN_used=True): 30 | super(Generator, self).__init__() 31 | self.hidden_dim = hidden_dim 32 | 33 | self.n_classes = n_classes 34 | self.init = init 35 | self.shared_dim = dim_z // 2 36 | 37 | if SN_used: 38 | self.which_linear = functools.partial(SNLinear, 39 | num_svs=1, num_itrs=1) 40 | # self.which_embedding = functools.partial(SNEmbedding, 41 | # num_svs=1, num_itrs=1) 42 | else: 43 | self.which_linear = nn.Linear 44 | self.which_embedding = nn.Embedding 45 | 46 | self.shared = self.which_embedding(n_classes, self.shared_dim) 47 | self.input_fc = self.which_linear(dim_z + self.shared_dim, self.hidden_dim) 48 | 49 | self.output_fc = self.which_linear(self.hidden_dim, output_dim) 50 | 51 | self.model = nn.Sequential(self.input_fc, 52 | nn.ReLU()) 53 | 54 | for index in range(hidden_number): 55 | middle_fc = nn.Linear(self.hidden_dim, self.hidden_dim) 56 | self.model.add_module('hidden-layers-{0}'.format(index), middle_fc) 57 | self.model.add_module('ReLu-{0}'.format(index), nn.ReLU()) 58 | # self.model.add_module('ReLu-{0}'.format(index), nn.Tanh()) 59 | 60 | self.model.add_module('output_layer', self.output_fc) 61 | 62 | self.init_weights() 63 | 64 | def init_weights(self): 65 | self.param_count = 0 66 | for module in self.modules(): 67 | if (isinstance(module, nn.Linear) 68 | or isinstance(module, nn.Embedding)): 69 | if self.init == 'ortho': 70 | init.orthogonal_(module.weight) 71 | elif self.init == 'N02': 72 | init.normal_(module.weight, 0, 0.02) 73 | elif self.init in ['glorot', 'xavier']: 74 | init.xavier_uniform_(module.weight) 75 | else: 76 | print('Init style not recognized...') 77 | self.param_count += sum([p.data.nelement() for p in module.parameters()]) 78 | print('Param count for G''s initialized parameters: %d' % self.param_count) 79 | 80 | def forward(self, z, y): 81 | # If hierarchical, concatenate zs and ys 82 | y = self.shared(y) # modification 83 | h = torch.cat([z, y], 1) 84 | 85 | h = self.model(h) 86 | # Apply batchnorm-relu-conv-tanh at output 87 | return h 88 | 89 | 90 | class Discriminator(nn.Module): 91 | def __init__(self, input_dim=64, hidden_dim=64, output_dim=1, n_classes=2, 92 | hidden_number=1, init='ortho', SN_used=True): 93 | super(Discriminator, self).__init__() 94 | 95 | self.hidden_dim = hidden_dim 96 | self.n_classes = n_classes 97 | self.init = init 98 | 99 | if SN_used: 100 | self.which_linear = functools.partial(SNLinear, 101 | num_svs=1, num_itrs=1) 102 | # self.which_embedding = functools.partial(SNEmbedding, 103 | # num_svs=1, num_itrs=1) 104 | else: 105 | self.which_linear = nn.Linear 106 | self.which_embedding = nn.Embedding 107 | 108 | self.input_fc = self.which_linear(input_dim, self.hidden_dim) 109 | self.output_fc = self.which_linear(self.hidden_dim, output_dim) 110 | self.output_category = nn.Sequential(self.which_linear(self.hidden_dim, output_dim), 111 | nn.Sigmoid()) 112 | # Embedding for projection discrimination 113 | self.embed = self.which_embedding(self.n_classes, self.hidden_dim) 114 | self.model = nn.Sequential(self.input_fc, 115 | nn.ReLU()) 116 | 117 | # self.blocks = [] 118 | for index in range(hidden_number): 119 | middle_fc = self.which_linear(self.hidden_dim, self.hidden_dim) 120 | self.model.add_module('hidden-layers-{0}'.format(index), middle_fc) 121 | self.model.add_module('ReLu-{0}'.format(index), nn.ReLU()) 122 | # self.model.add_module('ReLu-{0}'.format(index), nn.Sigmoid()) 123 | 124 | self.init_weights() 125 | 126 | # Initialize 127 | def init_weights(self): 128 | self.param_count = 0 129 | for module in self.modules(): 130 | if (isinstance(module, nn.Linear) 131 | or isinstance(module, nn.Embedding)): 132 | if self.init == 'ortho': 133 | init.orthogonal_(module.weight) 134 | elif self.init == 'N02': 135 | init.normal_(module.weight, 0, 0.02) 136 | elif self.init in ['glorot', 'xavier']: 137 | init.xavier_uniform_(module.weight) 138 | else: 139 | print('Init style not recognized...') 140 | self.param_count += sum([p.data.nelement() for p in module.parameters()]) 141 | # print('Param count for D''s initialized parameters: %d' % self.param_count) 142 | 143 | def forward(self, x, y=None, mode=0): 144 | # mode 0: train the whole discriminator network 145 | if mode == 0: 146 | h = self.model(x) 147 | out = self.output_fc(h) 148 | # Get projection of final featureset onto class vectors and add to evidence 149 | out_real_fake = out + torch.sum(self.embed(y) * h, 1, keepdim=True) 150 | out_category = self.output_category(h) 151 | return out_real_fake, out_category 152 | # mode 1: train self.output_fc, only classify whether an input is fake or real 153 | elif mode == 1: 154 | h = self.model(x) 155 | out = self.output_fc(h) 156 | return out 157 | # mode 2: train self.output_category, used in fine_tunning stage 158 | else: 159 | h = self.model(x) 160 | out = self.output_category(h) 161 | return out 162 | 163 | 164 | class EAL_GAN(nn.Module): 165 | def __init__(self, args, data_x, data_y, test_x, test_y, visualize=False): 166 | super(EAL_GAN, self).__init__() 167 | 168 | lr_g = args.lr_g 169 | lr_d = args.lr_d 170 | 171 | self.device = torch.device("cuda" if args.cuda else "cpu") 172 | z, y = prepare_z_y(20, args.dim_z, 2, device=self.device) 173 | y = y * 0 + 1 174 | self.y = y.long() 175 | self.noise = z 176 | 177 | self.args = args 178 | # self.data_x = torch.from_numpy(data_x).float() 179 | # self.data_y = torch.from_numpy(data_y).long() 180 | self.data_x = data_x 181 | self.data_y = data_y 182 | self.test_x = test_x 183 | self.test_y = test_y 184 | self.iterations = 0 185 | self.visualize = visualize 186 | 187 | self.feature_size = data_x.shape[1] 188 | self.data_size = data_x.shape[0] 189 | self.batch_size = min(args.batch_size, self.data_size) 190 | self.hidden_dim = self.feature_size * 2 191 | self.dim_z = args.dim_z 192 | 193 | manualSeed = random.randint(1, 10000) 194 | random.seed(manualSeed) 195 | torch.manual_seed(manualSeed) 196 | 197 | # 1: prepare Generator 198 | self.netG = Generator(dim_z=self.dim_z, hidden_dim=self.hidden_dim, output_dim=self.feature_size, n_classes=2, 199 | hidden_number=args.gen_layer, init=args.init_type, SN_used=args.SN_used) 200 | self.optimizerG = optim.Adam(self.netG.parameters(), lr=args.lr_g, betas=(0.00, 0.99)) 201 | 202 | if args.cuda: 203 | self.netG = nn.DataParallel(self.netG, device_ids=[0, 1]) 204 | self.netG = self.netG.to(self.device) 205 | 206 | # 2: create ensemble of discriminator 207 | self.NetD_Ensemble = [] 208 | self.opti_Ensemble = [] 209 | lr_ds = np.random.rand(args.ensemble_num) * (args.lr_d * 5 - args.lr_d) + args.lr_d # learning rate 210 | for index in range(args.ensemble_num): 211 | netD = Discriminator(input_dim=self.feature_size, hidden_dim=self.hidden_dim, output_dim=1, n_classes=2, 212 | hidden_number=args.dis_layer, init=args.init_type, SN_used=args.SN_used) 213 | optimizerD = optim.Adam(netD.parameters(), lr=lr_ds[index], betas=(0.00, 0.99)) 214 | if args.cuda: 215 | netD = nn.DataParallel(netD, device_ids=[0, 1]) 216 | netD = netD.to(self.device) 217 | 218 | self.NetD_Ensemble += [netD] 219 | self.opti_Ensemble += [optimizerD] 220 | 221 | def fit(self): 222 | log_dir = os.path.join('./log/', self.args.data_name) 223 | if not os.path.exists(log_dir): 224 | os.makedirs(log_dir) 225 | z, y = prepare_z_y(self.batch_size, self.dim_z, 2, device=self.device) 226 | # Start iteration 227 | Best_Measure_Recorded = -1 228 | best_auc = 0 229 | best_gmean = 0 230 | self.train_history = defaultdict(list) 231 | for epoch in range(self.args.max_epochs): 232 | train_AUC, train_Gmean, test_auc, test_gmean = self.train_one_epoch(z, y, epoch) 233 | if train_Gmean * train_AUC > Best_Measure_Recorded: 234 | Best_Measure_Recorded = train_Gmean * train_AUC 235 | best_auc = test_auc 236 | best_gmean = test_gmean 237 | states = { 238 | 'epoch': epoch, 239 | 'gen_dict': self.netG.state_dict(), 240 | 'opti_gen': self.optimizerG.state_dict(), 241 | 'max_auc': train_AUC 242 | } 243 | for i in range(self.args.ensemble_num): 244 | netD = self.NetD_Ensemble[i] 245 | optimi_D = self.opti_Ensemble[i] 246 | states['dis_dict' + str(i)] = netD.state_dict() 247 | states['opti_dis' + str(i)] = optimi_D.state_dict() 248 | 249 | torch.save(states, os.path.join(log_dir, 'checkpoint_best.pth')) 250 | # print(train_AUC, test_AUC, epoch) 251 | if self.args.print: 252 | print('Training for epoch %d: Train_AUC=%.4f train_Gmean=%.4f Test_AUC=%.4f Test_Gmean=%.4f' % ( 253 | epoch + 1, train_AUC, train_Gmean, test_auc, test_gmean)) 254 | 255 | # step 1: load the best models 256 | self.Best_Ensemble = [] 257 | states = torch.load(os.path.join(log_dir, 'checkpoint_best.pth')) 258 | self.netG.load_state_dict(states['gen_dict']) 259 | for i in range(self.args.ensemble_num): 260 | netD = self.NetD_Ensemble[i] 261 | netD.load_state_dict(states['dis_dict' + str(i)]) 262 | self.Best_Ensemble += [netD] 263 | 264 | return best_auc, best_gmean 265 | 266 | def predict(self, test_x, test_y, dis_Ensemble=None): 267 | p = [] 268 | y = [] 269 | 270 | data_size = test_x.shape[0] 271 | num_batches = data_size // self.batch_size 272 | num_batches = num_batches + 1 if data_size % self.batch_size > 0 else num_batches 273 | 274 | for index in range(num_batches): 275 | end_pos = min(data_size, (index + 1) * self.batch_size) 276 | real_x = test_x[index * self.batch_size: end_pos] 277 | real_y = test_y[index * self.batch_size: end_pos] 278 | 279 | final_pt = 0 280 | for i in range(self.args.ensemble_num): 281 | pt = self.Best_Ensemble[i](real_x, mode=2) if dis_Ensemble is None else dis_Ensemble[i](real_x, mode=2) 282 | final_pt = pt.detach() if i == 0 else final_pt + pt 283 | 284 | final_pt /= self.args.ensemble_num 285 | final_pt = final_pt.view(-1, ) 286 | 287 | p += [final_pt] 288 | y += [real_y] 289 | p = torch.cat(p, 0).cpu().detach().numpy() 290 | y = torch.cat(y, 0).cpu().detach().numpy() 291 | return y, p 292 | 293 | def train_one_epoch(self, z, y, epoch=1): 294 | # train discriminator & generator for one specific spoch 295 | 296 | data_size = self.data_x.shape[0] 297 | # feature_size = data_x.shape[1] 298 | batch_size = min(self.args.batch_size, data_size) 299 | 300 | num_batches = data_size // batch_size 301 | num_batches = num_batches + 1 if data_size % batch_size > 0 else num_batches 302 | 303 | # data shuffer 304 | perm_index = torch.randperm(data_size) 305 | if self.args.cuda: 306 | perm_index = perm_index.cuda() 307 | # data_x = data_x[perm_index] 308 | # data_y = data_y[perm_index] 309 | 310 | x_empirical, y_empirical = [], [] 311 | 312 | 313 | for index in range(num_batches): 314 | # step 1: train the ensemble of discriminator 315 | # Get training data 316 | self.iterations += 1 317 | 318 | end_pos = min(data_size, (index + 1) * batch_size) 319 | real_x = self.data_x[index * batch_size: end_pos] 320 | real_y = self.data_y[index * batch_size: end_pos] 321 | 322 | real_weights = None 323 | fake_weights = None 324 | real_cat_weights = None 325 | fake_cat_weights = None 326 | 327 | z.sample_() 328 | y.sample_() 329 | generated_x = self.netG(z, y) 330 | generated_x = generated_x.detach() 331 | 332 | losses = [] 333 | dis_loss = 0 334 | gen_loss = 0 335 | # select p% of the training data, label them 336 | real_x_selected, real_y_selected, _ = active_sampling_V1(self.args, real_x, real_y, self.NetD_Ensemble, 337 | need_sample=(self.iterations > 1)) 338 | x_empirical += [real_x_selected] 339 | y_empirical += [real_y_selected] 340 | x_empirical += [generated_x] 341 | y_empirical += [y] 342 | # print(real_y_selected.shape) 343 | for i in range(self.args.ensemble_num): 344 | optimizer = self.opti_Ensemble[i] 345 | netD = self.NetD_Ensemble[i] 346 | # train the GAN with real data 347 | 348 | out_real_fake, out_real_categoy = netD(real_x_selected, real_y_selected) 349 | 350 | loss1, real_loss_cat, real_weights, real_cat_weights = loss_dis_real(out_real_fake, out_real_categoy, 351 | real_y_selected, real_weights, 352 | real_cat_weights) 353 | real_loss = loss1 + real_loss_cat 354 | 355 | # train on fake data 356 | output, out_fake_category = netD(generated_x, y.detach()) 357 | loss2, loss_cat_fake, fake_weights, fake_cat_weights = loss_dis_fake(output, out_fake_category, y, 358 | fake_weights, fake_cat_weights) 359 | fake_loss = loss2 + loss_cat_fake 360 | sum_loss = real_loss + fake_loss 361 | dis_loss += sum_loss 362 | 363 | self.train_history['discriminator_loss_' + str(i)].append(sum_loss) 364 | losses += [sum_loss] 365 | optimizer.zero_grad() 366 | sum_loss.backward(retain_graph=True) 367 | optimizer.step() 368 | self.train_history['discriminator_loss'].append(dis_loss) 369 | 370 | # step 2: train the generator 371 | z.sample_() 372 | y.sample_() 373 | generated_x = self.netG(z, y) 374 | gen_loss = 0 375 | gen_weights = None 376 | gen_cat_weights = None 377 | for i in range(self.args.ensemble_num): 378 | # optimizer = names['optimizerD_' + str(i)] 379 | netD = self.NetD_Ensemble[i] 380 | output, out_category = netD(generated_x, y) 381 | # out_real_fake, out_real_categoy = netD(real_x, real_y) 382 | loss, loss_cat, gen_weights, gen_cat_weights = loss_dis_real(output, out_category, y, gen_weights, 383 | gen_cat_weights) 384 | # loss, gen_weights = loss_gen(output, gen_weights) 385 | gen_loss += (loss + loss_cat) 386 | self.train_history['generator_loss'].append(gen_loss) 387 | self.optimizerG.zero_grad() 388 | gen_loss.backward() 389 | self.optimizerG.step() 390 | 391 | x_empirical = torch.cat(x_empirical, 0) 392 | y_empirical = torch.cat(y_empirical, 0) 393 | y_train, y_pred_train = self.predict(x_empirical, y_empirical, self.NetD_Ensemble) 394 | 395 | y_scores_pandas = pd.DataFrame(y_pred_train) 396 | y_scores_pandas.fillna(0, inplace=True) 397 | y_pred_train = y_scores_pandas.values 398 | 399 | auc, gmean = AUC_and_Gmean(y_train, y_pred_train) 400 | self.train_history['train_auc'].append(auc) 401 | self.train_history['train_Gmean'].append(gmean) 402 | 403 | y_test, y_pred_test = self.predict(self.test_x, self.test_y, self.NetD_Ensemble) 404 | y_scores_pandas = pd.DataFrame(y_pred_test) 405 | y_scores_pandas.fillna(0, inplace=True) 406 | y_pred_test = y_scores_pandas.values 407 | test_auc, test_gmean = AUC_and_Gmean(y_test, y_pred_test) 408 | 409 | return auc, gmean, test_auc, test_gmean 410 | 411 | -------------------------------------------------------------------------------- /EAL-GAN-image/src/layers.py: -------------------------------------------------------------------------------- 1 | ''' Layers 2 | This file contains various layers for the BigGAN models. 3 | ''' 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torch.nn import Parameter as P 11 | 12 | from src.sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d 13 | 14 | 15 | # Projection of x onto y 16 | def proj(x, y): 17 | return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) 18 | 19 | 20 | # Orthogonalize x wrt list of vectors ys 21 | def gram_schmidt(x, ys): 22 | for y in ys: 23 | x = x - proj(x, y) 24 | return x 25 | 26 | 27 | # Apply num_itrs steps of the power method to estimate top N singular values. 28 | def power_iteration(W, u_, update=True, eps=1e-12): 29 | # Lists holding singular vectors and values 30 | us, vs, svs = [], [], [] 31 | for i, u in enumerate(u_): 32 | # Run one step of the power iteration 33 | with torch.no_grad(): 34 | v = torch.matmul(u, W) 35 | # Run Gram-Schmidt to subtract components of all other singular vectors 36 | v = F.normalize(gram_schmidt(v, vs), eps=eps) 37 | # Add to the list 38 | vs += [v] 39 | # Update the other singular vector 40 | u = torch.matmul(v, W.t()) 41 | # Run Gram-Schmidt to subtract components of all other singular vectors 42 | u = F.normalize(gram_schmidt(u, us), eps=eps) 43 | # Add to the list 44 | us += [u] 45 | if update: 46 | u_[i][:] = u 47 | # Compute this singular value and add it to the list 48 | svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] 49 | #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] 50 | return svs, us, vs 51 | 52 | 53 | # Convenience passthrough function 54 | class identity(nn.Module): 55 | def forward(self, input): 56 | return input 57 | 58 | 59 | # Spectral normalization base class 60 | class SN(object): 61 | def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): 62 | # Number of power iterations per step 63 | self.num_itrs = num_itrs 64 | # Number of singular values 65 | self.num_svs = num_svs 66 | # Transposed? 67 | self.transpose = transpose 68 | # Epsilon value for avoiding divide-by-0 69 | self.eps = eps 70 | # Register a singular vector for each sv 71 | for i in range(self.num_svs): 72 | self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) 73 | self.register_buffer('sv%d' % i, torch.ones(1)) 74 | 75 | # Singular vectors (u side) 76 | @property 77 | def u(self): 78 | return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] 79 | 80 | # Singular values; 81 | # note that these buffers are just for logging and are not used in training. 82 | @property 83 | def sv(self): 84 | return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] 85 | 86 | # Compute the spectrally-normalized weight 87 | def W_(self): 88 | W_mat = self.weight.view(self.weight.size(0), -1) 89 | if self.transpose: 90 | W_mat = W_mat.t() 91 | # Apply num_itrs power iterations 92 | for _ in range(self.num_itrs): 93 | svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) 94 | # Update the svs 95 | if self.training: 96 | with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! 97 | for i, sv in enumerate(svs): 98 | self.sv[i][:] = sv 99 | return self.weight / svs[0] 100 | 101 | 102 | # 2D Conv layer with spectral norm 103 | class SNConv2d(nn.Conv2d, SN): 104 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 105 | padding=0, dilation=1, groups=1, bias=True, 106 | num_svs=1, num_itrs=1, eps=1e-12): 107 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, 108 | padding, dilation, groups, bias) 109 | SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) 110 | def forward(self, x): 111 | return F.conv2d(x, self.W_(), self.bias, self.stride, 112 | self.padding, self.dilation, self.groups) 113 | 114 | 115 | # Linear layer with spectral norm 116 | class SNLinear(nn.Linear, SN): 117 | def __init__(self, in_features, out_features, bias=True, 118 | num_svs=1, num_itrs=1, eps=1e-12): 119 | nn.Linear.__init__(self, in_features, out_features, bias) 120 | SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) 121 | def forward(self, x): 122 | return F.linear(x, self.W_(), self.bias) 123 | 124 | 125 | # Embedding layer with spectral norm 126 | # We use num_embeddings as the dim instead of embedding_dim here 127 | # for convenience sake 128 | class SNEmbedding(nn.Embedding, SN): 129 | def __init__(self, num_embeddings, embedding_dim, padding_idx=None, 130 | max_norm=None, norm_type=2, scale_grad_by_freq=False, 131 | sparse=False, _weight=None, 132 | num_svs=1, num_itrs=1, eps=1e-12): 133 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, 134 | max_norm, norm_type, scale_grad_by_freq, 135 | sparse, _weight) 136 | SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) 137 | def forward(self, x): 138 | return F.embedding(x, self.W_()) 139 | 140 | 141 | # A non-local block as used in SA-GAN 142 | # Note that the implementation as described in the paper is largely incorrect; 143 | # refer to the released code for the actual implementation. 144 | class Attention(nn.Module): 145 | def __init__(self, ch, which_conv=SNConv2d, name='attention'): 146 | super(Attention, self).__init__() 147 | # Channel multiplier 148 | self.ch = ch 149 | self.which_conv = which_conv 150 | self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) 151 | self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) 152 | self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) 153 | self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) 154 | # Learnable gain parameter 155 | self.gamma = P(torch.tensor(0.), requires_grad=True) 156 | def forward(self, x, y=None): 157 | # Apply convs 158 | theta = self.theta(x) 159 | phi = F.max_pool2d(self.phi(x), [2,2]) 160 | g = F.max_pool2d(self.g(x), [2,2]) 161 | # Perform reshapes 162 | theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3]) 163 | phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4) 164 | g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4) 165 | # Matmul and softmax to get attention maps 166 | beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) 167 | # Attention map times g path 168 | o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) 169 | return self.gamma * o + x 170 | 171 | 172 | # Fused batchnorm op 173 | def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): 174 | # Apply scale and shift--if gain and bias are provided, fuse them here 175 | # Prepare scale 176 | scale = torch.rsqrt(var + eps) 177 | # If a gain is provided, use it 178 | if gain is not None: 179 | scale = scale * gain 180 | # Prepare shift 181 | shift = mean * scale 182 | # If bias is provided, use it 183 | if bias is not None: 184 | shift = shift - bias 185 | return x * scale - shift 186 | #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. 187 | 188 | 189 | # Manual BN 190 | # Calculate means and variances using mean-of-squares minus mean-squared 191 | def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): 192 | # Cast x to float32 if necessary 193 | float_x = x.float() 194 | # Calculate expected value of x (m) and expected value of x**2 (m2) 195 | # Mean of x 196 | m = torch.mean(float_x, [0, 2, 3], keepdim=True) 197 | # Mean of x squared 198 | m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) 199 | # Calculate variance as mean of squared minus mean squared. 200 | var = (m2 - m **2) 201 | # Cast back to float 16 if necessary 202 | var = var.type(x.type()) 203 | m = m.type(x.type()) 204 | # Return mean and variance for updating stored mean/var if requested 205 | if return_mean_var: 206 | return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() 207 | else: 208 | return fused_bn(x, m, var, gain, bias, eps) 209 | 210 | 211 | # My batchnorm, supports standing stats 212 | class myBN(nn.Module): 213 | def __init__(self, num_channels, eps=1e-5, momentum=0.1): 214 | super(myBN, self).__init__() 215 | # momentum for updating running stats 216 | self.momentum = momentum 217 | # epsilon to avoid dividing by 0 218 | self.eps = eps 219 | # Momentum 220 | self.momentum = momentum 221 | # Register buffers 222 | self.register_buffer('stored_mean', torch.zeros(num_channels)) 223 | self.register_buffer('stored_var', torch.ones(num_channels)) 224 | self.register_buffer('accumulation_counter', torch.zeros(1)) 225 | # Accumulate running means and vars 226 | self.accumulate_standing = False 227 | 228 | # reset standing stats 229 | def reset_stats(self): 230 | self.stored_mean[:] = 0 231 | self.stored_var[:] = 0 232 | self.accumulation_counter[:] = 0 233 | 234 | def forward(self, x, gain, bias): 235 | if self.training: 236 | out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) 237 | # If accumulating standing stats, increment them 238 | if self.accumulate_standing: 239 | self.stored_mean[:] = self.stored_mean + mean.data 240 | self.stored_var[:] = self.stored_var + var.data 241 | self.accumulation_counter += 1.0 242 | # If not accumulating standing stats, take running averages 243 | else: 244 | self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum 245 | self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum 246 | return out 247 | # If not in training mode, use the stored statistics 248 | else: 249 | mean = self.stored_mean.view(1, -1, 1, 1) 250 | var = self.stored_var.view(1, -1, 1, 1) 251 | # If using standing stats, divide them by the accumulation counter 252 | if self.accumulate_standing: 253 | mean = mean / self.accumulation_counter 254 | var = var / self.accumulation_counter 255 | return fused_bn(x, mean, var, gain, bias, self.eps) 256 | 257 | 258 | # Simple function to handle groupnorm norm stylization 259 | def groupnorm(x, norm_style): 260 | # If number of channels specified in norm_style: 261 | if 'ch' in norm_style: 262 | ch = int(norm_style.split('_')[-1]) 263 | groups = max(int(x.shape[1]) // ch, 1) 264 | # If number of groups specified in norm style 265 | elif 'grp' in norm_style: 266 | groups = int(norm_style.split('_')[-1]) 267 | # If neither, default to groups = 16 268 | else: 269 | groups = 16 270 | return F.group_norm(x, groups) 271 | 272 | 273 | # Class-conditional bn 274 | # output size is the number of channels, input size is for the linear layers 275 | # Andy's Note: this class feels messy but I'm not really sure how to clean it up 276 | # Suggestions welcome! (By which I mean, refactor this and make a pull request 277 | # if you want to make this more readable/usable). 278 | class ccbn(nn.Module): 279 | def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, 280 | cross_replica=False, mybn=False, norm_style='bn',): 281 | super(ccbn, self).__init__() 282 | self.output_size, self.input_size = output_size, input_size 283 | # Prepare gain and bias layers 284 | self.gain = which_linear(input_size, output_size) 285 | self.bias = which_linear(input_size, output_size) 286 | # epsilon to avoid dividing by 0 287 | self.eps = eps 288 | # Momentum 289 | self.momentum = momentum 290 | # Use cross-replica batchnorm? 291 | self.cross_replica = cross_replica 292 | # Use my batchnorm? 293 | self.mybn = mybn 294 | # Norm style? 295 | self.norm_style = norm_style 296 | 297 | if self.cross_replica: 298 | self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) 299 | elif self.mybn: 300 | self.bn = myBN(output_size, self.eps, self.momentum) 301 | elif self.norm_style in ['bn', 'in']: 302 | self.register_buffer('stored_mean', torch.zeros(output_size)) 303 | self.register_buffer('stored_var', torch.ones(output_size)) 304 | 305 | 306 | def forward(self, x, y): 307 | # Calculate class-conditional gains and biases 308 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) 309 | bias = self.bias(y).view(y.size(0), -1, 1, 1) 310 | # If using my batchnorm 311 | if self.mybn or self.cross_replica: 312 | return self.bn(x, gain=gain, bias=bias) 313 | # else: 314 | else: 315 | if self.norm_style == 'bn': 316 | out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, 317 | self.training, 0.1, self.eps) 318 | elif self.norm_style == 'in': 319 | out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, 320 | self.training, 0.1, self.eps) 321 | elif self.norm_style == 'gn': 322 | out = groupnorm(x, self.normstyle) 323 | elif self.norm_style == 'nonorm': 324 | out = x 325 | return out * gain + bias 326 | def extra_repr(self): 327 | s = 'out: {output_size}, in: {input_size},' 328 | s +=' cross_replica={cross_replica}' 329 | return s.format(**self.__dict__) 330 | 331 | 332 | # Normal, non-class-conditional BN 333 | class bn(nn.Module): 334 | def __init__(self, output_size, eps=1e-5, momentum=0.1, 335 | cross_replica=False, mybn=False): 336 | super(bn, self).__init__() 337 | self.output_size= output_size 338 | # Prepare gain and bias layers 339 | self.gain = P(torch.ones(output_size), requires_grad=True) 340 | self.bias = P(torch.zeros(output_size), requires_grad=True) 341 | # epsilon to avoid dividing by 0 342 | self.eps = eps 343 | # Momentum 344 | self.momentum = momentum 345 | # Use cross-replica batchnorm? 346 | self.cross_replica = cross_replica 347 | # Use my batchnorm? 348 | self.mybn = mybn 349 | 350 | if self.cross_replica: 351 | self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) 352 | elif mybn: 353 | self.bn = myBN(output_size, self.eps, self.momentum) 354 | # Register buffers if neither of the above 355 | else: 356 | self.register_buffer('stored_mean', torch.zeros(output_size)) 357 | self.register_buffer('stored_var', torch.ones(output_size)) 358 | 359 | def forward(self, x, y=None): 360 | if self.cross_replica or self.mybn: 361 | gain = self.gain.view(1,-1,1,1) 362 | bias = self.bias.view(1,-1,1,1) 363 | return self.bn(x, gain=gain, bias=bias) 364 | else: 365 | return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, 366 | self.bias, self.training, self.momentum, self.eps) 367 | 368 | 369 | # Generator blocks 370 | # Note that this class assumes the kernel size and padding (and any other 371 | # settings) have been selected in the main generator module and passed in 372 | # through the which_conv arg. Similar rules apply with which_bn (the input 373 | # size [which is actually the number of channels of the conditional info] must 374 | # be preselected) 375 | class GBlock(nn.Module): 376 | def __init__(self, in_channels, out_channels, 377 | which_conv=nn.Conv2d, which_bn=bn, activation=None, 378 | upsample=None): 379 | super(GBlock, self).__init__() 380 | 381 | self.in_channels, self.out_channels = in_channels, out_channels 382 | self.which_conv, self.which_bn = which_conv, which_bn 383 | self.activation = activation 384 | self.upsample = upsample 385 | # Conv layers 386 | self.conv1 = self.which_conv(self.in_channels, self.out_channels) 387 | self.conv2 = self.which_conv(self.out_channels, self.out_channels) 388 | self.learnable_sc = in_channels != out_channels or upsample 389 | if self.learnable_sc: 390 | self.conv_sc = self.which_conv(in_channels, out_channels, 391 | kernel_size=1, padding=0) 392 | # Batchnorm layers 393 | self.bn1 = self.which_bn(in_channels) 394 | self.bn2 = self.which_bn(out_channels) 395 | # upsample layers 396 | self.upsample = upsample 397 | 398 | def forward(self, x, y): 399 | h = self.activation(self.bn1(x, y)) 400 | if self.upsample: 401 | h = self.upsample(h) 402 | x = self.upsample(x) 403 | h = self.conv1(h) 404 | h = self.activation(self.bn2(h, y)) 405 | h = self.conv2(h) 406 | if self.learnable_sc: 407 | x = self.conv_sc(x) 408 | return h + x 409 | 410 | 411 | # Residual block for the discriminator 412 | class DBlock(nn.Module): 413 | def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, 414 | preactivation=False, activation=None, downsample=None,): 415 | super(DBlock, self).__init__() 416 | self.in_channels, self.out_channels = in_channels, out_channels 417 | # If using wide D (as in SA-GAN and BigGAN), change the channel pattern 418 | self.hidden_channels = self.out_channels if wide else self.in_channels 419 | self.which_conv = which_conv 420 | self.preactivation = preactivation 421 | self.activation = activation 422 | self.downsample = downsample 423 | 424 | # Conv layers 425 | self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) 426 | self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) 427 | self.learnable_sc = True if (in_channels != out_channels) or downsample else False 428 | if self.learnable_sc: 429 | self.conv_sc = self.which_conv(in_channels, out_channels, 430 | kernel_size=1, padding=0) 431 | def shortcut(self, x): 432 | if self.preactivation: 433 | if self.learnable_sc: 434 | x = self.conv_sc(x) 435 | if self.downsample: 436 | x = self.downsample(x) 437 | else: 438 | if self.downsample: 439 | x = self.downsample(x) 440 | if self.learnable_sc: 441 | x = self.conv_sc(x) 442 | return x 443 | 444 | def forward(self, x): 445 | if self.preactivation: 446 | # h = self.activation(x) # NOT TODAY SATAN 447 | # Andy's note: This line *must* be an out-of-place ReLU or it 448 | # will negatively affect the shortcut connection. 449 | h = F.relu(x) 450 | else: 451 | h = x 452 | h = self.conv1(h) 453 | h = self.conv2(self.activation(h)) 454 | if self.downsample: 455 | h = self.downsample(h) 456 | 457 | return h + self.shortcut(x) 458 | 459 | # dogball -------------------------------------------------------------------------------- /EAL-GAN-image/src/BigGAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import functools 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torch.nn import Parameter as P 11 | 12 | import src.layers as layers 13 | from src.sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d 14 | 15 | 16 | # Architectures for G 17 | # Attention is passed in in the format '32_64' to mean applying an attention 18 | # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. 19 | def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): 20 | arch = {} 21 | arch[512] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], 22 | 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], 23 | 'upsample' : [True] * 7, 24 | 'resolution' : [8, 16, 32, 64, 128, 256, 512], 25 | 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) 26 | for i in range(3,10)}} 27 | arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]], 28 | 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]], 29 | 'upsample' : [True] * 6, 30 | 'resolution' : [8, 16, 32, 64, 128, 256], 31 | 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) 32 | for i in range(3,9)}} 33 | arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]], 34 | 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]], 35 | 'upsample' : [True] * 5, 36 | 'resolution' : [8, 16, 32, 64, 128], 37 | 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) 38 | for i in range(3,8)}} 39 | arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]], 40 | 'out_channels' : [ch * item for item in [16, 8, 4, 2]], 41 | 'upsample' : [True] * 4, 42 | 'resolution' : [8, 16, 32, 64], 43 | 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) 44 | for i in range(3,7)}} 45 | arch[32] = {'in_channels' : [ch * item for item in [4, 4, 4]], 46 | 'out_channels' : [ch * item for item in [4, 4, 4]], 47 | 'upsample' : [True] * 3, 48 | 'resolution' : [8, 16, 32], 49 | 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) 50 | for i in range(3,6)}} 51 | 52 | return arch 53 | 54 | class Generator(nn.Module): 55 | def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128, 56 | G_kernel_size=3, G_attn='64', n_classes=1000, 57 | num_G_SVs=1, num_G_SV_itrs=1, 58 | G_shared=True, shared_dim=0, hier=True, 59 | cross_replica=False, mybn=False, 60 | G_activation=nn.ReLU(inplace=False), 61 | G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8, 62 | BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False, 63 | G_init='ortho', skip_init=False, no_optim=False, 64 | G_param='SN', norm_style='bn', 65 | **kwargs): 66 | super(Generator, self).__init__() 67 | # Channel width mulitplier 68 | self.ch = G_ch 69 | # Dimensionality of the latent space 70 | self.dim_z = dim_z 71 | # The initial spatial dimensions 72 | self.bottom_width = bottom_width 73 | # Resolution of the output 74 | self.resolution = resolution 75 | # Kernel size? 76 | self.kernel_size = G_kernel_size 77 | # Attention? 78 | self.attention = G_attn 79 | # number of classes, for use in categorical conditional generation 80 | self.n_classes = n_classes 81 | # Use shared embeddings? 82 | self.G_shared = G_shared 83 | # Dimensionality of the shared embedding? Unused if not using G_shared 84 | self.shared_dim = shared_dim if shared_dim > 0 else dim_z 85 | # Hierarchical latent space? 86 | self.hier = hier 87 | # Cross replica batchnorm? 88 | self.cross_replica = cross_replica 89 | # Use my batchnorm? 90 | self.mybn = mybn 91 | # nonlinearity for residual blocks 92 | self.activation = G_activation 93 | # Initialization style 94 | self.init = G_init 95 | # Parameterization style 96 | self.G_param = G_param 97 | # Normalization style 98 | self.norm_style = norm_style 99 | # Epsilon for BatchNorm? 100 | self.BN_eps = BN_eps 101 | # Epsilon for Spectral Norm? 102 | self.SN_eps = SN_eps 103 | # fp16? 104 | self.fp16 = G_fp16 105 | # Architecture dict 106 | self.arch = G_arch(self.ch, self.attention)[resolution] 107 | 108 | # If using hierarchical latents, adjust z 109 | if self.hier: 110 | # Number of places z slots into 111 | self.num_slots = len(self.arch['in_channels']) + 1 112 | self.z_chunk_size = (self.dim_z // self.num_slots) 113 | # Recalculate latent dimensionality for even splitting into chunks 114 | self.dim_z = self.z_chunk_size * self.num_slots 115 | else: 116 | self.num_slots = 1 117 | self.z_chunk_size = 0 118 | 119 | # Which convs, batchnorms, and linear layers to use 120 | if self.G_param == 'SN': 121 | self.which_conv = functools.partial(layers.SNConv2d, 122 | kernel_size=3, padding=1, 123 | num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, 124 | eps=self.SN_eps) 125 | self.which_linear = functools.partial(layers.SNLinear, 126 | num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, 127 | eps=self.SN_eps) 128 | else: 129 | self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) 130 | self.which_linear = nn.Linear 131 | 132 | # We use a non-spectral-normed embedding here regardless; 133 | # For some reason applying SN to G's embedding seems to randomly cripple G 134 | self.which_embedding = nn.Embedding 135 | bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared 136 | else self.which_embedding) 137 | self.which_bn = functools.partial(layers.ccbn, 138 | which_linear=bn_linear, 139 | cross_replica=self.cross_replica, 140 | mybn=self.mybn, 141 | input_size=(self.shared_dim + self.z_chunk_size if self.G_shared 142 | else self.n_classes), 143 | norm_style=self.norm_style, 144 | eps=self.BN_eps) 145 | 146 | 147 | # Prepare model 148 | # If not using shared embeddings, self.shared is just a passthrough 149 | self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared 150 | else layers.identity()) 151 | # First linear layer 152 | self.linear = self.which_linear(self.dim_z // self.num_slots, 153 | self.arch['in_channels'][0] * (self.bottom_width **2)) 154 | 155 | # self.blocks is a doubly-nested list of modules, the outer loop intended 156 | # to be over blocks at a given resolution (resblocks and/or self-attention) 157 | # while the inner loop is over a given block 158 | self.blocks = [] 159 | for index in range(len(self.arch['out_channels'])): 160 | self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], 161 | out_channels=self.arch['out_channels'][index], 162 | which_conv=self.which_conv, 163 | which_bn=self.which_bn, 164 | activation=self.activation, 165 | upsample=(functools.partial(F.interpolate, scale_factor=2) 166 | if self.arch['upsample'][index] else None))]] 167 | 168 | # If attention on this block, attach it to the end 169 | if self.arch['attention'][self.arch['resolution'][index]]: 170 | print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) 171 | self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] 172 | 173 | # Turn self.blocks into a ModuleList so that it's all properly registered. 174 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) 175 | 176 | # output layer: batchnorm-relu-conv. 177 | # Consider using a non-spectral conv here 178 | self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], 179 | cross_replica=self.cross_replica, 180 | mybn=self.mybn), 181 | self.activation, 182 | self.which_conv(self.arch['out_channels'][-1], 3)) 183 | 184 | # Initialize weights. Optionally skip init for testing. 185 | if not skip_init: 186 | self.init_weights() 187 | 188 | # Set up optimizer 189 | # If this is an EMA copy, no need for an optim, so just return now 190 | if no_optim: 191 | return 192 | self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps 193 | if G_mixed_precision: 194 | print('Using fp16 adam in G...') 195 | import utils 196 | self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, 197 | betas=(self.B1, self.B2), weight_decay=0, 198 | eps=self.adam_eps) 199 | else: 200 | self.optim = optim.Adam(params=self.parameters(), lr=self.lr, 201 | betas=(self.B1, self.B2), weight_decay=0, 202 | eps=self.adam_eps) 203 | 204 | # LR scheduling, left here for forward compatibility 205 | # self.lr_sched = {'itr' : 0}# if self.progressive else {} 206 | # self.j = 0 207 | 208 | # Initialize 209 | def init_weights(self): 210 | self.param_count = 0 211 | for module in self.modules(): 212 | if (isinstance(module, nn.Conv2d) 213 | or isinstance(module, nn.Linear) 214 | or isinstance(module, nn.Embedding)): 215 | if self.init == 'ortho': 216 | init.orthogonal_(module.weight) 217 | elif self.init == 'N02': 218 | init.normal_(module.weight, 0, 0.02) 219 | elif self.init in ['glorot', 'xavier']: 220 | init.xavier_uniform_(module.weight) 221 | else: 222 | print('Init style not recognized...') 223 | self.param_count += sum([p.data.nelement() for p in module.parameters()]) 224 | print('Param count for G''s initialized parameters: %d' % self.param_count) 225 | 226 | # Note on this forward function: we pass in a y vector which has 227 | # already been passed through G.shared to enable easy class-wise 228 | # interpolation later. If we passed in the one-hot and then ran it through 229 | # G.shared in this forward function, it would be harder to handle. 230 | def forward(self, z, y): 231 | # If hierarchical, concatenate zs and ys 232 | y = self.shared(y) 233 | if self.hier: 234 | zs = torch.split(z, self.z_chunk_size, 1) 235 | z = zs[0] 236 | ys = [torch.cat([y, item], 1) for item in zs[1:]] 237 | else: 238 | ys = [y] * len(self.blocks) 239 | 240 | # First linear layer 241 | h = self.linear(z) 242 | # Reshape 243 | h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) 244 | 245 | # Loop over blocks 246 | for index, blocklist in enumerate(self.blocks): 247 | # Second inner loop in case block has multiple layers 248 | for block in blocklist: 249 | h = block(h, ys[index]) 250 | 251 | # Apply batchnorm-relu-conv-tanh at output 252 | return torch.tanh(self.output_layer(h)) 253 | 254 | 255 | # Discriminator architecture, same paradigm as G's above 256 | def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'): 257 | arch = {} 258 | arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]], 259 | 'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], 260 | 'downsample' : [True] * 6 + [False], 261 | 'resolution' : [128, 64, 32, 16, 8, 4, 4 ], 262 | 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] 263 | for i in range(2,8)}} 264 | arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]], 265 | 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]], 266 | 'downsample' : [True] * 5 + [False], 267 | 'resolution' : [64, 32, 16, 8, 4, 4], 268 | 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] 269 | for i in range(2,8)}} 270 | arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]], 271 | 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]], 272 | 'downsample' : [True] * 4 + [False], 273 | 'resolution' : [32, 16, 8, 4, 4], 274 | 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] 275 | for i in range(2,7)}} 276 | arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]], 277 | 'out_channels' : [item * ch for item in [4, 4, 4, 4]], 278 | 'downsample' : [True, True, False, False], 279 | 'resolution' : [16, 16, 16, 16], 280 | 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')] 281 | for i in range(2,6)}} 282 | return arch 283 | 284 | class Discriminator(nn.Module): 285 | 286 | def __init__(self, D_ch=64, D_wide=True, resolution=128, 287 | D_kernel_size=3, D_attn='64', n_classes=1000, 288 | num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), 289 | D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, 290 | SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, 291 | D_init='ortho', skip_init=False, D_param='SN', **kwargs): 292 | super(Discriminator, self).__init__() 293 | # Width multiplier 294 | self.ch = D_ch 295 | # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? 296 | self.D_wide = D_wide 297 | # Resolution 298 | self.resolution = resolution 299 | # Kernel size 300 | self.kernel_size = D_kernel_size 301 | # Attention? 302 | self.attention = D_attn 303 | # Number of classes 304 | self.n_classes = n_classes 305 | # Activation 306 | self.activation = D_activation 307 | # Initialization style 308 | self.init = D_init 309 | # Parameterization style 310 | self.D_param = D_param 311 | # Epsilon for Spectral Norm? 312 | self.SN_eps = SN_eps 313 | # Fp16? 314 | self.fp16 = D_fp16 315 | # Architecture 316 | self.arch = D_arch(self.ch, self.attention)[resolution] 317 | 318 | # Which convs, batchnorms, and linear layers to use 319 | # No option to turn off SN in D right now 320 | if self.D_param == 'SN': 321 | self.which_conv = functools.partial(layers.SNConv2d, 322 | kernel_size=3, padding=1, 323 | num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, 324 | eps=self.SN_eps) 325 | self.which_linear = functools.partial(layers.SNLinear, 326 | num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, 327 | eps=self.SN_eps) 328 | self.which_embedding = functools.partial(layers.SNEmbedding, 329 | num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, 330 | eps=self.SN_eps) 331 | # Prepare model 332 | # self.blocks is a doubly-nested list of modules, the outer loop intended 333 | # to be over blocks at a given resolution (resblocks and/or self-attention) 334 | self.blocks = [] 335 | for index in range(len(self.arch['out_channels'])): 336 | self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], 337 | out_channels=self.arch['out_channels'][index], 338 | which_conv=self.which_conv, 339 | wide=self.D_wide, 340 | activation=self.activation, 341 | preactivation=(index > 0), 342 | downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] 343 | # If attention on this block, attach it to the end 344 | if self.arch['attention'][self.arch['resolution'][index]]: 345 | print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) 346 | self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], 347 | self.which_conv)] 348 | # Turn self.blocks into a ModuleList so that it's all properly registered. 349 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) 350 | # Linear output layer. The output dimension is typically 1, but may be 351 | # larger if we're e.g. turning this into a VAE with an inference output 352 | self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) 353 | self.output_cat = self.which_linear(self.arch['out_channels'][-1], 1) 354 | # Embedding for projection discrimination 355 | self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) 356 | 357 | # Initialize weights 358 | if not skip_init: 359 | self.init_weights() 360 | 361 | # Set up optimizer 362 | self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps 363 | if D_mixed_precision: 364 | print('Using fp16 adam in D...') 365 | import utils 366 | self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, 367 | betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) 368 | else: 369 | self.optim = optim.Adam(params=self.parameters(), lr=self.lr, 370 | betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) 371 | # LR scheduling, left here for forward compatibility 372 | # self.lr_sched = {'itr' : 0}# if self.progressive else {} 373 | # self.j = 0 374 | 375 | # Initialize 376 | def init_weights(self): 377 | self.param_count = 0 378 | for module in self.modules(): 379 | if (isinstance(module, nn.Conv2d) 380 | or isinstance(module, nn.Linear) 381 | or isinstance(module, nn.Embedding)): 382 | if self.init == 'ortho': 383 | init.orthogonal_(module.weight) 384 | elif self.init == 'N02': 385 | init.normal_(module.weight, 0, 0.02) 386 | elif self.init in ['glorot', 'xavier']: 387 | init.xavier_uniform_(module.weight) 388 | else: 389 | print('Init style not recognized...') 390 | self.param_count += sum([p.data.nelement() for p in module.parameters()]) 391 | print('Param count for D''s initialized parameters: %d' % self.param_count) 392 | 393 | def forward(self, x, y=None, mode=0): 394 | # Stick x into h for cleaner for loops without flow control 395 | h = x 396 | # Loop over blocks 397 | for index, blocklist in enumerate(self.blocks): 398 | for block in blocklist: 399 | h = block(h) 400 | # Apply global sum pooling as in SN-GAN 401 | h = torch.sum(self.activation(h), [2, 3]) 402 | if mode==0: 403 | # Get initial class-unconditional output 404 | out = self.linear(h) 405 | # Get projection of final featureset onto class vectors and add to evidence 406 | out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) 407 | out_cat = F.sigmoid(self.output_cat(h)) 408 | return out, out_cat 409 | else: 410 | out_cat = F.sigmoid(self.output_cat(h)) 411 | return out_cat 412 | --------------------------------------------------------------------------------