├── EAL-GAN
├── data
│ ├── mnist.mat
│ ├── musk.mat
│ ├── shuttle.mat
│ ├── vowels.mat
│ ├── optdigits.mat
│ ├── pendigits.mat
│ ├── satellite.mat
│ └── satimage-2.mat
├── models
│ ├── sync_batchnorm
│ │ ├── comm.pyc
│ │ ├── __init__.py
│ │ ├── unittest.py
│ │ ├── batchnorm_reimpl.py
│ │ ├── replicate.py
│ │ ├── comm.py
│ │ └── batchnorm.py
│ ├── losses.py
│ ├── pyod_utils.py
│ ├── layers.py
│ └── EAL_GAN.py
├── README.md
└── Train_EAL_GAN.py
├── EAL-GAN-image
├── src
│ ├── base
│ │ ├── __init__.py
│ │ ├── torchvision_dataset.py
│ │ ├── base_net.py
│ │ ├── base_dataset.py
│ │ ├── base_trainer.py
│ │ └── odds_dataset.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── odds.py
│ │ ├── main.py
│ │ ├── preprocessing.py
│ │ ├── cifar10.py
│ │ ├── mnist.py
│ │ └── fmnist.py
│ ├── sync_batchnorm
│ │ ├── __init__.py
│ │ ├── unittest.py
│ │ ├── batchnorm_reimpl.py
│ │ ├── replicate.py
│ │ ├── comm.py
│ │ └── batchnorm.py
│ ├── loss.py
│ ├── my_utils.py
│ ├── EAL_GAN.py
│ ├── layers.py
│ └── BigGAN.py
└── Train_EAL_GAN.py
└── README.md
/EAL-GAN/data/mnist.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/mnist.mat
--------------------------------------------------------------------------------
/EAL-GAN/data/musk.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/musk.mat
--------------------------------------------------------------------------------
/EAL-GAN/data/shuttle.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/shuttle.mat
--------------------------------------------------------------------------------
/EAL-GAN/data/vowels.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/vowels.mat
--------------------------------------------------------------------------------
/EAL-GAN/data/optdigits.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/optdigits.mat
--------------------------------------------------------------------------------
/EAL-GAN/data/pendigits.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/pendigits.mat
--------------------------------------------------------------------------------
/EAL-GAN/data/satellite.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/satellite.mat
--------------------------------------------------------------------------------
/EAL-GAN/data/satimage-2.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/data/satimage-2.mat
--------------------------------------------------------------------------------
/EAL-GAN/models/sync_batchnorm/comm.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smallcube/EAL-GAN/HEAD/EAL-GAN/models/sync_batchnorm/comm.pyc
--------------------------------------------------------------------------------
/EAL-GAN-image/src/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_dataset import *
2 | from .torchvision_dataset import *
3 | from .odds_dataset import *
4 | from .base_net import *
5 | from .base_trainer import *
6 |
--------------------------------------------------------------------------------
/EAL-GAN/models/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
2 | from .replicate import DataParallelWithCallback, patch_replication_callback
3 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .main import load_dataset
2 | from .mnist import MNIST_Dataset
3 | from .fmnist import FashionMNIST_Dataset
4 | from .cifar10 import CIFAR10_Dataset
5 | from .odds import ODDSADDataset
6 | from .preprocessing import *
7 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/EAL-GAN/models/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y), message)
29 |
30 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y), message)
29 |
30 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/base/torchvision_dataset.py:
--------------------------------------------------------------------------------
1 | from .base_dataset import BaseADDataset
2 | from torch.utils.data import DataLoader
3 |
4 |
5 | class TorchvisionDataset(BaseADDataset):
6 | """TorchvisionDataset class for datasets already implemented in torchvision.datasets."""
7 |
8 | def __init__(self, root: str):
9 | super().__init__(root)
10 |
11 | def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
12 | DataLoader, DataLoader):
13 | train_loader = DataLoader(dataset=self.train_set, batch_size=batch_size, shuffle=shuffle_train,
14 | num_workers=num_workers, drop_last=True)
15 | test_loader = DataLoader(dataset=self.test_set, batch_size=batch_size, shuffle=shuffle_test,
16 | num_workers=num_workers, drop_last=False)
17 | return train_loader, test_loader
18 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/base/base_net.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class BaseNet(nn.Module):
7 | """Base class for all neural networks."""
8 |
9 | def __init__(self):
10 | super().__init__()
11 | self.logger = logging.getLogger(self.__class__.__name__)
12 | self.rep_dim = None # representation dimensionality, i.e. dim of the code layer or last layer
13 |
14 | def forward(self, *input):
15 | """
16 | Forward pass logic
17 | :return: Network output
18 | """
19 | raise NotImplementedError
20 |
21 | def summary(self):
22 | """Network summary."""
23 | net_parameters = filter(lambda p: p.requires_grad, self.parameters())
24 | params = sum([np.prod(p.size()) for p in net_parameters])
25 | self.logger.info('Trainable parameters: {}'.format(params))
26 | self.logger.info(self)
27 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/base/base_dataset.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from torch.utils.data import DataLoader
3 |
4 |
5 | class BaseADDataset(ABC):
6 | """Anomaly detection dataset base class."""
7 |
8 | def __init__(self, root: str):
9 | super().__init__()
10 | self.root = root # root path to data
11 |
12 | self.n_classes = 2 # 0: normal, 1: outlier
13 | self.normal_classes = None # tuple with original class labels that define the normal class
14 | self.outlier_classes = None # tuple with original class labels that define the outlier class
15 |
16 | self.train_set = None # must be of type torch.utils.data.Dataset
17 | self.test_set = None # must be of type torch.utils.data.Dataset
18 |
19 | @abstractmethod
20 | def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
21 | DataLoader, DataLoader):
22 | """Implement data loaders of type torch.utils.data.DataLoader for train_set and test_set."""
23 | pass
24 |
25 | def __repr__(self):
26 | return self.__class__.__name__
27 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/base/base_trainer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from .base_dataset import BaseADDataset
3 | from .base_net import BaseNet
4 |
5 |
6 | class BaseTrainer(ABC):
7 | """Trainer base class."""
8 |
9 | def __init__(self, optimizer_name: str, lr: float, n_epochs: int, lr_milestones: tuple, batch_size: int,
10 | weight_decay: float, device: str, n_jobs_dataloader: int):
11 | super().__init__()
12 | self.optimizer_name = optimizer_name
13 | self.lr = lr
14 | self.n_epochs = n_epochs
15 | self.lr_milestones = lr_milestones
16 | self.batch_size = batch_size
17 | self.weight_decay = weight_decay
18 | self.device = device
19 | self.n_jobs_dataloader = n_jobs_dataloader
20 |
21 | @abstractmethod
22 | def train(self, dataset: BaseADDataset, net: BaseNet) -> BaseNet:
23 | """
24 | Implement train method that trains the given network using the train_set of dataset.
25 | :return: Trained net
26 | """
27 | pass
28 |
29 | @abstractmethod
30 | def test(self, dataset: BaseADDataset, net: BaseNet):
31 | """
32 | Implement test method that evaluates the test_set of dataset on the given network.
33 | """
34 | pass
35 |
--------------------------------------------------------------------------------
/EAL-GAN/README.md:
--------------------------------------------------------------------------------
1 |
2 | EAL-GAN: Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning
3 | ==
4 | This is the official implementation of “Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning”. Our paper has been submitted to the IEEE for possible publication, and a Preprint version of the manuscript can be found in Arxiv. If you use the codes in this repo, please cite the paper as follows:
5 |
6 | > @misc{chen2021supervised,
7 | title={Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning},
8 | author={Zhi Chen and Jiang Duan and Li Kang and Guoping Qiu},
9 | year={2021},
10 | eprint={2104.11952},
11 | archivePrefix={arXiv},
12 | primaryClass={cs.LG}
13 | }
14 |
15 | We have implemented two versions of the proposed EAL-GAN, including:(1) one for classical anomaly detection datastes, whose codes are in the "EAL-GAN" folder, and (2) one for the image datasets, whose codes are in the "EAL-GAN-image" folder. Please note the EAL-GAN-image is sensitive to the learning rate and initiation. We print the changes in the loss of discriminators and generator during training, and if the loss become NAN in the first epoch, you should re-run the code with smaller learning rate. If that doesn't happen, you can expect a promising result.
16 |
17 |
18 | Requirements
19 | ===
20 | Pytorch >1.6
21 | Python 3.7
22 |
23 | Getting started
24 | ===
25 | (1) You can run the script “train_EAL_GAN.py” to train the model proposed in our paper.
26 | (2) Some of the datasets in our paper are given in the folder “/data”.
27 | (3) Models/EAL-GAN.py is the proposed model.
28 | (4) Models/losses.py is the loss functions.
29 |
30 |
31 | Acknowledgments
32 | ==
33 | some of our codes (e.g., Spectral Normalization) are extracted from the [PyTorch implementation of BigGAN]( https://github.com/ajbrock/BigGAN-PyTorch).
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | EAL-GAN: Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning
3 | ==
4 | This is the official implementation of “Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning”. Our paper has been officially accpted by IEEE Transactions on Pattern Analysis and Machine Intelligence, and a Preprint version of the manuscript can be found in Arxiv. If you use the code in this repo, please cite the paper as follows:
5 |
6 | >@ARTICLE{chen202-EALGAN,
7 | author={Chen, Zhi and Duan, Jiang and Kang, Li and Qiu, Guoping},
8 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
9 | title={Supervised Anomaly Detection via Conditional Generative Adversarial Network and Ensemble Active Learning},
10 | year={2023},
11 | volume={45},
12 | number={6},
13 | pages={{7781-7798},
14 | doi={10.1109/TPAMI.2022.3225476}}
15 |
16 | We have implemented two versions of the proposed EAL-GAN, including:(1) one for classical anomaly detection datastes, whose codes are in the "EAL-GAN" folder, and (2) one for the image datasets, whose codes are in the "EAL-GAN-image" folder. Please note the EAL-GAN-image is sensitive to the learning rate and initiation. We print the changes in the loss of discriminators and generator during training, and if the loss become NAN in the first epoch, you should re-run the code with smaller learning rate. If that doesn't happen, you can expect a promising result.
17 |
18 |
19 | Requirements
20 | ===
21 | Pytorch >1.6
22 | Python 3.7
23 |
24 | Getting started
25 | ===
26 | (1) You can run the script “train_EAL_GAN.py” to train the model proposed in our paper.
27 | (2) Some of the datasets in our paper are given in the folder “/data”.
28 | (3) Models/EAL-GAN.py is the proposed model.
29 | (4) Models/losses.py is the loss functions.
30 |
31 |
32 | Acknowledgments
33 | ==
34 | some of our codes (e.g., Spectral Normalization) are extracted from the [PyTorch implementation of BigGAN]( https://github.com/ajbrock/BigGAN-PyTorch).
35 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/datasets/odds.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader, Subset
2 | from src.base.base_dataset import BaseADDataset
3 | from src.base.odds_dataset import ODDSDataset
4 | from src.datasets.preprocessing import create_semisupervised_setting
5 |
6 | import torch
7 |
8 |
9 | class ODDSADDataset(BaseADDataset):
10 |
11 | def __init__(self, root: str, dataset_name: str, n_known_outlier_classes: int = 0, ratio_known_normal: float = 0.0,
12 | ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0, random_state=None):
13 | super().__init__(root)
14 |
15 | # Define normal and outlier classes
16 | self.n_classes = 2 # 0: normal, 1: outlier
17 | self.normal_classes = (0,)
18 | self.outlier_classes = (1,)
19 |
20 | if n_known_outlier_classes == 0:
21 | self.known_outlier_classes = ()
22 | else:
23 | self.known_outlier_classes = (1,)
24 |
25 | # Get train set
26 | train_set = ODDSDataset(root=self.root, dataset_name=dataset_name, train=True, random_state=random_state,
27 | download=True)
28 |
29 | # Create semi-supervised setting
30 | idx, _, semi_targets = create_semisupervised_setting(train_set.targets.cpu().data.numpy(), self.normal_classes,
31 | self.outlier_classes, self.known_outlier_classes,
32 | ratio_known_normal, ratio_known_outlier, ratio_pollution)
33 | semi_targets = torch.tensor(semi_targets)
34 | if semi_targets.shape[0]>=train_set.semi_targets.shape[0]:
35 | semi_targets = semi_targets[0:train_set.semi_targets.shape[0]]
36 | train_set.semi_targets[idx] = semi_targets # set respective semi-supervised labels
37 |
38 | # Subset train_set to semi-supervised setup
39 | self.train_set = Subset(train_set, idx)
40 |
41 | # Get test set
42 | self.test_set = ODDSDataset(root=self.root, dataset_name=dataset_name, train=False, random_state=random_state)
43 |
44 | def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (DataLoader, DataLoader):
45 | train_loader = DataLoader(dataset=self.train_set, batch_size=batch_size, shuffle=shuffle_train,
46 | num_workers=num_workers, drop_last=True)
47 | test_loader = DataLoader(dataset=self.test_set, batch_size=batch_size, shuffle=shuffle_test,
48 | num_workers=num_workers, drop_last=False)
49 | return train_loader, test_loader
50 |
--------------------------------------------------------------------------------
/EAL-GAN/models/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNormReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNormReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/EAL-GAN-image/Train_EAL_GAN.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import sys
7 | from time import time
8 |
9 | sys.path.append(
10 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..')))
11 | # supress warnings for clean output
12 | import warnings
13 |
14 | warnings.filterwarnings("ignore")
15 |
16 | import numpy as np
17 | import pandas as pd
18 | from sklearn.model_selection import train_test_split
19 | from scipy.io import loadmat
20 |
21 | import torch
22 |
23 | from src.my_utils import parse_args
24 | from src.EAL_GAN import EAL_GAN
25 | from src.datasets.main import load_dataset
26 |
27 | if __name__=='__main__':
28 | # Define data file and read X and y
29 | data_root = './data'
30 | save_dir = './results'
31 | n_ite = 10
32 |
33 | data_names = [#'mnist',
34 | 'fmnist',
35 | #'cifar10',
36 | ]
37 | data_resolution = {'mnist':32, 'fmnist':32, 'cifar10':32}
38 | # define the number of iterations
39 | df_columns = ['Data', 'auc_mean', 'auc_std', 'gmean', 'gmean_std']
40 |
41 | # initialize the container for saving the results
42 | roc_df = pd.DataFrame(columns=df_columns)
43 | gmean_df = pd.DataFrame(columns=df_columns)
44 |
45 | args = parse_args()
46 | roc_df = pd.DataFrame(columns=df_columns)
47 | gmean_df = pd.DataFrame(columns=df_columns)
48 | outlier_class = 0
49 |
50 |
51 |
52 | for data_name in data_names:
53 | roc_list = [data_name]
54 | gmean_list = [data_name]
55 |
56 | roc_mat = np.zeros([1, 1])
57 | gmean_mat = np.zeros([1, 1])
58 |
59 | args.feat_dim = 3 if data_name=='cifar10' else 1
60 | #for i in range(0, n_ite):
61 | dataset = load_dataset(data_name, data_root, args.normal_class, outlier_class, args.n_known_outlier_classes,
62 | args.ratio_known_normal, args.ratio_known_outlier, args.ratio_pollution,
63 | random_state=args.seed, resolution=data_resolution[data_name])
64 |
65 |
66 | #todo: add my method
67 | cb_gan = EAL_GAN(args, dataset)
68 | best_auc, best_gmean = cb_gan.fit()
69 |
70 | print('AUC:%.4f, Gmean:%.4f ' % (best_auc, best_gmean))
71 |
72 | roc_mat[0, 0] = best_auc
73 | gmean_mat[0, 0] = best_gmean
74 |
75 | roc_list = roc_list + np.mean(roc_mat, axis=0).tolist() + np.std(roc_mat, axis=0).tolist() + np.mean(gmean_mat, axis=0).tolist()+np.std(gmean_mat, axis=0).tolist()
76 | temp_df = pd.DataFrame(roc_list).transpose()
77 | temp_df.columns = df_columns
78 | roc_df = pd.concat([roc_df, temp_df], axis=0)
79 |
80 | # Save the results for each run
81 | save_path1 = os.path.join(save_dir, "AUC_EAL_GAN.csv")
82 | roc_df.to_csv(save_path1, index=False, float_format='%.4f')
83 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pandas as pd
4 | import numpy as np
5 | import os
6 | import torch.nn.functional as F
7 |
8 | def loss_dis_real(dis_real, out_category, y, weights=None, cat_weight=None, gamma=2.0):
9 | #step 1: the loss for GAN
10 | logpt = F.softplus(-dis_real)
11 | pt = torch.exp(-logpt)
12 | if weights is None:
13 | weights = pt.detach()
14 | p = pt*1
15 | #print(dis_real.shape)
16 | #print(p.shape)
17 | p = p.view(dis_real.shape[0], 1)
18 | p = (1-p)**gamma
19 | loss = p.clone().detach() * logpt
20 | else:
21 | weights = torch.cat([weights, pt], 1)
22 | p = torch.mean(weights, 1)
23 | p = p.view(len(dis_real), 1)
24 | p = (1-p)**gamma
25 | loss = p.clone().detach()*logpt
26 | loss = torch.mean(loss)
27 |
28 | #step 2: loss for classifying
29 | target = y.view(y.size(0), 1)
30 | pt_cat = (1.-target.float())*(1-out_category) + target.float()*out_category
31 | logpt_cat = -torch.log(pt_cat)
32 | batch_size = target.size(0)
33 |
34 | if cat_weight is None:
35 | cat_weight = pt_cat.detach()
36 | p = pt_cat*1
37 | p = p.view(batch_size, 1)
38 | p = (1-p)**gamma
39 | else:
40 | cat_weight = torch.cat([cat_weight, pt_cat], 1)
41 | p = torch.mean(cat_weight, 1)
42 | p = p.view(batch_size, 1)
43 | p = (1-p)**gamma
44 | logpt_cat = p.clone().detach()*logpt_cat
45 | loss_cat = torch.mean(logpt_cat)
46 | return loss, loss_cat, weights, cat_weight
47 |
48 |
49 | def loss_dis_fake(dis_fake, out_category, y, weights=None, cat_weight=None, gamma=2.0):
50 | logpt = F.softplus(dis_fake)
51 | pt = torch.exp(-logpt)
52 |
53 | if weights is None:
54 | p = pt*1
55 | p = p.view(len(dis_fake), 1)
56 | p = (1-p)**gamma
57 | loss = p.clone().detach() * logpt
58 | weights = pt.detach()
59 | #loss = logpt
60 | else:
61 | weights = torch.cat([weights, pt], 1)
62 | p = torch.mean(weights, 1)
63 | p = p.view(len(dis_fake), 1)
64 | p = (1-p)**gamma
65 | loss = p.clone().detach()*logpt
66 |
67 | loss = torch.mean(loss)
68 |
69 | #step 2: loss for classifying
70 | target = y.view(y.size(0), 1)
71 | pt_cat = (1.-target.float())*(1-out_category) + target.float()*out_category
72 | logpt_cat = -torch.log(pt_cat)
73 | batch_size = target.size(0)
74 |
75 | if cat_weight is None:
76 | cat_weight = pt_cat.detach()
77 | p = pt_cat*1
78 | p = p.view(batch_size, 1)
79 | p = (1-p)**gamma
80 | else:
81 | cat_weight = torch.cat([cat_weight, pt_cat], 1)
82 | p = torch.mean(cat_weight, 1)
83 | p = p.view(batch_size, 1)
84 | p = (1-p)**gamma
85 | logpt_cat = p.clone().detach()*logpt_cat
86 | loss_cat = torch.mean(logpt_cat)
87 | return loss, loss_cat, weights, cat_weight
88 |
89 |
--------------------------------------------------------------------------------
/EAL-GAN/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pandas as pd
4 | import numpy as np
5 | import os
6 | import torch.nn.functional as F
7 |
8 | def loss_dis_real(dis_real, out_category, y, weights=None, cat_weight=None, gamma=2.0):
9 | #step 1: the loss for GAN
10 | logpt = F.softplus(-dis_real)
11 | pt = torch.exp(-logpt)
12 | if weights is None:
13 | weights = pt.detach()
14 | p = pt*1
15 | #print(dis_real.shape)
16 | #print(p.shape)
17 | p = p.view(dis_real.shape[0], 1)
18 | p = (1-p)**gamma
19 | loss = p.clone().detach() * logpt
20 | else:
21 | weights = torch.cat([weights, pt], 1)
22 | p = torch.mean(weights, 1)
23 | p = p.view(len(dis_real), 1)
24 | p = (1-p)**gamma
25 | loss = p.clone().detach()*logpt
26 | loss = torch.mean(loss)
27 |
28 | #step 2: loss for classifying
29 | target = y.view(y.size(0), 1)
30 | pt_cat = (1.-target.float())*(1-out_category) + target.float()*out_category
31 | logpt_cat = -torch.log(pt_cat)
32 | batch_size = target.size(0)
33 |
34 | if cat_weight is None:
35 | cat_weight = pt_cat.detach()
36 | p = pt_cat*1
37 | p = p.view(batch_size, 1)
38 | p = (1-p)**gamma
39 | else:
40 | cat_weight = torch.cat([cat_weight, pt_cat], 1)
41 | p = torch.mean(cat_weight, 1)
42 | p = p.view(batch_size, 1)
43 | p = (1-p)**gamma
44 | logpt_cat = p.clone().detach()*logpt_cat
45 | loss_cat = torch.mean(logpt_cat)
46 | return loss, loss_cat, weights, cat_weight
47 |
48 |
49 | def loss_dis_fake(dis_fake, out_category, y, weights=None, cat_weight=None, gamma=2.0):
50 | logpt = F.softplus(dis_fake)
51 | pt = torch.exp(-logpt)
52 |
53 | if weights is None:
54 | p = pt*1
55 | p = p.view(len(dis_fake), 1)
56 | p = (1-p)**gamma
57 | loss = p.clone().detach() * logpt
58 | weights = pt.detach()
59 | #loss = logpt
60 | else:
61 | weights = torch.cat([weights, pt], 1)
62 | p = torch.mean(weights, 1)
63 | p = p.view(len(dis_fake), 1)
64 | p = (1-p)**gamma
65 | loss = p.clone().detach()*logpt
66 |
67 | loss = torch.mean(loss)
68 |
69 | #step 2: loss for classifying
70 | target = y.view(y.size(0), 1)
71 | pt_cat = (1.-target.float())*(1-out_category) + target.float()*out_category
72 | logpt_cat = -torch.log(pt_cat)
73 | batch_size = target.size(0)
74 |
75 | if cat_weight is None:
76 | cat_weight = pt_cat.detach()
77 | p = pt_cat*1
78 | p = p.view(batch_size, 1)
79 | p = (1-p)**gamma
80 | else:
81 | cat_weight = torch.cat([cat_weight, pt_cat], 1)
82 | p = torch.mean(cat_weight, 1)
83 | p = p.view(batch_size, 1)
84 | p = (1-p)**gamma
85 | logpt_cat = p.clone().detach()*logpt_cat
86 | loss_cat = torch.mean(logpt_cat)
87 | return loss, loss_cat, weights, cat_weight
88 |
89 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/datasets/main.py:
--------------------------------------------------------------------------------
1 | from .mnist import MNIST_Dataset
2 | from .fmnist import FashionMNIST_Dataset
3 | from .cifar10 import CIFAR10_Dataset
4 | from .odds import ODDSADDataset
5 |
6 |
7 | def load_dataset(dataset_name, data_path, normal_class, known_outlier_class, n_known_outlier_classes: int = 0,
8 | ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0,
9 | random_state=None, resolution=32):
10 | """Loads the dataset."""
11 |
12 | #implemented_datasets = ('mnist', 'fmnist', 'cifar10',
13 | # 'arrhythmia', 'cardio', 'satellite', 'satimage-2', 'shuttle', 'thyroid')
14 | #assert dataset_name in implemented_datasets
15 |
16 | dataset = None
17 |
18 |
19 | if dataset_name == 'mnist':
20 | dataset = MNIST_Dataset(root=data_path,
21 | normal_class=normal_class,
22 | known_outlier_class=known_outlier_class,
23 | n_known_outlier_classes=n_known_outlier_classes,
24 | ratio_known_normal=ratio_known_normal,
25 | ratio_known_outlier=ratio_known_outlier,
26 | ratio_pollution=ratio_pollution,
27 | resolution=resolution)
28 |
29 | elif dataset_name == 'fmnist':
30 | dataset = FashionMNIST_Dataset(root=data_path,
31 | normal_class=normal_class,
32 | known_outlier_class=known_outlier_class,
33 | n_known_outlier_classes=n_known_outlier_classes,
34 | ratio_known_normal=ratio_known_normal,
35 | ratio_known_outlier=ratio_known_outlier,
36 | ratio_pollution=ratio_pollution,
37 | resolution=resolution)
38 |
39 | elif dataset_name == 'cifar10':
40 | dataset = CIFAR10_Dataset(root=data_path,
41 | normal_class=normal_class,
42 | known_outlier_class=known_outlier_class,
43 | n_known_outlier_classes=n_known_outlier_classes,
44 | ratio_known_normal=ratio_known_normal,
45 | ratio_known_outlier=ratio_known_outlier,
46 | ratio_pollution=ratio_pollution,
47 | resolution=resolution)
48 |
49 | else:
50 | dataset = ODDSADDataset(root=data_path,
51 | dataset_name=dataset_name,
52 | n_known_outlier_classes=n_known_outlier_classes,
53 | ratio_known_normal=ratio_known_normal,
54 | ratio_known_outlier=ratio_known_outlier,
55 | ratio_pollution=ratio_pollution,
56 | random_state=random_state)
57 |
58 | return dataset
59 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/EAL-GAN/models/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/datasets/preprocessing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def create_semisupervised_setting(labels, normal_classes, outlier_classes, known_outlier_classes,
6 | ratio_known_normal, ratio_known_outlier, ratio_pollution):
7 | """
8 | Create a semi-supervised data setting.
9 | :param labels: np.array with labels of all dataset samples
10 | :param normal_classes: tuple with normal class labels
11 | :param outlier_classes: tuple with anomaly class labels
12 | :param known_outlier_classes: tuple with known (labeled) anomaly class labels
13 | :param ratio_known_normal: the desired ratio of known (labeled) normal samples
14 | :param ratio_known_outlier: the desired ratio of known (labeled) anomalous samples
15 | :param ratio_pollution: the desired pollution ratio of the unlabeled data with unknown (unlabeled) anomalies.
16 | :return: tuple with list of sample indices, list of original labels, and list of semi-supervised labels
17 | """
18 | idx_normal = np.argwhere(np.isin(labels, normal_classes)).flatten()
19 | idx_outlier = np.argwhere(np.isin(labels, outlier_classes)).flatten()
20 | idx_known_outlier_candidates = np.argwhere(np.isin(labels, known_outlier_classes)).flatten()
21 |
22 | n_normal = len(idx_normal)
23 |
24 | # Solve system of linear equations to obtain respective number of samples
25 | a = np.array([[1, 1, 0, 0],
26 | [(1-ratio_known_normal), -ratio_known_normal, -ratio_known_normal, -ratio_known_normal],
27 | [-ratio_known_outlier, -ratio_known_outlier, -ratio_known_outlier, (1-ratio_known_outlier)],
28 | [0, -ratio_pollution, (1-ratio_pollution), 0]])
29 | b = np.array([n_normal, 0, 0, 0])
30 | x = np.linalg.solve(a, b)
31 |
32 | # Get number of samples
33 | n_known_normal = int(x[0])
34 | n_unlabeled_normal = int(x[1])
35 | n_unlabeled_outlier = int(x[2])
36 | n_known_outlier = int(x[3])
37 |
38 | # Sample indices
39 | perm_normal = np.random.permutation(n_normal)
40 | perm_outlier = np.random.permutation(len(idx_outlier))
41 | perm_known_outlier = np.random.permutation(len(idx_known_outlier_candidates))
42 |
43 | idx_known_normal = idx_normal[perm_normal[:n_known_normal]].tolist()
44 | idx_unlabeled_normal = idx_normal[perm_normal[n_known_normal:n_known_normal+n_unlabeled_normal]].tolist()
45 | idx_unlabeled_outlier = idx_outlier[perm_outlier[:n_unlabeled_outlier]].tolist()
46 | idx_known_outlier = idx_known_outlier_candidates[perm_known_outlier[:n_known_outlier]].tolist()
47 |
48 | # Get original class labels
49 | labels_known_normal = labels[idx_known_normal].tolist()
50 | labels_unlabeled_normal = labels[idx_unlabeled_normal].tolist()
51 | labels_unlabeled_outlier = labels[idx_unlabeled_outlier].tolist()
52 | labels_known_outlier = labels[idx_known_outlier].tolist()
53 |
54 | # Get semi-supervised setting labels
55 | semi_labels_known_normal = np.ones(n_known_normal).astype(np.int32).tolist()
56 | semi_labels_unlabeled_normal = np.zeros(n_unlabeled_normal).astype(np.int32).tolist()
57 | semi_labels_unlabeled_outlier = np.zeros(n_unlabeled_outlier).astype(np.int32).tolist()
58 | semi_labels_known_outlier = (-np.ones(n_known_outlier).astype(np.int32)).tolist()
59 |
60 | # Create final lists
61 | list_idx = idx_known_normal + idx_unlabeled_normal + idx_unlabeled_outlier + idx_known_outlier
62 | list_labels = labels_known_normal + labels_unlabeled_normal + labels_unlabeled_outlier + labels_known_outlier
63 | list_semi_labels = (semi_labels_known_normal + semi_labels_unlabeled_normal + semi_labels_unlabeled_outlier
64 | + semi_labels_known_outlier)
65 |
66 | return list_idx, list_labels, list_semi_labels
67 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Subset
2 | from PIL import Image
3 | from torchvision.datasets import CIFAR10
4 | from src.base.torchvision_dataset import TorchvisionDataset
5 | from src.datasets.preprocessing import create_semisupervised_setting
6 |
7 | import torch
8 | import torchvision.transforms as transforms
9 | import random
10 | import numpy as np
11 |
12 |
13 | class CIFAR10_Dataset(TorchvisionDataset):
14 |
15 | def __init__(self, root: str, normal_class: int = 5, known_outlier_class: int = 3, n_known_outlier_classes: int = 0,
16 | ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0, resolution=32):
17 | super().__init__(root)
18 |
19 | # Define normal and outlier classes
20 | self.n_classes = 2 # 0: normal, 1: outlier
21 | self.normal_classes = tuple([normal_class])
22 | #self.outlier_classes = list(range(0, 10))
23 | #self.outlier_classes.remove(normal_class)
24 | self.outlier_classes = tuple([known_outlier_class])
25 |
26 | if n_known_outlier_classes == 0:
27 | self.known_outlier_classes = ()
28 | elif n_known_outlier_classes == 1:
29 | self.known_outlier_classes = tuple([known_outlier_class])
30 | else:
31 | self.known_outlier_classes = tuple(random.sample(self.outlier_classes, n_known_outlier_classes))
32 |
33 | # CIFAR-10 preprocessing: feature scaling to [0, 1]
34 | #transform = transforms.ToTensor()
35 | transform = transforms.Compose([
36 | transforms.Resize(resolution+2),
37 | transforms.CenterCrop(resolution),
38 | transforms.ToTensor()
39 | ])
40 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
41 |
42 | # Get train set
43 | train_set = MyCIFAR10(root=self.root, train=True, transform=transform, target_transform=target_transform,
44 | download=True)
45 |
46 | '''
47 | # Create semi-supervised setting
48 | idx, _, semi_targets = create_semisupervised_setting(np.array(train_set.targets), self.normal_classes,
49 | self.outlier_classes, self.known_outlier_classes,
50 | ratio_known_normal, ratio_known_outlier, ratio_pollution)
51 | train_set.semi_targets[idx] = torch.tensor(semi_targets) # set respective semi-supervised labels
52 |
53 | # Subset train_set to semi-supervised setup
54 | self.train_set = Subset(train_set, idx)
55 | '''
56 | self.train_set = train_set
57 | # Get test set
58 | self.test_set = MyCIFAR10(root=self.root, train=False, transform=transform, target_transform=target_transform,
59 | download=True)
60 |
61 |
62 | class MyCIFAR10(CIFAR10):
63 | """
64 | Torchvision CIFAR10 class with additional targets for the semi-supervised setting and patch of __getitem__ method
65 | to also return the semi-supervised target as well as the index of a data sample.
66 | """
67 |
68 | def __init__(self, *args, **kwargs):
69 | super(MyCIFAR10, self).__init__(*args, **kwargs)
70 |
71 | self.semi_targets = torch.zeros(len(self.targets), dtype=torch.int64)
72 |
73 | def __getitem__(self, index):
74 | """Override the original method of the CIFAR10 class.
75 | Args:
76 | index (int): Index
77 |
78 | Returns:
79 | tuple: (image, target, semi_target, index)
80 | """
81 | img, target, semi_target = self.data[index], self.targets[index], int(self.semi_targets[index])
82 |
83 | # doing this so that it is consistent with all other datasets
84 | # to return a PIL Image
85 | img = Image.fromarray(img)
86 |
87 | if self.transform is not None:
88 | img = self.transform(img)
89 |
90 | if self.target_transform is not None:
91 | target = self.target_transform(target)
92 |
93 | return img, target, semi_target, index
94 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/datasets/mnist.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Subset
2 | from PIL import Image
3 | from torchvision.datasets import MNIST
4 | from src.base.torchvision_dataset import TorchvisionDataset
5 | from src.datasets.preprocessing import create_semisupervised_setting
6 |
7 | import torch
8 | import torchvision.transforms as transforms
9 | import random
10 |
11 |
12 | class MNIST_Dataset(TorchvisionDataset):
13 |
14 | def __init__(self, root: str, normal_class: int = 0, known_outlier_class: int = 1, n_known_outlier_classes: int = 0,
15 | ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0, resolution=32):
16 | super().__init__(root)
17 |
18 | # Define normal and outlier classes
19 | self.n_classes = 2 # 0: normal, 1: outlier
20 | self.normal_classes = tuple([normal_class])
21 | #self.outlier_classes = list(range(0, 10))
22 | #self.outlier_classes.remove(normal_class)
23 | #self.outlier_classes = list(known_outlier_class)
24 | self.outlier_classes = tuple([known_outlier_class])
25 |
26 | if n_known_outlier_classes == 0:
27 | self.known_outlier_classes = ()
28 | elif n_known_outlier_classes == 1:
29 | self.known_outlier_classes = tuple([known_outlier_class])
30 | else:
31 | self.known_outlier_classes = tuple(random.sample(self.outlier_classes, n_known_outlier_classes))
32 |
33 | # MNIST preprocessing: feature scaling to [0, 1]
34 | transform = transforms.Compose([
35 | transforms.Resize(resolution+2),
36 | transforms.CenterCrop(resolution),
37 | transforms.ToTensor()
38 | ])
39 |
40 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
41 |
42 | # Get train set
43 | train_set = MyMNIST(root=self.root, train=True, transform=transform, target_transform=target_transform,
44 | download=True)
45 |
46 | # Create semi-supervised setting
47 | '''
48 | idx, _, semi_targets = create_semisupervised_setting(train_set.targets.cpu().data.numpy(), self.normal_classes,
49 | self.outlier_classes, self.known_outlier_classes,
50 | ratio_known_normal, ratio_known_outlier, ratio_pollution)
51 | train_set.semi_targets[idx] = torch.tensor(semi_targets) # set respective semi-supervised labels
52 |
53 | # Subset train_set to semi-supervised setup
54 | self.train_set = Subset(train_set, idx)
55 | '''
56 | self.train_set = train_set
57 |
58 | # Get test set
59 | self.test_set = MyMNIST(root=self.root, train=False, transform=transform, target_transform=target_transform,
60 | download=True)
61 |
62 |
63 | class MyMNIST(MNIST):
64 | """
65 | Torchvision MNIST class with additional targets for the semi-supervised setting and patch of __getitem__ method
66 | to also return the semi-supervised target as well as the index of a data sample.
67 | """
68 |
69 | def __init__(self, *args, **kwargs):
70 | super(MyMNIST, self).__init__(*args, **kwargs)
71 |
72 | self.semi_targets = torch.zeros_like(self.targets)
73 |
74 | def __getitem__(self, index):
75 | """Override the original method of the MNIST class.
76 | Args:
77 | index (int): Index
78 |
79 | Returns:
80 | tuple: (image, target, semi_target, index)
81 | """
82 | img, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index])
83 |
84 | # doing this so that it is consistent with all other datasets
85 | # to return a PIL Image
86 | img = Image.fromarray(img.numpy(), mode='L')
87 |
88 | if self.transform is not None:
89 | img = self.transform(img)
90 |
91 | if self.target_transform is not None:
92 | target = self.target_transform(target)
93 |
94 | return img, target, semi_target, index
95 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/datasets/fmnist.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Subset
2 | from PIL import Image
3 | from torchvision.datasets import FashionMNIST
4 | from src.base.torchvision_dataset import TorchvisionDataset
5 | from src.datasets.preprocessing import create_semisupervised_setting
6 |
7 | import torch
8 | import torchvision.transforms as transforms
9 | import random
10 |
11 |
12 | class FashionMNIST_Dataset(TorchvisionDataset):
13 |
14 | def __init__(self, root: str, normal_class: int = 0, known_outlier_class: int = 1, n_known_outlier_classes: int = 0,
15 | ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0, resolution=32):
16 | super().__init__(root)
17 |
18 | # Define normal and outlier classes
19 | self.n_classes = 2 # 0: normal, 1: outlier
20 | self.normal_classes = tuple([normal_class])
21 | #self.outlier_classes = list(range(0, 10))
22 | #self.outlier_classes.remove(normal_class)
23 | #self.outlier_classes = list(known_outlier_class)
24 | self.outlier_classes = tuple([known_outlier_class])
25 |
26 | if n_known_outlier_classes == 0:
27 | self.known_outlier_classes = ()
28 | elif n_known_outlier_classes == 1:
29 | self.known_outlier_classes = tuple([known_outlier_class])
30 | else:
31 | self.known_outlier_classes = tuple(random.sample(self.outlier_classes, n_known_outlier_classes))
32 |
33 | # FashionMNIST preprocessing: feature scaling to [0, 1]
34 | transform = transforms.Compose([
35 | transforms.Resize(resolution+2),
36 | transforms.CenterCrop(resolution),
37 | transforms.ToTensor()
38 | ])
39 | #transform = transforms.ToTensor()
40 | target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
41 |
42 |
43 | # Get train set
44 | train_set = MyFashionMNIST(root=self.root, train=True, transform=transform, target_transform=target_transform,
45 | download=True)
46 | '''
47 | # Create semi-supervised setting
48 | idx, _, semi_targets = create_semisupervised_setting(train_set.targets.cpu().data.numpy(), self.normal_classes,
49 | self.outlier_classes, self.known_outlier_classes,
50 | ratio_known_normal, ratio_known_outlier, ratio_pollution)
51 | train_set.semi_targets[idx] = torch.tensor(semi_targets) # set respective semi-supervised labels
52 |
53 | # Subset train_set to semi-supervised setup
54 | self.train_set = Subset(train_set, idx)
55 | '''
56 | self.train_set = train_set
57 | # Get test set
58 | self.test_set = MyFashionMNIST(root=self.root, train=False, transform=transform,
59 | target_transform=target_transform, download=True)
60 |
61 |
62 | class MyFashionMNIST(FashionMNIST):
63 | """
64 | Torchvision FashionMNIST class with additional targets for the semi-supervised setting and patch of __getitem__
65 | method to also return the semi-supervised target as well as the index of a data sample.
66 | """
67 |
68 | def __init__(self, *args, **kwargs):
69 | super(MyFashionMNIST, self).__init__(*args, **kwargs)
70 |
71 | self.semi_targets = torch.zeros_like(self.targets)
72 |
73 | def __getitem__(self, index):
74 | """Override the original method of the MyFashionMNIST class.
75 | Args:
76 | index (int): Index
77 |
78 | Returns:
79 | tuple: (image, target, semi_target, index)
80 | """
81 | img, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index])
82 |
83 | # doing this so that it is consistent with all other datasets
84 | # to return a PIL Image
85 | img = Image.fromarray(img.numpy(), mode='L')
86 |
87 | if self.transform is not None:
88 | img = self.transform(img)
89 |
90 | if self.target_transform is not None:
91 | target = self.target_transform(target)
92 |
93 | return img, target, semi_target, index
94 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/EAL-GAN/models/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/base/odds_dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from torch.utils.data import Dataset
3 | from scipy.io import loadmat
4 | from sklearn.model_selection import train_test_split
5 | from sklearn.preprocessing import StandardScaler, MinMaxScaler
6 | from torchvision.datasets.utils import download_url
7 |
8 | import pandas as pd
9 |
10 | import os
11 | import torch
12 | import numpy as np
13 |
14 |
15 | class ODDSDataset(Dataset):
16 | """
17 | ODDSDataset class for datasets from Outlier Detection DataSets (ODDS): http://odds.cs.stonybrook.edu/
18 |
19 | Dataset class with additional targets for the semi-supervised setting and modification of __getitem__ method
20 | to also return the semi-supervised target as well as the index of a data sample.
21 | """
22 |
23 | urls = {
24 | 'arrhythmia': 'https://www.dropbox.com/s/lmlwuspn1sey48r/arrhythmia.mat?dl=1',
25 | 'cardio': 'https://www.dropbox.com/s/galg3ihvxklf0qi/cardio.mat?dl=1',
26 | 'satellite': 'https://www.dropbox.com/s/dpzxp8jyr9h93k5/satellite.mat?dl=1',
27 | 'satimage-2': 'https://www.dropbox.com/s/hckgvu9m6fs441p/satimage-2.mat?dl=1',
28 | 'shuttle': 'https://www.dropbox.com/s/mk8ozgisimfn3dw/shuttle.mat?dl=1',
29 | 'thyroid': 'https://www.dropbox.com/s/bih0e15a0fukftb/thyroid.mat?dl=1'
30 | }
31 |
32 | def __init__(self, root: str, dataset_name: str, train=True, random_state=None, download=False):
33 | super(Dataset, self).__init__()
34 |
35 | self.classes = [0, 1]
36 |
37 | if isinstance(root, torch._six.string_classes):
38 | root = os.path.expanduser(root)
39 | self.root = Path(root)
40 | self.dataset_name = dataset_name
41 | self.train = train # training set or test set
42 | self.file_name = self.dataset_name + '.mat'
43 | self.data_file = self.root / self.file_name
44 |
45 | if download:
46 | self.download()
47 |
48 | mat = loadmat(self.data_file)
49 | X = mat['X']
50 | y = mat['y'].ravel()
51 | idx_norm = y == 0
52 | idx_out = y == 1
53 |
54 | # 60% data for training and 40% for testing; keep outlier ratio
55 | '''
56 | X_train_norm, X_test_norm, y_train_norm, y_test_norm = train_test_split(X[idx_norm], y[idx_norm],
57 | test_size=0.4,
58 | random_state=random_state)
59 | X_train_out, X_test_out, y_train_out, y_test_out = train_test_split(X[idx_out], y[idx_out],
60 | test_size=0.4,
61 | random_state=random_state)
62 | X_train = np.concatenate((X_train_norm, X_train_out))
63 | X_test = np.concatenate((X_test_norm, X_test_out))
64 | y_train = np.concatenate((y_train_norm, y_train_out))
65 | y_test = np.concatenate((y_test_norm, y_test_out))
66 | '''
67 | X_train, X_test, y_train, y_test = \
68 | train_test_split(X, y, test_size=0.4, random_state=random_state)
69 |
70 | # Standardize data (per feature Z-normalization, i.e. zero-mean and unit variance)
71 | scaler = StandardScaler().fit(X_train)
72 | X_train_scaled = scaler.transform(X_train)
73 | X_test_scaled = scaler.transform(X_test)
74 |
75 | # Scale to range [0,1]
76 | #minmax_scaler = MinMaxScaler().fit(X_train_stand)
77 | #X_train_scaled = minmax_scaler.transform(X_train_stand)
78 | #X_test_scaled = minmax_scaler.transform(X_test_stand)
79 |
80 | X_train_pandas = pd.DataFrame(X_train_scaled)
81 | X_test_pandas = pd.DataFrame(X_test_scaled)
82 | X_train_pandas.fillna(X_train_pandas.mean(), inplace=True)
83 | X_test_pandas.fillna(X_train_pandas.mean(), inplace=True)
84 | X_train_scaled = X_train_pandas.values
85 | X_test_scaled = X_test_pandas.values
86 |
87 | if self.train:
88 | self.data = torch.tensor(X_train_scaled, dtype=torch.float32)
89 | self.targets = torch.tensor(y_train, dtype=torch.int64)
90 | else:
91 | self.data = torch.tensor(X_test_scaled, dtype=torch.float32)
92 | self.targets = torch.tensor(y_test, dtype=torch.int64)
93 |
94 | #print("fuck=", torch.sum(self.targets==1))
95 |
96 | self.semi_targets = torch.zeros_like(self.targets)
97 |
98 | def __getitem__(self, index):
99 | """
100 | Args:
101 | index (int): Index
102 |
103 | Returns:
104 | tuple: (sample, target, semi_target, index)
105 | """
106 | sample, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index])
107 |
108 | return sample, target, semi_target, index
109 |
110 | def __len__(self):
111 | return len(self.data)
112 |
113 | def _check_exists(self):
114 | return os.path.exists(self.data_file)
115 |
116 | def download(self):
117 | """Download the ODDS dataset if it doesn't exist in root already."""
118 |
119 | if self._check_exists():
120 | return
121 |
122 | # download file
123 | download_url(self.urls[self.dataset_name], self.root, self.file_name)
124 |
125 | print('Done!')
126 |
--------------------------------------------------------------------------------
/EAL-GAN/Train_EAL_GAN.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import sys
7 | from time import time
8 |
9 | sys.path.append(
10 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..')))
11 | # supress warnings for clean output
12 | import warnings
13 |
14 | warnings.filterwarnings("ignore")
15 |
16 | import numpy as np
17 | import pandas as pd
18 | from sklearn.model_selection import train_test_split
19 | from scipy.io import loadmat
20 |
21 | import torch
22 |
23 | from models.pyod_utils import standardizer, AUC_and_Gmean
24 | from models.pyod_utils import precision_n_scores, gmean_scores
25 | from models.utils import parse_args
26 | from sklearn.metrics import roc_auc_score
27 | from models.EAL_GAN import EAL_GAN
28 |
29 | if __name__=='__main__':
30 | # Define data file and read X and y
31 | data_root = './data'
32 | save_dir = './results'
33 |
34 |
35 | data_names = [#'arrhythmia.mat',
36 | # 'cardio.mat',
37 | # 'glass.mat',
38 | # 'ionosphere.mat',
39 | # 'letter.mat',
40 | # 'lympho.mat',
41 | # 'mnist.mat',
42 | # 'musk.mat',
43 | 'optdigits.mat',
44 | # 'pendigits.mat',
45 | # 'pima.mat',
46 | # 'satellite.mat',
47 | # 'satimage-2.mat',
48 | # 'shuttle.mat',
49 | # 'vowels.mat',
50 | # 'annthyroid.mat',
51 | # 'campaign.mat',
52 | # 'celeba.mat',
53 | 'fraud.mat',
54 | 'donors.mat'
55 | ]
56 | # define the number of iterations
57 | n_ite = 10
58 | n_classifiers = 1
59 |
60 | df_columns = ['Data', '#Samples', '# Dimensions', 'Outlier Perc',
61 | 'AUC_Mean', 'AUC_Std', 'Gmean', 'Gmean_Std']
62 |
63 | # initialize the container for saving the results
64 | roc_df = pd.DataFrame(columns=df_columns)
65 | gmean_df = pd.DataFrame(columns=df_columns)
66 | anomaly_ratio_df = pd.DataFrame(columns=df_columns)
67 | overall_ratio_df = pd.DataFrame(columns=df_columns)
68 | time_df = pd.DataFrame(columns=df_columns)
69 | args = parse_args()
70 |
71 | for data_name in data_names:
72 | mat = loadmat(os.path.join(data_root, data_name))
73 |
74 | X = mat['X']
75 | y = mat['y'].ravel()
76 | y = y.astype(np.long)
77 | idx_norm = y == 0
78 | idx_out = y == 1
79 |
80 | outliers_fraction = np.count_nonzero(y) / len(y)
81 | print(outliers_fraction)
82 | outliers_percentage = round(outliers_fraction * 100, ndigits=4)
83 |
84 | # construct containers for saving results
85 | roc_list = [data_name[:-4], X.shape[0], X.shape[1], outliers_percentage]
86 | gmean_list = [data_name[:-4], X.shape[0], X.shape[1], outliers_percentage]
87 | anomaly_ratio_list = [data_name[:-4], X.shape[0], X.shape[1], outliers_percentage]
88 | overall_ratio_list = [data_name[:-4], X.shape[0], X.shape[1], outliers_percentage]
89 |
90 | roc_mat = np.zeros([n_ite, n_classifiers])
91 | gmean_mat = np.zeros([n_ite, n_classifiers])
92 | anomaly_ratio_mat = np.zeros([n_ite, n_classifiers])
93 | overall_ratio_mat = np.zeros([n_ite, n_classifiers])
94 |
95 | for i in range(n_ite):
96 | print("\n... Processing", data_name[:-4], '...', 'Iteration', i + 1)
97 | random_state = np.random.RandomState(i)
98 |
99 | # 60% data for training and 40% for testing; keep outlier ratio
100 |
101 | X_train_norm, X_test_norm, y_train_norm, y_test_norm = train_test_split(X[idx_norm], y[idx_norm],
102 | test_size=0.4,
103 | random_state=random_state)
104 | X_train_out, X_test_out, y_train_out, y_test_out = train_test_split(X[idx_out], y[idx_out],
105 | test_size=0.4,
106 | random_state=random_state)
107 | X_train = np.concatenate((X_train_norm, X_train_out))
108 | X_test = np.concatenate((X_test_norm, X_test_out))
109 | y_train = np.concatenate((y_train_norm, y_train_out))
110 | y_test = np.concatenate((y_test_norm, y_test_out))
111 |
112 | X_train_norm, X_test_norm = standardizer(X_train, X_test)
113 |
114 | X_train_pandas = pd.DataFrame(X_train_norm)
115 | X_test_pandas = pd.DataFrame(X_test_norm)
116 | X_train_pandas.fillna(X_train_pandas.mean(), inplace=True)
117 | X_test_pandas.fillna(X_train_pandas.mean(), inplace=True)
118 | X_train_norm = X_train_pandas.values
119 | X_test_norm = X_test_pandas.values
120 |
121 | #X_train_norm, X_test_norm = X_train, X_test
122 | data_x = torch.from_numpy(X_train_norm).float()
123 | data_y = torch.from_numpy(y_train).long()
124 | test_x = torch.from_numpy(X_test_norm).float()
125 | test_y = torch.from_numpy(y_test).long()
126 | #print(data_x)
127 |
128 | t0 = time()
129 | #todo: add my method
130 | eal_gan = EAL_GAN(args, data_x, data_y, test_x, test_y)
131 | best_auc, best_gmean = eal_gan.fit()
132 |
133 | t1 = time()
134 | duration = round(t1 - t0, ndigits=4)
135 |
136 |
137 | print('AUC:%.4f, Gmean:%.4f execution time: %.4f s' % (best_auc, best_gmean, duration))
138 |
139 | roc_mat[i, 0] = best_auc
140 | gmean_mat[i, 0] = best_gmean
141 |
142 | roc_list = roc_list + np.mean(roc_mat, axis=0).tolist() + np.std(roc_mat, axis=0).tolist() + np.mean(gmean_mat, axis=0).tolist() + np.std(gmean_mat, axis=0).tolist()
143 | temp_df = pd.DataFrame(roc_list).transpose()
144 | temp_df.columns = df_columns
145 | roc_df = pd.concat([roc_df, temp_df], axis=0)
146 |
147 |
148 | # Save the results for each run
149 | save_path1 = os.path.join(save_dir, "AUC_EAL_GAN.csv")
150 |
151 | roc_df.to_csv(save_path1, index=False, float_format='%.4f')
152 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/my_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn.metrics import roc_auc_score
4 |
5 | import argparse
6 |
7 | def active_sampling(args, real_x, real_y, NetD_Ensemble, need_sample=True):
8 | if need_sample:
9 | pt = None
10 | for i in range(args.ensemble_num):
11 | netD = NetD_Ensemble[i]
12 | pt_i = netD(real_x, mode=1) #get the confidence on real data
13 | if i==0:
14 | pt = pt_i.detach()
15 | else:
16 | pt += pt_i.detach()
17 | pt /= args.ensemble_num
18 | pt = pt.view(pt.shape[0],)
19 | batch_size_selected = int(real_x.shape[0]*args.active_rate)
20 | batch_size_selected = max(1, batch_size_selected)
21 | pt = torch.abs(pt-0.5) # select the instance with low margin value
22 | _, idx = torch.sort(pt, descending=False)
23 | X = real_x[idx[0:batch_size_selected]].detach()
24 | Y = real_y[idx[0:batch_size_selected]].detach()
25 | X_unlabeled = real_x[idx[batch_size_selected:]].detach()
26 | Y = Y.view(Y.shape[0],)
27 | #X = X.view(X.shape[0], -1)
28 | return X, Y, X_unlabeled, idx[0:batch_size_selected]
29 | else:
30 | batch_size_selected = int(real_x.shape[0]*args.active_rate)
31 | batch_size_selected = max(1, batch_size_selected)
32 | X = real_x[0:batch_size_selected].detach()
33 | Y = real_y[0:batch_size_selected].detach()
34 | X_unlabeled = None
35 | idx = torch.tensor(np.arange(batch_size_selected)).view(-1,).long()
36 | return X, Y, X_unlabeled, idx
37 |
38 | def get_gmean(y, y_pred, threshold=0.5):
39 | """Utility function to calculate precision @ rank n.
40 |
41 | Parameters
42 | ----------
43 | y : list or numpy array of shape (n_samples,)
44 | The ground truth. Binary (0: inliers, 1: outliers).
45 |
46 | y_pred : list or numpy array of shape (n_samples,)
47 | The raw outlier scores as returned by a fitted model.
48 |
49 |
50 | Returns
51 | -------
52 | Gmean: float
53 | """
54 | #y_pred = get_label_n(y, y_pred)
55 | y = y.reshape(-1, )
56 | y_pred = y_pred.reshape(-1, )
57 | y_pred = (y_pred >= threshold).astype('int')
58 | ones_all = (y==1).sum()
59 | ones_correct = ((y==1) & (y_pred==1)).sum()
60 | zeros_all = (y==0).sum()
61 | zeros_correct = ((y==0) & (y_pred==0)).sum()
62 | Gmean = np.sqrt((1.0*ones_correct/ones_all) * (1.0*zeros_correct/zeros_all))
63 | #Gmean *= np.sqrt
64 |
65 | return Gmean
66 |
67 | def AUC_and_Gmean(y_test, y_scores):
68 | #print(y_test)
69 | #print(y_scores)
70 |
71 | auc = round(roc_auc_score(y_test, y_scores), ndigits=4)
72 | gmean = round(get_gmean(y_test, y_scores, 0.5), ndigits=4)
73 |
74 | return auc, gmean
75 |
76 | def parse_args():
77 | ################################################################################
78 | # Settings
79 | ################################################################################
80 | parser = argparse.ArgumentParser()
81 |
82 | parser.add_argument('--dataset_name', default='cardio', type=str)
83 | parser.add_argument('--net_name', default='cardio_mlp', type=str)
84 | parser.add_argument('--xp_path', default='./log', type=str)
85 | parser.add_argument('--data_path', default='./data', type=str)
86 | parser.add_argument('--load_config', default=None,
87 | help='Config JSON-file path (default: None).')
88 | parser.add_argument('--load_model', default=None,
89 | help='Model file path (default: None).')
90 | parser.add_argument('--eta', type=float, default=1.0, help='Deep SAD hyperparameter eta (must be 0 < eta).')
91 | parser.add_argument('--ratio_known_normal', type=float, default=0.9,
92 | help='Ratio of known (labeled) normal training examples.')
93 | parser.add_argument('--ratio_known_outlier', type=float, default=0.5,
94 | help='Ratio of known (labeled) anomalous training examples.')
95 | parser.add_argument('--ratio_pollution', type=float, default=0.0,
96 | help='Pollution ratio of unlabeled training data with unknown (unlabeled) anomalies.')
97 | parser.add_argument('--device', type=str, default='cuda', help='Computation device to use ("cpu", "cuda", "cuda:2", etc.).')
98 | parser.add_argument('--seed', type=int, default=0, help='Set seed. If -1, use randomization.')
99 | parser.add_argument('--weight_decay', type=float, default=1e-6,
100 | help='Weight decay (L2 penalty) hyperparameter for Deep SAD objective.')
101 |
102 | parser.add_argument('--n_jobs_dataloader', type=int, default=0,
103 | help='Number of workers for data loading. 0 means that the data will be loaded in the main process.')
104 | parser.add_argument('--normal_class', type=int, default=0,
105 | help='Specify the normal class of the dataset (all other classes are considered anomalous).')
106 | parser.add_argument('--known_outlier_class', type=int, default=1,
107 | help='Specify the known outlier class of the dataset for semi-supervised anomaly detection.')
108 | parser.add_argument('--n_known_outlier_classes', type=int, default=1,
109 | help='Number of known outlier classes.'
110 | 'If 0, no anomalies are known.'
111 | 'If 1, outlier class as specified in --known_outlier_class option.'
112 | 'If > 1, the specified number of outlier classes will be sampled at random.')
113 |
114 |
115 | #parameter for GAN
116 | parser.add_argument('--channel', type=int, default=64,
117 | help='capacity of the generator and discriminator')
118 | parser.add_argument('--resolution', type=int, default=32,
119 | help='resolution of the images')
120 | parser.add_argument('--gamma', type=float, default=0.1,
121 | help='gamma for the scheduler')
122 | parser.add_argument('--step_size', type=int, default=10,
123 | help='step_size for the scheduler to adjust learning rate')
124 | parser.add_argument('--max_epochs', type=int, default=20,
125 | help='Stop training generator after stop_epochs.')
126 | parser.add_argument('--lr_g', type=float, default=0.0001,
127 | help='Learning rate of generator.')
128 | parser.add_argument('--lr_d', type=float, default=0.0001,
129 | help='Learning rate of discriminator.')
130 | parser.add_argument('--active_rate', type=float, default=0.05,
131 | help='the proportion of instances that need to be labeled.')
132 | parser.add_argument('--batch_size', type=int, default=200,
133 | help='batch size.')
134 | parser.add_argument('--dim_z', type=int, default=128,
135 | help='dim for latent noise.')
136 | parser.add_argument('--ensemble_num', type=int, default=5,
137 | help='the number of dis in ensemble.')
138 | parser.add_argument('--cuda', type=bool, default=True,
139 | help='if GPU used')
140 | parser.add_argument('--feat_dim', type=int, default=3,
141 | help='channel number of images')
142 |
143 | return parser.parse_args()
144 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/EAL_GAN.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import functools
3 | from numpy.random import gamma
4 | import torch
5 | import torch.nn as nn
6 | from torch.optim import Adam
7 | #import torch.optim.lr_scheduler.StepLR as StepLR
8 |
9 | import numpy as np
10 | import pandas as pd
11 | from tqdm import tqdm
12 |
13 |
14 | from src.BigGANdeep import Generator, Discriminator
15 | from src.my_utils import active_sampling, AUC_and_Gmean
16 | from src.loss import loss_dis_fake, loss_dis_real
17 | import src.utils as utils
18 |
19 | class EAL_GAN(nn.Module):
20 | def __init__(self, args, dataset):
21 | super().__init__()
22 | self.args = args
23 | self.dataset = dataset
24 | self.device = 'cuda' if args.cuda else 'cpu'
25 |
26 | self.generator = Generator(G_ch=args.channel, dim_z=args.dim_z, resolution=args.resolution, n_classes=2, output_channel=args.feat_dim)
27 | self.optimizer_g = Adam(self.generator.parameters(), lr=args.lr_g, betas=(0.00, 0.99))
28 | self.scheduler_g = torch.optim.lr_scheduler.StepLR(self.optimizer_g, step_size=args.step_size, gamma=args.gamma)
29 | if args.cuda:
30 | self.generator = nn.DataParallel(self.generator).cuda()
31 |
32 | self.NetD_Ensemble = []
33 | self.Opti_Ensemble = []
34 | self.schD_Ensemble = []
35 | lr_ds = np.random.rand(args.ensemble_num)*(args.lr_d*5-args.lr_d)+args.lr_d #learning rate
36 | for index in range(args.ensemble_num):
37 | netD = Discriminator(D_ch=args.channel, resolution=args.resolution, n_classes=2, input_channel=args.feat_dim)
38 | optimizerD = Adam(netD.parameters(), lr=args.lr_d, betas=(0.00, 0.99))
39 | schedule_D = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=args.step_size, gamma=args.gamma)
40 | if args.cuda:
41 | netD = nn.DataParallel(netD).cuda()
42 |
43 | self.NetD_Ensemble += [netD]
44 | self.Opti_Ensemble += [optimizerD]
45 | self.schD_Ensemble += [schedule_D]
46 |
47 | def fit(self):
48 | self.iterations = 0
49 |
50 | self.z, self.y = utils.prepare_z_y(self.args.batch_size, dim_z=self.args.dim_z, nclasses=2, device=self.device)
51 | z_for_sample, y_for_sample = utils.prepare_z_y(self.args.batch_size, dim_z=self.args.dim_z, nclasses=2, device=self.device)
52 | self.sample = functools.partial(utils.sample, G=self.generator, z_=z_for_sample, y_=y_for_sample)
53 | self.train_history = defaultdict(list)
54 |
55 | Best_Measure_Recorded = 0
56 | for epoch in range(self.args.max_epochs):
57 |
58 | auc_train, gmean_train, auc_test, gmean_test = self.train_one_epoch(epoch)
59 | if auc_train*gmean_train > Best_Measure_Recorded:
60 | Best_Measure_Recorded = auc_train*gmean_train
61 | Best_AUC = auc_test
62 | Best_Gmean = gmean_test
63 |
64 | states = {
65 | 'epoch':epoch,
66 | 'gen_dict':self.generator.state_dict(),
67 | 'auc_train':auc_train,
68 | 'auc_test': auc_test
69 | }
70 | for i in range(self.args.ensemble_num):
71 | netD = self.NetD_Ensemble[i]
72 | states['dis_dict'+str(i)] = netD.state_dict()
73 |
74 | torch.save(states, './logs/checkpoint_best.pth')
75 | #print(train_AUC, test_AUC, epoch)
76 | print('Training for epoch %d: Train_AUC=%.4f train_Gmean=%.4f Test_AUC=%.4f Test_Gmean=%.4f' % (epoch + 1, auc_train, gmean_train, auc_test, gmean_test))
77 |
78 |
79 | '''
80 | #step 1: load the best models
81 | self.Best_Ensemble = []
82 | states = torch.load('./logs/checkpoint_best.pth')
83 | self.generator.load_state_dict(states['gen_dict'])
84 | for i in range(self.args.ensemble_num):
85 | netD = self.NetD_Ensemble[i]
86 | netD.load_state_dict(states['dis_dict'+str(i)])
87 | self.Best_Ensemble += [netD]
88 | '''
89 |
90 | return Best_AUC, Best_Gmean
91 |
92 |
93 | def train_one_epoch(self, epoch=1):
94 | train_loader, test_loader = self.dataset.loaders(batch_size=self.args.batch_size, num_workers=0)
95 | for (inputs, targets, semi_argets,_) in tqdm(train_loader):
96 | #print("inputs_shape=", inputs.shape)
97 | self.iterations += 1
98 | #step 1: update the discriminators
99 | if self.args.cuda:
100 | inputs = inputs.cuda()
101 | targets = targets.cuda()
102 |
103 | self.z.sample_()
104 | self.y.sample_()
105 | generated_x = self.generator(self.z, self.y)
106 | generated_x = generated_x.detach()
107 |
108 | #print("generated_shape=", generated_x.shape)
109 |
110 | real_weights = None
111 | fake_weights = None
112 | real_cat_weights = None
113 | fake_cat_weights = None
114 |
115 | dis_loss = 0
116 | gen_loss = 0
117 | #select p% of the training data, label them
118 | real_x_selected, real_y_selected, _, index_selected = active_sampling(self.args, inputs, targets, self.NetD_Ensemble, need_sample=(self.iterations>1))
119 |
120 | for i in range(self.args.ensemble_num):
121 | optimizer = self.Opti_Ensemble[i]
122 | netD = self.NetD_Ensemble[i]
123 | out_real_fake, out_real_categoy = netD(real_x_selected, real_y_selected)
124 |
125 | loss_adv_real, real_loss_cat, real_weights, real_cat_weights = loss_dis_real(out_real_fake, out_real_categoy, real_y_selected, real_weights, real_cat_weights)
126 | real_loss = loss_adv_real+real_loss_cat
127 |
128 | #train on fake data
129 | output, out_fake_category = netD(generated_x, self.y.detach())
130 | loss_adv_fake, loss_cat_fake, fake_weights, fake_cat_weights = loss_dis_fake(output, out_fake_category, self.y, fake_weights, fake_cat_weights)
131 | fake_loss = loss_adv_fake + loss_cat_fake
132 | sum_loss = real_loss+fake_loss
133 | dis_loss += sum_loss
134 |
135 | self.train_history['discriminator_loss_'+str(i)].append(sum_loss)
136 | optimizer.zero_grad()
137 | sum_loss.backward()
138 | optimizer.step()
139 | #print('real_loss=', real_loss.data, " fake_loss=", fake_loss.data)
140 | self.train_history['discriminator_loss'].append(dis_loss)
141 |
142 | #train the generator
143 | self.z.sample_()
144 | self.y.sample_()
145 | generated_x = self.generator(self.z, self.y)
146 |
147 | gen_loss = 0
148 | gen_weights = None
149 | gen_cat_weights = None
150 | for i in range(self.args.ensemble_num):
151 | #optimizer = names['optimizerD_' + str(i)]
152 | netD = self.NetD_Ensemble[i]
153 | output, out_category = netD(generated_x, self.y)
154 | #out_real_fake, out_real_categoy = netD(real_x, real_y)
155 | loss, loss_cat, gen_weights, gen_cat_weights = loss_dis_real(output, out_category, self.y, gen_weights, gen_cat_weights)
156 | #loss, gen_weights = loss_gen(output, gen_weights)
157 | gen_loss += (loss+loss_cat)
158 | self.train_history['generator_loss'].append(gen_loss)
159 | self.optimizer_g.zero_grad()
160 | gen_loss.backward()
161 | self.optimizer_g.step()
162 |
163 | if self.iterations%10==0:
164 | print("dis_loss=%.4f gen_loss=%.4f" % (gen_loss.data, dis_loss.data ))
165 |
166 | torch.cuda.empty_cache()
167 |
168 | y_train, y_scores = self.predict(train_loader, self.NetD_Ensemble)
169 | #print("anomaly number=", np.sum(y_train))
170 | y_scores_pandas = pd.DataFrame(y_scores)
171 | y_scores_pandas.fillna(0, inplace=True)
172 | y_scores = y_scores_pandas.values
173 |
174 | auc, gmean = AUC_and_Gmean(y_train, y_scores)
175 | self.train_history['train_auc'].append(auc)
176 | self.train_history['train_Gmean'].append(gmean)
177 |
178 | y_test, y_scores_test = self.predict(test_loader, self.NetD_Ensemble)
179 | y_scores_pandas = pd.DataFrame(y_scores_test)
180 | y_scores_pandas.fillna(0, inplace=True)
181 | y_scores_test = y_scores_pandas.values
182 |
183 | test_auc, test_gmean = AUC_and_Gmean(y_test, y_scores_test)
184 |
185 | return auc, gmean, test_auc, test_gmean
186 |
187 | def predict(self, data_loader, dis_Ensemble=None):
188 | p = []
189 | y = []
190 | for (real_x, real_y, _, index) in tqdm(data_loader):
191 | final_pt = 0
192 | for i in range(self.args.ensemble_num):
193 | pt = self.Best_Ensemble[i](real_x, mode=1) if dis_Ensemble is None else dis_Ensemble[i](real_x, mode=1)
194 | final_pt = pt.cpu().detach().numpy() if i==0 else final_pt+pt.cpu().detach().numpy()
195 |
196 | final_pt /= self.args.ensemble_num
197 | #final_pt = final_pt.view(-1,)
198 |
199 | p += [final_pt]
200 | y += [real_y.cpu().detach().numpy()]
201 | #p = torch.cat(p, 0).cpu().detach().numpy()
202 | #y = torch.cat(y, 0).cpu().detach().numpy()
203 | #print(p)
204 | p = np.concatenate(p, 0)
205 | y = np.concatenate(y, 0)
206 | return y, p
207 |
208 |
209 |
210 |
211 |
212 |
--------------------------------------------------------------------------------
/EAL-GAN/models/pyod_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """A set of utility functions to support outlier detection.
3 | """
4 | # Author: Yue Zhao
5 | # License: BSD 2 clause
6 |
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | import numpy as np
11 | from numpy import percentile
12 | import numbers
13 | import math
14 |
15 | import sklearn
16 | from sklearn.metrics import precision_score
17 | from sklearn.preprocessing import StandardScaler
18 |
19 | from sklearn.utils import column_or_1d
20 | from sklearn.utils import check_array
21 | from sklearn.utils import check_consistent_length
22 |
23 | from sklearn.utils import check_random_state
24 | from sklearn.utils.random import sample_without_replacement
25 | from sklearn.metrics import roc_auc_score
26 |
27 | MAX_INT = np.iinfo(np.int32).max
28 | MIN_INT = -1 * MAX_INT
29 |
30 |
31 | def check_parameter(param, low=MIN_INT, high=MAX_INT, param_name='',
32 | include_left=False, include_right=False):
33 | """Check if an input is within the defined range.
34 |
35 | Parameters
36 | ----------
37 | param : int, float
38 | The input parameter to check.
39 |
40 | low : int, float
41 | The lower bound of the range.
42 |
43 | high : int, float
44 | The higher bound of the range.
45 |
46 | param_name : str, optional (default='')
47 | The name of the parameter.
48 |
49 | include_left : bool, optional (default=False)
50 | Whether includes the lower bound (lower bound <=).
51 |
52 | include_right : bool, optional (default=False)
53 | Whether includes the higher bound (<= higher bound).
54 |
55 | Returns
56 | -------
57 | within_range : bool or raise errors
58 | Whether the parameter is within the range of (low, high)
59 |
60 | """
61 |
62 | # param, low and high should all be numerical
63 | if not isinstance(param, (numbers.Integral, np.integer, np.float)):
64 | raise TypeError('{param_name} is set to {param} Not numerical'.format(
65 | param=param, param_name=param_name))
66 |
67 | if not isinstance(low, (numbers.Integral, np.integer, np.float)):
68 | raise TypeError('low is set to {low}. Not numerical'.format(low=low))
69 |
70 | if not isinstance(high, (numbers.Integral, np.integer, np.float)):
71 | raise TypeError('high is set to {high}. Not numerical'.format(
72 | high=high))
73 |
74 | # at least one of the bounds should be specified
75 | if low is MIN_INT and high is MAX_INT:
76 | raise ValueError('Neither low nor high bounds is undefined')
77 |
78 | # if wrong bound values are used
79 | if low > high:
80 | raise ValueError(
81 | 'Lower bound > Higher bound')
82 |
83 | # value check under different bound conditions
84 | if (include_left and include_right) and (param < low or param > high):
85 | raise ValueError(
86 | '{param_name} is set to {param}. '
87 | 'Not in the range of [{low}, {high}].'.format(
88 | param=param, low=low, high=high, param_name=param_name))
89 |
90 | elif (include_left and not include_right) and (
91 | param < low or param >= high):
92 | raise ValueError(
93 | '{param_name} is set to {param}. '
94 | 'Not in the range of [{low}, {high}).'.format(
95 | param=param, low=low, high=high, param_name=param_name))
96 |
97 | elif (not include_left and include_right) and (
98 | param <= low or param > high):
99 | raise ValueError(
100 | '{param_name} is set to {param}. '
101 | 'Not in the range of ({low}, {high}].'.format(
102 | param=param, low=low, high=high, param_name=param_name))
103 |
104 | elif (not include_left and not include_right) and (
105 | param <= low or param >= high):
106 | raise ValueError(
107 | '{param_name} is set to {param}. '
108 | 'Not in the range of ({low}, {high}).'.format(
109 | param=param, low=low, high=high, param_name=param_name))
110 | else:
111 | return True
112 |
113 |
114 | def check_detector(detector):
115 | """Checks if fit and decision_function methods exist for given detector
116 |
117 | Parameters
118 | ----------
119 | detector : pyod.models
120 | Detector instance for which the check is performed.
121 |
122 | """
123 |
124 | if not hasattr(detector, 'fit') or not hasattr(detector,
125 | 'decision_function'):
126 | raise AttributeError("%s is not a detector instance." % (detector))
127 |
128 |
129 | def standardizer(X, X_t=None, keep_scalar=False):
130 | """Conduct Z-normalization on data to turn input samples become zero-mean
131 | and unit variance.
132 |
133 | Parameters
134 | ----------
135 | X : numpy array of shape (n_samples, n_features)
136 | The training samples
137 |
138 | X_t : numpy array of shape (n_samples_new, n_features), optional (default=None)
139 | The data to be converted
140 |
141 | keep_scalar : bool, optional (default=False)
142 | The flag to indicate whether to return the scalar
143 |
144 | Returns
145 | -------
146 | X_norm : numpy array of shape (n_samples, n_features)
147 | X after the Z-score normalization
148 |
149 | X_t_norm : numpy array of shape (n_samples, n_features)
150 | X_t after the Z-score normalization
151 |
152 | scalar : sklearn scalar object
153 | The scalar used in conversion
154 |
155 | """
156 | X = check_array(X)
157 | scaler = StandardScaler().fit(X)
158 |
159 | if X_t is None:
160 | if keep_scalar:
161 | return scaler.transform(X), scaler
162 | else:
163 | return scaler.transform(X)
164 | else:
165 | X_t = check_array(X_t)
166 | if X.shape[1] != X_t.shape[1]:
167 | raise ValueError(
168 | "The number of input data feature should be consistent"
169 | "X has {0} features and X_t has {1} features.".format(
170 | X.shape[1], X_t.shape[1]))
171 | if keep_scalar:
172 | return scaler.transform(X), scaler.transform(X_t), scaler
173 | else:
174 | return scaler.transform(X), scaler.transform(X_t)
175 |
176 |
177 | def score_to_label(pred_scores, outliers_fraction=0.1):
178 | """Turn raw outlier outlier scores to binary labels (0 or 1).
179 |
180 | Parameters
181 | ----------
182 | pred_scores : list or numpy array of shape (n_samples,)
183 | Raw outlier scores. Outliers are assumed have larger values.
184 |
185 | outliers_fraction : float in (0,1)
186 | Percentage of outliers.
187 |
188 | Returns
189 | -------
190 | outlier_labels : numpy array of shape (n_samples,)
191 | For each observation, tells whether or not
192 | it should be considered as an outlier according to the
193 | fitted model. Return the outlier probability, ranging
194 | in [0,1].
195 | """
196 | # check input values
197 | pred_scores = column_or_1d(pred_scores)
198 | check_parameter(outliers_fraction, 0, 1)
199 |
200 | threshold = percentile(pred_scores, 100 * (1 - outliers_fraction))
201 | pred_labels = (pred_scores > threshold).astype('int')
202 | return pred_labels
203 |
204 |
205 | def precision_n_scores(y, y_pred, n=None):
206 | """Utility function to calculate precision @ rank n.
207 |
208 | Parameters
209 | ----------
210 | y : list or numpy array of shape (n_samples,)
211 | The ground truth. Binary (0: inliers, 1: outliers).
212 |
213 | y_pred : list or numpy array of shape (n_samples,)
214 | The raw outlier scores as returned by a fitted model.
215 |
216 | n : int, optional (default=None)
217 | The number of outliers. if not defined, infer using ground truth.
218 |
219 | Returns
220 | -------
221 | precision_at_rank_n : float
222 | Precision at rank n score.
223 |
224 | """
225 |
226 | # turn raw prediction decision scores into binary labels
227 | y_pred = get_label_n(y, y_pred, n)
228 |
229 | # enforce formats of y and labels_
230 | y = column_or_1d(y)
231 | y_pred = column_or_1d(y_pred)
232 |
233 | return precision_score(y, y_pred)
234 |
235 |
236 | def get_label_n(y, y_pred, n=None):
237 | """Function to turn raw outlier scores into binary labels by assign 1
238 | to top n outlier scores.
239 |
240 | Parameters
241 | ----------
242 | y : list or numpy array of shape (n_samples,)
243 | The ground truth. Binary (0: inliers, 1: outliers).
244 |
245 | y_pred : list or numpy array of shape (n_samples,)
246 | The raw outlier scores as returned by a fitted model.
247 |
248 | n : int, optional (default=None)
249 | The number of outliers. if not defined, infer using ground truth.
250 |
251 | Returns
252 | -------
253 | labels : numpy array of shape (n_samples,)
254 | binary labels 0: normal points and 1: outliers
255 |
256 | Examples
257 | --------
258 | >>> from pyod.utils.utility import get_label_n
259 | >>> y = [0, 1, 1, 0, 0]
260 | >>> y_pred = [0.1, 0.5, 0.3, 0.2, 0.7]
261 | >>> get_label_n(y, y_pred)
262 | array([0, 1, 0, 0, 1])
263 |
264 | """
265 |
266 | # enforce formats of inputs
267 | y = column_or_1d(y)
268 | y_pred = column_or_1d(y_pred)
269 |
270 | check_consistent_length(y, y_pred)
271 | y_len = len(y) # the length of targets
272 |
273 | # calculate the percentage of outliers
274 | if n is not None:
275 | outliers_fraction = n / y_len
276 | else:
277 | outliers_fraction = np.count_nonzero(y) / y_len
278 |
279 | threshold = percentile(y_pred, 100 * (1 - outliers_fraction))
280 | y_pred = (y_pred > threshold).astype('int')
281 |
282 | return y_pred
283 |
284 | def gmean_scores(y, y_pred, threshold=0.5):
285 | """Utility function to calculate precision @ rank n.
286 |
287 | Parameters
288 | ----------
289 | y : list or numpy array of shape (n_samples,)
290 | The ground truth. Binary (0: inliers, 1: outliers).
291 |
292 | y_pred : list or numpy array of shape (n_samples,)
293 | The raw outlier scores as returned by a fitted model.
294 |
295 |
296 | Returns
297 | -------
298 | Gmean: float
299 | """
300 | y_pred = get_label_n(y, y_pred)
301 | #y_pred = (y_pred > threshold).astype('int')
302 | ones_all = (y==1).sum()
303 | ones_correct = ((y==1) & (y_pred==1)).sum()
304 | zeros_all = (y==0).sum()
305 | zeros_correct = ((y==0) & (y_pred==0)).sum()
306 | Gmean = np.sqrt(1.0*ones_correct/ones_all)
307 | Gmean *= np.sqrt(1.0*zeros_correct/zeros_all)
308 |
309 | return Gmean
310 |
311 | def get_gmean(y, y_pred, threshold=0.5):
312 | """Utility function to calculate precision @ rank n.
313 |
314 | Parameters
315 | ----------
316 | y : list or numpy array of shape (n_samples,)
317 | The ground truth. Binary (0: inliers, 1: outliers).
318 |
319 | y_pred : list or numpy array of shape (n_samples,)
320 | The raw outlier scores as returned by a fitted model.
321 |
322 |
323 | Returns
324 | -------
325 | Gmean: float
326 | """
327 | #y_pred = get_label_n(y, y_pred)
328 | y = y.reshape(-1,)
329 | y_pred = y_pred.reshape(-1,)
330 | y_pred = (y_pred >= threshold).astype('int')
331 | ones_all = (y==1).sum()
332 | ones_correct = ((y==1) & (y_pred==1)).sum()
333 | zeros_all = (y==0).sum()
334 | zeros_correct = ((y==0) & (y_pred==0)).sum()
335 | Gmean = np.sqrt((1.0*ones_correct/ones_all) * (1.0*zeros_correct/zeros_all))
336 | #Gmean *= np.sqrt
337 |
338 | return Gmean
339 |
340 | def AUC_and_Gmean(y_test, y_scores):
341 | #print(y_test)
342 | #print(y_scores)
343 | auc = round(roc_auc_score(y_test, y_scores), ndigits=4)
344 | #prn = round(precision_n_scores(y_test, y_scores), ndigits=4)
345 | gmean = round(get_gmean(y_test, y_scores, 0.5), ndigits=4)
346 |
347 | return auc, gmean
348 |
--------------------------------------------------------------------------------
/EAL-GAN/models/layers.py:
--------------------------------------------------------------------------------
1 | ''' Layers
2 | This file contains various layers for the BigGAN models.
3 | '''
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import init
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | from torch.nn import Parameter as P
11 |
12 | from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
13 |
14 |
15 | # Projection of x onto y
16 | def proj(x, y):
17 | return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
18 |
19 |
20 | # Orthogonalize x wrt list of vectors ys
21 | def gram_schmidt(x, ys):
22 | for y in ys:
23 | x = x - proj(x, y)
24 | return x
25 |
26 |
27 | # Apply num_itrs steps of the power method to estimate top N singular values.
28 | def power_iteration(W, u_, update=True, eps=1e-12):
29 | # Lists holding singular vectors and values
30 | us, vs, svs = [], [], []
31 | for i, u in enumerate(u_):
32 | # Run one step of the power iteration
33 | with torch.no_grad():
34 | v = torch.matmul(u, W)
35 | # Run Gram-Schmidt to subtract components of all other singular vectors
36 | v = F.normalize(gram_schmidt(v, vs), eps=eps)
37 | # Add to the list
38 | vs += [v]
39 | # Update the other singular vector
40 | u = torch.matmul(v, W.t())
41 | # Run Gram-Schmidt to subtract components of all other singular vectors
42 | u = F.normalize(gram_schmidt(u, us), eps=eps)
43 | # Add to the list
44 | us += [u]
45 | if update:
46 | u_[i][:] = u
47 | # Compute this singular value and add it to the list
48 | svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
49 | #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
50 | return svs, us, vs
51 |
52 |
53 | # Convenience passthrough function
54 | class identity(nn.Module):
55 | def forward(self, input):
56 | return input
57 |
58 |
59 | # Spectral normalization base class
60 | class SN(object):
61 | def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
62 | # Number of power iterations per step
63 | self.num_itrs = num_itrs
64 | # Number of singular values
65 | self.num_svs = num_svs
66 | # Transposed?
67 | self.transpose = transpose
68 | # Epsilon value for avoiding divide-by-0
69 | self.eps = eps
70 | # Register a singular vector for each sv
71 | for i in range(self.num_svs):
72 | self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
73 | self.register_buffer('sv%d' % i, torch.ones(1))
74 |
75 | # Singular vectors (u side)
76 | @property
77 | def u(self):
78 | return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
79 |
80 | # Singular values;
81 | # note that these buffers are just for logging and are not used in training.
82 | @property
83 | def sv(self):
84 | return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
85 |
86 | # Compute the spectrally-normalized weight
87 | def W_(self):
88 | W_mat = self.weight.view(self.weight.size(0), -1)
89 | if self.transpose:
90 | W_mat = W_mat.t()
91 | # Apply num_itrs power iterations
92 | for _ in range(self.num_itrs):
93 | svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
94 | # Update the svs
95 | if self.training:
96 | with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
97 | for i, sv in enumerate(svs):
98 | self.sv[i][:] = sv
99 | return self.weight / svs[0]
100 |
101 |
102 | # 2D Conv layer with spectral norm
103 | class SNConv2d(nn.Conv2d, SN):
104 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
105 | padding=0, dilation=1, groups=1, bias=True,
106 | num_svs=1, num_itrs=1, eps=1e-12):
107 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
108 | padding, dilation, groups, bias)
109 | SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
110 | def forward(self, x):
111 | return F.conv2d(x, self.W_(), self.bias, self.stride,
112 | self.padding, self.dilation, self.groups)
113 |
114 |
115 | # Linear layer with spectral norm
116 | class SNLinear(nn.Linear, SN):
117 | def __init__(self, in_features, out_features, bias=True,
118 | num_svs=1, num_itrs=1, eps=1e-12):
119 | nn.Linear.__init__(self, in_features, out_features, bias)
120 | SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
121 | def forward(self, x):
122 | return F.linear(x, self.W_(), self.bias)
123 |
124 |
125 | # Embedding layer with spectral norm
126 | # We use num_embeddings as the dim instead of embedding_dim here
127 | # for convenience sake
128 | class SNEmbedding(nn.Embedding, SN):
129 | def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
130 | max_norm=None, norm_type=2, scale_grad_by_freq=False,
131 | sparse=False, _weight=None,
132 | num_svs=1, num_itrs=1, eps=1e-12):
133 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
134 | max_norm, norm_type, scale_grad_by_freq,
135 | sparse, _weight)
136 | SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
137 | def forward(self, x):
138 | return F.embedding(x, self.W_())
139 |
140 |
141 | # A non-local block as used in SA-GAN
142 | # Note that the implementation as described in the paper is largely incorrect;
143 | # refer to the released code for the actual implementation.
144 | class Attention(nn.Module):
145 | def __init__(self, ch, which_conv=SNConv2d, name='attention'):
146 | super(Attention, self).__init__()
147 | # Channel multiplier
148 | self.ch = ch
149 | self.which_conv = which_conv
150 | self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
151 | self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
152 | self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
153 | self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
154 | # Learnable gain parameter
155 | self.gamma = P(torch.tensor(0.), requires_grad=True)
156 | def forward(self, x, y=None):
157 | # Apply convs
158 | theta = self.theta(x)
159 | phi = F.max_pool2d(self.phi(x), [2,2])
160 | g = F.max_pool2d(self.g(x), [2,2])
161 | # Perform reshapes
162 | theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3])
163 | phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4)
164 | g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4)
165 | # Matmul and softmax to get attention maps
166 | beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
167 | # Attention map times g path
168 | o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
169 | return self.gamma * o + x
170 |
171 |
172 | # Fused batchnorm op
173 | def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
174 | # Apply scale and shift--if gain and bias are provided, fuse them here
175 | # Prepare scale
176 | scale = torch.rsqrt(var + eps)
177 | # If a gain is provided, use it
178 | if gain is not None:
179 | scale = scale * gain
180 | # Prepare shift
181 | shift = mean * scale
182 | # If bias is provided, use it
183 | if bias is not None:
184 | shift = shift - bias
185 | return x * scale - shift
186 | #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
187 |
188 |
189 | # Manual BN
190 | # Calculate means and variances using mean-of-squares minus mean-squared
191 | def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
192 | # Cast x to float32 if necessary
193 | float_x = x.float()
194 | # Calculate expected value of x (m) and expected value of x**2 (m2)
195 | # Mean of x
196 | m = torch.mean(float_x, [0, 2, 3], keepdim=True)
197 | # Mean of x squared
198 | m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
199 | # Calculate variance as mean of squared minus mean squared.
200 | var = (m2 - m **2)
201 | # Cast back to float 16 if necessary
202 | var = var.type(x.type())
203 | m = m.type(x.type())
204 | # Return mean and variance for updating stored mean/var if requested
205 | if return_mean_var:
206 | return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
207 | else:
208 | return fused_bn(x, m, var, gain, bias, eps)
209 |
210 |
211 | # My batchnorm, supports standing stats
212 | class myBN(nn.Module):
213 | def __init__(self, num_channels, eps=1e-5, momentum=0.1):
214 | super(myBN, self).__init__()
215 | # momentum for updating running stats
216 | self.momentum = momentum
217 | # epsilon to avoid dividing by 0
218 | self.eps = eps
219 | # Momentum
220 | self.momentum = momentum
221 | # Register buffers
222 | self.register_buffer('stored_mean', torch.zeros(num_channels))
223 | self.register_buffer('stored_var', torch.ones(num_channels))
224 | self.register_buffer('accumulation_counter', torch.zeros(1))
225 | # Accumulate running means and vars
226 | self.accumulate_standing = False
227 |
228 | # reset standing stats
229 | def reset_stats(self):
230 | self.stored_mean[:] = 0
231 | self.stored_var[:] = 0
232 | self.accumulation_counter[:] = 0
233 |
234 | def forward(self, x, gain, bias):
235 | if self.training:
236 | out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
237 | # If accumulating standing stats, increment them
238 | if self.accumulate_standing:
239 | self.stored_mean[:] = self.stored_mean + mean.data
240 | self.stored_var[:] = self.stored_var + var.data
241 | self.accumulation_counter += 1.0
242 | # If not accumulating standing stats, take running averages
243 | else:
244 | self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
245 | self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
246 | return out
247 | # If not in training mode, use the stored statistics
248 | else:
249 | mean = self.stored_mean.view(1, -1, 1, 1)
250 | var = self.stored_var.view(1, -1, 1, 1)
251 | # If using standing stats, divide them by the accumulation counter
252 | if self.accumulate_standing:
253 | mean = mean / self.accumulation_counter
254 | var = var / self.accumulation_counter
255 | return fused_bn(x, mean, var, gain, bias, self.eps)
256 |
257 |
258 | # Simple function to handle groupnorm norm stylization
259 | def groupnorm(x, norm_style):
260 | # If number of channels specified in norm_style:
261 | if 'ch' in norm_style:
262 | ch = int(norm_style.split('_')[-1])
263 | groups = max(int(x.shape[1]) // ch, 1)
264 | # If number of groups specified in norm style
265 | elif 'grp' in norm_style:
266 | groups = int(norm_style.split('_')[-1])
267 | # If neither, default to groups = 16
268 | else:
269 | groups = 16
270 | return F.group_norm(x, groups)
271 |
272 |
273 | # Class-conditional bn
274 | # output size is the number of channels, input size is for the linear layers
275 | # Andy's Note: this class feels messy but I'm not really sure how to clean it up
276 | # Suggestions welcome! (By which I mean, refactor this and make a pull request
277 | # if you want to make this more readable/usable).
278 | class ccbn(nn.Module):
279 | def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
280 | cross_replica=False, mybn=False, norm_style='bn',):
281 | super(ccbn, self).__init__()
282 | self.output_size, self.input_size = output_size, input_size
283 | # Prepare gain and bias layers
284 | self.gain = which_linear(input_size, output_size)
285 | self.bias = which_linear(input_size, output_size)
286 | # epsilon to avoid dividing by 0
287 | self.eps = eps
288 | # Momentum
289 | self.momentum = momentum
290 | # Use cross-replica batchnorm?
291 | self.cross_replica = cross_replica
292 | # Use my batchnorm?
293 | self.mybn = mybn
294 | # Norm style?
295 | self.norm_style = norm_style
296 |
297 | if self.cross_replica:
298 | self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
299 | elif self.mybn:
300 | self.bn = myBN(output_size, self.eps, self.momentum)
301 | elif self.norm_style in ['bn', 'in']:
302 | self.register_buffer('stored_mean', torch.zeros(output_size))
303 | self.register_buffer('stored_var', torch.ones(output_size))
304 |
305 |
306 | def forward(self, x, y):
307 | # Calculate class-conditional gains and biases
308 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
309 | bias = self.bias(y).view(y.size(0), -1, 1, 1)
310 | # If using my batchnorm
311 | if self.mybn or self.cross_replica:
312 | return self.bn(x, gain=gain, bias=bias)
313 | # else:
314 | else:
315 | if self.norm_style == 'bn':
316 | out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
317 | self.training, 0.1, self.eps)
318 | elif self.norm_style == 'in':
319 | out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
320 | self.training, 0.1, self.eps)
321 | elif self.norm_style == 'gn':
322 | out = groupnorm(x, self.normstyle)
323 | elif self.norm_style == 'nonorm':
324 | out = x
325 | return out * gain + bias
326 | def extra_repr(self):
327 | s = 'out: {output_size}, in: {input_size},'
328 | s +=' cross_replica={cross_replica}'
329 | return s.format(**self.__dict__)
330 |
331 |
332 | # Normal, non-class-conditional BN
333 | class bn(nn.Module):
334 | def __init__(self, output_size, eps=1e-5, momentum=0.1,
335 | cross_replica=False, mybn=False):
336 | super(bn, self).__init__()
337 | self.output_size= output_size
338 | # Prepare gain and bias layers
339 | self.gain = P(torch.ones(output_size), requires_grad=True)
340 | self.bias = P(torch.zeros(output_size), requires_grad=True)
341 | # epsilon to avoid dividing by 0
342 | self.eps = eps
343 | # Momentum
344 | self.momentum = momentum
345 | # Use cross-replica batchnorm?
346 | self.cross_replica = cross_replica
347 | # Use my batchnorm?
348 | self.mybn = mybn
349 |
350 | if self.cross_replica:
351 | self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
352 | elif mybn:
353 | self.bn = myBN(output_size, self.eps, self.momentum)
354 | # Register buffers if neither of the above
355 | else:
356 | self.register_buffer('stored_mean', torch.zeros(output_size))
357 | self.register_buffer('stored_var', torch.ones(output_size))
358 |
359 | def forward(self, x, y=None):
360 | if self.cross_replica or self.mybn:
361 | gain = self.gain.view(1,-1,1,1)
362 | bias = self.bias.view(1,-1,1,1)
363 | return self.bn(x, gain=gain, bias=bias)
364 | else:
365 | return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
366 | self.bias, self.training, self.momentum, self.eps)
367 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 | # _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size'])
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | def forward(self, input, gain=None, bias=None):
49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50 | if not (self._is_parallel and self.training):
51 | out = F.batch_norm(
52 | input, self.running_mean, self.running_var, self.weight, self.bias,
53 | self.training, self.momentum, self.eps)
54 | if gain is not None:
55 | out = out + gain
56 | if bias is not None:
57 | out = out + bias
58 | return out
59 |
60 | # Resize the input to (B, C, -1).
61 | input_shape = input.size()
62 | # print(input_shape)
63 | input = input.view(input.size(0), input.size(1), -1)
64 |
65 | # Compute the sum and square-sum.
66 | sum_size = input.size(0) * input.size(2)
67 | input_sum = _sum_ft(input)
68 | input_ssum = _sum_ft(input ** 2)
69 | # Reduce-and-broadcast the statistics.
70 | # print('it begins')
71 | if self._parallel_id == 0:
72 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
73 | else:
74 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
75 | # if self._parallel_id == 0:
76 | # # print('here')
77 | # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
78 | # else:
79 | # # print('there')
80 | # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
81 |
82 | # print('how2')
83 | # num = sum_size
84 | # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu())))
85 | # Fix the graph
86 | # sum = (sum.detach() - input_sum.detach()) + input_sum
87 | # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum
88 |
89 | # mean = sum / num
90 | # var = ssum / num - mean ** 2
91 | # # var = (ssum - mean * sum) / num
92 | # inv_std = torch.rsqrt(var + self.eps)
93 |
94 | # Compute the output.
95 | if gain is not None:
96 | # print('gaining')
97 | # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
98 | # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
99 | # output = input * scale - shift
100 | output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
101 | elif self.affine:
102 | # MJY:: Fuse the multiplication for speed.
103 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
104 | else:
105 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
106 |
107 | # Reshape it.
108 | return output.view(input_shape)
109 |
110 | def __data_parallel_replicate__(self, ctx, copy_id):
111 | self._is_parallel = True
112 | self._parallel_id = copy_id
113 |
114 | # parallel_id == 0 means master device.
115 | if self._parallel_id == 0:
116 | ctx.sync_master = self._sync_master
117 | else:
118 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
119 |
120 | def _data_parallel_master(self, intermediates):
121 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
122 |
123 | # Always using same "device order" makes the ReduceAdd operation faster.
124 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
125 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
126 |
127 | to_reduce = [i[1][:2] for i in intermediates]
128 | to_reduce = [j for i in to_reduce for j in i] # flatten
129 | target_gpus = [i[1].sum.get_device() for i in intermediates]
130 |
131 | sum_size = sum([i[1].sum_size for i in intermediates])
132 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
133 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
134 |
135 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
136 | # print('a')
137 | # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size)
138 | # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device))
139 | # print('b')
140 | outputs = []
141 | for i, rec in enumerate(intermediates):
142 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
143 | # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3])))
144 |
145 | return outputs
146 |
147 | def _compute_mean_std(self, sum_, ssum, size):
148 | """Compute the mean and standard-deviation with sum and square-sum. This method
149 | also maintains the moving average on the master device."""
150 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
151 | mean = sum_ / size
152 | sumvar = ssum - sum_ * mean
153 | unbias_var = sumvar / (size - 1)
154 | bias_var = sumvar / size
155 |
156 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
157 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
158 | return mean, torch.rsqrt(bias_var + self.eps)
159 | # return mean, bias_var.clamp(self.eps) ** -0.5
160 |
161 |
162 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
163 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
164 | mini-batch.
165 |
166 | .. math::
167 |
168 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
169 |
170 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
171 | standard-deviation are reduced across all devices during training.
172 |
173 | For example, when one uses `nn.DataParallel` to wrap the network during
174 | training, PyTorch's implementation normalize the tensor on each device using
175 | the statistics only on that device, which accelerated the computation and
176 | is also easy to implement, but the statistics might be inaccurate.
177 | Instead, in this synchronized version, the statistics will be computed
178 | over all training samples distributed on multiple devices.
179 |
180 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
181 | as the built-in PyTorch implementation.
182 |
183 | The mean and standard-deviation are calculated per-dimension over
184 | the mini-batches and gamma and beta are learnable parameter vectors
185 | of size C (where C is the input size).
186 |
187 | During training, this layer keeps a running estimate of its computed mean
188 | and variance. The running sum is kept with a default momentum of 0.1.
189 |
190 | During evaluation, this running mean/variance is used for normalization.
191 |
192 | Because the BatchNorm is done over the `C` dimension, computing statistics
193 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
194 |
195 | Args:
196 | num_features: num_features from an expected input of size
197 | `batch_size x num_features [x width]`
198 | eps: a value added to the denominator for numerical stability.
199 | Default: 1e-5
200 | momentum: the value used for the running_mean and running_var
201 | computation. Default: 0.1
202 | affine: a boolean value that when set to ``True``, gives the layer learnable
203 | affine parameters. Default: ``True``
204 |
205 | Shape:
206 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
207 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
208 |
209 | Examples:
210 | >>> # With Learnable Parameters
211 | >>> m = SynchronizedBatchNorm1d(100)
212 | >>> # Without Learnable Parameters
213 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
214 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
215 | >>> output = m(input)
216 | """
217 |
218 | def _check_input_dim(self, input):
219 | if input.dim() != 2 and input.dim() != 3:
220 | raise ValueError('expected 2D or 3D input (got {}D input)'
221 | .format(input.dim()))
222 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
223 |
224 |
225 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
226 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
227 | of 3d inputs
228 |
229 | .. math::
230 |
231 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
232 |
233 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
234 | standard-deviation are reduced across all devices during training.
235 |
236 | For example, when one uses `nn.DataParallel` to wrap the network during
237 | training, PyTorch's implementation normalize the tensor on each device using
238 | the statistics only on that device, which accelerated the computation and
239 | is also easy to implement, but the statistics might be inaccurate.
240 | Instead, in this synchronized version, the statistics will be computed
241 | over all training samples distributed on multiple devices.
242 |
243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
244 | as the built-in PyTorch implementation.
245 |
246 | The mean and standard-deviation are calculated per-dimension over
247 | the mini-batches and gamma and beta are learnable parameter vectors
248 | of size C (where C is the input size).
249 |
250 | During training, this layer keeps a running estimate of its computed mean
251 | and variance. The running sum is kept with a default momentum of 0.1.
252 |
253 | During evaluation, this running mean/variance is used for normalization.
254 |
255 | Because the BatchNorm is done over the `C` dimension, computing statistics
256 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
257 |
258 | Args:
259 | num_features: num_features from an expected input of
260 | size batch_size x num_features x height x width
261 | eps: a value added to the denominator for numerical stability.
262 | Default: 1e-5
263 | momentum: the value used for the running_mean and running_var
264 | computation. Default: 0.1
265 | affine: a boolean value that when set to ``True``, gives the layer learnable
266 | affine parameters. Default: ``True``
267 |
268 | Shape:
269 | - Input: :math:`(N, C, H, W)`
270 | - Output: :math:`(N, C, H, W)` (same shape as input)
271 |
272 | Examples:
273 | >>> # With Learnable Parameters
274 | >>> m = SynchronizedBatchNorm2d(100)
275 | >>> # Without Learnable Parameters
276 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
277 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
278 | >>> output = m(input)
279 | """
280 |
281 | def _check_input_dim(self, input):
282 | if input.dim() != 4:
283 | raise ValueError('expected 4D input (got {}D input)'
284 | .format(input.dim()))
285 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
286 |
287 |
288 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
289 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
290 | of 4d inputs
291 |
292 | .. math::
293 |
294 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
295 |
296 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
297 | standard-deviation are reduced across all devices during training.
298 |
299 | For example, when one uses `nn.DataParallel` to wrap the network during
300 | training, PyTorch's implementation normalize the tensor on each device using
301 | the statistics only on that device, which accelerated the computation and
302 | is also easy to implement, but the statistics might be inaccurate.
303 | Instead, in this synchronized version, the statistics will be computed
304 | over all training samples distributed on multiple devices.
305 |
306 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
307 | as the built-in PyTorch implementation.
308 |
309 | The mean and standard-deviation are calculated per-dimension over
310 | the mini-batches and gamma and beta are learnable parameter vectors
311 | of size C (where C is the input size).
312 |
313 | During training, this layer keeps a running estimate of its computed mean
314 | and variance. The running sum is kept with a default momentum of 0.1.
315 |
316 | During evaluation, this running mean/variance is used for normalization.
317 |
318 | Because the BatchNorm is done over the `C` dimension, computing statistics
319 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
320 | or Spatio-temporal BatchNorm
321 |
322 | Args:
323 | num_features: num_features from an expected input of
324 | size batch_size x num_features x depth x height x width
325 | eps: a value added to the denominator for numerical stability.
326 | Default: 1e-5
327 | momentum: the value used for the running_mean and running_var
328 | computation. Default: 0.1
329 | affine: a boolean value that when set to ``True``, gives the layer learnable
330 | affine parameters. Default: ``True``
331 |
332 | Shape:
333 | - Input: :math:`(N, C, D, H, W)`
334 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
335 |
336 | Examples:
337 | >>> # With Learnable Parameters
338 | >>> m = SynchronizedBatchNorm3d(100)
339 | >>> # Without Learnable Parameters
340 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
341 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
342 | >>> output = m(input)
343 | """
344 |
345 | def _check_input_dim(self, input):
346 | if input.dim() != 5:
347 | raise ValueError('expected 5D input (got {}D input)'
348 | .format(input.dim()))
349 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
--------------------------------------------------------------------------------
/EAL-GAN/models/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 | # _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size'])
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | def forward(self, input, gain=None, bias=None):
49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50 | if not (self._is_parallel and self.training):
51 | out = F.batch_norm(
52 | input, self.running_mean, self.running_var, self.weight, self.bias,
53 | self.training, self.momentum, self.eps)
54 | if gain is not None:
55 | out = out + gain
56 | if bias is not None:
57 | out = out + bias
58 | return out
59 |
60 | # Resize the input to (B, C, -1).
61 | input_shape = input.size()
62 | # print(input_shape)
63 | input = input.view(input.size(0), input.size(1), -1)
64 |
65 | # Compute the sum and square-sum.
66 | sum_size = input.size(0) * input.size(2)
67 | input_sum = _sum_ft(input)
68 | input_ssum = _sum_ft(input ** 2)
69 | # Reduce-and-broadcast the statistics.
70 | # print('it begins')
71 | if self._parallel_id == 0:
72 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
73 | else:
74 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
75 | # if self._parallel_id == 0:
76 | # # print('here')
77 | # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
78 | # else:
79 | # # print('there')
80 | # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
81 |
82 | # print('how2')
83 | # num = sum_size
84 | # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu())))
85 | # Fix the graph
86 | # sum = (sum.detach() - input_sum.detach()) + input_sum
87 | # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum
88 |
89 | # mean = sum / num
90 | # var = ssum / num - mean ** 2
91 | # # var = (ssum - mean * sum) / num
92 | # inv_std = torch.rsqrt(var + self.eps)
93 |
94 | # Compute the output.
95 | if gain is not None:
96 | # print('gaining')
97 | # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
98 | # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
99 | # output = input * scale - shift
100 | output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
101 | elif self.affine:
102 | # MJY:: Fuse the multiplication for speed.
103 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
104 | else:
105 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
106 |
107 | # Reshape it.
108 | return output.view(input_shape)
109 |
110 | def __data_parallel_replicate__(self, ctx, copy_id):
111 | self._is_parallel = True
112 | self._parallel_id = copy_id
113 |
114 | # parallel_id == 0 means master device.
115 | if self._parallel_id == 0:
116 | ctx.sync_master = self._sync_master
117 | else:
118 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
119 |
120 | def _data_parallel_master(self, intermediates):
121 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
122 |
123 | # Always using same "device order" makes the ReduceAdd operation faster.
124 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
125 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
126 |
127 | to_reduce = [i[1][:2] for i in intermediates]
128 | to_reduce = [j for i in to_reduce for j in i] # flatten
129 | target_gpus = [i[1].sum.get_device() for i in intermediates]
130 |
131 | sum_size = sum([i[1].sum_size for i in intermediates])
132 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
133 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
134 |
135 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
136 | # print('a')
137 | # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size)
138 | # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device))
139 | # print('b')
140 | outputs = []
141 | for i, rec in enumerate(intermediates):
142 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
143 | # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3])))
144 |
145 | return outputs
146 |
147 | def _compute_mean_std(self, sum_, ssum, size):
148 | """Compute the mean and standard-deviation with sum and square-sum. This method
149 | also maintains the moving average on the master device."""
150 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
151 | mean = sum_ / size
152 | sumvar = ssum - sum_ * mean
153 | unbias_var = sumvar / (size - 1)
154 | bias_var = sumvar / size
155 |
156 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
157 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
158 | return mean, torch.rsqrt(bias_var + self.eps)
159 | # return mean, bias_var.clamp(self.eps) ** -0.5
160 |
161 |
162 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
163 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
164 | mini-batch.
165 |
166 | .. math::
167 |
168 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
169 |
170 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
171 | standard-deviation are reduced across all devices during training.
172 |
173 | For example, when one uses `nn.DataParallel` to wrap the network during
174 | training, PyTorch's implementation normalize the tensor on each device using
175 | the statistics only on that device, which accelerated the computation and
176 | is also easy to implement, but the statistics might be inaccurate.
177 | Instead, in this synchronized version, the statistics will be computed
178 | over all training samples distributed on multiple devices.
179 |
180 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
181 | as the built-in PyTorch implementation.
182 |
183 | The mean and standard-deviation are calculated per-dimension over
184 | the mini-batches and gamma and beta are learnable parameter vectors
185 | of size C (where C is the input size).
186 |
187 | During training, this layer keeps a running estimate of its computed mean
188 | and variance. The running sum is kept with a default momentum of 0.1.
189 |
190 | During evaluation, this running mean/variance is used for normalization.
191 |
192 | Because the BatchNorm is done over the `C` dimension, computing statistics
193 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
194 |
195 | Args:
196 | num_features: num_features from an expected input of size
197 | `batch_size x num_features [x width]`
198 | eps: a value added to the denominator for numerical stability.
199 | Default: 1e-5
200 | momentum: the value used for the running_mean and running_var
201 | computation. Default: 0.1
202 | affine: a boolean value that when set to ``True``, gives the layer learnable
203 | affine parameters. Default: ``True``
204 |
205 | Shape:
206 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
207 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
208 |
209 | Examples:
210 | >>> # With Learnable Parameters
211 | >>> m = SynchronizedBatchNorm1d(100)
212 | >>> # Without Learnable Parameters
213 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
214 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
215 | >>> output = m(input)
216 | """
217 |
218 | def _check_input_dim(self, input):
219 | if input.dim() != 2 and input.dim() != 3:
220 | raise ValueError('expected 2D or 3D input (got {}D input)'
221 | .format(input.dim()))
222 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
223 |
224 |
225 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
226 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
227 | of 3d inputs
228 |
229 | .. math::
230 |
231 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
232 |
233 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
234 | standard-deviation are reduced across all devices during training.
235 |
236 | For example, when one uses `nn.DataParallel` to wrap the network during
237 | training, PyTorch's implementation normalize the tensor on each device using
238 | the statistics only on that device, which accelerated the computation and
239 | is also easy to implement, but the statistics might be inaccurate.
240 | Instead, in this synchronized version, the statistics will be computed
241 | over all training samples distributed on multiple devices.
242 |
243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
244 | as the built-in PyTorch implementation.
245 |
246 | The mean and standard-deviation are calculated per-dimension over
247 | the mini-batches and gamma and beta are learnable parameter vectors
248 | of size C (where C is the input size).
249 |
250 | During training, this layer keeps a running estimate of its computed mean
251 | and variance. The running sum is kept with a default momentum of 0.1.
252 |
253 | During evaluation, this running mean/variance is used for normalization.
254 |
255 | Because the BatchNorm is done over the `C` dimension, computing statistics
256 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
257 |
258 | Args:
259 | num_features: num_features from an expected input of
260 | size batch_size x num_features x height x width
261 | eps: a value added to the denominator for numerical stability.
262 | Default: 1e-5
263 | momentum: the value used for the running_mean and running_var
264 | computation. Default: 0.1
265 | affine: a boolean value that when set to ``True``, gives the layer learnable
266 | affine parameters. Default: ``True``
267 |
268 | Shape:
269 | - Input: :math:`(N, C, H, W)`
270 | - Output: :math:`(N, C, H, W)` (same shape as input)
271 |
272 | Examples:
273 | >>> # With Learnable Parameters
274 | >>> m = SynchronizedBatchNorm2d(100)
275 | >>> # Without Learnable Parameters
276 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
277 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
278 | >>> output = m(input)
279 | """
280 |
281 | def _check_input_dim(self, input):
282 | if input.dim() != 4:
283 | raise ValueError('expected 4D input (got {}D input)'
284 | .format(input.dim()))
285 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
286 |
287 |
288 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
289 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
290 | of 4d inputs
291 |
292 | .. math::
293 |
294 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
295 |
296 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
297 | standard-deviation are reduced across all devices during training.
298 |
299 | For example, when one uses `nn.DataParallel` to wrap the network during
300 | training, PyTorch's implementation normalize the tensor on each device using
301 | the statistics only on that device, which accelerated the computation and
302 | is also easy to implement, but the statistics might be inaccurate.
303 | Instead, in this synchronized version, the statistics will be computed
304 | over all training samples distributed on multiple devices.
305 |
306 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
307 | as the built-in PyTorch implementation.
308 |
309 | The mean and standard-deviation are calculated per-dimension over
310 | the mini-batches and gamma and beta are learnable parameter vectors
311 | of size C (where C is the input size).
312 |
313 | During training, this layer keeps a running estimate of its computed mean
314 | and variance. The running sum is kept with a default momentum of 0.1.
315 |
316 | During evaluation, this running mean/variance is used for normalization.
317 |
318 | Because the BatchNorm is done over the `C` dimension, computing statistics
319 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
320 | or Spatio-temporal BatchNorm
321 |
322 | Args:
323 | num_features: num_features from an expected input of
324 | size batch_size x num_features x depth x height x width
325 | eps: a value added to the denominator for numerical stability.
326 | Default: 1e-5
327 | momentum: the value used for the running_mean and running_var
328 | computation. Default: 0.1
329 | affine: a boolean value that when set to ``True``, gives the layer learnable
330 | affine parameters. Default: ``True``
331 |
332 | Shape:
333 | - Input: :math:`(N, C, D, H, W)`
334 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
335 |
336 | Examples:
337 | >>> # With Learnable Parameters
338 | >>> m = SynchronizedBatchNorm3d(100)
339 | >>> # Without Learnable Parameters
340 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
341 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
342 | >>> output = m(input)
343 | """
344 |
345 | def _check_input_dim(self, input):
346 | if input.dim() != 5:
347 | raise ValueError('expected 5D input (got {}D input)'
348 | .format(input.dim()))
349 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
--------------------------------------------------------------------------------
/EAL-GAN/models/EAL_GAN.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import functools
4 | import random
5 | import os
6 | from collections import defaultdict
7 | import matplotlib.pyplot as plt
8 | import matplotlib.font_manager
9 |
10 | from numpy import percentile
11 | import pandas as pd
12 |
13 | import torch
14 | import torch.nn as nn
15 | from torch.nn import init
16 | import torch.optim as optim
17 | import torch.nn.functional as F
18 | from torch.nn import Parameter as P
19 |
20 | from .layers import ccbn, identity, SNLinear, SNEmbedding
21 |
22 | from .utils import prepare_z_y, active_sampling_V1, sample_selector_V1
23 | from .losses import loss_dis_fake, loss_dis_real
24 | from .pyod_utils import AUC_and_Gmean
25 |
26 |
27 | class Generator(nn.Module):
28 | def __init__(self, dim_z=64, hidden_dim=128, output_dim=128, n_classes=2, hidden_number=1,
29 | init='ortho', SN_used=True):
30 | super(Generator, self).__init__()
31 | self.hidden_dim = hidden_dim
32 |
33 | self.n_classes = n_classes
34 | self.init = init
35 | self.shared_dim = dim_z // 2
36 |
37 | if SN_used:
38 | self.which_linear = functools.partial(SNLinear,
39 | num_svs=1, num_itrs=1)
40 | # self.which_embedding = functools.partial(SNEmbedding,
41 | # num_svs=1, num_itrs=1)
42 | else:
43 | self.which_linear = nn.Linear
44 | self.which_embedding = nn.Embedding
45 |
46 | self.shared = self.which_embedding(n_classes, self.shared_dim)
47 | self.input_fc = self.which_linear(dim_z + self.shared_dim, self.hidden_dim)
48 |
49 | self.output_fc = self.which_linear(self.hidden_dim, output_dim)
50 |
51 | self.model = nn.Sequential(self.input_fc,
52 | nn.ReLU())
53 |
54 | for index in range(hidden_number):
55 | middle_fc = nn.Linear(self.hidden_dim, self.hidden_dim)
56 | self.model.add_module('hidden-layers-{0}'.format(index), middle_fc)
57 | self.model.add_module('ReLu-{0}'.format(index), nn.ReLU())
58 | # self.model.add_module('ReLu-{0}'.format(index), nn.Tanh())
59 |
60 | self.model.add_module('output_layer', self.output_fc)
61 |
62 | self.init_weights()
63 |
64 | def init_weights(self):
65 | self.param_count = 0
66 | for module in self.modules():
67 | if (isinstance(module, nn.Linear)
68 | or isinstance(module, nn.Embedding)):
69 | if self.init == 'ortho':
70 | init.orthogonal_(module.weight)
71 | elif self.init == 'N02':
72 | init.normal_(module.weight, 0, 0.02)
73 | elif self.init in ['glorot', 'xavier']:
74 | init.xavier_uniform_(module.weight)
75 | else:
76 | print('Init style not recognized...')
77 | self.param_count += sum([p.data.nelement() for p in module.parameters()])
78 | print('Param count for G''s initialized parameters: %d' % self.param_count)
79 |
80 | def forward(self, z, y):
81 | # If hierarchical, concatenate zs and ys
82 | y = self.shared(y) # modification
83 | h = torch.cat([z, y], 1)
84 |
85 | h = self.model(h)
86 | # Apply batchnorm-relu-conv-tanh at output
87 | return h
88 |
89 |
90 | class Discriminator(nn.Module):
91 | def __init__(self, input_dim=64, hidden_dim=64, output_dim=1, n_classes=2,
92 | hidden_number=1, init='ortho', SN_used=True):
93 | super(Discriminator, self).__init__()
94 |
95 | self.hidden_dim = hidden_dim
96 | self.n_classes = n_classes
97 | self.init = init
98 |
99 | if SN_used:
100 | self.which_linear = functools.partial(SNLinear,
101 | num_svs=1, num_itrs=1)
102 | # self.which_embedding = functools.partial(SNEmbedding,
103 | # num_svs=1, num_itrs=1)
104 | else:
105 | self.which_linear = nn.Linear
106 | self.which_embedding = nn.Embedding
107 |
108 | self.input_fc = self.which_linear(input_dim, self.hidden_dim)
109 | self.output_fc = self.which_linear(self.hidden_dim, output_dim)
110 | self.output_category = nn.Sequential(self.which_linear(self.hidden_dim, output_dim),
111 | nn.Sigmoid())
112 | # Embedding for projection discrimination
113 | self.embed = self.which_embedding(self.n_classes, self.hidden_dim)
114 | self.model = nn.Sequential(self.input_fc,
115 | nn.ReLU())
116 |
117 | # self.blocks = []
118 | for index in range(hidden_number):
119 | middle_fc = self.which_linear(self.hidden_dim, self.hidden_dim)
120 | self.model.add_module('hidden-layers-{0}'.format(index), middle_fc)
121 | self.model.add_module('ReLu-{0}'.format(index), nn.ReLU())
122 | # self.model.add_module('ReLu-{0}'.format(index), nn.Sigmoid())
123 |
124 | self.init_weights()
125 |
126 | # Initialize
127 | def init_weights(self):
128 | self.param_count = 0
129 | for module in self.modules():
130 | if (isinstance(module, nn.Linear)
131 | or isinstance(module, nn.Embedding)):
132 | if self.init == 'ortho':
133 | init.orthogonal_(module.weight)
134 | elif self.init == 'N02':
135 | init.normal_(module.weight, 0, 0.02)
136 | elif self.init in ['glorot', 'xavier']:
137 | init.xavier_uniform_(module.weight)
138 | else:
139 | print('Init style not recognized...')
140 | self.param_count += sum([p.data.nelement() for p in module.parameters()])
141 | # print('Param count for D''s initialized parameters: %d' % self.param_count)
142 |
143 | def forward(self, x, y=None, mode=0):
144 | # mode 0: train the whole discriminator network
145 | if mode == 0:
146 | h = self.model(x)
147 | out = self.output_fc(h)
148 | # Get projection of final featureset onto class vectors and add to evidence
149 | out_real_fake = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
150 | out_category = self.output_category(h)
151 | return out_real_fake, out_category
152 | # mode 1: train self.output_fc, only classify whether an input is fake or real
153 | elif mode == 1:
154 | h = self.model(x)
155 | out = self.output_fc(h)
156 | return out
157 | # mode 2: train self.output_category, used in fine_tunning stage
158 | else:
159 | h = self.model(x)
160 | out = self.output_category(h)
161 | return out
162 |
163 |
164 | class EAL_GAN(nn.Module):
165 | def __init__(self, args, data_x, data_y, test_x, test_y, visualize=False):
166 | super(EAL_GAN, self).__init__()
167 |
168 | lr_g = args.lr_g
169 | lr_d = args.lr_d
170 |
171 | self.device = torch.device("cuda" if args.cuda else "cpu")
172 | z, y = prepare_z_y(20, args.dim_z, 2, device=self.device)
173 | y = y * 0 + 1
174 | self.y = y.long()
175 | self.noise = z
176 |
177 | self.args = args
178 | # self.data_x = torch.from_numpy(data_x).float()
179 | # self.data_y = torch.from_numpy(data_y).long()
180 | self.data_x = data_x
181 | self.data_y = data_y
182 | self.test_x = test_x
183 | self.test_y = test_y
184 | self.iterations = 0
185 | self.visualize = visualize
186 |
187 | self.feature_size = data_x.shape[1]
188 | self.data_size = data_x.shape[0]
189 | self.batch_size = min(args.batch_size, self.data_size)
190 | self.hidden_dim = self.feature_size * 2
191 | self.dim_z = args.dim_z
192 |
193 | manualSeed = random.randint(1, 10000)
194 | random.seed(manualSeed)
195 | torch.manual_seed(manualSeed)
196 |
197 | # 1: prepare Generator
198 | self.netG = Generator(dim_z=self.dim_z, hidden_dim=self.hidden_dim, output_dim=self.feature_size, n_classes=2,
199 | hidden_number=args.gen_layer, init=args.init_type, SN_used=args.SN_used)
200 | self.optimizerG = optim.Adam(self.netG.parameters(), lr=args.lr_g, betas=(0.00, 0.99))
201 |
202 | if args.cuda:
203 | self.netG = nn.DataParallel(self.netG, device_ids=[0, 1])
204 | self.netG = self.netG.to(self.device)
205 |
206 | # 2: create ensemble of discriminator
207 | self.NetD_Ensemble = []
208 | self.opti_Ensemble = []
209 | lr_ds = np.random.rand(args.ensemble_num) * (args.lr_d * 5 - args.lr_d) + args.lr_d # learning rate
210 | for index in range(args.ensemble_num):
211 | netD = Discriminator(input_dim=self.feature_size, hidden_dim=self.hidden_dim, output_dim=1, n_classes=2,
212 | hidden_number=args.dis_layer, init=args.init_type, SN_used=args.SN_used)
213 | optimizerD = optim.Adam(netD.parameters(), lr=lr_ds[index], betas=(0.00, 0.99))
214 | if args.cuda:
215 | netD = nn.DataParallel(netD, device_ids=[0, 1])
216 | netD = netD.to(self.device)
217 |
218 | self.NetD_Ensemble += [netD]
219 | self.opti_Ensemble += [optimizerD]
220 |
221 | def fit(self):
222 | log_dir = os.path.join('./log/', self.args.data_name)
223 | if not os.path.exists(log_dir):
224 | os.makedirs(log_dir)
225 | z, y = prepare_z_y(self.batch_size, self.dim_z, 2, device=self.device)
226 | # Start iteration
227 | Best_Measure_Recorded = -1
228 | best_auc = 0
229 | best_gmean = 0
230 | self.train_history = defaultdict(list)
231 | for epoch in range(self.args.max_epochs):
232 | train_AUC, train_Gmean, test_auc, test_gmean = self.train_one_epoch(z, y, epoch)
233 | if train_Gmean * train_AUC > Best_Measure_Recorded:
234 | Best_Measure_Recorded = train_Gmean * train_AUC
235 | best_auc = test_auc
236 | best_gmean = test_gmean
237 | states = {
238 | 'epoch': epoch,
239 | 'gen_dict': self.netG.state_dict(),
240 | 'opti_gen': self.optimizerG.state_dict(),
241 | 'max_auc': train_AUC
242 | }
243 | for i in range(self.args.ensemble_num):
244 | netD = self.NetD_Ensemble[i]
245 | optimi_D = self.opti_Ensemble[i]
246 | states['dis_dict' + str(i)] = netD.state_dict()
247 | states['opti_dis' + str(i)] = optimi_D.state_dict()
248 |
249 | torch.save(states, os.path.join(log_dir, 'checkpoint_best.pth'))
250 | # print(train_AUC, test_AUC, epoch)
251 | if self.args.print:
252 | print('Training for epoch %d: Train_AUC=%.4f train_Gmean=%.4f Test_AUC=%.4f Test_Gmean=%.4f' % (
253 | epoch + 1, train_AUC, train_Gmean, test_auc, test_gmean))
254 |
255 | # step 1: load the best models
256 | self.Best_Ensemble = []
257 | states = torch.load(os.path.join(log_dir, 'checkpoint_best.pth'))
258 | self.netG.load_state_dict(states['gen_dict'])
259 | for i in range(self.args.ensemble_num):
260 | netD = self.NetD_Ensemble[i]
261 | netD.load_state_dict(states['dis_dict' + str(i)])
262 | self.Best_Ensemble += [netD]
263 |
264 | return best_auc, best_gmean
265 |
266 | def predict(self, test_x, test_y, dis_Ensemble=None):
267 | p = []
268 | y = []
269 |
270 | data_size = test_x.shape[0]
271 | num_batches = data_size // self.batch_size
272 | num_batches = num_batches + 1 if data_size % self.batch_size > 0 else num_batches
273 |
274 | for index in range(num_batches):
275 | end_pos = min(data_size, (index + 1) * self.batch_size)
276 | real_x = test_x[index * self.batch_size: end_pos]
277 | real_y = test_y[index * self.batch_size: end_pos]
278 |
279 | final_pt = 0
280 | for i in range(self.args.ensemble_num):
281 | pt = self.Best_Ensemble[i](real_x, mode=2) if dis_Ensemble is None else dis_Ensemble[i](real_x, mode=2)
282 | final_pt = pt.detach() if i == 0 else final_pt + pt
283 |
284 | final_pt /= self.args.ensemble_num
285 | final_pt = final_pt.view(-1, )
286 |
287 | p += [final_pt]
288 | y += [real_y]
289 | p = torch.cat(p, 0).cpu().detach().numpy()
290 | y = torch.cat(y, 0).cpu().detach().numpy()
291 | return y, p
292 |
293 | def train_one_epoch(self, z, y, epoch=1):
294 | # train discriminator & generator for one specific spoch
295 |
296 | data_size = self.data_x.shape[0]
297 | # feature_size = data_x.shape[1]
298 | batch_size = min(self.args.batch_size, data_size)
299 |
300 | num_batches = data_size // batch_size
301 | num_batches = num_batches + 1 if data_size % batch_size > 0 else num_batches
302 |
303 | # data shuffer
304 | perm_index = torch.randperm(data_size)
305 | if self.args.cuda:
306 | perm_index = perm_index.cuda()
307 | # data_x = data_x[perm_index]
308 | # data_y = data_y[perm_index]
309 |
310 | x_empirical, y_empirical = [], []
311 |
312 |
313 | for index in range(num_batches):
314 | # step 1: train the ensemble of discriminator
315 | # Get training data
316 | self.iterations += 1
317 |
318 | end_pos = min(data_size, (index + 1) * batch_size)
319 | real_x = self.data_x[index * batch_size: end_pos]
320 | real_y = self.data_y[index * batch_size: end_pos]
321 |
322 | real_weights = None
323 | fake_weights = None
324 | real_cat_weights = None
325 | fake_cat_weights = None
326 |
327 | z.sample_()
328 | y.sample_()
329 | generated_x = self.netG(z, y)
330 | generated_x = generated_x.detach()
331 |
332 | losses = []
333 | dis_loss = 0
334 | gen_loss = 0
335 | # select p% of the training data, label them
336 | real_x_selected, real_y_selected, _ = active_sampling_V1(self.args, real_x, real_y, self.NetD_Ensemble,
337 | need_sample=(self.iterations > 1))
338 | x_empirical += [real_x_selected]
339 | y_empirical += [real_y_selected]
340 | x_empirical += [generated_x]
341 | y_empirical += [y]
342 | # print(real_y_selected.shape)
343 | for i in range(self.args.ensemble_num):
344 | optimizer = self.opti_Ensemble[i]
345 | netD = self.NetD_Ensemble[i]
346 | # train the GAN with real data
347 |
348 | out_real_fake, out_real_categoy = netD(real_x_selected, real_y_selected)
349 |
350 | loss1, real_loss_cat, real_weights, real_cat_weights = loss_dis_real(out_real_fake, out_real_categoy,
351 | real_y_selected, real_weights,
352 | real_cat_weights)
353 | real_loss = loss1 + real_loss_cat
354 |
355 | # train on fake data
356 | output, out_fake_category = netD(generated_x, y.detach())
357 | loss2, loss_cat_fake, fake_weights, fake_cat_weights = loss_dis_fake(output, out_fake_category, y,
358 | fake_weights, fake_cat_weights)
359 | fake_loss = loss2 + loss_cat_fake
360 | sum_loss = real_loss + fake_loss
361 | dis_loss += sum_loss
362 |
363 | self.train_history['discriminator_loss_' + str(i)].append(sum_loss)
364 | losses += [sum_loss]
365 | optimizer.zero_grad()
366 | sum_loss.backward(retain_graph=True)
367 | optimizer.step()
368 | self.train_history['discriminator_loss'].append(dis_loss)
369 |
370 | # step 2: train the generator
371 | z.sample_()
372 | y.sample_()
373 | generated_x = self.netG(z, y)
374 | gen_loss = 0
375 | gen_weights = None
376 | gen_cat_weights = None
377 | for i in range(self.args.ensemble_num):
378 | # optimizer = names['optimizerD_' + str(i)]
379 | netD = self.NetD_Ensemble[i]
380 | output, out_category = netD(generated_x, y)
381 | # out_real_fake, out_real_categoy = netD(real_x, real_y)
382 | loss, loss_cat, gen_weights, gen_cat_weights = loss_dis_real(output, out_category, y, gen_weights,
383 | gen_cat_weights)
384 | # loss, gen_weights = loss_gen(output, gen_weights)
385 | gen_loss += (loss + loss_cat)
386 | self.train_history['generator_loss'].append(gen_loss)
387 | self.optimizerG.zero_grad()
388 | gen_loss.backward()
389 | self.optimizerG.step()
390 |
391 | x_empirical = torch.cat(x_empirical, 0)
392 | y_empirical = torch.cat(y_empirical, 0)
393 | y_train, y_pred_train = self.predict(x_empirical, y_empirical, self.NetD_Ensemble)
394 |
395 | y_scores_pandas = pd.DataFrame(y_pred_train)
396 | y_scores_pandas.fillna(0, inplace=True)
397 | y_pred_train = y_scores_pandas.values
398 |
399 | auc, gmean = AUC_and_Gmean(y_train, y_pred_train)
400 | self.train_history['train_auc'].append(auc)
401 | self.train_history['train_Gmean'].append(gmean)
402 |
403 | y_test, y_pred_test = self.predict(self.test_x, self.test_y, self.NetD_Ensemble)
404 | y_scores_pandas = pd.DataFrame(y_pred_test)
405 | y_scores_pandas.fillna(0, inplace=True)
406 | y_pred_test = y_scores_pandas.values
407 | test_auc, test_gmean = AUC_and_Gmean(y_test, y_pred_test)
408 |
409 | return auc, gmean, test_auc, test_gmean
410 |
411 |
--------------------------------------------------------------------------------
/EAL-GAN-image/src/layers.py:
--------------------------------------------------------------------------------
1 | ''' Layers
2 | This file contains various layers for the BigGAN models.
3 | '''
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import init
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | from torch.nn import Parameter as P
11 |
12 | from src.sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
13 |
14 |
15 | # Projection of x onto y
16 | def proj(x, y):
17 | return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
18 |
19 |
20 | # Orthogonalize x wrt list of vectors ys
21 | def gram_schmidt(x, ys):
22 | for y in ys:
23 | x = x - proj(x, y)
24 | return x
25 |
26 |
27 | # Apply num_itrs steps of the power method to estimate top N singular values.
28 | def power_iteration(W, u_, update=True, eps=1e-12):
29 | # Lists holding singular vectors and values
30 | us, vs, svs = [], [], []
31 | for i, u in enumerate(u_):
32 | # Run one step of the power iteration
33 | with torch.no_grad():
34 | v = torch.matmul(u, W)
35 | # Run Gram-Schmidt to subtract components of all other singular vectors
36 | v = F.normalize(gram_schmidt(v, vs), eps=eps)
37 | # Add to the list
38 | vs += [v]
39 | # Update the other singular vector
40 | u = torch.matmul(v, W.t())
41 | # Run Gram-Schmidt to subtract components of all other singular vectors
42 | u = F.normalize(gram_schmidt(u, us), eps=eps)
43 | # Add to the list
44 | us += [u]
45 | if update:
46 | u_[i][:] = u
47 | # Compute this singular value and add it to the list
48 | svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
49 | #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
50 | return svs, us, vs
51 |
52 |
53 | # Convenience passthrough function
54 | class identity(nn.Module):
55 | def forward(self, input):
56 | return input
57 |
58 |
59 | # Spectral normalization base class
60 | class SN(object):
61 | def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
62 | # Number of power iterations per step
63 | self.num_itrs = num_itrs
64 | # Number of singular values
65 | self.num_svs = num_svs
66 | # Transposed?
67 | self.transpose = transpose
68 | # Epsilon value for avoiding divide-by-0
69 | self.eps = eps
70 | # Register a singular vector for each sv
71 | for i in range(self.num_svs):
72 | self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
73 | self.register_buffer('sv%d' % i, torch.ones(1))
74 |
75 | # Singular vectors (u side)
76 | @property
77 | def u(self):
78 | return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
79 |
80 | # Singular values;
81 | # note that these buffers are just for logging and are not used in training.
82 | @property
83 | def sv(self):
84 | return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
85 |
86 | # Compute the spectrally-normalized weight
87 | def W_(self):
88 | W_mat = self.weight.view(self.weight.size(0), -1)
89 | if self.transpose:
90 | W_mat = W_mat.t()
91 | # Apply num_itrs power iterations
92 | for _ in range(self.num_itrs):
93 | svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
94 | # Update the svs
95 | if self.training:
96 | with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
97 | for i, sv in enumerate(svs):
98 | self.sv[i][:] = sv
99 | return self.weight / svs[0]
100 |
101 |
102 | # 2D Conv layer with spectral norm
103 | class SNConv2d(nn.Conv2d, SN):
104 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
105 | padding=0, dilation=1, groups=1, bias=True,
106 | num_svs=1, num_itrs=1, eps=1e-12):
107 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
108 | padding, dilation, groups, bias)
109 | SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
110 | def forward(self, x):
111 | return F.conv2d(x, self.W_(), self.bias, self.stride,
112 | self.padding, self.dilation, self.groups)
113 |
114 |
115 | # Linear layer with spectral norm
116 | class SNLinear(nn.Linear, SN):
117 | def __init__(self, in_features, out_features, bias=True,
118 | num_svs=1, num_itrs=1, eps=1e-12):
119 | nn.Linear.__init__(self, in_features, out_features, bias)
120 | SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
121 | def forward(self, x):
122 | return F.linear(x, self.W_(), self.bias)
123 |
124 |
125 | # Embedding layer with spectral norm
126 | # We use num_embeddings as the dim instead of embedding_dim here
127 | # for convenience sake
128 | class SNEmbedding(nn.Embedding, SN):
129 | def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
130 | max_norm=None, norm_type=2, scale_grad_by_freq=False,
131 | sparse=False, _weight=None,
132 | num_svs=1, num_itrs=1, eps=1e-12):
133 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
134 | max_norm, norm_type, scale_grad_by_freq,
135 | sparse, _weight)
136 | SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
137 | def forward(self, x):
138 | return F.embedding(x, self.W_())
139 |
140 |
141 | # A non-local block as used in SA-GAN
142 | # Note that the implementation as described in the paper is largely incorrect;
143 | # refer to the released code for the actual implementation.
144 | class Attention(nn.Module):
145 | def __init__(self, ch, which_conv=SNConv2d, name='attention'):
146 | super(Attention, self).__init__()
147 | # Channel multiplier
148 | self.ch = ch
149 | self.which_conv = which_conv
150 | self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
151 | self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
152 | self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
153 | self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
154 | # Learnable gain parameter
155 | self.gamma = P(torch.tensor(0.), requires_grad=True)
156 | def forward(self, x, y=None):
157 | # Apply convs
158 | theta = self.theta(x)
159 | phi = F.max_pool2d(self.phi(x), [2,2])
160 | g = F.max_pool2d(self.g(x), [2,2])
161 | # Perform reshapes
162 | theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3])
163 | phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4)
164 | g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4)
165 | # Matmul and softmax to get attention maps
166 | beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
167 | # Attention map times g path
168 | o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
169 | return self.gamma * o + x
170 |
171 |
172 | # Fused batchnorm op
173 | def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
174 | # Apply scale and shift--if gain and bias are provided, fuse them here
175 | # Prepare scale
176 | scale = torch.rsqrt(var + eps)
177 | # If a gain is provided, use it
178 | if gain is not None:
179 | scale = scale * gain
180 | # Prepare shift
181 | shift = mean * scale
182 | # If bias is provided, use it
183 | if bias is not None:
184 | shift = shift - bias
185 | return x * scale - shift
186 | #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
187 |
188 |
189 | # Manual BN
190 | # Calculate means and variances using mean-of-squares minus mean-squared
191 | def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
192 | # Cast x to float32 if necessary
193 | float_x = x.float()
194 | # Calculate expected value of x (m) and expected value of x**2 (m2)
195 | # Mean of x
196 | m = torch.mean(float_x, [0, 2, 3], keepdim=True)
197 | # Mean of x squared
198 | m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
199 | # Calculate variance as mean of squared minus mean squared.
200 | var = (m2 - m **2)
201 | # Cast back to float 16 if necessary
202 | var = var.type(x.type())
203 | m = m.type(x.type())
204 | # Return mean and variance for updating stored mean/var if requested
205 | if return_mean_var:
206 | return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
207 | else:
208 | return fused_bn(x, m, var, gain, bias, eps)
209 |
210 |
211 | # My batchnorm, supports standing stats
212 | class myBN(nn.Module):
213 | def __init__(self, num_channels, eps=1e-5, momentum=0.1):
214 | super(myBN, self).__init__()
215 | # momentum for updating running stats
216 | self.momentum = momentum
217 | # epsilon to avoid dividing by 0
218 | self.eps = eps
219 | # Momentum
220 | self.momentum = momentum
221 | # Register buffers
222 | self.register_buffer('stored_mean', torch.zeros(num_channels))
223 | self.register_buffer('stored_var', torch.ones(num_channels))
224 | self.register_buffer('accumulation_counter', torch.zeros(1))
225 | # Accumulate running means and vars
226 | self.accumulate_standing = False
227 |
228 | # reset standing stats
229 | def reset_stats(self):
230 | self.stored_mean[:] = 0
231 | self.stored_var[:] = 0
232 | self.accumulation_counter[:] = 0
233 |
234 | def forward(self, x, gain, bias):
235 | if self.training:
236 | out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
237 | # If accumulating standing stats, increment them
238 | if self.accumulate_standing:
239 | self.stored_mean[:] = self.stored_mean + mean.data
240 | self.stored_var[:] = self.stored_var + var.data
241 | self.accumulation_counter += 1.0
242 | # If not accumulating standing stats, take running averages
243 | else:
244 | self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
245 | self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
246 | return out
247 | # If not in training mode, use the stored statistics
248 | else:
249 | mean = self.stored_mean.view(1, -1, 1, 1)
250 | var = self.stored_var.view(1, -1, 1, 1)
251 | # If using standing stats, divide them by the accumulation counter
252 | if self.accumulate_standing:
253 | mean = mean / self.accumulation_counter
254 | var = var / self.accumulation_counter
255 | return fused_bn(x, mean, var, gain, bias, self.eps)
256 |
257 |
258 | # Simple function to handle groupnorm norm stylization
259 | def groupnorm(x, norm_style):
260 | # If number of channels specified in norm_style:
261 | if 'ch' in norm_style:
262 | ch = int(norm_style.split('_')[-1])
263 | groups = max(int(x.shape[1]) // ch, 1)
264 | # If number of groups specified in norm style
265 | elif 'grp' in norm_style:
266 | groups = int(norm_style.split('_')[-1])
267 | # If neither, default to groups = 16
268 | else:
269 | groups = 16
270 | return F.group_norm(x, groups)
271 |
272 |
273 | # Class-conditional bn
274 | # output size is the number of channels, input size is for the linear layers
275 | # Andy's Note: this class feels messy but I'm not really sure how to clean it up
276 | # Suggestions welcome! (By which I mean, refactor this and make a pull request
277 | # if you want to make this more readable/usable).
278 | class ccbn(nn.Module):
279 | def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
280 | cross_replica=False, mybn=False, norm_style='bn',):
281 | super(ccbn, self).__init__()
282 | self.output_size, self.input_size = output_size, input_size
283 | # Prepare gain and bias layers
284 | self.gain = which_linear(input_size, output_size)
285 | self.bias = which_linear(input_size, output_size)
286 | # epsilon to avoid dividing by 0
287 | self.eps = eps
288 | # Momentum
289 | self.momentum = momentum
290 | # Use cross-replica batchnorm?
291 | self.cross_replica = cross_replica
292 | # Use my batchnorm?
293 | self.mybn = mybn
294 | # Norm style?
295 | self.norm_style = norm_style
296 |
297 | if self.cross_replica:
298 | self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
299 | elif self.mybn:
300 | self.bn = myBN(output_size, self.eps, self.momentum)
301 | elif self.norm_style in ['bn', 'in']:
302 | self.register_buffer('stored_mean', torch.zeros(output_size))
303 | self.register_buffer('stored_var', torch.ones(output_size))
304 |
305 |
306 | def forward(self, x, y):
307 | # Calculate class-conditional gains and biases
308 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
309 | bias = self.bias(y).view(y.size(0), -1, 1, 1)
310 | # If using my batchnorm
311 | if self.mybn or self.cross_replica:
312 | return self.bn(x, gain=gain, bias=bias)
313 | # else:
314 | else:
315 | if self.norm_style == 'bn':
316 | out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
317 | self.training, 0.1, self.eps)
318 | elif self.norm_style == 'in':
319 | out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
320 | self.training, 0.1, self.eps)
321 | elif self.norm_style == 'gn':
322 | out = groupnorm(x, self.normstyle)
323 | elif self.norm_style == 'nonorm':
324 | out = x
325 | return out * gain + bias
326 | def extra_repr(self):
327 | s = 'out: {output_size}, in: {input_size},'
328 | s +=' cross_replica={cross_replica}'
329 | return s.format(**self.__dict__)
330 |
331 |
332 | # Normal, non-class-conditional BN
333 | class bn(nn.Module):
334 | def __init__(self, output_size, eps=1e-5, momentum=0.1,
335 | cross_replica=False, mybn=False):
336 | super(bn, self).__init__()
337 | self.output_size= output_size
338 | # Prepare gain and bias layers
339 | self.gain = P(torch.ones(output_size), requires_grad=True)
340 | self.bias = P(torch.zeros(output_size), requires_grad=True)
341 | # epsilon to avoid dividing by 0
342 | self.eps = eps
343 | # Momentum
344 | self.momentum = momentum
345 | # Use cross-replica batchnorm?
346 | self.cross_replica = cross_replica
347 | # Use my batchnorm?
348 | self.mybn = mybn
349 |
350 | if self.cross_replica:
351 | self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
352 | elif mybn:
353 | self.bn = myBN(output_size, self.eps, self.momentum)
354 | # Register buffers if neither of the above
355 | else:
356 | self.register_buffer('stored_mean', torch.zeros(output_size))
357 | self.register_buffer('stored_var', torch.ones(output_size))
358 |
359 | def forward(self, x, y=None):
360 | if self.cross_replica or self.mybn:
361 | gain = self.gain.view(1,-1,1,1)
362 | bias = self.bias.view(1,-1,1,1)
363 | return self.bn(x, gain=gain, bias=bias)
364 | else:
365 | return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
366 | self.bias, self.training, self.momentum, self.eps)
367 |
368 |
369 | # Generator blocks
370 | # Note that this class assumes the kernel size and padding (and any other
371 | # settings) have been selected in the main generator module and passed in
372 | # through the which_conv arg. Similar rules apply with which_bn (the input
373 | # size [which is actually the number of channels of the conditional info] must
374 | # be preselected)
375 | class GBlock(nn.Module):
376 | def __init__(self, in_channels, out_channels,
377 | which_conv=nn.Conv2d, which_bn=bn, activation=None,
378 | upsample=None):
379 | super(GBlock, self).__init__()
380 |
381 | self.in_channels, self.out_channels = in_channels, out_channels
382 | self.which_conv, self.which_bn = which_conv, which_bn
383 | self.activation = activation
384 | self.upsample = upsample
385 | # Conv layers
386 | self.conv1 = self.which_conv(self.in_channels, self.out_channels)
387 | self.conv2 = self.which_conv(self.out_channels, self.out_channels)
388 | self.learnable_sc = in_channels != out_channels or upsample
389 | if self.learnable_sc:
390 | self.conv_sc = self.which_conv(in_channels, out_channels,
391 | kernel_size=1, padding=0)
392 | # Batchnorm layers
393 | self.bn1 = self.which_bn(in_channels)
394 | self.bn2 = self.which_bn(out_channels)
395 | # upsample layers
396 | self.upsample = upsample
397 |
398 | def forward(self, x, y):
399 | h = self.activation(self.bn1(x, y))
400 | if self.upsample:
401 | h = self.upsample(h)
402 | x = self.upsample(x)
403 | h = self.conv1(h)
404 | h = self.activation(self.bn2(h, y))
405 | h = self.conv2(h)
406 | if self.learnable_sc:
407 | x = self.conv_sc(x)
408 | return h + x
409 |
410 |
411 | # Residual block for the discriminator
412 | class DBlock(nn.Module):
413 | def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
414 | preactivation=False, activation=None, downsample=None,):
415 | super(DBlock, self).__init__()
416 | self.in_channels, self.out_channels = in_channels, out_channels
417 | # If using wide D (as in SA-GAN and BigGAN), change the channel pattern
418 | self.hidden_channels = self.out_channels if wide else self.in_channels
419 | self.which_conv = which_conv
420 | self.preactivation = preactivation
421 | self.activation = activation
422 | self.downsample = downsample
423 |
424 | # Conv layers
425 | self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
426 | self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
427 | self.learnable_sc = True if (in_channels != out_channels) or downsample else False
428 | if self.learnable_sc:
429 | self.conv_sc = self.which_conv(in_channels, out_channels,
430 | kernel_size=1, padding=0)
431 | def shortcut(self, x):
432 | if self.preactivation:
433 | if self.learnable_sc:
434 | x = self.conv_sc(x)
435 | if self.downsample:
436 | x = self.downsample(x)
437 | else:
438 | if self.downsample:
439 | x = self.downsample(x)
440 | if self.learnable_sc:
441 | x = self.conv_sc(x)
442 | return x
443 |
444 | def forward(self, x):
445 | if self.preactivation:
446 | # h = self.activation(x) # NOT TODAY SATAN
447 | # Andy's note: This line *must* be an out-of-place ReLU or it
448 | # will negatively affect the shortcut connection.
449 | h = F.relu(x)
450 | else:
451 | h = x
452 | h = self.conv1(h)
453 | h = self.conv2(self.activation(h))
454 | if self.downsample:
455 | h = self.downsample(h)
456 |
457 | return h + self.shortcut(x)
458 |
459 | # dogball
--------------------------------------------------------------------------------
/EAL-GAN-image/src/BigGAN.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import functools
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import init
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | from torch.nn import Parameter as P
11 |
12 | import src.layers as layers
13 | from src.sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
14 |
15 |
16 | # Architectures for G
17 | # Attention is passed in in the format '32_64' to mean applying an attention
18 | # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
19 | def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
20 | arch = {}
21 | arch[512] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
22 | 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
23 | 'upsample' : [True] * 7,
24 | 'resolution' : [8, 16, 32, 64, 128, 256, 512],
25 | 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
26 | for i in range(3,10)}}
27 | arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]],
28 | 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]],
29 | 'upsample' : [True] * 6,
30 | 'resolution' : [8, 16, 32, 64, 128, 256],
31 | 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
32 | for i in range(3,9)}}
33 | arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]],
34 | 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]],
35 | 'upsample' : [True] * 5,
36 | 'resolution' : [8, 16, 32, 64, 128],
37 | 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
38 | for i in range(3,8)}}
39 | arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]],
40 | 'out_channels' : [ch * item for item in [16, 8, 4, 2]],
41 | 'upsample' : [True] * 4,
42 | 'resolution' : [8, 16, 32, 64],
43 | 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
44 | for i in range(3,7)}}
45 | arch[32] = {'in_channels' : [ch * item for item in [4, 4, 4]],
46 | 'out_channels' : [ch * item for item in [4, 4, 4]],
47 | 'upsample' : [True] * 3,
48 | 'resolution' : [8, 16, 32],
49 | 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
50 | for i in range(3,6)}}
51 |
52 | return arch
53 |
54 | class Generator(nn.Module):
55 | def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128,
56 | G_kernel_size=3, G_attn='64', n_classes=1000,
57 | num_G_SVs=1, num_G_SV_itrs=1,
58 | G_shared=True, shared_dim=0, hier=True,
59 | cross_replica=False, mybn=False,
60 | G_activation=nn.ReLU(inplace=False),
61 | G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
62 | BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
63 | G_init='ortho', skip_init=False, no_optim=False,
64 | G_param='SN', norm_style='bn',
65 | **kwargs):
66 | super(Generator, self).__init__()
67 | # Channel width mulitplier
68 | self.ch = G_ch
69 | # Dimensionality of the latent space
70 | self.dim_z = dim_z
71 | # The initial spatial dimensions
72 | self.bottom_width = bottom_width
73 | # Resolution of the output
74 | self.resolution = resolution
75 | # Kernel size?
76 | self.kernel_size = G_kernel_size
77 | # Attention?
78 | self.attention = G_attn
79 | # number of classes, for use in categorical conditional generation
80 | self.n_classes = n_classes
81 | # Use shared embeddings?
82 | self.G_shared = G_shared
83 | # Dimensionality of the shared embedding? Unused if not using G_shared
84 | self.shared_dim = shared_dim if shared_dim > 0 else dim_z
85 | # Hierarchical latent space?
86 | self.hier = hier
87 | # Cross replica batchnorm?
88 | self.cross_replica = cross_replica
89 | # Use my batchnorm?
90 | self.mybn = mybn
91 | # nonlinearity for residual blocks
92 | self.activation = G_activation
93 | # Initialization style
94 | self.init = G_init
95 | # Parameterization style
96 | self.G_param = G_param
97 | # Normalization style
98 | self.norm_style = norm_style
99 | # Epsilon for BatchNorm?
100 | self.BN_eps = BN_eps
101 | # Epsilon for Spectral Norm?
102 | self.SN_eps = SN_eps
103 | # fp16?
104 | self.fp16 = G_fp16
105 | # Architecture dict
106 | self.arch = G_arch(self.ch, self.attention)[resolution]
107 |
108 | # If using hierarchical latents, adjust z
109 | if self.hier:
110 | # Number of places z slots into
111 | self.num_slots = len(self.arch['in_channels']) + 1
112 | self.z_chunk_size = (self.dim_z // self.num_slots)
113 | # Recalculate latent dimensionality for even splitting into chunks
114 | self.dim_z = self.z_chunk_size * self.num_slots
115 | else:
116 | self.num_slots = 1
117 | self.z_chunk_size = 0
118 |
119 | # Which convs, batchnorms, and linear layers to use
120 | if self.G_param == 'SN':
121 | self.which_conv = functools.partial(layers.SNConv2d,
122 | kernel_size=3, padding=1,
123 | num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
124 | eps=self.SN_eps)
125 | self.which_linear = functools.partial(layers.SNLinear,
126 | num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
127 | eps=self.SN_eps)
128 | else:
129 | self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
130 | self.which_linear = nn.Linear
131 |
132 | # We use a non-spectral-normed embedding here regardless;
133 | # For some reason applying SN to G's embedding seems to randomly cripple G
134 | self.which_embedding = nn.Embedding
135 | bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
136 | else self.which_embedding)
137 | self.which_bn = functools.partial(layers.ccbn,
138 | which_linear=bn_linear,
139 | cross_replica=self.cross_replica,
140 | mybn=self.mybn,
141 | input_size=(self.shared_dim + self.z_chunk_size if self.G_shared
142 | else self.n_classes),
143 | norm_style=self.norm_style,
144 | eps=self.BN_eps)
145 |
146 |
147 | # Prepare model
148 | # If not using shared embeddings, self.shared is just a passthrough
149 | self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
150 | else layers.identity())
151 | # First linear layer
152 | self.linear = self.which_linear(self.dim_z // self.num_slots,
153 | self.arch['in_channels'][0] * (self.bottom_width **2))
154 |
155 | # self.blocks is a doubly-nested list of modules, the outer loop intended
156 | # to be over blocks at a given resolution (resblocks and/or self-attention)
157 | # while the inner loop is over a given block
158 | self.blocks = []
159 | for index in range(len(self.arch['out_channels'])):
160 | self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
161 | out_channels=self.arch['out_channels'][index],
162 | which_conv=self.which_conv,
163 | which_bn=self.which_bn,
164 | activation=self.activation,
165 | upsample=(functools.partial(F.interpolate, scale_factor=2)
166 | if self.arch['upsample'][index] else None))]]
167 |
168 | # If attention on this block, attach it to the end
169 | if self.arch['attention'][self.arch['resolution'][index]]:
170 | print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
171 | self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
172 |
173 | # Turn self.blocks into a ModuleList so that it's all properly registered.
174 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
175 |
176 | # output layer: batchnorm-relu-conv.
177 | # Consider using a non-spectral conv here
178 | self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
179 | cross_replica=self.cross_replica,
180 | mybn=self.mybn),
181 | self.activation,
182 | self.which_conv(self.arch['out_channels'][-1], 3))
183 |
184 | # Initialize weights. Optionally skip init for testing.
185 | if not skip_init:
186 | self.init_weights()
187 |
188 | # Set up optimizer
189 | # If this is an EMA copy, no need for an optim, so just return now
190 | if no_optim:
191 | return
192 | self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
193 | if G_mixed_precision:
194 | print('Using fp16 adam in G...')
195 | import utils
196 | self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
197 | betas=(self.B1, self.B2), weight_decay=0,
198 | eps=self.adam_eps)
199 | else:
200 | self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
201 | betas=(self.B1, self.B2), weight_decay=0,
202 | eps=self.adam_eps)
203 |
204 | # LR scheduling, left here for forward compatibility
205 | # self.lr_sched = {'itr' : 0}# if self.progressive else {}
206 | # self.j = 0
207 |
208 | # Initialize
209 | def init_weights(self):
210 | self.param_count = 0
211 | for module in self.modules():
212 | if (isinstance(module, nn.Conv2d)
213 | or isinstance(module, nn.Linear)
214 | or isinstance(module, nn.Embedding)):
215 | if self.init == 'ortho':
216 | init.orthogonal_(module.weight)
217 | elif self.init == 'N02':
218 | init.normal_(module.weight, 0, 0.02)
219 | elif self.init in ['glorot', 'xavier']:
220 | init.xavier_uniform_(module.weight)
221 | else:
222 | print('Init style not recognized...')
223 | self.param_count += sum([p.data.nelement() for p in module.parameters()])
224 | print('Param count for G''s initialized parameters: %d' % self.param_count)
225 |
226 | # Note on this forward function: we pass in a y vector which has
227 | # already been passed through G.shared to enable easy class-wise
228 | # interpolation later. If we passed in the one-hot and then ran it through
229 | # G.shared in this forward function, it would be harder to handle.
230 | def forward(self, z, y):
231 | # If hierarchical, concatenate zs and ys
232 | y = self.shared(y)
233 | if self.hier:
234 | zs = torch.split(z, self.z_chunk_size, 1)
235 | z = zs[0]
236 | ys = [torch.cat([y, item], 1) for item in zs[1:]]
237 | else:
238 | ys = [y] * len(self.blocks)
239 |
240 | # First linear layer
241 | h = self.linear(z)
242 | # Reshape
243 | h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
244 |
245 | # Loop over blocks
246 | for index, blocklist in enumerate(self.blocks):
247 | # Second inner loop in case block has multiple layers
248 | for block in blocklist:
249 | h = block(h, ys[index])
250 |
251 | # Apply batchnorm-relu-conv-tanh at output
252 | return torch.tanh(self.output_layer(h))
253 |
254 |
255 | # Discriminator architecture, same paradigm as G's above
256 | def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'):
257 | arch = {}
258 | arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]],
259 | 'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
260 | 'downsample' : [True] * 6 + [False],
261 | 'resolution' : [128, 64, 32, 16, 8, 4, 4 ],
262 | 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
263 | for i in range(2,8)}}
264 | arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]],
265 | 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]],
266 | 'downsample' : [True] * 5 + [False],
267 | 'resolution' : [64, 32, 16, 8, 4, 4],
268 | 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
269 | for i in range(2,8)}}
270 | arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]],
271 | 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]],
272 | 'downsample' : [True] * 4 + [False],
273 | 'resolution' : [32, 16, 8, 4, 4],
274 | 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
275 | for i in range(2,7)}}
276 | arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]],
277 | 'out_channels' : [item * ch for item in [4, 4, 4, 4]],
278 | 'downsample' : [True, True, False, False],
279 | 'resolution' : [16, 16, 16, 16],
280 | 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
281 | for i in range(2,6)}}
282 | return arch
283 |
284 | class Discriminator(nn.Module):
285 |
286 | def __init__(self, D_ch=64, D_wide=True, resolution=128,
287 | D_kernel_size=3, D_attn='64', n_classes=1000,
288 | num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
289 | D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
290 | SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
291 | D_init='ortho', skip_init=False, D_param='SN', **kwargs):
292 | super(Discriminator, self).__init__()
293 | # Width multiplier
294 | self.ch = D_ch
295 | # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
296 | self.D_wide = D_wide
297 | # Resolution
298 | self.resolution = resolution
299 | # Kernel size
300 | self.kernel_size = D_kernel_size
301 | # Attention?
302 | self.attention = D_attn
303 | # Number of classes
304 | self.n_classes = n_classes
305 | # Activation
306 | self.activation = D_activation
307 | # Initialization style
308 | self.init = D_init
309 | # Parameterization style
310 | self.D_param = D_param
311 | # Epsilon for Spectral Norm?
312 | self.SN_eps = SN_eps
313 | # Fp16?
314 | self.fp16 = D_fp16
315 | # Architecture
316 | self.arch = D_arch(self.ch, self.attention)[resolution]
317 |
318 | # Which convs, batchnorms, and linear layers to use
319 | # No option to turn off SN in D right now
320 | if self.D_param == 'SN':
321 | self.which_conv = functools.partial(layers.SNConv2d,
322 | kernel_size=3, padding=1,
323 | num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
324 | eps=self.SN_eps)
325 | self.which_linear = functools.partial(layers.SNLinear,
326 | num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
327 | eps=self.SN_eps)
328 | self.which_embedding = functools.partial(layers.SNEmbedding,
329 | num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
330 | eps=self.SN_eps)
331 | # Prepare model
332 | # self.blocks is a doubly-nested list of modules, the outer loop intended
333 | # to be over blocks at a given resolution (resblocks and/or self-attention)
334 | self.blocks = []
335 | for index in range(len(self.arch['out_channels'])):
336 | self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
337 | out_channels=self.arch['out_channels'][index],
338 | which_conv=self.which_conv,
339 | wide=self.D_wide,
340 | activation=self.activation,
341 | preactivation=(index > 0),
342 | downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
343 | # If attention on this block, attach it to the end
344 | if self.arch['attention'][self.arch['resolution'][index]]:
345 | print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
346 | self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
347 | self.which_conv)]
348 | # Turn self.blocks into a ModuleList so that it's all properly registered.
349 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
350 | # Linear output layer. The output dimension is typically 1, but may be
351 | # larger if we're e.g. turning this into a VAE with an inference output
352 | self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
353 | self.output_cat = self.which_linear(self.arch['out_channels'][-1], 1)
354 | # Embedding for projection discrimination
355 | self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
356 |
357 | # Initialize weights
358 | if not skip_init:
359 | self.init_weights()
360 |
361 | # Set up optimizer
362 | self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
363 | if D_mixed_precision:
364 | print('Using fp16 adam in D...')
365 | import utils
366 | self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
367 | betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
368 | else:
369 | self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
370 | betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
371 | # LR scheduling, left here for forward compatibility
372 | # self.lr_sched = {'itr' : 0}# if self.progressive else {}
373 | # self.j = 0
374 |
375 | # Initialize
376 | def init_weights(self):
377 | self.param_count = 0
378 | for module in self.modules():
379 | if (isinstance(module, nn.Conv2d)
380 | or isinstance(module, nn.Linear)
381 | or isinstance(module, nn.Embedding)):
382 | if self.init == 'ortho':
383 | init.orthogonal_(module.weight)
384 | elif self.init == 'N02':
385 | init.normal_(module.weight, 0, 0.02)
386 | elif self.init in ['glorot', 'xavier']:
387 | init.xavier_uniform_(module.weight)
388 | else:
389 | print('Init style not recognized...')
390 | self.param_count += sum([p.data.nelement() for p in module.parameters()])
391 | print('Param count for D''s initialized parameters: %d' % self.param_count)
392 |
393 | def forward(self, x, y=None, mode=0):
394 | # Stick x into h for cleaner for loops without flow control
395 | h = x
396 | # Loop over blocks
397 | for index, blocklist in enumerate(self.blocks):
398 | for block in blocklist:
399 | h = block(h)
400 | # Apply global sum pooling as in SN-GAN
401 | h = torch.sum(self.activation(h), [2, 3])
402 | if mode==0:
403 | # Get initial class-unconditional output
404 | out = self.linear(h)
405 | # Get projection of final featureset onto class vectors and add to evidence
406 | out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
407 | out_cat = F.sigmoid(self.output_cat(h))
408 | return out, out_cat
409 | else:
410 | out_cat = F.sigmoid(self.output_cat(h))
411 | return out_cat
412 |
--------------------------------------------------------------------------------