├── ODIR ├── training annotation (English).xlsx ├── off-site test annotation (English).xlsx ├── on-site test annotation (English).xlsx └── ODIR dataset building steps.ipynb ├── README.md ├── util.py ├── input_ops.py ├── sngan_gen.py ├── sngan_gen_y.py ├── sngan_joint.py ├── diffAugment.py ├── config.py ├── BigGAN ├── resnet_ops.py ├── resnet_gen_y.py ├── resnet_joint.py ├── resnet_gen_y_deep.py └── resnet_joint_deep.py ├── ops.py ├── trainer_joint.py ├── model_joint.py └── arch_ops.py /ODIR/training annotation (English).xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xyporz/CISSL-GANs/HEAD/ODIR/training annotation (English).xlsx -------------------------------------------------------------------------------- /ODIR/off-site test annotation (English).xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xyporz/CISSL-GANs/HEAD/ODIR/off-site test annotation (English).xlsx -------------------------------------------------------------------------------- /ODIR/on-site test annotation (English).xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xyporz/CISSL-GANs/HEAD/ODIR/on-site test annotation (English).xlsx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CISSL-GANs 2 | 3 | ### About 4 | 5 | 📫 My email: xieyingpeng2017@email.szu.edu.cn 6 | 7 | Official tensorflow implementation of the paper: 8 | #### Fundus Image-label Pairs Synthesis and Retinopathy Screening via GANs with Class-imbalanced Semi-supervised Learning 9 | IEEE Transactions on Medical Imaging (IEEE-TMI). 10 | 11 | This code is primarily based on [SSGAN-Tensorflow](https://github.com/clvrai/SSGAN-Tensorflow) and [Compare GAN](https://github.com/google/compare_gan). 12 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | """ Utilities """ 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | 8 | # Logging 9 | # ======= 10 | 11 | import logging 12 | import os, os.path 13 | from colorlog import ColoredFormatter 14 | 15 | ch = logging.StreamHandler() 16 | ch.setLevel(logging.DEBUG) 17 | 18 | formatter = ColoredFormatter( 19 | "%(log_color)s[%(asctime)s] %(message)s", 20 | # datefmt='%H:%M:%S.%f', 21 | datefmt=None, 22 | reset=True, 23 | log_colors={ 24 | 'DEBUG': 'cyan', 25 | 'INFO': 'white,bold', 26 | 'INFOV': 'cyan,bold', 27 | 'WARNING': 'yellow', 28 | 'ERROR': 'red,bold', 29 | 'CRITICAL': 'red,bg_white', 30 | }, 31 | secondary_log_colors={}, 32 | style='%' 33 | ) 34 | 35 | ch.setFormatter(formatter) 36 | 37 | log = logging.getLogger('GBGANs') 38 | log.setLevel(logging.DEBUG) 39 | log.handlers = [] # No duplicated handlers 40 | log.propagate = False # workaround for duplicated logs in ipython 41 | log.addHandler(ch) 42 | 43 | logging.addLevelName(logging.INFO + 1, 'INFOV') 44 | def _infov(self, msg, *args, **kwargs): 45 | self.log(logging.INFO + 1, msg, *args, **kwargs) 46 | 47 | logging.Logger.infov = _infov 48 | -------------------------------------------------------------------------------- /input_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from util import log 4 | 5 | def check_data_id(dataset, data_id): 6 | if not data_id: 7 | return 8 | 9 | wrong = [] 10 | for id in data_id: 11 | if id in dataset.data: 12 | pass 13 | else: 14 | wrong.append(id) 15 | 16 | if len(wrong) > 0: 17 | raise RuntimeError("There are %d invalid ids, including %s" % ( 18 | len(wrong), wrong[:5] 19 | )) 20 | 21 | 22 | def create_input_ops(dataset, 23 | batch_size, 24 | num_threads=16, # for creating batches 25 | data_id=None, 26 | scope='inputs', 27 | shuffle=True, 28 | ): 29 | ''' 30 | Return a batched tensor for the inputs from the dataset. 31 | ''' 32 | input_ops = {} 33 | 34 | if data_id is None: 35 | data_id = dataset.ids 36 | log.info("input_ops [%s]: Using %d IDs from dataset", scope, len(data_id)) 37 | else: 38 | log.info("input_ops [%s]: Using specified %d IDs", scope, len(data_id)) 39 | 40 | # single operations 41 | with tf.device("/cpu:0"), tf.name_scope(scope): 42 | input_ops['id'] = tf.train.string_input_producer( 43 | tf.convert_to_tensor(data_id), 44 | capacity=128 45 | ).dequeue(name='input_ids_dequeue') 46 | 47 | m, label = dataset.get_data(data_id[0]) 48 | 49 | def load_fn(id): 50 | image, label = dataset.get_data(id) 51 | return (id, 52 | image.astype(np.float32), 53 | label.astype(np.float32)) 54 | 55 | input_ops['id'], input_ops['image'], input_ops['label'] = tf.py_func( 56 | load_fn, inp=[input_ops['id']], 57 | Tout=[tf.string, tf.float32, tf.float32], 58 | name='func_hp' 59 | ) 60 | input_ops['id'].set_shape([]) 61 | input_ops['image'].set_shape(list(m.shape)) 62 | input_ops['label'].set_shape(list(label.shape)) 63 | 64 | # batchify 65 | capacity = 2 * batch_size * num_threads 66 | min_capacity = min(int(capacity * 0.75), 1024) 67 | 68 | if shuffle: 69 | batch_ops = tf.train.shuffle_batch( 70 | input_ops, 71 | batch_size=batch_size, 72 | num_threads=num_threads, 73 | capacity=capacity, 74 | min_after_dequeue=min_capacity, 75 | ) 76 | else: 77 | batch_ops = tf.train.batch( 78 | input_ops, 79 | batch_size=batch_size, 80 | num_threads=num_threads, 81 | capacity=capacity 82 | ) 83 | 84 | return batch_ops 85 | -------------------------------------------------------------------------------- /sngan_gen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from ops import linear 4 | from ops import deconv2d 5 | from util import log 6 | from ops import non_local_block 7 | import time 8 | 9 | def conv_out_size_same(size, stride): 10 | return int(np.ceil(float(size) / float(stride))) 11 | 12 | class Generator(object): 13 | def __init__(self, name, h, w, c, is_train, use_sn): 14 | 15 | self.name = name 16 | self.s_h, self.s_w, self.colors = [h,w,c] 17 | self.s_h2, self.s_w2 = conv_out_size_same(self.s_h, 2), conv_out_size_same(self.s_w, 2) 18 | self.s_h4, self.s_w4 = conv_out_size_same(self.s_h2, 2), conv_out_size_same(self.s_w2, 2) 19 | self.s_h8, self.s_w8 = conv_out_size_same(self.s_h4, 2), conv_out_size_same(self.s_w4, 2) 20 | self._is_train = is_train 21 | self.use_sn = use_sn 22 | 23 | def __call__(self, z): 24 | 25 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 26 | batch_size = z.shape[0] 27 | 28 | net = linear(z, self.s_h8 * self.s_w8 * 512, scope="g_fc1", use_sn=self.use_sn) 29 | net = tf.reshape(net, [batch_size, self.s_h8, self.s_w8, 512]) 30 | net = tf.contrib.layers.batch_norm(net, 31 | updates_collections=None, is_training=self._is_train, center=True, scale=True, decay=0.9, epsilon=1e-5, scope="g_bn_deconv0") 32 | net = tf.nn.relu(net) 33 | 34 | net = deconv2d(net, [batch_size, self.s_h4, self.s_w4, 256], 4, 4, 2, 2, name="g_dc1", use_sn=self.use_sn) 35 | net = tf.contrib.layers.batch_norm(net, 36 | updates_collections=None, is_training=self._is_train, center=True, scale=True, decay=0.9, epsilon=1e-5, scope="g_bn_deconv1") 37 | net = tf.nn.relu(net) 38 | 39 | net = deconv2d(net, [batch_size, self.s_h2, self.s_w2, 128], 4, 4, 2, 2, name="g_dc2", use_sn=self.use_sn) 40 | net = tf.contrib.layers.batch_norm(net, 41 | updates_collections=None, is_training=self._is_train, center=True, scale=True, decay=0.9, epsilon=1e-5, scope="g_bn_deconv2") 42 | net = tf.nn.relu(net) 43 | 44 | net = deconv2d(net, [batch_size, self.s_h, self.s_w, 64], 4, 4, 2, 2, name="g_dc3", use_sn=self.use_sn) 45 | net = tf.contrib.layers.batch_norm(net, 46 | updates_collections=None, is_training=self._is_train, center=True, scale=True, decay=0.9, epsilon=1e-5, scope="g_bn_deconv3") 47 | net = tf.nn.relu(net) 48 | 49 | net = deconv2d(net, [batch_size, self.s_h, self.s_w, self.colors], 3, 3, 1, 1, name="g_dc4", use_sn=self.use_sn) 50 | out = tf.div(tf.nn.tanh(net) + 1.0, 2.0) 51 | 52 | return out -------------------------------------------------------------------------------- /sngan_gen_y.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from ops import linear 4 | from ops import deconv2d 5 | from util import log 6 | from ops import non_local_block, conditional_batch_norm 7 | import time 8 | 9 | def conv_out_size_same(size, stride): 10 | return int(np.ceil(float(size) / float(stride))) 11 | 12 | class Generator(object): 13 | def __init__(self, name, h, w, c, is_train, use_sn): 14 | 15 | self.name = name 16 | self.s_h, self.s_w, self.colors = [h,w,c] 17 | self.s_h2, self.s_w2 = conv_out_size_same(self.s_h, 2), conv_out_size_same(self.s_w, 2) 18 | self.s_h4, self.s_w4 = conv_out_size_same(self.s_h2, 2), conv_out_size_same(self.s_w2, 2) 19 | self.s_h8, self.s_w8 = conv_out_size_same(self.s_h4, 2), conv_out_size_same(self.s_w4, 2) 20 | self._is_train = is_train 21 | self.use_sn = use_sn 22 | self._embed_y_dim = 128 23 | 24 | def __call__(self, z, y): 25 | 26 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 27 | batch_size = z.shape[0] 28 | 29 | y = tf.concat([z, y], axis=1) 30 | z = y 31 | 32 | net = linear(z, self.s_h8 * self.s_w8 * 512, scope="g_fc1", use_sn=self.use_sn) 33 | net = tf.reshape(net, [batch_size, self.s_h8, self.s_w8, 512]) 34 | net = conditional_batch_norm(net, y, is_training=self._is_train, use_sn = self.use_sn, center=True, scale=True, name="g_cbn_deconv0", use_bias=False) 35 | net = tf.nn.relu(net) 36 | 37 | net = deconv2d(net, [batch_size, self.s_h4, self.s_w4, 256], 4, 4, 2, 2, name="g_dc1", use_sn=self.use_sn) 38 | net = conditional_batch_norm(net, y, is_training=self._is_train, use_sn = self.use_sn, center=True, scale=True, name="g_cbn_deconv1", use_bias=False) 39 | net = tf.nn.relu(net) 40 | 41 | net = deconv2d(net, [batch_size, self.s_h2, self.s_w2, 128], 4, 4, 2, 2, name="g_dc2", use_sn=self.use_sn) 42 | net = conditional_batch_norm(net, y, is_training=self._is_train, use_sn = self.use_sn, center=True, scale=True, name="g_cbn_deconv2", use_bias=False) 43 | net = tf.nn.relu(net) 44 | 45 | net = deconv2d(net, [batch_size, self.s_h, self.s_w, 64], 4, 4, 2, 2, name="g_dc3", use_sn=self.use_sn) 46 | net = tf.contrib.layers.batch_norm(net, 47 | updates_collections=None, is_training=self._is_train, center=True, scale=True, decay=0.9, epsilon=1e-5, scope="g_cbn_deconv3") 48 | net = tf.nn.relu(net) 49 | 50 | net = deconv2d(net, [batch_size, self.s_h, self.s_w, self.colors], 3, 3, 1, 1, name="g_dc4", use_sn=self.use_sn) 51 | out = tf.div(tf.nn.tanh(net) + 1.0, 2.0) 52 | 53 | return out 54 | -------------------------------------------------------------------------------- /sngan_joint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from ops import conv2d 4 | from ops import linear 5 | from util import log 6 | from ops import lrelu 7 | from ops import non_local_block, spectral_norm 8 | import time 9 | from tensorflow.python.framework import ops 10 | from tensorflow.python.ops import gen_nn_ops 11 | 12 | import cv2 13 | 14 | def normalize(x): 15 | """Utility function to normalize a tensor by its L2 norm""" 16 | return tf.div(x + tf.constant(1e-10), tf.sqrt(tf.reduce_mean(tf.square(x))) + tf.constant(1e-10)) 17 | 18 | class Classifier_proD(object): 19 | def __init__(self, name, num_class, use_sn): 20 | 21 | self.name = name 22 | self._num_class = num_class 23 | self.use_sn = use_sn 24 | 25 | def __call__(self, input, y): 26 | 27 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 28 | 29 | input = input * 2.0 - 1.0 30 | 31 | net = conv2d(input, 64, 3, 3, 1, 1, name="d_conv1", use_sn=self.use_sn) 32 | net = lrelu(net, leak=0.1) 33 | 34 | net = conv2d(net, 128, 4, 4, 2, 2, name="d_conv2", use_sn=self.use_sn) 35 | net = lrelu(net, leak=0.1) 36 | 37 | net = conv2d(net, 128, 3, 3, 1, 1, name="d_conv3", use_sn=self.use_sn) 38 | net = lrelu(net, leak=0.1) 39 | 40 | net = conv2d(net, 256, 4, 4, 2, 2, name="d_conv4", use_sn=self.use_sn) 41 | net = lrelu(net, leak=0.1) 42 | 43 | net = conv2d(net, 256, 3, 3, 1, 1, name="d_conv5", use_sn=self.use_sn) 44 | net = lrelu(net, leak=0.1) 45 | 46 | net = conv2d(net, 512, 4, 4, 2, 2, name="d_conv6", use_sn=self.use_sn) 47 | net = lrelu(net, leak=0.1) 48 | 49 | net = conv2d(net, 512, 3, 3, 1, 1, name="d_conv7", use_sn=self.use_sn) 50 | net_conv = lrelu(net, leak=0.1) 51 | 52 | h = tf.layers.flatten(net_conv) 53 | out_logit = linear(h, self._num_class, scope="d_fc1", use_sn=self.use_sn) 54 | out_logit_tf = linear(h, 1, scope="final_fc", use_sn=self.use_sn) 55 | 56 | feature_matching = h 57 | 58 | log.info("[Discriminator] after final processing: %s", net_conv.shape) 59 | with tf.variable_scope("embedding_fc", reuse=tf.AUTO_REUSE): 60 | # We do not use ops.linear() below since it does not have an option to 61 | # override the initializer. 62 | kernel = tf.get_variable( 63 | "kernel", [y.shape[1], h.shape[1]], tf.float32, 64 | initializer=tf.initializers.glorot_normal()) 65 | if self.use_sn: 66 | kernel = spectral_norm(kernel) 67 | embedded_y = tf.matmul(y, kernel) 68 | out_logit_tf += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True) 69 | 70 | return tf.nn.softmax(out_logit), out_logit, tf.nn.sigmoid(out_logit_tf), out_logit_tf, feature_matching 71 | -------------------------------------------------------------------------------- /diffAugment.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def DiffAugment(x, policy="color,translation,cutout", channels_first=False): 4 | if policy: 5 | if channels_first: 6 | x = tf.transpose(x, [0, 2, 3, 1]) 7 | for p in policy.split(','): 8 | for f in AUGMENT_FNS[p]: 9 | x = f(x) 10 | if channels_first: 11 | x = tf.transpose(x, [0, 3, 1, 2]) 12 | return x 13 | 14 | 15 | def rand_brightness(x): 16 | magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) - 0.5 17 | x = x + magnitude 18 | return x 19 | 20 | 21 | def rand_saturation(x): 22 | magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) * 2 23 | x_mean = tf.reduce_mean(x, axis=3, keepdims=True) 24 | x = (x - x_mean) * magnitude + x_mean 25 | return x 26 | 27 | 28 | def rand_contrast(x): 29 | magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) + 0.5 30 | x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True) 31 | x = (x - x_mean) * magnitude + x_mean 32 | return x 33 | 34 | 35 | def rand_translation(x, ratio=0.125): 36 | batch_size = tf.shape(x)[0] 37 | image_size = tf.shape(x)[1:3] 38 | shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32) 39 | translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32) 40 | translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32) 41 | grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1, 0, image_size[0] + 1) 42 | grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1, 0, image_size[1] + 1) 43 | x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1) 44 | x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1), [0, 2, 1, 3]) 45 | return x 46 | 47 | 48 | def rand_cutout(x, ratio=0.5): 49 | batch_size = tf.shape(x)[0] 50 | image_size = tf.shape(x)[1:3] 51 | cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32) 52 | offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32) 53 | offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32) 54 | grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32), tf.range(cutout_size[0], dtype=tf.int32), tf.range(cutout_size[1], dtype=tf.int32), indexing='ij') 55 | cutout_grid = tf.stack([grid_batch, grid_x + offset_x - cutout_size[0] // 2, grid_y + offset_y - cutout_size[1] // 2], axis=-1) 56 | mask_shape = tf.stack([batch_size, image_size[0], image_size[1]]) 57 | cutout_grid = tf.maximum(cutout_grid, 0) 58 | cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3])) 59 | mask = tf.maximum(1 - tf.scatter_nd(cutout_grid, tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32), mask_shape), 0) 60 | x = x * tf.expand_dims(mask, axis=3) 61 | return x 62 | 63 | 64 | AUGMENT_FNS = { 65 | 'color': [rand_brightness, rand_saturation, rand_contrast], 66 | 'translation': [rand_translation], 67 | 'cutout': [rand_cutout], 68 | } -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from model_joint import Model 5 | import datasets.hdf5_loader as dataset 6 | import numpy as np 7 | import pandas as pd 8 | from six.moves import xrange 9 | from util import log 10 | 11 | def get_params(): 12 | 13 | def str2bool(v): 14 | return v.lower() == 'true' 15 | 16 | parser = argparse.ArgumentParser( 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | 19 | parser.add_argument('--debug', action='store_true', default=False) 20 | parser.add_argument('--prefix', type=str, default='GANs') 21 | parser.add_argument('--train_dir', type=str) 22 | parser.add_argument('--checkpoint', type=str, default=None) 23 | parser.add_argument('--dataset', type=str, default="Orgimage_128", choices=["Orgimage_128"]) 24 | parser.add_argument('--dump_result', type=str2bool, default=False) 25 | 26 | # Model 27 | parser.add_argument('--batch_size_G', type=int, default=16) 28 | parser.add_argument('--batch_size_L', type=int, default=16) 29 | parser.add_argument('--batch_size_U', type=int, default=16) 30 | parser.add_argument('--n_z', type=int, default=128) 31 | 32 | # Training config {{{ 33 | # ======== 34 | # log 35 | parser.add_argument('--log_step', type=int, default=10) 36 | parser.add_argument('--save_image_step', type=int, default=10000) 37 | parser.add_argument('--test_sample_step', type=int, default=100) 38 | parser.add_argument('--output_save_step', type=int, default=10000) 39 | # learning 40 | parser.add_argument('--max_training_steps', type=int, default=250001) 41 | parser.add_argument('--learning_rate_g', type=float, default=2e-4) 42 | parser.add_argument('--learning_rate_d', type=float, default=5e-5) 43 | parser.add_argument('--update_rate', type=int, default=2) 44 | # }}} 45 | 46 | # Testing config {{{ 47 | # ======== 48 | parser.add_argument('--data_id', nargs='*', default=None) 49 | # }}} 50 | 51 | config, _ = parser.parse_known_args() 52 | args = parser.parse_args() 53 | 54 | return config, args 55 | 56 | def argparser(config, is_train=True): 57 | 58 | dataset_path = os.path.join('./datasets/TMI/', config["dataset"].lower()) 59 | dataset_train, dataset_val, dataset_test = dataset.create_default_splits(dataset_path) 60 | dataset_train_unlabel, _ = dataset.create_default_splits_unlabel(dataset_path) 61 | 62 | config["img"] = [] 63 | labels = [] 64 | with open('./datasets/metadata.tsv','w') as f: 65 | f.write("Index\tLabel\n") 66 | for index, labeldex in enumerate(dataset_test.ids): 67 | config["img"].append(dataset_test.get_data(labeldex)[0]) 68 | label = np.argmax(dataset_test.get_data(labeldex)[1]) 69 | labels.append(label) 70 | f.write("%d\t%d\n" % (index, label)) 71 | config["img"] = np.array(config["img"]) 72 | log.info(config["img"].shape) 73 | config["len"] = config["img"].shape[0] 74 | config["label"] = labels 75 | 76 | config["Size"] = config["len"] 77 | picname = [] 78 | for i in xrange(config["Size"]): 79 | picname.append("V{step:04d}.jpg".format(step=i+1)) 80 | Csv=pd.DataFrame(columns=['Label'], index=picname, data=labels) 81 | Csv.to_csv('./datasets/Classification_Results_label.csv',encoding='gbk') 82 | 83 | img, label = dataset_train.get_data(dataset_train.ids[0]) 84 | config["h"] = img.shape[0] 85 | config["w"] = img.shape[1] 86 | config["c"] = img.shape[2] 87 | config["num_class"] = label.shape[0] 88 | 89 | # --- create model --- 90 | model = Model(config, debug_information=config["debug"], is_train=is_train) 91 | 92 | return config, model, dataset_train, dataset_train_unlabel, dataset_val, dataset_test 93 | -------------------------------------------------------------------------------- /BigGAN/resnet_ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ResNet specific operations. 17 | Defines the default ResNet generator and discriminator blocks and some helper 18 | operations such as unpooling. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import math 26 | 27 | from compare_gan.architectures import abstract_arch 28 | from compare_gan.architectures import arch_ops as ops 29 | 30 | from six.moves import range 31 | import tensorflow as tf 32 | 33 | 34 | def unpool(value, name="unpool"): 35 | """Unpooling operation. 36 | N-dimensional version of the unpooling operation from 37 | https://www.robots.ox.ac.uk/~vgg/rg/papers/Dosovitskiy_Learning_to_Generate_2015_CVPR_paper.pdf 38 | Taken from: https://github.com/tensorflow/tensorflow/issues/2169 39 | Args: 40 | value: a Tensor of shape [b, d0, d1, ..., dn, ch] 41 | name: name of the op 42 | Returns: 43 | A Tensor of shape [b, 2*d0, 2*d1, ..., 2*dn, ch] 44 | """ 45 | with tf.name_scope(name) as scope: 46 | sh = value.get_shape().as_list() 47 | dim = len(sh[1:-1]) 48 | out = (tf.reshape(value, [-1] + sh[-dim:])) 49 | for i in range(dim, 0, -1): 50 | out = tf.concat([out, tf.zeros_like(out)], i) 51 | out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]] 52 | out = tf.reshape(out, out_size, name=scope) 53 | return out 54 | 55 | 56 | def validate_image_inputs(inputs, validate_power2=True): 57 | inputs.get_shape().assert_has_rank(4) 58 | inputs.get_shape()[1:3].assert_is_fully_defined() 59 | # if inputs.get_shape()[1] != inputs.get_shape()[2]: 60 | # raise ValueError("Input tensor does not have equal width and height: ", 61 | # inputs.get_shape()[1:3]) 62 | width = inputs.get_shape().as_list()[1] 63 | if validate_power2 and math.log(width, 2) != int(math.log(width, 2)): 64 | raise ValueError("Input tensor `width` is not a power of 2: ", width) 65 | 66 | 67 | class ResNetBlock(object): 68 | """ResNet block with options for various normalizations.""" 69 | 70 | def __init__(self, 71 | name, 72 | in_channels, 73 | out_channels, 74 | scale, 75 | is_gen_block, 76 | layer_norm=False, 77 | spectral_norm=False, 78 | batch_norm=None): 79 | """Constructs a new ResNet block. 80 | Args: 81 | name: Scope name for the resent block. 82 | in_channels: Integer, the input channel size. 83 | out_channels: Integer, the output channel size. 84 | scale: Whether or not to scale up or down, choose from "up", "down" or 85 | "none". 86 | is_gen_block: Boolean, deciding whether this is a generator or 87 | discriminator block. 88 | layer_norm: Apply layer norm before both convolutions. 89 | spectral_norm: Use spectral normalization for all weights. 90 | batch_norm: Function for batch normalization. 91 | """ 92 | assert scale in ["up", "down", "none"] 93 | self._name = name 94 | self._in_channels = in_channels 95 | self._out_channels = out_channels 96 | self._scale = scale 97 | # In SN paper, if they upscale in generator they do this in the first conv. 98 | # For discriminator downsampling happens after second conv. 99 | self._scale1 = scale if is_gen_block else "none" 100 | self._scale2 = "none" if is_gen_block else scale 101 | self._layer_norm = layer_norm 102 | self._spectral_norm = spectral_norm 103 | self.batch_norm = batch_norm 104 | 105 | def __call__(self, inputs, z, y, is_training): 106 | return self.apply(inputs=inputs, z=z, y=y, is_training=is_training) 107 | 108 | def _get_conv(self, inputs, in_channels, out_channels, scale, suffix, 109 | kernel_size=(3, 3), strides=(1, 1)): 110 | """Performs a convolution in the ResNet block.""" 111 | if inputs.get_shape().as_list()[-1] != in_channels: 112 | raise ValueError("Unexpected number of input channels.") 113 | if scale not in ["up", "down", "none"]: 114 | raise ValueError( 115 | "Scale: got {}, expected 'up', 'down', or 'none'.".format(scale)) 116 | 117 | outputs = inputs 118 | if scale == "up": 119 | outputs = unpool(outputs) 120 | outputs = ops.conv2d( 121 | outputs, 122 | output_dim=out_channels, 123 | k_h=kernel_size[0], k_w=kernel_size[1], 124 | d_h=strides[0], d_w=strides[1], 125 | use_sn=self._spectral_norm, 126 | name="{}_{}".format("same" if scale == "none" else scale, suffix)) 127 | if scale == "down": 128 | outputs = tf.nn.pool(outputs, [2, 2], "AVG", "SAME", strides=[2, 2], 129 | name="pool_%s" % suffix) 130 | return outputs 131 | 132 | def apply(self, inputs, z, y, is_training): 133 | """"ResNet block containing possible down/up sampling, shared for G / D. 134 | Args: 135 | inputs: a 3d input tensor of feature map. 136 | z: the latent vector for potential self-modulation. Can be None if use_sbn 137 | is set to False. 138 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 139 | labels. 140 | is_training: boolean, whether or notthis is called during the training. 141 | Returns: 142 | output: a 3d output tensor of feature map. 143 | """ 144 | if inputs.get_shape().as_list()[-1] != self._in_channels: 145 | raise ValueError("Unexpected number of input channels.") 146 | 147 | with tf.variable_scope(self._name, values=[inputs]): 148 | output = inputs 149 | 150 | shortcut = self._get_conv( 151 | output, self._in_channels, self._out_channels, self._scale, 152 | suffix="conv_shortcut") 153 | 154 | output = self.batch_norm( 155 | output, z=z, y=y, is_training=is_training, name="bn1") 156 | if self._layer_norm: 157 | output = ops.layer_norm(output, is_training=is_training, scope="ln1") 158 | 159 | output = tf.nn.relu(output) 160 | output = self._get_conv( 161 | output, self._in_channels, self._out_channels, self._scale1, 162 | suffix="conv1") 163 | 164 | output = self.batch_norm( 165 | output, z=z, y=y, is_training=is_training, name="bn2") 166 | if self._layer_norm: 167 | output = ops.layer_norm(output, is_training=is_training, scope="ln2") 168 | 169 | output = tf.nn.relu(output) 170 | output = self._get_conv( 171 | output, self._out_channels, self._out_channels, self._scale2, 172 | suffix="conv2") 173 | 174 | # Combine skip-connection with the convolved part. 175 | output += shortcut 176 | return output 177 | 178 | 179 | class ResNetGenerator(abstract_arch.AbstractGenerator): 180 | """Abstract base class for generators based on the ResNet architecture.""" 181 | 182 | def _resnet_block(self, name, in_channels, out_channels, scale): 183 | """ResNet block for the generator.""" 184 | if scale not in ["up", "none"]: 185 | raise ValueError( 186 | "Unknown generator ResNet block scaling: {}.".format(scale)) 187 | return ResNetBlock( 188 | name=name, 189 | in_channels=in_channels, 190 | out_channels=out_channels, 191 | scale=scale, 192 | is_gen_block=True, 193 | spectral_norm=self._spectral_norm, 194 | batch_norm=self.batch_norm) 195 | 196 | 197 | class ResNetDiscriminator(abstract_arch.AbstractDiscriminator): 198 | """Abstract base class for discriminators based on the ResNet architecture.""" 199 | 200 | def _resnet_block(self, name, in_channels, out_channels, scale): 201 | """ResNet block for the generator.""" 202 | if scale not in ["down", "none"]: 203 | raise ValueError( 204 | "Unknown discriminator ResNet block scaling: {}.".format(scale)) 205 | return ResNetBlock( 206 | name=name, 207 | in_channels=in_channels, 208 | out_channels=out_channels, 209 | scale=scale, 210 | is_gen_block=False, 211 | layer_norm=self._layer_norm, 212 | spectral_norm=self._spectral_norm, 213 | batch_norm=self.batch_norm) -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as layers 3 | import tensorflow.contrib.slim as slim 4 | from util import log 5 | import functools 6 | 7 | def weight_initializer(initializer="org", stddev=0.02): 8 | """Returns the initializer for the given name. 9 | Args: 10 | initializer: Name of the initalizer. Use one in consts.INITIALIZERS. 11 | stddev: Standard deviation passed to initalizer. 12 | Returns: 13 | Initializer from `tf.initializers`. 14 | """ 15 | if initializer == "normal": 16 | return tf.initializers.random_normal(stddev=stddev) 17 | if initializer == "truncated": 18 | return tf.initializers.truncated_normal(stddev=stddev) 19 | if initializer == "org": 20 | return tf.initializers.orthogonal() 21 | raise ValueError("Unknown weight initializer {}.".format(initializer)) 22 | 23 | def lrelu(inputs, leak=0.2, name = "lrelu"): 24 | """Performs leaky-ReLU on the input.""" 25 | return tf.maximum(inputs, leak * inputs, name=name) 26 | 27 | def spectral_norm(inputs, epsilon=1e-12, singular_value="auto"): 28 | """Performs Spectral Normalization on a weight tensor. 29 | Details of why this is helpful for GAN's can be found in "Spectral 30 | Normalization for Generative Adversarial Networks", Miyato T. et al., 2018. 31 | [https://arxiv.org/abs/1802.05957]. 32 | Args: 33 | inputs: The weight tensor to normalize. 34 | epsilon: Epsilon for L2 normalization. 35 | singular_value: Which first singular value to store (left or right). Use 36 | "auto" to automatically choose the one that has fewer dimensions. 37 | Returns: 38 | The normalized weight tensor. 39 | """ 40 | if len(inputs.shape) < 2: 41 | raise ValueError("Spectral norm can only be applied to multi-dimensional tensors") 42 | 43 | # The paper says to flatten convnet kernel weights from (C_out, C_in, KH, KW) 44 | # to (C_out, C_in * KH * KW). Our Conv2D kernel shape is (KH, KW, C_in, C_out) 45 | # so it should be reshaped to (KH * KW * C_in, C_out), and similarly for other 46 | # layers that put output channels as last dimension. This implies that w 47 | # here is equivalent to w.T in the paper. 48 | w = tf.reshape(inputs, (-1, inputs.shape[-1])) 49 | 50 | # Choose whether to persist the first left or first right singular vector. 51 | # As the underlying matrix is PSD, this should be equivalent, but in practice 52 | # the shape of the persisted vector is different. Here one can choose whether 53 | # to maintain the left or right one, or pick the one which has the smaller 54 | # dimension. We use the same variable for the singular vector if we switch 55 | # from normal weights to EMA weights. 56 | var_name = inputs.name.replace("/ExponentialMovingAverage", "").split("/")[-1] 57 | var_name = var_name.split(":")[0] + "/u_var" 58 | if singular_value == "auto": 59 | singular_value = "left" if w.shape[0] <= w.shape[1] else "right" 60 | u_shape = (w.shape[0], 1) if singular_value == "left" else (1, w.shape[-1]) 61 | u_var = tf.get_variable( 62 | var_name, 63 | shape=u_shape, 64 | dtype=w.dtype, 65 | initializer=tf.random_normal_initializer(), 66 | trainable=False) 67 | u = u_var 68 | 69 | # Use power iteration method to approximate the spectral norm. 70 | # The authors suggest that one round of power iteration was sufficient in the 71 | # actual experiment to achieve satisfactory performance. 72 | power_iteration_rounds = 1 73 | for _ in range(power_iteration_rounds): 74 | if singular_value == "left": 75 | # `v` approximates the first right singular vector of matrix `w`. 76 | v = tf.math.l2_normalize( 77 | tf.matmul(tf.transpose(w), u), axis=None, epsilon=epsilon) 78 | u = tf.math.l2_normalize(tf.matmul(w, v), axis=None, epsilon=epsilon) 79 | else: 80 | v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True), 81 | epsilon=epsilon) 82 | u = tf.math.l2_normalize(tf.matmul(v, w), epsilon=epsilon) 83 | 84 | # Update the approximation. 85 | with tf.control_dependencies([tf.assign(u_var, u, name="update_u")]): 86 | u = tf.identity(u) 87 | 88 | # The authors of SN-GAN chose to stop gradient propagating through u and v 89 | # and we maintain that option. 90 | u = tf.stop_gradient(u) 91 | v = tf.stop_gradient(v) 92 | 93 | if singular_value == "left": 94 | norm_value = tf.matmul(tf.matmul(tf.transpose(u), w), v) 95 | else: 96 | norm_value = tf.matmul(tf.matmul(v, w), u, transpose_b=True) 97 | norm_value.shape.assert_is_fully_defined() 98 | norm_value.shape.assert_is_compatible_with([1, 1]) 99 | 100 | w_normalized = w / norm_value 101 | 102 | # Deflate normalized weights to match the unnormalized tensor. 103 | w_tensor_normalized = tf.reshape(w_normalized, inputs.shape) 104 | return w_tensor_normalized 105 | 106 | def conv2d(inputs, output_dim, k_h, k_w, d_h, d_w, stddev=0.02, 107 | name="conv2d", use_sn=False, use_bias=True): 108 | with tf.variable_scope(name): 109 | w = tf.get_variable( 110 | "kernel", [k_h, k_w, inputs.shape[-1].value, output_dim], 111 | initializer=weight_initializer(stddev=stddev)) 112 | if use_sn: 113 | w = spectral_norm(w) 114 | outputs = tf.nn.conv2d(inputs, w, strides=[1, d_h, d_w, 1], padding="SAME") 115 | if use_bias: 116 | bias = tf.get_variable( 117 | "bias", [output_dim], initializer=tf.constant_initializer(0.0)) 118 | outputs += bias 119 | return outputs 120 | 121 | def deconv2d(inputs, output_shape, k_h, k_w, d_h, d_w, stddev=0.02, 122 | name='deconv2d', use_sn=False): 123 | with tf.variable_scope(name): 124 | w = tf.get_variable( 125 | "kernel", [k_h, k_w, output_shape[-1], inputs.get_shape()[-1]], 126 | initializer=weight_initializer(stddev=stddev)) 127 | if use_sn: 128 | w = spectral_norm(w) 129 | deconv = tf.nn.conv2d_transpose( 130 | inputs, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) 131 | bias = tf.get_variable( 132 | "bias", [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 133 | return tf.reshape(tf.nn.bias_add(deconv, bias), tf.shape(deconv)) 134 | 135 | def linear(inputs, output_size, scope=None, stddev=0.02, bias_start=0.0, use_sn=False, use_bias=True): 136 | shape = inputs.get_shape().as_list() 137 | with tf.variable_scope(scope or "Linear"): 138 | kernel = tf.get_variable( 139 | "kernel", 140 | [shape[1], output_size], 141 | initializer=weight_initializer(stddev=stddev)) 142 | if use_sn: 143 | kernel = spectral_norm(kernel) 144 | outputs = tf.matmul(inputs, kernel) 145 | if use_bias: 146 | bias = tf.get_variable( 147 | "bias", 148 | [output_size], 149 | initializer=tf.constant_initializer(bias_start)) 150 | outputs += bias 151 | return outputs 152 | 153 | conv1x1 = functools.partial(conv2d, k_h=1, k_w=1, d_h=1, d_w=1) 154 | 155 | def non_local_block(x, name, use_sn): 156 | """Self-attention (non-local) block. 157 | This method is used to exactly reproduce SAGAN and ignores Gin settings on 158 | weight initialization and spectral normalization. 159 | Args: 160 | x: Input tensor of shape [batch, h, w, c]. 161 | name: Name of the variable scope. 162 | use_sn: Apply spectral norm to the weights. 163 | Returns: 164 | A tensor of the same shape after self-attention was applied. 165 | """ 166 | def _spatial_flatten(inputs): 167 | shape = inputs.shape 168 | return tf.reshape(inputs, (-1, shape[1] * shape[2], shape[3])) 169 | 170 | with tf.variable_scope(name): 171 | h, w, num_channels = x.get_shape().as_list()[1:] 172 | num_channels_attn = num_channels // 8 173 | num_channels_g = num_channels // 2 174 | 175 | # Theta path 176 | theta = conv1x1(x, num_channels_attn, name="conv2d_theta", use_sn=use_sn, 177 | use_bias=False) 178 | theta = _spatial_flatten(theta) 179 | 180 | # Phi path 181 | phi = conv1x1(x, num_channels_attn, name="conv2d_phi", use_sn=use_sn, use_bias=False) 182 | phi = tf.layers.max_pooling2d(inputs=phi, pool_size=[2, 2], strides=2) 183 | phi = _spatial_flatten(phi) 184 | 185 | attn = tf.matmul(theta, phi, transpose_b=True) 186 | attn = tf.nn.softmax(attn) 187 | 188 | # G path 189 | g = conv1x1(x, num_channels_g, name="conv2d_g", use_sn=use_sn, 190 | use_bias=False) 191 | g = tf.layers.max_pooling2d(inputs=g, pool_size=[2, 2], strides=2) 192 | g = _spatial_flatten(g) 193 | 194 | attn_g = tf.matmul(attn, g) 195 | attn_g = tf.reshape(attn_g, [-1, h, w, num_channels_g]) 196 | sigma = tf.get_variable("sigma", [], initializer=tf.zeros_initializer()) 197 | attn_g = conv1x1(attn_g, num_channels, name="conv2d_attn_g", use_sn=use_sn, 198 | use_bias=False) 199 | 200 | return x + sigma * attn_g 201 | 202 | def conditional_batch_norm(inputs, y, is_training, use_sn, center=True, 203 | scale=True, name="batch_norm", use_bias=False): 204 | """Conditional batch normalization.""" 205 | if y is None: 206 | raise ValueError("You must provide y for conditional batch normalization.") 207 | if y.shape.ndims != 2: 208 | raise ValueError("Conditioning must have rank 2.") 209 | with tf.variable_scope(name, values=[inputs]): 210 | outputs = tf.contrib.layers.batch_norm(inputs, center=False, scale=False, decay=0.9, epsilon=1e-5, updates_collections=None, is_training=is_training) 211 | num_channels = inputs.shape[-1].value 212 | with tf.variable_scope("condition", values=[inputs, y]): 213 | if scale: 214 | gamma = linear(y, num_channels, scope="gamma", use_sn=use_sn, 215 | use_bias=use_bias) 216 | gamma = tf.reshape(gamma, [-1, 1, 1, num_channels]) 217 | outputs *= gamma 218 | if center: 219 | beta = linear(y, num_channels, scope="beta", use_sn=use_sn, 220 | use_bias=use_bias) 221 | beta = tf.reshape(beta, [-1, 1, 1, num_channels]) 222 | outputs += beta 223 | return outputs -------------------------------------------------------------------------------- /BigGAN/resnet_gen_y.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from absl import logging 6 | 7 | import arch_ops as ops 8 | import resnet_ops 9 | 10 | import utils 11 | 12 | import gin 13 | from six.moves import range 14 | import tensorflow as tf 15 | 16 | class BigGanResNetBlock(resnet_ops.ResNetBlock): 17 | """ResNet block with options for various normalizations. 18 | This block uses a 1x1 convolution for the (optional) shortcut connection. 19 | """ 20 | 21 | def __init__(self, 22 | add_shortcut=True, 23 | **kwargs): 24 | """Constructs a new ResNet block for BigGAN. 25 | Args: 26 | add_shortcut: Whether to add a shortcut connection. 27 | **kwargs: Additional arguments for ResNetBlock. 28 | """ 29 | super(BigGanResNetBlock, self).__init__(**kwargs) 30 | self._add_shortcut = add_shortcut 31 | 32 | def apply(self, inputs, z, y, is_training): 33 | """"ResNet block containing possible down/up sampling, shared for G / D. 34 | Args: 35 | inputs: a 3d input tensor of feature map. 36 | z: the latent vector for potential self-modulation. Can be None if use_sbn 37 | is set to False. 38 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 39 | labels. 40 | is_training: boolean, whether or notthis is called during the training. 41 | Returns: 42 | output: a 3d output tensor of feature map. 43 | """ 44 | if inputs.shape[-1].value != self._in_channels: 45 | raise ValueError( 46 | "Unexpected number of input channels (expected {}, got {}).".format( 47 | self._in_channels, inputs.shape[-1].value)) 48 | 49 | with tf.variable_scope(self._name, values=[inputs]): 50 | outputs = inputs 51 | 52 | outputs = self.batch_norm( 53 | outputs, z=z, y=y, is_training=is_training, name="bn1") 54 | if self._layer_norm: 55 | outputs = ops.layer_norm(outputs, is_training=is_training, scope="ln1") 56 | 57 | outputs = tf.nn.relu(outputs) 58 | outputs = self._get_conv( 59 | outputs, self._in_channels, self._out_channels, self._scale1, 60 | suffix="conv1") 61 | 62 | outputs = self.batch_norm( 63 | outputs, z=z, y=y, is_training=is_training, name="bn2") 64 | if self._layer_norm: 65 | outputs = ops.layer_norm(outputs, is_training=is_training, scope="ln2") 66 | 67 | outputs = tf.nn.relu(outputs) 68 | outputs = self._get_conv( 69 | outputs, self._out_channels, self._out_channels, self._scale2, 70 | suffix="conv2") 71 | 72 | # Combine skip-connection with the convolved part. 73 | if self._add_shortcut: 74 | shortcut = self._get_conv( 75 | inputs, self._in_channels, self._out_channels, self._scale, 76 | kernel_size=(1, 1), 77 | suffix="conv_shortcut") 78 | outputs += shortcut 79 | logging.info("[Block] %s (z=%s, y=%s) -> %s", inputs.shape, 80 | None if z is None else z.shape, 81 | None if y is None else y.shape, outputs.shape) 82 | return outputs 83 | 84 | class Generator(object): 85 | """ResNet-based generator supporting resolutions 32, 64, 128, 256, 512.""" 86 | 87 | def __init__(self, 88 | name, h, w, c, is_train, use_sn, 89 | ch=96, 90 | blocks_with_attention="B4", 91 | hierarchical_z=True, 92 | embed_z=False, 93 | embed_y=True, 94 | embed_y_dim=128, 95 | embed_bias=False, 96 | **kwargs): 97 | """Constructor for BigGAN generator. 98 | Args: 99 | ch: Channel multiplier. 100 | blocks_with_attention: Comma-separated list of blocks that are followed by 101 | a non-local block. 102 | hierarchical_z: Split z into chunks and only give one chunk to each. 103 | Each chunk will also be concatenated to y, the one hot encoded labels. 104 | embed_z: If True use a learnable embedding of z that is used instead. 105 | The embedding will have the length of z. 106 | embed_y: If True use a learnable embedding of y that is used instead. 107 | embed_y_dim: Size of the embedding of y. 108 | embed_bias: Use bias with for the embedding of z and y. 109 | **kwargs: additional arguments past on to ResNetGenerator. 110 | """ 111 | self.name = name 112 | self.s_h, self.s_w, self.colors = [h,w,c] 113 | self._image_shape = [h,w,c] 114 | self._is_train = is_train 115 | self._batch_norm_fn = ops.conditional_batch_norm 116 | self._spectral_norm = True 117 | 118 | self._ch = ch 119 | self._blocks_with_attention = set(blocks_with_attention.split(",")) 120 | self._hierarchical_z = hierarchical_z 121 | self._embed_z = embed_z 122 | self._embed_y = embed_y 123 | self._embed_y_dim = embed_y_dim 124 | self._embed_bias = embed_bias 125 | 126 | def _resnet_block(self, name, in_channels, out_channels, scale): 127 | """ResNet block for the generator.""" 128 | if scale not in ["up", "none"]: 129 | raise ValueError( 130 | "Unknown generator ResNet block scaling: {}.".format(scale)) 131 | return BigGanResNetBlock( 132 | name=name, 133 | in_channels=in_channels, 134 | out_channels=out_channels, 135 | scale=scale, 136 | is_gen_block=True, 137 | spectral_norm=self._spectral_norm, 138 | batch_norm=self.batch_norm) 139 | 140 | def batch_norm(self, inputs, **kwargs): 141 | if self._batch_norm_fn is None: 142 | return inputs 143 | args = kwargs.copy() 144 | args["inputs"] = inputs 145 | if "use_sn" not in args: 146 | args["use_sn"] = self._spectral_norm 147 | return utils.call_with_accepted_args(self._batch_norm_fn, **args) 148 | 149 | def _get_in_out_channels(self): 150 | resolution = self._image_shape[0] 151 | if resolution == 512: 152 | channel_multipliers = [16, 16, 8, 8, 4, 2, 1, 1] 153 | elif resolution == 256: 154 | channel_multipliers = [16, 16, 8, 8, 4, 2, 1] 155 | elif resolution == 128: 156 | channel_multipliers = [16, 16, 8, 4, 2, 1] 157 | elif resolution == 64: 158 | channel_multipliers = [16, 16, 8, 4, 2] 159 | elif resolution == 32: 160 | channel_multipliers = [4, 4, 4, 4] 161 | else: 162 | raise ValueError("Unsupported resolution: {}".format(resolution)) 163 | in_channels = [self._ch * c for c in channel_multipliers[:-1]] 164 | out_channels = [self._ch * c for c in channel_multipliers[1:]] 165 | return in_channels, out_channels 166 | 167 | def __call__(self, z, y): 168 | with tf.variable_scope(self.name, values=[z, y], reuse=tf.AUTO_REUSE): 169 | """Build the generator network for the given inputs. 170 | Args: 171 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 172 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 173 | labels. 174 | is_training: boolean, are we in train or eval model. 175 | Returns: 176 | A tensor of size [batch_size] + self._image_shape with values in [0, 1]. 177 | """ 178 | shape_or_none = lambda t: None if t is None else t.shape 179 | logging.info("[Generator] inputs are z=%s, y=%s", z.shape, shape_or_none(y)) 180 | # Each block upscales by a factor of 2. 181 | seed_size = 4 182 | z_dim = z.shape[1].value 183 | 184 | in_channels, out_channels = self._get_in_out_channels() 185 | num_blocks = len(in_channels) 186 | 187 | if self._embed_z: 188 | z = ops.linear(z, z_dim, scope="embed_z", use_sn=False, 189 | use_bias=self._embed_bias) 190 | if self._embed_y: 191 | y = ops.linear(y, self._embed_y_dim, scope="embed_y", use_sn=False, 192 | use_bias=self._embed_bias) 193 | y_per_block = num_blocks * [y] 194 | if self._hierarchical_z: 195 | z_per_block = tf.split(z, num_blocks + 1, axis=1) 196 | z0, z_per_block = z_per_block[0], z_per_block[1:] 197 | if y is not None: 198 | y_per_block = [tf.concat([zi, y], 1) for zi in z_per_block] 199 | else: 200 | z0 = z 201 | z_per_block = num_blocks * [z] 202 | 203 | logging.info("[Generator] z0=%s, z_per_block=%s, y_per_block=%s", 204 | z0.shape, [str(shape_or_none(t)) for t in z_per_block], 205 | [str(shape_or_none(t)) for t in y_per_block]) 206 | 207 | # Map noise to the actual seed. 208 | net = ops.linear( 209 | z0, 210 | in_channels[0] * seed_size * seed_size, 211 | scope="fc_noise", 212 | use_sn=self._spectral_norm) 213 | # Reshape the seed to be a rank-4 Tensor. 214 | net = tf.reshape( 215 | net, 216 | [-1, seed_size, seed_size, in_channels[0]], 217 | name="fc_reshaped") 218 | 219 | for block_idx in range(num_blocks): 220 | name = "B{}".format(block_idx + 1) 221 | block = self._resnet_block( 222 | name=name, 223 | in_channels=in_channels[block_idx], 224 | out_channels=out_channels[block_idx], 225 | scale="up") 226 | net = block( 227 | net, 228 | z=z_per_block[block_idx], 229 | y=y_per_block[block_idx], 230 | is_training=self._is_train) 231 | if name in self._blocks_with_attention: 232 | logging.info("[Generator] Applying non-local block to %s", net.shape) 233 | net = ops.non_local_block(net, "non_local_block", 234 | use_sn=self._spectral_norm) 235 | # Final processing of the net. 236 | # Use unconditional batch norm. 237 | logging.info("[Generator] before final processing: %s", net.shape) 238 | net = ops.batch_norm(net, is_training=self._is_train, name="final_norm") 239 | net = tf.nn.relu(net) 240 | net = ops.conv2d(net, output_dim=self._image_shape[2], k_h=3, k_w=3, 241 | d_h=1, d_w=1, name="final_conv", 242 | use_sn=self._spectral_norm) 243 | logging.info("[Generator] after final processing: %s", net.shape) 244 | net = (tf.nn.tanh(net) + 1.0) / 2.0 245 | return net -------------------------------------------------------------------------------- /BigGAN/resnet_joint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from absl import logging 6 | 7 | import arch_ops as ops 8 | import resnet_ops 9 | 10 | import gin 11 | from six.moves import range 12 | import tensorflow as tf 13 | 14 | 15 | class BigGanResNetBlock(resnet_ops.ResNetBlock): 16 | """ResNet block with options for various normalizations. 17 | This block uses a 1x1 convolution for the (optional) shortcut connection. 18 | """ 19 | 20 | def __init__(self, 21 | add_shortcut=True, 22 | **kwargs): 23 | """Constructs a new ResNet block for BigGAN. 24 | Args: 25 | add_shortcut: Whether to add a shortcut connection. 26 | **kwargs: Additional arguments for ResNetBlock. 27 | """ 28 | super(BigGanResNetBlock, self).__init__(**kwargs) 29 | self._add_shortcut = add_shortcut 30 | 31 | def apply(self, inputs, z, y, is_training): 32 | """"ResNet block containing possible down/up sampling, shared for G / D. 33 | Args: 34 | inputs: a 3d input tensor of feature map. 35 | z: the latent vector for potential self-modulation. Can be None if use_sbn 36 | is set to False. 37 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 38 | labels. 39 | is_training: boolean, whether or notthis is called during the training. 40 | Returns: 41 | output: a 3d output tensor of feature map. 42 | """ 43 | if inputs.shape[-1].value != self._in_channels: 44 | raise ValueError( 45 | "Unexpected number of input channels (expected {}, got {}).".format( 46 | self._in_channels, inputs.shape[-1].value)) 47 | 48 | with tf.variable_scope(self._name, values=[inputs]): 49 | outputs = inputs 50 | 51 | outputs = self.batch_norm( 52 | outputs, z=z, y=y, is_training=is_training, name="bn1") 53 | if self._layer_norm: 54 | outputs = ops.layer_norm(outputs, is_training=is_training, scope="ln1") 55 | 56 | outputs = tf.nn.relu(outputs) 57 | outputs = self._get_conv( 58 | outputs, self._in_channels, self._out_channels, self._scale1, 59 | suffix="conv1") 60 | 61 | outputs = self.batch_norm( 62 | outputs, z=z, y=y, is_training=is_training, name="bn2") 63 | if self._layer_norm: 64 | outputs = ops.layer_norm(outputs, is_training=is_training, scope="ln2") 65 | 66 | outputs = tf.nn.relu(outputs) 67 | outputs = self._get_conv( 68 | outputs, self._out_channels, self._out_channels, self._scale2, 69 | suffix="conv2") 70 | 71 | # Combine skip-connection with the convolved part. 72 | if self._add_shortcut: 73 | shortcut = self._get_conv( 74 | inputs, self._in_channels, self._out_channels, self._scale, 75 | kernel_size=(1, 1), 76 | suffix="conv_shortcut") 77 | outputs += shortcut 78 | logging.info("[Block] %s (z=%s, y=%s) -> %s", inputs.shape, 79 | None if z is None else z.shape, 80 | None if y is None else y.shape, outputs.shape) 81 | return outputs 82 | 83 | class Classifier_proD(object): 84 | """ResNet-based discriminator supporting resolutions 32, 64, 128, 256, 512.""" 85 | 86 | def __init__(self, 87 | name, num_class, use_sn, 88 | ch=96, 89 | blocks_with_attention="B1", 90 | project_y=True, 91 | **kwargs): 92 | """Constructor for BigGAN discriminator. 93 | Args: 94 | ch: Channel multiplier. 95 | blocks_with_attention: Comma-separated list of blocks that are followed by 96 | a non-local block. 97 | project_y: Add an embedding of y in the output layer. 98 | **kwargs: additional arguments past on to ResNetDiscriminator. 99 | """ 100 | self.name = name 101 | self.num_class = num_class 102 | self._spectral_norm = True 103 | self._batch_norm_fn = None 104 | self._layer_norm = False 105 | 106 | self._ch = ch 107 | self._blocks_with_attention = set(blocks_with_attention.split(",")) 108 | self._project_y = project_y 109 | 110 | def _resnet_block(self, name, in_channels, out_channels, scale): 111 | """ResNet block for the generator.""" 112 | if scale not in ["down", "none"]: 113 | raise ValueError( 114 | "Unknown discriminator ResNet block scaling: {}.".format(scale)) 115 | return BigGanResNetBlock( 116 | name=name, 117 | in_channels=in_channels, 118 | out_channels=out_channels, 119 | scale=scale, 120 | is_gen_block=False, 121 | add_shortcut=in_channels != out_channels, 122 | layer_norm=self._layer_norm, 123 | spectral_norm=self._spectral_norm, 124 | batch_norm=self.batch_norm) 125 | 126 | def batch_norm(self, inputs, **kwargs): 127 | if self._batch_norm_fn is None: 128 | return inputs 129 | args = kwargs.copy() 130 | args["inputs"] = inputs 131 | if "use_sn" not in args: 132 | args["use_sn"] = self._spectral_norm 133 | return utils.call_with_accepted_args(self._batch_norm_fn, **args) 134 | 135 | def _get_in_out_channels(self, colors, resolution): 136 | if colors not in [1, 3]: 137 | raise ValueError("Unsupported color channels: {}".format(colors)) 138 | if resolution == 512: 139 | channel_multipliers = [1, 1, 2, 4, 8, 8, 16, 16] 140 | elif resolution == 256: 141 | channel_multipliers = [1, 2, 4, 8, 8, 16, 16] 142 | elif resolution == 128: 143 | channel_multipliers = [1, 2, 4, 8, 16, 16] 144 | elif resolution == 64: 145 | channel_multipliers = [2, 4, 8, 16, 16] 146 | elif resolution == 32: 147 | channel_multipliers = [2, 2, 2, 2] 148 | else: 149 | raise ValueError("Unsupported resolution: {}".format(resolution)) 150 | out_channels = [self._ch * c for c in channel_multipliers] 151 | in_channels = [colors] + out_channels[:-1] 152 | return in_channels, out_channels 153 | 154 | def __call__(self, x, y, _, __): 155 | with tf.variable_scope(self.name, values=[x, y], reuse=tf.AUTO_REUSE): 156 | """Apply the discriminator on a input. 157 | Args: 158 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. 159 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 160 | labels. 161 | is_training: Boolean, whether the architecture should be constructed for 162 | training or inference. 163 | Returns: 164 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 165 | before the final output activation function and logits form the second 166 | last layer. 167 | """ 168 | logging.info("[Discriminator] inputs are x=%s, y=%s", x.shape, 169 | None if y is None else y.shape) 170 | resnet_ops.validate_image_inputs(x) 171 | 172 | in_channels, out_channels = self._get_in_out_channels( 173 | colors=x.shape[-1].value, resolution=x.shape[1].value) 174 | num_blocks = len(in_channels) 175 | 176 | net = x 177 | for block_idx in range(num_blocks): 178 | name = "B{}".format(block_idx + 1) 179 | is_last_block = block_idx == num_blocks - 1 180 | block = self._resnet_block( 181 | name=name, 182 | in_channels=in_channels[block_idx], 183 | out_channels=out_channels[block_idx], 184 | scale="none" if is_last_block else "down") 185 | net = block(net, z=None, y=y, is_training=None) 186 | if name in self._blocks_with_attention: 187 | logging.info("[Discriminator] Applying non-local block to %s", 188 | net.shape) 189 | net = ops.non_local_block(net, "non_local_block", 190 | use_sn=self._spectral_norm) 191 | 192 | # Final part 193 | logging.info("[Discriminator] before final processing: %s", net.shape) 194 | net_conv = tf.nn.relu(net) 195 | h = tf.math.reduce_sum(net_conv, axis=[1, 2]) 196 | out_logit_tf = ops.linear(h, 1, scope="final_fc", use_sn=self._spectral_norm) 197 | logging.info("[Discriminator] after final processing: %s", net.shape) 198 | if self._project_y: 199 | if y is None: 200 | raise ValueError("You must provide class information y to project.") 201 | with tf.variable_scope("embedding_fc"): 202 | y_embedding_dim = out_channels[-1] 203 | # We do not use ops.linear() below since it does not have an option to 204 | # override the initializer. 205 | kernel = tf.get_variable( 206 | "kernel", [y.shape[1], y_embedding_dim], tf.float32, 207 | initializer=tf.initializers.glorot_normal()) 208 | if self._spectral_norm: 209 | kernel = ops.spectral_norm(kernel) 210 | embedded_y = tf.matmul(y, kernel) 211 | logging.info("[Discriminator] embedded_y for projection: %s", 212 | embedded_y.shape) 213 | out_logit_tf += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True) 214 | 215 | feature_matching = h 216 | t_SNE = h 217 | out_logit = ops.linear(h, self.num_class, scope="final_fc_cla", use_sn=self._spectral_norm) 218 | 219 | # grad cam 220 | cls = tf.argmax(y,axis=1)[0] 221 | out_logit_cam = tf.identity(out_logit) 222 | y_c = out_logit_cam[0, cls] 223 | grads = tf.gradients(y_c, net_conv)[0] 224 | output_conv = net_conv[0] 225 | grads_val = grads[0] 226 | 227 | # grad cam ++ 228 | first = tf.exp(y_c)*grads 229 | second = tf.exp(y_c)*grads*grads 230 | third = tf.exp(y_c)*grads*grads*grads 231 | conv_first_grad, conv_second_grad, conv_third_grad = first[0], second[0], third[0] 232 | grad_val_plusplus = [conv_first_grad, conv_second_grad, conv_third_grad] 233 | 234 | # saliency 235 | signal = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=out_logit[0, :], labels=y[0])) 236 | guide_grad = tf.gradients(signal, x)[0] 237 | 238 | return tf.nn.softmax(out_logit), out_logit, tf.nn.sigmoid(out_logit_tf), out_logit_tf, feature_matching, t_SNE, output_conv, grads_val, guide_grad, grad_val_plusplus, tf.argmax(y,axis=1)[0], tf.argmax(out_logit_cam,axis=1)[0] 239 | 240 | 241 | -------------------------------------------------------------------------------- /BigGAN/resnet_gen_y_deep.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import functools 6 | 7 | from absl import logging 8 | 9 | import arch_ops as ops 10 | import resnet_ops 11 | 12 | import utils 13 | 14 | from six.moves import range 15 | import tensorflow as tf 16 | import time 17 | 18 | class BigGanDeepResNetBlock(object): 19 | """ResNet block with bottleneck and identity preserving skip connections.""" 20 | 21 | def __init__(self, 22 | name, 23 | in_channels, 24 | out_channels, 25 | scale, 26 | spectral_norm=False, 27 | batch_norm=None): 28 | """Constructs a new ResNet block with bottleneck. 29 | Args: 30 | name: Scope name for the resent block. 31 | in_channels: Integer, the input channel size. 32 | out_channels: Integer, the output channel size. 33 | scale: Whether or not to scale up or down, choose from "up", "down" or 34 | "none". 35 | spectral_norm: Use spectral normalization for all weights. 36 | batch_norm: Function for batch normalization. 37 | """ 38 | assert scale in ["up", "down", "none"] 39 | self._name = name 40 | self._in_channels = in_channels 41 | self._out_channels = out_channels 42 | self._scale = scale 43 | self._spectral_norm = spectral_norm 44 | self.batch_norm = batch_norm 45 | 46 | def __call__(self, inputs, z, y, is_training): 47 | return self.apply(inputs=inputs, z=z, y=y, is_training=is_training) 48 | 49 | def _shortcut(self, inputs): 50 | """Constructs a skip connection from inputs.""" 51 | with tf.variable_scope("shortcut", values=[inputs]): 52 | shortcut = inputs 53 | num_channels = inputs.shape[-1].value 54 | if num_channels > self._out_channels: 55 | assert self._scale == "up" 56 | # Drop redundant channels. 57 | logging.info("[Shortcut] Dropping %d channels in shortcut.", 58 | num_channels - self._out_channels) 59 | shortcut = shortcut[:, :, :, :self._out_channels] 60 | if self._scale == "up": 61 | shortcut = resnet_ops.unpool(shortcut) 62 | if self._scale == "down": 63 | shortcut = tf.nn.pool(shortcut, [2, 2], "AVG", "SAME", 64 | strides=[2, 2], name="pool") 65 | if num_channels < self._out_channels: 66 | assert self._scale == "down" 67 | # Increase number of channels if necessary. 68 | num_missing = self._out_channels - num_channels 69 | logging.info("[Shortcut] Adding %d channels in shortcut.", num_missing) 70 | added = ops.conv1x1(shortcut, num_missing, name="add_channels", 71 | use_sn=self._spectral_norm) 72 | shortcut = tf.concat([shortcut, added], axis=-1) 73 | return shortcut 74 | 75 | def apply(self, inputs, z, y, is_training): 76 | """"ResNet block containing possible down/up sampling, shared for G / D. 77 | Args: 78 | inputs: a 3d input tensor of feature map. 79 | z: the latent vector for potential self-modulation. Can be None if use_sbn 80 | is set to False. 81 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 82 | labels. 83 | is_training: boolean, whether or notthis is called during the training. 84 | Returns: 85 | output: a 3d output tensor of feature map. 86 | """ 87 | if inputs.shape[-1].value != self._in_channels: 88 | raise ValueError( 89 | "Unexpected number of input channels (expected {}, got {}).".format( 90 | self._in_channels, inputs.shape[-1].value)) 91 | 92 | bottleneck_channels = max(self._in_channels, self._out_channels) // 4 93 | bn = functools.partial(self.batch_norm, z=z, y=y, is_training=is_training) 94 | conv1x1 = functools.partial(ops.conv1x1, use_sn=self._spectral_norm) 95 | conv3x3 = functools.partial(ops.conv2d, k_h=3, k_w=3, d_h=1, d_w=1, 96 | use_sn=self._spectral_norm) 97 | 98 | with tf.variable_scope(self._name, values=[inputs]): 99 | outputs = inputs 100 | 101 | with tf.variable_scope("conv1", values=[outputs]): 102 | outputs = bn(outputs, name="bn") 103 | outputs = tf.nn.relu(outputs) 104 | outputs = conv1x1(outputs, bottleneck_channels, name="1x1_conv") 105 | 106 | with tf.variable_scope("conv2", values=[outputs]): 107 | outputs = bn(outputs, name="bn") 108 | outputs = tf.nn.relu(outputs) 109 | if self._scale == "up": 110 | outputs = resnet_ops.unpool(outputs) 111 | outputs = conv3x3(outputs, bottleneck_channels, name="3x3_conv") 112 | 113 | with tf.variable_scope("conv3", values=[outputs]): 114 | outputs = bn(outputs, name="bn") 115 | outputs = tf.nn.relu(outputs) 116 | outputs = conv3x3(outputs, bottleneck_channels, name="3x3_conv") 117 | 118 | with tf.variable_scope("conv4", values=[outputs]): 119 | outputs = bn(outputs, name="bn") 120 | outputs = tf.nn.relu(outputs) 121 | if self._scale == "down": 122 | outputs = tf.nn.pool(outputs, [2, 2], "AVG", "SAME", strides=[2, 2], 123 | name="avg_pool") 124 | outputs = conv1x1(outputs, self._out_channels, name="1x1_conv") 125 | 126 | # Add skip-connection. 127 | outputs += self._shortcut(inputs) 128 | 129 | logging.info("[Block] %s (z=%s, y=%s) -> %s", inputs.shape, 130 | None if z is None else z.shape, 131 | None if y is None else y.shape, outputs.shape) 132 | return outputs 133 | 134 | 135 | class Generator(object): 136 | """ResNet-based generator supporting resolutions 32, 64, 128, 256, 512.""" 137 | 138 | def __init__(self, 139 | name, h, w, c, is_train, use_sn, 140 | ch=128, 141 | embed_y=True, 142 | embed_y_dim=128, 143 | experimental_fast_conv_to_rgb=False, 144 | **kwargs): 145 | """Constructor for BigGAN generator. 146 | Args: 147 | ch: Channel multiplier. 148 | embed_y: If True use a learnable embedding of y that is used instead. 149 | embed_y_dim: Size of the embedding of y. 150 | experimental_fast_conv_to_rgb: If True optimize the last convolution to 151 | sacrifize memory for better speed. 152 | **kwargs: additional arguments past on to ResNetGenerator. 153 | """ 154 | self.name = name 155 | self.s_h, self.s_w, self.colors = [h,w,c] 156 | self._image_shape = [h,w,c] 157 | self._is_train = is_train 158 | self._batch_norm_fn = ops.conditional_batch_norm 159 | self._spectral_norm = False 160 | 161 | self._ch = ch 162 | self._embed_y = embed_y 163 | self._embed_y_dim = embed_y_dim 164 | self._experimental_fast_conv_to_rgb = experimental_fast_conv_to_rgb 165 | 166 | def _resnet_block(self, name, in_channels, out_channels, scale): 167 | """ResNet block for the generator.""" 168 | if scale not in ["up", "none"]: 169 | raise ValueError( 170 | "Unknown generator ResNet block scaling: {}.".format(scale)) 171 | return BigGanDeepResNetBlock( 172 | name=name, 173 | in_channels=in_channels, 174 | out_channels=out_channels, 175 | scale=scale, 176 | spectral_norm=self._spectral_norm, 177 | batch_norm=self.batch_norm) 178 | 179 | def _get_in_out_channels(self): 180 | # See Table 7-9. 181 | resolution = self._image_shape[0] 182 | if resolution == 512: 183 | channel_multipliers = 4 * [16] + 4 * [8] + [4, 4, 2, 2, 1, 1, 1] 184 | elif resolution == 256: 185 | channel_multipliers = 4 * [16] + 4 * [8] + [4, 4, 2, 2, 1] 186 | elif resolution == 128: 187 | channel_multipliers = 4 * [16] + 2 * [8] + [4, 4, 2, 2, 1] 188 | elif resolution == 64: 189 | channel_multipliers = 4 * [16] + 2 * [8] + [4, 4, 2] 190 | elif resolution == 32: 191 | channel_multipliers = 8 * [4] 192 | else: 193 | raise ValueError("Unsupported resolution: {}".format(resolution)) 194 | in_channels = [self._ch * c for c in channel_multipliers[:-1]] 195 | out_channels = [self._ch * c for c in channel_multipliers[1:]] 196 | return in_channels, out_channels 197 | 198 | def batch_norm(self, inputs, **kwargs): 199 | if self._batch_norm_fn is None: 200 | return inputs 201 | args = kwargs.copy() 202 | args["inputs"] = inputs 203 | if "use_sn" not in args: 204 | args["use_sn"] = self._spectral_norm 205 | return utils.call_with_accepted_args(self._batch_norm_fn, **args) 206 | 207 | def __call__(self, z, y): 208 | with tf.variable_scope(self.name, values=[z, y], reuse=tf.AUTO_REUSE): 209 | """Build the generator network for the given inputs. 210 | Args: 211 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 212 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 213 | labels. 214 | is_training: boolean, are we in train or eval model. 215 | Returns: 216 | A tensor of size [batch_size] + self._image_shape with values in [0, 1]. 217 | """ 218 | shape_or_none = lambda t: None if t is None else t.shape 219 | logging.info("[Generator] inputs are z=%s, y=%s", z.shape, shape_or_none(y)) 220 | seed_size = 4 221 | 222 | if self._embed_y: 223 | y = ops.linear(y, self._embed_y_dim, scope="embed_y", use_sn=False, 224 | use_bias=False) 225 | if y is not None: 226 | y = tf.concat([z, y], axis=1) 227 | z = y 228 | 229 | in_channels, out_channels = self._get_in_out_channels() 230 | num_blocks = len(in_channels) 231 | 232 | # Map noise to the actual seed. 233 | net = ops.linear( 234 | z, 235 | in_channels[0] * seed_size * seed_size, 236 | scope="fc_noise", 237 | use_sn=self._spectral_norm) 238 | # Reshape the seed to be a rank-4 Tensor. 239 | net = tf.reshape( 240 | net, 241 | [-1, seed_size, seed_size, in_channels[0]], 242 | name="fc_reshaped") 243 | 244 | for block_idx in range(num_blocks): 245 | scale = "none" if block_idx % 2 == 0 else "up" 246 | block = self._resnet_block( 247 | name="B{}".format(block_idx + 1), 248 | in_channels=in_channels[block_idx], 249 | out_channels=out_channels[block_idx], 250 | scale=scale) 251 | net = block(net, z=z, y=y, is_training=(self._is_train)) 252 | # At resolution 64x64 there is a self-attention block. 253 | if scale == "up" and net.shape[1].value == 64: 254 | logging.info("[Generator] Applying non-local block to %s", net.shape) 255 | net = ops.non_local_block(net, "non_local_block", 256 | use_sn=self._spectral_norm) 257 | # Final processing of the net. 258 | # Use unconditional batch norm. 259 | logging.info("[Generator] before final processing: %s", net.shape) 260 | net = ops.batch_norm(net, is_training=(self._is_train), name="final_norm") 261 | net = tf.nn.relu(net) 262 | colors = self._image_shape[2] 263 | if self._experimental_fast_conv_to_rgb: 264 | 265 | net = ops.conv2d(net, output_dim=128, k_h=3, k_w=3, 266 | d_h=1, d_w=1, name="final_conv", 267 | use_sn=self._spectral_norm) 268 | net = net[:, :, :, :colors] 269 | else: 270 | net = ops.conv2d(net, output_dim=colors, k_h=3, k_w=3, 271 | d_h=1, d_w=1, name="final_conv", 272 | use_sn=self._spectral_norm) 273 | logging.info("[Generator] after final processing: %s", net.shape) 274 | net = (tf.nn.tanh(net) + 1.0) / 2.0 275 | return net -------------------------------------------------------------------------------- /BigGAN/resnet_joint_deep.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import functools 6 | 7 | from absl import logging 8 | 9 | import arch_ops as ops 10 | import resnet_ops 11 | 12 | import utils 13 | 14 | from six.moves import range 15 | import tensorflow as tf 16 | import time 17 | 18 | class BigGanDeepResNetBlock(object): 19 | """ResNet block with bottleneck and identity preserving skip connections.""" 20 | 21 | def __init__(self, 22 | name, 23 | in_channels, 24 | out_channels, 25 | scale, 26 | spectral_norm=False, 27 | batch_norm=None): 28 | """Constructs a new ResNet block with bottleneck. 29 | Args: 30 | name: Scope name for the resent block. 31 | in_channels: Integer, the input channel size. 32 | out_channels: Integer, the output channel size. 33 | scale: Whether or not to scale up or down, choose from "up", "down" or 34 | "none". 35 | spectral_norm: Use spectral normalization for all weights. 36 | batch_norm: Function for batch normalization. 37 | """ 38 | assert scale in ["up", "down", "none"] 39 | self._name = name 40 | self._in_channels = in_channels 41 | self._out_channels = out_channels 42 | self._scale = scale 43 | self._spectral_norm = spectral_norm 44 | self.batch_norm = batch_norm 45 | 46 | def __call__(self, inputs, z, y, is_training): 47 | return self.apply(inputs=inputs, z=z, y=y, is_training=is_training) 48 | 49 | def _shortcut(self, inputs): 50 | """Constructs a skip connection from inputs.""" 51 | with tf.variable_scope("shortcut", values=[inputs]): 52 | shortcut = inputs 53 | num_channels = inputs.shape[-1].value 54 | if num_channels > self._out_channels: 55 | assert self._scale == "up" 56 | # Drop redundant channels. 57 | logging.info("[Shortcut] Dropping %d channels in shortcut.", 58 | num_channels - self._out_channels) 59 | shortcut = shortcut[:, :, :, :self._out_channels] 60 | if self._scale == "up": 61 | shortcut = resnet_ops.unpool(shortcut) 62 | if self._scale == "down": 63 | shortcut = tf.nn.pool(shortcut, [2, 2], "AVG", "SAME", 64 | strides=[2, 2], name="pool") 65 | if num_channels < self._out_channels: 66 | assert self._scale == "down" 67 | # Increase number of channels if necessary. 68 | num_missing = self._out_channels - num_channels 69 | logging.info("[Shortcut] Adding %d channels in shortcut.", num_missing) 70 | added = ops.conv1x1(shortcut, num_missing, name="add_channels", 71 | use_sn=self._spectral_norm) 72 | shortcut = tf.concat([shortcut, added], axis=-1) 73 | return shortcut 74 | 75 | def apply(self, inputs, z, y, is_training): 76 | """"ResNet block containing possible down/up sampling, shared for G / D. 77 | Args: 78 | inputs: a 3d input tensor of feature map. 79 | z: the latent vector for potential self-modulation. Can be None if use_sbn 80 | is set to False. 81 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 82 | labels. 83 | is_training: boolean, whether or notthis is called during the training. 84 | Returns: 85 | output: a 3d output tensor of feature map. 86 | """ 87 | if inputs.shape[-1].value != self._in_channels: 88 | raise ValueError( 89 | "Unexpected number of input channels (expected {}, got {}).".format( 90 | self._in_channels, inputs.shape[-1].value)) 91 | 92 | bottleneck_channels = max(self._in_channels, self._out_channels) // 4 93 | bn = functools.partial(self.batch_norm, z=z, y=y, is_training=is_training) 94 | conv1x1 = functools.partial(ops.conv1x1, use_sn=self._spectral_norm) 95 | conv3x3 = functools.partial(ops.conv2d, k_h=3, k_w=3, d_h=1, d_w=1, 96 | use_sn=self._spectral_norm) 97 | 98 | with tf.variable_scope(self._name, values=[inputs]): 99 | outputs = inputs 100 | 101 | with tf.variable_scope("conv1", values=[outputs]): 102 | outputs = bn(outputs, name="bn") 103 | outputs = tf.nn.relu(outputs) 104 | outputs = conv1x1(outputs, bottleneck_channels, name="1x1_conv") 105 | 106 | with tf.variable_scope("conv2", values=[outputs]): 107 | outputs = bn(outputs, name="bn") 108 | outputs = tf.nn.relu(outputs) 109 | if self._scale == "up": 110 | outputs = resnet_ops.unpool(outputs) 111 | outputs = conv3x3(outputs, bottleneck_channels, name="3x3_conv") 112 | 113 | with tf.variable_scope("conv3", values=[outputs]): 114 | outputs = bn(outputs, name="bn") 115 | outputs = tf.nn.relu(outputs) 116 | outputs = conv3x3(outputs, bottleneck_channels, name="3x3_conv") 117 | 118 | with tf.variable_scope("conv4", values=[outputs]): 119 | outputs = bn(outputs, name="bn") 120 | outputs = tf.nn.relu(outputs) 121 | if self._scale == "down": 122 | outputs = tf.nn.pool(outputs, [2, 2], "AVG", "SAME", strides=[2, 2], 123 | name="avg_pool") 124 | outputs = conv1x1(outputs, self._out_channels, name="1x1_conv") 125 | 126 | # Add skip-connection. 127 | outputs += self._shortcut(inputs) 128 | 129 | logging.info("[Block] %s (z=%s, y=%s) -> %s", inputs.shape, 130 | None if z is None else z.shape, 131 | None if y is None else y.shape, outputs.shape) 132 | return outputs 133 | 134 | 135 | class Classifier_proD(object): 136 | """ResNet-based discriminator supporting resolutions 32, 64, 128, 256, 512.""" 137 | 138 | def __init__(self, 139 | name, num_class, use_sn, 140 | ch=128, 141 | blocks_with_attention="B1", 142 | project_y=True, 143 | **kwargs): 144 | """Constructor for BigGAN discriminator. 145 | Args: 146 | ch: Channel multiplier. 147 | blocks_with_attention: Comma-separated list of blocks that are followed by 148 | a non-local block. 149 | project_y: Add an embedding of y in the output layer. 150 | **kwargs: additional arguments past on to ResNetDiscriminator. 151 | """ 152 | self.name = name 153 | self.num_class = num_class 154 | self._spectral_norm = True 155 | self._batch_norm_fn = None 156 | 157 | self._ch = ch 158 | self._blocks_with_attention = set(blocks_with_attention.split(",")) 159 | self._project_y = project_y 160 | 161 | def _resnet_block(self, name, in_channels, out_channels, scale): 162 | """ResNet block for the generator.""" 163 | if scale not in ["down", "none"]: 164 | raise ValueError( 165 | "Unknown discriminator ResNet block scaling: {}.".format(scale)) 166 | return BigGanDeepResNetBlock( 167 | name=name, 168 | in_channels=in_channels, 169 | out_channels=out_channels, 170 | scale=scale, 171 | spectral_norm=self._spectral_norm, 172 | batch_norm=self.batch_norm) 173 | 174 | def _get_in_out_channels(self, colors, resolution): 175 | # See Table 7-9. 176 | if colors not in [1, 3]: 177 | raise ValueError("Unsupported color channels: {}".format(colors)) 178 | if resolution == 512: 179 | channel_multipliers = [1, 1, 1, 2, 2, 4, 4] + 4 * [8] + 4 * [16] 180 | elif resolution == 256: 181 | channel_multipliers = [1, 2, 2, 4, 4] + 4 * [8] + 4 * [16] 182 | elif resolution == 128: 183 | channel_multipliers = [1, 2, 2, 4, 4] + 2 * [8] + 4 * [16] 184 | elif resolution == 64: 185 | channel_multipliers = [2, 4, 4] + 2 * [8] + 4 * [16] 186 | elif resolution == 32: 187 | channel_multipliers = 8 * [2] 188 | else: 189 | raise ValueError("Unsupported resolution: {}".format(resolution)) 190 | in_channels = [self._ch * c for c in channel_multipliers[:-1]] 191 | out_channels = [self._ch * c for c in channel_multipliers[1:]] 192 | return in_channels, out_channels 193 | 194 | def batch_norm(self, inputs, **kwargs): 195 | if self._batch_norm_fn is None: 196 | return inputs 197 | args = kwargs.copy() 198 | args["inputs"] = inputs 199 | if "use_sn" not in args: 200 | args["use_sn"] = self._spectral_norm 201 | return utils.call_with_accepted_args(self._batch_norm_fn, **args) 202 | 203 | def __call__(self, x, y, _, __): 204 | with tf.variable_scope(self.name, values=[x, y], reuse=tf.AUTO_REUSE): 205 | """Apply the discriminator on a input. 206 | Args: 207 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. 208 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 209 | labels. 210 | is_training: Boolean, whether the architecture should be constructed for 211 | training or inference. 212 | Returns: 213 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 214 | before the final output activation function and logits form the second 215 | last layer. 216 | """ 217 | logging.info("[Discriminator] inputs are x=%s, y=%s", x.shape, 218 | None if y is None else y.shape) 219 | resnet_ops.validate_image_inputs(x) 220 | 221 | in_channels, out_channels = self._get_in_out_channels( 222 | colors=x.shape[-1].value, resolution=x.shape[1].value) 223 | num_blocks = len(in_channels) 224 | 225 | net = ops.conv2d(x, output_dim=in_channels[0], k_h=3, k_w=3, 226 | d_h=1, d_w=1, name="initial_conv", 227 | use_sn=self._spectral_norm) 228 | 229 | for block_idx in range(num_blocks): 230 | scale = "down" if block_idx % 2 == 0 else "none" 231 | block = self._resnet_block( 232 | name="B{}".format(block_idx + 1), 233 | in_channels=in_channels[block_idx], 234 | out_channels=out_channels[block_idx], 235 | scale=scale) 236 | net = block(net, z=None, y=y, is_training=None) 237 | # At resolution 64x64 there is a self-attention block. 238 | if scale == "none" and net.shape[1].value == 64: 239 | logging.info("[Discriminator] Applying non-local block to %s", 240 | net.shape) 241 | net = ops.non_local_block(net, "non_local_block", 242 | use_sn=self._spectral_norm) 243 | 244 | # Final part 245 | logging.info("[Discriminator] before final processing: %s", net.shape) 246 | net_conv = tf.nn.relu(net) 247 | h = tf.math.reduce_sum(net_conv, axis=[1, 2]) 248 | out_logit_tf = ops.linear(h, 1, scope="final_fc", use_sn=self._spectral_norm) 249 | logging.info("[Discriminator] after final processing: %s", net.shape) 250 | if self._project_y: 251 | if y is None: 252 | raise ValueError("You must provide class information y to project.") 253 | with tf.variable_scope("embedding_fc"): 254 | y_embedding_dim = out_channels[-1] 255 | # We do not use ops.linear() below since it does not have an option to 256 | # override the initializer. 257 | kernel = tf.get_variable( 258 | "kernel", [y.shape[1], y_embedding_dim], tf.float32, 259 | initializer=tf.initializers.glorot_normal()) 260 | if self._spectral_norm: 261 | kernel = ops.spectral_norm(kernel) 262 | embedded_y = tf.matmul(y, kernel) 263 | logging.info("[Discriminator] embedded_y for projection: %s", 264 | embedded_y.shape) 265 | out_logit_tf += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True) 266 | 267 | feature_matching = h 268 | t_SNE = h 269 | out_logit = ops.linear(h, self.num_class, scope="final_fc_cla", use_sn=self._spectral_norm) 270 | 271 | # grad cam 272 | cls = tf.argmax(y,axis=1)[0] 273 | out_logit_cam = tf.identity(out_logit) 274 | y_c = out_logit_cam[0, cls] 275 | grads = tf.gradients(y_c, net_conv)[0] 276 | output_conv = net_conv[0] 277 | grads_val = grads[0] 278 | 279 | # grad cam ++ 280 | first = tf.exp(y_c)*grads 281 | second = tf.exp(y_c)*grads*grads 282 | third = tf.exp(y_c)*grads*grads*grads 283 | conv_first_grad, conv_second_grad, conv_third_grad = first[0], second[0], third[0] 284 | grad_val_plusplus = [conv_first_grad, conv_second_grad, conv_third_grad] 285 | 286 | # saliency 287 | signal = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=out_logit[0, :], labels=y[0])) 288 | guide_grad = tf.gradients(signal, x)[0] 289 | 290 | return tf.nn.softmax(out_logit), out_logit, tf.nn.sigmoid(out_logit_tf), out_logit_tf, feature_matching, t_SNE, output_conv, grads_val, guide_grad, grad_val_plusplus, tf.argmax(y,axis=1)[0], tf.argmax(out_logit_cam,axis=1)[0] -------------------------------------------------------------------------------- /trainer_joint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import time 7 | from six.moves import xrange 8 | from pprint import pprint 9 | import h5py 10 | import tensorflow as tf 11 | import numpy as np 12 | import tensorflow.contrib.slim as slim 13 | 14 | from input_ops import create_input_ops 15 | from util import log 16 | from config import argparser, get_params 17 | from tensorflow.contrib.tensorboard.plugins import projector 18 | import cv2 19 | import time 20 | from tensorflow.python.framework import ops 21 | from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, roc_auc_score, f1_score, fbeta_score, cohen_kappa_score 22 | 23 | def sigmoid(x, a, b, c): 24 | return c / (1 + np.exp(-a * (x-b))) 25 | 26 | class Trainer(object): 27 | 28 | def __init__(self, config, model, dataset_train, dataset_train_unlabel, dataset_test): 29 | self.config = config 30 | self.model = model 31 | hyper_parameter_str = '{}_lr_g_{}_d_{}_update_G{}D{}'.format( 32 | config["dataset"], config["learning_rate_g"], config["learning_rate_d"], 33 | config["update_rate"], 1 34 | ) 35 | self.train_dir = './train_dir/TMI/%s-%s-%s' % ( 36 | config["prefix"], 37 | hyper_parameter_str, 38 | time.strftime("%Y%m%d-%H%M%S") 39 | ) 40 | 41 | os.makedirs(self.train_dir) 42 | log.infov("Train Dir: %s", self.train_dir) 43 | 44 | self.batch_size = config["batch_size_L"] 45 | 46 | # --- input ops --- 47 | self.batch_train = create_input_ops( 48 | dataset_train, config["batch_size_L"]) 49 | 50 | self.batch_train_unlabel = create_input_ops( 51 | dataset_train_unlabel, config["batch_size_U"] * 2) 52 | 53 | self.batch_test = create_input_ops( 54 | dataset_test, config["batch_size_L"]) 55 | 56 | # --- optimizer --- 57 | self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) 58 | # --- checkpoint and monitoring --- 59 | all_var = tf.trainable_variables() 60 | 61 | gG_var = [v for v in all_var if v.name.startswith(('Generator_g'))] 62 | log.warning("********* gG_var ********** ") 63 | slim.model_analyzer.analyze_vars(gG_var, print_info=True) 64 | 65 | bG_var = [v for v in all_var if v.name.startswith(('Generator_b'))] 66 | log.warning("********* bG_var ********** ") 67 | slim.model_analyzer.analyze_vars(bG_var, print_info=True) 68 | 69 | c_var = [v for v in all_var if v.name.startswith(('Classifier'))] 70 | log.warning("********* c_var ********** ") 71 | slim.model_analyzer.analyze_vars(c_var, print_info=True) 72 | 73 | rem_var = (set(all_var) - set(gG_var) - set(bG_var) - set(c_var)) 74 | log.error([v.name for v in rem_var]) 75 | assert not rem_var 76 | 77 | self.gG_optimizer = tf.contrib.layers.optimize_loss( 78 | loss=self.model.g_loss_good, 79 | global_step=self.global_step, 80 | learning_rate=config["learning_rate_g"], 81 | optimizer=tf.train.AdamOptimizer(beta1=0.5,beta2=0.999), 82 | clip_gradients=20.0, 83 | name='gG_optimize_loss', 84 | variables=gG_var 85 | ) 86 | 87 | self.bG_optimizer = tf.contrib.layers.optimize_loss( 88 | loss=self.model.g_loss_bad, 89 | global_step=self.global_step, 90 | learning_rate=config["learning_rate_g"], 91 | optimizer=tf.train.AdamOptimizer(beta1=0.5,beta2=0.999), 92 | clip_gradients=20.0, 93 | name='bG_optimize_loss', 94 | variables=bG_var 95 | ) 96 | 97 | self.c_optimizer = tf.contrib.layers.optimize_loss( 98 | loss=self.model.c_loss + self.model.d_loss, 99 | global_step=self.global_step, 100 | learning_rate=config["learning_rate_d"], 101 | optimizer=tf.train.AdamOptimizer(beta1=0.5,beta2=0.999), 102 | clip_gradients=20.0, 103 | name='c_optimize_loss', 104 | variables=c_var 105 | ) 106 | 107 | self.c_optimizer_only = tf.contrib.layers.optimize_loss( 108 | loss=self.model.c_loss, 109 | global_step=self.global_step, 110 | learning_rate=config["learning_rate_d"], 111 | optimizer=tf.train.AdamOptimizer(beta1=0.5,beta2=0.999), 112 | clip_gradients=20.0, 113 | name='c_optimize_loss_only', 114 | variables=c_var 115 | ) 116 | 117 | self.summary_op = tf.summary.merge_all() 118 | self.saver = tf.train.Saver(max_to_keep=1000) 119 | self.summary_writer = tf.summary.FileWriter(self.train_dir) 120 | 121 | self.supervisor = tf.train.Supervisor( 122 | logdir=self.train_dir, 123 | is_chief=True, 124 | saver=None, 125 | summary_op=None, 126 | summary_writer=self.summary_writer, 127 | save_summaries_secs=300, 128 | save_model_secs=600, 129 | global_step=self.global_step, 130 | ) 131 | 132 | session_config = tf.ConfigProto( 133 | allow_soft_placement=True, 134 | gpu_options=tf.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction = 0.99), 135 | device_count={'GPU': 1} 136 | ) 137 | self.session = self.supervisor.prepare_or_wait_for_session(config=session_config) 138 | 139 | self.ckpt_path = config["checkpoint"] 140 | if self.ckpt_path is not None: 141 | log.info("Checkpoint path: %s", self.ckpt_path) 142 | self.saver.restore(self.session, self.ckpt_path) 143 | log.info("Loaded the pretrain parameters from the provided checkpoint path") 144 | 145 | def train(self): 146 | log.infov("Training Starts!") 147 | log.infov(self.batch_train) 148 | log.infov(self.batch_train_unlabel) 149 | step = self.session.run(self.global_step) 150 | 151 | for s in xrange(self.config["max_training_steps"]): 152 | step, accuracy, d_loss, g_loss, step_time, prediction_train, gt_train, Image_Real = \ 153 | self.run_single_step(self.batch_train, self.batch_train_unlabel, step=s, is_train = True) 154 | 155 | if s % self.config["log_step"] == self.config["log_step"] - 1: 156 | self.log_step_message(step + 1, accuracy, d_loss, g_loss, step_time, is_train=True) 157 | 158 | # periodic inference 159 | if s % self.config["test_sample_step"] == self.config["test_sample_step"] - 1: 160 | 161 | accuracy_mean, recall, precision, f1, kappa = [], [], [], [], [] 162 | 163 | accuracy, summary, d_loss, g_loss, step_time, prediction_test, gt_test = \ 164 | self.run_test(self.batch_test, self.batch_train_unlabel, is_train=False, step=s) 165 | self.log_step_message(step + 1, accuracy, d_loss, g_loss, 166 | step_time, is_train=False) 167 | 168 | self.summary_writer.add_summary(summary, global_step=step + 1) 169 | 170 | y_true = np.argmax(prediction_test, axis=1) 171 | y_pred = np.argmax(gt_test, axis=1) 172 | 173 | accuracy_mean.append(accuracy_score(y_true, y_pred)) 174 | recall.append(recall_score(y_true, y_pred, average="macro")) 175 | precision.append(precision_score(y_true, y_pred, average="macro")) 176 | f1.append(f1_score(y_true, y_pred, average="macro")) 177 | kappa.append(cohen_kappa_score(y_true, y_pred)) 178 | 179 | 180 | if s % self.config["output_save_step"] == self.config["output_save_step"] - 1 and s != self.config["max_training_steps"]-1: 181 | log.infov("Saved checkpoint at %d", step + 1) 182 | self.saver.save(self.session, os.path.join(self.train_dir, 'model'), global_step=step + 1) 183 | 184 | if s == self.config["max_training_steps"] - 1: 185 | log.infov("Saved checkpoint at %d", step + 1) 186 | self.saver.save(self.session, os.path.join(self.train_dir, 'model'), global_step=step + 1) 187 | 188 | def run_single_step(self, batch, batch_unlabel, step=None, is_train=True): 189 | 190 | _start_time = time.time() 191 | 192 | batch_chunk = self.session.run(batch) 193 | batch_chunk_unlabel = self.session.run(batch_unlabel) 194 | 195 | z = np.random.uniform(low = -1.0, high = 1.0, size=(self.config["batch_size_G"], self.config["n_z"])).astype(np.float32) 196 | z_tmp = np.random.uniform(low = -1.0, high = 1.0, size=(self.config["n_z"])).astype(np.float32) 197 | z_linspace = np.linspace(z[0], z_tmp, 10) 198 | y_temp = np.random.randint(low = 0, high = self.config["num_class"], size = (self.config["batch_size_G"])) 199 | y = np.zeros((self.config["batch_size_G"], self.config["num_class"])) 200 | y[np.arange(self.config["batch_size_G"]), y_temp] = 1 201 | 202 | fetch = [self.global_step, self.model.accuracy, 203 | self.model.d_loss, self.model.g_loss, 204 | self.model.x_l_ph, 205 | self.model.all_preds, self.model.all_targets] 206 | 207 | if step % (self.config["update_rate"]+1) == 0: 208 | # Train the generator 209 | fetch.append(self.c_optimizer) 210 | elif step % (self.config["update_rate"]+1) == 1: 211 | # Train the discriminator 212 | fetch.append(self.gG_optimizer) 213 | elif step % (self.config["update_rate"]+1) == 2: 214 | fetch.append(self.bG_optimizer) 215 | 216 | fetch_values = self.session.run(fetch, 217 | feed_dict = self.model.get_feed_dict_withunlabel(batch_chunk, batch_chunk_unlabel, z, z_linspace, y, step=step, is_training = is_train)) 218 | # log.error(fetch_values[8]) 219 | [step, accuracy, d_loss, g_loss, Image_Real,\ 220 | all_preds, all_targets] = fetch_values[:7] 221 | 222 | _end_time = time.time() 223 | 224 | return step, accuracy, d_loss, g_loss, \ 225 | (_end_time - _start_time), all_preds, all_targets, Image_Real 226 | 227 | 228 | def run_test(self, batch, batch_unlabel, is_train=False, step=None): 229 | 230 | _start_time = time.time() 231 | 232 | batch_chunk = self.session.run(batch) 233 | batch_chunk_unlabel = self.session.run(batch_unlabel) 234 | 235 | z = np.random.uniform(low = -1.0, high = 1.0, size=(self.config["batch_size_G"], self.config["n_z"])).astype(np.float32) 236 | z_tmp = np.random.uniform(low = -1.0, high = 1.0, size=(self.config["n_z"])).astype(np.float32) 237 | z_linspace = np.linspace(z[0], z_tmp, 10) 238 | y_temp = np.random.randint(low = 0, high = self.config["num_class"], size = (self.config["batch_size_G"])) 239 | y = np.zeros((self.config["batch_size_G"], self.config["num_class"])) 240 | y[np.arange(self.config["batch_size_G"]), y_temp] = 1 241 | 242 | [accuracy, summary, d_loss, g_loss, all_preds, all_targets] = self.session.run( 243 | [self.model.accuracy, self.summary_op, self.model.d_loss, 244 | self.model.g_loss, self.model.all_preds, self.model.all_targets], 245 | feed_dict=self.model.get_feed_dict_withunlabel(batch_chunk, batch_chunk_unlabel, z, z_linspace, y, step=step, is_training = is_train)) 246 | 247 | _end_time = time.time() 248 | 249 | return accuracy, summary, d_loss, g_loss, (_end_time - _start_time), all_preds, all_targets 250 | 251 | def log_step_message(self, step, accuracy, d_loss, g_loss, 252 | step_time, is_train=True): 253 | if step_time == 0: step_time = 0.001 254 | log_fn = (is_train and log.info or log.infov) 255 | log_fn((" [{split_mode:5s} step {step:4d}] " + 256 | "D loss: {d_loss:.5f} " + 257 | "G loss: {g_loss:.5f} " + 258 | "Accuracy: {accuracy:.5f} " 259 | "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) " 260 | ).format(split_mode=(is_train and 'train' or 'val'), 261 | step = step, 262 | d_loss = d_loss, 263 | g_loss = g_loss, 264 | accuracy = accuracy, 265 | sec_per_batch = step_time, 266 | instance_per_sec = self.batch_size / step_time) 267 | ) 268 | 269 | def main(): 270 | try: 271 | params, args = get_params() 272 | params = vars(params) 273 | 274 | config, model, dataset_train, dataset_train_unlabel, dataset_val, dataset_test = argparser(params, is_train=True) 275 | trainer = Trainer(config, model, dataset_train, dataset_train_unlabel, dataset_val) 276 | 277 | log.info("dataset: %s, learning_rate_g: %f, learning_rate_d: %f", 278 | config["dataset"], config["learning_rate_g"], config["learning_rate_d"]) 279 | 280 | trainer.train() 281 | 282 | except Exception as exception: 283 | log.exception(exception) 284 | raise 285 | 286 | if __name__ == '__main__': 287 | 288 | main() 289 | 290 | -------------------------------------------------------------------------------- /model_joint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from util import log 8 | from sngan_gen import Generator 9 | from sngan_gen_y import Generator as gGenerator 10 | from sngan_joint import Classifier_proD 11 | import os 12 | import time 13 | import numpy as np 14 | from diffAugment import DiffAugment 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = '2' 17 | 18 | class Model(object): 19 | 20 | def __init__(self, config, 21 | debug_information=False, 22 | is_train=True): 23 | self.debug = debug_information 24 | 25 | self.config = config 26 | self.batch_size_G = self.config["batch_size_G"] 27 | self.batch_size_L = self.config["batch_size_L"] 28 | self.batch_size_U = self.config["batch_size_U"] 29 | 30 | self.h = self.config["h"] 31 | self.w = self.config["w"] 32 | self.c = self.config["c"] 33 | self.IMAGE_DIM = [self.h, self.w, self.c] 34 | 35 | self.len = self.config["len"] 36 | self.num_class = self.config["num_class"] 37 | self.n_z = self.config["n_z"] 38 | 39 | self.is_training = tf.placeholder_with_default(bool(is_train), [], name='is_training') 40 | 41 | self.z_g_ph = tf.placeholder(tf.float32, [self.batch_size_G, self.n_z], name = 'latent_variable') # latent variable 42 | self.z_g_ph_linspace = tf.placeholder(tf.float32, [10, self.n_z], name = 'latent_variable_linspace') 43 | self.y_g_ph = tf.placeholder(tf.float32, [self.batch_size_G, self.num_class], name='condition_label') 44 | 45 | self.x_l_ph = tf.placeholder(tf.float32, [self.batch_size_L] + self.IMAGE_DIM, name='labeled_images') 46 | self.y_l_ph = tf.placeholder(tf.float32, [self.batch_size_L, self.num_class], name='real_label') 47 | 48 | self.x_u_ph = tf.placeholder(tf.float32, [self.batch_size_U] + self.IMAGE_DIM, name='unlabeled_images') 49 | self.x_u_c_ph = tf.placeholder(tf.float32, [self.batch_size_U] + self.IMAGE_DIM, name='unlabeled_images_for_c') 50 | self.y_u_ph = tf.placeholder(tf.float32, [self.batch_size_U, self.num_class], name='unlabeled_tmp') 51 | 52 | self.weights = tf.placeholder_with_default(0.0, [], name='weight') 53 | self.keep_prob_first = tf.placeholder(tf.float32, name='keep_prob_first') 54 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 55 | tf.summary.scalar("Loss/recon_wieght", self.weights) 56 | 57 | self.build(is_train=is_train) 58 | 59 | def get_feed_dict_withunlabel(self, batch_chunk, batch_chunk_unlabel, z, z_linspace, y, step=None, is_training=None): 60 | fd = { 61 | self.x_l_ph: batch_chunk['image'], 62 | self.y_l_ph: batch_chunk['label'], 63 | self.z_g_ph: z, 64 | self.z_g_ph_linspace: z_linspace, 65 | self.y_g_ph: y, 66 | self.x_u_ph: batch_chunk_unlabel['image'][:self.batch_size_U], 67 | self.x_u_c_ph: batch_chunk_unlabel['image'][self.batch_size_U:self.batch_size_U + self.batch_size_U], 68 | self.y_u_ph: y 69 | } 70 | 71 | if is_training is not None: 72 | fd[self.is_training] = is_training 73 | 74 | if step > 50000: 75 | fd[self.weights] = 1.0 76 | 77 | return fd 78 | 79 | def _entropy(self, logits): 80 | with tf.name_scope('Entropy'): 81 | probs = tf.nn.softmax(logits) 82 | ent = tf.reduce_mean(- tf.reduce_sum(probs * logits, axis=1, keepdims=True) \ 83 | + tf.reduce_logsumexp(logits, axis=1, keepdims=True)) 84 | return ent 85 | 86 | def build(self, is_train=True): 87 | 88 | n = self.num_class 89 | 90 | # build loss and accuracy {{{ 91 | def build_loss(D_real, D_real_logits, D_real_logits_org, D_real_FM, D_real_tf, D_real_logits_tf, D_fake_bad, D_fake_logits_bad, D_fake_bad_FM, D_fake_good, D_fake_logits_good, D_fake_good_tf, D_fake_logits_good_tf, D_unl, D_unl_logits, D_unl_logits_noaug, D_unl_FM, D_unl_hard, D_unl_tf, D_unl_logits_tf, C_unl_hard, C_unl_tf, C_unl_logits_tf, x, fake_image_bad, fake_image_good, label, y_g_ph, x_u_ph, y_u, C): 92 | 93 | # Good GANs 94 | # Discriminator/classifier loss 95 | c_loss_real = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( 96 | logits=D_real_logits, labels=label)) 97 | 98 | c_loss_fake_bad = 0.5 * tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(D_fake_logits_bad, axis = 1))) 99 | c_loss_unl_bad = - 0.5 * tf.reduce_mean(tf.reduce_logsumexp(D_unl_logits, axis = 1)) + 0.5 * tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(D_unl_logits, axis = 1))) 100 | 101 | probs = tf.nn.softmax(D_unl_logits) 102 | p = tf.reduce_max(probs, axis = 1) 103 | c_loss_unl_rein = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(D_unl_logits_tf), 104 | logits=D_unl_logits_tf), axis = 1) ## C fools D 105 | c_loss_unl_rein = 0.5 * tf.reduce_mean(tf.multiply(p, c_loss_unl_rein)) 106 | 107 | c_loss_fake_good_pseudo = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=D_fake_logits_good, labels=y_g_ph)) 108 | 109 | d_loss_real_tf = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(D_real_logits_tf), logits = D_real_logits_tf)) 110 | d_loss_fake_good_tf = 0.5 * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake_logits_good_tf), logits = D_fake_logits_good_tf)) 111 | #d_loss_unl_tf = 0.5 * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_unl_logits_tf), logits = D_unl_logits_tf)) 112 | #d_loss_unl_tf_c = 0.5 * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(C_unl_logits_tf), logits = C_unl_logits_tf)) 113 | 114 | # Dynamic class-rebalancing 115 | d_unl_add=tf.constant(1e-10) 116 | d_loss_unl_tf_tmp=tf.constant(0.0) 117 | for i in range(C_unl_logits_tf.shape[0]): 118 | d_loss_unl_tf_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(0,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_unl_logits_tf[i]), logits = D_unl_logits_tf[i])),tf.constant(0.0)) 119 | d_loss_unl_tf_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(2,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_unl_logits_tf[i]), logits = D_unl_logits_tf[i])),tf.constant(0.0)) 120 | d_loss_unl_tf_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(3,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_unl_logits_tf[i]), logits = D_unl_logits_tf[i])),tf.constant(0.0)) 121 | d_loss_unl_tf_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(4,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_unl_logits_tf[i]), logits = D_unl_logits_tf[i])),tf.constant(0.0)) 122 | d_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(0,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 123 | d_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(2,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 124 | d_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(3,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 125 | d_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(D_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(4,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 126 | d_loss_unl_tf = 0.5 * d_loss_unl_tf_tmp / d_unl_add 127 | 128 | c_unl_add=tf.constant(1e-10) 129 | d_loss_unl_tf_c_tmp=tf.constant(0.0) 130 | for i in range(C_unl_logits_tf.shape[0]): 131 | d_loss_unl_tf_c_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(1,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(C_unl_logits_tf[i]), logits = C_unl_logits_tf[i])),tf.constant(0.0)) 132 | c_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(1,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 133 | d_loss_unl_tf_c_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(2,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(C_unl_logits_tf[i]), logits = C_unl_logits_tf[i])),tf.constant(0.0)) 134 | c_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(2,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 135 | d_loss_unl_tf_c_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(3,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(C_unl_logits_tf[i]), logits = C_unl_logits_tf[i])),tf.constant(0.0)) 136 | c_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(3,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 137 | d_loss_unl_tf_c_tmp += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(4,dtype=tf.int64)),tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(C_unl_logits_tf[i]), logits = C_unl_logits_tf[i])),tf.constant(0.0)) 138 | c_unl_add += tf.where(tf.equal(tf.argmax(tf.one_hot(C_unl_hard, depth = self.num_class),axis=1)[i], tf.constant(4,dtype=tf.int64)),tf.constant(1.0),tf.constant(0.0)) 139 | d_loss_unl_tf_c = 0.5 * d_loss_unl_tf_c_tmp / c_unl_add 140 | 141 | # Conditional entropy 142 | c_ent = 0.1 * tf.reduce_mean(tf.distributions.Categorical(logits=D_unl_logits).entropy()) 143 | 144 | # Batch Nuclear-norm Maximization 145 | c_bnm = - 0.1 * tf.reduce_sum(tf.svd(D_unl, compute_uv = False)) / self.batch_size_U 146 | 147 | c_loss = c_loss_real + self.weights * c_loss_fake_good_pseudo + c_loss_unl_rein 148 | c_loss += c_loss_fake_bad + c_loss_unl_bad + c_bnm 149 | 150 | d_loss = d_loss_real_tf + d_loss_fake_good_tf + d_loss_unl_tf_c + d_loss_unl_tf 151 | 152 | # Feature matching loss 153 | F_match_loss = tf.reduce_mean(tf.abs(tf.reduce_mean(D_unl_FM,axis=0) - tf.reduce_mean(D_fake_bad_FM,axis=0))) 154 | 155 | # Entropy term via pull-away term 156 | feat_norm = D_fake_bad_FM / tf.norm(D_fake_bad_FM, ord='euclidean', axis=1, \ 157 | keepdims=True) 158 | cosine = tf.tensordot(feat_norm, feat_norm, axes=[[1], [1]]) 159 | mask = tf.ones(tf.shape(cosine)) - tf.diag(tf.ones(tf.shape(cosine)[0])) 160 | square = tf.reduce_sum(tf.square(tf.multiply(cosine, mask))) 161 | divident = tf.cast(tf.shape(cosine)[0] * (tf.shape(cosine)[0] - 1), tf.float32) 162 | G_pt = 0.1 * tf.divide(square, divident) 163 | 164 | g_loss_bad = F_match_loss + G_pt 165 | 166 | g_loss_good = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(D_fake_logits_good_tf), logits = D_fake_logits_good_tf)) 167 | 168 | g_loss = g_loss_good + g_loss_bad 169 | 170 | GAN_loss = tf.reduce_mean(d_loss + g_loss + c_loss) 171 | 172 | # Classification accuracy 173 | correct_prediction = tf.equal(tf.argmax(D_real_logits_org, axis = 1), 174 | tf.argmax(label, 1)) 175 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 176 | 177 | return c_loss_real, c_loss_fake_bad, c_loss_unl_bad, c_loss_unl_rein, c_loss_fake_good_pseudo, d_loss_real_tf, d_loss_fake_good_tf, d_loss_unl_tf, d_loss_unl_tf_c, c_ent, c_bnm, F_match_loss, G_pt, g_loss_bad, g_loss_good, c_loss, d_loss, g_loss, GAN_loss, accuracy, d_loss_unl_tf_c_tmp, c_unl_add, d_loss_unl_tf_tmp, d_unl_add 178 | # }}} 179 | 180 | # Generator {{{ 181 | # ========= 182 | G_bad = Generator('Generator_bad', self.h, self.w, self.c, self.is_training, use_sn = False) 183 | G_good = gGenerator('Generator_good', self.h, self.w, self.c, self.is_training, use_sn = False) 184 | 185 | self.fake_image_bad = G_bad(self.z_g_ph) 186 | self.fake_image_good = G_good(self.z_g_ph, self.y_g_ph) 187 | # }}} 188 | 189 | # Generator {{{ 190 | # ========= 191 | # output of C for real images 192 | C = Classifier_proD('Classifier_proD', self.num_class, use_sn = True) 193 | 194 | # Discriminator {{{ 195 | # ========= 196 | x_r = DiffAugment(self.x_l_ph) 197 | _, _, D_real_tf, D_real_logits_tf, _ = C(x_r, self.y_l_ph) 198 | D_real, D_real_logits, _, _, D_real_FM = C(x_r, self.y_l_ph) 199 | 200 | self.real_predict, D_real_logits_org, _, _, _ = C(self.x_l_ph, self.y_l_ph) 201 | 202 | # output of D for generated examples 203 | x_bG = DiffAugment(self.fake_image_bad) 204 | D_fake_bad, D_fake_logits_bad, _, _, D_fake_bad_FM = C(x_bG, self.y_g_ph) 205 | 206 | x_gG = DiffAugment(self.fake_image_good) 207 | D_fake_good, D_fake_logits_good, _, _, D_fake_good_FM = C(x_gG, self.y_g_ph) 208 | _, _, D_fake_good_tf, D_fake_logits_good_tf, _ = C(x_gG, self.y_g_ph) 209 | 210 | # output of D for unlabeled examples (negative example) 211 | x_d = DiffAugment(self.x_u_ph) 212 | D_unl, D_unl_logits, _, _, D_unl_FM = C(x_d, self.y_u_ph) 213 | D_unl_tmp, D_unl_logits_noaug, _, _, _ = C(self.x_u_ph, self.y_u_ph) 214 | D_unl_hard = tf.argmax(D_unl_tmp, axis = 1) 215 | _, _, D_unl_tf, D_unl_logits_tf, _ = C(x_d, tf.one_hot(D_unl_hard, depth = self.num_class)) 216 | 217 | x_c = DiffAugment(self.x_u_c_ph) 218 | C_unl, C_unl_logits, _, _, C_unl_FM = C(x_c, self.y_u_ph) 219 | C_unl_tmp, _, _, _, _ = C(self.x_u_c_ph, self.y_u_ph) 220 | C_unl_hard = tf.argmax(C_unl_tmp, axis = 1) 221 | _, _, C_unl_tf, C_unl_logits_tf, _ = C(x_c, tf.one_hot(C_unl_hard, depth = self.num_class)) 222 | 223 | self.all_preds = D_real_logits_org 224 | self.all_targets = self.y_l_ph 225 | # }}} 226 | 227 | self.real_activations = D_real_logits 228 | self.fake_activations = D_fake_logits_good 229 | 230 | self.c_loss_real, self.c_loss_fake_bad, self.c_loss_unl_bad, self.c_loss_unl_rein, self.c_loss_fake_good_pseudo, self.d_loss_real_tf, self.d_loss_fake_good_tf, self.d_loss_unl_tf, self.d_loss_unl_tf_c, self.c_ent, self.c_bnm, self.F_match_loss, self.G_pt, self.g_loss_bad, self.g_loss_good, self.c_loss, self.d_loss, self.g_loss, self.GAN_loss, self.accuracy, self.d_loss_unl_tf_c_tmp, self.c_unl_add, self.d_loss_unl_tf_tmp, self.d_unl_add = build_loss(D_real, D_real_logits, D_real_logits_org, D_real_FM, D_real_tf, D_real_logits_tf, D_fake_bad, D_fake_logits_bad, D_fake_bad_FM, D_fake_good, D_fake_logits_good, D_fake_good_tf, D_fake_logits_good_tf, D_unl, D_unl_logits, D_unl_logits_noaug, D_unl_FM, D_unl_hard, D_unl_tf, D_unl_logits_tf, C_unl_hard, C_unl_tf, C_unl_logits_tf, self.x_l_ph, self.fake_image_bad, self.fake_image_good, self.y_l_ph, self.y_g_ph, x_d, self.y_u_ph, C) 231 | 232 | tf.summary.scalar("Loss/Accuracy", self.accuracy) 233 | tf.summary.scalar("Loss/C_loss_real", self.c_loss_real) 234 | tf.summary.scalar("Loss/C_loss_fake_bad", self.c_loss_fake_bad) 235 | tf.summary.scalar("Loss/C_loss_unl_bad", self.c_loss_unl_bad) 236 | tf.summary.scalar("Loss/C_loss_unl_rein", self.c_loss_unl_rein) 237 | tf.summary.scalar("Loss/C_loss_fake_good_pseudo", self.c_loss_fake_good_pseudo) 238 | tf.summary.scalar("Loss/C_bnm", self.c_bnm) 239 | tf.summary.scalar("Loss/D_loss_real_tf", self.d_loss_real_tf) 240 | tf.summary.scalar("Loss/D_loss_fake_good_tf", self.d_loss_fake_good_tf) 241 | tf.summary.scalar("Loss/D_loss_unl_tf", self.d_loss_unl_tf) 242 | tf.summary.scalar("Loss/D_loss_unl_tf_c", self.d_loss_unl_tf_c) 243 | tf.summary.scalar("Loss/F_match_loss", self.F_match_loss) 244 | tf.summary.scalar("Loss/G_pt", self.G_pt) 245 | tf.summary.scalar("Loss/G_loss_bad", self.g_loss_bad) 246 | tf.summary.scalar("Loss/G_loss_good", self.g_loss_good) 247 | tf.summary.scalar("Loss/D_loss_unl_tf_tmp", self.d_loss_unl_tf_tmp) 248 | tf.summary.scalar("Loss/D_unl_add", self.d_unl_add) 249 | tf.summary.scalar("Loss/D_loss_unl_tf_c_tmp", self.d_loss_unl_tf_c_tmp) 250 | tf.summary.scalar("Loss/C_unl_add", self.c_unl_add) 251 | 252 | tf.summary.image("Img/Fake_bad", G_bad(self.z_g_ph), max_outputs=10) 253 | tf.summary.image("Img/Normal", G_good(self.z_g_ph_linspace, tf.constant([[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0]], dtype=tf.float32)), max_outputs=10) 254 | tf.summary.image("Img/High", G_good(self.z_g_ph_linspace, tf.constant([[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,1,0,0,0]], dtype=tf.float32)), max_outputs=10) 255 | tf.summary.image("Img/PM", G_good(self.z_g_ph_linspace, tf.constant([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0]], dtype=tf.float32)), max_outputs=10) 256 | tf.summary.image("Img/AMD", G_good(self.z_g_ph_linspace, tf.constant([[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0],[0,0,0,1,0]], dtype=tf.float32)), max_outputs=10) 257 | tf.summary.image("Img/Glaucoma", G_good(self.z_g_ph_linspace, tf.constant([[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1]], dtype=tf.float32)), max_outputs=10) 258 | tf.summary.image("Img/Real", self.x_l_ph, max_outputs=1) 259 | tf.summary.image("Img/DiffAug", DiffAugment(self.x_l_ph), max_outputs=1) 260 | log.warn('\033[93mSuccessfully loaded the model.\033[0m') 261 | -------------------------------------------------------------------------------- /ODIR/ODIR dataset building steps.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ODIR dataset building steps\n", 8 | "To facilitate future method comparisons, we provide the detailed steps to construct ODIR dataset used in our TMI paper as follows." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "### First of all, these 16 images are first removed from the training set.\n", 16 | "The background of the following images is quite different from the rest ones. They are fundus images uploaded from the hospital.\n", 17 | "\n", 18 | "- 2174_right.jpg\n", 19 | "- 2175_left.jpg\n", 20 | "- 2176_left.jpg\n", 21 | "- 2177_left.jpg\n", 22 | "- 2177_right.jpg\n", 23 | "- 2178_right.jpg\n", 24 | "- 2179_left.jpg\n", 25 | "- 2179_right.jpg\n", 26 | "- 2180_left.jpg\n", 27 | "- 2180_right.jpg\n", 28 | "- 2181_left.jpg\n", 29 | "- 2181_right.jpg\n", 30 | "- 2182_left.jpg\n", 31 | "- 2182_right.jpg\n", 32 | "- 2957_left.jpg\n", 33 | "- 2957_right.jpg" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "### Second, we need to modify the following excel file provided by ODIR dataset:\n", 41 | "- training annotation (English).xlsx\n", 42 | "- off-site test annotation (English).xlsx\n", 43 | "- on-site test annotation (English).xlsx\n", 44 | "\n", 45 | "Specifically, we used global substitution to unify diagnostic keywords for the same disease according to the Table II in our paper. \n", 46 | "\n", 47 | "> For example, diagnostic keywords including \"Mild nonproliferative retinopathy\", \"Moderate nonproliferative retinopathy\", \"Severe nonproliferative retinopathy\", \"Proliferative diabetic retinopathy\", \"Severe proliferative diabetic retinopathy\", and \"Diabetic retinopathy\" are all replace with \"Diabetic retinopathy\".\n", 48 | "\n", 49 | "Moreover, we treat all suspected diseases or abnormalities as diagnosed diseases or abnormalities, so all \"suspected \" are replace with \"\".\n", 50 | "\n", 51 | "For the convenience of the follow-up, we upload the final excel file in https://github.com/Xyporz/CISSL-GANs/tree/main/ODIR." 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "# Third, run the following code step by step.\n", 59 | "Remembering to change the path to your file path." 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 1, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "import pandas as pd # 用于读取csv\n", 69 | "import os # 用于设定目录\n", 70 | "import shutil #用于移动文件\n", 71 | "\n", 72 | "# 目录准备\n", 73 | "Picture_Current_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/Training Set/Images/\"\n", 74 | "Current_Path =\"F:/Data/OIA-ODIR/ODIR/\"\n", 75 | "CSV_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/Training Set/Annotation/training annotation (English).xlsx\"\n", 76 | "\n", 77 | "Train_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/ODIR/Train/\"\n", 78 | "\n", 79 | "#打开目录下表格文件并读取\n", 80 | "list = pd.read_excel(CSV_Path)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 2, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "text/html": [ 91 | "
\n", 92 | "\n", 105 | "\n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | "
IDPatient AgePatient SexLeft-FundusRight-FundusLeft-Diagnostic KeywordsRight-Diagnostic KeywordsNDGCAHMO
0069Female0_left.jpg0_right.jpgcataractnormal fundus00010000
1157Male1_left.jpg1_right.jpgnormal fundusnormal fundus10000000
2242Male2_left.jpg2_right.jpglaser spot,diabetic retinopathydiabetic retinopathy01000001
3366Male3_left.jpg3_right.jpgnormal fundusbranch retinal artery occlusion00000001
4453Male4_left.jpg4_right.jpgmacular epiretinal membranediabetic retinopathy01000001
................................................
3495468663Male4686_left.jpg4686_right.jpgdiabetic retinopathydiabetic retinopathy01000000
3496468842Male4688_left.jpg4688_right.jpgdiabetic retinopathydiabetic retinopathy01000000
3497468954Male4689_left.jpg4689_right.jpgdiabetic retinopathynormal fundus01000000
3498469057Male4690_left.jpg4690_right.jpgdiabetic retinopathydiabetic retinopathy01000000
3499478458Male4784_left.jpg4784_right.jpghypertensive retinopathy,age-related macular d...hypertensive retinopathy,age-related macular d...00001100
\n", 327 | "

3500 rows × 15 columns

\n", 328 | "
" 329 | ], 330 | "text/plain": [ 331 | " ID Patient Age Patient Sex Left-Fundus Right-Fundus \\\n", 332 | "0 0 69 Female 0_left.jpg 0_right.jpg \n", 333 | "1 1 57 Male 1_left.jpg 1_right.jpg \n", 334 | "2 2 42 Male 2_left.jpg 2_right.jpg \n", 335 | "3 3 66 Male 3_left.jpg 3_right.jpg \n", 336 | "4 4 53 Male 4_left.jpg 4_right.jpg \n", 337 | "... ... ... ... ... ... \n", 338 | "3495 4686 63 Male 4686_left.jpg 4686_right.jpg \n", 339 | "3496 4688 42 Male 4688_left.jpg 4688_right.jpg \n", 340 | "3497 4689 54 Male 4689_left.jpg 4689_right.jpg \n", 341 | "3498 4690 57 Male 4690_left.jpg 4690_right.jpg \n", 342 | "3499 4784 58 Male 4784_left.jpg 4784_right.jpg \n", 343 | "\n", 344 | " Left-Diagnostic Keywords \\\n", 345 | "0 cataract \n", 346 | "1 normal fundus \n", 347 | "2 laser spot,diabetic retinopathy \n", 348 | "3 normal fundus \n", 349 | "4 macular epiretinal membrane \n", 350 | "... ... \n", 351 | "3495 diabetic retinopathy \n", 352 | "3496 diabetic retinopathy \n", 353 | "3497 diabetic retinopathy \n", 354 | "3498 diabetic retinopathy \n", 355 | "3499 hypertensive retinopathy,age-related macular d... \n", 356 | "\n", 357 | " Right-Diagnostic Keywords N D G C A H M \\\n", 358 | "0 normal fundus 0 0 0 1 0 0 0 \n", 359 | "1 normal fundus 1 0 0 0 0 0 0 \n", 360 | "2 diabetic retinopathy 0 1 0 0 0 0 0 \n", 361 | "3 branch retinal artery occlusion 0 0 0 0 0 0 0 \n", 362 | "4 diabetic retinopathy 0 1 0 0 0 0 0 \n", 363 | "... ... .. .. .. .. .. .. .. \n", 364 | "3495 diabetic retinopathy 0 1 0 0 0 0 0 \n", 365 | "3496 diabetic retinopathy 0 1 0 0 0 0 0 \n", 366 | "3497 normal fundus 0 1 0 0 0 0 0 \n", 367 | "3498 diabetic retinopathy 0 1 0 0 0 0 0 \n", 368 | "3499 hypertensive retinopathy,age-related macular d... 0 0 0 0 1 1 0 \n", 369 | "\n", 370 | " O \n", 371 | "0 0 \n", 372 | "1 0 \n", 373 | "2 1 \n", 374 | "3 1 \n", 375 | "4 1 \n", 376 | "... .. \n", 377 | "3495 0 \n", 378 | "3496 0 \n", 379 | "3497 0 \n", 380 | "3498 0 \n", 381 | "3499 0 \n", 382 | "\n", 383 | "[3500 rows x 15 columns]" 384 | ] 385 | }, 386 | "execution_count": 2, 387 | "metadata": {}, 388 | "output_type": "execute_result" 389 | } 390 | ], 391 | "source": [ 392 | "list" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": 3, 398 | "metadata": {}, 399 | "outputs": [], 400 | "source": [ 401 | "col = [ \"normal fundus\",\n", 402 | " \"diabetic retinopathy\",\n", 403 | " \"glaucoma\",\n", 404 | " \"cataract\",\n", 405 | " \"age-related macular degeneration\",\n", 406 | " \"hypertensive retinopathy\",\n", 407 | " \"myopia retinopathy\" ]" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 4, 413 | "metadata": {}, 414 | "outputs": [ 415 | { 416 | "name": "stdout", 417 | "output_type": "stream", 418 | "text": [ 419 | "2957_left.jpg\n", 420 | "2175_left.jpg\n", 421 | "2176_left.jpg\n", 422 | "2179_left.jpg\n", 423 | "2180_left.jpg\n", 424 | "2181_left.jpg\n", 425 | "2182_left.jpg\n" 426 | ] 427 | } 428 | ], 429 | "source": [ 430 | "for i in col:\n", 431 | " listnew=list[list[\"Left-Diagnostic Keywords\"]==i]\n", 432 | " l=listnew[\"Left-Fundus\"].tolist()\n", 433 | " s=listnew[\"Patient Sex\"].tolist()\n", 434 | " for each in zip(l,s):\n", 435 | " if each[1]=='Male':\n", 436 | " sex = 0\n", 437 | " elif each[1]=='Female':\n", 438 | " sex = 1\n", 439 | " if os.path.exists(Picture_Current_Path+each[0]):\n", 440 | " shutil.move(Picture_Current_Path+each[0],Train_Path+i+'/'+str(sex)+'_'+each[0])\n", 441 | " else:\n", 442 | " print(each[0])" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 5, 448 | "metadata": {}, 449 | "outputs": [ 450 | { 451 | "name": "stdout", 452 | "output_type": "stream", 453 | "text": [ 454 | "2957_right.jpg\n", 455 | "2174_right.jpg\n", 456 | "2177_right.jpg\n", 457 | "2178_right.jpg\n", 458 | "2179_right.jpg\n", 459 | "2180_right.jpg\n", 460 | "2181_right.jpg\n", 461 | "2182_right.jpg\n" 462 | ] 463 | } 464 | ], 465 | "source": [ 466 | "for i in col:\n", 467 | " listnew=list[list[\"Right-Diagnostic Keywords\"]==i]\n", 468 | " l=listnew[\"Right-Fundus\"].tolist()\n", 469 | " s=listnew[\"Patient Sex\"].tolist()\n", 470 | " for each in zip(l,s):\n", 471 | " if each[1]=='Male':\n", 472 | " sex = 2\n", 473 | " elif each[1]=='Female':\n", 474 | " sex = 3\n", 475 | " if os.path.exists(Picture_Current_Path+each[0]):\n", 476 | " shutil.move(Picture_Current_Path+each[0],Train_Path+i+'/'+str(sex)+'_'+each[0])\n", 477 | " else:\n", 478 | " print(each[0])" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 11, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "# 目录准备\n", 488 | "Picture_Current_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/Off-site Test Set/Images/\"\n", 489 | "Current_Path =\"F:/Data/OIA-ODIR/ODIR/\"\n", 490 | "CSV_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/Off-site Test Set/Annotation/off-site test annotation (English).xlsx\"\n", 491 | "\n", 492 | "Train_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/ODIR/Val/\"\n", 493 | "\n", 494 | "#打开目录下表格文件并读取\n", 495 | "list = pd.read_excel(CSV_Path)" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": 12, 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "for i in col:\n", 505 | " listnew=list[list[\"Left-Diagnostic Keywords\"]==i]\n", 506 | " l=listnew[\"Left-Fundus\"].tolist()\n", 507 | " s=listnew[\"Patient Sex\"].tolist()\n", 508 | " for each in zip(l,s):\n", 509 | " if each[1]=='Male':\n", 510 | " sex = 0\n", 511 | " elif each[1]=='Female':\n", 512 | " sex = 1\n", 513 | " if os.path.exists(Picture_Current_Path+each[0]):\n", 514 | " shutil.move(Picture_Current_Path+each[0],Train_Path+i+'/'+str(sex)+'_'+each[0])\n", 515 | " else:\n", 516 | " print(each[0])\n", 517 | "\n", 518 | "for i in col:\n", 519 | " listnew=list[list[\"Right-Diagnostic Keywords\"]==i]\n", 520 | " l=listnew[\"Right-Fundus\"].tolist()\n", 521 | " s=listnew[\"Patient Sex\"].tolist()\n", 522 | " for each in zip(l,s):\n", 523 | " if each[1]=='Male':\n", 524 | " sex = 2\n", 525 | " elif each[1]=='Female':\n", 526 | " sex = 3\n", 527 | " if os.path.exists(Picture_Current_Path+each[0]):\n", 528 | " shutil.move(Picture_Current_Path+each[0],Train_Path+i+'/'+str(sex)+'_'+each[0])\n", 529 | " else:\n", 530 | " print(each[0])" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": 13, 536 | "metadata": {}, 537 | "outputs": [], 538 | "source": [ 539 | "# 目录准备\n", 540 | "Picture_Current_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/On-site Test Set/Images/\"\n", 541 | "Current_Path =\"F:/Data/OIA-ODIR/ODIR/\"\n", 542 | "CSV_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/On-site Test Set/Annotation/on-site test annotation (English).xlsx\"\n", 543 | "\n", 544 | "Train_Path = \"F:/Data/ODIR_Fromscratch/OIA-ODIR/ODIR/Test/\"\n", 545 | "\n", 546 | "#打开目录下表格文件并读取\n", 547 | "list = pd.read_excel(CSV_Path)" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 14, 553 | "metadata": {}, 554 | "outputs": [], 555 | "source": [ 556 | "for i in col:\n", 557 | " listnew=list[list[\"Left-Diagnostic Keywords\"]==i]\n", 558 | " l=listnew[\"Left-Fundus\"].tolist()\n", 559 | " s=listnew[\"Patient Sex\"].tolist()\n", 560 | " for each in zip(l,s):\n", 561 | " if each[1]=='Male':\n", 562 | " sex = 0\n", 563 | " elif each[1]=='Female':\n", 564 | " sex = 1\n", 565 | " if os.path.exists(Picture_Current_Path+each[0]):\n", 566 | " shutil.move(Picture_Current_Path+each[0],Train_Path+i+'/'+str(sex)+'_'+each[0])\n", 567 | " else:\n", 568 | " print(each[0])\n", 569 | "\n", 570 | "for i in col:\n", 571 | " listnew=list[list[\"Right-Diagnostic Keywords\"]==i]\n", 572 | " l=listnew[\"Right-Fundus\"].tolist()\n", 573 | " s=listnew[\"Patient Sex\"].tolist()\n", 574 | " for each in zip(l,s):\n", 575 | " if each[1]=='Male':\n", 576 | " sex = 2\n", 577 | " elif each[1]=='Female':\n", 578 | " sex = 3\n", 579 | " if os.path.exists(Picture_Current_Path+each[0]):\n", 580 | " shutil.move(Picture_Current_Path+each[0],Train_Path+i+'/'+str(sex)+'_'+each[0])\n", 581 | " else:\n", 582 | " print(each[0])" 583 | ] 584 | } 585 | ], 586 | "metadata": { 587 | "kernelspec": { 588 | "display_name": "Python 3", 589 | "language": "python", 590 | "name": "python3" 591 | }, 592 | "language_info": { 593 | "codemirror_mode": { 594 | "name": "ipython", 595 | "version": 3 596 | }, 597 | "file_extension": ".py", 598 | "mimetype": "text/x-python", 599 | "name": "python", 600 | "nbconvert_exporter": "python", 601 | "pygments_lexer": "ipython3", 602 | "version": "3.7.9" 603 | } 604 | }, 605 | "nbformat": 4, 606 | "nbformat_minor": 4 607 | } 608 | -------------------------------------------------------------------------------- /arch_ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Provides a library of custom architecture-related operations. 17 | It currently provides the following operations: 18 | - linear, conv2d, deconv2d, lrelu 19 | - batch norm, conditional batch norm, self-modulation 20 | - spectral norm, weight norm, layer norm 21 | - self-attention block 22 | - various weight initialization schemes 23 | These operations are supported on both GPUs and TPUs. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import functools 31 | 32 | from absl import logging 33 | 34 | from compare_gan.gans import consts 35 | from compare_gan.tpu import tpu_ops 36 | import gin 37 | from six.moves import range 38 | import tensorflow as tf 39 | 40 | from tensorflow.contrib.tpu.python.tpu import tpu_function 41 | from tensorflow.python.training import moving_averages # pylint: disable=g-direct-tensorflow-import 42 | 43 | 44 | @gin.configurable("weights") 45 | def weight_initializer(initializer=consts.NORMAL_INIT, stddev=0.02): 46 | """Returns the initializer for the given name. 47 | Args: 48 | initializer: Name of the initalizer. Use one in consts.INITIALIZERS. 49 | stddev: Standard deviation passed to initalizer. 50 | Returns: 51 | Initializer from `tf.initializers`. 52 | """ 53 | if initializer == consts.NORMAL_INIT: 54 | return tf.initializers.random_normal(stddev=stddev) 55 | if initializer == consts.TRUNCATED_INIT: 56 | return tf.initializers.truncated_normal(stddev=stddev) 57 | if initializer == consts.ORTHOGONAL_INIT: 58 | return tf.initializers.orthogonal() 59 | raise ValueError("Unknown weight initializer {}.".format(initializer)) 60 | 61 | 62 | def _moving_moments_for_inference(mean, variance, is_training, decay): 63 | """Use moving averages of moments during inference. 64 | Args: 65 | mean: Tensor of shape [num_channels] with the mean of the current batch. 66 | variance: Tensor of shape [num_channels] with the variance of the current 67 | batch. 68 | is_training: Boolean, wheather to construct ops for training or inference 69 | graph. 70 | decay: Decay rate to use for moving averages. 71 | Returns: 72 | Tuple of (mean, variance) to use. This can the same as the inputs. 73 | """ 74 | # Create the moving average variables and add them to the appropriate 75 | # collections. 76 | variable_collections = [ 77 | tf.GraphKeys.MOVING_AVERAGE_VARIABLES, 78 | tf.GraphKeys.MODEL_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES, 79 | ] 80 | # Disable partition setting for moving_mean and moving_variance 81 | # as assign_moving_average op below doesn"t support partitioned variable. 82 | moving_mean = tf.get_variable( 83 | "moving_mean", 84 | shape=mean.shape, 85 | initializer=tf.zeros_initializer(), 86 | trainable=False, 87 | partitioner=None, 88 | collections=variable_collections) 89 | moving_variance = tf.get_variable( 90 | "moving_variance", 91 | shape=variance.shape, 92 | initializer=tf.ones_initializer(), 93 | trainable=False, 94 | partitioner=None, 95 | collections=variable_collections) 96 | if is_training: 97 | logging.debug("Adding update ops for moving averages of mean and variance.") 98 | # Update variables for mean and variance during training. 99 | update_moving_mean = moving_averages.assign_moving_average( 100 | moving_mean, 101 | tf.cast(mean, moving_mean.dtype), 102 | decay, 103 | zero_debias=False) 104 | update_moving_variance = moving_averages.assign_moving_average( 105 | moving_variance, 106 | tf.cast(variance, moving_variance.dtype), 107 | decay, 108 | zero_debias=False) 109 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_mean) 110 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_variance) 111 | return mean, variance 112 | logging.debug("Using moving mean and variance.") 113 | return moving_mean, moving_variance 114 | 115 | 116 | def _accumulated_moments_for_inference(mean, variance, is_training): 117 | """Use accumulated statistics for moments during inference. 118 | After training the user is responsible for filling the accumulators with the 119 | actual values. See _UpdateBnAccumulators() in eval_gan_lib.py for an example. 120 | Args: 121 | mean: Tensor of shape [num_channels] with the mean of the current batch. 122 | variance: Tensor of shape [num_channels] with the variance of the current 123 | batch. 124 | is_training: Boolean, wheather to construct ops for training or inference 125 | graph. 126 | Returns: 127 | Tuple of (mean, variance) to use. This can the same as the inputs. 128 | """ 129 | variable_collections = [ 130 | tf.GraphKeys.MODEL_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES, 131 | ] 132 | with tf.variable_scope("accu", values=[mean, variance]): 133 | # Create variables for accumulating batch statistic and use them during 134 | # inference. The ops for filling the accumulators must be created and run 135 | # before eval. See docstring above. 136 | accu_mean = tf.get_variable( 137 | "accu_mean", 138 | shape=mean.shape, 139 | initializer=tf.zeros_initializer(), 140 | trainable=False, 141 | collections=variable_collections) 142 | accu_variance = tf.get_variable( 143 | "accu_variance", 144 | shape=variance.shape, 145 | initializer=tf.zeros_initializer(), 146 | trainable=False, 147 | collections=variable_collections) 148 | accu_counter = tf.get_variable( 149 | "accu_counter", 150 | shape=[], 151 | initializer=tf.initializers.constant(1e-12), 152 | trainable=False, 153 | collections=variable_collections) 154 | update_accus = tf.get_variable( 155 | "update_accus", 156 | shape=[], 157 | dtype=tf.int32, 158 | initializer=tf.zeros_initializer(), 159 | trainable=False, 160 | collections=variable_collections) 161 | 162 | mean = tf.identity(mean, "mean") 163 | variance = tf.identity(variance, "variance") 164 | 165 | if is_training: 166 | return mean, variance 167 | 168 | logging.debug("Using accumulated moments.") 169 | # Return the accumulated batch statistics and add current batch statistics 170 | # to accumulators if update_accus variables equals 1. 171 | def update_accus_fn(): 172 | return tf.group([ 173 | tf.assign_add(accu_mean, mean), 174 | tf.assign_add(accu_variance, variance), 175 | tf.assign_add(accu_counter, 1), 176 | ]) 177 | dep = tf.cond( 178 | tf.equal(update_accus, 1), 179 | update_accus_fn, 180 | tf.no_op) 181 | with tf.control_dependencies([dep]): 182 | return accu_mean / accu_counter, accu_variance / accu_counter 183 | 184 | 185 | @gin.configurable(whitelist=["decay", "epsilon", "use_cross_replica_mean", 186 | "use_moving_averages"]) 187 | def standardize_batch(inputs, 188 | is_training, 189 | decay=0.999, 190 | epsilon=1e-3, 191 | data_format="NHWC", 192 | use_moving_averages=True, 193 | use_cross_replica_mean=None): 194 | """Adds TPU-enabled batch normalization layer. 195 | This version does not apply trainable scale or offset! 196 | It normalizes a tensor by mean and variance. 197 | Details on Batch Normalization can be found in "Batch Normalization: 198 | Accelerating Deep Network Training by Reducing Internal Covariate Shift", 199 | Ioffe S. and Szegedy C. 2015 [http://arxiv.org/abs/1502.03167]. 200 | Note #1: This method computes the batch statistic across all TPU replicas, 201 | thus simulating the true batch norm in the distributed setting. If one wants 202 | to avoid the cross-replica communication set use_cross_replica_mean=False. 203 | Note #2: When is_training is True the moving_mean and moving_variance need 204 | to be updated in each training step. By default, the update_ops are placed 205 | in `tf.GraphKeys.UPDATE_OPS` and they need to be added as a dependency to 206 | the `train_op`. For example: 207 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 208 | if update_ops: 209 | updates = tf.group(*update_ops) 210 | total_loss = control_flow_ops.with_dependencies([updates], total_loss) 211 | Note #3: Reasonable values for `decay` are close to 1.0, typically in the 212 | multiple-nines range: 0.999, 0.99, 0.9, etc. Lower the `decay` value (trying 213 | `decay`=0.9) if model experiences reasonably good training performance but 214 | poor validation and/or test performance. 215 | Args: 216 | inputs: A tensor with 2 or 4 dimensions, where the first dimension is 217 | `batch_size`. The normalization is over all but the last dimension if 218 | `data_format` is `NHWC`, and the second dimension if `data_format` is 219 | `NCHW`. 220 | is_training: Whether or not the layer is in training mode. In training 221 | mode it would accumulate the statistics of the moments into the 222 | `moving_mean` and `moving_variance` using an exponential moving average 223 | with the given `decay`. When is_training=False, these variables are not 224 | updated, and the precomputed values are used verbatim. 225 | decay: Decay for the moving averages. See notes above for reasonable 226 | values. 227 | epsilon: Small float added to variance to avoid dividing by zero. 228 | data_format: Input data format. NHWC or NCHW. 229 | use_moving_averages: If True keep moving averages of mean and variance that 230 | are used during inference. Otherwise use accumlators. 231 | use_cross_replica_mean: If True add operations to do computes batch norm 232 | statistics across all TPU cores. These ops are not compatible with other 233 | platforms. The default (None) will only add the operations if running 234 | on TPU. 235 | Returns: 236 | The normalized tensor with the same type and shape as `inputs`. 237 | """ 238 | if data_format not in {"NCHW", "NHWC"}: 239 | raise ValueError( 240 | "Invalid data_format {}. Allowed: NCHW, NHWC.".format(data_format)) 241 | if use_cross_replica_mean is None: 242 | # Default to global batch norm only on TPUs. 243 | use_cross_replica_mean = ( 244 | tpu_function.get_tpu_context().number_of_shards is not None) 245 | logging.debug("Automatically determined use_cross_replica_mean=%s.", 246 | use_cross_replica_mean) 247 | 248 | inputs = tf.convert_to_tensor(inputs) 249 | inputs_dtype = inputs.dtype 250 | inputs_shape = inputs.get_shape() 251 | 252 | num_channels = inputs.shape[-1].value 253 | if num_channels is None: 254 | raise ValueError("`C` dimension must be known but is None") 255 | 256 | inputs_rank = inputs_shape.ndims 257 | if inputs_rank is None: 258 | raise ValueError("Inputs %s has undefined rank" % inputs.name) 259 | elif inputs_rank not in [2, 4]: 260 | raise ValueError( 261 | "Inputs %s has unsupported rank." 262 | " Expected 2 or 4 but got %d" % (inputs.name, inputs_rank)) 263 | # Bring 2-D inputs into 4-D format. 264 | if inputs_rank == 2: 265 | new_shape = [-1, 1, 1, num_channels] 266 | if data_format == "NCHW": 267 | new_shape = [-1, num_channels, 1, 1] 268 | inputs = tf.reshape(inputs, new_shape) 269 | 270 | # Execute a distributed batch normalization 271 | axis = 1 if data_format == "NCHW" else 3 272 | inputs = tf.cast(inputs, tf.float32) 273 | reduction_axes = [i for i in range(4) if i != axis] 274 | if use_cross_replica_mean: 275 | mean, variance = tpu_ops.cross_replica_moments(inputs, reduction_axes) 276 | else: 277 | counts, mean_ss, variance_ss, _ = tf.nn.sufficient_statistics( 278 | inputs, reduction_axes, keep_dims=False) 279 | mean, variance = tf.nn.normalize_moments( 280 | counts, mean_ss, variance_ss, shift=None) 281 | 282 | if use_moving_averages: 283 | mean, variance = _moving_moments_for_inference( 284 | mean=mean, variance=variance, is_training=is_training, decay=decay) 285 | else: 286 | mean, variance = _accumulated_moments_for_inference( 287 | mean=mean, variance=variance, is_training=is_training) 288 | 289 | outputs = tf.nn.batch_normalization( 290 | inputs, 291 | mean=mean, 292 | variance=variance, 293 | offset=None, 294 | scale=None, 295 | variance_epsilon=epsilon) 296 | outputs = tf.cast(outputs, inputs_dtype) 297 | 298 | # Bring 2-D inputs back into 2-D format. 299 | if inputs_rank == 2: 300 | outputs = tf.reshape(outputs, [-1] + inputs_shape[1:].as_list()) 301 | outputs.set_shape(inputs_shape) 302 | return outputs 303 | 304 | 305 | @gin.configurable(blacklist=["inputs"]) 306 | def no_batch_norm(inputs): 307 | return inputs 308 | 309 | 310 | @gin.configurable( 311 | blacklist=["inputs", "is_training", "center", "scale", "name"]) 312 | def batch_norm(inputs, is_training, center=True, scale=True, name="batch_norm"): 313 | """Performs the vanilla batch normalization with trainable scaling and offset. 314 | Args: 315 | inputs: A tensor with 2 or 4 dimensions, where the first dimension is 316 | `batch_size`. The normalization is over all but the last dimension if 317 | `data_format` is `NHWC`, and the second dimension if `data_format` is 318 | `NCHW`. 319 | is_training: Whether or not the layer is in training mode. 320 | center: If True, add offset of beta to normalized tensor. 321 | scale: If True, multiply by gamma. When the next layer is linear this can 322 | be disabled since the scaling will be done by the next layer. 323 | name: Name of the variable scope. 324 | Returns: 325 | The normalized tensor with the same type and shape as `inputs`. 326 | """ 327 | with tf.variable_scope(name, values=[inputs]): 328 | outputs = tf.contrib.layers.batch_norm(inputs, decay=0.9, epsilon=1e-5, center=False, scale=False, updates_collections=None, is_training=is_training) 329 | num_channels = inputs.shape[-1].value 330 | 331 | # Allocate parameters for the trainable variables. 332 | collections = [tf.GraphKeys.MODEL_VARIABLES, 333 | tf.GraphKeys.GLOBAL_VARIABLES] 334 | if scale: 335 | gamma = tf.get_variable( 336 | "gamma", 337 | [num_channels], 338 | collections=collections, 339 | initializer=tf.ones_initializer()) 340 | outputs *= gamma 341 | if center: 342 | beta = tf.get_variable( 343 | "beta", 344 | [num_channels], 345 | collections=collections, 346 | initializer=tf.zeros_initializer()) 347 | outputs += beta 348 | return outputs 349 | 350 | 351 | @gin.configurable(whitelist=["num_hidden"]) 352 | def self_modulated_batch_norm(inputs, z, is_training, use_sn, 353 | center=True, scale=True, 354 | name="batch_norm", num_hidden=32): 355 | """Performs a self-modulated batch normalization. 356 | Details can be found in "On Self Modulation for Generative Adversarial 357 | Networks", Chen T. et al., 2018. [https://arxiv.org/abs/1810.01365] 358 | Like a normal batch normalization but the scale and offset are trainable 359 | transformation of `z`. 360 | Args: 361 | inputs: A tensor with 2 or 4 dimensions, where the first dimension is 362 | `batch_size`. The normalization is over all but the last dimension if 363 | `data_format` is `NHWC`, and the second dimension if `data_format` is 364 | `NCHW`. 365 | z: 2-D tensor with shape [batch_size, ?] with the latent code. 366 | is_training: Whether or not the layer is in training mode. 367 | use_sn: Whether to apply spectral normalization to the weights of the 368 | hidden layer and the linear transformations. 369 | center: If True, add offset of beta to normalized tensor. 370 | scale: If True, multiply by gamma. When the next layer is linear this can 371 | be disabled since the scaling will be done by the next layer. 372 | name: Name of the variable scope. 373 | num_hidden: Number of hidden units in the hidden layer. If 0 the scale and 374 | offset are simple linear transformations of `z`. 375 | Returns: 376 | """ 377 | if z is None: 378 | raise ValueError("You must provide z for self modulation.") 379 | with tf.variable_scope(name, values=[inputs]): 380 | outputs = standardize_batch(inputs, is_training=is_training) 381 | num_channels = inputs.shape[-1].value 382 | 383 | with tf.variable_scope("sbn", values=[inputs, z]): 384 | h = z 385 | if num_hidden > 0: 386 | h = linear(h, num_hidden, scope="hidden", use_sn=use_sn) 387 | h = tf.nn.relu(h) 388 | if scale: 389 | gamma = linear(h, num_channels, scope="gamma", bias_start=1.0, 390 | use_sn=use_sn) 391 | gamma = tf.reshape(gamma, [-1, 1, 1, num_channels]) 392 | outputs *= gamma 393 | if center: 394 | beta = linear(h, num_channels, scope="beta", use_sn=use_sn) 395 | beta = tf.reshape(beta, [-1, 1, 1, num_channels]) 396 | outputs += beta 397 | return outputs 398 | 399 | 400 | @gin.configurable(whitelist=["use_bias"]) 401 | def conditional_batch_norm(inputs, y, is_training, use_sn, center=True, 402 | scale=True, name="batch_norm", use_bias=False): 403 | """Conditional batch normalization.""" 404 | if y is None: 405 | raise ValueError("You must provide y for conditional batch normalization.") 406 | if y.shape.ndims != 2: 407 | raise ValueError("Conditioning must have rank 2.") 408 | with tf.variable_scope(name, values=[inputs]): 409 | outputs =tf.contrib.layers.batch_norm(inputs, decay=0.9, epsilon=1e-5, center=False, scale=False, updates_collections=None, is_training=is_training) 410 | num_channels = inputs.shape[-1].value 411 | with tf.variable_scope("condition", values=[inputs, y]): 412 | if scale: 413 | gamma = linear(y, num_channels, scope="gamma", use_sn=use_sn, 414 | use_bias=use_bias) 415 | gamma = tf.reshape(gamma, [-1, 1, 1, num_channels]) 416 | outputs *= gamma 417 | if center: 418 | beta = linear(y, num_channels, scope="beta", use_sn=use_sn, 419 | use_bias=use_bias) 420 | beta = tf.reshape(beta, [-1, 1, 1, num_channels]) 421 | outputs += beta 422 | return outputs 423 | 424 | 425 | def layer_norm(input_, is_training, scope): 426 | return tf.contrib.layers.layer_norm( 427 | input_, trainable=is_training, scope=scope) 428 | 429 | 430 | @gin.configurable(blacklist=["inputs"]) 431 | def spectral_norm(inputs, epsilon=1e-12, singular_value="auto"): 432 | """Performs Spectral Normalization on a weight tensor. 433 | Details of why this is helpful for GAN's can be found in "Spectral 434 | Normalization for Generative Adversarial Networks", Miyato T. et al., 2018. 435 | [https://arxiv.org/abs/1802.05957]. 436 | Args: 437 | inputs: The weight tensor to normalize. 438 | epsilon: Epsilon for L2 normalization. 439 | singular_value: Which first singular value to store (left or right). Use 440 | "auto" to automatically choose the one that has fewer dimensions. 441 | Returns: 442 | The normalized weight tensor. 443 | """ 444 | if len(inputs.shape) < 2: 445 | raise ValueError( 446 | "Spectral norm can only be applied to multi-dimensional tensors") 447 | 448 | # The paper says to flatten convnet kernel weights from (C_out, C_in, KH, KW) 449 | # to (C_out, C_in * KH * KW). Our Conv2D kernel shape is (KH, KW, C_in, C_out) 450 | # so it should be reshaped to (KH * KW * C_in, C_out), and similarly for other 451 | # layers that put output channels as last dimension. This implies that w 452 | # here is equivalent to w.T in the paper. 453 | w = tf.reshape(inputs, (-1, inputs.shape[-1])) 454 | 455 | # Choose whether to persist the first left or first right singular vector. 456 | # As the underlying matrix is PSD, this should be equivalent, but in practice 457 | # the shape of the persisted vector is different. Here one can choose whether 458 | # to maintain the left or right one, or pick the one which has the smaller 459 | # dimension. We use the same variable for the singular vector if we switch 460 | # from normal weights to EMA weights. 461 | var_name = inputs.name.replace("/ExponentialMovingAverage", "").split("/")[-1] 462 | var_name = var_name.split(":")[0] + "/u_var" 463 | if singular_value == "auto": 464 | singular_value = "left" if w.shape[0] <= w.shape[1] else "right" 465 | u_shape = (w.shape[0], 1) if singular_value == "left" else (1, w.shape[-1]) 466 | u_var = tf.get_variable( 467 | var_name, 468 | shape=u_shape, 469 | dtype=w.dtype, 470 | initializer=tf.random_normal_initializer(), 471 | trainable=False) 472 | u = u_var 473 | 474 | # Use power iteration method to approximate the spectral norm. 475 | # The authors suggest that one round of power iteration was sufficient in the 476 | # actual experiment to achieve satisfactory performance. 477 | power_iteration_rounds = 1 478 | for _ in range(power_iteration_rounds): 479 | if singular_value == "left": 480 | # `v` approximates the first right singular vector of matrix `w`. 481 | v = tf.math.l2_normalize( 482 | tf.matmul(tf.transpose(w), u), axis=None, epsilon=epsilon) 483 | u = tf.math.l2_normalize(tf.matmul(w, v), axis=None, epsilon=epsilon) 484 | else: 485 | v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True), 486 | epsilon=epsilon) 487 | u = tf.math.l2_normalize(tf.matmul(v, w), epsilon=epsilon) 488 | 489 | # Update the approximation. 490 | with tf.control_dependencies([tf.assign(u_var, u, name="update_u")]): 491 | u = tf.identity(u) 492 | 493 | # The authors of SN-GAN chose to stop gradient propagating through u and v 494 | # and we maintain that option. 495 | u = tf.stop_gradient(u) 496 | v = tf.stop_gradient(v) 497 | 498 | if singular_value == "left": 499 | norm_value = tf.matmul(tf.matmul(tf.transpose(u), w), v) 500 | else: 501 | norm_value = tf.matmul(tf.matmul(v, w), u, transpose_b=True) 502 | norm_value.shape.assert_is_fully_defined() 503 | norm_value.shape.assert_is_compatible_with([1, 1]) 504 | 505 | w_normalized = w / norm_value 506 | 507 | # Deflate normalized weights to match the unnormalized tensor. 508 | w_tensor_normalized = tf.reshape(w_normalized, inputs.shape) 509 | return w_tensor_normalized 510 | 511 | 512 | def linear(inputs, output_size, scope=None, stddev=0.02, bias_start=0.0, 513 | use_sn=False, use_bias=True): 514 | """Linear layer without the non-linear activation applied.""" 515 | shape = inputs.get_shape().as_list() 516 | with tf.variable_scope(scope or "linear"): 517 | kernel = tf.get_variable( 518 | "kernel", 519 | [shape[1], output_size], 520 | initializer=weight_initializer(stddev=stddev)) 521 | if use_sn: 522 | kernel = spectral_norm(kernel) 523 | outputs = tf.matmul(inputs, kernel) 524 | if use_bias: 525 | bias = tf.get_variable( 526 | "bias", 527 | [output_size], 528 | initializer=tf.constant_initializer(bias_start)) 529 | outputs += bias 530 | return outputs 531 | 532 | 533 | def conv2d(inputs, output_dim, k_h, k_w, d_h, d_w, stddev=0.02, name="conv2d", 534 | use_sn=False, use_bias=True): 535 | """Performs 2D convolution of the input.""" 536 | with tf.variable_scope(name): 537 | w = tf.get_variable( 538 | "kernel", [k_h, k_w, inputs.shape[-1].value, output_dim], 539 | initializer=weight_initializer(stddev=stddev)) 540 | if use_sn: 541 | w = spectral_norm(w) 542 | outputs = tf.nn.conv2d(inputs, w, strides=[1, d_h, d_w, 1], padding="SAME") 543 | if use_bias: 544 | bias = tf.get_variable( 545 | "bias", [output_dim], initializer=tf.constant_initializer(0.0)) 546 | outputs += bias 547 | return outputs 548 | 549 | 550 | conv1x1 = functools.partial(conv2d, k_h=1, k_w=1, d_h=1, d_w=1) 551 | 552 | 553 | def deconv2d(inputs, output_shape, k_h, k_w, d_h, d_w, 554 | stddev=0.02, name="deconv2d", use_sn=False): 555 | """Performs transposed 2D convolution of the input.""" 556 | with tf.variable_scope(name): 557 | w = tf.get_variable( 558 | "kernel", [k_h, k_w, output_shape[-1], inputs.get_shape()[-1]], 559 | initializer=weight_initializer(stddev=stddev)) 560 | if use_sn: 561 | w = spectral_norm(w) 562 | deconv = tf.nn.conv2d_transpose( 563 | inputs, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) 564 | bias = tf.get_variable( 565 | "bias", [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 566 | return tf.reshape(tf.nn.bias_add(deconv, bias), tf.shape(deconv)) 567 | 568 | 569 | def lrelu(inputs, leak=0.2, name="lrelu"): 570 | """Performs leaky-ReLU on the input.""" 571 | return tf.maximum(inputs, leak * inputs, name=name) 572 | 573 | 574 | def weight_norm_linear(input_, output_size, 575 | init=False, init_scale=1.0, 576 | name="wn_linear", 577 | initializer=tf.truncated_normal_initializer, 578 | stddev=0.02): 579 | """Linear layer with Weight Normalization (Salimans, Kingma '16).""" 580 | with tf.variable_scope(name): 581 | if init: 582 | v = tf.get_variable("V", [int(input_.get_shape()[1]), output_size], 583 | tf.float32, initializer(0, stddev), trainable=True) 584 | v_norm = tf.nn.l2_normalize(v.initialized_value(), [0]) 585 | x_init = tf.matmul(input_, v_norm) 586 | m_init, v_init = tf.nn.moments(x_init, [0]) 587 | scale_init = init_scale / tf.sqrt(v_init + 1e-10) 588 | g = tf.get_variable("g", dtype=tf.float32, 589 | initializer=scale_init, trainable=True) 590 | b = tf.get_variable("b", dtype=tf.float32, initializer= 591 | -m_init*scale_init, trainable=True) 592 | x_init = tf.reshape(scale_init, [1, output_size]) * ( 593 | x_init - tf.reshape(m_init, [1, output_size])) 594 | return x_init 595 | else: 596 | # Note that the original implementation uses Polyak averaging. 597 | v = tf.get_variable("V") 598 | g = tf.get_variable("g") 599 | b = tf.get_variable("b") 600 | tf.assert_variables_initialized([v, g, b]) 601 | x = tf.matmul(input_, v) 602 | scaler = g / tf.sqrt(tf.reduce_sum(tf.square(v), [0])) 603 | x = tf.reshape(scaler, [1, output_size]) * x + tf.reshape( 604 | b, [1, output_size]) 605 | return x 606 | 607 | 608 | def weight_norm_conv2d(input_, output_dim, 609 | k_h, k_w, d_h, d_w, 610 | init, init_scale, 611 | stddev=0.02, 612 | name="wn_conv2d", 613 | initializer=tf.truncated_normal_initializer): 614 | """Performs convolution with Weight Normalization.""" 615 | with tf.variable_scope(name): 616 | if init: 617 | v = tf.get_variable( 618 | "V", [k_h, k_w] + [int(input_.get_shape()[-1]), output_dim], 619 | tf.float32, initializer(0, stddev), trainable=True) 620 | v_norm = tf.nn.l2_normalize(v.initialized_value(), [0, 1, 2]) 621 | x_init = tf.nn.conv2d(input_, v_norm, strides=[1, d_h, d_w, 1], 622 | padding="SAME") 623 | m_init, v_init = tf.nn.moments(x_init, [0, 1, 2]) 624 | scale_init = init_scale / tf.sqrt(v_init + 1e-8) 625 | g = tf.get_variable( 626 | "g", dtype=tf.float32, initializer=scale_init, trainable=True) 627 | b = tf.get_variable( 628 | "b", dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) 629 | x_init = tf.reshape(scale_init, [1, 1, 1, output_dim]) * ( 630 | x_init - tf.reshape(m_init, [1, 1, 1, output_dim])) 631 | return x_init 632 | else: 633 | v = tf.get_variable("V") 634 | g = tf.get_variable("g") 635 | b = tf.get_variable("b") 636 | tf.assert_variables_initialized([v, g, b]) 637 | w = tf.reshape(g, [1, 1, 1, output_dim]) * tf.nn.l2_normalize( 638 | v, [0, 1, 2]) 639 | x = tf.nn.bias_add( 640 | tf.nn.conv2d(input_, w, [1, d_h, d_w, 1], padding="SAME"), b) 641 | return x 642 | 643 | 644 | def weight_norm_deconv2d(x, output_dim, 645 | k_h, k_w, d_h, d_w, 646 | init=False, init_scale=1.0, 647 | stddev=0.02, 648 | name="wn_deconv2d", 649 | initializer=tf.truncated_normal_initializer): 650 | """Performs Transposed Convolution with Weight Normalization.""" 651 | xs = x.get_shape().as_list() 652 | target_shape = [xs[0], xs[1] * d_h, xs[2] * d_w, output_dim] 653 | with tf.variable_scope(name): 654 | if init: 655 | v = tf.get_variable( 656 | "V", [k_h, k_w] + [output_dim, int(x.get_shape()[-1])], 657 | tf.float32, initializer(0, stddev), trainable=True) 658 | v_norm = tf.nn.l2_normalize(v.initialized_value(), [0, 1, 3]) 659 | x_init = tf.nn.conv2d_transpose(x, v_norm, target_shape, 660 | [1, d_h, d_w, 1], padding="SAME") 661 | m_init, v_init = tf.nn.moments(x_init, [0, 1, 2]) 662 | scale_init = init_scale/tf.sqrt(v_init + 1e-8) 663 | g = tf.get_variable("g", dtype=tf.float32, 664 | initializer=scale_init, trainable=True) 665 | b = tf.get_variable("b", dtype=tf.float32, 666 | initializer=-m_init*scale_init, trainable=True) 667 | x_init = tf.reshape(scale_init, [1, 1, 1, output_dim]) * ( 668 | x_init - tf.reshape(m_init, [1, 1, 1, output_dim])) 669 | return x_init 670 | else: 671 | v = tf.get_variable("v") 672 | g = tf.get_variable("g") 673 | b = tf.get_variable("b") 674 | tf.assert_variables_initialized([v, g, b]) 675 | w = tf.reshape(g, [1, 1, output_dim, 1]) * tf.nn.l2_normalize( 676 | v, [0, 1, 3]) 677 | x = tf.nn.conv2d_transpose(x, w, target_shape, strides=[1, d_h, d_w, 1], 678 | padding="SAME") 679 | x = tf.nn.bias_add(x, b) 680 | return x 681 | 682 | 683 | def non_local_block(x, name, use_sn): 684 | """Self-attention (non-local) block. 685 | This method is used to exactly reproduce SAGAN and ignores Gin settings on 686 | weight initialization and spectral normalization. 687 | Args: 688 | x: Input tensor of shape [batch, h, w, c]. 689 | name: Name of the variable scope. 690 | use_sn: Apply spectral norm to the weights. 691 | Returns: 692 | A tensor of the same shape after self-attention was applied. 693 | """ 694 | def _spatial_flatten(inputs): 695 | shape = inputs.shape 696 | return tf.reshape(inputs, (-1, shape[1] * shape[2], shape[3])) 697 | 698 | with tf.variable_scope(name): 699 | h, w, num_channels = x.get_shape().as_list()[1:] 700 | num_channels_attn = num_channels // 8 701 | num_channels_g = num_channels // 2 702 | 703 | # Theta path 704 | theta = conv1x1(x, num_channels_attn, name="conv2d_theta", use_sn=use_sn, 705 | use_bias=False) 706 | theta = _spatial_flatten(theta) 707 | 708 | # Phi path 709 | phi = conv1x1(x, num_channels_attn, name="conv2d_phi", use_sn=use_sn, 710 | use_bias=False) 711 | phi = tf.layers.max_pooling2d(inputs=phi, pool_size=[2, 2], strides=2) 712 | phi = _spatial_flatten(phi) 713 | 714 | attn = tf.matmul(theta, phi, transpose_b=True) 715 | attn = tf.nn.softmax(attn) 716 | 717 | # G path 718 | g = conv1x1(x, num_channels_g, name="conv2d_g", use_sn=use_sn, 719 | use_bias=False) 720 | g = tf.layers.max_pooling2d(inputs=g, pool_size=[2, 2], strides=2) 721 | g = _spatial_flatten(g) 722 | 723 | attn_g = tf.matmul(attn, g) 724 | attn_g = tf.reshape(attn_g, [-1, h, w, num_channels_g]) 725 | sigma = tf.get_variable("sigma", [], initializer=tf.zeros_initializer()) 726 | attn_g = conv1x1(attn_g, num_channels, name="conv2d_attn_g", use_sn=use_sn, 727 | use_bias=False) 728 | return x + sigma * attn_g --------------------------------------------------------------------------------