├── tflib ├── __pycache__ │ ├── ada.cpython-36.pyc │ ├── fid.cpython-36.pyc │ ├── mnist.cpython-36.pyc │ ├── plot.cpython-36.pyc │ ├── utils.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── cifar10.cpython-36.pyc │ ├── imagenet64.cpython-36.pyc │ ├── imagenet128.cpython-36.pyc │ ├── save_images.cpython-36.pyc │ ├── inception_score.cpython-36.pyc │ ├── small_imagenet.cpython-36.pyc │ └── imagenet128_10classes.cpython-36.pyc ├── custom_ops │ ├── __pycache__ │ │ └── custom_ops.cpython-36.pyc │ ├── _cudacache │ │ ├── upfirdn_2d_3e8f2d34c6404d8005e9f4001e3024c0.so │ │ └── fused_bias_act_e0a8650f1ad6d2006c69e2aca94e8f3d.so │ ├── custom_ops.py │ ├── fused_bias_act.cu │ └── upfirdn_2d.cu ├── enqueue_pickle.py ├── ops │ ├── layernorm.py │ ├── cond_batchnorm.py │ ├── conv1d.py │ ├── batchnorm.py │ ├── deconv2d.py │ ├── linear.py │ └── conv2d.py ├── stl.py ├── small_imagenet.py ├── cifar10.py ├── imagenet64.py ├── plot.py ├── save_images.py ├── imagenet32.py ├── imagenet128_10classes.py ├── imagenet128.py ├── mnist.py ├── ut_zap50k.py ├── inception_score.py ├── fid.py ├── inception_score_tpu.py.bak ├── inception_score_tpu.py ├── celeba.py ├── fid_tpu.py.bak ├── fid_tpu.py ├── tpu_ops.py ├── __init__.py ├── utils.py ├── classifier_score.py ├── ada.py └── memory_saving_gradients.py └── README.md /tflib/__pycache__/ada.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/ada.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/fid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/fid.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/mnist.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/mnist.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/plot.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/plot.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/cifar10.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/cifar10.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/imagenet64.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/imagenet64.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/imagenet128.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/imagenet128.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/save_images.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/save_images.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/inception_score.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/inception_score.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/small_imagenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/small_imagenet.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/__pycache__/imagenet128_10classes.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/__pycache__/imagenet128_10classes.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/custom_ops/__pycache__/custom_ops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/custom_ops/__pycache__/custom_ops.cpython-36.pyc -------------------------------------------------------------------------------- /tflib/custom_ops/_cudacache/upfirdn_2d_3e8f2d34c6404d8005e9f4001e3024c0.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/custom_ops/_cudacache/upfirdn_2d_3e8f2d34c6404d8005e9f4001e3024c0.so -------------------------------------------------------------------------------- /tflib/custom_ops/_cudacache/fused_bias_act_e0a8650f1ad6d2006c69e2aca94e8f3d.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsc2017/MIX-GAN/HEAD/tflib/custom_ops/_cudacache/fused_bias_act_e0a8650f1ad6d2006c69e2aca94e8f3d.so -------------------------------------------------------------------------------- /tflib/enqueue_pickle.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | def unpickle(file): 4 | fo = open(file, 'rb') 5 | dict = pickle.load(fo) 6 | fo.close() 7 | return dict['data'], dict['labels'] 8 | 9 | def enqueue_pickle(filename, que, size): 10 | try: 11 | print('trying to read', filename) 12 | data, labels = unpickle(filename) 13 | labels = np.array(labels) 14 | data = data.reshape([-1, 3, size, size]).transpose([0, 2, 3, 1]).reshape([-1, size * size * 3]) 15 | indices = np.arange(len(labels)) 16 | np.random.shuffle(indices) 17 | que.put((data[indices], labels[indices])) 18 | except: 19 | print('Cannot read %s' % filename) -------------------------------------------------------------------------------- /tflib/ops/layernorm.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | def Layernorm(name, norm_axes, inputs): 7 | mean, var = tf.nn.moments(inputs, norm_axes, keep_dims=True) 8 | 9 | # Assume the 'neurons' axis is the first of norm_axes. This is the case for fully-connected and BCHW conv layers. 10 | n_neurons = inputs.get_shape().as_list()[norm_axes[0]] 11 | 12 | offset = lib.param(name+'.offset', np.zeros(n_neurons, dtype='float32')) 13 | scale = lib.param(name+'.scale', np.ones(n_neurons, dtype='float32')) 14 | 15 | # Add broadcasting dims to offset and scale (e.g. BCHW conv data) 16 | offset = tf.reshape(offset, [-1] + [1 for i in range(len(norm_axes)-1)]) 17 | scale = tf.reshape(scale, [-1] + [1 for i in range(len(norm_axes)-1)]) 18 | 19 | result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5) 20 | 21 | return result -------------------------------------------------------------------------------- /tflib/ops/cond_batchnorm.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | dtype='float32' 6 | def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True, fused=True, labels=None, n_labels=None): 7 | """conditional batchnorm (dumoulin et al 2016) for BCHW conv filtermaps""" 8 | if axes != [0,1,2]: 9 | raise Exception('unsupported') 10 | mean, var = tf.nn.moments(inputs, axes, keep_dims=True) 11 | shape = mean.get_shape().as_list() # shape is [1,n,1,1] 12 | offset_m = lib.param(name+'.offset', np.zeros([n_labels,shape[3]], dtype=dtype)) 13 | scale_m = lib.param(name+'.scale', np.ones([n_labels,shape[3]], dtype=dtype)) 14 | offset = tf.nn.embedding_lookup(offset_m, labels) 15 | scale = tf.nn.embedding_lookup(scale_m, labels) 16 | result = tf.nn.batch_normalization(inputs, mean, var, offset[:,None,None,:], scale[:,None,None,:], 1e-5) 17 | return result -------------------------------------------------------------------------------- /tflib/stl.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import numpy as np 3 | 4 | data_dir='/home/tsc/data/mydata/STL-10/stl10_matlab/' 5 | 6 | def make_generator(mode, image_path, files): 7 | if mode=='TRAIN': 8 | train_data_dir =data_dir+'train.mat' 9 | load_data = sio.loadmat(train_data_dir) 10 | all_images_train=load_data['X'].reshape([-1,3,96,96]).transpose([0,1,3,2]) 11 | all_labels_train=load_data['y'].reshape([-1]) 12 | return all_images_train[np.ix_(files)], all_labels_train[np.ix_(files)]-1 13 | elif mode=='TEST': 14 | test_data_dir =data_dir+'test.mat' 15 | load_data = sio.loadmat(test_data_dir) 16 | all_images_test=load_data['X'].reshape([-1,3,96,96]).transpose([0,1,3,2]) 17 | all_labels_test=load_data['y'].reshape([-1]) 18 | return all_images_test[np.ix_(files)], all_labels_test[np.ix_(files)]-1 19 | def load(mode, data_dir,files): 20 | return make_generator(mode, data_dir, files) 21 | -------------------------------------------------------------------------------- /tflib/small_imagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import time 4 | 5 | def make_generator(path, n_files, batch_size): 6 | epoch_count = [1] 7 | def get_epoch(): 8 | images = np.zeros((batch_size, 3, 64, 64), dtype='int32') 9 | files = range(n_files) 10 | random_state = np.random.RandomState(epoch_count[0]) 11 | random_state.shuffle(files) 12 | epoch_count[0] += 1 13 | for n, i in enumerate(files): 14 | image = scipy.misc.imread("{}/{}.png".format(path, str(i+1).zfill(len(str(n_files))))) 15 | images[n % batch_size] = image.transpose(2,0,1) 16 | if n > 0 and n % batch_size == 0: 17 | yield (images,) 18 | return get_epoch 19 | 20 | def load(batch_size, data_dir='/home/ishaan/data/imagenet64'): 21 | return ( 22 | make_generator(data_dir+'/train_64x64', 1281149, batch_size), 23 | make_generator(data_dir+'/valid_64x64', 49999, batch_size) 24 | ) 25 | 26 | if __name__ == '__main__': 27 | train_gen, valid_gen = load(64) 28 | t0 = time.time() 29 | for i, batch in enumerate(train_gen(), start=1): 30 | print("{}\t{}".format(str(time.time() - t0), batch[0][0,0,0,0])) 31 | if i == 1000: 32 | break 33 | t0 = time.time() -------------------------------------------------------------------------------- /tflib/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import pickle 7 | 8 | def unpickle(file): 9 | fo = open(file, 'rb') 10 | dict = pickle.load(fo,encoding='latin1') 11 | fo.close() 12 | return dict['data'], dict['labels'] 13 | 14 | def cifar_generator(filenames, batch_size, data_dir): 15 | all_data = [] 16 | all_labels = [] 17 | for filename in filenames: 18 | data, labels = unpickle(data_dir + '/' + filename) 19 | all_data.append(data) 20 | all_labels.append(labels) 21 | 22 | images = np.concatenate(all_data, axis=0).reshape([-1,3,32,32]).transpose([0,2,3,1]).reshape([-1,32*32*3]) 23 | labels = np.concatenate(all_labels, axis=0) 24 | 25 | def get_epoch(): 26 | rng_state = np.random.get_state() 27 | np.random.shuffle(images) 28 | np.random.set_state(rng_state) 29 | np.random.shuffle(labels) 30 | 31 | for i in range(len(images) // batch_size): 32 | yield (images[i*batch_size:(i+1)*batch_size], labels[i*batch_size:(i+1)*batch_size]) 33 | 34 | return get_epoch 35 | 36 | 37 | def load(batch_size, data_dir): 38 | return ( 39 | cifar_generator(['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5'], batch_size, data_dir), 40 | cifar_generator(['test_batch'], batch_size, data_dir) 41 | ) -------------------------------------------------------------------------------- /tflib/imagenet64.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import pickle 7 | import gc 8 | def unpickle(file): 9 | fo = open(file, 'rb') 10 | dict = pickle.load(fo,encoding='latin1') 11 | fo.close() 12 | return dict['data'], dict['labels'] 13 | 14 | def cifar_generator(filenames, batch_size, data_dir): 15 | all_data = [] 16 | all_labels = [] 17 | for filename in filenames: 18 | data, labels = unpickle(data_dir + '/' + filename) 19 | all_data.append(data) 20 | all_labels.append(labels) 21 | 22 | images = np.concatenate(all_data, axis=0).reshape([-1,3,64,64]).transpose([0,2,3,1]).reshape([-1,64*64*3]) 23 | labels = np.concatenate(all_labels, axis=0)-1 24 | 25 | print('All training data loaded into memory.')# 1281167×64×64×3×8 bits = 15.74 GB 26 | del all_data, all_labels 27 | gc.collect() 28 | def get_epoch(): 29 | rng_state = np.random.get_state() 30 | np.random.shuffle(images) 31 | np.random.set_state(rng_state) 32 | np.random.shuffle(labels) 33 | 34 | for i in range(len(images) // batch_size): 35 | yield (images[i*batch_size:(i+1)*batch_size], labels[i*batch_size:(i+1)*batch_size]) 36 | 37 | return get_epoch 38 | 39 | 40 | def load(mode, batch_size, data_dir): 41 | if mode=='TRAIN': 42 | return cifar_generator(['train_data_batch_%i'%(i+1) for i in range(100)], batch_size, data_dir) 43 | else: 44 | return cifar_generator(['val_data_batch_%i' % (i + 1) for i in range(10)], batch_size, data_dir) 45 | -------------------------------------------------------------------------------- /tflib/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | 7 | import collections 8 | import time 9 | import pickle 10 | import os, shutil 11 | _since_beginning = collections.defaultdict(lambda: {}) 12 | _since_last_flush = collections.defaultdict(lambda: {}) 13 | first_flush=True 14 | 15 | _iter = [0] 16 | def tick(): 17 | _iter[0] += 1 18 | 19 | def plot(name, value): 20 | _since_last_flush[name][_iter[0]] = value 21 | 22 | def flush(save_path): 23 | global _since_beginning, _since_last_flush, first_flush 24 | if first_flush and os.path.exists(os.path.join(save_path,'log.pkl')): 25 | try: 26 | pkl_file=open(os.path.join(save_path,'log.pkl'), 'rb') 27 | _since_beginning.update(pickle.load(pkl_file)) 28 | except: 29 | pkl_file=open(os.path.join(save_path,'log.pkl.bak'), 'rb') 30 | _since_beginning.update(pickle.load(pkl_file)) 31 | first_flush=False 32 | pkl_file=os.path.join(save_path,'log.pkl') 33 | if os.path.exists(pkl_file): 34 | try: 35 | backup_pkl=shutil.copyfile(pkl_file, pkl_file+'.bak') # shutil.copy does not work well with gcsfuse and is thus not used here 36 | except: 37 | pass 38 | prints = [] 39 | 40 | for name, vals in _since_last_flush.items(): 41 | prints.append("{}={:0,.2f},".format(name, np.mean(list(vals.values())))) 42 | _since_beginning[name].update(vals) 43 | 44 | x_vals = np.sort(list(_since_beginning[name].keys())) 45 | y_vals = [_since_beginning[name][x] for x in x_vals] 46 | 47 | plt.clf() 48 | plt.plot(x_vals, y_vals) 49 | plt.xlabel('iteration') 50 | plt.ylabel(name) 51 | try: 52 | plt.savefig(save_path+name.replace(' ', '_')+'.jpg') 53 | except: 54 | pass 55 | 56 | print("iter {}\t{}".format(_iter[0], " ".join(prints))) 57 | _since_last_flush.clear() 58 | 59 | with open(pkl_file, 'wb') as f: 60 | pickle.dump(dict(_since_beginning), f, pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /tflib/save_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image grid saver, based on color_grid_vis from github.com/Newmu 3 | """ 4 | 5 | import numpy as np 6 | import scipy.misc 7 | from imageio import imwrite 8 | import matplotlib.pyplot as plt 9 | import IPython 10 | import PIL 11 | def save_images(X, save_path): 12 | # [0, 1] -> [0,255] 13 | if isinstance(X.flatten()[0], np.floating): 14 | X = (255.*X).astype('uint8') 15 | 16 | n_samples = X.shape[0] 17 | rows = int(np.sqrt(n_samples)) 18 | while n_samples % rows != 0: 19 | rows -= 1 20 | 21 | nh, nw = rows, n_samples//rows 22 | 23 | if X.ndim == 2: 24 | X = np.reshape(X, (X.shape[0], int(np.sqrt(X.shape[1])), int(np.sqrt(X.shape[1])))) 25 | 26 | if X.ndim == 4: 27 | # BCHW -> BHWC 28 | #X = X.transpose(0,2,3,1) 29 | h, w = X[0].shape[:2] 30 | img = np.zeros((h*nh, w*nw, 3), dtype=np.uint8) 31 | elif X.ndim == 3: 32 | h, w = X[0].shape[:2] 33 | img = np.zeros((h*nh, w*nw), dtype=np.uint8) 34 | 35 | for n, x in enumerate(X): 36 | j = n//nw 37 | i = n%nw 38 | img[j*h:j*h+h, i*w:i*w+w] = x 39 | 40 | imwrite(save_path, img) 41 | 42 | def show_images(X): 43 | # [0, 1] -> [0,255] 44 | if isinstance(X.flatten()[0], np.floating): 45 | X = (255.*X).astype('uint8') 46 | 47 | n_samples = X.shape[0] 48 | rows = int(np.sqrt(n_samples)) 49 | while n_samples % rows != 0: 50 | rows -= 1 51 | 52 | nh, nw = rows, n_samples//rows 53 | 54 | if X.ndim == 2: 55 | X = np.reshape(X, (X.shape[0], int(np.sqrt(X.shape[1])), int(np.sqrt(X.shape[1])))) 56 | 57 | if X.ndim == 4: 58 | # BCHW -> BHWC 59 | #X = X.transpose(0,2,3,1) 60 | h, w = X[0].shape[:2] 61 | img = np.zeros((h*nh, w*nw, 3),np.uint8) 62 | elif X.ndim == 3: 63 | h, w = X[0].shape[:2] 64 | img = np.zeros((h*nh, w*nw),np.uint8) 65 | 66 | for n, x in enumerate(X): 67 | j = n//nw 68 | i = n%nw 69 | img[j*h:j*h+h, i*w:i*w+w] = x 70 | IPython.display.display((PIL.Image.fromarray(img))) -------------------------------------------------------------------------------- /tflib/imagenet32.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | import pickle 3 | import numpy as np 4 | import os 5 | import urllib 6 | import gzip 7 | import time 8 | import multiprocessing 9 | import threading 10 | import queue 11 | 12 | epoch=[1] 13 | pickles_queue =queue.Queue(maxsize = 2) 14 | reading_lock = threading.Lock() 15 | def unpickle(file): 16 | fo = open(file, 'rb') 17 | dict = pickle.load(fo) 18 | fo.close() 19 | return dict['data'], dict['labels'] 20 | 21 | def enqueue_pickle(filename, que, size): 22 | reading_lock.acquire() 23 | try: 24 | data, labels = unpickle(filename) 25 | labels = np.array(labels) 26 | data = data.reshape([-1, 3, size, size]).transpose([0, 2, 3, 1]).reshape([-1, size * size * 3]) 27 | indices = np.arange(len(labels)) 28 | np.random.shuffle(indices) 29 | que.put((data[indices], labels[indices])) 30 | except: 31 | print('Cannot read %s' % filename) 32 | reading_lock.release() 33 | 34 | def cifar_generator(filenames, batch_size, data_dir): 35 | def get_epoch(): 36 | print('Epoch %i' % epoch[0]) 37 | epoch[0]+=1 38 | np.random.shuffle(filenames) 39 | for filename in filenames: 40 | threading.Thread(target=enqueue_pickle, args=(data_dir + filename, pickles_queue, 32)).start() 41 | #lib.read_pickle.read_pickle(data_dir + '/' + filename, pickles_queue) 42 | count = 0 43 | while 1: 44 | #print('queue length: ', pickles_queue.qsize()) 45 | data, labels = pickles_queue.get() 46 | labels = labels -1 47 | count+=1 48 | for i in range(data.shape[0] // batch_size): 49 | yield data[i * batch_size:(i + 1) * batch_size], labels[i * batch_size:(i + 1) * batch_size] 50 | pickles_queue.task_done() 51 | if count == len(filenames)-1: 52 | #assert pickles_queue.qsize()==1 53 | #pool.close() 54 | #pool.join() 55 | break 56 | return get_epoch 57 | 58 | def load(mode, batch_size, data_dir): 59 | if mode=='TRAIN': 60 | return cifar_generator(['train_data_batch_%i'%(i+1) for i in range(100)], batch_size, data_dir) 61 | else: 62 | return cifar_generator(['val_data_batch_%i' % (i + 1) for i in range(10)], batch_size, data_dir) 63 | -------------------------------------------------------------------------------- /tflib/imagenet128_10classes.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | import pickle 3 | import numpy as np 4 | import os 5 | import urllib 6 | import gzip 7 | import time 8 | import multiprocessing 9 | import threading 10 | import queue 11 | 12 | epoch=[1] 13 | pickles_queue =queue.Queue(maxsize = 2) 14 | reading_lock = threading.Lock() 15 | def unpickle(file): 16 | fo = open(file, 'rb') 17 | dict = pickle.load(fo) 18 | fo.close() 19 | return dict['data'], dict['labels'] 20 | 21 | def enqueue_pickle(filename, que, size): 22 | reading_lock.acquire() 23 | try: 24 | data, labels = unpickle(filename) 25 | labels = np.array(labels) 26 | data = data.reshape([-1, 3, size, size]).transpose([0, 2, 3, 1]).reshape([-1, size * size * 3]) 27 | indices = np.arange(len(labels)) 28 | np.random.shuffle(indices) 29 | que.put((data[indices], labels[indices])) 30 | except: 31 | print('Cannot read %s' % filename) 32 | reading_lock.release() 33 | 34 | def cifar_generator(filenames, batch_size, data_dir): 35 | def get_epoch(): 36 | print('Epoch %i' % epoch[0]) 37 | epoch[0]+=1 38 | np.random.shuffle(filenames) 39 | for filename in filenames: 40 | threading.Thread(target=enqueue_pickle, args=(data_dir + filename, pickles_queue, 128)).start() 41 | #lib.read_pickle.read_pickle(data_dir + '/' + filename, pickles_queue) 42 | count = 0 43 | while 1: 44 | #print('queue length: ', pickles_queue.qsize()) 45 | data, labels = pickles_queue.get() 46 | labels = labels//100 47 | count+=1 48 | for i in range(data.shape[0] // batch_size): 49 | yield data[i * batch_size:(i + 1) * batch_size], labels[i * batch_size:(i + 1) * batch_size] 50 | pickles_queue.task_done() 51 | if count == len(filenames)-1: 52 | #assert pickles_queue.qsize()==1 53 | #pool.close() 54 | #pool.join() 55 | break 56 | return get_epoch 57 | 58 | def load(mode, batch_size, data_dir): 59 | if mode=='TRAIN': 60 | return cifar_generator(['train_data_batch_%i'%(i+1) for i in range(100)], batch_size, data_dir) 61 | else: 62 | return cifar_generator(['val_data_batch_%i' % (i + 1) for i in range(10)], batch_size, data_dir) 63 | -------------------------------------------------------------------------------- /tflib/imagenet128.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | import pickle 3 | import numpy as np 4 | import os 5 | import urllib 6 | import gzip 7 | import time 8 | import multiprocessing 9 | import threading 10 | import queue 11 | 12 | epoch=[1] 13 | 14 | def unpickle(file): 15 | fo = open(file, 'rb') 16 | dict = pickle.load(fo) 17 | fo.close() 18 | return dict['data'], dict['labels'] 19 | 20 | def enqueue_pickle(filename, que, reading_lock, size): 21 | reading_lock.acquire() 22 | try: 23 | data, labels = unpickle(filename) 24 | labels = np.array(labels) 25 | data = data.reshape([-1, 3, size, size]).transpose([0, 2, 3, 1]).reshape([-1, size * size * 3]) 26 | indices = np.arange(len(labels)) 27 | np.random.shuffle(indices) 28 | que.put((data[indices], labels[indices])) 29 | except: 30 | print('Cannot read %s' % filename) 31 | reading_lock.release() 32 | 33 | def cifar_generator(filenames, batch_size, data_dir): 34 | pickles_queue =queue.Queue(maxsize = 2) 35 | reading_lock = threading.Lock() 36 | def get_epoch(): 37 | print('Epoch %i' % epoch[0]) 38 | epoch[0]+=1 39 | np.random.shuffle(filenames) 40 | for filename in filenames: 41 | threading.Thread(target=enqueue_pickle, args=(data_dir + filename, pickles_queue,reading_lock, 128)).start() 42 | #lib.read_pickle.read_pickle(data_dir + '/' + filename, pickles_queue) 43 | count = 0 44 | while 1: 45 | #print('queue length: ', pickles_queue.qsize()) 46 | data, labels = pickles_queue.get() 47 | labels = labels -1 48 | count+=1 49 | for i in range(data.shape[0] // batch_size): 50 | yield data[i * batch_size:(i + 1) * batch_size], labels[i * batch_size:(i + 1) * batch_size] 51 | pickles_queue.task_done() 52 | if count == len(filenames)-1: 53 | #assert pickles_queue.qsize()==1 54 | #pool.close() 55 | #pool.join() 56 | break 57 | return get_epoch 58 | 59 | def load(mode, batch_size, data_dir): 60 | if mode=='TRAIN': 61 | return cifar_generator(['train_data_batch_%i'%(i+1) for i in range(100)], batch_size, data_dir) 62 | else: 63 | return cifar_generator(['val_data_batch_%i' % (i + 1) for i in range(10)], batch_size, data_dir) 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MIX-GAN 2 | ## code for the paper [**Lessons Learned from the Training of GANs on Artificial Datasets**](https://arxiv.org/abs/2007.06418) and beyond 3 | Some recent state-of-the-art generative models in ONE notebook. 4 | 5 | This repo implements any method that can match the following regular expression: 6 | 7 | `(MIX-)?(GAN|WGAN|BigGAN|MHingeGAN|AMGAN|StyleGAN|StyleGAN2)(\+ADA|\+CR|\+EMA|\+GP|\+R1|\+SA|\+SN)*` 8 | 9 | # Major dependencies 10 | - For the GPU implementation, `tensorflow>=2` or `tensorflow-gpu==1.14` (some modifications for the calculation of IS and FID will be necessary, see the other repos of mine). 11 | - For the TPU implemetation, `tensorflow>=2.4` or `tf-nightly` will be necessary. 12 | # Free GPU training on Colab 13 | [![Example In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tsc2017/MIX-GAN/blob/main/MIX-MHingeGAN-CIFAR-10.ipynb) 14 | 15 | This implemetation supports automatic mixed-precision training of TensorFlow, which can reduce GPU memory usage and training time dramatically. Therefore, it is recommended to upgrade to [Colab Pro](https://colab.research.google.com/signup) in order to use GPUs with Tensor Cores. Training `MIX-MHingeGAN` with 10 generators and 10 discriminators takes only 1.5 days on a single Tesla V100. 16 | # Free TPU training on Colab 17 | Coming soon... 18 | # Training on Cloud TPUs 19 | - First [disable Stackdriver Logging](https://console.cloud.google.com/logs/router?) to avoid unnecessary charges. 20 | - [Create cloud TPUs](https://cloud.google.com/tpu/docs/creating-deleting-tpus), TPU software version should be at least `2.4.0` or `nightly`. 21 | - Fill in `TPU_NAMES` and `ZONE` in the the above notebook for TPUs. Set up environment variables `LOG` and `DATA`, run the notebook. 22 | - [Delete TPUs](https://cloud.google.com/tpu/docs/creating-deleting-tpus). 23 | # References 24 | https://github.com/igul222/improved_wgan_training 25 | https://github.com/biuyq/CT-GAN 26 | https://github.com/google/compare_gan 27 | https://github.com/ajbrock/BigGAN-PyTorch 28 | https://github.com/taki0112/BigGAN-Tensorflow 29 | https://github.com/brain-research/self-attention-gan 30 | https://github.com/ilyakava/BigGAN-PyTorch 31 | https://github.com/NVlabs/stylegan2 32 | https://github.com/NVlabs/stylegan2-ada 33 | # Citation 34 | ``` 35 | @article{tang2020lessons, 36 | title={Lessons Learned from the Training of GANs on Artificial Datasets}, 37 | author={Tang, Shichang}, 38 | journal={arXiv preprint arXiv:2007.06418}, 39 | year={2020} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /tflib/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import scipy.misc 3 | import os 4 | import urllib 5 | import gzip 6 | import pickle 7 | height=width=32 8 | def mnist_generator(data, batch_size, n_labelled, limit=None): 9 | images, targets = data 10 | 11 | rng_state = numpy.random.get_state() 12 | numpy.random.shuffle(images) 13 | numpy.random.set_state(rng_state) 14 | numpy.random.shuffle(targets) 15 | if limit is not None: 16 | print("WARNING ONLY FIRST {} MNIST DIGITS".format(limit)) 17 | images = images.astype('float32')[:limit] 18 | targets = targets.astype('int32')[:limit] 19 | if n_labelled is not None: 20 | labelled = numpy.zeros(len(images), dtype='int32') 21 | labelled[:n_labelled] = 1 22 | images=images.reshape(-1,28,28) 23 | reshape_images=numpy.zeros((images.shape[0],height,width)) 24 | for i in range(images.shape[0]): 25 | reshape_images[i]=scipy.misc.imresize(images[i], [height, width]) 26 | images=reshape_images/255. #scale changes from [0,1] to [0,255], so need to renormalize 27 | def get_epoch(): 28 | rng_state = numpy.random.get_state() 29 | numpy.random.shuffle(images) 30 | numpy.random.set_state(rng_state) 31 | numpy.random.shuffle(targets) 32 | 33 | if n_labelled is not None: 34 | numpy.random.set_state(rng_state) 35 | numpy.random.shuffle(labelled) 36 | image_batches = images.reshape(-1, height*width) 37 | target_batches = targets.reshape(-1) 38 | 39 | if n_labelled is not None: 40 | labelled_batches = labelled.reshape(-1, batch_size) 41 | 42 | for i in range(len(image_batches)//batch_size): 43 | yield (numpy.copy(image_batches[i*batch_size:(i+1)*batch_size]), numpy.copy(target_batches[i*batch_size:(i+1)*batch_size]), numpy.copy(labelled)) 44 | 45 | else: 46 | 47 | for i in range(len(image_batches)//batch_size): 48 | yield (numpy.copy(image_batches[i*batch_size:(i+1)*batch_size]), numpy.copy(target_batches[i*batch_size:(i+1)*batch_size])) 49 | 50 | return get_epoch 51 | 52 | def load(batch_size, test_batch_size, n_labelled=None): 53 | filepath = os.environ['HOME']+'/data/mydata/mnist.pkl.gz' 54 | url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' 55 | 56 | if not os.path.isfile(filepath): 57 | print("Couldn't find MNIST dataset in /tmp, downloading...") 58 | urllib.urlretrieve(url, filepath) 59 | 60 | with gzip.open(filepath, 'rb') as f: 61 | train_data, dev_data, test_data = pickle.load(f) 62 | 63 | return ( 64 | mnist_generator(train_data, batch_size, n_labelled), 65 | mnist_generator(dev_data, test_batch_size, n_labelled), 66 | mnist_generator(test_data, test_batch_size, n_labelled) 67 | ) -------------------------------------------------------------------------------- /tflib/ut_zap50k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import time 4 | import os 5 | #Loading all images into memory all at once will take a long time at startup while not loading them into memory will slow down the program. 6 | #So we adopt the following alternative: 7 | #If an image is accessed for the first time, it will be loaded into memory. 8 | #Subsequent visits to it will go to memory directly, which results in a 2~3x acceration 9 | data_dir=os.environ['HOME']+'/data/mydata/ut-zap50k' 10 | train_size=40000 11 | test_size=10025 12 | height=width=32 13 | loaded={'TRAIN':np.zeros(train_size),'TEST':np.zeros(test_size)} 14 | all_images={'TRAIN':np.zeros((train_size, 3, height, width), dtype='int32'),'TEST':np.zeros((test_size, 3, height, width), dtype='int32')} 15 | all_sketches={'TRAIN':np.zeros((train_size, 1, height, width), dtype='int32'),'TEST':np.zeros((test_size, 1, height, width), dtype='int32')} 16 | def make_generator(mode, image_path,sketch_path, name_list,permute_channels=True): 17 | epoch_count = [1] 18 | #def get_epoch(): 19 | images = np.zeros((len(name_list), 3, height, width), dtype='int32') 20 | sketches = np.zeros((len(name_list),1, height, width), dtype='int32') 21 | files = name_list 22 | random_state = np.random.RandomState(epoch_count[0]) 23 | random_state.shuffle(files) 24 | epoch_count[0] += 1 25 | perm=np.arange(3) 26 | for n, i in enumerate(files): 27 | if loaded[mode][i]==0: 28 | image = scipy.misc.imread("{}/{}.jpg".format(image_path, str(i+1).zfill(len(str(train_size)))),mode='RGB') 29 | image=scipy.misc.imresize(image, [height, width]) 30 | sketch = scipy.misc.imread("{}/{}.jpg".format(sketch_path, str(i+1).zfill(len(str(train_size))))) 31 | #print "{}/{}.jpg".format(sketch_path, str(i+1).zfill(len(str(train_size)))) 32 | sketch= scipy.misc.imresize(sketch, [height, width]) 33 | if len(image.shape)!=3: 34 | print "{}/{}.jpg".format(image_path, str(i+1).zfill(len(str(train_size)))) 35 | print image.shape 36 | all_images[mode][i] = image.transpose(2,0,1) 37 | all_sketches[mode][i][0] = sketch 38 | loaded[mode][i]=1 39 | #images[n] =all_images[mode][i] 40 | #sketches[n][0] =all_sketches[mode][i][0] 41 | #if n > 0 and n % batch_size == 0: 42 | #yield (images,sketches) 43 | #return images,sketches 44 | np.random.shuffle(perm) 45 | if mode=='TRAIN' and permute_channels: 46 | all_images[mode][i]=all_images[mode][i][perm] 47 | return all_images[mode][np.ix_(files)], all_sketches[mode][np.ix_(files)] 48 | def load(mode, data_dir,sketch_dir,name_list,permute_channels=True): 49 | if mode=='TRAIN': 50 | return make_generator(mode, data_dir+'/train', sketch_dir+'/train',name_list,permute_channels) 51 | elif mode=='TEST': 52 | return make_generator(mode, data_dir+'/test', sketch_dir+'/test',name_list,permute_channels) 53 | 54 | 55 | if __name__ == '__main__': 56 | train_gen, valid_gen = load(64) 57 | t0 = time.time() 58 | for i, batch in enumerate(train_gen(), start=1): 59 | print "{}\t{}".format(str(time.time() - t0), batch[0][0,0,0,0]) 60 | if i == 1000: 61 | break 62 | t0 = time.time() -------------------------------------------------------------------------------- /tflib/ops/conv1d.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _default_weightnorm = False 7 | def enable_default_weightnorm(): 8 | global _default_weightnorm 9 | _default_weightnorm = True 10 | 11 | def Conv1D(name, input_dim, output_dim, filter_size, inputs, he_init=True, mask_type=None, stride=1, weightnorm=None, biases=True, gain=1.): 12 | """ 13 | inputs: tensor of shape (batch size, num channels, width) 14 | mask_type: one of None, 'a', 'b' 15 | 16 | returns: tensor of shape (batch size, num channels, width) 17 | """ 18 | with tf.name_scope(name) as scope: 19 | 20 | if mask_type is not None: 21 | mask_type, mask_n_channels = mask_type 22 | 23 | mask = np.ones( 24 | (filter_size, input_dim, output_dim), 25 | dtype='float32' 26 | ) 27 | center = filter_size // 2 28 | 29 | # Mask out future locations 30 | # filter shape is (width, input channels, output channels) 31 | mask[center+1:, :, :] = 0. 32 | 33 | # Mask out future channels 34 | for i in xrange(mask_n_channels): 35 | for j in xrange(mask_n_channels): 36 | if (mask_type=='a' and i >= j) or (mask_type=='b' and i > j): 37 | mask[ 38 | center, 39 | i::mask_n_channels, 40 | j::mask_n_channels 41 | ] = 0. 42 | 43 | 44 | def uniform(stdev, size): 45 | return np.random.uniform( 46 | low=-stdev * np.sqrt(3), 47 | high=stdev * np.sqrt(3), 48 | size=size 49 | ).astype('float32') 50 | 51 | fan_in = input_dim * filter_size 52 | fan_out = output_dim * filter_size / stride 53 | 54 | if mask_type is not None: # only approximately correct 55 | fan_in /= 2. 56 | fan_out /= 2. 57 | 58 | if he_init: 59 | filters_stdev = np.sqrt(4./(fan_in+fan_out)) 60 | else: # Normalized init (Glorot & Bengio) 61 | filters_stdev = np.sqrt(2./(fan_in+fan_out)) 62 | 63 | filter_values = uniform( 64 | filters_stdev, 65 | (filter_size, input_dim, output_dim) 66 | ) 67 | # print "WARNING IGNORING GAIN" 68 | filter_values *= gain 69 | 70 | filters = lib.param(name+'.Filters', filter_values) 71 | 72 | if weightnorm==None: 73 | weightnorm = _default_weightnorm 74 | if weightnorm: 75 | norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,1))) 76 | target_norms = lib.param( 77 | name + '.g', 78 | norm_values 79 | ) 80 | with tf.name_scope('weightnorm') as scope: 81 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0,1])) 82 | filters = filters * (target_norms / norms) 83 | 84 | if mask_type is not None: 85 | with tf.name_scope('filter_mask'): 86 | filters = filters * mask 87 | 88 | result = tf.nn.conv1d( 89 | value=inputs, 90 | filters=filters, 91 | stride=stride, 92 | padding='SAME', 93 | data_format='NHWC' 94 | ) 95 | 96 | if biases: 97 | _biases = lib.param( 98 | name+'.Biases', 99 | np.zeros([output_dim], dtype='float32') 100 | ) 101 | 102 | # result = result + _biases 103 | 104 | result = tf.expand_dims(result, 2) 105 | result = tf.nn.bias_add(result, _biases, data_format='NHWC') 106 | result = tf.squeeze(result) 107 | 108 | return result 109 | -------------------------------------------------------------------------------- /tflib/inception_score.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From https://github.com/tsc2017/Inception-Score 3 | Code derived from https://github.com/openai/improved-gan/blob/master/inception_score/model.py and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py 4 | 5 | Usage: 6 | Call get_inception_score(images, splits=10) 7 | Args: 8 | images: A numpy array with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. A dtype of np.uint8 is recommended to save CPU memory. 9 | splits: The number of splits of the images, default is 10. 10 | Returns: 11 | Mean and standard deviation of the Inception Score across the splits. 12 | ''' 13 | 14 | import tensorflow.compat.v1 as tf 15 | tf.disable_v2_behavior() 16 | import tensorflow_gan as tfgan 17 | import os 18 | import functools 19 | import numpy as np 20 | import time 21 | from tensorflow.python.ops import array_ops 22 | # pip install tensorflow-gan 23 | import tensorflow_gan as tfgan 24 | 25 | config=tf.ConfigProto(log_device_placement=True,allow_soft_placement=True) 26 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 27 | from tensorflow.core.protobuf import rewriter_config_pb2 28 | config.graph_options.rewrite_options.auto_mixed_precision = rewriter_config_pb2.RewriterConfig.ON 29 | session=tf.Session(config=config) 30 | 31 | # A smaller BATCH_SIZE reduces GPU memory usage, but at the cost of a slight slowdown 32 | BATCH_SIZE = 64 33 | INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1' 34 | INCEPTION_OUTPUT = 'logits' 35 | 36 | # Run images through Inception. 37 | inception_images = tf.compat.v1.placeholder(tf.float32, [None, 3, None, None], name = 'inception_images') 38 | def inception_logits(images = inception_images, num_splits = 1): 39 | images = tf.transpose(images, [0, 2, 3, 1]) 40 | size = 299 41 | images = tf.compat.v1.image.resize_bilinear(images, [size, size]) 42 | generated_images_list = array_ops.split(images, num_or_size_splits = num_splits) 43 | logits = tf.map_fn( 44 | fn = tfgan.eval.classifier_fn_from_tfhub(INCEPTION_TFHUB, INCEPTION_OUTPUT, True), 45 | elems = array_ops.stack(generated_images_list), 46 | parallel_iterations = 8, 47 | back_prop = False, 48 | swap_memory = True, 49 | name = 'RunClassifier') 50 | logits = array_ops.concat(array_ops.unstack(logits), 0) 51 | return logits 52 | 53 | logits=inception_logits() 54 | 55 | def get_inception_probs(inps): 56 | #session=tf.get_default_session() 57 | n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE)) 58 | preds = np.zeros([inps.shape[0], 1000], dtype = np.float32) 59 | for i in range(n_batches): 60 | inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1 61 | preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits, {inception_images: inp})[:, :1000] 62 | preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True) 63 | return preds 64 | 65 | def preds2score(preds, splits=10): 66 | scores = [] 67 | for i in range(splits): 68 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 69 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 70 | kl = np.mean(np.sum(kl, 1)) 71 | scores.append(np.exp(kl)) 72 | return np.mean(scores), np.std(scores) 73 | 74 | def get_inception_score(images, splits=10): 75 | assert(type(images) == np.ndarray) 76 | assert(len(images.shape) == 4) 77 | assert(images.shape[1] == 3) 78 | assert(np.min(images[0]) >= 0 and np.max(images[0]) > 10), 'Image values should be in the range [0, 255]' 79 | print('Calculating Inception Score with %i images in %i splits' % (images.shape[0], splits)) 80 | start_time=time.time() 81 | preds = get_inception_probs(images) 82 | mean, std = preds2score(preds, splits) 83 | print('Inception Score calculation time: %f s' % (time.time() - start_time)) 84 | return mean, std # Reference values: 11.38 for 50000 CIFAR-10 training set images, or mean=11.31, std=0.10 if in 10 splits. -------------------------------------------------------------------------------- /tflib/fid.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From https://github.com/tsc2017/Frechet-Inception-Distance 3 | Code derived from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py 4 | 5 | Usage: 6 | Call get_fid(images1, images2) 7 | Args: 8 | images1, images2: Numpy arrays with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. 9 | dtype of the images is recommended to be np.uint8 to save CPU memory. 10 | Returns: 11 | Frechet Inception Distance between the two image distributions. 12 | ''' 13 | 14 | import tensorflow.compat.v1 as tf 15 | tf.disable_v2_behavior() 16 | import os 17 | import functools 18 | import numpy as np 19 | import time 20 | from tensorflow.python.ops import array_ops 21 | # pip install tensorflow-gan 22 | import tensorflow_gan as tfgan 23 | 24 | config=tf.ConfigProto(log_device_placement=True,allow_soft_placement=True) 25 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 26 | from tensorflow.core.protobuf import rewriter_config_pb2 27 | config.graph_options.rewrite_options.auto_mixed_precision = rewriter_config_pb2.RewriterConfig.ON 28 | session=tf.Session(config=config) 29 | 30 | # A smaller BATCH_SIZE reduces GPU memory usage, but at the cost of a slight slowdown 31 | BATCH_SIZE = 64 32 | 33 | # Run images through Inception. 34 | inception_images = tf.compat.v1.placeholder(tf.float32, [None, 3, None, None], name = 'inception_images') 35 | activations1 = tf.compat.v1.placeholder(tf.float32, [None, None], name = 'activations1') 36 | activations2 = tf.compat.v1.placeholder(tf.float32, [None, None], name = 'activations2') 37 | fcd = tfgan.eval.frechet_classifier_distance_from_activations(activations1, activations2) 38 | 39 | INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1' 40 | INCEPTION_FINAL_POOL = 'pool_3' 41 | 42 | def inception_activations(images = inception_images, num_splits = 1): 43 | images = tf.transpose(images, [0, 2, 3, 1]) 44 | size = 299 45 | images = tf.compat.v1.image.resize_bilinear(images, [size, size]) 46 | generated_images_list = array_ops.split(images, num_or_size_splits = num_splits) 47 | activations = tf.map_fn( 48 | fn = tfgan.eval.classifier_fn_from_tfhub(INCEPTION_TFHUB, INCEPTION_FINAL_POOL, True), 49 | elems = array_ops.stack(generated_images_list), 50 | parallel_iterations = 1, 51 | back_prop = False, 52 | swap_memory = True, 53 | name = 'RunClassifier') 54 | activations = array_ops.concat(array_ops.unstack(activations), 0) 55 | return activations 56 | 57 | activations =inception_activations() 58 | 59 | def get_inception_activations(inps): 60 | n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE)) 61 | act = np.zeros([inps.shape[0], 2048], dtype = np.float32) 62 | for i in range(n_batches): 63 | inp = inps[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] / 255. * 2 - 1 64 | act[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(activations, feed_dict = {inception_images: inp}) 65 | return act 66 | 67 | def activations2distance(act1, act2): 68 | return session.run(fcd, feed_dict = {activations1: act1, activations2: act2}) 69 | 70 | def get_fid(images1, images2): 71 | assert(type(images1) == np.ndarray) 72 | assert(len(images1.shape) == 4) 73 | assert(images1.shape[1] == 3) 74 | assert(np.min(images1[0]) >= 0 and np.max(images1[0]) > 10), 'Image values should be in the range [0, 255]' 75 | assert(type(images2) == np.ndarray) 76 | assert(len(images2.shape) == 4) 77 | assert(images2.shape[1] == 3) 78 | assert(np.min(images2[0]) >= 0 and np.max(images2[0]) > 10), 'Image values should be in the range [0, 255]' 79 | assert(images1.shape == images2.shape), 'The two numpy arrays must have the same shape' 80 | print('Calculating FID with %i images from each distribution' % (images1.shape[0])) 81 | start_time = time.time() 82 | act1 = get_inception_activations(images1) 83 | act2 = get_inception_activations(images2) 84 | fid = activations2distance(act1, act2) 85 | print('FID calculation time: %f s' % (time.time() - start_time)) 86 | return fid -------------------------------------------------------------------------------- /tflib/inception_score_tpu.py.bak: -------------------------------------------------------------------------------- 1 | ''' 2 | From https://github.com/tsc2017/Inception-Score 3 | Code derived from https://github.com/openai/improved-gan/blob/master/inception_score/model.py and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py 4 | 5 | Usage: 6 | Call get_inception_score(images, splits=10) 7 | Args: 8 | images: A numpy array with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. A dtype of np.uint8 is recommended to save CPU memory. 9 | splits: The number of splits of the images, default is 10. 10 | Returns: 11 | Mean and standard deviation of the Inception Score across the splits. 12 | ''' 13 | 14 | import tensorflow as tf 15 | import os 16 | import functools 17 | import numpy as np 18 | import time 19 | from tensorflow.python.ops import array_ops 20 | if float('.'.join(tf.__version__.split('.')[:2])) < 1.15: 21 | tfgan = tf.contrib.gan 22 | else: 23 | import tensorflow_gan as tfgan 24 | session=tf.compat.v1.InteractiveSession() 25 | # A smaller BATCH_SIZE reduces GPU memory usage, but at the cost of a slight slowdown 26 | BATCH_SIZE = 8 27 | INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05_v4.tar.gz' 28 | INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score_tpu.pb' 29 | FIRST_RUN=[1] 30 | # Run images through Inception. 31 | inception_images =[None] 32 | def inception_logits(images = inception_images[0], num_splits = 1): 33 | images = inception_images[0] 34 | images = tf.transpose(images, [0, 2, 3, 1]) 35 | size = 299 36 | images = tf.compat.v1.image.resize_bilinear(images, [size, size]) 37 | generated_images_list = array_ops.split(images, num_or_size_splits = num_splits) 38 | logits = tf.map_fn( 39 | fn = functools.partial( 40 | tfgan.eval.run_inception, 41 | default_graph_def_fn = functools.partial( 42 | tfgan.eval.get_graph_def_from_url_tarball, 43 | INCEPTION_URL, 44 | INCEPTION_FROZEN_GRAPH, 45 | os.path.basename(INCEPTION_URL)), 46 | output_tensor = 'logits:0'), 47 | elems = array_ops.stack(generated_images_list), 48 | parallel_iterations = 8, 49 | back_prop = False, 50 | swap_memory = True, 51 | name = 'RunClassifier') 52 | logits = array_ops.concat(array_ops.unstack(logits), 0) 53 | return logits 54 | 55 | logits=[None] 56 | def get_inception_probs(inps, session=None, strategy=None): 57 | if FIRST_RUN[0]: 58 | print('Running Inception for the first time, compiling...') 59 | with session.graph.as_default(): 60 | inception_images[0]=tf.compat.v1.placeholder(tf.float32, [None, 3, None, None], name = 'inception_images') 61 | logits[0]=strategy.experimental_run(inception_logits).values[0] 62 | FIRST_RUN[0]=0 63 | n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE)) 64 | preds = np.zeros([inps.shape[0], 1000], dtype = np.float32) 65 | for i in range(n_batches): 66 | inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1 67 | preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits[0],{inception_images[0]: inp})[:, :1000] 68 | preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True) 69 | return preds 70 | 71 | def preds2score(preds, splits=10): 72 | scores = [] 73 | for i in range(splits): 74 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 75 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 76 | kl = np.mean(np.sum(kl, 1)) 77 | scores.append(np.exp(kl)) 78 | return np.mean(scores), np.std(scores) 79 | 80 | def get_inception_score(images, splits=10, session=None, strategy=None): 81 | assert(type(images) == np.ndarray) 82 | assert(len(images.shape) == 4) 83 | assert(images.shape[1] == 3) 84 | assert(np.min(images[0]) >= 0 and np.max(images[0]) > 10), 'Image values should be in the range [0, 255]' 85 | print('Calculating Inception Score with %i images in %i splits' % (images.shape[0], splits)) 86 | start_time=time.time() 87 | preds = get_inception_probs(images, session, strategy) 88 | mean, std = preds2score(preds, splits) 89 | print('Inception Score calculation time: %f s' % (time.time() - start_time)) 90 | return mean, std # Reference values: 11.38 for 50000 CIFAR-10 training set images, or mean=11.31, std=0.10 if in 10 splits. -------------------------------------------------------------------------------- /tflib/inception_score_tpu.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From https://github.com/tsc2017/Inception-Score 3 | Code derived from https://github.com/openai/improved-gan/blob/master/inception_score/model.py and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py 4 | 5 | Usage: 6 | Call get_inception_score(images, splits=10) 7 | Args: 8 | images: A numpy array with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. A dtype of np.uint8 is recommended to save CPU memory. 9 | splits: The number of splits of the images, default is 10. 10 | Returns: 11 | Mean and standard deviation of the Inception Score across the splits. 12 | ''' 13 | 14 | import tensorflow.compat.v1 as tf 15 | tf.disable_v2_behavior() 16 | import os 17 | import functools 18 | import numpy as np 19 | import time 20 | from tensorflow.python.ops import array_ops 21 | # pip install tensorflow-gan 22 | import tensorflow_gan as tfgan 23 | session=tf.compat.v1.InteractiveSession() 24 | # A smaller BATCH_SIZE reduces TPU memory usage, but at the cost of a slight slowdown 25 | BATCH_SIZE = 1000 26 | INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1' 27 | INCEPTION_OUTPUT = 'logits' 28 | FIRST_RUN=[True] 29 | # Run images through Inception. 30 | inception_images =[None] 31 | image_iterator_init=[None] 32 | inception_size = 299 33 | input_size=[32] 34 | def inception_logits(images): 35 | images = tf.transpose(images, [0, 2, 3, 1]) 36 | images = tf.compat.v1.image.resize_bilinear(images, [inception_size, inception_size]) 37 | generated_images_list = array_ops.split(images, num_or_size_splits = 1) 38 | logits = tf.map_fn( 39 | fn = tfgan.eval.classifier_fn_from_tfhub(INCEPTION_TFHUB, INCEPTION_OUTPUT, True), 40 | elems = array_ops.stack(generated_images_list), 41 | parallel_iterations = 1, 42 | back_prop = False, 43 | swap_memory = True, 44 | name = 'RunClassifier') 45 | logits = array_ops.concat(array_ops.unstack(logits), 0) 46 | return logits 47 | 48 | logits=[None] 49 | def get_inception_probs(inps, session=None, strategy=None): 50 | if FIRST_RUN[0]: 51 | print('Running Inception for the first time, compiling...') 52 | with session.graph.as_default(): 53 | inception_images[0]=tf.compat.v1.placeholder(tf.float32, [BATCH_SIZE, 3, input_size[0], input_size[0]], name = 'inception_images') 54 | image_dataset = tf.data.Dataset.from_tensor_slices((inception_images[0])).batch(BATCH_SIZE, drop_remainder=True) 55 | image_iterator = strategy.make_dataset_iterator(image_dataset) 56 | image_iterator_init[0] = image_iterator.initialize() 57 | logits[0]=tf.concat(strategy.experimental_run(inception_logits, image_iterator).values,0) 58 | FIRST_RUN[0]=False 59 | n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE)) 60 | preds = np.zeros([inps.shape[0], 1000], dtype = np.float32) 61 | for i in range(n_batches): 62 | inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1 63 | session.run(image_iterator_init[0],{inception_images[0]: inp}) 64 | preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits[0])[:, :1000] 65 | preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True) 66 | return preds 67 | 68 | def preds2score(preds, splits=10): 69 | scores = [] 70 | for i in range(splits): 71 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 72 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 73 | kl = np.mean(np.sum(kl, 1)) 74 | scores.append(np.exp(kl)) 75 | return np.mean(scores), np.std(scores) 76 | 77 | def get_inception_score(images, splits=10, session=None, strategy=None): 78 | assert(type(images) == np.ndarray) 79 | assert(len(images.shape) == 4) 80 | assert(images.shape[1] == 3) 81 | assert(np.min(images[0]) >= 0 and np.max(images[0]) > 10), 'Image values should be in the range [0, 255]' 82 | input_size[0]=images.shape[3] 83 | print('Calculating Inception Score with %i images in %i splits' % (images.shape[0], splits)) 84 | start_time=time.time() 85 | preds = get_inception_probs(images, session, strategy) 86 | mean, std = preds2score(preds, splits) 87 | print('Inception Score calculation time: %f s' % (time.time() - start_time)) 88 | return mean, std # Reference values: 11.38 for 50000 CIFAR-10 training set images, or mean=11.31, std=0.10 if in 10 splits. -------------------------------------------------------------------------------- /tflib/celeba.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import time 4 | import os 5 | #Loading all images into memory all at once will take a long time at startup while not loading them into memory will slow down the program. 6 | #So we adopt the following alternative: 7 | #If an image is accessed for the first time, it will be loaded into memory. 8 | #Subsequent visits to it will go to memory directly, which results in a 2~3x acceration 9 | data_dir=os.environ['HOME']+'/data/mydata/CelebA_aligned/CelebA_crop' 10 | train_size=200000 # real training set size, should not be modified 11 | test_size=2599 # real test set size, should not be modified 12 | height=width=32 # can be modified 13 | loaded=np.zeros(train_size+test_size) 14 | all_images=[] 15 | all_labels=[] 16 | indices=[] 17 | attr_file=os.environ['HOME']+'/data/mydata/CelebA_aligned/list_attr_celeba.txt' 18 | with open(attr_file, 'r') as csvfile: 19 | line = csvfile.readline().strip() 20 | num_attr_records=int(line) 21 | line = csvfile.readline().strip() 22 | attr_names=line.split(' ') 23 | attrs=np.loadtxt(attr_file,dtype=int,skiprows=2,usecols=np.arange(40)+1) 24 | 25 | def set_height(HEIGHT): 26 | global height,width,all_images 27 | height=width=HEIGHT 28 | all_images=np.zeros((train_size+test_size, 3, height, width), dtype='uint8') 29 | 30 | def set_chosen_attr(chosen_attr_name): 31 | assert(len(chosen_attr_name)>0) 32 | global all_labels,indices,attr_names 33 | chosen_attr_idx=np.zeros(len(chosen_attr_name),dtype=int) 34 | for i in range(len(chosen_attr_name)): 35 | chosen_attr_idx[i]=attr_names.index(chosen_attr_name[i]) 36 | all_labels=attrs[:,chosen_attr_idx] 37 | all_labels[all_labels==-1]=0 38 | print('Balancing training set attributes...') 39 | balanced_indices={} 40 | num_attrs=len(chosen_attr_name) 41 | for i in range(2**num_attrs): 42 | balanced_indices[bin(i)[2:].zfill(num_attrs)]=[] 43 | print( balanced_indices) 44 | for i in range(len(all_labels)): 45 | balanced_indices[''.join(all_labels[i].astype(str)).zfill(num_attrs)].append(i) 46 | min_length=len(all_labels) 47 | for i in range(2**num_attrs): 48 | min_length=min(min_length,len(balanced_indices[bin(i)[2:].zfill(num_attrs)])) 49 | for j in range(min_length): 50 | for i in range(2**num_attrs): 51 | indices+=[balanced_indices[bin(i)[2:].zfill(num_attrs)][j]] 52 | 53 | indices=np.array(indices) 54 | 55 | #np.random.shuffle(indices) 56 | print('Training set size=%i'%len(indices)) 57 | return len(indices) 58 | 59 | def set_chosen_attr(chosen_attr_idx): 60 | print('chosen attrs:', chosen_attr_idx) 61 | global all_labels,indices 62 | all_labels=attrs#[:,chosen_attr_idx] 63 | all_labels[all_labels==-1]=0 64 | #all_labels[:,list(set(np.arange(40))-set(chosen_attr_idx))]=0 65 | #indices=np.array(indices) 66 | indices=np.array(range(train_size)) 67 | #np.random.shuffle(indices) 68 | print('Training set size=%i'%len(indices)) 69 | return len(indices) 70 | 71 | def make_generator(mode, image_path, name_list,permute_channels=False): 72 | global height,width,all_images 73 | epoch_count = [1] 74 | #def get_epoch(): 75 | images = np.zeros((len(name_list), 3, height, width), dtype='uint8') 76 | files = indices[name_list] #if mode=='TRAIN' else name_list 77 | #random_state = np.random.RandomState(epoch_count[0]) 78 | #random_state.shuffle(files) 79 | epoch_count[0] += 1 80 | for n, i in enumerate(files): 81 | if loaded[i]==0: 82 | image = scipy.misc.imread("{}/{}.jpg".format(image_path, str(i+1).zfill(len(str(train_size+test_size)))),mode='RGB') 83 | image=scipy.misc.imresize(image, [height, width]) 84 | all_images[i] = image.transpose(2,0,1) 85 | loaded[i]=1 86 | return all_images[np.ix_(files)], all_labels[np.ix_(files)] 87 | def load(mode, data_dir,name_list,permute_channels=False): 88 | return make_generator(mode, data_dir, name_list,permute_channels) 89 | 90 | 91 | def get_prior(): 92 | return all_labels.sum(axis=0)/np.float32(all_labels.shape[0]) 93 | 94 | if __name__ == '__main__': 95 | train_gen, valid_gen = load(64) 96 | t0 = time.time() 97 | for i, batch in enumerate(train_gen(), start=1): 98 | print ("{}\t{}".format(str(time.time() - t0), batch[0][0,0,0,0])) 99 | if i == 1000: 100 | break 101 | t0 = time.time() -------------------------------------------------------------------------------- /tflib/ops/batchnorm.py: -------------------------------------------------------------------------------- 1 | import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | dtype='float32' 6 | def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True, fused=True): 7 | input_type=inputs.dtype 8 | inputs=tf.cast(inputs,'float32') 9 | if ((axes == [0,1,2]) or (axes == [0,1])) and fused==True: 10 | if axes==[0,1]: 11 | inputs = tf.expand_dims(inputs, 2) 12 | # Old (working but pretty slow) implementation: 13 | ########## 14 | 15 | # inputs = tf.transpose(inputs, [0,2,3,1]) 16 | 17 | # mean, var = tf.nn.moments(inputs, [0,1,2], keep_dims=False) 18 | # offset = lib.param(name+'.offset', np.zeros(mean.get_shape()[-1], dtype='float32')) 19 | # scale = lib.param(name+'.scale', np.ones(var.get_shape()[-1], dtype='float32')) 20 | # result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-4) 21 | 22 | # return tf.transpose(result, [0,3,1,2]) 23 | 24 | # New (super fast but untested) implementation: 25 | offset = lib.param(name+'.offset', np.zeros(inputs.get_shape()[3], dtype=dtype)) 26 | scale = lib.param(name+'.scale', np.ones(inputs.get_shape()[3], dtype=dtype)) 27 | 28 | moving_mean = lib.param(name+'.moving_mean', np.zeros(inputs.get_shape()[3], dtype=dtype), trainable=False) 29 | moving_variance = lib.param(name+'.moving_variance', np.ones(inputs.get_shape()[3], dtype=dtype), trainable=False) 30 | 31 | def _fused_batch_norm_training(): 32 | return tf.compat.v1.nn.fused_batch_norm(inputs, scale, offset, epsilon=1e-5) 33 | def _fused_batch_norm_inference(): 34 | # Version which blends in the current item's statistics 35 | batch_size = tf.cast(tf.shape(inputs)[0], dtype) 36 | mean, var = tf.nn.moments(inputs, [1,2], keep_dims=True) 37 | mean = ((1./batch_size)*mean) + (((batch_size-1.)/batch_size)*moving_mean)[None,None,None,:] 38 | var = ((1./batch_size)*var) + (((batch_size-1.)/batch_size)*moving_variance)[None,None,None,:] 39 | return tf.nn.batch_normalization(inputs, mean, var, offset[None,None,None,:], scale[None,None,None,:], 1e-5), mean, var 40 | 41 | # Standard version 42 | # return tf.nn.fused_batch_norm( 43 | # inputs, 44 | # scale, 45 | # offset, 46 | # epsilon=1e-2, 47 | # mean=moving_mean, 48 | # variance=moving_variance, 49 | # is_training=False, 50 | # data_format='NCHW' 51 | # ) 52 | 53 | if is_training is None: 54 | outputs, batch_mean, batch_var = _fused_batch_norm_training() 55 | 56 | else: 57 | outputs, batch_mean, batch_var = tf.cond(is_training,_fused_batch_norm_training, _fused_batch_norm_inference) 58 | if update_moving_stats: 59 | no_updates = lambda: outputs 60 | def _force_updates(): 61 | """Internal function forces updates moving_vars if is_training.""" 62 | float_stats_iter = tf.cast(stats_iter, dtype) 63 | 64 | update_moving_mean = tf.assign(moving_mean, ((float_stats_iter/(float_stats_iter+1))*moving_mean) + ((1/(float_stats_iter+1))*batch_mean)) 65 | update_moving_variance = tf.assign(moving_variance, ((float_stats_iter/(float_stats_iter+1))*moving_variance) + ((1/(float_stats_iter+1))*batch_var)) 66 | 67 | with tf.control_dependencies([update_moving_mean, update_moving_variance]): 68 | return tf.identity(outputs) 69 | outputs = tf.cond(is_training, _force_updates, no_updates) 70 | 71 | if axes == [0,1]: 72 | return outputs[:,:,:,0] # collapse last dim 73 | else: 74 | return tf.cast(outputs,input_type) 75 | else: 76 | # raise Exception('old BN') 77 | # TODO we can probably use nn.fused_batch_norm here too for speedup 78 | mean, var = tf.nn.moments(inputs, axes, keep_dims=True) 79 | shape = mean.get_shape().as_list() 80 | if 0 not in axes: 81 | print("WARNING ({}): didn't find 0 in axes, but not using separate BN params for each item in batch".format(name)) 82 | shape[0] = 1 83 | offset = lib.param(name+'.offset', np.zeros(shape, dtype=dtype)) 84 | scale = lib.param(name+'.scale', np.ones(shape, dtype=dtype)) 85 | result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5) 86 | 87 | return result 88 | -------------------------------------------------------------------------------- /tflib/fid_tpu.py.bak: -------------------------------------------------------------------------------- 1 | ''' 2 | From https://github.com/tsc2017/Frechet-Inception-Distance 3 | Code derived from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py 4 | 5 | Usage: 6 | Call get_fid(images1, images2, session=YOUR_SESSION, strategy=YOUR_TPUSTRATEGY) 7 | Args: 8 | images1, images2: Numpy arrays with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. 9 | dtype of the images is recommended to be np.uint8 to save CPU memory. 10 | Returns: 11 | Frechet Inception Distance between the two image distributions. 12 | ''' 13 | 14 | import tensorflow as tf 15 | import os 16 | import functools 17 | import numpy as np 18 | import time 19 | from tensorflow.python.ops import array_ops 20 | if float('.'.join(tf.__version__.split('.')[:2])) < 1.15: 21 | tfgan = tf.contrib.gan 22 | else: 23 | import tensorflow_gan as tfgan 24 | FIRST_RUN=[1] 25 | session=tf.compat.v1.InteractiveSession() 26 | # A smaller BATCH_SIZE reduces GPU memory usage, but at the cost of a slight slowdown 27 | BATCH_SIZE = 8 28 | INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05_v4.tar.gz' 29 | INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score_tpu.pb' 30 | # Run images through Inception. 31 | inception_images = [None] 32 | activations1 = [None] 33 | activations2 = [None] 34 | fcd = [None] 35 | 36 | def inception_activations(num_splits = 1): 37 | images = inception_images[0] 38 | images = tf.transpose(images, [0, 2, 3, 1]) 39 | size = 299 40 | images = tf.compat.v1.image.resize_bilinear(images, [size, size]) 41 | generated_images_list = array_ops.split(images, num_or_size_splits = num_splits) 42 | activations = tf.map_fn( 43 | fn = functools.partial( 44 | tfgan.eval.run_inception, 45 | default_graph_def_fn = functools.partial( 46 | tfgan.eval.get_graph_def_from_url_tarball, 47 | INCEPTION_URL, 48 | INCEPTION_FROZEN_GRAPH, 49 | os.path.basename(INCEPTION_URL)), 50 | output_tensor = 'pool_3:0'), 51 | elems = array_ops.stack(generated_images_list), 52 | parallel_iterations = 1, 53 | back_prop = False, 54 | swap_memory = True, 55 | name = 'RunClassifier') 56 | activations = array_ops.concat(array_ops.unstack(activations), 0) 57 | return activations 58 | 59 | activations =[None] 60 | 61 | def get_inception_activations(inps, session=None, strategy=None): 62 | if FIRST_RUN[0]: 63 | with session.graph.as_default(): 64 | inception_images[0] = tf.compat.v1.placeholder(tf.float32, [None, 3, None, None], name = 'inception_images') 65 | activations1[0] = tf.compat.v1.placeholder(tf.float32, [None, None], name = 'activations1') 66 | activations2[0] = tf.compat.v1.placeholder(tf.float32, [None, None], name = 'activations2') 67 | fcd[0] = tfgan.eval.frechet_classifier_distance_from_activations(activations1[0], activations2[0]) 68 | print('Running Inception for the first time, compiling...') 69 | activations[0]=strategy.experimental_run(inception_activations).values[0] 70 | FIRST_RUN[0]=0 71 | n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE)) 72 | act = np.zeros([inps.shape[0], 2048], dtype = np.float32) 73 | for i in range(n_batches): 74 | inp = inps[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] / 255. * 2 - 1 75 | act[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(activations[0], feed_dict = {inception_images[0]: inp}) 76 | return act 77 | 78 | def activations2distance(act1, act2, session=None): 79 | return session.run(fcd[0], feed_dict = {activations1[0]: act1, activations2[0]: act2}) 80 | 81 | def get_fid(images1, images2, session=None, strategy=None): 82 | assert(type(images1) == np.ndarray) 83 | assert(len(images1.shape) == 4) 84 | assert(images1.shape[1] == 3) 85 | assert(np.min(images1[0]) >= 0 and np.max(images1[0]) > 10), 'Image values should be in the range [0, 255]' 86 | assert(type(images2) == np.ndarray) 87 | assert(len(images2.shape) == 4) 88 | assert(images2.shape[1] == 3) 89 | assert(np.min(images2[0]) >= 0 and np.max(images2[0]) > 10), 'Image values should be in the range [0, 255]' 90 | assert(images1.shape == images2.shape), 'The two numpy arrays must have the same shape' 91 | print('Calculating FID with %i images from each distribution' % (images1.shape[0])) 92 | start_time = time.time() 93 | act1 = get_inception_activations(images1, session, strategy) 94 | act2 = get_inception_activations(images2, session, strategy) 95 | fid = activations2distance(act1, act2, session) 96 | print('FID calculation time: %f s' % (time.time() - start_time)) 97 | return fid -------------------------------------------------------------------------------- /tflib/fid_tpu.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From https://github.com/tsc2017/Frechet-Inception-Distance 3 | Code derived from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py 4 | 5 | Usage: 6 | Call get_fid(images1, images2, session=YOUR_SESSION, strategy=YOUR_TPUSTRATEGY) 7 | Args: 8 | images1, images2: Numpy arrays with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. 9 | dtype of the images is recommended to be np.uint8 to save CPU memory. 10 | Returns: 11 | Frechet Inception Distance between the two image distributions. 12 | ''' 13 | 14 | import tensorflow.compat.v1 as tf 15 | tf.disable_v2_behavior() 16 | import os 17 | import functools 18 | import numpy as np 19 | import time 20 | from tensorflow.python.ops import array_ops 21 | if float('.'.join(tf.__version__.split('.')[:2])) < 1.15: 22 | tfgan = tf.contrib.gan 23 | else: 24 | import tensorflow_gan as tfgan 25 | FIRST_RUN=[1] 26 | session=tf.compat.v1.InteractiveSession() 27 | # A smaller BATCH_SIZE reduces TPU memory usage, but at the cost of a slight slowdown 28 | BATCH_SIZE = 1000 29 | INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05_v4.tar.gz' 30 | INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score_tpu.pb' 31 | # Run images through Inception. 32 | inception_images =[None] 33 | image_iterator_init=[None] 34 | inception_size = 299 35 | input_size=[32] 36 | activations1 = [None] 37 | activations2 = [None] 38 | fcd = [None] 39 | INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1' 40 | INCEPTION_OUTPUT = 'logits' 41 | INCEPTION_FINAL_POOL = 'pool_3' 42 | def inception_activations(images): 43 | images = tf.transpose(images, [0, 2, 3, 1]) 44 | images = tf.compat.v1.image.resize_bilinear(images, [inception_size, inception_size]) 45 | generated_images_list = array_ops.split(images, num_or_size_splits = 1) 46 | activations = tf.map_fn( 47 | fn = tfgan.eval.classifier_fn_from_tfhub(INCEPTION_TFHUB, INCEPTION_FINAL_POOL, True), 48 | elems = array_ops.stack(generated_images_list), 49 | parallel_iterations = 1, 50 | back_prop = False, 51 | swap_memory = True, 52 | name = 'RunClassifier') 53 | activations = array_ops.concat(array_ops.unstack(activations), 0) 54 | return activations 55 | 56 | activations =[None] 57 | 58 | def get_inception_activations(inps, session=None, strategy=None): 59 | if FIRST_RUN[0]: 60 | with session.graph.as_default(): 61 | activations1[0] = tf.compat.v1.placeholder(tf.float32, [None, None], name = 'activations1') 62 | activations2[0] = tf.compat.v1.placeholder(tf.float32, [None, None], name = 'activations2') 63 | fcd[0] = tfgan.eval.frechet_classifier_distance_from_activations(activations1[0], activations2[0]) 64 | print('Running Inception for the first time, compiling...') 65 | inception_images[0]=tf.compat.v1.placeholder(tf.float32, [BATCH_SIZE, 3, input_size[0], input_size[0]], name = 'inception_images') 66 | image_dataset = tf.data.Dataset.from_tensor_slices((inception_images[0])).batch(BATCH_SIZE, drop_remainder=True) 67 | image_iterator = strategy.make_dataset_iterator(image_dataset) 68 | image_iterator_init[0] = image_iterator.initialize() 69 | activations[0]=tf.concat(strategy.experimental_run(inception_activations, image_iterator).values,0) 70 | FIRST_RUN[0]=0 71 | n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE)) 72 | act = np.zeros([inps.shape[0], 2048], dtype = np.float32) 73 | for i in range(n_batches): 74 | inp = inps[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] / 255. * 2 - 1 75 | session.run(image_iterator_init[0],{inception_images[0]: inp}) 76 | act[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(activations[0]) 77 | return act 78 | 79 | def activations2distance(act1, act2, session=None): 80 | return session.run(fcd[0], feed_dict = {activations1[0]: act1, activations2[0]: act2}) 81 | 82 | def get_fid(images1, images2, session=None, strategy=None): 83 | assert(type(images1) == np.ndarray) 84 | assert(len(images1.shape) == 4) 85 | assert(images1.shape[1] == 3) 86 | assert(np.min(images1[0]) >= 0 and np.max(images1[0]) > 10), 'Image values should be in the range [0, 255]' 87 | assert(type(images2) == np.ndarray) 88 | assert(len(images2.shape) == 4) 89 | assert(images2.shape[1] == 3) 90 | assert(np.min(images2[0]) >= 0 and np.max(images2[0]) > 10), 'Image values should be in the range [0, 255]' 91 | assert(images1.shape == images2.shape), 'The two numpy arrays must have the same shape' 92 | input_size[0]=images1.shape[3] 93 | print('Calculating FID with %i images from each distribution' % (images1.shape[0])) 94 | start_time = time.time() 95 | act1 = get_inception_activations(images1, session, strategy) 96 | act2 = get_inception_activations(images2, session, strategy) 97 | print('Activations calculation time: %f s' % (time.time()-start_time)) 98 | fid = activations2distance(act1, act2, session) 99 | print('FID calculation time: %f s' % (time.time() - start_time)) 100 | return fid -------------------------------------------------------------------------------- /tflib/tpu_ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tensorflow operations specific to TPUs.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | #import gin 23 | from six.moves import range 24 | import tensorflow.compat.v1 as tf 25 | 26 | from tensorflow.python.tpu import tpu_function 27 | 28 | 29 | def cross_replica_concat(value, replica_id, num_replicas): 30 | """Reduce a concatenation of the `value` across TPU replicas. 31 | 32 | Args: 33 | value: Tensor to concatenate. 34 | replica_id: Integer tensor that indicates the index of the replica. 35 | num_replicas: Python integer, total number of replicas. 36 | 37 | Returns: 38 | Tensor of the same rank as value with first dimension `num_replicas` 39 | times larger. 40 | 41 | Raises: 42 | ValueError: If `value` is a scalar. 43 | """ 44 | if value.shape.ndims < 1: 45 | raise ValueError("Value must have at least rank 1 but got {}.".format( 46 | value.shape.ndims)) 47 | if num_replicas <= 1: 48 | return value 49 | with tf.name_scope(None, "tpu_cross_replica_concat"): 50 | # Mask is one hot encoded position of the core_index. 51 | mask = tf.to_float(tf.equal(tf.range(num_replicas), replica_id)) 52 | # Expand dims with 1's to match rank of value. 53 | mask = tf.reshape(mask, [num_replicas] + [1] * value.shape.ndims) 54 | if value.dtype in {tf.bfloat16, tf.float32}: 55 | result = mask * value 56 | else: 57 | result = mask * tf.to_float(value) 58 | # Thanks to broadcasting now result is set only in the position pointed by 59 | # replica_id, the rest of the vector is set to 0's. 60 | # All these steps are basically implementing tf.scatter_nd which is missing 61 | # in TPU's backend since it doesn't support sparse operations. 62 | 63 | # Merge first 2 dimensions. 64 | # This is equivalent to (value.shape[0].value * num_replicas). 65 | # Using [-1] trick to support also scalar input. 66 | result = tf.reshape(result, [-1] + result.shape.as_list()[2:]) 67 | # Each core set the "results" in position pointed by replica_id. When we now 68 | # sum across replicas we exchange the information and fill in local 0's with 69 | # values from other cores. 70 | result = tf.tpu.cross_replica_sum(result) 71 | # Now all the cores see exactly the same data. 72 | return tf.cast(result, dtype=value.dtype) 73 | 74 | 75 | def cross_replica_mean(inputs, group_size=None): 76 | """Calculates the average value of inputs tensor across TPU replicas.""" 77 | num_replicas = tpu_function.get_tpu_context().number_of_shards 78 | if not group_size: 79 | group_size = num_replicas 80 | if group_size == 1: 81 | return inputs 82 | if group_size != num_replicas: 83 | group_assignment = [] 84 | assert num_replicas % group_size == 0 85 | for g in range(num_replicas // group_size): 86 | replica_ids = [g * group_size + i for i in range(group_size)] 87 | group_assignment.append(replica_ids) 88 | else: 89 | group_assignment = None 90 | return tf.tpu.cross_replica_sum(inputs, group_assignment) / tf.cast( 91 | group_size, inputs.dtype) 92 | 93 | 94 | #@gin.configurable(blacklist=["inputs", "axis"]) 95 | def cross_replica_moments(inputs, axis, parallel=True, group_size=None): 96 | """Compute mean and variance of the inputs tensor across TPU replicas. 97 | 98 | Args: 99 | inputs: A tensor with 2 or more dimensions. 100 | axis: Array of ints. Axes along which to compute mean and variance. 101 | parallel: Use E[x^2] - (E[x])^2 to compute variance. Then can be done 102 | in parallel to computing the mean and reducing the communication overhead. 103 | group_size: Integer, the number of replicas to compute moments arcoss. 104 | None or 0 will use all replicas (global). 105 | 106 | Returns: 107 | Two tensors with mean and variance. 108 | """ 109 | # Compute local mean and then average across replicas. 110 | mean = tf.math.reduce_mean(inputs, axis=axis) 111 | mean = cross_replica_mean(mean) 112 | if parallel: 113 | # Compute variance using the E[x^2] - (E[x])^2 formula. This is less 114 | # numerically stable than the E[(x-E[x])^2] formula, but allows the two 115 | # cross-replica sums to be computed in parallel, saving communication 116 | # overhead. 117 | mean_of_squares = tf.reduce_mean(tf.square(inputs), axis=axis) 118 | mean_of_squares = cross_replica_mean(mean_of_squares, group_size=group_size) 119 | mean_squared = tf.square(mean) 120 | variance = mean_of_squares - mean_squared 121 | else: 122 | variance = tf.math.reduce_mean( 123 | tf.math.square(inputs - mean), axis=axis) 124 | variance = cross_replica_mean(variance, group_size=group_size) 125 | return mean, variance -------------------------------------------------------------------------------- /tflib/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import locale 5 | 6 | locale.setlocale(locale.LC_ALL, '') 7 | weight_init = tf.compat.v1.truncated_normal_initializer(mean=0.0, stddev=0.02) 8 | _params = {} 9 | _param_aliases = {} 10 | 11 | 12 | def param(name, *args, **kwargs): 13 | """ 14 | A wrapper for `tf.Variable` which enables parameter sharing in models. 15 | 16 | Creates and returns theano shared variables similarly to `tf.Variable`, 17 | except if you try to create a param with the same name as a 18 | previously-created one, `param(...)` will just return the old one instead of 19 | making a new one. 20 | 21 | This constructor also adds a `param` attribute to the shared variables it 22 | creates, so that you can easily search a graph for all params. 23 | """ 24 | 25 | if name not in _params: 26 | kwargs['name'] = name 27 | if 'ema' in name: 28 | kwargs['trainable']=False 29 | param = tf.Variable(*args, **kwargs) 30 | param.param = True 31 | _params[name] = param 32 | result = _params[name] 33 | i = 0 34 | while result in _param_aliases: 35 | # print 'following alias {}: {} to {}'.format(i, result, _param_aliases[result]) 36 | i += 1 37 | result = _param_aliases[result] 38 | return result 39 | 40 | def get_param(name, *args, **kwargs): 41 | """ 42 | A wrapper for `tf.get_variable` which enables parameter sharing in models. 43 | 44 | Creates and returns theano shared variables similarly to `tf.get_variable`, 45 | except if you try to create a param with the same name as a 46 | previously-created one, `param(...)` will just return the old one instead of 47 | making a new one. 48 | 49 | This constructor also adds a `param` attribute to the shared variables it 50 | creates, so that you can easily search a graph for all params. 51 | """ 52 | 53 | if name not in _params: 54 | #with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 55 | if 'ema' in name: 56 | kwargs['trainable']=False 57 | param = tf.compat.v1.get_variable(name, *args, **kwargs) 58 | param.param = True 59 | _params[name] = param 60 | result = _params[name] 61 | i = 0 62 | while result in _param_aliases: 63 | # print 'following alias {}: {} to {}'.format(i, result, _param_aliases[result]) 64 | i += 1 65 | result = _param_aliases[result] 66 | return result 67 | 68 | def params_with_name(name1, name2=None): 69 | if name2: 70 | return [p for n,p in _params.items() if name1 in n and name2 in p.name] 71 | else: 72 | return [p for n,p in _params.items() if name1 in n] 73 | 74 | def delete_all_params(): 75 | _params.clear() 76 | 77 | def alias_params(replace_dict): 78 | for old,new in replace_dict.items(): 79 | # print "aliasing {} to {}".format(old,new) 80 | _param_aliases[old] = new 81 | 82 | def delete_param_aliases(): 83 | _param_aliases.clear() 84 | 85 | # def search(node, critereon): 86 | # """ 87 | # Traverse the Theano graph starting at `node` and return a list of all nodes 88 | # which match the `critereon` function. When optimizing a cost function, you 89 | # can use this to get a list of all of the trainable params in the graph, like 90 | # so: 91 | 92 | # `lib.search(cost, lambda x: hasattr(x, "param"))` 93 | # """ 94 | 95 | # def _search(node, critereon, visited): 96 | # if node in visited: 97 | # return [] 98 | # visited.add(node) 99 | 100 | # results = [] 101 | # if isinstance(node, T.Apply): 102 | # for inp in node.inputs: 103 | # results += _search(inp, critereon, visited) 104 | # else: # Variable node 105 | # if critereon(node): 106 | # results.append(node) 107 | # if node.owner is not None: 108 | # results += _search(node.owner, critereon, visited) 109 | # return results 110 | 111 | # return _search(node, critereon, set()) 112 | 113 | # def print_params_info(params): 114 | # """Print information about the parameters in the given param set.""" 115 | 116 | # params = sorted(params, key=lambda p: p.name) 117 | # values = [p.get_value(borrow=True) for p in params] 118 | # shapes = [p.shape for p in values] 119 | # print "Params for cost:" 120 | # for param, value, shape in zip(params, values, shapes): 121 | # print "\t{0} ({1})".format( 122 | # param.name, 123 | # ",".join([str(x) for x in shape]) 124 | # ) 125 | 126 | # total_param_count = 0 127 | # for shape in shapes: 128 | # param_count = 1 129 | # for dim in shape: 130 | # param_count *= dim 131 | # total_param_count += param_count 132 | # print "Total parameter count: {0}".format( 133 | # locale.format("%d", total_param_count, grouping=True) 134 | # ) 135 | 136 | def print_model_settings(locals_): 137 | print("Uppercase local vars:") 138 | all_vars = [(k,v) for (k,v) in locals_.items() if (k.isupper() and k!='T' and k!='SETTINGS' and k!='ALL_SETTINGS')] 139 | all_vars = sorted(all_vars, key=lambda x: x[0]) 140 | for var_name, var_value in all_vars: 141 | print("\t{}: {}".format(var_name, var_value)) 142 | 143 | 144 | def print_model_settings_dict(settings): 145 | print("Settings dict:") 146 | all_vars = [(k,v) for (k,v) in settings.items()] 147 | all_vars = sorted(all_vars, key=lambda x: x[0]) 148 | for var_name, var_value in all_vars: 149 | print("\t{}: {}".format(var_name, var_value)) -------------------------------------------------------------------------------- /tflib/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tflib as lib 3 | dtype='float32' 4 | def orthogonal_regularizer(scale) : 5 | """ Defining the Orthogonal regularizer and return the function at last to be used in Conv layer as kernel regularizer""" 6 | 7 | def ortho_reg(w) : 8 | """ Reshaping the matrxi in to 2D tensor for enforcing orthogonality""" 9 | w = tf.reshape(w, [-1, int(w.shape[-1])]) 10 | w=tf.transpose(w) if w.shape[0]= j) or (mask_type=='b' and i > j): 57 | mask[ 58 | center, 59 | center, 60 | i::mask_n_channels, 61 | j::mask_n_channels 62 | ] = 0. 63 | 64 | 65 | def uniform(stdev, size): 66 | return np.random.uniform( 67 | low=-stdev * np.sqrt(3), 68 | high=stdev * np.sqrt(3), 69 | size=size 70 | ).astype(dtype) 71 | 72 | 73 | fan_in = input_dim * filter_size**2 74 | fan_out = output_dim * filter_size**2 / (stride**2) 75 | if depthwise: 76 | fan_in = filter_size**2 77 | fan_out = channel_multiplier * filter_size**2 / (stride**2) 78 | if mask_type is not None: # only approximately correct 79 | fan_in /= 2. 80 | fan_out /= 2. 81 | 82 | if he_init: 83 | filters_stdev = np.sqrt(4./(fan_in+fan_out)) 84 | else: # Normalized init (Glorot & Bengio) 85 | filters_stdev = np.sqrt(2./(fan_in+fan_out)) 86 | 87 | if _weights_stdev is not None: 88 | filter_values = uniform( 89 | _weights_stdev, 90 | (filter_size, filter_size, input_dim, output_dim) 91 | ) 92 | else: 93 | filter_values = uniform( 94 | filters_stdev, 95 | (filter_size, filter_size, input_dim, output_dim) 96 | ) 97 | if depthwise: 98 | filter_values = uniform( 99 | filters_stdev, 100 | (filter_size, filter_size, input_dim, channel_multiplier) 101 | ) 102 | 103 | # print "WARNING IGNORING GAIN" 104 | filter_values *= gain 105 | 106 | filters = lib.get_param(name+'.Filters', filter_values.shape, dtype, weight_init, weight_regularizer) 107 | #filters = lib.param(name+'.Filters', filter_values) 108 | 109 | #tf.add_to_collection('G_conv' if 'Generator' in name else 'D_conv',orthogonal_regularizer(0.0001)(filters)) 110 | if depthwise and group: 111 | filters=tf.reshape(filters,[filter_size, filter_size, 1, input_dim*channel_multiplier]) 112 | #filters=tf.reshape(filters,[filter_size, filter_size, input_dim,channel_multiplier]) 113 | if weightnorm==None: 114 | weightnorm = _default_weightnorm 115 | if weightnorm: 116 | norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,1,2))) 117 | target_norms = lib.param(name + '.g',norm_values) 118 | with tf.name_scope('weightnorm') as scope: 119 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0,1,2])) 120 | filters = filters * (target_norms / (norms+1e-12)) 121 | 122 | # spectral normalization 123 | power_method_update=tf.zeros([]) 124 | t_update=tf.zeros([]) 125 | if spectralnorm==None: 126 | spectralnorm = _default_spectralnorm 127 | if spectralnorm and not depthwise: 128 | filters=spectral_norm(filters,update_sn=update_sn) 129 | ''' 130 | v=lib.param(name + '.sn.v',np.random.randn(1, output_dim),dtype=dtype,trainable=False) 131 | W=tf.reshape(filters,[filter_size*filter_size*input_dim, output_dim]) 132 | new_u = tf.nn.l2_normalize(tf.matmul(v, tf.transpose(W))) 133 | new_v = tf.nn.l2_normalize(tf.matmul(new_u, W)) 134 | new_u = tf.stop_gradient(new_u) 135 | new_v = tf.stop_gradient(new_v) 136 | #new_u=tf.random_normal(new_u.shape) 137 | #new_v=tf.random_normal(new_v.shape) 138 | #new_u = tf.nn.l2_normalize((new_u),1) 139 | #new_v = tf.nn.l2_normalize((new_v),1) 140 | 141 | spectral_norm = tf.matmul(tf.matmul(new_u, W),tf.transpose(new_v)) 142 | #spectral_norm=tf.math.reduce_logsumexp(tf.abs(W)) 143 | #spectral_norm=tf.svd(W, compute_uv=False)[0] 144 | #spectral_norm=tf.stop_gradient(spectral_norm) 145 | 146 | #filters/=tf.norm(filters) 147 | if name not in norm_weight_names: 148 | print('first of', name) 149 | norm_weight_names.append(name) 150 | power_method_update = tf.assign(v, new_v) 151 | with tf.control_dependencies([power_method_update]): 152 | filters=tf.reshape(W/spectral_norm, filters.shape)#*target_norm 153 | else: 154 | print('not first of', name) 155 | filters=tf.reshape(W/spectral_norm, filters.shape) 156 | ''' 157 | 158 | if mask_type is not None: 159 | with tf.name_scope('filter_mask'): 160 | filters = filters * mask 161 | 162 | if not depthwise or group: 163 | result = tf.nn.conv2d( 164 | input=inputs, 165 | filter=filters, 166 | strides=[1, stride, stride, 1], 167 | padding='SAME', 168 | data_format='NHWC') 169 | if 'Generator' in name: 170 | rec = tf.nn.conv2d_transpose( 171 | value=result, 172 | filter=filters, 173 | output_shape=inputs.shape.as_list(), 174 | strides=[1,stride,stride, 1], 175 | padding='SAME' 176 | ) 177 | assert inputs.shape==rec.shape 178 | tf.add_to_collection('REC_LOSS', tf.reduce_mean((tf.stop_gradient(inputs)-rec)**2)) 179 | else: 180 | result = tf.nn.depthwise_conv2d( 181 | input=inputs, 182 | filter=filters, 183 | strides=[1, stride, stride, 1], 184 | padding='SAME', 185 | data_format='NHWC') 186 | if biases: 187 | _biases = lib.param( 188 | name+'.Biases', 189 | np.zeros(output_dim, dtype=dtype) 190 | ) 191 | result = tf.nn.bias_add(result, _biases, data_format='NHWC') 192 | 193 | return result 194 | -------------------------------------------------------------------------------- /tflib/custom_ops/fused_bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #define EIGEN_USE_GPU 8 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ 9 | #include "tensorflow/core/framework/op.h" 10 | #include "tensorflow/core/framework/op_kernel.h" 11 | #include "tensorflow/core/framework/shape_inference.h" 12 | #include 13 | 14 | using namespace tensorflow; 15 | using namespace tensorflow::shape_inference; 16 | 17 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) 18 | 19 | //------------------------------------------------------------------------ 20 | // CUDA kernel. 21 | 22 | template 23 | struct FusedBiasActKernelParams 24 | { 25 | const T* x; // [sizeX] 26 | const T* b; // [sizeB] or NULL 27 | const T* ref; // [sizeX] or NULL 28 | T* y; // [sizeX] 29 | 30 | int grad; 31 | int axis; 32 | int act; 33 | float alpha; 34 | float gain; 35 | 36 | int sizeX; 37 | int sizeB; 38 | int stepB; 39 | int loopX; 40 | }; 41 | 42 | template 43 | static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p) 44 | { 45 | const float expRange = 80.0f; 46 | const float halfExpRange = 40.0f; 47 | const float seluScale = 1.0507009873554804934193349852946f; 48 | const float seluAlpha = 1.6732632423543772848170429916717f; 49 | 50 | // Loop over elements. 51 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 52 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 53 | { 54 | // Load and apply bias. 55 | float x = (float)p.x[xi]; 56 | if (p.b) 57 | x += (float)p.b[(xi / p.stepB) % p.sizeB]; 58 | float ref = (p.ref) ? (float)p.ref[xi] : 0.0f; 59 | if (p.gain != 0.0f & p.act != 9) 60 | ref /= p.gain; 61 | 62 | // Evaluate activation func. 63 | float y; 64 | switch (p.act * 10 + p.grad) 65 | { 66 | // linear 67 | default: 68 | case 10: y = x; break; 69 | case 11: y = x; break; 70 | case 12: y = 0.0f; break; 71 | 72 | // relu 73 | case 20: y = (x > 0.0f) ? x : 0.0f; break; 74 | case 21: y = (ref > 0.0f) ? x : 0.0f; break; 75 | case 22: y = 0.0f; break; 76 | 77 | // lrelu 78 | case 30: y = (x > 0.0f) ? x : x * p.alpha; break; 79 | case 31: y = (ref > 0.0f) ? x : x * p.alpha; break; 80 | case 32: y = 0.0f; break; 81 | 82 | // tanh 83 | case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break; 84 | case 41: y = x * (1.0f - ref * ref); break; 85 | case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break; 86 | 87 | // sigmoid 88 | case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break; 89 | case 51: y = x * ref * (1.0f - ref); break; 90 | case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break; 91 | 92 | // elu 93 | case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break; 94 | case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break; 95 | case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break; 96 | 97 | // selu 98 | case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break; 99 | case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break; 100 | case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break; 101 | 102 | // softplus 103 | case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break; 104 | case 81: y = x * (1.0f - expf(-ref)); break; 105 | case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break; 106 | 107 | // swish 108 | case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break; 109 | case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break; 110 | case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break; 111 | } 112 | 113 | // Apply gain and store. 114 | p.y[xi] = (T)(y * p.gain); 115 | } 116 | } 117 | 118 | //------------------------------------------------------------------------ 119 | // TensorFlow op. 120 | 121 | template 122 | struct FusedBiasActOp : public OpKernel 123 | { 124 | FusedBiasActKernelParams m_attribs; 125 | 126 | FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx) 127 | { 128 | memset(&m_attribs, 0, sizeof(m_attribs)); 129 | OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad)); 130 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis)); 131 | OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act)); 132 | OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha)); 133 | OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain)); 134 | OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative")); 135 | OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative")); 136 | OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative")); 137 | } 138 | 139 | void Compute(OpKernelContext* ctx) 140 | { 141 | FusedBiasActKernelParams p = m_attribs; 142 | cudaStream_t stream = ctx->eigen_device().stream(); 143 | 144 | const Tensor& x = ctx->input(0); // [...] 145 | const Tensor& b = ctx->input(1); // [sizeB] or [0] 146 | const Tensor& ref = ctx->input(2); // x.shape or [0] 147 | p.x = x.flat().data(); 148 | p.b = (b.NumElements()) ? b.flat().data() : NULL; 149 | p.ref = (ref.NumElements()) ? ref.flat().data() : NULL; 150 | OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds")); 151 | OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1")); 152 | OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements")); 153 | OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements")); 154 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large")); 155 | 156 | p.sizeX = (int)x.NumElements(); 157 | p.sizeB = (int)b.NumElements(); 158 | p.stepB = 1; 159 | for (int i = m_attribs.axis + 1; i < x.dims(); i++) 160 | p.stepB *= (int)x.dim_size(i); 161 | 162 | Tensor* y = NULL; // x.shape 163 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y)); 164 | p.y = y->flat().data(); 165 | 166 | p.loopX = 4; 167 | int blockSize = 4 * 32; 168 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 169 | void* args[] = {&p}; 170 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream)); 171 | } 172 | }; 173 | 174 | REGISTER_OP("FusedBiasAct") 175 | .Input ("x: T") 176 | .Input ("b: T") 177 | .Input ("ref: T") 178 | .Output ("y: T") 179 | .Attr ("T: {float, half}") 180 | .Attr ("grad: int = 0") 181 | .Attr ("axis: int = 1") 182 | .Attr ("act: int = 0") 183 | .Attr ("alpha: float = 0.0") 184 | .Attr ("gain: float = 1.0"); 185 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); 186 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); 187 | 188 | //------------------------------------------------------------------------ 189 | -------------------------------------------------------------------------------- /tflib/ada.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf, tensorflow.keras.backend as K 2 | import os, sys 3 | import numpy as np 4 | from tensorflow.python.ops import array_ops 5 | import functools 6 | TF_VERSION=float('.'.join(tf.__version__.split('.')[:2])) 7 | ################################################ 8 | def rand_flip_left_right(images): 9 | return tf.image.random_flip_left_right(images) 10 | B, H, W, C = images.shape.as_list() 11 | toss=tf.random_uniform([B])

= 3.5 required to use "@" for matrix mulplication 203 | C@=contrast_matrix(B, p) 204 | C@=luma_flip_matrix(B, p, v) 205 | C@=hue_rotation_matrix(B, p, v) 206 | C@=saturation_matrix(B, p, v) 207 | images=tf.reshape(images@C,[B,H,W,4]) 208 | images=images[:,:,:,:3]#tf.slice(images, [0,0,0,0],[B,H,W,3]) 209 | return images 210 | ################################################################## 211 | #Cutout is based on https://github.com/mit-han-lab/data-efficient-gans/blob/7a1ea3d0a1e467b0c74f3bdb79ef9ace5e41c321/DiffAugment_tf.py#L51-L64 212 | def cutout(x, toss, ratio=[1, 2]): 213 | batch_size = tf.shape(x)[0] 214 | image_size = tf.shape(x)[1:3] 215 | cutout_size = image_size * ratio[0] // ratio[1] 216 | offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32) 217 | offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32) 218 | grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32), tf.range(cutout_size[0], dtype=tf.int32), tf.range(cutout_size[1], dtype=tf.int32), indexing='ij') 219 | cutout_grid = tf.stack([grid_batch, grid_x + offset_x - cutout_size[0] // 2, grid_y + offset_y - cutout_size[1] // 2], axis=-1) 220 | mask_shape = tf.stack([batch_size, image_size[0], image_size[1]]) 221 | cutout_grid = tf.maximum(cutout_grid, 0) 222 | cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3])) 223 | mask = tf.maximum(1 - tf.reshape(toss,[-1,1,1])*tf.scatter_nd(cutout_grid, tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32), mask_shape), 0) 224 | x = x * tf.expand_dims(mask, axis=3) 225 | return x 226 | def rand_cutout(images,p=1): 227 | B,H,W,C=images.shape.as_list() 228 | toss=tf.cast(tf.random.uniform([B]) 13 | 14 | using namespace tensorflow; 15 | using namespace tensorflow::shape_inference; 16 | 17 | //------------------------------------------------------------------------ 18 | // Helpers. 19 | 20 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) 21 | 22 | static __host__ __device__ __forceinline__ int floorDiv(int a, int b) 23 | { 24 | int c = a / b; 25 | if (c * b > a) 26 | c--; 27 | return c; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | // CUDA kernel params. 32 | 33 | template 34 | struct UpFirDn2DKernelParams 35 | { 36 | const T* x; // [majorDim, inH, inW, minorDim] 37 | const T* k; // [kernelH, kernelW] 38 | T* y; // [majorDim, outH, outW, minorDim] 39 | 40 | int upx; 41 | int upy; 42 | int downx; 43 | int downy; 44 | int padx0; 45 | int padx1; 46 | int pady0; 47 | int pady1; 48 | 49 | int majorDim; 50 | int inH; 51 | int inW; 52 | int minorDim; 53 | int kernelH; 54 | int kernelW; 55 | int outH; 56 | int outW; 57 | int loopMajor; 58 | int loopX; 59 | }; 60 | 61 | //------------------------------------------------------------------------ 62 | // General CUDA implementation for large filter kernels. 63 | 64 | template 65 | static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p) 66 | { 67 | // Calculate thread index. 68 | int minorIdx = blockIdx.x * blockDim.x + threadIdx.x; 69 | int outY = minorIdx / p.minorDim; 70 | minorIdx -= outY * p.minorDim; 71 | int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; 72 | int majorIdxBase = blockIdx.z * p.loopMajor; 73 | if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim) 74 | return; 75 | 76 | // Setup Y receptive field. 77 | int midY = outY * p.downy + p.upy - 1 - p.pady0; 78 | int inY = min(max(floorDiv(midY, p.upy), 0), p.inH); 79 | int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY; 80 | int kernelY = midY + p.kernelH - (inY + 1) * p.upy; 81 | 82 | // Loop over majorDim and outX. 83 | for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++) 84 | for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y) 85 | { 86 | // Setup X receptive field. 87 | int midX = outX * p.downx + p.upx - 1 - p.padx0; 88 | int inX = min(max(floorDiv(midX, p.upx), 0), p.inW); 89 | int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX; 90 | int kernelX = midX + p.kernelW - (inX + 1) * p.upx; 91 | 92 | // Initialize pointers. 93 | const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; 94 | const T* kp = &p.k[kernelY * p.kernelW + kernelX]; 95 | int xpx = p.minorDim; 96 | int kpx = -p.upx; 97 | int xpy = p.inW * p.minorDim; 98 | int kpy = -p.upy * p.kernelW; 99 | 100 | // Inner loop. 101 | float v = 0.0f; 102 | for (int y = 0; y < h; y++) 103 | { 104 | for (int x = 0; x < w; x++) 105 | { 106 | v += (float)(*xp) * (float)(*kp); 107 | xp += xpx; 108 | kp += kpx; 109 | } 110 | xp += xpy - w * xpx; 111 | kp += kpy - w * kpx; 112 | } 113 | 114 | // Store result. 115 | p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; 116 | } 117 | } 118 | 119 | //------------------------------------------------------------------------ 120 | // Specialized CUDA implementation for small filter kernels. 121 | 122 | template 123 | static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p) 124 | { 125 | //assert(kernelW % upx == 0); 126 | //assert(kernelH % upy == 0); 127 | const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1; 128 | const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1; 129 | __shared__ volatile float sk[kernelH][kernelW]; 130 | __shared__ volatile float sx[tileInH][tileInW]; 131 | 132 | // Calculate tile index. 133 | int minorIdx = blockIdx.x; 134 | int tileOutY = minorIdx / p.minorDim; 135 | minorIdx -= tileOutY * p.minorDim; 136 | tileOutY *= tileOutH; 137 | int tileOutXBase = blockIdx.y * p.loopX * tileOutW; 138 | int majorIdxBase = blockIdx.z * p.loopMajor; 139 | if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim) 140 | return; 141 | 142 | // Load filter kernel (flipped). 143 | for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x) 144 | { 145 | int ky = tapIdx / kernelW; 146 | int kx = tapIdx - ky * kernelW; 147 | float v = 0.0f; 148 | if (kx < p.kernelW & ky < p.kernelH) 149 | v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)]; 150 | sk[ky][kx] = v; 151 | } 152 | 153 | // Loop over majorDim and outX. 154 | for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++) 155 | for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW) 156 | { 157 | // Load input pixels. 158 | int tileMidX = tileOutX * downx + upx - 1 - p.padx0; 159 | int tileMidY = tileOutY * downy + upy - 1 - p.pady0; 160 | int tileInX = floorDiv(tileMidX, upx); 161 | int tileInY = floorDiv(tileMidY, upy); 162 | __syncthreads(); 163 | for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x) 164 | { 165 | int relInY = inIdx / tileInW; 166 | int relInX = inIdx - relInY * tileInW; 167 | int inX = relInX + tileInX; 168 | int inY = relInY + tileInY; 169 | float v = 0.0f; 170 | if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH) 171 | v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; 172 | sx[relInY][relInX] = v; 173 | } 174 | 175 | // Loop over output pixels. 176 | __syncthreads(); 177 | for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x) 178 | { 179 | int relOutY = outIdx / tileOutW; 180 | int relOutX = outIdx - relOutY * tileOutW; 181 | int outX = relOutX + tileOutX; 182 | int outY = relOutY + tileOutY; 183 | 184 | // Setup receptive field. 185 | int midX = tileMidX + relOutX * downx; 186 | int midY = tileMidY + relOutY * downy; 187 | int inX = floorDiv(midX, upx); 188 | int inY = floorDiv(midY, upy); 189 | int relInX = inX - tileInX; 190 | int relInY = inY - tileInY; 191 | int kernelX = (inX + 1) * upx - midX - 1; // flipped 192 | int kernelY = (inY + 1) * upy - midY - 1; // flipped 193 | 194 | // Inner loop. 195 | float v = 0.0f; 196 | #pragma unroll 197 | for (int y = 0; y < kernelH / upy; y++) 198 | #pragma unroll 199 | for (int x = 0; x < kernelW / upx; x++) 200 | v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx]; 201 | 202 | // Store result. 203 | if (outX < p.outW & outY < p.outH) 204 | p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; 205 | } 206 | } 207 | } 208 | 209 | //------------------------------------------------------------------------ 210 | // TensorFlow op. 211 | 212 | template 213 | struct UpFirDn2DOp : public OpKernel 214 | { 215 | UpFirDn2DKernelParams m_attribs; 216 | 217 | UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx) 218 | { 219 | memset(&m_attribs, 0, sizeof(m_attribs)); 220 | OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx)); 221 | OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy)); 222 | OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx)); 223 | OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy)); 224 | OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0)); 225 | OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1)); 226 | OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0)); 227 | OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1)); 228 | OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1")); 229 | OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1")); 230 | } 231 | 232 | void Compute(OpKernelContext* ctx) 233 | { 234 | UpFirDn2DKernelParams p = m_attribs; 235 | cudaStream_t stream = ctx->eigen_device().stream(); 236 | 237 | const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim] 238 | const Tensor& k = ctx->input(1); // [kernelH, kernelW] 239 | p.x = x.flat().data(); 240 | p.k = k.flat().data(); 241 | OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4")); 242 | OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2")); 243 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large")); 244 | OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large")); 245 | 246 | p.majorDim = (int)x.dim_size(0); 247 | p.inH = (int)x.dim_size(1); 248 | p.inW = (int)x.dim_size(2); 249 | p.minorDim = (int)x.dim_size(3); 250 | p.kernelH = (int)k.dim_size(0); 251 | p.kernelW = (int)k.dim_size(1); 252 | OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1")); 253 | 254 | p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx; 255 | p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy; 256 | OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1")); 257 | 258 | Tensor* y = NULL; // [majorDim, outH, outW, minorDim] 259 | TensorShape ys; 260 | ys.AddDim(p.majorDim); 261 | ys.AddDim(p.outH); 262 | ys.AddDim(p.outW); 263 | ys.AddDim(p.minorDim); 264 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y)); 265 | p.y = y->flat().data(); 266 | OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large")); 267 | 268 | // Choose CUDA kernel to use. 269 | void* cudaKernel = (void*)UpFirDn2DKernel_large; 270 | int tileOutW = -1; 271 | int tileOutH = -1; 272 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 273 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 274 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 275 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 276 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 277 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 278 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 279 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 280 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 281 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 282 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 283 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 284 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 285 | 286 | // Choose launch params. 287 | dim3 blockSize; 288 | dim3 gridSize; 289 | if (tileOutW > 0 && tileOutH > 0) // small 290 | { 291 | p.loopMajor = (p.majorDim - 1) / 16384 + 1; 292 | p.loopX = 1; 293 | blockSize = dim3(32 * 8, 1, 1); 294 | gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1); 295 | } 296 | else // large 297 | { 298 | p.loopMajor = (p.majorDim - 1) / 16384 + 1; 299 | p.loopX = 4; 300 | blockSize = dim3(4, 32, 1); 301 | gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1); 302 | } 303 | 304 | // Launch CUDA kernel. 305 | void* args[] = {&p}; 306 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream)); 307 | } 308 | }; 309 | 310 | REGISTER_OP("UpFirDn2D") 311 | .Input ("x: T") 312 | .Input ("k: T") 313 | .Output ("y: T") 314 | .Attr ("T: {float, half}") 315 | .Attr ("upx: int = 1") 316 | .Attr ("upy: int = 1") 317 | .Attr ("downx: int = 1") 318 | .Attr ("downy: int = 1") 319 | .Attr ("padx0: int = 0") 320 | .Attr ("padx1: int = 0") 321 | .Attr ("pady0: int = 0") 322 | .Attr ("pady1: int = 0"); 323 | REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); 324 | REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); 325 | 326 | //------------------------------------------------------------------------ -------------------------------------------------------------------------------- /tflib/memory_saving_gradients.py: -------------------------------------------------------------------------------- 1 | from toposort import toposort 2 | import contextlib 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow.contrib.graph_editor as ge 6 | import time 7 | import sys 8 | sys.setrecursionlimit(10000) 9 | # refers back to current module if we decide to split helpers out 10 | util = sys.modules[__name__] 11 | 12 | # getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated" 13 | setattr(tf.GraphKeys, "VARIABLES", "variables") 14 | 15 | # save original gradients since tf.gradient could be monkey-patched to point 16 | # to our version 17 | from tensorflow.python.ops import gradients as tf_gradients_lib 18 | tf_gradient_function = tf_gradients_lib.gradients 19 | 20 | # ISSUE: https://github.com/cybertronai/gradient-checkpointing/issues/38 21 | def tf_gradients(ys, *args, **kwargs): 22 | """Decorate tf.gradients calls with explicit device placement to avoid memory 23 | leaks when splitting model across multiple GPUs""" 24 | source = ys[0] if isinstance(ys, (list, tuple)) else ys 25 | device = source.op.node_def.device if isinstance(source, tf.Tensor) else None 26 | with tf.device(device): 27 | return tf_gradient_function(ys, *args, **kwargs) 28 | 29 | 30 | MIN_CHECKPOINT_NODE_SIZE=1024 # use lower value during testing 31 | 32 | # specific versions we can use to do process-wide replacement of tf.gradients 33 | def gradients_speed(ys, xs, grad_ys=None, **kwargs): 34 | return gradients(ys, xs, grad_ys, checkpoints='speed', **kwargs) 35 | 36 | def gradients_memory(ys, xs, grad_ys=None, **kwargs): 37 | return gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs) 38 | 39 | def gradients_collection(ys, xs, grad_ys=None, **kwargs): 40 | return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs) 41 | 42 | def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): 43 | ''' 44 | Authors: Tim Salimans & Yaroslav Bulatov 45 | 46 | memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" 47 | by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) 48 | 49 | ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients 50 | (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 51 | 52 | 'checkpoints' can either be 53 | - a list consisting of tensors from the forward pass of the neural net 54 | that we should re-use when calculating the gradients in the backward pass 55 | all other tensors that do not appear in this list will be re-computed 56 | - a string specifying how this list should be determined. currently we support 57 | - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, 58 | so checkpointing them maximizes the running speed 59 | (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) 60 | - 'memory': try to minimize the memory usage 61 | (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) 62 | - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint 63 | ''' 64 | 65 | # print("Calling memsaving gradients with", checkpoints) 66 | if not isinstance(ys,list): 67 | ys = [ys] 68 | if not isinstance(xs,list): 69 | xs = [xs] 70 | 71 | bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], 72 | inclusive=True) 73 | 74 | debug_print("bwd_ops: %s", bwd_ops) 75 | 76 | # forward ops are all ops that are candidates for recomputation 77 | fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], 78 | inclusive=True, 79 | within_ops=bwd_ops) 80 | debug_print("fwd_ops: %s", fwd_ops) 81 | 82 | # exclude ops with no inputs 83 | fwd_ops = [op for op in fwd_ops if op.inputs] 84 | 85 | # don't recompute xs, remove variables 86 | xs_ops = _to_ops(xs) 87 | fwd_ops = [op for op in fwd_ops if not op in xs_ops] 88 | fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] 89 | fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] 90 | fwd_ops = [op for op in fwd_ops if not '/read' in op.name] 91 | ts_all = ge.filter_ts(fwd_ops, True) # get the tensors 92 | ts_all = [t for t in ts_all if '/read' not in t.name] 93 | ts_all = set(ts_all) - set(xs) - set(ys) 94 | 95 | # construct list of tensors to checkpoint during forward pass, if not 96 | # given as input 97 | if type(checkpoints) is not list: 98 | if checkpoints == 'collection': 99 | checkpoints = tf.get_collection('checkpoints') 100 | 101 | elif checkpoints == 'speed': 102 | # checkpoint all expensive ops to maximize running speed 103 | checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') 104 | 105 | elif checkpoints == 'memory': 106 | 107 | # remove very small tensors and some weird ops 108 | def fixdims(t): # tf.Dimension values are not compatible with int, convert manually 109 | try: 110 | return [int(e if e.value is not None else 64) for e in t] 111 | except: 112 | return [0] # unknown shape 113 | ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE] 114 | ts_all = [t for t in ts_all if 'L2Loss' not in t.name] 115 | ts_all = [t for t in ts_all if 'entropy' not in t.name] 116 | ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] 117 | ts_all = [t for t in ts_all if 'Switch' not in t.name] 118 | ts_all = [t for t in ts_all if 'dropout' not in t.name] 119 | # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 120 | ts_all = [t for t in ts_all if 'Cast' not in t.name] 121 | 122 | # filter out all tensors that are inputs of the backward graph 123 | with util.capture_ops() as bwd_ops: 124 | tf_gradients(ys, xs, grad_ys, **kwargs) 125 | 126 | bwd_inputs = [t for op in bwd_ops for t in op.inputs] 127 | # list of tensors in forward graph that is in input to bwd graph 128 | ts_filtered = list(set(bwd_inputs).intersection(ts_all)) 129 | debug_print("Using tensors %s", ts_filtered) 130 | 131 | # try two slightly different ways of getting bottlenecks tensors 132 | # to checkpoint 133 | for ts in [ts_filtered, ts_all]: 134 | 135 | # get all bottlenecks in the graph 136 | bottleneck_ts = [] 137 | for t in ts: 138 | b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) 139 | f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) 140 | # check that there are not shortcuts 141 | b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) 142 | f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) 143 | if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): 144 | bottleneck_ts.append(t) # we have a bottleneck! 145 | else: 146 | debug_print("Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp))) 147 | 148 | # success? or try again without filtering? 149 | if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found! 150 | break 151 | 152 | if not bottleneck_ts: 153 | raise Exception('unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".') 154 | 155 | # sort the bottlenecks 156 | bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) 157 | sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] 158 | 159 | # save an approximately optimal number ~ sqrt(N) 160 | N = len(ts_filtered) 161 | if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): 162 | checkpoints = sorted_bottlenecks 163 | else: 164 | step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) 165 | checkpoints = sorted_bottlenecks[step::step] 166 | 167 | else: 168 | raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,)) 169 | 170 | checkpoints = list(ts_all) 171 | 172 | # at this point automatic selection happened and checkpoints is list of nodes 173 | assert isinstance(checkpoints, list) 174 | 175 | print("Checkpoint nodes used: %s"% checkpoints) 176 | # better error handling of special cases 177 | # xs are already handled as checkpoint nodes, so no need to include them 178 | xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) 179 | if xs_intersect_checkpoints: 180 | print("Warning, some input nodes are also checkpoint nodes: %s", 181 | xs_intersect_checkpoints) 182 | ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) 183 | print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, 184 | ys_intersect_checkpoints) 185 | # saving an output node (ys) gives no benefit in memory while creating 186 | # new edge cases, exclude them 187 | if ys_intersect_checkpoints: 188 | print("Warning, some output nodes are also checkpoints nodes: %s", 189 | format_ops(ys_intersect_checkpoints)) 190 | 191 | # remove initial and terminal nodes from checkpoints list if present 192 | checkpoints = list(set(checkpoints) - set(ys) - set(xs)) 193 | 194 | # check that we have some nodes to checkpoint 195 | if not checkpoints: 196 | raise Exception('no checkpoints nodes found or given as input! ') 197 | 198 | # disconnect dependencies between checkpointed tensors 199 | checkpoints_disconnected = {} 200 | for x in checkpoints: 201 | if x.op and x.op.name is not None: 202 | grad_node = tf.stop_gradient(x, name=x.op.name+"_sg") 203 | else: 204 | grad_node = tf.stop_gradient(x) 205 | grad_node.op._set_device(x.op.node_def.device) 206 | checkpoints_disconnected[x] = grad_node 207 | 208 | # partial derivatives to the checkpointed tensors and xs 209 | ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], 210 | stop_at_ts=checkpoints, within_ops=fwd_ops) 211 | debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", 212 | len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) 213 | debug_print("ops_to_copy = %s", ops_to_copy) 214 | debug_print("Processing list %s", ys) 215 | copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) 216 | for origin_op, op in info._transformed_ops.items(): 217 | op._set_device(origin_op.node_def.device) 218 | copied_ops = info._transformed_ops.values() 219 | debug_print("Copied %s to %s", ops_to_copy, copied_ops) 220 | ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) 221 | debug_print("Rewired %s in place of %s restricted to %s", 222 | checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) 223 | 224 | # get gradients with respect to current boundary + original x's 225 | copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] 226 | boundary = list(checkpoints_disconnected.values()) 227 | dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs) 228 | debug_print("Got gradients %s", dv) 229 | debug_print("for %s", copied_ys) 230 | debug_print("with respect to %s", boundary+xs) 231 | 232 | inputs_to_do_before = [y.op for y in ys] 233 | if grad_ys is not None: 234 | inputs_to_do_before += grad_ys 235 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] 236 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) 237 | 238 | # partial derivatives to the checkpointed nodes 239 | # dictionary of "node: backprop" for nodes in the boundary 240 | d_checkpoints = {r: dr for r,dr in zip(checkpoints_disconnected.keys(), 241 | dv[:len(checkpoints_disconnected)])} 242 | # partial derivatives to xs (usually the params of the neural net) 243 | d_xs = dv[len(checkpoints_disconnected):] 244 | 245 | # incorporate derivatives flowing through the checkpointed nodes 246 | checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) 247 | for ts in checkpoints_sorted_lists[::-1]: 248 | debug_print("Processing list %s", ts) 249 | checkpoints_other = [r for r in checkpoints if r not in ts] 250 | checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other] 251 | 252 | # copy part of the graph below current checkpoint node, stopping at 253 | # other checkpoints nodes 254 | ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) 255 | debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", 256 | len(ops_to_copy), fwd_ops, [r.op for r in ts], 257 | checkpoints_other) 258 | debug_print("ops_to_copy = %s", ops_to_copy) 259 | if not ops_to_copy: # we're done! 260 | break 261 | copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) 262 | for origin_op, op in info._transformed_ops.items(): 263 | op._set_device(origin_op.node_def.device) 264 | copied_ops = info._transformed_ops.values() 265 | debug_print("Copied %s to %s", ops_to_copy, copied_ops) 266 | ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) 267 | debug_print("Rewired %s in place of %s restricted to %s", 268 | checkpoints_disconnected_other, checkpoints_other, copied_ops) 269 | 270 | # gradient flowing through the checkpointed node 271 | boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] 272 | substitute_backprops = [d_checkpoints[r] for r in ts] 273 | dv = tf_gradients(boundary, 274 | checkpoints_disconnected_other+xs, 275 | grad_ys=substitute_backprops, **kwargs) 276 | debug_print("Got gradients %s", dv) 277 | debug_print("for %s", boundary) 278 | debug_print("with respect to %s", checkpoints_disconnected_other+xs) 279 | debug_print("with boundary backprop substitutions %s", substitute_backprops) 280 | 281 | inputs_to_do_before = [d_checkpoints[r].op for r in ts if d_checkpoints[r] is not None] 282 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] 283 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) 284 | 285 | # partial derivatives to the checkpointed nodes 286 | for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): 287 | if dr is not None: 288 | if d_checkpoints[r] is None: 289 | d_checkpoints[r] = dr 290 | else: 291 | d_checkpoints[r] += dr 292 | def _unsparsify(x): 293 | if not isinstance(x, tf.IndexedSlices): 294 | return x 295 | assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" 296 | indices = x.indices 297 | while indices.shape.ndims < x.values.shape.ndims: 298 | indices = tf.expand_dims(indices, -1) 299 | return tf.scatter_nd(indices, x.values, x.dense_shape) 300 | 301 | # partial derivatives to xs (usually the params of the neural net) 302 | d_xs_new = dv[len(checkpoints_other):] 303 | for j in range(len(xs)): 304 | if d_xs_new[j] is not None: 305 | if d_xs[j] is None: 306 | d_xs[j] = _unsparsify(d_xs_new[j]) 307 | else: 308 | d_xs[j] += _unsparsify(d_xs_new[j]) 309 | 310 | 311 | return d_xs 312 | 313 | def tf_toposort(ts, within_ops=None): 314 | all_ops = ge.get_forward_walk_ops([x.op for x in ts], within_ops=within_ops) 315 | 316 | deps = {} 317 | for op in all_ops: 318 | for o in op.outputs: 319 | deps[o] = set(op.inputs) 320 | sorted_ts = toposort(deps) 321 | 322 | # only keep the tensors from our original list 323 | ts_sorted_lists = [] 324 | for l in sorted_ts: 325 | keep = list(set(l).intersection(ts)) 326 | if keep: 327 | ts_sorted_lists.append(keep) 328 | 329 | return ts_sorted_lists 330 | 331 | def fast_backward_ops(within_ops, seed_ops, stop_at_ts): 332 | bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts)) 333 | ops = bwd_ops.intersection(within_ops).difference([t.op for t in stop_at_ts]) 334 | return list(ops) 335 | 336 | @contextlib.contextmanager 337 | def capture_ops(): 338 | """Decorator to capture ops created in the block. 339 | with capture_ops() as ops: 340 | # create some ops 341 | print(ops) # => prints ops created. 342 | """ 343 | 344 | micros = int(time.time()*10**6) 345 | scope_name = str(micros) 346 | op_list = [] 347 | with tf.name_scope(scope_name): 348 | yield op_list 349 | 350 | g = tf.get_default_graph() 351 | op_list.extend(ge.select_ops(scope_name+"/.*", graph=g)) 352 | 353 | def _to_op(tensor_or_op): 354 | if hasattr(tensor_or_op, "op"): 355 | return tensor_or_op.op 356 | return tensor_or_op 357 | 358 | def _to_ops(iterable): 359 | if not _is_iterable(iterable): 360 | return iterable 361 | return [_to_op(i) for i in iterable] 362 | 363 | def _is_iterable(o): 364 | try: 365 | _ = iter(o) 366 | except Exception: 367 | return False 368 | return True 369 | 370 | DEBUG_LOGGING=False 371 | def debug_print(s, *args): 372 | """Like logger.log, but also replaces all TensorFlow ops/tensors with their 373 | names. Sensitive to value of DEBUG_LOGGING, see enable_debug/disable_debug 374 | 375 | Usage: 376 | debug_print("see tensors %s for %s", tensorlist, [1,2,3]) 377 | """ 378 | 379 | if DEBUG_LOGGING: 380 | formatted_args = [format_ops(arg) for arg in args] 381 | print("DEBUG "+s % tuple(formatted_args)) 382 | 383 | def format_ops(ops, sort_outputs=True): 384 | """Helper method for printing ops. Converts Tensor/Operation op to op.name, 385 | rest to str(op).""" 386 | 387 | if hasattr(ops, '__iter__') and not isinstance(ops, str): 388 | l = [(op.name if hasattr(op, "name") else str(op)) for op in ops] 389 | if sort_outputs: 390 | return sorted(l) 391 | return l 392 | else: 393 | return ops.name if hasattr(ops, "name") else str(ops) 394 | 395 | def my_add_control_inputs(wait_to_do_ops, inputs_to_do_before): 396 | for op in wait_to_do_ops: 397 | ci = [i for i in inputs_to_do_before if op.control_inputs is None or i not in op.control_inputs] 398 | ge.add_control_inputs(op, ci) --------------------------------------------------------------------------------