├── .gitignore
├── MANIFEST.in
├── README.md
├── main.py
├── ocgan
├── __init__.py
├── dataset.py
├── models.py
├── ocgan.py
└── utils.py
├── requirements.txt
├── setup.py
└── upload.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | ckpts
2 | ckpts/*
3 | tblog/*
4 | *.pyc
5 | *.idea
6 | ocgan.egg-info/*
7 | dist/*
8 | build/*
9 | .eggs/*
10 | dist
11 | build
12 | eggs
13 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include requirements.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tensorflow Implementation of OCGAN
2 | This repository provides a [Tensorflow](https://www.tensorflow.org/) implementation of the *OCGAN* presented in
3 | CVPR 2019 paper "[OCGAN: One-class Novelty Detection Using GANs with Constrained Latent Representations](http://openaccess.thecvf.com/content_CVPR_2019/papers/Perera_OCGAN_One-Class_Novelty_Detection_Using_GANs_With_Constrained_Latent_Representations_CVPR_2019_paper.pdf)".
4 |
5 | The author's implementation of *OCGAN* in MXNet is at [here](https://github.com/PramuPerera/OCGAN).
6 |
7 |
8 | ## Installation
9 | This code is written in `Python 3.5` and tested with `Tensorflow 1.13`.
10 |
11 | Install using pip or clone this repository.
12 |
13 | 1. Installation using pip:
14 | ```bash
15 | pip install ocgan
16 | ```
17 |
18 | and
19 |
20 | ```python
21 | from ocgan import OCGAN
22 | ```
23 |
24 | 2. Clone this repository:
25 |
26 | ```bash
27 | git clone https://github.com/nuclearboy95/Anomaly-Detection-OCGAN-tensorflow.git
28 | ```
29 |
30 | ## Result (AUROC)
31 | | **MNIST DIGIT** | **OCGAN w/
Informative-negative
mining** | **OCGAN w/o
Informative-negative
mining** |
32 | |:---------------:|:--------------------------------------------------:|:---------------------------------------------------:|
33 | | 0 | **0.9952** | 0.9935 |
34 | | 1 | 0.9976 | **0.9985** |
35 | | 2 | **0.9268** | 0.9133 |
36 | | 3 | **0.9410** | 0.9208 |
37 | | 4 | **0.9636** | 0.9600 |
38 | | 5 | **0.9613** | 0.9145 |
39 | | 6 | **0.9910** | 0.9835 |
40 | | 7 | **0.9658** | 0.9526 |
41 | | 8 | **0.9009** | 0.8758 |
42 | | 9 | 0.9584 | **0.9701** |
43 |
44 | NOTE: *The AUROC values are measured only once for each digit.*
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from ocgan.utils import attrdict
3 | from ocgan import OCGAN
4 | from ocgan.dataset import get_mnist
5 |
6 |
7 | def train(use_informative_mining=True, cls=1):
8 | x_train, x_test, y_test = get_mnist(cls)
9 | BS = 128
10 |
11 | steps = attrdict({
12 | 'C': 5,
13 | 'AE': 5,
14 | 'Dl': 3,
15 | 'Dv': 2,
16 | 'IM': 5
17 | })
18 | lr = attrdict({
19 | 'C': 1e-4,
20 | 'recon': 3e-4,
21 | 'Dl': 1e-5,
22 | 'Dv': 1e-5,
23 | 'AE': 3e-5,
24 | })
25 |
26 | config = tf.ConfigProto()
27 | config.gpu_options.allow_growth = True
28 | with tf.Session(config=config) as sess:
29 | ocgan = OCGAN(sess, steps, lr, BS, use_informative_mining=use_informative_mining)
30 | if use_informative_mining:
31 | tb_name = 'OCGAN(%s)' % cls
32 | else:
33 | tb_name = 'OCGAN_NOIM(%s)' % cls
34 |
35 | sess.run(tf.global_variables_initializer())
36 | ckpt_path = './ckpts/%s/%s' % (tb_name, tb_name)
37 | ocgan.fit(x_train, x_test, y_test, epochs=50, ckpt_path=ckpt_path)
38 |
39 |
40 | def main():
41 | for cls in range(10):
42 | tf.reset_default_graph()
43 | train(True, cls)
44 |
45 |
46 | if __name__ == '__main__':
47 | main()
48 |
--------------------------------------------------------------------------------
/ocgan/__init__.py:
--------------------------------------------------------------------------------
1 | from .ocgan import OCGAN
2 |
--------------------------------------------------------------------------------
/ocgan/dataset.py:
--------------------------------------------------------------------------------
1 | from tensorflow import keras
2 | import numpy as np
3 |
4 |
5 | def get_mnist(cls=1):
6 | d_train, d_test = keras.datasets.mnist.load_data()
7 | x_train, y_train = d_train
8 | x_test, y_test = d_test
9 |
10 | mask = y_train == cls
11 |
12 | x_train = x_train[mask]
13 | x_train = np.expand_dims(x_train / 255., axis=-1).astype(np.float32)
14 | x_test = np.expand_dims(x_test / 255., axis=-1).astype(np.float32)
15 |
16 | y_test = (y_test == cls).astype(np.float32)
17 | return x_train, x_test, y_test # y_test: normal -> 1 / abnormal -> 0
18 |
--------------------------------------------------------------------------------
/ocgan/models.py:
--------------------------------------------------------------------------------
1 | from tensorflow import keras
2 | import tensorflow as tf
3 |
4 |
5 | def get_encoder(input_shape=(28, 28, 1)):
6 | model = keras.Sequential([
7 | keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.tanh, input_shape=input_shape),
8 | keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.tanh),
9 | keras.layers.BatchNormalization(),
10 | keras.layers.MaxPool2D(),
11 |
12 | keras.layers.Conv2D(64, 3, padding='same', activation=tf.nn.tanh),
13 | keras.layers.Conv2D(64, 3, padding='same', activation=tf.nn.tanh),
14 | keras.layers.BatchNormalization(),
15 | keras.layers.MaxPool2D(),
16 |
17 | keras.layers.Conv2D(32, 3, padding='valid', activation=tf.nn.tanh),
18 | keras.layers.Conv2D(32, 3, padding='valid', activation=tf.nn.tanh),
19 | keras.layers.Flatten()
20 | ], name='encoder')
21 | return model
22 |
23 |
24 | def get_decoder(input_shape=(288,)):
25 | model = keras.Sequential([
26 | keras.layers.Reshape((3, 3, 32), input_shape=input_shape),
27 | keras.layers.UpSampling2D(),
28 | keras.layers.Conv2DTranspose(32, 3, padding='same', activation=tf.nn.tanh),
29 | keras.layers.Conv2DTranspose(32, 3, padding='same', activation=tf.nn.tanh),
30 | keras.layers.BatchNormalization(),
31 |
32 | keras.layers.UpSampling2D(),
33 | keras.layers.Conv2DTranspose(64, 3, padding='same', activation=tf.nn.tanh),
34 | keras.layers.Conv2DTranspose(64, 3, padding='same', activation=tf.nn.tanh),
35 | keras.layers.BatchNormalization(),
36 |
37 | keras.layers.UpSampling2D(),
38 | keras.layers.Conv2DTranspose(64, 3, padding='valid', activation=tf.nn.tanh),
39 | keras.layers.Conv2DTranspose(64, 3, padding='valid', activation=tf.nn.tanh),
40 | keras.layers.BatchNormalization(),
41 |
42 | keras.layers.Conv2DTranspose(1, 3, padding='same', activation='sigmoid')
43 | ], name='decoder')
44 | return model
45 |
46 |
47 | def get_disc_latent(input_shape=(288,)):
48 | model = keras.Sequential([
49 | keras.layers.Dense(128, input_shape=input_shape),
50 | keras.layers.BatchNormalization(),
51 | keras.layers.ReLU(),
52 |
53 | keras.layers.Dense(64),
54 | keras.layers.BatchNormalization(),
55 | keras.layers.ReLU(),
56 |
57 | keras.layers.Dense(32),
58 | keras.layers.BatchNormalization(),
59 | keras.layers.ReLU(),
60 |
61 | keras.layers.Dense(16),
62 | keras.layers.BatchNormalization(),
63 | keras.layers.ReLU(),
64 |
65 | keras.layers.Dense(1)
66 | ], name='discriminator_latent')
67 | return model
68 |
69 |
70 | def get_disc_visual(input_shape=(28, 28, 1)):
71 | model = keras.Sequential([
72 | keras.layers.Conv2D(16, 5, (2, 2), padding='same', input_shape=input_shape),
73 | keras.layers.BatchNormalization(),
74 | keras.layers.LeakyReLU(0.2),
75 |
76 | keras.layers.Conv2D(16, 5, (2, 2), padding='same'),
77 | keras.layers.BatchNormalization(),
78 | keras.layers.LeakyReLU(0.2),
79 |
80 | keras.layers.Conv2D(16, 5, (2, 2), padding='same'),
81 | keras.layers.BatchNormalization(),
82 | keras.layers.LeakyReLU(0.2),
83 |
84 | keras.layers.Conv2D(1, 5, (2, 2), padding='same'),
85 | keras.layers.GlobalAveragePooling2D()
86 | ], name='discriminator_visual')
87 | return model
88 |
89 |
90 | def get_classifier(input_shape=(28, 28, 1)):
91 | model = keras.Sequential([
92 | keras.layers.Conv2D(32, 5, (2, 2), padding='same', input_shape=input_shape),
93 | keras.layers.BatchNormalization(),
94 | keras.layers.LeakyReLU(0.2),
95 |
96 | keras.layers.Conv2D(64, 5, (2, 2), padding='same'),
97 | keras.layers.BatchNormalization(),
98 | keras.layers.LeakyReLU(0.2),
99 |
100 | keras.layers.Conv2D(64, 5, (2, 2), padding='same'),
101 | keras.layers.BatchNormalization(),
102 | keras.layers.LeakyReLU(0.2),
103 |
104 | keras.layers.Conv2D(1, 5, (2, 2), padding='same'),
105 | keras.layers.GlobalAveragePooling2D()
106 | ], name='classifier')
107 | return model
108 |
109 |
--------------------------------------------------------------------------------
/ocgan/ocgan.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import roc_auc_score
3 | from .models import *
4 | from .utils import task, attrdict, d_of_l, assure_dtype_uint8, merge, gray2rgb, add_border
5 |
6 |
7 | class OCGAN:
8 | def __init__(self, sess, steps, lr, BS, use_informative_mining=True):
9 | self.sess = sess
10 | latent_shape = [288]
11 |
12 | self.enc = get_encoder()
13 | self.dec = get_decoder()
14 | self.disc_v = get_disc_visual()
15 | self.disc_l = get_disc_latent(latent_shape)
16 | self.cl = get_classifier()
17 | self.BS = BS
18 |
19 | self.steps = steps
20 | self.use_informative_mining = use_informative_mining
21 |
22 | with task('Build Graph'):
23 | X = tf.placeholder(tf.float32, [BS, 28, 28, 1])
24 | z = tf.random.normal(tf.shape(X), stddev=1e-5)
25 |
26 | l2 = tf.random.uniform([BS] + latent_shape, minval=-1, maxval=1)
27 | Xz = X + z
28 | l1 = self.enc(Xz)
29 | self.recon = self.dec(self.enc(X))
30 | dec_l2 = self.dec(l2)
31 | self.gen = dec_l2
32 |
33 | with task('Loss'):
34 | loss_op = attrdict()
35 | with task('1. Classifier loss'):
36 | logits_C_l1 = self.cl(self.dec(l1))
37 | logits_C_l2 = self.cl(dec_l2)
38 |
39 | loss_op.C_l1 = tf.losses.sigmoid_cross_entropy(tf.ones_like(logits_C_l1), logits=logits_C_l1)
40 | loss_op.C_l2 = tf.losses.sigmoid_cross_entropy(tf.zeros_like(logits_C_l2), logits=logits_C_l2)
41 | loss_op.C = (loss_op.C_l1 + loss_op.C_l2) / 2
42 |
43 | with task('2. Discriminator latent loss'):
44 | logits_Dl_l1 = self.disc_l(l1)
45 | logits_Dl_l2 = self.disc_l(l2)
46 |
47 | loss_op.Dl_l1 = tf.losses.sigmoid_cross_entropy(tf.zeros_like(logits_Dl_l1), logits=logits_Dl_l1)
48 | loss_op.Dl_l2 = tf.losses.sigmoid_cross_entropy(tf.ones_like(logits_Dl_l2), logits=logits_Dl_l2)
49 | loss_op.Dl = (loss_op.Dl_l1 + loss_op.Dl_l2) / 2
50 |
51 | with task('3. Discriminator visual loss'):
52 | logits_Dv_X = self.disc_v(X)
53 | logits_Dv_l2 = self.disc_v(self.dec(l2))
54 |
55 | loss_op.Dv_X = tf.losses.sigmoid_cross_entropy(tf.ones_like(logits_Dv_X), logits=logits_Dv_X)
56 | loss_op.Dv_l2 = tf.losses.sigmoid_cross_entropy(tf.zeros_like(logits_Dv_l2), logits=logits_Dv_l2)
57 | loss_op.Dv = (loss_op.Dv_X + loss_op.Dv_l2) / 2
58 |
59 | with task('4. Informative-negative mining'):
60 | l2_mine = tf.get_variable('l2_mine', [BS] + latent_shape, tf.float32)
61 | logits_C_l2_mine = self.cl(self.dec(l2_mine))
62 | loss_C_l2_mine = tf.losses.sigmoid_cross_entropy(tf.zeros_like(logits_C_l2_mine), logits=logits_C_l2_mine)
63 | opt = tf.train.GradientDescentOptimizer(1)
64 |
65 | def cond(i):
66 | return i < self.steps.IM
67 |
68 | def body(i):
69 | descent_op = opt.minimize(loss_C_l2_mine, var_list=[l2_mine])
70 | with tf.control_dependencies([descent_op]):
71 | return i + 1
72 |
73 | self.l2_mine_descent = tf.while_loop(cond, body, [tf.constant(0)])
74 |
75 | with task('5. AE loss'):
76 | Xh = self.dec(l1)
77 | loss_AE_recon = tf.reduce_mean(tf.square(X - Xh), axis=[1, 2, 3])
78 | loss_op.AE_recon = tf.reduce_mean(loss_AE_recon)
79 |
80 | loss_op.AE_l = tf.losses.sigmoid_cross_entropy(tf.ones_like(logits_Dl_l1), logits=logits_Dl_l1)
81 |
82 | logits_Dv_l2_mine = self.disc_v(self.dec(l2_mine))
83 | loss_op.AE_v = tf.losses.sigmoid_cross_entropy(tf.ones_like(logits_Dv_l2_mine), logits=logits_Dv_l2_mine)
84 |
85 | self.lamb = tf.placeholder_with_default(10., [])
86 | loss_op.AE = loss_op.AE_l + loss_op.AE_v + self.lamb * loss_op.AE_recon
87 |
88 | with task('Optimize'):
89 | Opt = tf.train.AdamOptimizer
90 | train_op = attrdict()
91 | ae_vars = self.enc.trainable_variables + self.dec.trainable_variables
92 | train_op.C = Opt(lr.C).minimize(loss_op.C, var_list=self.cl.trainable_variables)
93 | train_op.Dl = Opt(lr.Dl).minimize(loss_op.Dl, var_list=self.disc_l.trainable_variables)
94 | train_op.Dv = Opt(lr.Dv).minimize(loss_op.Dv, var_list=self.disc_v.trainable_variables)
95 | train_op.AE = Opt(lr.AE).minimize(loss_op.AE, var_list=ae_vars)
96 | train_op.recon = Opt(lr.recon).minimize(loss_op.AE_recon, var_list=ae_vars)
97 |
98 | with task('Placeholders'):
99 | self.loss_op = loss_op
100 | self.train_op = train_op
101 | self.X = X
102 | self.Xz = Xz
103 | self.l2_mine = l2_mine
104 | self.l2 = l2
105 | self.anomaly_score = tf.reduce_mean(tf.square(X - self.recon), axis=[1, 2, 3])
106 |
107 | def pretrain_AE(self, X):
108 | sess = self.sess
109 | feed_dict = {self.X: X}
110 | loss = attrdict()
111 | loss.AE_recon, _ = sess.run([self.loss_op.AE_recon, self.train_op.recon], feed_dict)
112 | return loss.as_dict()
113 |
114 | def train_step(self, X):
115 | sess = self.sess
116 | feed_dict = {self.X: X}
117 | loss = attrdict()
118 |
119 | with task('1. Train classifier'):
120 | for _ in range(self.steps.C):
121 | loss_C, _ = sess.run([self.loss_op.filt_keys('C'), self.train_op.C], feed_dict)
122 | loss.update(loss_C)
123 |
124 | with task('2. Train discriminators'):
125 | for _ in range(self.steps.Dl):
126 | loss_Dl, _ = sess.run([self.loss_op.filt_keys('Dl'), self.train_op.Dl], feed_dict)
127 | loss.update(loss_Dl)
128 |
129 | for _ in range(self.steps.Dv):
130 | loss_Dv, _ = sess.run([self.loss_op.filt_keys('Dv'), self.train_op.Dv], feed_dict)
131 | loss.update(loss_Dv)
132 |
133 | if self.use_informative_mining:
134 | self.informative_negative_mining()
135 |
136 | with task('4. Train AutoEncoder'):
137 | for _ in range(self.steps.AE):
138 | loss_AE, _ = sess.run([self.loss_op.filt_keys('AE'), self.train_op.AE], feed_dict)
139 | loss.update(loss_AE)
140 |
141 | return loss.as_dict()
142 |
143 | def informative_negative_mining(self):
144 | self.sess.run(tf.assign(self.l2_mine, self.l2))
145 | self.sess.run(self.l2_mine_descent)
146 |
147 | def fit(self, x_train, x_test, y_test, epochs=100, ckpt_path='ocgan'):
148 | BS = self.BS
149 | saver = tf.train.Saver()
150 | N = x_train.shape[0]
151 | sess = self.sess
152 |
153 | for i_epoch in range(epochs):
154 | keras.backend.set_learning_phase(True)
155 | results = d_of_l()
156 | X = x_train[np.random.permutation(N)]
157 | BN = N // BS # residual batches are dropped!
158 | for i_batch in range(BN):
159 | x_batch = X[i_batch * BS: (i_batch + 1) * BS]
160 |
161 | if i_epoch < 20:
162 | result = self.pretrain_AE(x_batch)
163 | else:
164 | result = self.train_step(x_batch)
165 |
166 | results.appends(result)
167 |
168 | else:
169 | with task('Eval test performance'):
170 | N_test = BS * (N // BS)
171 | auc = self.evaluate(x_test[:N_test], y_test[:N_test])
172 | print('Epoch %d AUROC: %.4f' % (i_epoch, auc))
173 | saver.save(sess, ckpt_path)
174 |
175 | with task('Save Images'):
176 | keras.backend.set_learning_phase(False)
177 | gens = self.generate()[:BS]
178 | gens = assure_dtype_uint8(gens)
179 |
180 | is_normal = (y_test[:BS] == 1.)[:BS]
181 |
182 | origin = x_test[:BS]
183 | recons = self.reconstruct(origin)
184 | recons = gray2rgb(assure_dtype_uint8(recons))
185 | recons[is_normal] = add_border(recons[is_normal])
186 |
187 | origin = gray2rgb(assure_dtype_uint8(origin))[:BS]
188 | origin[is_normal] = add_border(origin[is_normal])
189 |
190 | d = {
191 | 'example/generated': merge(gens[:64], (8, 8)),
192 | 'example/original': merge(origin[:64], (8, 8)),
193 | 'example/recon': merge(recons[:64], (8, 8))
194 | } # example images are created.
195 |
196 | # feedforward
197 |
198 | def generate(self):
199 | return self.sess.run(self.gen)
200 |
201 | def reconstruct(self, X):
202 | recon = self.sess.run(self.recon, feed_dict={self.X: X})
203 | return recon
204 |
205 | def predict(self, X):
206 | anomaly_scores = list()
207 | BS = self.BS
208 | N = X.shape[0]
209 | BN = N // BS # residual batches are dropped!
210 | for i_batch in range(BN):
211 | x_batch = X[i_batch * BS: (i_batch + 1) * BS]
212 | anomaly_score = self.sess.run(self.anomaly_score, feed_dict={self.X: x_batch})
213 | anomaly_scores.append(anomaly_score)
214 |
215 | return np.concatenate(anomaly_scores)
216 |
217 | def evaluate(self, X, y):
218 | anomaly_score = self.predict(X)
219 | auc = roc_auc_score(y, -anomaly_score)
220 | return auc
221 |
--------------------------------------------------------------------------------
/ocgan/utils.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from collections import defaultdict
3 | import numpy as np
4 |
5 |
6 | __all__ = ['task', 'attrdict', 'd_of_l', 'assure_dtype_uint8',
7 | 'merge', 'gray2rgb', 'add_border']
8 |
9 |
10 | @contextmanager
11 | def task(_=''):
12 | yield
13 |
14 |
15 | class attrdict(dict):
16 | def __init__(self, *args, **kwargs):
17 | super().__init__(*args, **kwargs)
18 |
19 | __getattr__ = dict.__getitem__
20 | __setattr__ = dict.__setitem__
21 |
22 | def as_dict(self):
23 | return dict(self)
24 |
25 | def filt_keys(self, prefix=''):
26 | d = {k: v for k, v in self.items() if k.startswith(prefix)}
27 | return self.__class__(d)
28 |
29 |
30 | class d_of_l(defaultdict):
31 | __getattr__ = dict.__getitem__
32 |
33 | def __init__(self, *args, **kwargs):
34 | super().__init__(list, *args, **kwargs)
35 |
36 | def as_dict(self):
37 | return dict(self)
38 |
39 | def appends(self, d):
40 | for key, value in d.items():
41 | self[key].append(value)
42 |
43 |
44 | def assure_dtype_uint8(image):
45 | def raise_unknown_float_image():
46 | raise ValueError('Unknown float image range. Min: {}, Max: {}'.format(min_v, max_v))
47 |
48 | def raise_unknown_image_dtype():
49 | raise ValueError('Unknown image dtype: {}'.format(image.dtype))
50 |
51 | max_v = image.max()
52 | min_v = image.min()
53 | if image.dtype in [np.float32, np.float64]:
54 | if 0 <= max_v <= 1:
55 | if 0 <= min_v <= 1: # [0, 1)
56 | min_v, max_v = 0, 1
57 |
58 | elif -1 <= min_v <= 0: # Presumably [-1, 1)
59 | min_v, max_v = -1, 1
60 |
61 | else:
62 | raise_unknown_float_image()
63 |
64 | elif 0 <= max_v <= 255:
65 | if 0 <= min_v <= 255: # Presumably [0, 255)
66 | min_v, max_v = 0, 255
67 |
68 | elif -256 <= min_v <= 0: # Presumably [-256, 255)
69 | min_v, max_v = -256, 255
70 |
71 | else:
72 | raise_unknown_float_image()
73 |
74 | else:
75 | raise_unknown_float_image()
76 |
77 | return rescale(image,
78 | min_from=min_v, max_from=max_v,
79 | min_to=0, max_to=255,
80 | dtype='uint8')
81 |
82 | elif image.dtype in [np.uint8]:
83 | return image
84 |
85 | else:
86 | raise_unknown_image_dtype()
87 |
88 |
89 | def rescale(img, min_from=-1, max_from=1, min_to=0, max_to=255, dtype='float32'):
90 | len_from = max_from - min_from
91 | len_to = max_to - min_to
92 | img = (img.astype(np.float32) - min_from) * len_to / len_from + min_to
93 | return img.astype(dtype)
94 |
95 |
96 | def flatten_image_list(images, show_shape) -> np.ndarray:
97 | """
98 |
99 | :param images:
100 | :param tuple show_shape:
101 | :return:
102 | """
103 | N = np.prod(show_shape)
104 |
105 | if isinstance(images, list):
106 | images = np.array(images)
107 |
108 | for i in range(len(images.shape)): # find axis.
109 | if N == np.prod(images.shape[:i]):
110 | img_shape = images.shape[i:]
111 | new_shape = (N,) + img_shape
112 | return np.reshape(images, new_shape)
113 |
114 | else:
115 | raise ValueError('Cannot distinguish images. imgs shape: %s, show_shape: %s' % (images.shape, show_shape))
116 |
117 |
118 | def merge(images, show_shape, order='row') -> np.ndarray:
119 | """
120 |
121 | :param np.ndarray images:
122 | :param tuple show_shape:
123 | :param str order:
124 |
125 | :return:
126 | """
127 | images = flatten_image_list(images, show_shape)
128 | H, W, C = images.shape[-3:]
129 | I, J = show_shape
130 | result = np.zeros((I * H, J * W, C), dtype=images.dtype)
131 |
132 | for k, image in enumerate(images):
133 | if order.lower().startswith('row'):
134 | i = k // J
135 | j = k % J
136 | else:
137 | i = k % I
138 | j = k // I
139 |
140 | target_shape = result[i * H: (i + 1) * H, j * W: (j + 1) * W].shape
141 | result[i * H: (i + 1) * H, j * W: (j + 1) * W] = image.reshape(target_shape)
142 |
143 | return result
144 |
145 |
146 | def gray2rgb(images):
147 | H, W, C = images.shape[-3:]
148 | if C != 1:
149 | return images
150 |
151 | if images.shape[-1] != C:
152 | images = np.expand_dims(images, axis=-1)
153 |
154 | tile_shape = np.ones(len(images.shape), dtype=int)
155 | tile_shape[-1] = 3
156 | images = np.tile(images, tile_shape)
157 | return images
158 |
159 |
160 | def add_border(images, color=(0, 255, 0), border=0.07):
161 | H, W, C = images.shape[-3:]
162 |
163 | if isinstance(border, float): # if fraction
164 | border = int(round(min(H, W) * border))
165 |
166 | T = border
167 | images = images.copy()
168 | images = assure_dtype_uint8(images)
169 | images[:, :T, :] = color
170 | images[:, -T:, :] = color
171 | images[:, :, :T] = color
172 | images[:, :, -T:] = color
173 |
174 | return images
175 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scikit-learn
3 | tensorflow-gpu>=1.12.0
4 | matplotlib
5 | tqdm
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open('requirements.txt', 'r') as f:
4 | install_reqs = [
5 | s for s in [
6 | line.strip(' \n') for line in f
7 | ] if not s.startswith('#') and s != ''
8 | ]
9 |
10 | setup(name='ocgan',
11 | version='1.0',
12 | url='https://github.com/nuclearboy95/Anomaly-Detection-OCGAN-tensorflow',
13 | license='MIT',
14 | author='Jihun Yi',
15 | author_email='t080205@gmail.com',
16 | description='Tensorflow implementation of OCGAN',
17 | packages=find_packages(exclude=['dist', 'build']),
18 | include_package_data=True,
19 | long_description=open('README.md').read(),
20 | zip_safe=False,
21 | setup_requires=['nose>=1.0'],
22 | install_requires=install_reqs,
23 | test_suite='nose.collector')
24 |
--------------------------------------------------------------------------------
/upload.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | rm build/*
4 | rm dist/*
5 | python setup.py bdist_wheel
6 | twine upload dist/ocgan-*.whl
--------------------------------------------------------------------------------