├── ODIR ├── training annotation (English).xlsx ├── off-site test annotation (English).xlsx ├── on-site test annotation (English).xlsx └── ODIR dataset building steps.ipynb ├── README.md ├── util.py ├── input_ops.py ├── sngan_gen.py ├── sngan_gen_y.py ├── sngan_joint.py ├── diffAugment.py ├── config.py ├── BigGAN ├── resnet_ops.py ├── resnet_gen_y.py ├── resnet_joint.py ├── resnet_gen_y_deep.py └── resnet_joint_deep.py ├── ops.py ├── trainer_joint.py ├── model_joint.py └── arch_ops.py /ODIR/training annotation (English).xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xyporz/CISSL-GANs/HEAD/ODIR/training annotation (English).xlsx -------------------------------------------------------------------------------- /ODIR/off-site test annotation (English).xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xyporz/CISSL-GANs/HEAD/ODIR/off-site test annotation (English).xlsx -------------------------------------------------------------------------------- /ODIR/on-site test annotation (English).xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xyporz/CISSL-GANs/HEAD/ODIR/on-site test annotation (English).xlsx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CISSL-GANs 2 | 3 | ### About 4 | 5 | 📫 My email: xieyingpeng2017@email.szu.edu.cn 6 | 7 | Official tensorflow implementation of the paper: 8 | #### Fundus Image-label Pairs Synthesis and Retinopathy Screening via GANs with Class-imbalanced Semi-supervised Learning 9 | IEEE Transactions on Medical Imaging (IEEE-TMI). 10 | 11 | This code is primarily based on [SSGAN-Tensorflow](https://github.com/clvrai/SSGAN-Tensorflow) and [Compare GAN](https://github.com/google/compare_gan). 12 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | """ Utilities """ 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | # Logging 9 | # ======= 10 | 11 | import logging 12 | import os, os.path 13 | from colorlog import ColoredFormatter 14 | 15 | ch = logging.StreamHandler() 16 | ch.setLevel(logging.DEBUG) 17 | 18 | formatter = ColoredFormatter( 19 | "%(log_color)s[%(asctime)s] %(message)s", 20 | # datefmt='%H:%M:%S.%f', 21 | datefmt=None, 22 | reset=True, 23 | log_colors={ 24 | 'DEBUG': 'cyan', 25 | 'INFO': 'white,bold', 26 | 'INFOV': 'cyan,bold', 27 | 'WARNING': 'yellow', 28 | 'ERROR': 'red,bold', 29 | 'CRITICAL': 'red,bg_white', 30 | }, 31 | secondary_log_colors={}, 32 | style='%' 33 | ) 34 | 35 | ch.setFormatter(formatter) 36 | 37 | log = logging.getLogger('GBGANs') 38 | log.setLevel(logging.DEBUG) 39 | log.handlers = [] # No duplicated handlers 40 | log.propagate = False # workaround for duplicated logs in ipython 41 | log.addHandler(ch) 42 | 43 | logging.addLevelName(logging.INFO + 1, 'INFOV') 44 | def _infov(self, msg, *args, **kwargs): 45 | self.log(logging.INFO + 1, msg, *args, **kwargs) 46 | 47 | logging.Logger.infov = _infov 48 | -------------------------------------------------------------------------------- /input_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from util import log 4 | 5 | def check_data_id(dataset, data_id): 6 | if not data_id: 7 | return 8 | 9 | wrong = [] 10 | for id in data_id: 11 | if id in dataset.data: 12 | pass 13 | else: 14 | wrong.append(id) 15 | 16 | if len(wrong) > 0: 17 | raise RuntimeError("There are %d invalid ids, including %s" % ( 18 | len(wrong), wrong[:5] 19 | )) 20 | 21 | 22 | def create_input_ops(dataset, 23 | batch_size, 24 | num_threads=16, # for creating batches 25 | data_id=None, 26 | scope='inputs', 27 | shuffle=True, 28 | ): 29 | ''' 30 | Return a batched tensor for the inputs from the dataset. 31 | ''' 32 | input_ops = {} 33 | 34 | if data_id is None: 35 | data_id = dataset.ids 36 | log.info("input_ops [%s]: Using %d IDs from dataset", scope, len(data_id)) 37 | else: 38 | log.info("input_ops [%s]: Using specified %d IDs", scope, len(data_id)) 39 | 40 | # single operations 41 | with tf.device("/cpu:0"), tf.name_scope(scope): 42 | input_ops['id'] = tf.train.string_input_producer( 43 | tf.convert_to_tensor(data_id), 44 | capacity=128 45 | ).dequeue(name='input_ids_dequeue') 46 | 47 | m, label = dataset.get_data(data_id[0]) 48 | 49 | def load_fn(id): 50 | image, label = dataset.get_data(id) 51 | return (id, 52 | image.astype(np.float32), 53 | label.astype(np.float32)) 54 | 55 | input_ops['id'], input_ops['image'], input_ops['label'] = tf.py_func( 56 | load_fn, inp=[input_ops['id']], 57 | Tout=[tf.string, tf.float32, tf.float32], 58 | name='func_hp' 59 | ) 60 | input_ops['id'].set_shape([]) 61 | input_ops['image'].set_shape(list(m.shape)) 62 | input_ops['label'].set_shape(list(label.shape)) 63 | 64 | # batchify 65 | capacity = 2 * batch_size * num_threads 66 | min_capacity = min(int(capacity * 0.75), 1024) 67 | 68 | if shuffle: 69 | batch_ops = tf.train.shuffle_batch( 70 | input_ops, 71 | batch_size=batch_size, 72 | num_threads=num_threads, 73 | capacity=capacity, 74 | min_after_dequeue=min_capacity, 75 | ) 76 | else: 77 | batch_ops = tf.train.batch( 78 | input_ops, 79 | batch_size=batch_size, 80 | num_threads=num_threads, 81 | capacity=capacity 82 | ) 83 | 84 | return batch_ops 85 | -------------------------------------------------------------------------------- /sngan_gen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from ops import linear 4 | from ops import deconv2d 5 | from util import log 6 | from ops import non_local_block 7 | import time 8 | 9 | def conv_out_size_same(size, stride): 10 | return int(np.ceil(float(size) / float(stride))) 11 | 12 | class Generator(object): 13 | def __init__(self, name, h, w, c, is_train, use_sn): 14 | 15 | self.name = name 16 | self.s_h, self.s_w, self.colors = [h,w,c] 17 | self.s_h2, self.s_w2 = conv_out_size_same(self.s_h, 2), conv_out_size_same(self.s_w, 2) 18 | self.s_h4, self.s_w4 = conv_out_size_same(self.s_h2, 2), conv_out_size_same(self.s_w2, 2) 19 | self.s_h8, self.s_w8 = conv_out_size_same(self.s_h4, 2), conv_out_size_same(self.s_w4, 2) 20 | self._is_train = is_train 21 | self.use_sn = use_sn 22 | 23 | def __call__(self, z): 24 | 25 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 26 | batch_size = z.shape[0] 27 | 28 | net = linear(z, self.s_h8 * self.s_w8 * 512, scope="g_fc1", use_sn=self.use_sn) 29 | net = tf.reshape(net, [batch_size, self.s_h8, self.s_w8, 512]) 30 | net = tf.contrib.layers.batch_norm(net, 31 | updates_collections=None, is_training=self._is_train, center=True, scale=True, decay=0.9, epsilon=1e-5, scope="g_bn_deconv0") 32 | net = tf.nn.relu(net) 33 | 34 | net = deconv2d(net, [batch_size, self.s_h4, self.s_w4, 256], 4, 4, 2, 2, name="g_dc1", use_sn=self.use_sn) 35 | net = tf.contrib.layers.batch_norm(net, 36 | updates_collections=None, is_training=self._is_train, center=True, scale=True, decay=0.9, epsilon=1e-5, scope="g_bn_deconv1") 37 | net = tf.nn.relu(net) 38 | 39 | net = deconv2d(net, [batch_size, self.s_h2, self.s_w2, 128], 4, 4, 2, 2, name="g_dc2", use_sn=self.use_sn) 40 | net = tf.contrib.layers.batch_norm(net, 41 | updates_collections=None, is_training=self._is_train, center=True, scale=True, decay=0.9, epsilon=1e-5, scope="g_bn_deconv2") 42 | net = tf.nn.relu(net) 43 | 44 | net = deconv2d(net, [batch_size, self.s_h, self.s_w, 64], 4, 4, 2, 2, name="g_dc3", use_sn=self.use_sn) 45 | net = tf.contrib.layers.batch_norm(net, 46 | updates_collections=None, is_training=self._is_train, center=True, scale=True, decay=0.9, epsilon=1e-5, scope="g_bn_deconv3") 47 | net = tf.nn.relu(net) 48 | 49 | net = deconv2d(net, [batch_size, self.s_h, self.s_w, self.colors], 3, 3, 1, 1, name="g_dc4", use_sn=self.use_sn) 50 | out = tf.div(tf.nn.tanh(net) + 1.0, 2.0) 51 | 52 | return out -------------------------------------------------------------------------------- /sngan_gen_y.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from ops import linear 4 | from ops import deconv2d 5 | from util import log 6 | from ops import non_local_block, conditional_batch_norm 7 | import time 8 | 9 | def conv_out_size_same(size, stride): 10 | return int(np.ceil(float(size) / float(stride))) 11 | 12 | class Generator(object): 13 | def __init__(self, name, h, w, c, is_train, use_sn): 14 | 15 | self.name = name 16 | self.s_h, self.s_w, self.colors = [h,w,c] 17 | self.s_h2, self.s_w2 = conv_out_size_same(self.s_h, 2), conv_out_size_same(self.s_w, 2) 18 | self.s_h4, self.s_w4 = conv_out_size_same(self.s_h2, 2), conv_out_size_same(self.s_w2, 2) 19 | self.s_h8, self.s_w8 = conv_out_size_same(self.s_h4, 2), conv_out_size_same(self.s_w4, 2) 20 | self._is_train = is_train 21 | self.use_sn = use_sn 22 | self._embed_y_dim = 128 23 | 24 | def __call__(self, z, y): 25 | 26 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 27 | batch_size = z.shape[0] 28 | 29 | y = tf.concat([z, y], axis=1) 30 | z = y 31 | 32 | net = linear(z, self.s_h8 * self.s_w8 * 512, scope="g_fc1", use_sn=self.use_sn) 33 | net = tf.reshape(net, [batch_size, self.s_h8, self.s_w8, 512]) 34 | net = conditional_batch_norm(net, y, is_training=self._is_train, use_sn = self.use_sn, center=True, scale=True, name="g_cbn_deconv0", use_bias=False) 35 | net = tf.nn.relu(net) 36 | 37 | net = deconv2d(net, [batch_size, self.s_h4, self.s_w4, 256], 4, 4, 2, 2, name="g_dc1", use_sn=self.use_sn) 38 | net = conditional_batch_norm(net, y, is_training=self._is_train, use_sn = self.use_sn, center=True, scale=True, name="g_cbn_deconv1", use_bias=False) 39 | net = tf.nn.relu(net) 40 | 41 | net = deconv2d(net, [batch_size, self.s_h2, self.s_w2, 128], 4, 4, 2, 2, name="g_dc2", use_sn=self.use_sn) 42 | net = conditional_batch_norm(net, y, is_training=self._is_train, use_sn = self.use_sn, center=True, scale=True, name="g_cbn_deconv2", use_bias=False) 43 | net = tf.nn.relu(net) 44 | 45 | net = deconv2d(net, [batch_size, self.s_h, self.s_w, 64], 4, 4, 2, 2, name="g_dc3", use_sn=self.use_sn) 46 | net = tf.contrib.layers.batch_norm(net, 47 | updates_collections=None, is_training=self._is_train, center=True, scale=True, decay=0.9, epsilon=1e-5, scope="g_cbn_deconv3") 48 | net = tf.nn.relu(net) 49 | 50 | net = deconv2d(net, [batch_size, self.s_h, self.s_w, self.colors], 3, 3, 1, 1, name="g_dc4", use_sn=self.use_sn) 51 | out = tf.div(tf.nn.tanh(net) + 1.0, 2.0) 52 | 53 | return out 54 | -------------------------------------------------------------------------------- /sngan_joint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from ops import conv2d 4 | from ops import linear 5 | from util import log 6 | from ops import lrelu 7 | from ops import non_local_block, spectral_norm 8 | import time 9 | from tensorflow.python.framework import ops 10 | from tensorflow.python.ops import gen_nn_ops 11 | 12 | import cv2 13 | 14 | def normalize(x): 15 | """Utility function to normalize a tensor by its L2 norm""" 16 | return tf.div(x + tf.constant(1e-10), tf.sqrt(tf.reduce_mean(tf.square(x))) + tf.constant(1e-10)) 17 | 18 | class Classifier_proD(object): 19 | def __init__(self, name, num_class, use_sn): 20 | 21 | self.name = name 22 | self._num_class = num_class 23 | self.use_sn = use_sn 24 | 25 | def __call__(self, input, y): 26 | 27 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 28 | 29 | input = input * 2.0 - 1.0 30 | 31 | net = conv2d(input, 64, 3, 3, 1, 1, name="d_conv1", use_sn=self.use_sn) 32 | net = lrelu(net, leak=0.1) 33 | 34 | net = conv2d(net, 128, 4, 4, 2, 2, name="d_conv2", use_sn=self.use_sn) 35 | net = lrelu(net, leak=0.1) 36 | 37 | net = conv2d(net, 128, 3, 3, 1, 1, name="d_conv3", use_sn=self.use_sn) 38 | net = lrelu(net, leak=0.1) 39 | 40 | net = conv2d(net, 256, 4, 4, 2, 2, name="d_conv4", use_sn=self.use_sn) 41 | net = lrelu(net, leak=0.1) 42 | 43 | net = conv2d(net, 256, 3, 3, 1, 1, name="d_conv5", use_sn=self.use_sn) 44 | net = lrelu(net, leak=0.1) 45 | 46 | net = conv2d(net, 512, 4, 4, 2, 2, name="d_conv6", use_sn=self.use_sn) 47 | net = lrelu(net, leak=0.1) 48 | 49 | net = conv2d(net, 512, 3, 3, 1, 1, name="d_conv7", use_sn=self.use_sn) 50 | net_conv = lrelu(net, leak=0.1) 51 | 52 | h = tf.layers.flatten(net_conv) 53 | out_logit = linear(h, self._num_class, scope="d_fc1", use_sn=self.use_sn) 54 | out_logit_tf = linear(h, 1, scope="final_fc", use_sn=self.use_sn) 55 | 56 | feature_matching = h 57 | 58 | log.info("[Discriminator] after final processing: %s", net_conv.shape) 59 | with tf.variable_scope("embedding_fc", reuse=tf.AUTO_REUSE): 60 | # We do not use ops.linear() below since it does not have an option to 61 | # override the initializer. 62 | kernel = tf.get_variable( 63 | "kernel", [y.shape[1], h.shape[1]], tf.float32, 64 | initializer=tf.initializers.glorot_normal()) 65 | if self.use_sn: 66 | kernel = spectral_norm(kernel) 67 | embedded_y = tf.matmul(y, kernel) 68 | out_logit_tf += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True) 69 | 70 | return tf.nn.softmax(out_logit), out_logit, tf.nn.sigmoid(out_logit_tf), out_logit_tf, feature_matching 71 | -------------------------------------------------------------------------------- /diffAugment.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def DiffAugment(x, policy="color,translation,cutout", channels_first=False): 4 | if policy: 5 | if channels_first: 6 | x = tf.transpose(x, [0, 2, 3, 1]) 7 | for p in policy.split(','): 8 | for f in AUGMENT_FNS[p]: 9 | x = f(x) 10 | if channels_first: 11 | x = tf.transpose(x, [0, 3, 1, 2]) 12 | return x 13 | 14 | 15 | def rand_brightness(x): 16 | magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) - 0.5 17 | x = x + magnitude 18 | return x 19 | 20 | 21 | def rand_saturation(x): 22 | magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) * 2 23 | x_mean = tf.reduce_mean(x, axis=3, keepdims=True) 24 | x = (x - x_mean) * magnitude + x_mean 25 | return x 26 | 27 | 28 | def rand_contrast(x): 29 | magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) + 0.5 30 | x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True) 31 | x = (x - x_mean) * magnitude + x_mean 32 | return x 33 | 34 | 35 | def rand_translation(x, ratio=0.125): 36 | batch_size = tf.shape(x)[0] 37 | image_size = tf.shape(x)[1:3] 38 | shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32) 39 | translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32) 40 | translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32) 41 | grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1, 0, image_size[0] + 1) 42 | grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1, 0, image_size[1] + 1) 43 | x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1) 44 | x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1), [0, 2, 1, 3]) 45 | return x 46 | 47 | 48 | def rand_cutout(x, ratio=0.5): 49 | batch_size = tf.shape(x)[0] 50 | image_size = tf.shape(x)[1:3] 51 | cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32) 52 | offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32) 53 | offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32) 54 | grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32), tf.range(cutout_size[0], dtype=tf.int32), tf.range(cutout_size[1], dtype=tf.int32), indexing='ij') 55 | cutout_grid = tf.stack([grid_batch, grid_x + offset_x - cutout_size[0] // 2, grid_y + offset_y - cutout_size[1] // 2], axis=-1) 56 | mask_shape = tf.stack([batch_size, image_size[0], image_size[1]]) 57 | cutout_grid = tf.maximum(cutout_grid, 0) 58 | cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3])) 59 | mask = tf.maximum(1 - tf.scatter_nd(cutout_grid, tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32), mask_shape), 0) 60 | x = x * tf.expand_dims(mask, axis=3) 61 | return x 62 | 63 | 64 | AUGMENT_FNS = { 65 | 'color': [rand_brightness, rand_saturation, rand_contrast], 66 | 'translation': [rand_translation], 67 | 'cutout': [rand_cutout], 68 | } -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from model_joint import Model 5 | import datasets.hdf5_loader as dataset 6 | import numpy as np 7 | import pandas as pd 8 | from six.moves import xrange 9 | from util import log 10 | 11 | def get_params(): 12 | 13 | def str2bool(v): 14 | return v.lower() == 'true' 15 | 16 | parser = argparse.ArgumentParser( 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | 19 | parser.add_argument('--debug', action='store_true', default=False) 20 | parser.add_argument('--prefix', type=str, default='GANs') 21 | parser.add_argument('--train_dir', type=str) 22 | parser.add_argument('--checkpoint', type=str, default=None) 23 | parser.add_argument('--dataset', type=str, default="Orgimage_128", choices=["Orgimage_128"]) 24 | parser.add_argument('--dump_result', type=str2bool, default=False) 25 | 26 | # Model 27 | parser.add_argument('--batch_size_G', type=int, default=16) 28 | parser.add_argument('--batch_size_L', type=int, default=16) 29 | parser.add_argument('--batch_size_U', type=int, default=16) 30 | parser.add_argument('--n_z', type=int, default=128) 31 | 32 | # Training config {{{ 33 | # ======== 34 | # log 35 | parser.add_argument('--log_step', type=int, default=10) 36 | parser.add_argument('--save_image_step', type=int, default=10000) 37 | parser.add_argument('--test_sample_step', type=int, default=100) 38 | parser.add_argument('--output_save_step', type=int, default=10000) 39 | # learning 40 | parser.add_argument('--max_training_steps', type=int, default=250001) 41 | parser.add_argument('--learning_rate_g', type=float, default=2e-4) 42 | parser.add_argument('--learning_rate_d', type=float, default=5e-5) 43 | parser.add_argument('--update_rate', type=int, default=2) 44 | # }}} 45 | 46 | # Testing config {{{ 47 | # ======== 48 | parser.add_argument('--data_id', nargs='*', default=None) 49 | # }}} 50 | 51 | config, _ = parser.parse_known_args() 52 | args = parser.parse_args() 53 | 54 | return config, args 55 | 56 | def argparser(config, is_train=True): 57 | 58 | dataset_path = os.path.join('./datasets/TMI/', config["dataset"].lower()) 59 | dataset_train, dataset_val, dataset_test = dataset.create_default_splits(dataset_path) 60 | dataset_train_unlabel, _ = dataset.create_default_splits_unlabel(dataset_path) 61 | 62 | config["img"] = [] 63 | labels = [] 64 | with open('./datasets/metadata.tsv','w') as f: 65 | f.write("Index\tLabel\n") 66 | for index, labeldex in enumerate(dataset_test.ids): 67 | config["img"].append(dataset_test.get_data(labeldex)[0]) 68 | label = np.argmax(dataset_test.get_data(labeldex)[1]) 69 | labels.append(label) 70 | f.write("%d\t%d\n" % (index, label)) 71 | config["img"] = np.array(config["img"]) 72 | log.info(config["img"].shape) 73 | config["len"] = config["img"].shape[0] 74 | config["label"] = labels 75 | 76 | config["Size"] = config["len"] 77 | picname = [] 78 | for i in xrange(config["Size"]): 79 | picname.append("V{step:04d}.jpg".format(step=i+1)) 80 | Csv=pd.DataFrame(columns=['Label'], index=picname, data=labels) 81 | Csv.to_csv('./datasets/Classification_Results_label.csv',encoding='gbk') 82 | 83 | img, label = dataset_train.get_data(dataset_train.ids[0]) 84 | config["h"] = img.shape[0] 85 | config["w"] = img.shape[1] 86 | config["c"] = img.shape[2] 87 | config["num_class"] = label.shape[0] 88 | 89 | # --- create model --- 90 | model = Model(config, debug_information=config["debug"], is_train=is_train) 91 | 92 | return config, model, dataset_train, dataset_train_unlabel, dataset_val, dataset_test 93 | -------------------------------------------------------------------------------- /BigGAN/resnet_ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ResNet specific operations. 17 | Defines the default ResNet generator and discriminator blocks and some helper 18 | operations such as unpooling. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import math 26 | 27 | from compare_gan.architectures import abstract_arch 28 | from compare_gan.architectures import arch_ops as ops 29 | 30 | from six.moves import range 31 | import tensorflow as tf 32 | 33 | 34 | def unpool(value, name="unpool"): 35 | """Unpooling operation. 36 | N-dimensional version of the unpooling operation from 37 | https://www.robots.ox.ac.uk/~vgg/rg/papers/Dosovitskiy_Learning_to_Generate_2015_CVPR_paper.pdf 38 | Taken from: https://github.com/tensorflow/tensorflow/issues/2169 39 | Args: 40 | value: a Tensor of shape [b, d0, d1, ..., dn, ch] 41 | name: name of the op 42 | Returns: 43 | A Tensor of shape [b, 2*d0, 2*d1, ..., 2*dn, ch] 44 | """ 45 | with tf.name_scope(name) as scope: 46 | sh = value.get_shape().as_list() 47 | dim = len(sh[1:-1]) 48 | out = (tf.reshape(value, [-1] + sh[-dim:])) 49 | for i in range(dim, 0, -1): 50 | out = tf.concat([out, tf.zeros_like(out)], i) 51 | out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]] 52 | out = tf.reshape(out, out_size, name=scope) 53 | return out 54 | 55 | 56 | def validate_image_inputs(inputs, validate_power2=True): 57 | inputs.get_shape().assert_has_rank(4) 58 | inputs.get_shape()[1:3].assert_is_fully_defined() 59 | # if inputs.get_shape()[1] != inputs.get_shape()[2]: 60 | # raise ValueError("Input tensor does not have equal width and height: ", 61 | # inputs.get_shape()[1:3]) 62 | width = inputs.get_shape().as_list()[1] 63 | if validate_power2 and math.log(width, 2) != int(math.log(width, 2)): 64 | raise ValueError("Input tensor `width` is not a power of 2: ", width) 65 | 66 | 67 | class ResNetBlock(object): 68 | """ResNet block with options for various normalizations.""" 69 | 70 | def __init__(self, 71 | name, 72 | in_channels, 73 | out_channels, 74 | scale, 75 | is_gen_block, 76 | layer_norm=False, 77 | spectral_norm=False, 78 | batch_norm=None): 79 | """Constructs a new ResNet block. 80 | Args: 81 | name: Scope name for the resent block. 82 | in_channels: Integer, the input channel size. 83 | out_channels: Integer, the output channel size. 84 | scale: Whether or not to scale up or down, choose from "up", "down" or 85 | "none". 86 | is_gen_block: Boolean, deciding whether this is a generator or 87 | discriminator block. 88 | layer_norm: Apply layer norm before both convolutions. 89 | spectral_norm: Use spectral normalization for all weights. 90 | batch_norm: Function for batch normalization. 91 | """ 92 | assert scale in ["up", "down", "none"] 93 | self._name = name 94 | self._in_channels = in_channels 95 | self._out_channels = out_channels 96 | self._scale = scale 97 | # In SN paper, if they upscale in generator they do this in the first conv. 98 | # For discriminator downsampling happens after second conv. 99 | self._scale1 = scale if is_gen_block else "none" 100 | self._scale2 = "none" if is_gen_block else scale 101 | self._layer_norm = layer_norm 102 | self._spectral_norm = spectral_norm 103 | self.batch_norm = batch_norm 104 | 105 | def __call__(self, inputs, z, y, is_training): 106 | return self.apply(inputs=inputs, z=z, y=y, is_training=is_training) 107 | 108 | def _get_conv(self, inputs, in_channels, out_channels, scale, suffix, 109 | kernel_size=(3, 3), strides=(1, 1)): 110 | """Performs a convolution in the ResNet block.""" 111 | if inputs.get_shape().as_list()[-1] != in_channels: 112 | raise ValueError("Unexpected number of input channels.") 113 | if scale not in ["up", "down", "none"]: 114 | raise ValueError( 115 | "Scale: got {}, expected 'up', 'down', or 'none'.".format(scale)) 116 | 117 | outputs = inputs 118 | if scale == "up": 119 | outputs = unpool(outputs) 120 | outputs = ops.conv2d( 121 | outputs, 122 | output_dim=out_channels, 123 | k_h=kernel_size[0], k_w=kernel_size[1], 124 | d_h=strides[0], d_w=strides[1], 125 | use_sn=self._spectral_norm, 126 | name="{}_{}".format("same" if scale == "none" else scale, suffix)) 127 | if scale == "down": 128 | outputs = tf.nn.pool(outputs, [2, 2], "AVG", "SAME", strides=[2, 2], 129 | name="pool_%s" % suffix) 130 | return outputs 131 | 132 | def apply(self, inputs, z, y, is_training): 133 | """"ResNet block containing possible down/up sampling, shared for G / D. 134 | Args: 135 | inputs: a 3d input tensor of feature map. 136 | z: the latent vector for potential self-modulation. Can be None if use_sbn 137 | is set to False. 138 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 139 | labels. 140 | is_training: boolean, whether or notthis is called during the training. 141 | Returns: 142 | output: a 3d output tensor of feature map. 143 | """ 144 | if inputs.get_shape().as_list()[-1] != self._in_channels: 145 | raise ValueError("Unexpected number of input channels.") 146 | 147 | with tf.variable_scope(self._name, values=[inputs]): 148 | output = inputs 149 | 150 | shortcut = self._get_conv( 151 | output, self._in_channels, self._out_channels, self._scale, 152 | suffix="conv_shortcut") 153 | 154 | output = self.batch_norm( 155 | output, z=z, y=y, is_training=is_training, name="bn1") 156 | if self._layer_norm: 157 | output = ops.layer_norm(output, is_training=is_training, scope="ln1") 158 | 159 | output = tf.nn.relu(output) 160 | output = self._get_conv( 161 | output, self._in_channels, self._out_channels, self._scale1, 162 | suffix="conv1") 163 | 164 | output = self.batch_norm( 165 | output, z=z, y=y, is_training=is_training, name="bn2") 166 | if self._layer_norm: 167 | output = ops.layer_norm(output, is_training=is_training, scope="ln2") 168 | 169 | output = tf.nn.relu(output) 170 | output = self._get_conv( 171 | output, self._out_channels, self._out_channels, self._scale2, 172 | suffix="conv2") 173 | 174 | # Combine skip-connection with the convolved part. 175 | output += shortcut 176 | return output 177 | 178 | 179 | class ResNetGenerator(abstract_arch.AbstractGenerator): 180 | """Abstract base class for generators based on the ResNet architecture.""" 181 | 182 | def _resnet_block(self, name, in_channels, out_channels, scale): 183 | """ResNet block for the generator.""" 184 | if scale not in ["up", "none"]: 185 | raise ValueError( 186 | "Unknown generator ResNet block scaling: {}.".format(scale)) 187 | return ResNetBlock( 188 | name=name, 189 | in_channels=in_channels, 190 | out_channels=out_channels, 191 | scale=scale, 192 | is_gen_block=True, 193 | spectral_norm=self._spectral_norm, 194 | batch_norm=self.batch_norm) 195 | 196 | 197 | class ResNetDiscriminator(abstract_arch.AbstractDiscriminator): 198 | """Abstract base class for discriminators based on the ResNet architecture.""" 199 | 200 | def _resnet_block(self, name, in_channels, out_channels, scale): 201 | """ResNet block for the generator.""" 202 | if scale not in ["down", "none"]: 203 | raise ValueError( 204 | "Unknown discriminator ResNet block scaling: {}.".format(scale)) 205 | return ResNetBlock( 206 | name=name, 207 | in_channels=in_channels, 208 | out_channels=out_channels, 209 | scale=scale, 210 | is_gen_block=False, 211 | layer_norm=self._layer_norm, 212 | spectral_norm=self._spectral_norm, 213 | batch_norm=self.batch_norm) -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as layers 3 | import tensorflow.contrib.slim as slim 4 | from util import log 5 | import functools 6 | 7 | def weight_initializer(initializer="org", stddev=0.02): 8 | """Returns the initializer for the given name. 9 | Args: 10 | initializer: Name of the initalizer. Use one in consts.INITIALIZERS. 11 | stddev: Standard deviation passed to initalizer. 12 | Returns: 13 | Initializer from `tf.initializers`. 14 | """ 15 | if initializer == "normal": 16 | return tf.initializers.random_normal(stddev=stddev) 17 | if initializer == "truncated": 18 | return tf.initializers.truncated_normal(stddev=stddev) 19 | if initializer == "org": 20 | return tf.initializers.orthogonal() 21 | raise ValueError("Unknown weight initializer {}.".format(initializer)) 22 | 23 | def lrelu(inputs, leak=0.2, name = "lrelu"): 24 | """Performs leaky-ReLU on the input.""" 25 | return tf.maximum(inputs, leak * inputs, name=name) 26 | 27 | def spectral_norm(inputs, epsilon=1e-12, singular_value="auto"): 28 | """Performs Spectral Normalization on a weight tensor. 29 | Details of why this is helpful for GAN's can be found in "Spectral 30 | Normalization for Generative Adversarial Networks", Miyato T. et al., 2018. 31 | [https://arxiv.org/abs/1802.05957]. 32 | Args: 33 | inputs: The weight tensor to normalize. 34 | epsilon: Epsilon for L2 normalization. 35 | singular_value: Which first singular value to store (left or right). Use 36 | "auto" to automatically choose the one that has fewer dimensions. 37 | Returns: 38 | The normalized weight tensor. 39 | """ 40 | if len(inputs.shape) < 2: 41 | raise ValueError("Spectral norm can only be applied to multi-dimensional tensors") 42 | 43 | # The paper says to flatten convnet kernel weights from (C_out, C_in, KH, KW) 44 | # to (C_out, C_in * KH * KW). Our Conv2D kernel shape is (KH, KW, C_in, C_out) 45 | # so it should be reshaped to (KH * KW * C_in, C_out), and similarly for other 46 | # layers that put output channels as last dimension. This implies that w 47 | # here is equivalent to w.T in the paper. 48 | w = tf.reshape(inputs, (-1, inputs.shape[-1])) 49 | 50 | # Choose whether to persist the first left or first right singular vector. 51 | # As the underlying matrix is PSD, this should be equivalent, but in practice 52 | # the shape of the persisted vector is different. Here one can choose whether 53 | # to maintain the left or right one, or pick the one which has the smaller 54 | # dimension. We use the same variable for the singular vector if we switch 55 | # from normal weights to EMA weights. 56 | var_name = inputs.name.replace("/ExponentialMovingAverage", "").split("/")[-1] 57 | var_name = var_name.split(":")[0] + "/u_var" 58 | if singular_value == "auto": 59 | singular_value = "left" if w.shape[0] <= w.shape[1] else "right" 60 | u_shape = (w.shape[0], 1) if singular_value == "left" else (1, w.shape[-1]) 61 | u_var = tf.get_variable( 62 | var_name, 63 | shape=u_shape, 64 | dtype=w.dtype, 65 | initializer=tf.random_normal_initializer(), 66 | trainable=False) 67 | u = u_var 68 | 69 | # Use power iteration method to approximate the spectral norm. 70 | # The authors suggest that one round of power iteration was sufficient in the 71 | # actual experiment to achieve satisfactory performance. 72 | power_iteration_rounds = 1 73 | for _ in range(power_iteration_rounds): 74 | if singular_value == "left": 75 | # `v` approximates the first right singular vector of matrix `w`. 76 | v = tf.math.l2_normalize( 77 | tf.matmul(tf.transpose(w), u), axis=None, epsilon=epsilon) 78 | u = tf.math.l2_normalize(tf.matmul(w, v), axis=None, epsilon=epsilon) 79 | else: 80 | v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True), 81 | epsilon=epsilon) 82 | u = tf.math.l2_normalize(tf.matmul(v, w), epsilon=epsilon) 83 | 84 | # Update the approximation. 85 | with tf.control_dependencies([tf.assign(u_var, u, name="update_u")]): 86 | u = tf.identity(u) 87 | 88 | # The authors of SN-GAN chose to stop gradient propagating through u and v 89 | # and we maintain that option. 90 | u = tf.stop_gradient(u) 91 | v = tf.stop_gradient(v) 92 | 93 | if singular_value == "left": 94 | norm_value = tf.matmul(tf.matmul(tf.transpose(u), w), v) 95 | else: 96 | norm_value = tf.matmul(tf.matmul(v, w), u, transpose_b=True) 97 | norm_value.shape.assert_is_fully_defined() 98 | norm_value.shape.assert_is_compatible_with([1, 1]) 99 | 100 | w_normalized = w / norm_value 101 | 102 | # Deflate normalized weights to match the unnormalized tensor. 103 | w_tensor_normalized = tf.reshape(w_normalized, inputs.shape) 104 | return w_tensor_normalized 105 | 106 | def conv2d(inputs, output_dim, k_h, k_w, d_h, d_w, stddev=0.02, 107 | name="conv2d", use_sn=False, use_bias=True): 108 | with tf.variable_scope(name): 109 | w = tf.get_variable( 110 | "kernel", [k_h, k_w, inputs.shape[-1].value, output_dim], 111 | initializer=weight_initializer(stddev=stddev)) 112 | if use_sn: 113 | w = spectral_norm(w) 114 | outputs = tf.nn.conv2d(inputs, w, strides=[1, d_h, d_w, 1], padding="SAME") 115 | if use_bias: 116 | bias = tf.get_variable( 117 | "bias", [output_dim], initializer=tf.constant_initializer(0.0)) 118 | outputs += bias 119 | return outputs 120 | 121 | def deconv2d(inputs, output_shape, k_h, k_w, d_h, d_w, stddev=0.02, 122 | name='deconv2d', use_sn=False): 123 | with tf.variable_scope(name): 124 | w = tf.get_variable( 125 | "kernel", [k_h, k_w, output_shape[-1], inputs.get_shape()[-1]], 126 | initializer=weight_initializer(stddev=stddev)) 127 | if use_sn: 128 | w = spectral_norm(w) 129 | deconv = tf.nn.conv2d_transpose( 130 | inputs, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) 131 | bias = tf.get_variable( 132 | "bias", [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 133 | return tf.reshape(tf.nn.bias_add(deconv, bias), tf.shape(deconv)) 134 | 135 | def linear(inputs, output_size, scope=None, stddev=0.02, bias_start=0.0, use_sn=False, use_bias=True): 136 | shape = inputs.get_shape().as_list() 137 | with tf.variable_scope(scope or "Linear"): 138 | kernel = tf.get_variable( 139 | "kernel", 140 | [shape[1], output_size], 141 | initializer=weight_initializer(stddev=stddev)) 142 | if use_sn: 143 | kernel = spectral_norm(kernel) 144 | outputs = tf.matmul(inputs, kernel) 145 | if use_bias: 146 | bias = tf.get_variable( 147 | "bias", 148 | [output_size], 149 | initializer=tf.constant_initializer(bias_start)) 150 | outputs += bias 151 | return outputs 152 | 153 | conv1x1 = functools.partial(conv2d, k_h=1, k_w=1, d_h=1, d_w=1) 154 | 155 | def non_local_block(x, name, use_sn): 156 | """Self-attention (non-local) block. 157 | This method is used to exactly reproduce SAGAN and ignores Gin settings on 158 | weight initialization and spectral normalization. 159 | Args: 160 | x: Input tensor of shape [batch, h, w, c]. 161 | name: Name of the variable scope. 162 | use_sn: Apply spectral norm to the weights. 163 | Returns: 164 | A tensor of the same shape after self-attention was applied. 165 | """ 166 | def _spatial_flatten(inputs): 167 | shape = inputs.shape 168 | return tf.reshape(inputs, (-1, shape[1] * shape[2], shape[3])) 169 | 170 | with tf.variable_scope(name): 171 | h, w, num_channels = x.get_shape().as_list()[1:] 172 | num_channels_attn = num_channels // 8 173 | num_channels_g = num_channels // 2 174 | 175 | # Theta path 176 | theta = conv1x1(x, num_channels_attn, name="conv2d_theta", use_sn=use_sn, 177 | use_bias=False) 178 | theta = _spatial_flatten(theta) 179 | 180 | # Phi path 181 | phi = conv1x1(x, num_channels_attn, name="conv2d_phi", use_sn=use_sn, use_bias=False) 182 | phi = tf.layers.max_pooling2d(inputs=phi, pool_size=[2, 2], strides=2) 183 | phi = _spatial_flatten(phi) 184 | 185 | attn = tf.matmul(theta, phi, transpose_b=True) 186 | attn = tf.nn.softmax(attn) 187 | 188 | # G path 189 | g = conv1x1(x, num_channels_g, name="conv2d_g", use_sn=use_sn, 190 | use_bias=False) 191 | g = tf.layers.max_pooling2d(inputs=g, pool_size=[2, 2], strides=2) 192 | g = _spatial_flatten(g) 193 | 194 | attn_g = tf.matmul(attn, g) 195 | attn_g = tf.reshape(attn_g, [-1, h, w, num_channels_g]) 196 | sigma = tf.get_variable("sigma", [], initializer=tf.zeros_initializer()) 197 | attn_g = conv1x1(attn_g, num_channels, name="conv2d_attn_g", use_sn=use_sn, 198 | use_bias=False) 199 | 200 | return x + sigma * attn_g 201 | 202 | def conditional_batch_norm(inputs, y, is_training, use_sn, center=True, 203 | scale=True, name="batch_norm", use_bias=False): 204 | """Conditional batch normalization.""" 205 | if y is None: 206 | raise ValueError("You must provide y for conditional batch normalization.") 207 | if y.shape.ndims != 2: 208 | raise ValueError("Conditioning must have rank 2.") 209 | with tf.variable_scope(name, values=[inputs]): 210 | outputs = tf.contrib.layers.batch_norm(inputs, center=False, scale=False, decay=0.9, epsilon=1e-5, updates_collections=None, is_training=is_training) 211 | num_channels = inputs.shape[-1].value 212 | with tf.variable_scope("condition", values=[inputs, y]): 213 | if scale: 214 | gamma = linear(y, num_channels, scope="gamma", use_sn=use_sn, 215 | use_bias=use_bias) 216 | gamma = tf.reshape(gamma, [-1, 1, 1, num_channels]) 217 | outputs *= gamma 218 | if center: 219 | beta = linear(y, num_channels, scope="beta", use_sn=use_sn, 220 | use_bias=use_bias) 221 | beta = tf.reshape(beta, [-1, 1, 1, num_channels]) 222 | outputs += beta 223 | return outputs -------------------------------------------------------------------------------- /BigGAN/resnet_gen_y.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from absl import logging 6 | 7 | import arch_ops as ops 8 | import resnet_ops 9 | 10 | import utils 11 | 12 | import gin 13 | from six.moves import range 14 | import tensorflow as tf 15 | 16 | class BigGanResNetBlock(resnet_ops.ResNetBlock): 17 | """ResNet block with options for various normalizations. 18 | This block uses a 1x1 convolution for the (optional) shortcut connection. 19 | """ 20 | 21 | def __init__(self, 22 | add_shortcut=True, 23 | **kwargs): 24 | """Constructs a new ResNet block for BigGAN. 25 | Args: 26 | add_shortcut: Whether to add a shortcut connection. 27 | **kwargs: Additional arguments for ResNetBlock. 28 | """ 29 | super(BigGanResNetBlock, self).__init__(**kwargs) 30 | self._add_shortcut = add_shortcut 31 | 32 | def apply(self, inputs, z, y, is_training): 33 | """"ResNet block containing possible down/up sampling, shared for G / D. 34 | Args: 35 | inputs: a 3d input tensor of feature map. 36 | z: the latent vector for potential self-modulation. Can be None if use_sbn 37 | is set to False. 38 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 39 | labels. 40 | is_training: boolean, whether or notthis is called during the training. 41 | Returns: 42 | output: a 3d output tensor of feature map. 43 | """ 44 | if inputs.shape[-1].value != self._in_channels: 45 | raise ValueError( 46 | "Unexpected number of input channels (expected {}, got {}).".format( 47 | self._in_channels, inputs.shape[-1].value)) 48 | 49 | with tf.variable_scope(self._name, values=[inputs]): 50 | outputs = inputs 51 | 52 | outputs = self.batch_norm( 53 | outputs, z=z, y=y, is_training=is_training, name="bn1") 54 | if self._layer_norm: 55 | outputs = ops.layer_norm(outputs, is_training=is_training, scope="ln1") 56 | 57 | outputs = tf.nn.relu(outputs) 58 | outputs = self._get_conv( 59 | outputs, self._in_channels, self._out_channels, self._scale1, 60 | suffix="conv1") 61 | 62 | outputs = self.batch_norm( 63 | outputs, z=z, y=y, is_training=is_training, name="bn2") 64 | if self._layer_norm: 65 | outputs = ops.layer_norm(outputs, is_training=is_training, scope="ln2") 66 | 67 | outputs = tf.nn.relu(outputs) 68 | outputs = self._get_conv( 69 | outputs, self._out_channels, self._out_channels, self._scale2, 70 | suffix="conv2") 71 | 72 | # Combine skip-connection with the convolved part. 73 | if self._add_shortcut: 74 | shortcut = self._get_conv( 75 | inputs, self._in_channels, self._out_channels, self._scale, 76 | kernel_size=(1, 1), 77 | suffix="conv_shortcut") 78 | outputs += shortcut 79 | logging.info("[Block] %s (z=%s, y=%s) -> %s", inputs.shape, 80 | None if z is None else z.shape, 81 | None if y is None else y.shape, outputs.shape) 82 | return outputs 83 | 84 | class Generator(object): 85 | """ResNet-based generator supporting resolutions 32, 64, 128, 256, 512.""" 86 | 87 | def __init__(self, 88 | name, h, w, c, is_train, use_sn, 89 | ch=96, 90 | blocks_with_attention="B4", 91 | hierarchical_z=True, 92 | embed_z=False, 93 | embed_y=True, 94 | embed_y_dim=128, 95 | embed_bias=False, 96 | **kwargs): 97 | """Constructor for BigGAN generator. 98 | Args: 99 | ch: Channel multiplier. 100 | blocks_with_attention: Comma-separated list of blocks that are followed by 101 | a non-local block. 102 | hierarchical_z: Split z into chunks and only give one chunk to each. 103 | Each chunk will also be concatenated to y, the one hot encoded labels. 104 | embed_z: If True use a learnable embedding of z that is used instead. 105 | The embedding will have the length of z. 106 | embed_y: If True use a learnable embedding of y that is used instead. 107 | embed_y_dim: Size of the embedding of y. 108 | embed_bias: Use bias with for the embedding of z and y. 109 | **kwargs: additional arguments past on to ResNetGenerator. 110 | """ 111 | self.name = name 112 | self.s_h, self.s_w, self.colors = [h,w,c] 113 | self._image_shape = [h,w,c] 114 | self._is_train = is_train 115 | self._batch_norm_fn = ops.conditional_batch_norm 116 | self._spectral_norm = True 117 | 118 | self._ch = ch 119 | self._blocks_with_attention = set(blocks_with_attention.split(",")) 120 | self._hierarchical_z = hierarchical_z 121 | self._embed_z = embed_z 122 | self._embed_y = embed_y 123 | self._embed_y_dim = embed_y_dim 124 | self._embed_bias = embed_bias 125 | 126 | def _resnet_block(self, name, in_channels, out_channels, scale): 127 | """ResNet block for the generator.""" 128 | if scale not in ["up", "none"]: 129 | raise ValueError( 130 | "Unknown generator ResNet block scaling: {}.".format(scale)) 131 | return BigGanResNetBlock( 132 | name=name, 133 | in_channels=in_channels, 134 | out_channels=out_channels, 135 | scale=scale, 136 | is_gen_block=True, 137 | spectral_norm=self._spectral_norm, 138 | batch_norm=self.batch_norm) 139 | 140 | def batch_norm(self, inputs, **kwargs): 141 | if self._batch_norm_fn is None: 142 | return inputs 143 | args = kwargs.copy() 144 | args["inputs"] = inputs 145 | if "use_sn" not in args: 146 | args["use_sn"] = self._spectral_norm 147 | return utils.call_with_accepted_args(self._batch_norm_fn, **args) 148 | 149 | def _get_in_out_channels(self): 150 | resolution = self._image_shape[0] 151 | if resolution == 512: 152 | channel_multipliers = [16, 16, 8, 8, 4, 2, 1, 1] 153 | elif resolution == 256: 154 | channel_multipliers = [16, 16, 8, 8, 4, 2, 1] 155 | elif resolution == 128: 156 | channel_multipliers = [16, 16, 8, 4, 2, 1] 157 | elif resolution == 64: 158 | channel_multipliers = [16, 16, 8, 4, 2] 159 | elif resolution == 32: 160 | channel_multipliers = [4, 4, 4, 4] 161 | else: 162 | raise ValueError("Unsupported resolution: {}".format(resolution)) 163 | in_channels = [self._ch * c for c in channel_multipliers[:-1]] 164 | out_channels = [self._ch * c for c in channel_multipliers[1:]] 165 | return in_channels, out_channels 166 | 167 | def __call__(self, z, y): 168 | with tf.variable_scope(self.name, values=[z, y], reuse=tf.AUTO_REUSE): 169 | """Build the generator network for the given inputs. 170 | Args: 171 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 172 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 173 | labels. 174 | is_training: boolean, are we in train or eval model. 175 | Returns: 176 | A tensor of size [batch_size] + self._image_shape with values in [0, 1]. 177 | """ 178 | shape_or_none = lambda t: None if t is None else t.shape 179 | logging.info("[Generator] inputs are z=%s, y=%s", z.shape, shape_or_none(y)) 180 | # Each block upscales by a factor of 2. 181 | seed_size = 4 182 | z_dim = z.shape[1].value 183 | 184 | in_channels, out_channels = self._get_in_out_channels() 185 | num_blocks = len(in_channels) 186 | 187 | if self._embed_z: 188 | z = ops.linear(z, z_dim, scope="embed_z", use_sn=False, 189 | use_bias=self._embed_bias) 190 | if self._embed_y: 191 | y = ops.linear(y, self._embed_y_dim, scope="embed_y", use_sn=False, 192 | use_bias=self._embed_bias) 193 | y_per_block = num_blocks * [y] 194 | if self._hierarchical_z: 195 | z_per_block = tf.split(z, num_blocks + 1, axis=1) 196 | z0, z_per_block = z_per_block[0], z_per_block[1:] 197 | if y is not None: 198 | y_per_block = [tf.concat([zi, y], 1) for zi in z_per_block] 199 | else: 200 | z0 = z 201 | z_per_block = num_blocks * [z] 202 | 203 | logging.info("[Generator] z0=%s, z_per_block=%s, y_per_block=%s", 204 | z0.shape, [str(shape_or_none(t)) for t in z_per_block], 205 | [str(shape_or_none(t)) for t in y_per_block]) 206 | 207 | # Map noise to the actual seed. 208 | net = ops.linear( 209 | z0, 210 | in_channels[0] * seed_size * seed_size, 211 | scope="fc_noise", 212 | use_sn=self._spectral_norm) 213 | # Reshape the seed to be a rank-4 Tensor. 214 | net = tf.reshape( 215 | net, 216 | [-1, seed_size, seed_size, in_channels[0]], 217 | name="fc_reshaped") 218 | 219 | for block_idx in range(num_blocks): 220 | name = "B{}".format(block_idx + 1) 221 | block = self._resnet_block( 222 | name=name, 223 | in_channels=in_channels[block_idx], 224 | out_channels=out_channels[block_idx], 225 | scale="up") 226 | net = block( 227 | net, 228 | z=z_per_block[block_idx], 229 | y=y_per_block[block_idx], 230 | is_training=self._is_train) 231 | if name in self._blocks_with_attention: 232 | logging.info("[Generator] Applying non-local block to %s", net.shape) 233 | net = ops.non_local_block(net, "non_local_block", 234 | use_sn=self._spectral_norm) 235 | # Final processing of the net. 236 | # Use unconditional batch norm. 237 | logging.info("[Generator] before final processing: %s", net.shape) 238 | net = ops.batch_norm(net, is_training=self._is_train, name="final_norm") 239 | net = tf.nn.relu(net) 240 | net = ops.conv2d(net, output_dim=self._image_shape[2], k_h=3, k_w=3, 241 | d_h=1, d_w=1, name="final_conv", 242 | use_sn=self._spectral_norm) 243 | logging.info("[Generator] after final processing: %s", net.shape) 244 | net = (tf.nn.tanh(net) + 1.0) / 2.0 245 | return net -------------------------------------------------------------------------------- /BigGAN/resnet_joint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from absl import logging 6 | 7 | import arch_ops as ops 8 | import resnet_ops 9 | 10 | import gin 11 | from six.moves import range 12 | import tensorflow as tf 13 | 14 | 15 | class BigGanResNetBlock(resnet_ops.ResNetBlock): 16 | """ResNet block with options for various normalizations. 17 | This block uses a 1x1 convolution for the (optional) shortcut connection. 18 | """ 19 | 20 | def __init__(self, 21 | add_shortcut=True, 22 | **kwargs): 23 | """Constructs a new ResNet block for BigGAN. 24 | Args: 25 | add_shortcut: Whether to add a shortcut connection. 26 | **kwargs: Additional arguments for ResNetBlock. 27 | """ 28 | super(BigGanResNetBlock, self).__init__(**kwargs) 29 | self._add_shortcut = add_shortcut 30 | 31 | def apply(self, inputs, z, y, is_training): 32 | """"ResNet block containing possible down/up sampling, shared for G / D. 33 | Args: 34 | inputs: a 3d input tensor of feature map. 35 | z: the latent vector for potential self-modulation. Can be None if use_sbn 36 | is set to False. 37 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 38 | labels. 39 | is_training: boolean, whether or notthis is called during the training. 40 | Returns: 41 | output: a 3d output tensor of feature map. 42 | """ 43 | if inputs.shape[-1].value != self._in_channels: 44 | raise ValueError( 45 | "Unexpected number of input channels (expected {}, got {}).".format( 46 | self._in_channels, inputs.shape[-1].value)) 47 | 48 | with tf.variable_scope(self._name, values=[inputs]): 49 | outputs = inputs 50 | 51 | outputs = self.batch_norm( 52 | outputs, z=z, y=y, is_training=is_training, name="bn1") 53 | if self._layer_norm: 54 | outputs = ops.layer_norm(outputs, is_training=is_training, scope="ln1") 55 | 56 | outputs = tf.nn.relu(outputs) 57 | outputs = self._get_conv( 58 | outputs, self._in_channels, self._out_channels, self._scale1, 59 | suffix="conv1") 60 | 61 | outputs = self.batch_norm( 62 | outputs, z=z, y=y, is_training=is_training, name="bn2") 63 | if self._layer_norm: 64 | outputs = ops.layer_norm(outputs, is_training=is_training, scope="ln2") 65 | 66 | outputs = tf.nn.relu(outputs) 67 | outputs = self._get_conv( 68 | outputs, self._out_channels, self._out_channels, self._scale2, 69 | suffix="conv2") 70 | 71 | # Combine skip-connection with the convolved part. 72 | if self._add_shortcut: 73 | shortcut = self._get_conv( 74 | inputs, self._in_channels, self._out_channels, self._scale, 75 | kernel_size=(1, 1), 76 | suffix="conv_shortcut") 77 | outputs += shortcut 78 | logging.info("[Block] %s (z=%s, y=%s) -> %s", inputs.shape, 79 | None if z is None else z.shape, 80 | None if y is None else y.shape, outputs.shape) 81 | return outputs 82 | 83 | class Classifier_proD(object): 84 | """ResNet-based discriminator supporting resolutions 32, 64, 128, 256, 512.""" 85 | 86 | def __init__(self, 87 | name, num_class, use_sn, 88 | ch=96, 89 | blocks_with_attention="B1", 90 | project_y=True, 91 | **kwargs): 92 | """Constructor for BigGAN discriminator. 93 | Args: 94 | ch: Channel multiplier. 95 | blocks_with_attention: Comma-separated list of blocks that are followed by 96 | a non-local block. 97 | project_y: Add an embedding of y in the output layer. 98 | **kwargs: additional arguments past on to ResNetDiscriminator. 99 | """ 100 | self.name = name 101 | self.num_class = num_class 102 | self._spectral_norm = True 103 | self._batch_norm_fn = None 104 | self._layer_norm = False 105 | 106 | self._ch = ch 107 | self._blocks_with_attention = set(blocks_with_attention.split(",")) 108 | self._project_y = project_y 109 | 110 | def _resnet_block(self, name, in_channels, out_channels, scale): 111 | """ResNet block for the generator.""" 112 | if scale not in ["down", "none"]: 113 | raise ValueError( 114 | "Unknown discriminator ResNet block scaling: {}.".format(scale)) 115 | return BigGanResNetBlock( 116 | name=name, 117 | in_channels=in_channels, 118 | out_channels=out_channels, 119 | scale=scale, 120 | is_gen_block=False, 121 | add_shortcut=in_channels != out_channels, 122 | layer_norm=self._layer_norm, 123 | spectral_norm=self._spectral_norm, 124 | batch_norm=self.batch_norm) 125 | 126 | def batch_norm(self, inputs, **kwargs): 127 | if self._batch_norm_fn is None: 128 | return inputs 129 | args = kwargs.copy() 130 | args["inputs"] = inputs 131 | if "use_sn" not in args: 132 | args["use_sn"] = self._spectral_norm 133 | return utils.call_with_accepted_args(self._batch_norm_fn, **args) 134 | 135 | def _get_in_out_channels(self, colors, resolution): 136 | if colors not in [1, 3]: 137 | raise ValueError("Unsupported color channels: {}".format(colors)) 138 | if resolution == 512: 139 | channel_multipliers = [1, 1, 2, 4, 8, 8, 16, 16] 140 | elif resolution == 256: 141 | channel_multipliers = [1, 2, 4, 8, 8, 16, 16] 142 | elif resolution == 128: 143 | channel_multipliers = [1, 2, 4, 8, 16, 16] 144 | elif resolution == 64: 145 | channel_multipliers = [2, 4, 8, 16, 16] 146 | elif resolution == 32: 147 | channel_multipliers = [2, 2, 2, 2] 148 | else: 149 | raise ValueError("Unsupported resolution: {}".format(resolution)) 150 | out_channels = [self._ch * c for c in channel_multipliers] 151 | in_channels = [colors] + out_channels[:-1] 152 | return in_channels, out_channels 153 | 154 | def __call__(self, x, y, _, __): 155 | with tf.variable_scope(self.name, values=[x, y], reuse=tf.AUTO_REUSE): 156 | """Apply the discriminator on a input. 157 | Args: 158 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. 159 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 160 | labels. 161 | is_training: Boolean, whether the architecture should be constructed for 162 | training or inference. 163 | Returns: 164 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 165 | before the final output activation function and logits form the second 166 | last layer. 167 | """ 168 | logging.info("[Discriminator] inputs are x=%s, y=%s", x.shape, 169 | None if y is None else y.shape) 170 | resnet_ops.validate_image_inputs(x) 171 | 172 | in_channels, out_channels = self._get_in_out_channels( 173 | colors=x.shape[-1].value, resolution=x.shape[1].value) 174 | num_blocks = len(in_channels) 175 | 176 | net = x 177 | for block_idx in range(num_blocks): 178 | name = "B{}".format(block_idx + 1) 179 | is_last_block = block_idx == num_blocks - 1 180 | block = self._resnet_block( 181 | name=name, 182 | in_channels=in_channels[block_idx], 183 | out_channels=out_channels[block_idx], 184 | scale="none" if is_last_block else "down") 185 | net = block(net, z=None, y=y, is_training=None) 186 | if name in self._blocks_with_attention: 187 | logging.info("[Discriminator] Applying non-local block to %s", 188 | net.shape) 189 | net = ops.non_local_block(net, "non_local_block", 190 | use_sn=self._spectral_norm) 191 | 192 | # Final part 193 | logging.info("[Discriminator] before final processing: %s", net.shape) 194 | net_conv = tf.nn.relu(net) 195 | h = tf.math.reduce_sum(net_conv, axis=[1, 2]) 196 | out_logit_tf = ops.linear(h, 1, scope="final_fc", use_sn=self._spectral_norm) 197 | logging.info("[Discriminator] after final processing: %s", net.shape) 198 | if self._project_y: 199 | if y is None: 200 | raise ValueError("You must provide class information y to project.") 201 | with tf.variable_scope("embedding_fc"): 202 | y_embedding_dim = out_channels[-1] 203 | # We do not use ops.linear() below since it does not have an option to 204 | # override the initializer. 205 | kernel = tf.get_variable( 206 | "kernel", [y.shape[1], y_embedding_dim], tf.float32, 207 | initializer=tf.initializers.glorot_normal()) 208 | if self._spectral_norm: 209 | kernel = ops.spectral_norm(kernel) 210 | embedded_y = tf.matmul(y, kernel) 211 | logging.info("[Discriminator] embedded_y for projection: %s", 212 | embedded_y.shape) 213 | out_logit_tf += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True) 214 | 215 | feature_matching = h 216 | t_SNE = h 217 | out_logit = ops.linear(h, self.num_class, scope="final_fc_cla", use_sn=self._spectral_norm) 218 | 219 | # grad cam 220 | cls = tf.argmax(y,axis=1)[0] 221 | out_logit_cam = tf.identity(out_logit) 222 | y_c = out_logit_cam[0, cls] 223 | grads = tf.gradients(y_c, net_conv)[0] 224 | output_conv = net_conv[0] 225 | grads_val = grads[0] 226 | 227 | # grad cam ++ 228 | first = tf.exp(y_c)*grads 229 | second = tf.exp(y_c)*grads*grads 230 | third = tf.exp(y_c)*grads*grads*grads 231 | conv_first_grad, conv_second_grad, conv_third_grad = first[0], second[0], third[0] 232 | grad_val_plusplus = [conv_first_grad, conv_second_grad, conv_third_grad] 233 | 234 | # saliency 235 | signal = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=out_logit[0, :], labels=y[0])) 236 | guide_grad = tf.gradients(signal, x)[0] 237 | 238 | return tf.nn.softmax(out_logit), out_logit, tf.nn.sigmoid(out_logit_tf), out_logit_tf, feature_matching, t_SNE, output_conv, grads_val, guide_grad, grad_val_plusplus, tf.argmax(y,axis=1)[0], tf.argmax(out_logit_cam,axis=1)[0] 239 | 240 | 241 | -------------------------------------------------------------------------------- /BigGAN/resnet_gen_y_deep.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import functools 6 | 7 | from absl import logging 8 | 9 | import arch_ops as ops 10 | import resnet_ops 11 | 12 | import utils 13 | 14 | from six.moves import range 15 | import tensorflow as tf 16 | import time 17 | 18 | class BigGanDeepResNetBlock(object): 19 | """ResNet block with bottleneck and identity preserving skip connections.""" 20 | 21 | def __init__(self, 22 | name, 23 | in_channels, 24 | out_channels, 25 | scale, 26 | spectral_norm=False, 27 | batch_norm=None): 28 | """Constructs a new ResNet block with bottleneck. 29 | Args: 30 | name: Scope name for the resent block. 31 | in_channels: Integer, the input channel size. 32 | out_channels: Integer, the output channel size. 33 | scale: Whether or not to scale up or down, choose from "up", "down" or 34 | "none". 35 | spectral_norm: Use spectral normalization for all weights. 36 | batch_norm: Function for batch normalization. 37 | """ 38 | assert scale in ["up", "down", "none"] 39 | self._name = name 40 | self._in_channels = in_channels 41 | self._out_channels = out_channels 42 | self._scale = scale 43 | self._spectral_norm = spectral_norm 44 | self.batch_norm = batch_norm 45 | 46 | def __call__(self, inputs, z, y, is_training): 47 | return self.apply(inputs=inputs, z=z, y=y, is_training=is_training) 48 | 49 | def _shortcut(self, inputs): 50 | """Constructs a skip connection from inputs.""" 51 | with tf.variable_scope("shortcut", values=[inputs]): 52 | shortcut = inputs 53 | num_channels = inputs.shape[-1].value 54 | if num_channels > self._out_channels: 55 | assert self._scale == "up" 56 | # Drop redundant channels. 57 | logging.info("[Shortcut] Dropping %d channels in shortcut.", 58 | num_channels - self._out_channels) 59 | shortcut = shortcut[:, :, :, :self._out_channels] 60 | if self._scale == "up": 61 | shortcut = resnet_ops.unpool(shortcut) 62 | if self._scale == "down": 63 | shortcut = tf.nn.pool(shortcut, [2, 2], "AVG", "SAME", 64 | strides=[2, 2], name="pool") 65 | if num_channels < self._out_channels: 66 | assert self._scale == "down" 67 | # Increase number of channels if necessary. 68 | num_missing = self._out_channels - num_channels 69 | logging.info("[Shortcut] Adding %d channels in shortcut.", num_missing) 70 | added = ops.conv1x1(shortcut, num_missing, name="add_channels", 71 | use_sn=self._spectral_norm) 72 | shortcut = tf.concat([shortcut, added], axis=-1) 73 | return shortcut 74 | 75 | def apply(self, inputs, z, y, is_training): 76 | """"ResNet block containing possible down/up sampling, shared for G / D. 77 | Args: 78 | inputs: a 3d input tensor of feature map. 79 | z: the latent vector for potential self-modulation. Can be None if use_sbn 80 | is set to False. 81 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 82 | labels. 83 | is_training: boolean, whether or notthis is called during the training. 84 | Returns: 85 | output: a 3d output tensor of feature map. 86 | """ 87 | if inputs.shape[-1].value != self._in_channels: 88 | raise ValueError( 89 | "Unexpected number of input channels (expected {}, got {}).".format( 90 | self._in_channels, inputs.shape[-1].value)) 91 | 92 | bottleneck_channels = max(self._in_channels, self._out_channels) // 4 93 | bn = functools.partial(self.batch_norm, z=z, y=y, is_training=is_training) 94 | conv1x1 = functools.partial(ops.conv1x1, use_sn=self._spectral_norm) 95 | conv3x3 = functools.partial(ops.conv2d, k_h=3, k_w=3, d_h=1, d_w=1, 96 | use_sn=self._spectral_norm) 97 | 98 | with tf.variable_scope(self._name, values=[inputs]): 99 | outputs = inputs 100 | 101 | with tf.variable_scope("conv1", values=[outputs]): 102 | outputs = bn(outputs, name="bn") 103 | outputs = tf.nn.relu(outputs) 104 | outputs = conv1x1(outputs, bottleneck_channels, name="1x1_conv") 105 | 106 | with tf.variable_scope("conv2", values=[outputs]): 107 | outputs = bn(outputs, name="bn") 108 | outputs = tf.nn.relu(outputs) 109 | if self._scale == "up": 110 | outputs = resnet_ops.unpool(outputs) 111 | outputs = conv3x3(outputs, bottleneck_channels, name="3x3_conv") 112 | 113 | with tf.variable_scope("conv3", values=[outputs]): 114 | outputs = bn(outputs, name="bn") 115 | outputs = tf.nn.relu(outputs) 116 | outputs = conv3x3(outputs, bottleneck_channels, name="3x3_conv") 117 | 118 | with tf.variable_scope("conv4", values=[outputs]): 119 | outputs = bn(outputs, name="bn") 120 | outputs = tf.nn.relu(outputs) 121 | if self._scale == "down": 122 | outputs = tf.nn.pool(outputs, [2, 2], "AVG", "SAME", strides=[2, 2], 123 | name="avg_pool") 124 | outputs = conv1x1(outputs, self._out_channels, name="1x1_conv") 125 | 126 | # Add skip-connection. 127 | outputs += self._shortcut(inputs) 128 | 129 | logging.info("[Block] %s (z=%s, y=%s) -> %s", inputs.shape, 130 | None if z is None else z.shape, 131 | None if y is None else y.shape, outputs.shape) 132 | return outputs 133 | 134 | 135 | class Generator(object): 136 | """ResNet-based generator supporting resolutions 32, 64, 128, 256, 512.""" 137 | 138 | def __init__(self, 139 | name, h, w, c, is_train, use_sn, 140 | ch=128, 141 | embed_y=True, 142 | embed_y_dim=128, 143 | experimental_fast_conv_to_rgb=False, 144 | **kwargs): 145 | """Constructor for BigGAN generator. 146 | Args: 147 | ch: Channel multiplier. 148 | embed_y: If True use a learnable embedding of y that is used instead. 149 | embed_y_dim: Size of the embedding of y. 150 | experimental_fast_conv_to_rgb: If True optimize the last convolution to 151 | sacrifize memory for better speed. 152 | **kwargs: additional arguments past on to ResNetGenerator. 153 | """ 154 | self.name = name 155 | self.s_h, self.s_w, self.colors = [h,w,c] 156 | self._image_shape = [h,w,c] 157 | self._is_train = is_train 158 | self._batch_norm_fn = ops.conditional_batch_norm 159 | self._spectral_norm = False 160 | 161 | self._ch = ch 162 | self._embed_y = embed_y 163 | self._embed_y_dim = embed_y_dim 164 | self._experimental_fast_conv_to_rgb = experimental_fast_conv_to_rgb 165 | 166 | def _resnet_block(self, name, in_channels, out_channels, scale): 167 | """ResNet block for the generator.""" 168 | if scale not in ["up", "none"]: 169 | raise ValueError( 170 | "Unknown generator ResNet block scaling: {}.".format(scale)) 171 | return BigGanDeepResNetBlock( 172 | name=name, 173 | in_channels=in_channels, 174 | out_channels=out_channels, 175 | scale=scale, 176 | spectral_norm=self._spectral_norm, 177 | batch_norm=self.batch_norm) 178 | 179 | def _get_in_out_channels(self): 180 | # See Table 7-9. 181 | resolution = self._image_shape[0] 182 | if resolution == 512: 183 | channel_multipliers = 4 * [16] + 4 * [8] + [4, 4, 2, 2, 1, 1, 1] 184 | elif resolution == 256: 185 | channel_multipliers = 4 * [16] + 4 * [8] + [4, 4, 2, 2, 1] 186 | elif resolution == 128: 187 | channel_multipliers = 4 * [16] + 2 * [8] + [4, 4, 2, 2, 1] 188 | elif resolution == 64: 189 | channel_multipliers = 4 * [16] + 2 * [8] + [4, 4, 2] 190 | elif resolution == 32: 191 | channel_multipliers = 8 * [4] 192 | else: 193 | raise ValueError("Unsupported resolution: {}".format(resolution)) 194 | in_channels = [self._ch * c for c in channel_multipliers[:-1]] 195 | out_channels = [self._ch * c for c in channel_multipliers[1:]] 196 | return in_channels, out_channels 197 | 198 | def batch_norm(self, inputs, **kwargs): 199 | if self._batch_norm_fn is None: 200 | return inputs 201 | args = kwargs.copy() 202 | args["inputs"] = inputs 203 | if "use_sn" not in args: 204 | args["use_sn"] = self._spectral_norm 205 | return utils.call_with_accepted_args(self._batch_norm_fn, **args) 206 | 207 | def __call__(self, z, y): 208 | with tf.variable_scope(self.name, values=[z, y], reuse=tf.AUTO_REUSE): 209 | """Build the generator network for the given inputs. 210 | Args: 211 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 212 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 213 | labels. 214 | is_training: boolean, are we in train or eval model. 215 | Returns: 216 | A tensor of size [batch_size] + self._image_shape with values in [0, 1]. 217 | """ 218 | shape_or_none = lambda t: None if t is None else t.shape 219 | logging.info("[Generator] inputs are z=%s, y=%s", z.shape, shape_or_none(y)) 220 | seed_size = 4 221 | 222 | if self._embed_y: 223 | y = ops.linear(y, self._embed_y_dim, scope="embed_y", use_sn=False, 224 | use_bias=False) 225 | if y is not None: 226 | y = tf.concat([z, y], axis=1) 227 | z = y 228 | 229 | in_channels, out_channels = self._get_in_out_channels() 230 | num_blocks = len(in_channels) 231 | 232 | # Map noise to the actual seed. 233 | net = ops.linear( 234 | z, 235 | in_channels[0] * seed_size * seed_size, 236 | scope="fc_noise", 237 | use_sn=self._spectral_norm) 238 | # Reshape the seed to be a rank-4 Tensor. 239 | net = tf.reshape( 240 | net, 241 | [-1, seed_size, seed_size, in_channels[0]], 242 | name="fc_reshaped") 243 | 244 | for block_idx in range(num_blocks): 245 | scale = "none" if block_idx % 2 == 0 else "up" 246 | block = self._resnet_block( 247 | name="B{}".format(block_idx + 1), 248 | in_channels=in_channels[block_idx], 249 | out_channels=out_channels[block_idx], 250 | scale=scale) 251 | net = block(net, z=z, y=y, is_training=(self._is_train)) 252 | # At resolution 64x64 there is a self-attention block. 253 | if scale == "up" and net.shape[1].value == 64: 254 | logging.info("[Generator] Applying non-local block to %s", net.shape) 255 | net = ops.non_local_block(net, "non_local_block", 256 | use_sn=self._spectral_norm) 257 | # Final processing of the net. 258 | # Use unconditional batch norm. 259 | logging.info("[Generator] before final processing: %s", net.shape) 260 | net = ops.batch_norm(net, is_training=(self._is_train), name="final_norm") 261 | net = tf.nn.relu(net) 262 | colors = self._image_shape[2] 263 | if self._experimental_fast_conv_to_rgb: 264 | 265 | net = ops.conv2d(net, output_dim=128, k_h=3, k_w=3, 266 | d_h=1, d_w=1, name="final_conv", 267 | use_sn=self._spectral_norm) 268 | net = net[:, :, :, :colors] 269 | else: 270 | net = ops.conv2d(net, output_dim=colors, k_h=3, k_w=3, 271 | d_h=1, d_w=1, name="final_conv", 272 | use_sn=self._spectral_norm) 273 | logging.info("[Generator] after final processing: %s", net.shape) 274 | net = (tf.nn.tanh(net) + 1.0) / 2.0 275 | return net -------------------------------------------------------------------------------- /BigGAN/resnet_joint_deep.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import functools 6 | 7 | from absl import logging 8 | 9 | import arch_ops as ops 10 | import resnet_ops 11 | 12 | import utils 13 | 14 | from six.moves import range 15 | import tensorflow as tf 16 | import time 17 | 18 | class BigGanDeepResNetBlock(object): 19 | """ResNet block with bottleneck and identity preserving skip connections.""" 20 | 21 | def __init__(self, 22 | name, 23 | in_channels, 24 | out_channels, 25 | scale, 26 | spectral_norm=False, 27 | batch_norm=None): 28 | """Constructs a new ResNet block with bottleneck. 29 | Args: 30 | name: Scope name for the resent block. 31 | in_channels: Integer, the input channel size. 32 | out_channels: Integer, the output channel size. 33 | scale: Whether or not to scale up or down, choose from "up", "down" or 34 | "none". 35 | spectral_norm: Use spectral normalization for all weights. 36 | batch_norm: Function for batch normalization. 37 | """ 38 | assert scale in ["up", "down", "none"] 39 | self._name = name 40 | self._in_channels = in_channels 41 | self._out_channels = out_channels 42 | self._scale = scale 43 | self._spectral_norm = spectral_norm 44 | self.batch_norm = batch_norm 45 | 46 | def __call__(self, inputs, z, y, is_training): 47 | return self.apply(inputs=inputs, z=z, y=y, is_training=is_training) 48 | 49 | def _shortcut(self, inputs): 50 | """Constructs a skip connection from inputs.""" 51 | with tf.variable_scope("shortcut", values=[inputs]): 52 | shortcut = inputs 53 | num_channels = inputs.shape[-1].value 54 | if num_channels > self._out_channels: 55 | assert self._scale == "up" 56 | # Drop redundant channels. 57 | logging.info("[Shortcut] Dropping %d channels in shortcut.", 58 | num_channels - self._out_channels) 59 | shortcut = shortcut[:, :, :, :self._out_channels] 60 | if self._scale == "up": 61 | shortcut = resnet_ops.unpool(shortcut) 62 | if self._scale == "down": 63 | shortcut = tf.nn.pool(shortcut, [2, 2], "AVG", "SAME", 64 | strides=[2, 2], name="pool") 65 | if num_channels < self._out_channels: 66 | assert self._scale == "down" 67 | # Increase number of channels if necessary. 68 | num_missing = self._out_channels - num_channels 69 | logging.info("[Shortcut] Adding %d channels in shortcut.", num_missing) 70 | added = ops.conv1x1(shortcut, num_missing, name="add_channels", 71 | use_sn=self._spectral_norm) 72 | shortcut = tf.concat([shortcut, added], axis=-1) 73 | return shortcut 74 | 75 | def apply(self, inputs, z, y, is_training): 76 | """"ResNet block containing possible down/up sampling, shared for G / D. 77 | Args: 78 | inputs: a 3d input tensor of feature map. 79 | z: the latent vector for potential self-modulation. Can be None if use_sbn 80 | is set to False. 81 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 82 | labels. 83 | is_training: boolean, whether or notthis is called during the training. 84 | Returns: 85 | output: a 3d output tensor of feature map. 86 | """ 87 | if inputs.shape[-1].value != self._in_channels: 88 | raise ValueError( 89 | "Unexpected number of input channels (expected {}, got {}).".format( 90 | self._in_channels, inputs.shape[-1].value)) 91 | 92 | bottleneck_channels = max(self._in_channels, self._out_channels) // 4 93 | bn = functools.partial(self.batch_norm, z=z, y=y, is_training=is_training) 94 | conv1x1 = functools.partial(ops.conv1x1, use_sn=self._spectral_norm) 95 | conv3x3 = functools.partial(ops.conv2d, k_h=3, k_w=3, d_h=1, d_w=1, 96 | use_sn=self._spectral_norm) 97 | 98 | with tf.variable_scope(self._name, values=[inputs]): 99 | outputs = inputs 100 | 101 | with tf.variable_scope("conv1", values=[outputs]): 102 | outputs = bn(outputs, name="bn") 103 | outputs = tf.nn.relu(outputs) 104 | outputs = conv1x1(outputs, bottleneck_channels, name="1x1_conv") 105 | 106 | with tf.variable_scope("conv2", values=[outputs]): 107 | outputs = bn(outputs, name="bn") 108 | outputs = tf.nn.relu(outputs) 109 | if self._scale == "up": 110 | outputs = resnet_ops.unpool(outputs) 111 | outputs = conv3x3(outputs, bottleneck_channels, name="3x3_conv") 112 | 113 | with tf.variable_scope("conv3", values=[outputs]): 114 | outputs = bn(outputs, name="bn") 115 | outputs = tf.nn.relu(outputs) 116 | outputs = conv3x3(outputs, bottleneck_channels, name="3x3_conv") 117 | 118 | with tf.variable_scope("conv4", values=[outputs]): 119 | outputs = bn(outputs, name="bn") 120 | outputs = tf.nn.relu(outputs) 121 | if self._scale == "down": 122 | outputs = tf.nn.pool(outputs, [2, 2], "AVG", "SAME", strides=[2, 2], 123 | name="avg_pool") 124 | outputs = conv1x1(outputs, self._out_channels, name="1x1_conv") 125 | 126 | # Add skip-connection. 127 | outputs += self._shortcut(inputs) 128 | 129 | logging.info("[Block] %s (z=%s, y=%s) -> %s", inputs.shape, 130 | None if z is None else z.shape, 131 | None if y is None else y.shape, outputs.shape) 132 | return outputs 133 | 134 | 135 | class Classifier_proD(object): 136 | """ResNet-based discriminator supporting resolutions 32, 64, 128, 256, 512.""" 137 | 138 | def __init__(self, 139 | name, num_class, use_sn, 140 | ch=128, 141 | blocks_with_attention="B1", 142 | project_y=True, 143 | **kwargs): 144 | """Constructor for BigGAN discriminator. 145 | Args: 146 | ch: Channel multiplier. 147 | blocks_with_attention: Comma-separated list of blocks that are followed by 148 | a non-local block. 149 | project_y: Add an embedding of y in the output layer. 150 | **kwargs: additional arguments past on to ResNetDiscriminator. 151 | """ 152 | self.name = name 153 | self.num_class = num_class 154 | self._spectral_norm = True 155 | self._batch_norm_fn = None 156 | 157 | self._ch = ch 158 | self._blocks_with_attention = set(blocks_with_attention.split(",")) 159 | self._project_y = project_y 160 | 161 | def _resnet_block(self, name, in_channels, out_channels, scale): 162 | """ResNet block for the generator.""" 163 | if scale not in ["down", "none"]: 164 | raise ValueError( 165 | "Unknown discriminator ResNet block scaling: {}.".format(scale)) 166 | return BigGanDeepResNetBlock( 167 | name=name, 168 | in_channels=in_channels, 169 | out_channels=out_channels, 170 | scale=scale, 171 | spectral_norm=self._spectral_norm, 172 | batch_norm=self.batch_norm) 173 | 174 | def _get_in_out_channels(self, colors, resolution): 175 | # See Table 7-9. 176 | if colors not in [1, 3]: 177 | raise ValueError("Unsupported color channels: {}".format(colors)) 178 | if resolution == 512: 179 | channel_multipliers = [1, 1, 1, 2, 2, 4, 4] + 4 * [8] + 4 * [16] 180 | elif resolution == 256: 181 | channel_multipliers = [1, 2, 2, 4, 4] + 4 * [8] + 4 * [16] 182 | elif resolution == 128: 183 | channel_multipliers = [1, 2, 2, 4, 4] + 2 * [8] + 4 * [16] 184 | elif resolution == 64: 185 | channel_multipliers = [2, 4, 4] + 2 * [8] + 4 * [16] 186 | elif resolution == 32: 187 | channel_multipliers = 8 * [2] 188 | else: 189 | raise ValueError("Unsupported resolution: {}".format(resolution)) 190 | in_channels = [self._ch * c for c in channel_multipliers[:-1]] 191 | out_channels = [self._ch * c for c in channel_multipliers[1:]] 192 | return in_channels, out_channels 193 | 194 | def batch_norm(self, inputs, **kwargs): 195 | if self._batch_norm_fn is None: 196 | return inputs 197 | args = kwargs.copy() 198 | args["inputs"] = inputs 199 | if "use_sn" not in args: 200 | args["use_sn"] = self._spectral_norm 201 | return utils.call_with_accepted_args(self._batch_norm_fn, **args) 202 | 203 | def __call__(self, x, y, _, __): 204 | with tf.variable_scope(self.name, values=[x, y], reuse=tf.AUTO_REUSE): 205 | """Apply the discriminator on a input. 206 | Args: 207 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. 208 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 209 | labels. 210 | is_training: Boolean, whether the architecture should be constructed for 211 | training or inference. 212 | Returns: 213 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 214 | before the final output activation function and logits form the second 215 | last layer. 216 | """ 217 | logging.info("[Discriminator] inputs are x=%s, y=%s", x.shape, 218 | None if y is None else y.shape) 219 | resnet_ops.validate_image_inputs(x) 220 | 221 | in_channels, out_channels = self._get_in_out_channels( 222 | colors=x.shape[-1].value, resolution=x.shape[1].value) 223 | num_blocks = len(in_channels) 224 | 225 | net = ops.conv2d(x, output_dim=in_channels[0], k_h=3, k_w=3, 226 | d_h=1, d_w=1, name="initial_conv", 227 | use_sn=self._spectral_norm) 228 | 229 | for block_idx in range(num_blocks): 230 | scale = "down" if block_idx % 2 == 0 else "none" 231 | block = self._resnet_block( 232 | name="B{}".format(block_idx + 1), 233 | in_channels=in_channels[block_idx], 234 | out_channels=out_channels[block_idx], 235 | scale=scale) 236 | net = block(net, z=None, y=y, is_training=None) 237 | # At resolution 64x64 there is a self-attention block. 238 | if scale == "none" and net.shape[1].value == 64: 239 | logging.info("[Discriminator] Applying non-local block to %s", 240 | net.shape) 241 | net = ops.non_local_block(net, "non_local_block", 242 | use_sn=self._spectral_norm) 243 | 244 | # Final part 245 | logging.info("[Discriminator] before final processing: %s", net.shape) 246 | net_conv = tf.nn.relu(net) 247 | h = tf.math.reduce_sum(net_conv, axis=[1, 2]) 248 | out_logit_tf = ops.linear(h, 1, scope="final_fc", use_sn=self._spectral_norm) 249 | logging.info("[Discriminator] after final processing: %s", net.shape) 250 | if self._project_y: 251 | if y is None: 252 | raise ValueError("You must provide class information y to project.") 253 | with tf.variable_scope("embedding_fc"): 254 | y_embedding_dim = out_channels[-1] 255 | # We do not use ops.linear() below since it does not have an option to 256 | # override the initializer. 257 | kernel = tf.get_variable( 258 | "kernel", [y.shape[1], y_embedding_dim], tf.float32, 259 | initializer=tf.initializers.glorot_normal()) 260 | if self._spectral_norm: 261 | kernel = ops.spectral_norm(kernel) 262 | embedded_y = tf.matmul(y, kernel) 263 | logging.info("[Discriminator] embedded_y for projection: %s", 264 | embedded_y.shape) 265 | out_logit_tf += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True) 266 | 267 | feature_matching = h 268 | t_SNE = h 269 | out_logit = ops.linear(h, self.num_class, scope="final_fc_cla", use_sn=self._spectral_norm) 270 | 271 | # grad cam 272 | cls = tf.argmax(y,axis=1)[0] 273 | out_logit_cam = tf.identity(out_logit) 274 | y_c = out_logit_cam[0, cls] 275 | grads = tf.gradients(y_c, net_conv)[0] 276 | output_conv = net_conv[0] 277 | grads_val = grads[0] 278 | 279 | # grad cam ++ 280 | first = tf.exp(y_c)*grads 281 | second = tf.exp(y_c)*grads*grads 282 | third = tf.exp(y_c)*grads*grads*grads 283 | conv_first_grad, conv_second_grad, conv_third_grad = first[0], second[0], third[0] 284 | grad_val_plusplus = [conv_first_grad, conv_second_grad, conv_third_grad] 285 | 286 | # saliency 287 | signal = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=out_logit[0, :], labels=y[0])) 288 | guide_grad = tf.gradients(signal, x)[0] 289 | 290 | return tf.nn.softmax(out_logit), out_logit, tf.nn.sigmoid(out_logit_tf), out_logit_tf, feature_matching, t_SNE, output_conv, grads_val, guide_grad, grad_val_plusplus, tf.argmax(y,axis=1)[0], tf.argmax(out_logit_cam,axis=1)[0] -------------------------------------------------------------------------------- /trainer_joint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import time 7 | from six.moves import xrange 8 | from pprint import pprint 9 | import h5py 10 | import tensorflow as tf 11 | import numpy as np 12 | import tensorflow.contrib.slim as slim 13 | 14 | from input_ops import create_input_ops 15 | from util import log 16 | from config import argparser, get_params 17 | from tensorflow.contrib.tensorboard.plugins import projector 18 | import cv2 19 | import time 20 | from tensorflow.python.framework import ops 21 | from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, roc_auc_score, f1_score, fbeta_score, cohen_kappa_score 22 | 23 | def sigmoid(x, a, b, c): 24 | return c / (1 + np.exp(-a * (x-b))) 25 | 26 | class Trainer(object): 27 | 28 | def __init__(self, config, model, dataset_train, dataset_train_unlabel, dataset_test): 29 | self.config = config 30 | self.model = model 31 | hyper_parameter_str = '{}_lr_g_{}_d_{}_update_G{}D{}'.format( 32 | config["dataset"], config["learning_rate_g"], config["learning_rate_d"], 33 | config["update_rate"], 1 34 | ) 35 | self.train_dir = './train_dir/TMI/%s-%s-%s' % ( 36 | config["prefix"], 37 | hyper_parameter_str, 38 | time.strftime("%Y%m%d-%H%M%S") 39 | ) 40 | 41 | os.makedirs(self.train_dir) 42 | log.infov("Train Dir: %s", self.train_dir) 43 | 44 | self.batch_size = config["batch_size_L"] 45 | 46 | # --- input ops --- 47 | self.batch_train = create_input_ops( 48 | dataset_train, config["batch_size_L"]) 49 | 50 | self.batch_train_unlabel = create_input_ops( 51 | dataset_train_unlabel, config["batch_size_U"] * 2) 52 | 53 | self.batch_test = create_input_ops( 54 | dataset_test, config["batch_size_L"]) 55 | 56 | # --- optimizer --- 57 | self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) 58 | # --- checkpoint and monitoring --- 59 | all_var = tf.trainable_variables() 60 | 61 | gG_var = [v for v in all_var if v.name.startswith(('Generator_g'))] 62 | log.warning("********* gG_var ********** ") 63 | slim.model_analyzer.analyze_vars(gG_var, print_info=True) 64 | 65 | bG_var = [v for v in all_var if v.name.startswith(('Generator_b'))] 66 | log.warning("********* bG_var ********** ") 67 | slim.model_analyzer.analyze_vars(bG_var, print_info=True) 68 | 69 | c_var = [v for v in all_var if v.name.startswith(('Classifier'))] 70 | log.warning("********* c_var ********** ") 71 | slim.model_analyzer.analyze_vars(c_var, print_info=True) 72 | 73 | rem_var = (set(all_var) - set(gG_var) - set(bG_var) - set(c_var)) 74 | log.error([v.name for v in rem_var]) 75 | assert not rem_var 76 | 77 | self.gG_optimizer = tf.contrib.layers.optimize_loss( 78 | loss=self.model.g_loss_good, 79 | global_step=self.global_step, 80 | learning_rate=config["learning_rate_g"], 81 | optimizer=tf.train.AdamOptimizer(beta1=0.5,beta2=0.999), 82 | clip_gradients=20.0, 83 | name='gG_optimize_loss', 84 | variables=gG_var 85 | ) 86 | 87 | self.bG_optimizer = tf.contrib.layers.optimize_loss( 88 | loss=self.model.g_loss_bad, 89 | global_step=self.global_step, 90 | learning_rate=config["learning_rate_g"], 91 | optimizer=tf.train.AdamOptimizer(beta1=0.5,beta2=0.999), 92 | clip_gradients=20.0, 93 | name='bG_optimize_loss', 94 | variables=bG_var 95 | ) 96 | 97 | self.c_optimizer = tf.contrib.layers.optimize_loss( 98 | loss=self.model.c_loss + self.model.d_loss, 99 | global_step=self.global_step, 100 | learning_rate=config["learning_rate_d"], 101 | optimizer=tf.train.AdamOptimizer(beta1=0.5,beta2=0.999), 102 | clip_gradients=20.0, 103 | name='c_optimize_loss', 104 | variables=c_var 105 | ) 106 | 107 | self.c_optimizer_only = tf.contrib.layers.optimize_loss( 108 | loss=self.model.c_loss, 109 | global_step=self.global_step, 110 | learning_rate=config["learning_rate_d"], 111 | optimizer=tf.train.AdamOptimizer(beta1=0.5,beta2=0.999), 112 | clip_gradients=20.0, 113 | name='c_optimize_loss_only', 114 | variables=c_var 115 | ) 116 | 117 | self.summary_op = tf.summary.merge_all() 118 | self.saver = tf.train.Saver(max_to_keep=1000) 119 | self.summary_writer = tf.summary.FileWriter(self.train_dir) 120 | 121 | self.supervisor = tf.train.Supervisor( 122 | logdir=self.train_dir, 123 | is_chief=True, 124 | saver=None, 125 | summary_op=None, 126 | summary_writer=self.summary_writer, 127 | save_summaries_secs=300, 128 | save_model_secs=600, 129 | global_step=self.global_step, 130 | ) 131 | 132 | session_config = tf.ConfigProto( 133 | allow_soft_placement=True, 134 | gpu_options=tf.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction = 0.99), 135 | device_count={'GPU': 1} 136 | ) 137 | self.session = self.supervisor.prepare_or_wait_for_session(config=session_config) 138 | 139 | self.ckpt_path = config["checkpoint"] 140 | if self.ckpt_path is not None: 141 | log.info("Checkpoint path: %s", self.ckpt_path) 142 | self.saver.restore(self.session, self.ckpt_path) 143 | log.info("Loaded the pretrain parameters from the provided checkpoint path") 144 | 145 | def train(self): 146 | log.infov("Training Starts!") 147 | log.infov(self.batch_train) 148 | log.infov(self.batch_train_unlabel) 149 | step = self.session.run(self.global_step) 150 | 151 | for s in xrange(self.config["max_training_steps"]): 152 | step, accuracy, d_loss, g_loss, step_time, prediction_train, gt_train, Image_Real = \ 153 | self.run_single_step(self.batch_train, self.batch_train_unlabel, step=s, is_train = True) 154 | 155 | if s % self.config["log_step"] == self.config["log_step"] - 1: 156 | self.log_step_message(step + 1, accuracy, d_loss, g_loss, step_time, is_train=True) 157 | 158 | # periodic inference 159 | if s % self.config["test_sample_step"] == self.config["test_sample_step"] - 1: 160 | 161 | accuracy_mean, recall, precision, f1, kappa = [], [], [], [], [] 162 | 163 | accuracy, summary, d_loss, g_loss, step_time, prediction_test, gt_test = \ 164 | self.run_test(self.batch_test, self.batch_train_unlabel, is_train=False, step=s) 165 | self.log_step_message(step + 1, accuracy, d_loss, g_loss, 166 | step_time, is_train=False) 167 | 168 | self.summary_writer.add_summary(summary, global_step=step + 1) 169 | 170 | y_true = np.argmax(prediction_test, axis=1) 171 | y_pred = np.argmax(gt_test, axis=1) 172 | 173 | accuracy_mean.append(accuracy_score(y_true, y_pred)) 174 | recall.append(recall_score(y_true, y_pred, average="macro")) 175 | precision.append(precision_score(y_true, y_pred, average="macro")) 176 | f1.append(f1_score(y_true, y_pred, average="macro")) 177 | kappa.append(cohen_kappa_score(y_true, y_pred)) 178 | 179 | 180 | if s % self.config["output_save_step"] == self.config["output_save_step"] - 1 and s != self.config["max_training_steps"]-1: 181 | log.infov("Saved checkpoint at %d", step + 1) 182 | self.saver.save(self.session, os.path.join(self.train_dir, 'model'), global_step=step + 1) 183 | 184 | if s == self.config["max_training_steps"] - 1: 185 | log.infov("Saved checkpoint at %d", step + 1) 186 | self.saver.save(self.session, os.path.join(self.train_dir, 'model'), global_step=step + 1) 187 | 188 | def run_single_step(self, batch, batch_unlabel, step=None, is_train=True): 189 | 190 | _start_time = time.time() 191 | 192 | batch_chunk = self.session.run(batch) 193 | batch_chunk_unlabel = self.session.run(batch_unlabel) 194 | 195 | z = np.random.uniform(low = -1.0, high = 1.0, size=(self.config["batch_size_G"], self.config["n_z"])).astype(np.float32) 196 | z_tmp = np.random.uniform(low = -1.0, high = 1.0, size=(self.config["n_z"])).astype(np.float32) 197 | z_linspace = np.linspace(z[0], z_tmp, 10) 198 | y_temp = np.random.randint(low = 0, high = self.config["num_class"], size = (self.config["batch_size_G"])) 199 | y = np.zeros((self.config["batch_size_G"], self.config["num_class"])) 200 | y[np.arange(self.config["batch_size_G"]), y_temp] = 1 201 | 202 | fetch = [self.global_step, self.model.accuracy, 203 | self.model.d_loss, self.model.g_loss, 204 | self.model.x_l_ph, 205 | self.model.all_preds, self.model.all_targets] 206 | 207 | if step % (self.config["update_rate"]+1) == 0: 208 | # Train the generator 209 | fetch.append(self.c_optimizer) 210 | elif step % (self.config["update_rate"]+1) == 1: 211 | # Train the discriminator 212 | fetch.append(self.gG_optimizer) 213 | elif step % (self.config["update_rate"]+1) == 2: 214 | fetch.append(self.bG_optimizer) 215 | 216 | fetch_values = self.session.run(fetch, 217 | feed_dict = self.model.get_feed_dict_withunlabel(batch_chunk, batch_chunk_unlabel, z, z_linspace, y, step=step, is_training = is_train)) 218 | # log.error(fetch_values[8]) 219 | [step, accuracy, d_loss, g_loss, Image_Real,\ 220 | all_preds, all_targets] = fetch_values[:7] 221 | 222 | _end_time = time.time() 223 | 224 | return step, accuracy, d_loss, g_loss, \ 225 | (_end_time - _start_time), all_preds, all_targets, Image_Real 226 | 227 | 228 | def run_test(self, batch, batch_unlabel, is_train=False, step=None): 229 | 230 | _start_time = time.time() 231 | 232 | batch_chunk = self.session.run(batch) 233 | batch_chunk_unlabel = self.session.run(batch_unlabel) 234 | 235 | z = np.random.uniform(low = -1.0, high = 1.0, size=(self.config["batch_size_G"], self.config["n_z"])).astype(np.float32) 236 | z_tmp = np.random.uniform(low = -1.0, high = 1.0, size=(self.config["n_z"])).astype(np.float32) 237 | z_linspace = np.linspace(z[0], z_tmp, 10) 238 | y_temp = np.random.randint(low = 0, high = self.config["num_class"], size = (self.config["batch_size_G"])) 239 | y = np.zeros((self.config["batch_size_G"], self.config["num_class"])) 240 | y[np.arange(self.config["batch_size_G"]), y_temp] = 1 241 | 242 | [accuracy, summary, d_loss, g_loss, all_preds, all_targets] = self.session.run( 243 | [self.model.accuracy, self.summary_op, self.model.d_loss, 244 | self.model.g_loss, self.model.all_preds, self.model.all_targets], 245 | feed_dict=self.model.get_feed_dict_withunlabel(batch_chunk, batch_chunk_unlabel, z, z_linspace, y, step=step, is_training = is_train)) 246 | 247 | _end_time = time.time() 248 | 249 | return accuracy, summary, d_loss, g_loss, (_end_time - _start_time), all_preds, all_targets 250 | 251 | def log_step_message(self, step, accuracy, d_loss, g_loss, 252 | step_time, is_train=True): 253 | if step_time == 0: step_time = 0.001 254 | log_fn = (is_train and log.info or log.infov) 255 | log_fn((" [{split_mode:5s} step {step:4d}] " + 256 | "D loss: {d_loss:.5f} " + 257 | "G loss: {g_loss:.5f} " + 258 | "Accuracy: {accuracy:.5f} " 259 | "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) " 260 | ).format(split_mode=(is_train and 'train' or 'val'), 261 | step = step, 262 | d_loss = d_loss, 263 | g_loss = g_loss, 264 | accuracy = accuracy, 265 | sec_per_batch = step_time, 266 | instance_per_sec = self.batch_size / step_time) 267 | ) 268 | 269 | def main(): 270 | try: 271 | params, args = get_params() 272 | params = vars(params) 273 | 274 | config, model, dataset_train, dataset_train_unlabel, dataset_val, dataset_test = argparser(params, is_train=True) 275 | trainer = Trainer(config, model, dataset_train, dataset_train_unlabel, dataset_val) 276 | 277 | log.info("dataset: %s, learning_rate_g: %f, learning_rate_d: %f", 278 | config["dataset"], config["learning_rate_g"], config["learning_rate_d"]) 279 | 280 | trainer.train() 281 | 282 | except Exception as exception: 283 | log.exception(exception) 284 | raise 285 | 286 | if __name__ == '__main__': 287 | 288 | main() 289 | 290 | -------------------------------------------------------------------------------- /model_joint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from util import log 8 | from sngan_gen import Generator 9 | from sngan_gen_y import Generator as gGenerator 10 | from sngan_joint import Classifier_proD 11 | import os 12 | import time 13 | import numpy as np 14 | from diffAugment import DiffAugment 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = '2' 17 | 18 | class Model(object): 19 | 20 | def __init__(self, config, 21 | debug_information=False, 22 | is_train=True): 23 | self.debug = debug_information 24 | 25 | self.config = config 26 | self.batch_size_G = self.config["batch_size_G"] 27 | self.batch_size_L = self.config["batch_size_L"] 28 | self.batch_size_U = self.config["batch_size_U"] 29 | 30 | self.h = self.config["h"] 31 | self.w = self.config["w"] 32 | self.c = self.config["c"] 33 | self.IMAGE_DIM = [self.h, self.w, self.c] 34 | 35 | self.len = self.config["len"] 36 | self.num_class = self.config["num_class"] 37 | self.n_z = self.config["n_z"] 38 | 39 | self.is_training = tf.placeholder_with_default(bool(is_train), [], name='is_training') 40 | 41 | self.z_g_ph = tf.placeholder(tf.float32, [self.batch_size_G, self.n_z], name = 'latent_variable') # latent variable 42 | self.z_g_ph_linspace = tf.placeholder(tf.float32, [10, self.n_z], name = 'latent_variable_linspace') 43 | self.y_g_ph = tf.placeholder(tf.float32, [self.batch_size_G, self.num_class], name='condition_label') 44 | 45 | self.x_l_ph = tf.placeholder(tf.float32, [self.batch_size_L] + self.IMAGE_DIM, name='labeled_images') 46 | self.y_l_ph = tf.placeholder(tf.float32, [self.batch_size_L, self.num_class], name='real_label') 47 | 48 | self.x_u_ph = tf.placeholder(tf.float32, [self.batch_size_U] + self.IMAGE_DIM, name='unlabeled_images') 49 | self.x_u_c_ph = tf.placeholder(tf.float32, [self.batch_size_U] + self.IMAGE_DIM, name='unlabeled_images_for_c') 50 | self.y_u_ph = tf.placeholder(tf.float32, [self.batch_size_U, self.num_class], name='unlabeled_tmp') 51 | 52 | self.weights = tf.placeholder_with_default(0.0, [], name='weight') 53 | self.keep_prob_first = tf.placeholder(tf.float32, name='keep_prob_first') 54 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 55 | tf.summary.scalar("Loss/recon_wieght", self.weights) 56 | 57 | self.build(is_train=is_train) 58 | 59 | def get_feed_dict_withunlabel(self, batch_chunk, batch_chunk_unlabel, z, z_linspace, y, step=None, is_training=None): 60 | fd = { 61 | self.x_l_ph: batch_chunk['image'], 62 | self.y_l_ph: batch_chunk['label'], 63 | self.z_g_ph: z, 64 | self.z_g_ph_linspace: z_linspace, 65 | self.y_g_ph: y, 66 | self.x_u_ph: batch_chunk_unlabel['image'][:self.batch_size_U], 67 | self.x_u_c_ph: batch_chunk_unlabel['image'][self.batch_size_U:self.batch_size_U + self.batch_size_U], 68 | self.y_u_ph: y 69 | } 70 | 71 | if is_training is not None: 72 | fd[self.is_training] = is_training 73 | 74 | if step > 50000: 75 | fd[self.weights] = 1.0 76 | 77 | return fd 78 | 79 | def _entropy(self, logits): 80 | with tf.name_scope('Entropy'): 81 | probs = tf.nn.softmax(logits) 82 | ent = tf.reduce_mean(- tf.reduce_sum(probs * logits, axis=1, keepdims=True) \ 83 | + tf.reduce_logsumexp(logits, axis=1, keepdims=True)) 84 | return ent 85 | 86 | def build(self, is_train=True): 87 | 88 | n = self.num_class 89 | 90 | # build loss and accuracy {{{ 91 | def build_loss(D_real, D_real_logits, D_real_logits_org, D_real_FM, D_real_tf, D_real_logits_tf, D_fake_bad, D_fake_logits_bad, D_fake_bad_FM, D_fake_good, D_fake_logits_good, D_fake_good_tf, D_fake_logits_good_tf, D_unl, D_unl_logits, D_unl_logits_noaug, D_unl_FM, D_unl_hard, D_unl_tf, D_unl_logits_tf, C_unl_hard, C_unl_tf, C_unl_logits_tf, x, fake_image_bad, fake_image_good, label, y_g_ph, x_u_ph, y_u, C): 92 | 93 | # Good GANs 94 | # Discriminator/classifier loss 95 | c_loss_real = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( 96 | logits=D_real_logits, labels=label)) 97 | 98 | c_loss_fake_bad = 0.5 * tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(D_fake_logits_bad, axis = 1))) 99 | c_loss_unl_bad = - 0.5 * tf.reduce_mean(tf.reduce_logsumexp(D_unl_logits, axis = 1)) + 0.5 * tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(D_unl_logits, axis = 1))) 100 | 101 | probs = tf.nn.softmax(D_unl_logits) 102 | p = tf.reduce_max(probs, axis = 1) 103 | c_loss_unl_rein = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(D_unl_logits_tf), 104 | logits=D_unl_logits_tf), axis = 1) ## C fools D 105 | c_loss_unl_rein = 0.5 * tf.reduce_mean(tf.multiply(p, c_loss_unl_rein)) 106 | 107 | c_loss_fake_good_pseudo = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=D_fake_logits_good, labels=y_g_ph)) 108 | 109 | d_loss_real_tf = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(D_real_logits_tf), logits = D_real_logits_tf)) 110 | d_loss_fake_good_tf = 0.5 * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake_logits_good_tf), logits = D_fake_logits_good_tf)) 111 | #d_loss_unl_tf = 0.5 * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_unl_logits_tf), logits = D_unl_logits_tf)) 112 | #d_loss_unl_tf_c = 0.5 * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(C_unl_logits_tf), logits = C_unl_logits_tf)) 113 | 114 | # Dynamic class-rebalancing 115 | d_unl_add=tf.constant(1e-10) 116 | d_loss_unl_tf_tmp=tf.constant(0.0) 117 | for i in range(C_unl_logits_tf.shape[0]): 118 | d_loss_unl_tf_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(0,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_unl_logits_tf[i]), logits = D_unl_logits_tf[i])),tf.constant(0.0)) 119 | d_loss_unl_tf_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(2,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_unl_logits_tf[i]), logits = D_unl_logits_tf[i])),tf.constant(0.0)) 120 | d_loss_unl_tf_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(3,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_unl_logits_tf[i]), logits = D_unl_logits_tf[i])),tf.constant(0.0)) 121 | d_loss_unl_tf_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(4,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_unl_logits_tf[i]), logits = D_unl_logits_tf[i])),tf.constant(0.0)) 122 | d_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(0,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 123 | d_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(2,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 124 | d_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(3,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 125 | d_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(4,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 126 | d_loss_unl_tf = 0.5 * d_loss_unl_tf_tmp / d_unl_add 127 | 128 | c_unl_add=tf.constant(1e-10) 129 | d_loss_unl_tf_c_tmp=tf.constant(0.0) 130 | for i in range(C_unl_logits_tf.shape[0]): 131 | d_loss_unl_tf_c_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(1,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(C_unl_logits_tf[i]), logits = C_unl_logits_tf[i])),tf.constant(0.0)) 132 | c_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(1,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 133 | d_loss_unl_tf_c_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(2,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(C_unl_logits_tf[i]), logits = C_unl_logits_tf[i])),tf.constant(0.0)) 134 | c_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(2,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 135 | d_loss_unl_tf_c_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(3,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(C_unl_logits_tf[i]), logits = C_unl_logits_tf[i])),tf.constant(0.0)) 136 | c_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(3,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 137 | d_loss_unl_tf_c_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(4,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(C_unl_logits_tf[i]), logits = C_unl_logits_tf[i])),tf.constant(0.0)) 138 | c_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(4,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 139 | d_loss_unl_tf_c = 0.5 * d_loss_unl_tf_c_tmp / c_unl_add 140 | 141 | # Conditional entropy 142 | c_ent = 0.1 * tf.reduce_mean(tf.distributions.Categorical(logits=D_unl_logits).entropy()) 143 | 144 | # Batch Nuclear-norm Maximization 145 | c_bnm = - 0.1 * tf.reduce_sum(tf.svd(D_unl, compute_uv = False)) / self.batch_size_U 146 | 147 | c_loss = c_loss_real + self.weights * c_loss_fake_good_pseudo + c_loss_unl_rein 148 | c_loss += c_loss_fake_bad + c_loss_unl_bad + c_bnm 149 | 150 | d_loss = d_loss_real_tf + d_loss_fake_good_tf + d_loss_unl_tf_c + d_loss_unl_tf 151 | 152 | # Feature matching loss 153 | F_match_loss = tf.reduce_mean(tf.abs(tf.reduce_mean(D_unl_FM,axis=0) - tf.reduce_mean(D_fake_bad_FM,axis=0))) 154 | 155 | # Entropy term via pull-away term 156 | feat_norm = D_fake_bad_FM / tf.norm(D_fake_bad_FM, ord='euclidean', axis=1, \ 157 | keepdims=True) 158 | cosine = tf.tensordot(feat_norm, feat_norm, axes=[[1], [1]]) 159 | mask = tf.ones(tf.shape(cosine)) - tf.diag(tf.ones(tf.shape(cosine)[0])) 160 | square = tf.reduce_sum(tf.square(tf.multiply(cosine, mask))) 161 | divident = tf.cast(tf.shape(cosine)[0] * (tf.shape(cosine)[0] - 1), tf.float32) 162 | G_pt = 0.1 * tf.divide(square, divident) 163 | 164 | g_loss_bad = F_match_loss + G_pt 165 | 166 | g_loss_good = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(D_fake_logits_good_tf), logits = D_fake_logits_good_tf)) 167 | 168 | g_loss = g_loss_good + g_loss_bad 169 | 170 | GAN_loss = tf.reduce_mean(d_loss + g_loss + c_loss) 171 | 172 | # Classification accuracy 173 | correct_prediction = tf.equal(tf.argmax(D_real_logits_org, axis = 1), 174 | tf.argmax(label, 1)) 175 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 176 | 177 | return c_loss_real, c_loss_fake_bad, c_loss_unl_bad, c_loss_unl_rein, c_loss_fake_good_pseudo, d_loss_real_tf, d_loss_fake_good_tf, d_loss_unl_tf, d_loss_unl_tf_c, c_ent, c_bnm, F_match_loss, G_pt, g_loss_bad, g_loss_good, c_loss, d_loss, g_loss, GAN_loss, accuracy, d_loss_unl_tf_c_tmp, c_unl_add, d_loss_unl_tf_tmp, d_unl_add 178 | # }}} 179 | 180 | # Generator {{{ 181 | # ========= 182 | G_bad = Generator('Generator_bad', self.h, self.w, self.c, self.is_training, use_sn = False) 183 | G_good = gGenerator('Generator_good', self.h, self.w, self.c, self.is_training, use_sn = False) 184 | 185 | self.fake_image_bad = G_bad(self.z_g_ph) 186 | self.fake_image_good = G_good(self.z_g_ph, self.y_g_ph) 187 | # }}} 188 | 189 | # Generator {{{ 190 | # ========= 191 | # output of C for real images 192 | C = Classifier_proD('Classifier_proD', self.num_class, use_sn = True) 193 | 194 | # Discriminator {{{ 195 | # ========= 196 | x_r = DiffAugment(self.x_l_ph) 197 | _, _, D_real_tf, D_real_logits_tf, _ = C(x_r, self.y_l_ph) 198 | D_real, D_real_logits, _, _, D_real_FM = C(x_r, self.y_l_ph) 199 | 200 | self.real_predict, D_real_logits_org, _, _, _ = C(self.x_l_ph, self.y_l_ph) 201 | 202 | # output of D for generated examples 203 | x_bG = DiffAugment(self.fake_image_bad) 204 | D_fake_bad, D_fake_logits_bad, _, _, D_fake_bad_FM = C(x_bG, self.y_g_ph) 205 | 206 | x_gG = DiffAugment(self.fake_image_good) 207 | D_fake_good, D_fake_logits_good, _, _, D_fake_good_FM = C(x_gG, self.y_g_ph) 208 | _, _, D_fake_good_tf, D_fake_logits_good_tf, _ = C(x_gG, self.y_g_ph) 209 | 210 | # output of D for unlabeled examples (negative example) 211 | x_d = DiffAugment(self.x_u_ph) 212 | D_unl, D_unl_logits, _, _, D_unl_FM = C(x_d, self.y_u_ph) 213 | D_unl_tmp, D_unl_logits_noaug, _, _, _ = C(self.x_u_ph, self.y_u_ph) 214 | D_unl_hard = tf.argmax(D_unl_tmp, axis = 1) 215 | _, _, D_unl_tf, D_unl_logits_tf, _ = C(x_d, tf.one_hot(D_unl_hard, depth = self.num_class)) 216 | 217 | x_c = DiffAugment(self.x_u_c_ph) 218 | C_unl, C_unl_logits, _, _, C_unl_FM = C(x_c, self.y_u_ph) 219 | C_unl_tmp, _, _, _, _ = C(self.x_u_c_ph, self.y_u_ph) 220 | C_unl_hard = tf.argmax(C_unl_tmp, axis = 1) 221 | _, _, C_unl_tf, C_unl_logits_tf, _ = C(x_c, tf.one_hot(C_unl_hard, depth = self.num_class)) 222 | 223 | self.all_preds = D_real_logits_org 224 | self.all_targets = self.y_l_ph 225 | # }}} 226 | 227 | self.real_activations = D_real_logits 228 | self.fake_activations = D_fake_logits_good 229 | 230 | self.c_loss_real, self.c_loss_fake_bad, self.c_loss_unl_bad, self.c_loss_unl_rein, self.c_loss_fake_good_pseudo, self.d_loss_real_tf, self.d_loss_fake_good_tf, self.d_loss_unl_tf, self.d_loss_unl_tf_c, self.c_ent, self.c_bnm, self.F_match_loss, self.G_pt, self.g_loss_bad, self.g_loss_good, self.c_loss, self.d_loss, self.g_loss, self.GAN_loss, self.accuracy, self.d_loss_unl_tf_c_tmp, self.c_unl_add, self.d_loss_unl_tf_tmp, self.d_unl_add = build_loss(D_real, D_real_logits, D_real_logits_org, D_real_FM, D_real_tf, D_real_logits_tf, D_fake_bad, D_fake_logits_bad, D_fake_bad_FM, D_fake_good, D_fake_logits_good, D_fake_good_tf, D_fake_logits_good_tf, D_unl, D_unl_logits, D_unl_logits_noaug, D_unl_FM, D_unl_hard, D_unl_tf, D_unl_logits_tf, C_unl_hard, C_unl_tf, C_unl_logits_tf, self.x_l_ph, self.fake_image_bad, self.fake_image_good, self.y_l_ph, self.y_g_ph, x_d, self.y_u_ph, C) 231 | 232 | tf.summary.scalar("Loss/Accuracy", self.accuracy) 233 | tf.summary.scalar("Loss/C_loss_real", self.c_loss_real) 234 | tf.summary.scalar("Loss/C_loss_fake_bad", self.c_loss_fake_bad) 235 | tf.summary.scalar("Loss/C_loss_unl_bad", self.c_loss_unl_bad) 236 | tf.summary.scalar("Loss/C_loss_unl_rein", self.c_loss_unl_rein) 237 | tf.summary.scalar("Loss/C_loss_fake_good_pseudo", self.c_loss_fake_good_pseudo) 238 | tf.summary.scalar("Loss/C_bnm", self.c_bnm) 239 | tf.summary.scalar("Loss/D_loss_real_tf", self.d_loss_real_tf) 240 | tf.summary.scalar("Loss/D_loss_fake_good_tf", self.d_loss_fake_good_tf) 241 | tf.summary.scalar("Loss/D_loss_unl_tf", self.d_loss_unl_tf) 242 | tf.summary.scalar("Loss/D_loss_unl_tf_c", self.d_loss_unl_tf_c) 243 | tf.summary.scalar("Loss/F_match_loss", self.F_match_loss) 244 | tf.summary.scalar("Loss/G_pt", self.G_pt) 245 | tf.summary.scalar("Loss/G_loss_bad", self.g_loss_bad) 246 | tf.summary.scalar("Loss/G_loss_good", self.g_loss_good) 247 | tf.summary.scalar("Loss/D_loss_unl_tf_tmp", self.d_loss_unl_tf_tmp) 248 | tf.summary.scalar("Loss/D_unl_add", self.d_unl_add) 249 | tf.summary.scalar("Loss/D_loss_unl_tf_c_tmp", self.d_loss_unl_tf_c_tmp) 250 | tf.summary.scalar("Loss/C_unl_add", self.c_unl_add) 251 | 252 | tf.summary.image("Img/Fake_bad", G_bad(self.z_g_ph), max_outputs=10) 253 | tf.summary.image("Img/Normal", G_good(self.z_g_ph_linspace, tf.constant([[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0]], dtype=tf.float32)), max_outputs=10) 254 | tf.summary.image("Img/High", G_good(self.z_g_ph_linspace, tf.constant([[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0]], dtype=tf.float32)), max_outputs=10) 255 | tf.summary.image("Img/PM", G_good(self.z_g_ph_linspace, tf.constant([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0]], dtype=tf.float32)), max_outputs=10) 256 | tf.summary.image("Img/AMD", G_good(self.z_g_ph_linspace, tf.constant([[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0]], dtype=tf.float32)), max_outputs=10) 257 | tf.summary.image("Img/Glaucoma", G_good(self.z_g_ph_linspace, tf.constant([[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1]], dtype=tf.float32)), max_outputs=10) 258 | tf.summary.image("Img/Real", self.x_l_ph, max_outputs=1) 259 | tf.summary.image("Img/DiffAug", DiffAugment(self.x_l_ph), max_outputs=1) 260 | log.warn('\033[93mSuccessfully loaded the model.\033[0m') 261 | -------------------------------------------------------------------------------- /ODIR/ODIR dataset building steps.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ODIR dataset building steps\n", 8 | "To facilitate future method comparisons, we provide the detailed steps to construct ODIR dataset used in our TMI paper as follows." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "### First of all, these 16 images are first removed from the training set.\n", 16 | "The background of the following images is quite different from the rest ones. They are fundus images uploaded from the hospital.\n", 17 | "\n", 18 | "- 2174_right.jpg\n", 19 | "- 2175_left.jpg\n", 20 | "- 2176_left.jpg\n", 21 | "- 2177_left.jpg\n", 22 | "- 2177_right.jpg\n", 23 | "- 2178_right.jpg\n", 24 | "- 2179_left.jpg\n", 25 | "- 2179_right.jpg\n", 26 | "- 2180_left.jpg\n", 27 | "- 2180_right.jpg\n", 28 | "- 2181_left.jpg\n", 29 | "- 2181_right.jpg\n", 30 | "- 2182_left.jpg\n", 31 | "- 2182_right.jpg\n", 32 | "- 2957_left.jpg\n", 33 | "- 2957_right.jpg" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "### Second, we need to modify the following excel file provided by ODIR dataset:\n", 41 | "- training annotation (English).xlsx\n", 42 | "- off-site test annotation (English).xlsx\n", 43 | "- on-site test annotation (English).xlsx\n", 44 | "\n", 45 | "Specifically, we used global substitution to unify diagnostic keywords for the same disease according to the Table II in our paper. \n", 46 | "\n", 47 | "> For example, diagnostic keywords including \"Mild nonproliferative retinopathy\", \"Moderate nonproliferative retinopathy\", \"Severe nonproliferative retinopathy\", \"Proliferative diabetic retinopathy\", \"Severe proliferative diabetic retinopathy\", and \"Diabetic retinopathy\" are all replace with \"Diabetic retinopathy\".\n", 48 | "\n", 49 | "Moreover, we treat all suspected diseases or abnormalities as diagnosed diseases or abnormalities, so all \"suspected \" are replace with \"\".\n", 50 | "\n", 51 | "For the convenience of the follow-up, we upload the final excel file in https://github.com/Xyporz/CISSL-GANs/tree/main/ODIR." 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "# Third, run the following code step by step.\n", 59 | "Remembering to change the path to your file path." 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 1, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "import pandas as pd # 用于读取csv\n", 69 | "import os # 用于设定目录\n", 70 | "import shutil #用于移动文件\n", 71 | "\n", 72 | "# 目录准备\n", 73 | "Picture_Current_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/Training Set/Images/\"\n", 74 | "Current_Path =\"F:/Data/OIA-ODIR/ODIR/\"\n", 75 | "CSV_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/Training Set/Annotation/training annotation (English).xlsx\"\n", 76 | "\n", 77 | "Train_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/ODIR/Train/\"\n", 78 | "\n", 79 | "#打开目录下表格文件并读取\n", 80 | "list = pd.read_excel(CSV_Path)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 2, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "text/html": [ 91 | "
| \n", 109 | " | ID | \n", 110 | "Patient Age | \n", 111 | "Patient Sex | \n", 112 | "Left-Fundus | \n", 113 | "Right-Fundus | \n", 114 | "Left-Diagnostic Keywords | \n", 115 | "Right-Diagnostic Keywords | \n", 116 | "N | \n", 117 | "D | \n", 118 | "G | \n", 119 | "C | \n", 120 | "A | \n", 121 | "H | \n", 122 | "M | \n", 123 | "O | \n", 124 | "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", 129 | "0 | \n", 130 | "69 | \n", 131 | "Female | \n", 132 | "0_left.jpg | \n", 133 | "0_right.jpg | \n", 134 | "cataract | \n", 135 | "normal fundus | \n", 136 | "0 | \n", 137 | "0 | \n", 138 | "0 | \n", 139 | "1 | \n", 140 | "0 | \n", 141 | "0 | \n", 142 | "0 | \n", 143 | "0 | \n", 144 | "
| 1 | \n", 147 | "1 | \n", 148 | "57 | \n", 149 | "Male | \n", 150 | "1_left.jpg | \n", 151 | "1_right.jpg | \n", 152 | "normal fundus | \n", 153 | "normal fundus | \n", 154 | "1 | \n", 155 | "0 | \n", 156 | "0 | \n", 157 | "0 | \n", 158 | "0 | \n", 159 | "0 | \n", 160 | "0 | \n", 161 | "0 | \n", 162 | "
| 2 | \n", 165 | "2 | \n", 166 | "42 | \n", 167 | "Male | \n", 168 | "2_left.jpg | \n", 169 | "2_right.jpg | \n", 170 | "laser spot,diabetic retinopathy | \n", 171 | "diabetic retinopathy | \n", 172 | "0 | \n", 173 | "1 | \n", 174 | "0 | \n", 175 | "0 | \n", 176 | "0 | \n", 177 | "0 | \n", 178 | "0 | \n", 179 | "1 | \n", 180 | "
| 3 | \n", 183 | "3 | \n", 184 | "66 | \n", 185 | "Male | \n", 186 | "3_left.jpg | \n", 187 | "3_right.jpg | \n", 188 | "normal fundus | \n", 189 | "branch retinal artery occlusion | \n", 190 | "0 | \n", 191 | "0 | \n", 192 | "0 | \n", 193 | "0 | \n", 194 | "0 | \n", 195 | "0 | \n", 196 | "0 | \n", 197 | "1 | \n", 198 | "
| 4 | \n", 201 | "4 | \n", 202 | "53 | \n", 203 | "Male | \n", 204 | "4_left.jpg | \n", 205 | "4_right.jpg | \n", 206 | "macular epiretinal membrane | \n", 207 | "diabetic retinopathy | \n", 208 | "0 | \n", 209 | "1 | \n", 210 | "0 | \n", 211 | "0 | \n", 212 | "0 | \n", 213 | "0 | \n", 214 | "0 | \n", 215 | "1 | \n", 216 | "
| ... | \n", 219 | "... | \n", 220 | "... | \n", 221 | "... | \n", 222 | "... | \n", 223 | "... | \n", 224 | "... | \n", 225 | "... | \n", 226 | "... | \n", 227 | "... | \n", 228 | "... | \n", 229 | "... | \n", 230 | "... | \n", 231 | "... | \n", 232 | "... | \n", 233 | "... | \n", 234 | "
| 3495 | \n", 237 | "4686 | \n", 238 | "63 | \n", 239 | "Male | \n", 240 | "4686_left.jpg | \n", 241 | "4686_right.jpg | \n", 242 | "diabetic retinopathy | \n", 243 | "diabetic retinopathy | \n", 244 | "0 | \n", 245 | "1 | \n", 246 | "0 | \n", 247 | "0 | \n", 248 | "0 | \n", 249 | "0 | \n", 250 | "0 | \n", 251 | "0 | \n", 252 | "
| 3496 | \n", 255 | "4688 | \n", 256 | "42 | \n", 257 | "Male | \n", 258 | "4688_left.jpg | \n", 259 | "4688_right.jpg | \n", 260 | "diabetic retinopathy | \n", 261 | "diabetic retinopathy | \n", 262 | "0 | \n", 263 | "1 | \n", 264 | "0 | \n", 265 | "0 | \n", 266 | "0 | \n", 267 | "0 | \n", 268 | "0 | \n", 269 | "0 | \n", 270 | "
| 3497 | \n", 273 | "4689 | \n", 274 | "54 | \n", 275 | "Male | \n", 276 | "4689_left.jpg | \n", 277 | "4689_right.jpg | \n", 278 | "diabetic retinopathy | \n", 279 | "normal fundus | \n", 280 | "0 | \n", 281 | "1 | \n", 282 | "0 | \n", 283 | "0 | \n", 284 | "0 | \n", 285 | "0 | \n", 286 | "0 | \n", 287 | "0 | \n", 288 | "
| 3498 | \n", 291 | "4690 | \n", 292 | "57 | \n", 293 | "Male | \n", 294 | "4690_left.jpg | \n", 295 | "4690_right.jpg | \n", 296 | "diabetic retinopathy | \n", 297 | "diabetic retinopathy | \n", 298 | "0 | \n", 299 | "1 | \n", 300 | "0 | \n", 301 | "0 | \n", 302 | "0 | \n", 303 | "0 | \n", 304 | "0 | \n", 305 | "0 | \n", 306 | "
| 3499 | \n", 309 | "4784 | \n", 310 | "58 | \n", 311 | "Male | \n", 312 | "4784_left.jpg | \n", 313 | "4784_right.jpg | \n", 314 | "hypertensive retinopathy,age-related macular d... | \n", 315 | "hypertensive retinopathy,age-related macular d... | \n", 316 | "0 | \n", 317 | "0 | \n", 318 | "0 | \n", 319 | "0 | \n", 320 | "1 | \n", 321 | "1 | \n", 322 | "0 | \n", 323 | "0 | \n", 324 | "
3500 rows × 15 columns
\n", 328 | "