├── .gitignore ├── MANIFEST.in ├── README.md ├── ckpts ├── dsvdd ├── __init__.py ├── datasets.py ├── deepSVDD.py ├── networks.py └── utils.py ├── main.py ├── main_cifar.py ├── requirements.txt ├── results ├── mnist_0.png ├── mnist_1.png ├── mnist_2.png ├── mnist_3.png ├── mnist_4.png ├── mnist_5.png ├── mnist_6.png ├── mnist_7.png ├── mnist_8.png └── mnist_9.png └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | *.pyc 3 | deep_svdd.egg-info/* 4 | dist/* 5 | build/* 6 | ckpts/* 7 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Implementation of Deep SVDD 2 | This repository provides a [Tensorflow](https://www.tensorflow.org/) implementation of the *Deep SVDD* method presented in 3 | ICML 2018 paper ”Deep One-Class Classification”. 4 | 5 | The author's implementation of *Deep-SVDD* in PyTorch is at [https://github.com/lukasruff/Deep-SVDD-PyTorch](https://github.com/lukasruff/Deep-SVDD-PyTorch). 6 | 7 | 8 | ## Citation and Contact 9 | You find a PDF of the Deep One-Class Classification ICML 2018 paper at 10 | [http://proceedings.mlr.press/v80/ruff18a.html](http://proceedings.mlr.press/v80/ruff18a.html). 11 | 12 | 13 | ## Installation 14 | This code is written in `Python 3.5` and tested with `Tensorflow 1.12`. 15 | 16 | Install using pip or clone this repository. 17 | 18 | 1. Installation using pip: 19 | ```bash 20 | pip install deep-svdd 21 | ``` 22 | 23 | and 24 | 25 | ```python 26 | from dsvdd import DeepSVDD 27 | ``` 28 | 29 | 2. Clone this repository: 30 | 31 | ```bash 32 | git clone https://github.com/nuclearboy95/Deep-SVDD-Tensorflow.git 33 | ``` -------------------------------------------------------------------------------- /ckpts: -------------------------------------------------------------------------------- 1 | /home/t080205/Datasets/Deep-SVDD -------------------------------------------------------------------------------- /dsvdd/__init__.py: -------------------------------------------------------------------------------- 1 | from .networks import * 2 | from .datasets import * 3 | from .deepSVDD import DeepSVDD 4 | -------------------------------------------------------------------------------- /dsvdd/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow import keras 3 | 4 | 5 | __all__ = ['get_mnist', 'get_cifar10'] 6 | 7 | 8 | def get_mnist(cls=1): 9 | d_train, d_test = keras.datasets.mnist.load_data() 10 | x_train, y_train = d_train 11 | x_test, y_test = d_test 12 | 13 | mask = y_train == cls 14 | 15 | x_train = x_train[mask] 16 | x_train = np.expand_dims(x_train / 255., axis=-1).astype(np.float32) 17 | x_test = np.expand_dims(x_test / 255., axis=-1).astype(np.float32) 18 | 19 | y_test = (y_test == cls).astype(np.float32) 20 | return x_train, x_test, y_test 21 | 22 | 23 | def get_cifar10(cls=1): 24 | d_train, d_test = keras.datasets.cifar10.load_data() 25 | x_train, y_train = d_train 26 | x_test, y_test = d_test 27 | y_train = np.squeeze(y_train) 28 | y_test = np.squeeze(y_test) 29 | 30 | mask = y_train == cls 31 | 32 | x_train = x_train[mask] 33 | x_train = (x_train / 255.).astype(np.float32) 34 | x_test = (x_test / 255.).astype(np.float32) 35 | 36 | y_test = (y_test == cls).astype(np.float32) 37 | return x_train, x_test, y_test 38 | -------------------------------------------------------------------------------- /dsvdd/deepSVDD.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from math import ceil 4 | from sklearn.metrics import roc_auc_score 5 | from tensorflow import keras 6 | from tqdm import tqdm 7 | 8 | from .utils import task 9 | 10 | 11 | class DeepSVDD: 12 | def __init__(self, keras_model, input_shape=(28, 28, 1), objective='one-class', 13 | nu=0.1, representation_dim=32, batch_size=128, lr=1e-3): 14 | self.represetation_dim = representation_dim 15 | self.objective = objective 16 | self.keras_model = keras_model 17 | self.nu = nu 18 | self.R = tf.get_variable('R', [], dtype=tf.float32, trainable=False) 19 | self.c = tf.get_variable('c', [self.represetation_dim], dtype=tf.float32, trainable=False) 20 | self.warm_up_n_epochs = 10 21 | self.batch_size = batch_size 22 | 23 | with task('Build graph'): 24 | self.x = tf.placeholder(tf.float32, [None] + list(input_shape)) 25 | self.latent_op = self.keras_model(self.x) 26 | self.dist_op = tf.reduce_sum(tf.square(self.latent_op - self.c), axis=-1) 27 | 28 | if self.objective == 'soft-boundary': 29 | self.score_op = self.dist_op - self.R ** 2 30 | penalty = tf.maximum(self.score_op, tf.zeros_like(self.score_op)) 31 | self.loss_op = self.R ** 2 + (1 / self.nu) * penalty 32 | 33 | else: # one-class 34 | self.score_op = self.dist_op 35 | self.loss_op = self.score_op 36 | 37 | opt = tf.train.AdamOptimizer(lr) 38 | self.train_op = opt.minimize(self.loss_op) 39 | 40 | config = tf.ConfigProto() 41 | config.gpu_options.allow_growth = True 42 | self.sess = tf.Session(config=config) 43 | self.sess.run(tf.global_variables_initializer()) 44 | 45 | def __del__(self): 46 | self.sess.close() 47 | 48 | def fit(self, X, X_test, y_test, epochs=10, verbose=True): 49 | N = X.shape[0] 50 | BS = self.batch_size 51 | BN = int(ceil(N / BS)) 52 | 53 | self.sess.run(tf.global_variables_initializer()) 54 | self._init_c(X) 55 | 56 | ops = { 57 | 'train': self.train_op, 58 | 'loss': tf.reduce_mean(self.loss_op), 59 | 'dist': self.dist_op 60 | } 61 | keras.backend.set_learning_phase(True) 62 | 63 | for i_epoch in range(epochs): 64 | ind = np.random.permutation(N) 65 | x_train = X[ind] 66 | g_batch = tqdm(range(BN)) if verbose else range(BN) 67 | for i_batch in g_batch: 68 | x_batch = x_train[i_batch * BS: (i_batch + 1) * BS] 69 | results = self.sess.run(ops, feed_dict={self.x: x_batch}) 70 | 71 | if self.objective == 'soft-boundary' and i_epoch >= self.warm_up_n_epochs: 72 | self.sess.run(tf.assign(self.R, self._get_R(results['dist'], self.nu))) 73 | 74 | else: 75 | if verbose: 76 | pred = self.predict(X_test) # pred: large->fail small->pass 77 | auc = roc_auc_score(y_test, -pred) # y_test: 1->pass 0->fail 78 | print('\rEpoch: %3d AUROC: %.3f' % (i_epoch, auc)) 79 | 80 | def predict(self, X): 81 | N = X.shape[0] 82 | BS = self.batch_size 83 | BN = int(ceil(N / BS)) 84 | scores = list() 85 | keras.backend.set_learning_phase(False) 86 | 87 | for i_batch in range(BN): 88 | x_batch = X[i_batch * BS: (i_batch + 1) * BS] 89 | s_batch = self.sess.run(self.score_op, feed_dict={self.x: x_batch}) 90 | scores.append(s_batch) 91 | 92 | return np.concatenate(scores) 93 | 94 | def _init_c(self, X, eps=1e-1): 95 | N = X.shape[0] 96 | BS = self.batch_size 97 | BN = int(ceil(N / BS)) 98 | keras.backend.set_learning_phase(False) 99 | 100 | with task('1. Get output'): 101 | latent_sum = np.zeros(self.latent_op.shape[-1]) 102 | for i_batch in range(BN): 103 | x_batch = X[i_batch * BS: (i_batch + 1) * BS] 104 | latent_v = self.sess.run(self.latent_op, feed_dict={self.x: x_batch}) 105 | latent_sum += latent_v.sum(axis=0) 106 | 107 | c = latent_sum / N 108 | 109 | with task('2. Modify eps'): 110 | c[(abs(c) < eps) & (c < 0)] = -eps 111 | c[(abs(c) < eps) & (c > 0)] = eps 112 | 113 | self.sess.run(tf.assign(self.c, c)) 114 | 115 | def _get_R(self, dist, nu): 116 | return np.quantile(np.sqrt(dist), 1 - nu) 117 | -------------------------------------------------------------------------------- /dsvdd/networks.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | 4 | __all__ = ['mnist_lenet', 'cifar_lenet'] 5 | 6 | 7 | def mnist_lenet(H=32): 8 | model = keras.models.Sequential() 9 | 10 | model.add(keras.layers.Conv2D(8, (5, 5), padding='same', use_bias=False, input_shape=(28, 28, 1))) 11 | model.add(keras.layers.LeakyReLU(1e-2)) 12 | model.add(keras.layers.BatchNormalization(epsilon=1e-4, trainable=False)) 13 | model.add(keras.layers.MaxPool2D()) 14 | 15 | model.add(keras.layers.Conv2D(4, (5, 5), padding='same', use_bias=False)) 16 | model.add(keras.layers.LeakyReLU(1e-2)) 17 | model.add(keras.layers.BatchNormalization(epsilon=1e-4, trainable=False)) 18 | model.add(keras.layers.MaxPool2D()) 19 | 20 | model.add(keras.layers.Flatten()) 21 | model.add(keras.layers.Dense(H, use_bias=False)) 22 | 23 | return model 24 | 25 | 26 | def cifar_lenet(H=128): 27 | model = keras.models.Sequential() 28 | 29 | model.add(keras.layers.Conv2D(32, (5, 5), strides=(3, 3), padding='same', use_bias=False, input_shape=(32, 32, 3))) 30 | model.add(keras.layers.LeakyReLU(1e-2)) 31 | model.add(keras.layers.BatchNormalization(epsilon=1e-4, trainable=False)) 32 | 33 | model.add(keras.layers.Conv2D(64, (5, 5), strides=(3, 3), padding='same', use_bias=False)) 34 | model.add(keras.layers.LeakyReLU(1e-2)) 35 | model.add(keras.layers.BatchNormalization(epsilon=1e-4, trainable=False)) 36 | 37 | model.add(keras.layers.Conv2D(128, (5, 5), strides=(3, 3), padding='same', use_bias=False)) 38 | model.add(keras.layers.LeakyReLU(1e-2)) 39 | model.add(keras.layers.BatchNormalization(epsilon=1e-4, trainable=False)) 40 | 41 | model.add(keras.layers.Flatten()) 42 | model.add(keras.layers.Dense(H, use_bias=False)) 43 | 44 | return model 45 | -------------------------------------------------------------------------------- /dsvdd/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | 6 | __all__ = ['plot_most_normal_and_abnormal_images', 'task'] 7 | 8 | 9 | @contextmanager 10 | def task(_=''): 11 | yield 12 | 13 | 14 | def flatten_image_list(images, show_shape) -> np.ndarray: 15 | """ 16 | 17 | :param images: 18 | :param tuple show_shape: 19 | :return: 20 | """ 21 | N = np.prod(show_shape) 22 | 23 | if isinstance(images, list): 24 | images = np.array(images) 25 | 26 | for i in range(len(images.shape)): # find axis. 27 | if N == np.prod(images.shape[:i]): 28 | img_shape = images.shape[i:] 29 | new_shape = (N,) + img_shape 30 | return np.reshape(images, new_shape) 31 | 32 | else: 33 | raise ValueError('Cannot distinguish images. imgs shape: %s, show_shape: %s' % (images.shape, show_shape)) 34 | 35 | 36 | def get_shape(image): 37 | shape_ = image.shape[-3:] 38 | if len(shape_) <= 1: 39 | raise ValueError('Unexpected shape: {}'.format(shape_)) 40 | 41 | elif len(shape_) == 2: 42 | H, W = shape_ 43 | return H, W, 1 44 | 45 | elif len(shape_) == 3: 46 | H, W, C = shape_ 47 | if C in [1, 3]: 48 | return H, W, C 49 | else: 50 | raise ValueError('Unexpected shape: {}'.format(shape_)) 51 | 52 | else: 53 | raise ValueError('Unexpected shape: {}'.format(shape_)) 54 | 55 | 56 | def merge_image(images, show_shape, order='row'): 57 | images = flatten_image_list(images, show_shape) 58 | H, W, C = get_shape(images) 59 | I, J = show_shape 60 | result = np.zeros((I * H, J * W, C), dtype=images.dtype) 61 | 62 | for k, img in enumerate(images): 63 | if order.lower().startswith('row'): 64 | i = k // J 65 | j = k % J 66 | else: 67 | i = k % I 68 | j = k // I 69 | 70 | result[i * H: (i + 1) * H, j * W: (j + 1) * W] = img 71 | 72 | return result 73 | 74 | 75 | def plot_most_normal_and_abnormal_images(X_test, score): 76 | fig, axes = plt.subplots(nrows=2) 77 | fig.set_size_inches((5, 5)) 78 | inds = np.argsort(score) 79 | 80 | image1 = merge_image(X_test[inds[:10]], (2, 5)) 81 | axes[0].imshow(np.squeeze(image1)) 82 | axes[0].set_title('Most normal images') 83 | axes[0].set_axis_off() 84 | 85 | image2 = merge_image(X_test[inds[-10:]], (2, 5)) 86 | axes[1].imshow(np.squeeze(image2)) 87 | axes[1].set_title('Most abnormal images') 88 | axes[1].set_axis_off() 89 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from dsvdd import * 2 | import matplotlib.pyplot as plt 3 | import os 4 | import tensorflow as tf 5 | from sklearn.metrics import roc_auc_score 6 | 7 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '5' 8 | 9 | 10 | def main(cls=1): 11 | tf.reset_default_graph() 12 | from dsvdd.utils import plot_most_normal_and_abnormal_images 13 | # build model and DeepSVDD 14 | keras_model = mnist_lenet(32) 15 | svdd = DeepSVDD(keras_model, input_shape=(28, 28, 1), representation_dim=32, 16 | objective='soft-boundary') 17 | 18 | # get dataset 19 | X_train, X_test, y_test = get_mnist(cls) 20 | 21 | # train DeepSVDD 22 | svdd.fit(X_train, X_test, y_test, epochs=10, verbose=True) 23 | 24 | # test DeepSVDD 25 | score = svdd.predict(X_test) 26 | auc = roc_auc_score(y_test, -score) 27 | print('AUROC: %.3f' % auc) 28 | 29 | plot_most_normal_and_abnormal_images(X_test, score) 30 | plt.show() 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /main_cifar.py: -------------------------------------------------------------------------------- 1 | from dsvdd import * 2 | import matplotlib.pyplot as plt 3 | import os 4 | import tensorflow as tf 5 | from sklearn.metrics import roc_auc_score 6 | 7 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '5' 8 | 9 | 10 | def main(cls=1): 11 | tf.reset_default_graph() 12 | from dsvdd.utils import plot_most_normal_and_abnormal_images 13 | # build model and DeepSVDD 14 | keras_model = cifar_lenet(128) 15 | keras_model.summary() 16 | svdd = DeepSVDD(keras_model, input_shape=(32, 32, 3), representation_dim=128, 17 | objective='one-class') 18 | 19 | # get dataset 20 | X_train, X_test, y_test = get_cifar10(cls) 21 | 22 | # train DeepSVDD 23 | svdd.fit(X_train, X_test, y_test, epochs=10, verbose=True) 24 | 25 | # test DeepSVDD 26 | score = svdd.predict(X_test) 27 | auc = roc_auc_score(y_test, -score) 28 | print('AUROC: %.3f' % auc) 29 | 30 | plot_most_normal_and_abnormal_images(X_test, score) 31 | plt.show() 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn 3 | tensorflow-gpu>=1.12.0 4 | matplotlib 5 | tqdm 6 | -------------------------------------------------------------------------------- /results/mnist_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuclearboy95/Anomaly-Detection-Deep-SVDD-Tensorflow/7ca84c3ecaa52b40c0c4aac86de446773f82dcd2/results/mnist_0.png -------------------------------------------------------------------------------- /results/mnist_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuclearboy95/Anomaly-Detection-Deep-SVDD-Tensorflow/7ca84c3ecaa52b40c0c4aac86de446773f82dcd2/results/mnist_1.png -------------------------------------------------------------------------------- /results/mnist_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuclearboy95/Anomaly-Detection-Deep-SVDD-Tensorflow/7ca84c3ecaa52b40c0c4aac86de446773f82dcd2/results/mnist_2.png -------------------------------------------------------------------------------- /results/mnist_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuclearboy95/Anomaly-Detection-Deep-SVDD-Tensorflow/7ca84c3ecaa52b40c0c4aac86de446773f82dcd2/results/mnist_3.png -------------------------------------------------------------------------------- /results/mnist_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuclearboy95/Anomaly-Detection-Deep-SVDD-Tensorflow/7ca84c3ecaa52b40c0c4aac86de446773f82dcd2/results/mnist_4.png -------------------------------------------------------------------------------- /results/mnist_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuclearboy95/Anomaly-Detection-Deep-SVDD-Tensorflow/7ca84c3ecaa52b40c0c4aac86de446773f82dcd2/results/mnist_5.png -------------------------------------------------------------------------------- /results/mnist_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuclearboy95/Anomaly-Detection-Deep-SVDD-Tensorflow/7ca84c3ecaa52b40c0c4aac86de446773f82dcd2/results/mnist_6.png -------------------------------------------------------------------------------- /results/mnist_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuclearboy95/Anomaly-Detection-Deep-SVDD-Tensorflow/7ca84c3ecaa52b40c0c4aac86de446773f82dcd2/results/mnist_7.png -------------------------------------------------------------------------------- /results/mnist_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuclearboy95/Anomaly-Detection-Deep-SVDD-Tensorflow/7ca84c3ecaa52b40c0c4aac86de446773f82dcd2/results/mnist_8.png -------------------------------------------------------------------------------- /results/mnist_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nuclearboy95/Anomaly-Detection-Deep-SVDD-Tensorflow/7ca84c3ecaa52b40c0c4aac86de446773f82dcd2/results/mnist_9.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('requirements.txt', 'r') as f: 4 | install_reqs = [ 5 | s for s in [ 6 | line.strip(' \n') for line in f 7 | ] if not s.startswith('#') and s != '' 8 | ] 9 | 10 | setup(name='deep-svdd', 11 | version='1.3', 12 | url='https://github.com/nuclearboy95/Deep-SVDD-Tensorflow', 13 | license='MIT', 14 | author='Jihun Yi', 15 | author_email='t080205@gmail.com', 16 | description='Tensorflow implementation of Deep SVDD', 17 | packages=find_packages(exclude=['dist', 'build']), 18 | include_package_data=True, 19 | long_description=open('README.md').read(), 20 | zip_safe=False, 21 | setup_requires=['nose>=1.0'], 22 | install_requires=install_reqs, 23 | test_suite='nose.collector') 24 | --------------------------------------------------------------------------------