├── detector ├── __init__.py ├── layers.py ├── c_512_4x4_32.py └── nn.py ├── github ├── teaser.png ├── detection.png ├── lesion_manipulation.png └── real_and_synthesized.png ├── data ├── IDRiD │ ├── mask.png │ ├── to_npy.py │ └── convert.py ├── detector.h5.download.txt ├── download_datasets.txt ├── imagenet-vgg-verydeep-19.mat.download.txt ├── retinal-lesions │ ├── to_npy.py │ └── convert.py └── FGADR │ ├── to_npy.py │ └── convert.py ├── requirements.txt ├── .gitignore ├── utils.py ├── vgg.py ├── DMB_build.py ├── DMB_build_test_samples.py ├── Opts.py ├── tfpipe_dump_activation.py ├── DMB_build_FGADR.py ├── Test_reconstruct_DMB.py ├── Test_reconstruct_DMB_FGADR.py ├── README.md ├── Net.py ├── Test_reconstruct_DMB_randomize.py ├── Test_reconstruct_DMB_randomize_FGADR.py ├── Test_reconstruct_DMB_numberadjust.py ├── Test_reconstruct_DMB_numberadjust_FGADR.py ├── StyleFeature.py ├── DMB_fragment.py ├── dataBlocks.py └── Train.py /detector/__init__.py: -------------------------------------------------------------------------------- 1 | from .nn import * 2 | from . import layers 3 | -------------------------------------------------------------------------------- /github/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzdyyy/Patho-GAN/HEAD/github/teaser.png -------------------------------------------------------------------------------- /data/IDRiD/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzdyyy/Patho-GAN/HEAD/data/IDRiD/mask.png -------------------------------------------------------------------------------- /github/detection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzdyyy/Patho-GAN/HEAD/github/detection.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.8 2 | Keras==2.2 3 | scikit-image 4 | opencv-python 5 | gdown -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /DMB/ 3 | /Model_and_Result/ 4 | /Visualization/ 5 | /Test/ 6 | 7 | *.h5 8 | *.npy 9 | *.mat -------------------------------------------------------------------------------- /github/lesion_manipulation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzdyyy/Patho-GAN/HEAD/github/lesion_manipulation.png -------------------------------------------------------------------------------- /github/real_and_synthesized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzdyyy/Patho-GAN/HEAD/github/real_and_synthesized.png -------------------------------------------------------------------------------- /data/detector.h5.download.txt: -------------------------------------------------------------------------------- 1 | Download detector.h5 from https://drive.google.com/open?id=1OI1d3XWM7IyW2igIEq8s-ZyF9vw0vTiw 2 | $ md5sum detector.h5 3 | 09963e46e5ae7ea537fbb070f4251668 detector.h5 4 | -------------------------------------------------------------------------------- /data/download_datasets.txt: -------------------------------------------------------------------------------- 1 | IDRiD: https://ieee-dataport.org/open-access/indian-diabetic-retinopathy-image-dataset-idrid 2 | retinal-lesions: https://github.com/WeiQijie/retinal-lesions 3 | FGADR: https://csyizhou.github.io/FGADR/ 4 | -------------------------------------------------------------------------------- /data/imagenet-vgg-verydeep-19.mat.download.txt: -------------------------------------------------------------------------------- 1 | $ wget http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat 2 | $ md5sum imagenet-vgg-verydeep-19.mat 3 | 8ee3263992981a1d26e73b3ca028a123 imagenet-vgg-verydeep-19.mat 4 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import scipy.misc, numpy as np, os, sys 2 | 3 | 4 | def save_img(out_path, img): 5 | img = np.clip(img, 0, 255).astype(np.uint8) 6 | scipy.misc.imsave(out_path, img) 7 | 8 | def scale_img(style_path, style_scale): 9 | scale = float(style_scale) 10 | o0, o1, o2 = scipy.misc.imread(style_path, mode='RGB').shape 11 | scale = float(style_scale) 12 | new_shape = (int(o0 * scale), int(o1 * scale), o2) 13 | style_target = get_img(style_path, img_size=new_shape) 14 | return style_target 15 | 16 | def get_img(src, img_size=False): 17 | img = scipy.misc.imread(src, mode='RGB') # misc.imresize(, (256, 256, 3)) 18 | if not (len(img.shape) == 3 and img.shape[2] == 3): 19 | img = np.dstack((img,img,img)) 20 | if img_size != False: 21 | img = scipy.misc.imresize(img, img_size) 22 | return img 23 | 24 | def exists(p, msg): 25 | assert os.path.exists(p), msg 26 | 27 | def list_files(in_path): 28 | files = [] 29 | for (dirpath, dirnames, filenames) in os.walk(in_path): 30 | files.extend(filenames) 31 | break 32 | 33 | return files 34 | 35 | -------------------------------------------------------------------------------- /data/IDRiD/to_npy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import glob 4 | import yaml 5 | import os 6 | 7 | dataset_name='IDRiD' 8 | for mode in ['train', 'test']: 9 | filelist = sorted(glob.glob(mode+'_512/*.jpg')) 10 | 11 | images = np.stack(list(map(cv2.imread, filelist))) 12 | images = images/255. 13 | images = images[..., ::-1] # BGR2RGB 14 | images = images.astype('float32') 15 | print('images', images.shape, images.dtype) 16 | np.save('../'+dataset_name+'_'+mode+'_image.npy', images) 17 | 18 | mask = cv2.imread('mask.png', cv2.IMREAD_GRAYSCALE) 19 | mask = mask/255. 20 | mask = mask.astype('float32') 21 | mask = np.tile(mask[None, ...], (images.shape[0], 1, 1)) 22 | print('mask', mask.shape, mask.dtype) 23 | np.save('../'+dataset_name+'_'+mode+'_mask.npy', mask) 24 | 25 | gt1 = np.stack(list(map(lambda x: cv2.imread(x.replace('.jpg', '_VS.png'), cv2.IMREAD_GRAYSCALE), filelist))) 26 | gt1 = gt1/255. 27 | gt1 = gt1.astype('float32') 28 | # gt2 = np.load('../../Visualization/'+dataset_name+'_'+mode+'_dump.npy') 29 | # gt = np.concatenate([gt1[..., None], gt2], 3) 30 | # print('gt1', gt1.shape, gt1.dtype, 'gt2', gt2.shape, gt2.dtype, 'gt', gt.shape, gt.dtype) 31 | np.save('../'+dataset_name+'_'+mode+'_gt.npy', gt1[..., None]) 32 | 33 | # dump file_name list 34 | with open('../'+dataset_name+'_'+mode+'.list', 'w') as f: 35 | yaml.dump(list(map(os.path.basename, filelist)), f) 36 | -------------------------------------------------------------------------------- /data/retinal-lesions/to_npy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import glob 4 | import yaml 5 | import os 6 | 7 | dataset_name='retinal-lesions' 8 | for mode in ['train', 'test']: 9 | filelist = { 10 | 'train': sorted(glob.glob('resized_512/3*.jpg')), # 337 images 11 | 'test': sorted(glob.glob('resized_512/[124-9]*.jpg')) # 1256 images 12 | }[mode] 13 | 14 | images = np.stack(list(map(cv2.imread, filelist))) 15 | images = images/255. 16 | images = images[..., ::-1] # BGR2RGB 17 | images = images.astype('float32') 18 | print('images', images.shape, images.dtype) 19 | np.save('../'+dataset_name+'_'+mode+'_image.npy', images) 20 | 21 | mask = np.stack(list(map(lambda x:cv2.imread(x.replace('.jpg', '_MASK.png'), cv2.IMREAD_GRAYSCALE), filelist))) 22 | mask = mask/255. 23 | mask = mask.astype('float32') 24 | print('mask', mask.shape, mask.dtype) 25 | np.save('../'+dataset_name+'_'+mode+'_mask.npy', mask) 26 | 27 | gt1 = np.stack(list(map(lambda x: cv2.imread(x.replace('.jpg', '_VS.png'), cv2.IMREAD_GRAYSCALE), filelist))) 28 | gt1 = gt1/255. 29 | gt1 = gt1.astype('float32') 30 | # gt2 = np.load('../../Visualization/'+dataset_name+'_'+mode+'_dump.npy') 31 | # gt = np.concatenate([gt1[..., None], gt2], 3) 32 | # print('gt1', gt1.shape, gt1.dtype, 'gt2', gt2.shape, gt2.dtype, 'gt', gt.shape, gt.dtype) 33 | np.save('../'+dataset_name+'_'+mode+'_gt.npy', gt1[..., None]) 34 | 35 | # dump file_name list 36 | with open('../'+dataset_name+'_'+mode+'.list', 'w') as f: 37 | yaml.dump(list(map(os.path.basename, filelist)), f) 38 | -------------------------------------------------------------------------------- /data/FGADR/to_npy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import glob 4 | import yaml 5 | import os 6 | 7 | dataset_name='FGADR' 8 | for mode in ['train', 'test']: 9 | filelist_all = sorted(glob.glob('resized_512/*.jpg')) # 1842 images 10 | myslice = { 11 | 'train': slice(0, 500), # 500 images 12 | 'test': slice(500, None) # 1342 images 13 | }[mode] 14 | filelist = filelist_all[myslice] 15 | 16 | images = np.stack(list(map(cv2.imread, filelist))) 17 | images = images/255. 18 | images = images[..., ::-1] # BGR2RGB 19 | images = images.astype('float32') 20 | print('images', images.shape, images.dtype) 21 | np.save('../'+dataset_name+'_'+mode+'_image.npy', images) 22 | 23 | mask = np.stack(list(map(lambda x:cv2.imread(x.replace('.jpg', '_MASK.png'), cv2.IMREAD_GRAYSCALE), filelist))) 24 | mask = mask/255. 25 | mask = mask.astype('float32') 26 | print('mask', mask.shape, mask.dtype) 27 | np.save('../'+dataset_name+'_'+mode+'_mask.npy', mask) 28 | 29 | gt1 = np.stack(list(map(lambda x: cv2.imread(x.replace('.jpg', '_VS.png'), cv2.IMREAD_GRAYSCALE), filelist))) 30 | gt1 = gt1/255. 31 | gt1 = gt1.astype('float32') 32 | # gt2 = np.load('../../Visualization/'+dataset_name+'_dump.npy')[myslice, ...] 33 | # gt = np.concatenate([gt1[..., None], gt2], 3) 34 | # print('gt1', gt1.shape, gt1.dtype, 'gt2', gt2.shape, gt2.dtype, 'gt', gt.shape, gt.dtype) 35 | np.save('../'+dataset_name+'_'+mode+'_gt.npy', gt1[..., None]) 36 | 37 | # dump file_name list 38 | with open('../'+dataset_name+'_'+mode+'.list', 'w') as f: 39 | yaml.dump(list(map(os.path.basename, filelist)), f) 40 | -------------------------------------------------------------------------------- /detector/layers.py: -------------------------------------------------------------------------------- 1 | """prepare some layers and their default parameters for configs/*.py""" 2 | 3 | import keras 4 | from keras.layers import * 5 | from keras import backend as K 6 | from keras.regularizers import l1_l2, l2 7 | 8 | leaky_rectify_alpha = 0.01 9 | regular_factor_l1 = 0. 10 | regular_factor_l2 = 5e-4 # weight_decay 11 | 12 | def conv_params(filters, **kwargs): 13 | """default Conv2d arguments""" 14 | args = { 15 | 'filters': filters, 16 | 'kernel_size': (3, 3), 17 | 'padding': 'same', 18 | 'activation': lambda x: keras.activations.relu(x, leaky_rectify_alpha), 19 | 'use_bias': True, 20 | 'kernel_initializer': 'zero', 21 | 'bias_initializer': 'zero', 22 | } 23 | args.update(kwargs) 24 | return args 25 | 26 | 27 | def pool_params(**kwargs): 28 | """default MaxPool2d/RMSPoolLayer arguments""" 29 | args = { 30 | 'pool_size': 3, 31 | 'strides': (2, 2), 32 | } 33 | args.update(kwargs) 34 | return args 35 | 36 | 37 | def dense_params(num_units, **kwargs): 38 | """default dense layer arguments""" 39 | args = { 40 | 'units': num_units, 41 | 'activation': lambda x: keras.activations.relu(x, leaky_rectify_alpha), 42 | 'kernel_initializer': 'zero', 43 | 'bias_initializer': 'zero', 44 | } 45 | args.update(kwargs) 46 | return args 47 | 48 | 49 | class RMSPoolLayer(keras.layers.pooling._Pooling2D): 50 | """Use RMS(Root Mean Squared) as pooling function. 51 | 52 | origin version from https://github.com/benanne/kaggle-ndsb/blob/master/tmp_dnn.py 53 | """ 54 | def __init__(self, *args, **kwargs): 55 | super(RMSPoolLayer, self).__init__(*args, **kwargs) 56 | 57 | def _pooling_function(self, inputs, pool_size, strides, 58 | padding, data_format): 59 | output = K.pool2d(K.square(inputs), pool_size, strides, 60 | padding, data_format, pool_mode='avg') 61 | return K.sqrt(output + K.epsilon()) 62 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-2016 Anish Athalye. Released under GPLv3. 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import scipy.io 6 | import pdb 7 | 8 | MEAN_PIXEL = np.array([ 123.68 , 116.779, 103.939]) 9 | 10 | def net(data, input_image): 11 | with tf.variable_scope('vgg'): 12 | layers = ( 13 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 14 | 15 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 16 | 17 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 18 | 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 19 | 20 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 21 | 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 22 | 23 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 24 | 'relu5_3', 'conv5_4', 'relu5_4' 25 | ) 26 | 27 | #data = scipy.io.loadmat(data_path) 28 | mean = data['normalization'][0][0][0] 29 | mean_pixel = np.mean(mean, axis=(0, 1)) 30 | weights = data['layers'][0] 31 | 32 | net = {} 33 | current = input_image 34 | for i, name in enumerate(layers): 35 | kind = name[:4] 36 | if kind == 'conv': 37 | kernels, bias = weights[i][0][0][0][0] 38 | # matconvnet: weights are [width, height, in_channels, out_channels] 39 | # tensorflow: weights are [height, width, in_channels, out_channels] 40 | kernels = np.transpose(kernels, (1, 0, 2, 3)) 41 | bias = bias.reshape(-1) 42 | current = _conv_layer(current, kernels, bias) 43 | elif kind == 'relu': 44 | current = tf.nn.relu(current) 45 | elif kind == 'pool': 46 | current = _pool_layer(current) 47 | net[name] = current 48 | 49 | assert len(net) == len(layers) 50 | return net 51 | 52 | 53 | def _conv_layer(input, weights, bias): 54 | conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1), 55 | padding='SAME') 56 | return tf.nn.bias_add(conv, bias) 57 | 58 | 59 | def _pool_layer(input): 60 | return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), 61 | padding='SAME') 62 | 63 | 64 | def preprocess(image): 65 | return image - MEAN_PIXEL 66 | 67 | 68 | def unprocess(image): 69 | return image + MEAN_PIXEL 70 | -------------------------------------------------------------------------------- /detector/c_512_4x4_32.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # from config import Config 4 | # from data import BALANCE_WEIGHTS 5 | from .layers import * 6 | 7 | cnf = { 8 | 'name': __name__.split('.')[-1], 9 | 'w': 448, 10 | 'h': 448, 11 | # 'train_dir': 'data/train_medium', 12 | # 'test_dir': 'data/test_medium', 13 | # 'batch_size_train': 48, 14 | # 'batch_size_test': 8, 15 | # 'balance_weights': np.array(BALANCE_WEIGHTS), 16 | # 'balance_ratio': 0.975, 17 | # 'final_balance_weights': np.array([1, 2, 2, 2, 2], dtype=float), 18 | # 'aug_params': { 19 | # 'zoom_range': (1 / 1.15, 1.15), 20 | # 'rotation_range': (0, 360), 21 | # 'shear_range': (0, 0), 22 | # 'translation_range': (-40, 40), 23 | # 'do_flip': True, 24 | # 'allow_stretch': True, 25 | # }, 26 | # 'sigma': 0.25, 27 | # 'schedule': { 28 | # 0: 0.003, 29 | # 150: 0.0003, 30 | # 220: 0.00003, 31 | # 251: 'stop', 32 | # }, 33 | } 34 | 35 | def cp(filters, **kwargs): 36 | args = { 37 | 'filters': filters, 38 | 'kernel_size': (4, 4), 39 | } 40 | args.update(kwargs) 41 | return conv_params(**args) 42 | 43 | n = 32 44 | 45 | layers = [ 46 | (InputLayer, {'input_shape': (cnf['h'], cnf['w'], 3)}), 47 | (Conv2D, cp(n, strides=(2, 2))), 48 | (ZeroPadding2D, {'padding': 2}), (Conv2D, cp(n, 49 | kernel_regularizer=l1_l2(regular_factor_l1, regular_factor_l2), 50 | bias_regularizer=l1_l2(regular_factor_l1, regular_factor_l2))), 51 | (MaxPool2D, pool_params()), 52 | (Conv2D, cp(2 * n, strides=(2, 2))), 53 | (ZeroPadding2D, {'padding': 2}), (Conv2D, cp(2 * n)), 54 | (Conv2D, cp(2 * n)), 55 | (MaxPool2D, pool_params()), 56 | (ZeroPadding2D, {'padding': 2}), (Conv2D, cp(4 * n)), 57 | (Conv2D, cp(4 * n)), 58 | (ZeroPadding2D, {'padding': 2}), (Conv2D, cp(4 * n)), 59 | (MaxPool2D, pool_params()), 60 | (ZeroPadding2D, {'padding': 2}), (Conv2D, cp(8 * n)), 61 | (Conv2D, cp(8 * n)), 62 | (ZeroPadding2D, {'padding': 2}), (Conv2D, cp(8 * n)), 63 | (MaxPool2D, pool_params()), 64 | (Conv2D, cp(16 * n)), 65 | (RMSPoolLayer, pool_params()), 66 | (Dropout, {'rate': 0.5}), 67 | (Flatten, {}), (Dense, dense_params(1024)), 68 | (Reshape, {'target_shape': (-1, 1)}), (MaxPooling1D, {'pool_size': 2}), 69 | (Dropout, {'rate': 0.5}), 70 | (Flatten, {}), (Dense, dense_params(1024)), 71 | (Reshape, {'target_shape': (-1, 1)}), (MaxPooling1D, {'pool_size': 2}), 72 | (Flatten, {}), (Dense, dense_params(1, activation='linear')), 73 | ] 74 | 75 | # config = Config(layers=layers, cnf=cnf) 76 | -------------------------------------------------------------------------------- /DMB_build.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('dataset_name') 5 | parser.add_argument('--gpus') 6 | args = parser.parse_args() 7 | print(args) 8 | 9 | 10 | import os 11 | if args.gpus: 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 13 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 14 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 15 | 16 | import keras.backend as K 17 | import tensorflow as tf 18 | import numpy as np 19 | import StyleFeature 20 | import cv2 21 | 22 | from DMB_fragment import extract_descriptors 23 | import pickle 24 | 25 | # =============================== path set =============================================== # 26 | dataset = args.dataset_name 27 | 28 | img_channel = 3 29 | img_size = 512 30 | img_x = 512 31 | img_y = 512 32 | img = tf.placeholder(shape=[None, img_size, img_size, img_channel], dtype=tf.float32) 33 | mask = tf.placeholder(shape=[None, img_size, img_size, 1], dtype=tf.float32) 34 | 35 | img_detector = StyleFeature.get_style_model(img, mask, with_feature_mask_from=('dense_2', [])) 36 | act_input = { 37 | size: tf.image.resize_images((img_detector.get_layer(layer_name).related_projection.output - mean)/std, [size,size]) 38 | for layer_name, size, mean, std in zip(StyleFeature.STYLE_LAYERS, 39 | StyleFeature.STYLE_LAYERS_SIZE, 40 | StyleFeature.STYLE_LAYERS_MEAN, 41 | StyleFeature.STYLE_LAYERS_STD) 42 | } 43 | 44 | img_sample = np.load('data/{}_test_image.npy'.format(dataset)) # [n, h, w, 3] original fundus 45 | mask_sample = np.load('data/{}_test_mask.npy'.format(dataset)) # [n, h, w] FOV mask 46 | activation_maps = np.load('Visualization/{}_test.npy'.format(dataset)) # [n, h, w, 3] AMaps generated with tfpipe_dump_activation.py 47 | segmentation_labels = np.load('data/{}_test_mask.npy'.format(dataset))[..., 1:] # Deprecated 48 | 49 | img_sample = (np.reshape(img_sample, [-1, img_x, img_y, img_channel]) - 0.5) * 2.0 50 | mask_sample = (np.reshape(mask_sample, [-1, img_x, img_y, 1]) - 0.5) * 2.0 51 | activation_maps = (activation_maps - 0.5) * 2.0 52 | 53 | 54 | # extract all descriptors 55 | descriptors = [] 56 | for i, (img_array, mask_array, amap_array, seg_label) in enumerate(zip(img_sample, mask_sample, activation_maps, segmentation_labels)): 57 | print('img:', i) 58 | 59 | intermed_amap = K.get_session().run(act_input, feed_dict={img: [img_array], mask: [mask_array]}) 60 | 61 | descriptors.extend(extract_descriptors(intermed_amap, amap_array, seg_label, dataset, i)) 62 | 63 | if not os.path.exists('DMB'): 64 | os.makedirs('DMB') 65 | 66 | # resort fragments by category 67 | fragments_by_category = [[frag for frag in descriptors if frag[-2] == i] for i in range(len(img_sample))] 68 | with open('DMB/{}.by_img'.format(dataset), 'wb') as file: 69 | pickle.dump(fragments_by_category, file) 70 | 71 | # resort fragments by category 72 | fragments_by_category = [[frag for frag in descriptors if frag[-1] == i] for i in range(-1, segmentation_labels.shape[-1])] 73 | with open('DMB/{}.by_cat'.format(dataset), 'wb') as file: 74 | pickle.dump(fragments_by_category, file) -------------------------------------------------------------------------------- /DMB_build_test_samples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('dataset_name') 5 | parser.add_argument('test1') 6 | parser.add_argument('test2') 7 | parser.add_argument('test3') 8 | parser.add_argument('test4') 9 | parser.add_argument('--gpus') 10 | args = parser.parse_args() 11 | print(args) 12 | 13 | 14 | import os 15 | if args.gpus: 16 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 17 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 18 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 19 | 20 | import keras.backend as K 21 | import tensorflow as tf 22 | import numpy as np 23 | import StyleFeature 24 | import cv2 25 | 26 | from DMB_fragment import extract_descriptors, rebuild_AMaps_by_img 27 | import pickle 28 | import yaml 29 | 30 | # =============================== path set =============================================== # 31 | dataset = args.dataset_name 32 | 33 | img_channel = 3 34 | img_size = 512 35 | img_x = 512 36 | img_y = 512 37 | img = tf.placeholder(shape=[None, img_size, img_size, img_channel], dtype=tf.float32) 38 | mask = tf.placeholder(shape=[None, img_size, img_size, 1], dtype=tf.float32) 39 | 40 | img_detector = StyleFeature.get_style_model(img, mask, with_feature_mask_from=('dense_2', [])) 41 | act_input = { 42 | size: tf.image.resize_images((img_detector.get_layer(layer_name).related_projection.output - mean)/std, [size,size]) 43 | for layer_name, size, mean, std in zip(StyleFeature.STYLE_LAYERS, 44 | StyleFeature.STYLE_LAYERS_SIZE, 45 | StyleFeature.STYLE_LAYERS_MEAN, 46 | StyleFeature.STYLE_LAYERS_STD) 47 | } 48 | 49 | with open('data/'+dataset+'_test.list', 'r') as f: 50 | file_list = yaml.safe_load(f) 51 | select = [file_list.index(args.test1), file_list.index(args.test2), file_list.index(args.test3), file_list.index(args.test4), ] 52 | 53 | img_sample = np.load('data/{}_test_image.npy'.format(dataset))[select, ...] # [4, h, w, 3] original fundus 54 | mask_sample = np.load('data/{}_test_mask.npy'.format(dataset))[select, ...] # [4, h, w] FOV mask 55 | activation_maps = np.load('Visualization/{}_test.npy'.format(dataset))[select, ...] # [4, h, w, 3] AMaps generated with tfpipe_dump_activation.py 56 | segmentation_labels = np.load('data/{}_test_mask.npy'.format(dataset))[select, ..., 1:] # Labels - fake, not used 57 | 58 | img_sample = (np.reshape(img_sample, [-1, img_x, img_y, img_channel]) - 0.5) * 2.0 59 | mask_sample = (np.reshape(mask_sample, [-1, img_x, img_y, 1]) - 0.5) * 2.0 60 | activation_maps = (activation_maps - 0.5) * 2.0 61 | 62 | 63 | # extract all descriptors 64 | descriptors = [] 65 | for i, (img_array, mask_array, amap_array, seg_label) in enumerate(zip(img_sample, mask_sample, activation_maps, segmentation_labels)): 66 | print('img:', i) 67 | 68 | intermed_amap = K.get_session().run(act_input, feed_dict={img: [img_array], mask: [mask_array]}) 69 | 70 | descriptors.extend(extract_descriptors(intermed_amap, amap_array, seg_label, dataset, i)) 71 | 72 | if not os.path.exists('DMB'): 73 | os.makedirs('DMB') 74 | 75 | # resort fragments by img_id 76 | fragments_by_imgid = [[frag for frag in descriptors if frag[-2] == i] for i in range(len(img_sample))] 77 | test_amaps_reconstruction = [rebuild_AMaps_by_img(i, fragments_by_imgid) for i in range(len(img_sample))] 78 | print(test_amaps_reconstruction) 79 | with open('DMB/{}.test_amaps_reconstruction'.format(dataset), 'wb') as file: 80 | pickle.dump(test_amaps_reconstruction, file) 81 | 82 | -------------------------------------------------------------------------------- /Opts.py: -------------------------------------------------------------------------------- 1 | # author is He Zhao 2 | # The time to create is 8:49 PM, 28/11/16 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import scipy.io as sio 7 | import scipy.misc 8 | import scipy 9 | import tensorflow.contrib.slim as slim 10 | 11 | 12 | def lrelu(x, leak=0.2, name='lrelu'): 13 | with tf.variable_scope(name): 14 | f1 = 0.5 * (1 + leak) 15 | f2 = 0.5 * (1 - leak) 16 | return f1 * x + f2 * abs(x) 17 | 18 | def lrelu1(x, leak=0.2, name="lrelu"): 19 | return tf.maximum(x, leak*x) 20 | 21 | 22 | def save_images(images, size, image_path): 23 | return imsave(inverse_transform(images), size, image_path) 24 | 25 | 26 | def merge(images, size): 27 | h, w = images.shape[1], images.shape[2] 28 | img = np.zeros((h * size[0], w * size[1], 3)) 29 | for idx, image in enumerate(images): 30 | i = idx % size[1] 31 | j = idx // size[1] 32 | img[j*h:j*h+h, i*w:i*w+w, :] = image 33 | 34 | return img 35 | 36 | 37 | def imsave(images, size, path): 38 | return scipy.misc.imsave(path, merge(images, size)) 39 | 40 | 41 | def inverse_transform(images): 42 | return (images+1.)/2. 43 | 44 | def matTonpy(): 45 | 46 | # img = sio.loadmat('sample.mat')['imgSample'] 47 | # gt = sio.loadmat('sample.mat')['gtSample'] 48 | 49 | img = sio.loadmat('test.mat')['imgAllTest'] 50 | gt = sio.loadmat('test.mat')['gtAllTest'] 51 | 52 | # with open('img_sample.npy', 'wb') as fout: 53 | # np.save(fout, img) 54 | # with open('gt_sample.npy', 'wb') as fout: 55 | # np.save(fout, gt) 56 | return img, gt 57 | 58 | 59 | def TestImgForTest(dataPath): 60 | 61 | img = sio.loadmat(dataPath)['imgAllTest'] 62 | gt = sio.loadmat(dataPath)['gtAllTest'] 63 | 64 | return img, gt 65 | 66 | 67 | def TrainImgForTest(dataPath): 68 | 69 | img = sio.loadmat(dataPath)['imgAllTrain'] 70 | gt = sio.loadmat(dataPath)['gtAllTrain'] 71 | 72 | return img, gt 73 | 74 | def resUnit(input_layer, i, out_size): 75 | with tf.variable_scope("g_res_unit" + str(i)): 76 | net = slim.conv2d(inputs=input_layer, normalizer_fn=slim.batch_norm, activation_fn=lrelu, 77 | num_outputs=out_size, kernel_size=[4, 4], stride=2, padding='SAME') 78 | 79 | net = slim.conv2d(inputs=net, normalizer_fn=slim.batch_norm, activation_fn=lrelu, 80 | num_outputs=out_size, kernel_size=[4, 4], stride=1, padding='SAME') 81 | 82 | res = slim.conv2d(inputs=input_layer, normalizer_fn=slim.batch_norm, activation_fn=lrelu, 83 | num_outputs=out_size, kernel_size=[1, 1], stride=2, padding='SAME') 84 | 85 | output = net + res 86 | return output 87 | 88 | 89 | def resUnit_up(input_layer, i, out_size): 90 | with tf.variable_scope("g_res_unit_up" + str(i)): 91 | net = slim.conv2d_transpose(inputs=input_layer, normalizer_fn=slim.batch_norm, activation_fn=lrelu, 92 | num_outputs=out_size, kernel_size=[4, 4], stride=1, padding='SAME') 93 | 94 | net = slim.conv2d_transpose(inputs=net, normalizer_fn=slim.batch_norm, activation_fn=lrelu, 95 | num_outputs=out_size, kernel_size=[4, 4], stride=2, padding='SAME') 96 | 97 | res = slim.conv2d_transpose(inputs=input_layer, normalizer_fn=slim.batch_norm, activation_fn=lrelu, 98 | num_outputs=out_size, kernel_size=[1, 1], stride=2, padding='SAME') 99 | 100 | output = net + res 101 | return output 102 | -------------------------------------------------------------------------------- /tfpipe_dump_activation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('files', help='a path with wildcard, remember to quote it in shell.') 5 | parser.add_argument('--gpus') 6 | parser.add_argument('--dump_to') 7 | parser.add_argument('--visualize', action="store_true") 8 | args = parser.parse_args() 9 | print(args) 10 | 11 | import cv2 12 | import os 13 | if args.gpus: 14 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 15 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 16 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 17 | 18 | import numpy as np 19 | 20 | import tensorflow as tf 21 | import keras 22 | import keras.backend as K 23 | import detector 24 | import glob 25 | 26 | project_name = args.dump_to # for example 'IDRiD_train_activation_dump' 27 | project_path = os.path.join('Visualization', project_name) 28 | if not os.path.exists('Visualization'): 29 | os.makedirs('Visualization') 30 | 31 | input_list = sorted(glob.glob(args.files)) 32 | print('total files:', len(input_list)) 33 | 34 | batch_size = 100 35 | n_fill = -len(input_list) % batch_size # fill n_fill elements to the last batch 36 | input_list += [input_list[-1]] * n_fill 37 | 38 | ds_fnames = tf.data.Dataset.from_tensor_slices(input_list) 39 | # ds_images = ds_fnames.map(lambda fname: tf.image.decode_png(tf.read_file(fname), channels=3), num_parallel_calls=20) 40 | ds_images = ds_fnames.map(lambda fname: tf.py_func(lambda fname: cv2.imread(fname.decode()), [fname], [tf.uint8], stateful=False), num_parallel_calls=24) 41 | ds_images = ds_images.batch(batch_size) 42 | ds_images = ds_images.prefetch(buffer_size=1) 43 | images = ds_images.make_one_shot_iterator().get_next() 44 | images = images[0] 45 | images = images[..., ::-1] # bgr to rgb 46 | 47 | images = tf.cast(images, tf.float32)/255.0 48 | images = (images - 0.5) * 2.0 # rescale to [-1, 1] 49 | images.set_shape([None, 512, 512, 3]) 50 | images = tf.image.resize_images(images, [448, 448]) 51 | 52 | 53 | model = detector.get_layers_model(images, ['my_input'], 'scope_name', 54 | with_projection_output_from=('dense_2', [])) 55 | model.summary() 56 | projections = model.get_layer('my_input').related_projection.output 57 | p_mean = tf.reduce_mean(projections, axis=[1, 2, 3], keepdims=True) 58 | p_range = tf.reduce_max(projections, axis=[1, 2, 3], keepdims=True) - tf.reduce_min(projections, axis=[1, 2, 3], keepdims=True) 59 | projections = (projections - p_mean) / 0.1 / 255 * 255. 60 | projections = projections+0.5 61 | projections = tf.clip_by_value(projections, 0., 1.) 62 | projections = tf.image.resize_images(projections, [512, 512]) 63 | 64 | sess = K.get_session() 65 | 66 | projection_list = [] 67 | try: 68 | i = 0 69 | while True: 70 | print(i*batch_size) 71 | projection_list.append( 72 | sess.run(projections, {K.learning_phase(): 0}) # K.learning_phase() 0:testing 1:training 73 | ) 74 | i += 1 75 | except tf.errors.OutOfRangeError: 76 | print('done.') 77 | pass 78 | 79 | projection_npy = np.concatenate(projection_list, axis=0) 80 | projection_npy = projection_npy[:len(projection_npy)-n_fill, ...] 81 | print(len(projection_npy)) 82 | np.save(project_path+'.npy', projection_npy) 83 | print('Result saved to', project_path+'.npy') 84 | 85 | if args.visualize: 86 | if not os.path.exists(project_path): 87 | os.mkdir(project_path) 88 | 89 | for img, fname in zip(projection_npy, input_list): 90 | cv2.imwrite( 91 | os.path.join(project_path, os.path.basename(fname).replace('.jpg', '_A0.jpg')), 92 | img[..., ::-1]*255 # RGB2BGR 93 | ) 94 | print('visualizing', os.path.join(project_path, os.path.basename(fname)), '...') -------------------------------------------------------------------------------- /DMB_build_FGADR.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('dataset_name') 5 | parser.add_argument('--gpus') 6 | args = parser.parse_args() 7 | print(args) 8 | 9 | 10 | import os 11 | if args.gpus: 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 13 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 14 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 15 | 16 | import keras.backend as K 17 | import tensorflow as tf 18 | import numpy as np 19 | import StyleFeature 20 | import cv2 21 | 22 | from DMB_fragment import extract_descriptors 23 | import pickle 24 | 25 | # =============================== path set =============================================== # 26 | dataset = args.dataset_name 27 | 28 | img_channel = 3 29 | img_size = 512 30 | img_x = 512 31 | img_y = 512 32 | img = tf.placeholder(shape=[None, img_size, img_size, img_channel], dtype=tf.float32) 33 | mask = tf.placeholder(shape=[None, img_size, img_size, 1], dtype=tf.float32) 34 | 35 | img_detector = StyleFeature.get_style_model(img, mask, with_feature_mask_from=('dense_2', [])) 36 | act_input = { 37 | size: tf.image.resize_images((img_detector.get_layer(layer_name).related_projection.output - mean)/std, [size,size]) 38 | for layer_name, size, mean, std in zip(StyleFeature.STYLE_LAYERS, 39 | StyleFeature.STYLE_LAYERS_SIZE, 40 | StyleFeature.STYLE_LAYERS_MEAN, 41 | StyleFeature.STYLE_LAYERS_STD) 42 | } 43 | 44 | img_sample = np.load('data/{}_test_image.npy'.format(dataset)) # [n, h, w, 3] original fundus 45 | mask_sample = np.load('data/{}_test_mask.npy'.format(dataset)) # [n, h, w] FOV mask 46 | activation_maps = np.load('Visualization/{}_test.npy'.format(dataset)) # [n, h, w, 3] AMaps generated with tfpipe_dump_activation.py 47 | segmentation_labels = np.load('data/{}_test_mask.npy'.format(dataset))[..., 1:] # Deprecated 48 | 49 | img_sample = (np.reshape(img_sample, [-1, img_x, img_y, img_channel]) - 0.5) * 2.0 50 | mask_sample = (np.reshape(mask_sample, [-1, img_x, img_y, 1]) - 0.5) * 2.0 51 | activation_maps = (activation_maps - 0.5) * 2.0 52 | 53 | per_part = 250 54 | for part, part_start in enumerate(range(0, len(img_sample), per_part)): 55 | 56 | print('part', part, 'from', part_start, 'to', part_start+per_part) 57 | 58 | # extract all descriptors 59 | descriptors = [] 60 | for i, (img_array, mask_array, amap_array, seg_label) in enumerate(zip(img_sample[part_start:part_start+per_part, ...], 61 | mask_sample[part_start:part_start+per_part, ...], 62 | activation_maps[part_start:part_start+per_part, ...], 63 | segmentation_labels[part_start:part_start+per_part, ...])): 64 | print('part', part, 'img:', i) 65 | 66 | intermed_amap = K.get_session().run(act_input, feed_dict={img: [img_array], mask: [mask_array]}) 67 | 68 | descriptors.extend(extract_descriptors(intermed_amap, amap_array, seg_label, dataset, i)) 69 | 70 | if not os.path.exists('DMB'): 71 | os.makedirs('DMB') 72 | 73 | # resort fragments by category 74 | fragments_by_category = [[frag for frag in descriptors if frag[-2] == i] for i in range(len(img_sample))] 75 | with open('DMB/{}.by_img.{}'.format(dataset, part), 'wb') as file: 76 | pickle.dump(fragments_by_category, file) 77 | del fragments_by_category 78 | 79 | # resort fragments by category 80 | fragments_by_category = [[frag for frag in descriptors if frag[-1] == i] for i in range(-1, segmentation_labels.shape[-1])] 81 | with open('DMB/{}.by_cat.{}'.format(dataset, part), 'wb') as file: 82 | pickle.dump(fragments_by_category, file) 83 | del fragments_by_category 84 | 85 | del descriptors -------------------------------------------------------------------------------- /detector/nn.py: -------------------------------------------------------------------------------- 1 | """create CNN model and train it""" 2 | 3 | import tensorflow as tf 4 | from . import c_512_4x4_32 as config 5 | from .layers import * 6 | 7 | _detector = None 8 | models = {} 9 | 10 | 11 | def get_detector(): 12 | """return the detector singleton model, which is used for build layers model""" 13 | global _detector 14 | if _detector is not None: 15 | return _detector 16 | with tf.name_scope('detector_prototype'): 17 | _detector = keras.Sequential(name=config.cnf['name']) 18 | for layer, kwargs in config.layers: 19 | if 'activation' not in kwargs: 20 | _detector.add(layer(**kwargs)) 21 | if layer is InputLayer: 22 | _detector.add(Lambda(lambda x:x, name='my_input')) 23 | else: 24 | del kwargs['activation'] 25 | new_layer = layer(**kwargs) 26 | new_layer.related_activation = LeakyReLU(leaky_rectify_alpha, name=new_layer.name+'_act') 27 | _detector.add(new_layer) 28 | _detector.add(new_layer.related_activation) 29 | return _detector 30 | 31 | 32 | def get_layers_model(input, layer_names, scope_name='detector', load_weights_from='data/detector.h5', with_projection_output_from=None): 33 | """return a model, which output the features in determined layer_names 34 | argument: 35 | with_projection_output_from: (layer_name, channels), leave channels empty list to select all channels 36 | """ 37 | proto = get_detector() 38 | with tf.name_scope(scope_name): 39 | model = keras.models.clone_model(proto, input) # the model has his own weights and need to load_weights 40 | outputs = [ 41 | model.get_layer(name).related_activation.output 42 | if hasattr(model.get_layer(name), 'related_activation') 43 | else model.get_layer(name).get_output_at(-1) 44 | for name in layer_names 45 | ] 46 | layers_model = keras.Model(input, outputs) 47 | models[scope_name] = layers_model 48 | 49 | if load_weights_from: 50 | model.load_weights(load_weights_from, by_name=True) 51 | 52 | layers_model.original_layers = model.layers 53 | if with_projection_output_from: 54 | layername, channels = with_projection_output_from 55 | _add_projection_network(layers_model, model.get_layer(layername), channels) 56 | layers_model.output_projection = [model.get_layer(name).related_projection.output for name in layer_names if hasattr(model.get_layer(name), 'related_projection')] 57 | 58 | return layers_model 59 | 60 | 61 | def _add_projection_network(model, from_layer, channels): 62 | """add projection to model output, which is proposed originally for visualization in paper 63 | 'Visualizing and Understanding Convolutio Networks' 64 | 65 | note that from_layer accepts a layer but [LeakyReLU, Dropout] 66 | """ 67 | # get all layers to process on 68 | layers = [layer for layer in model.original_layers if not isinstance(layer, (LeakyReLU, Dropout))] 69 | 70 | def proj_func(input, layer: Layer): 71 | if isinstance(layer, (Dense, Conv2D)): 72 | input = tf.nn.relu(input) 73 | # input = input - layer.bias # biased? 74 | output = tf.gradients(layer.output, layer.input, grad_ys=input) 75 | return output 76 | 77 | from_idx = layers.index(from_layer) 78 | to_idx = layers.index(model.get_layer('my_input')) 79 | x = from_layer.output 80 | if channels is not None and channels != []: 81 | new_layer = Lambda(lambda x: tf.gather(x, channels, axis=-1), name='gather_channel') 82 | x = new_layer(x) 83 | from_idx += 1 84 | layers[from_idx] = new_layer 85 | 86 | for idx in range(from_idx, to_idx, -1): 87 | new_layer = Lambda(proj_func, arguments={'layer': layers[idx]}, name=layers[idx-1].name+'_proj') 88 | x = new_layer(x) 89 | layers[idx-1].related_projection = new_layer 90 | 91 | 92 | def preprocess(img): 93 | return tf.image.resize_images(img, [448, 448]) 94 | -------------------------------------------------------------------------------- /Test_reconstruct_DMB.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('dataset_name') 5 | parser.add_argument('--gpus') 6 | args = parser.parse_args() 7 | print(args) 8 | 9 | import os 10 | if args.gpus: 11 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 12 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 13 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 14 | 15 | 16 | import keras.backend as K 17 | import tensorflow as tf 18 | import Net 19 | import numpy as np 20 | import StyleFeature 21 | import scipy.io as sio 22 | import cv2 23 | from DMB_fragment import rebuild_AMaps_by_img 24 | import pickle 25 | import random 26 | import yaml 27 | 28 | from Opts import save_images, matTonpy 29 | # =============================== path set =============================================== # 30 | out_dir = args.dataset_name + '_Reconstruct' 31 | load_model = args.dataset_name 32 | db_dataset = args.dataset_name 33 | real_img_dataset = args.dataset_name # 'DRIVE' 34 | real_img_test_dataset = args.dataset_name # 'DRIVE' 35 | 36 | result_dir = 'Test' + '/' + out_dir + '' 37 | model_directory = 'Model_and_Result' + '/' + load_model + '/models' # Directory to restore trained model from. 38 | 39 | if tf.gfile.Exists(result_dir): 40 | print('Result dir exists! Press Enter to OVERRIDE...', end='') 41 | input() 42 | tf.gfile.DeleteRecursively(result_dir) 43 | if not os.path.exists(result_dir): 44 | os.makedirs(result_dir) 45 | 46 | os.system('cp {} {}'.format(__file__, result_dir)) 47 | 48 | # ============================== parameters set ========================================== # 49 | max_epoch = 1 50 | 51 | 52 | img_channel = 3 53 | img_size = 512 54 | img_x = 512 55 | img_y = 512 56 | gt_channel = 1 57 | 58 | z_size = 400 59 | 60 | 61 | 62 | # =============================== model and data definition ================================ # 63 | generator = Net.generator 64 | 65 | tf.reset_default_graph() 66 | 67 | gt = tf.placeholder(shape=[None, img_size, img_size, gt_channel], dtype=tf.float32) 68 | img = tf.placeholder(shape=[None, img_size, img_size, img_channel], dtype=tf.float32) 69 | mask = tf.placeholder(shape=[None, img_size, img_size, 1], dtype=tf.float32) 70 | z = tf.placeholder(shape=[None, z_size], dtype=tf.float32) 71 | 72 | # gt_mask = tf.concat([gt, mask], 3) 73 | 74 | img_detector = StyleFeature.get_style_model(img, mask, with_feature_mask_from=('dense_2', [])) 75 | act_input = { 76 | size: tf.image.resize_images((img_detector.get_layer(layer_name).related_projection.output - mean)/std, [size,size]) 77 | for layer_name, size, mean, std in zip(StyleFeature.STYLE_LAYERS, 78 | StyleFeature.STYLE_LAYERS_SIZE, 79 | StyleFeature.STYLE_LAYERS_MEAN, 80 | StyleFeature.STYLE_LAYERS_STD) 81 | } 82 | 83 | syn = generator(gt, act_input, z) 84 | 85 | 86 | # =============================== init ============================================= # 87 | t_vars = tf.trainable_variables() 88 | g_vars = list(set(t_vars) & set(tf.global_variables('generator'))) 89 | 90 | init = tf.variables_initializer(g_vars) 91 | sess = K.get_session() 92 | saver = tf.train.Saver(g_vars, max_to_keep=None) 93 | 94 | sess.run(init) 95 | #writer = tf.train.SummaryWriter(model_directory, sess.graph) 96 | 97 | # ==================================== restore weights ================================ # 98 | ckpt = tf.train.get_checkpoint_state(model_directory) 99 | restore_path = ckpt.model_checkpoint_path.replace('-9000', '-9000') 100 | print(restore_path) 101 | saver.restore(sess, restore_path) 102 | 103 | 104 | 105 | # ==================================== start training ===================================== # 106 | 107 | if real_img_test_dataset in ['FGADR', 'retinal-lesions', 'IDRiD']: 108 | # img_sample = np.load('data/{}_test_image.npy'.format(real_img_test_dataset)) 109 | gt_sample = np.load('data/{}_test_gt.npy'.format(real_img_test_dataset))[..., [0]] 110 | mask_sample = np.load('data/{}_test_mask.npy'.format(real_img_test_dataset)) 111 | elif real_img_test_dataset == 'DRIVE': 112 | img_sample, gt_sample, mask_sample = Net.matTonpy_35() 113 | 114 | 115 | # img_sample = (np.reshape(img_sample, [-1, img_x, img_y, img_channel]) - 0.5) * 2.0 116 | gt_sample = (np.reshape(gt_sample, [-1, img_x, img_y, gt_channel]) - 0.5) * 2.0 117 | mask_sample = (np.reshape(mask_sample, [-1, img_x, img_y, 1]) - 0.5) * 2.0 118 | 119 | with open('DMB/{}.by_img'.format(db_dataset), 'rb') as file: 120 | fragments_DB = pickle.load(file) 121 | 122 | with open('data/'+real_img_test_dataset+'_test.list', 'r') as f: 123 | fname_list = yaml.safe_load(f) 124 | 125 | # amaps = rebuild_AMaps_by_img(0, fragments_DB) 126 | 127 | # generate images with fragmentDB 128 | for epoch in range(max_epoch): 129 | print('epoch:', epoch) 130 | batchNum = 1 131 | 132 | for i, (gt_array, mask_array) in enumerate(zip(gt_sample, mask_sample)): 133 | print('img:', batchNum) 134 | 135 | zs = np.random.normal(0, 0.001, size=[1, z_size]).astype(np.float32) 136 | 137 | amaps = rebuild_AMaps_by_img(i, fragments_DB) 138 | syn_array = sess.run(syn, feed_dict={gt: [gt_array], z:zs, mask: [mask_array], 139 | act_input[256]: [amaps[256]], 140 | act_input[64]: [amaps[64]]}) 141 | syn_array = (syn_array + 1) / 2 142 | syn_sample_m = syn_array * ((mask_array + 1) / 2) 143 | 144 | syn_sample_m = np.reshape(syn_sample_m, [img_x, img_y, img_channel]) 145 | save_images(np.reshape(syn_sample_m, [1, img_x, img_y, img_channel]), 146 | [1, 1], 147 | result_dir + '/{}_{}.jpg'.format(fname_list[i], epoch)) 148 | 149 | batchNum += 1 150 | 151 | sess.close() 152 | -------------------------------------------------------------------------------- /Test_reconstruct_DMB_FGADR.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('dataset_name') 5 | parser.add_argument('--gpus') 6 | args = parser.parse_args() 7 | print(args) 8 | 9 | import os 10 | if args.gpus: 11 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 12 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 13 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 14 | 15 | 16 | import keras.backend as K 17 | import tensorflow as tf 18 | import Net 19 | import numpy as np 20 | import StyleFeature 21 | import scipy.io as sio 22 | import cv2 23 | from DMB_fragment import rebuild_AMaps_by_img 24 | import pickle 25 | import random 26 | import yaml 27 | 28 | from Opts import save_images, matTonpy 29 | # =============================== path set =============================================== # 30 | out_dir = args.dataset_name + '_Reconstruct' 31 | load_model = args.dataset_name 32 | db_dataset = args.dataset_name 33 | real_img_dataset = args.dataset_name # 'DRIVE' 34 | real_img_test_dataset = args.dataset_name # 'DRIVE' 35 | 36 | result_dir = 'Test' + '/' + out_dir + '' 37 | model_directory = 'Model_and_Result' + '/' + load_model + '/models' # Directory to restore trained model from. 38 | 39 | if tf.gfile.Exists(result_dir): 40 | print('Result dir exists! Press Enter to OVERRIDE...', end='') 41 | input() 42 | tf.gfile.DeleteRecursively(result_dir) 43 | if not os.path.exists(result_dir): 44 | os.makedirs(result_dir) 45 | 46 | os.system('cp {} {}'.format(__file__, result_dir)) 47 | 48 | # ============================== parameters set ========================================== # 49 | max_epoch = 1 50 | 51 | 52 | img_channel = 3 53 | img_size = 512 54 | img_x = 512 55 | img_y = 512 56 | gt_channel = 1 57 | 58 | z_size = 400 59 | 60 | 61 | 62 | # =============================== model and data definition ================================ # 63 | generator = Net.generator 64 | 65 | tf.reset_default_graph() 66 | 67 | gt = tf.placeholder(shape=[None, img_size, img_size, gt_channel], dtype=tf.float32) 68 | img = tf.placeholder(shape=[None, img_size, img_size, img_channel], dtype=tf.float32) 69 | mask = tf.placeholder(shape=[None, img_size, img_size, 1], dtype=tf.float32) 70 | z = tf.placeholder(shape=[None, z_size], dtype=tf.float32) 71 | 72 | # gt_mask = tf.concat([gt, mask], 3) 73 | 74 | img_detector = StyleFeature.get_style_model(img, mask, with_feature_mask_from=('dense_2', [])) 75 | act_input = { 76 | size: tf.image.resize_images((img_detector.get_layer(layer_name).related_projection.output - mean)/std, [size,size]) 77 | for layer_name, size, mean, std in zip(StyleFeature.STYLE_LAYERS, 78 | StyleFeature.STYLE_LAYERS_SIZE, 79 | StyleFeature.STYLE_LAYERS_MEAN, 80 | StyleFeature.STYLE_LAYERS_STD) 81 | } 82 | 83 | syn = generator(gt, act_input, z) 84 | 85 | 86 | # =============================== init ============================================= # 87 | t_vars = tf.trainable_variables() 88 | g_vars = list(set(t_vars) & set(tf.global_variables('generator'))) 89 | 90 | init = tf.variables_initializer(g_vars) 91 | sess = K.get_session() 92 | saver = tf.train.Saver(g_vars, max_to_keep=None) 93 | 94 | sess.run(init) 95 | #writer = tf.train.SummaryWriter(model_directory, sess.graph) 96 | 97 | # ==================================== restore weights ================================ # 98 | ckpt = tf.train.get_checkpoint_state(model_directory) 99 | restore_path = ckpt.model_checkpoint_path.replace('-9000', '-9000') 100 | print(restore_path) 101 | saver.restore(sess, restore_path) 102 | 103 | 104 | 105 | # ==================================== start training ===================================== # 106 | 107 | if real_img_test_dataset in ['FGADR', 'retinal-lesions', 'IDRiD']: 108 | # img_sample = np.load('data/{}_test_image.npy'.format(real_img_test_dataset)) 109 | gt_sample = np.load('data/{}_test_gt.npy'.format(real_img_test_dataset))[..., [0]] 110 | mask_sample = np.load('data/{}_test_mask.npy'.format(real_img_test_dataset)) 111 | elif real_img_test_dataset == 'DRIVE': 112 | img_sample, gt_sample, mask_sample = Net.matTonpy_35() 113 | 114 | 115 | # img_sample = (np.reshape(img_sample, [-1, img_x, img_y, img_channel]) - 0.5) * 2.0 116 | gt_sample = (np.reshape(gt_sample, [-1, img_x, img_y, gt_channel]) - 0.5) * 2.0 117 | mask_sample = (np.reshape(mask_sample, [-1, img_x, img_y, 1]) - 0.5) * 2.0 118 | 119 | per_part = 250 120 | part_id = -1 121 | 122 | with open('data/'+real_img_test_dataset+'_test.list', 'r') as f: 123 | fname_list = yaml.safe_load(f) 124 | 125 | # amaps = rebuild_AMaps_by_img(0, fragments_DB) 126 | 127 | # generate images with fragmentDB 128 | for epoch in range(max_epoch): 129 | print('epoch:', epoch) 130 | batchNum = 1 131 | 132 | for i, (gt_array, mask_array) in enumerate(zip(gt_sample, mask_sample)): 133 | print('img:', batchNum) 134 | 135 | if i//per_part != part_id: 136 | part_id = i//per_part 137 | with open('DMB/{}.by_img.{}'.format(db_dataset, part_id), 'rb') as file: 138 | fragments_DB = pickle.load(file) 139 | 140 | zs = np.random.normal(0, 0.001, size=[1, z_size]).astype(np.float32) 141 | 142 | amaps = rebuild_AMaps_by_img(i % per_part, fragments_DB) 143 | syn_array = sess.run(syn, feed_dict={gt: [gt_array], z:zs, mask: [mask_array], 144 | act_input[256]: [amaps[256]], 145 | act_input[64]: [amaps[64]]}) 146 | syn_array = (syn_array + 1) / 2 147 | syn_sample_m = syn_array * ((mask_array + 1) / 2) 148 | 149 | syn_sample_m = np.reshape(syn_sample_m, [img_x, img_y, img_channel]) 150 | save_images(np.reshape(syn_sample_m, [1, img_x, img_y, img_channel]), 151 | [1, 1], 152 | result_dir + '/{}_{}.jpg'.format(fname_list[i], epoch)) 153 | 154 | batchNum += 1 155 | 156 | sess.close() 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Patho-GAN: Interpretation + Medical Data Augmentation 2 | 3 | This repo is the project for IEEE JBHI paper "[Explainable Diabetic Retinopathy Detection and Retinal Image Generation](https://doi.org/10.1109/JBHI.2021.3110593)" ([arxiv](https://arxiv.org/abs/2107.00296)), based on a previous AAAI'19 work "[Pathological Evidence Exploration in Deep Retinal Image Diagnosis](https://ojs.aaai.org//index.php/AAAI/article/view/3901)" ([arxiv](https://arxiv.org/abs/1812.02640)). Inspired by Koch's Postulates, the foundation in evidence-based medicine (EBM) to identify the pathogen, we propose to exploit the interpretability of deep learning application in medical diagnosis. 4 | 5 | This is a comprehensive medical image framework featuring: 6 | 7 | - Unsupervised lesion detection & rough segmentation 8 | - Detecting lesions related to disease diagnosis such as microaneurysms, hemorrhages, soft and hard exudates 9 | - Fundus image synthesis with full controllability of lesion location and number 10 | - Data augmentation, fast and photo-realistic 11 | - CNN Interpretability framework 12 | 13 | With a diabetic retinopathy (DR) fundus as input, Pathological Descriptors can be extracted from [pretrained DR detectors](https://github.com/zzdyyy/kaggle_diabetic_keras). Patho-GAN can then generate diabetic retinopathy (DR) fundus given Pathological descriptors and vessel segmentation. 14 | 15 | ![teaser](github/teaser.png) 16 | 17 | **Interpretation**: We can determine the symptoms that the DR detector identifies as evidence to make prediction. This explainable work helps medical community to further understand on how deep learning makes prediction and encourage more collaboration. 18 | 19 | ![detection](github/detection.png) 20 | 21 | **Augmentation**: We can generate high quality medical images with various lesions, suitable for medical data augmentation (not only DR retinal fundus). The synthesized lesions is controllable by location and quantity. 22 | 23 | ![real_and_synthesized](github/real_and_synthesized.png) 24 | 25 | ![lesion_manipulation](github/lesion_manipulation.png) 26 | 27 | # Paper 28 | 29 | IEEE JBHI paper "[Explainable Diabetic Retinopathy Detection and Retinal Image Generation](https://doi.org/10.1109/JBHI.2021.3110593)" ([arxiv](https://arxiv.org/abs/2107.00296)) 30 | 31 | AAAI'19 work "[Pathological Evidence Exploration in Deep Retinal Image Diagnosis](https://ojs.aaai.org//index.php/AAAI/article/view/3901)" ([arxiv](https://arxiv.org/abs/1812.02640)) 32 | 33 | Please consider citing us. 34 | 35 | ``` 36 | @article{niu2021explainable, 37 | author={Niu, Yuhao and Gu, Lin and Zhao, Yitian and Lu, Feng}, 38 | journal={IEEE Journal of Biomedical and Health Informatics}, 39 | title={Explainable Diabetic Retinopathy Detection and Retinal Image Generation}, 40 | year={2021}, 41 | doi={10.1109/JBHI.2021.3110593} 42 | } 43 | ``` 44 | 45 | # Requirements 46 | 47 | The code is tested under Ubuntu16.04+Python3.5+Tensorflow1.8+Keras2.2. You can install a python3.5 environment and run: 48 | 49 | ```bash 50 | pip install --upgrade pip 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | Alternatively, you can find a docker image at https://hub.docker.com/r/zzdyyy/patho-gan on Linux machine with GPU(s). 55 | 56 | 57 | # Testing 58 | 59 | To synthesize DR images with pre-trained model: 60 | 61 | ```bash 62 | # Download pretrained VGG-19 model 63 | wget -O data/imagenet-vgg-verydeep-19.mat 'http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat' 64 | 65 | # Download pretrained o_O detector model 66 | gdown -O data/detector.h5 'https://drive.google.com/uc?id=1OI1d3XWM7IyW2igIEq8s-ZyF9vw0vTiw' 67 | 68 | # Download IDRiD vessel segmentation, descriptors, and pretrained Patho-GAN model 69 | gdown -O idrid_testing.tar.xz 'https://drive.google.com/uc?id=1Cf1WoaoGf6m7t6z70kpEl1SXOxTeM6Qu' 70 | tar -xvf idrid_testing.tar.xz 71 | 72 | # Synthesize the original image, output to `Test/IDRiD_Reconstruct/` 73 | python Test_reconstruct_DMB.py IDRiD 74 | 75 | # Lesion Relocation: relocate the lesions in IDRiD_67, output to `Test/IDRiD_Randomize_IDRiD_67.jpg/` 76 | python Test_reconstruct_DMB_randomize.py --dataset_name IDRiD --img_name IDRiD_67.jpg 77 | 78 | # Number Manipulation: decrease or increase the lesions in IDRiD_67, output to `Test/IDRiD_NumberAdjust_IDRiD_67.jpg/` 79 | python Test_reconstruct_DMB_numberadjust.py --dataset_name IDRiD --img_name IDRiD_67.jpg 80 | ``` 81 | 82 | To generate activation maps (suppose that IDRiD fundus are already in data/IDRiD/test_512/*.jpg): 83 | 84 | ```bash 85 | # this will output the activation maps in Visualization/IDRiD_test/ 86 | python tfpipe_dump_activation.py 'data/IDRiD/test_512/*.jpg' --dump_to IDRiD_test --visualize 87 | ``` 88 | 89 | # Training 90 | 91 | We take IDRiD dataset for example. 92 | 93 | 1. Download dataset from [this link](https://ieee-dataport.org/open-access/indian-diabetic-retinopathy-image-dataset-idrid), and extract it. 94 | 2. To crop and resize images into 512x512, cd into `data/IDRiD`, and run following line : 95 | ``` 96 | python convert.py --directory '/path_to_extracted_data/IDRiD/A. Segmentation/1. Original Images/a. Training Set' --convert_directory train_512/ 97 | python convert.py --directory '/path_to_extracted_data/IDRiD/A. Segmentation/1. Original Images/b. Testing Set' --convert_directory test_512/ 98 | ``` 99 | 3. Generate vessel segmentation. Clone modified version of [SA-UNet](https://github.com/zzdyyy/SA-UNet), and run in its root directory: 100 | ``` 101 | python Test_PathoGAN.py IDRiD 102 | ``` 103 | 4. Generate the numpy of the dataset. Run `python to_npy.py` in `data/IDRiD/`. 104 | 5. Generate descriptors for test samples. Download data/imagenet-vgg-verydeep-19.mat, run in Patho-GAN's root directory: 105 | ``` 106 | python DMB_build_test_samples.py IDRiD IDRiD_55.jpg IDRiD_61.jpg IDRiD_73.jpg IDRiD_81.jpg 107 | python DMB_build_test_samples.py retinal-lesions 250_right.jpg 2016_right.jpg 2044_left.jpg 2767_left.jpg 108 | python DMB_build_test_samples.py FGADR 0508.jpg 0549.jpg 0515.jpg 0529.jpg 109 | ``` 110 | 6. Start training: 111 | ``` 112 | python Train.py IDRiD IDRiD_55.jpg IDRiD_61.jpg IDRiD_73.jpg IDRiD_81.jpg 113 | python Train.py retinal-lesions 250_right.jpg 2016_right.jpg 2044_left.jpg 2767_left.jpg 114 | python Train.py FGADR 0508.jpg 0549.jpg 0515.jpg 0529.jpg 115 | ``` 116 | -------------------------------------------------------------------------------- /Net.py: -------------------------------------------------------------------------------- 1 | # author is He Zhao 2 | # The time to create is 8:47 PM, 28/11/16 3 | 4 | import tensorflow as tf 5 | import tensorflow.contrib.slim as slim 6 | from Opts import lrelu, resUnit 7 | from dataBlocks import DataBlocks 8 | import pickle 9 | import scipy.io as sio 10 | import numpy as np 11 | 12 | 13 | initializer = tf.truncated_normal_initializer(stddev=0.02) 14 | bias_initializer = tf.constant_initializer(0.0) 15 | 16 | 17 | def build_data(batchsize, dataset_name='DRIVE'): 18 | #readpath = open('img.pkl', 'rb') 19 | #datapaths = pickle.load(readpath) 20 | datapaths = [['data/{}_train_image.npy'.format(dataset_name), 21 | 'data/{}_train_gt.npy'.format(dataset_name), 22 | 'data/{}_train_mask.npy'.format(dataset_name)]] 23 | db = DataBlocks(data_paths=datapaths, train_valid_ratio=[39, 0], batchsize=batchsize, allow_preload=False) 24 | return db 25 | 26 | 27 | 28 | def matTonpy_35(): 29 | 30 | img = sio.loadmat('test_1To4.mat')['imgAllTest'] 31 | gt = sio.loadmat('test_1To4.mat')['gtAllTest'] 32 | mask = sio.loadmat('test_1To4.mat')['maskAllTest'] 33 | 34 | return img, gt, mask 35 | 36 | def discriminator(image, reuse=False): 37 | n=32 38 | bn = slim.batch_norm 39 | with tf.variable_scope("discriminator"): 40 | # original 41 | dis1 = slim.convolution2d(image, n, [4, 4], 2, activation_fn=lrelu, 42 | reuse=reuse, scope='d_conv1', weights_initializer=initializer) # 256 256 64 43 | 44 | dis2 = slim.convolution2d(dis1, 2*n, [4, 4], 2, normalizer_fn=bn, activation_fn=lrelu, 45 | reuse=reuse, scope='d_conv2', weights_initializer=initializer) # 128 128 64 46 | 47 | dis3 = slim.convolution2d(dis2, 4*n, [4, 4], 2, normalizer_fn=bn, activation_fn=lrelu, 48 | reuse=reuse, scope='d_conv3', weights_initializer=initializer) # 64 64 128 49 | 50 | dis4 = slim.convolution2d(dis3, 8*n, [4, 4], 2, normalizer_fn=bn, activation_fn=lrelu, 51 | reuse=reuse, scope='d_conv4', weights_initializer=initializer) # 32 32 256 52 | 53 | dis5 = slim.convolution2d(dis4, 16*n, [4, 4], 2, normalizer_fn=bn, activation_fn=lrelu, 54 | reuse=reuse, scope='d_conv5', weights_initializer=initializer) # 16 16 512 55 | 56 | 57 | d_out_logits = slim.fully_connected(slim.flatten(dis5), 1, activation_fn=None, reuse=reuse, scope='d_out', 58 | weights_initializer=initializer) 59 | 60 | d_out = tf.nn.sigmoid(d_out_logits) 61 | return d_out, d_out_logits 62 | 63 | def generator(image, act_input, z): 64 | n = 64 65 | with tf.variable_scope("generator"): 66 | # original 67 | e1 = slim.conv2d(image, n, [4, 4], 2, activation_fn=lrelu, scope='g_e1_conv', 68 | weights_initializer=initializer) 69 | # 256 70 | e2 = slim.conv2d(tf.concat([lrelu(e1), act_input[256]], 3), 2 * n, [4, 4], 2, normalizer_fn=slim.batch_norm, activation_fn=None, scope='g_e2_conv', 71 | weights_initializer=initializer) 72 | # 128 73 | e3 = slim.conv2d(lrelu(e2), 4 * n, [4, 4], 2, normalizer_fn=slim.batch_norm, activation_fn=None, scope='g_e3_conv', 74 | weights_initializer=initializer) 75 | # 64 76 | e4 = slim.conv2d(tf.concat([lrelu(e3), act_input[64]], 3), 8 * n, [4, 4], 2, normalizer_fn=slim.batch_norm, activation_fn=None, scope='g_e4_conv', 77 | weights_initializer=initializer) 78 | # 32 79 | e5 = slim.conv2d(lrelu(e4), 8*n, [4, 4], 2, normalizer_fn=slim.batch_norm, activation_fn=None, scope='g_e5_conv', 80 | weights_initializer=initializer) 81 | # 16 82 | e6 = slim.conv2d(lrelu(e5), 8*n, [4, 4], 2, normalizer_fn=slim.batch_norm, activation_fn=None, scope='g_e6_conv', 83 | weights_initializer=initializer) 84 | # 8 85 | 86 | 87 | zP = slim.fully_connected(z, 4 * 4 * n, normalizer_fn=None, activation_fn=lrelu, scope='g_project', 88 | weights_initializer=initializer) 89 | zCon = tf.reshape(zP, [-1, 4, 4, n]) 90 | 91 | gen1 = slim.conv2d(tf.image.resize_nearest_neighbor(lrelu(zCon), [8, 8]), 92 | 2 * n, [3, 3], 1, padding='SAME', normalizer_fn=slim.batch_norm, activation_fn=None, 93 | scope='g_dconv1', weights_initializer=initializer) 94 | # 8 95 | gen1 = tf.concat([gen1, e6], 3) 96 | 97 | gen2 = slim.conv2d(tf.image.resize_nearest_neighbor(lrelu(gen1), [16, 16]), 98 | 4 * n, [3, 3], 1, normalizer_fn=slim.batch_norm, activation_fn=None, 99 | scope='g_dconv2', weights_initializer=initializer) 100 | # 16 101 | gen2 = tf.concat([gen2, e5], 3) 102 | 103 | gen3 = slim.conv2d(tf.image.resize_nearest_neighbor(lrelu(gen2), [32, 32]), 104 | 8 * n, [3, 3], 1, normalizer_fn=slim.batch_norm, activation_fn=None, 105 | scope='g_dconv3', weights_initializer=initializer) 106 | gen3 = tf.concat([gen3, e4], 3) 107 | 108 | # 32 109 | gen6 = slim.conv2d(tf.image.resize_nearest_neighbor(tf.nn.relu(gen3), [64, 64]), 110 | 4 * n, [3, 3], 1, normalizer_fn=slim.batch_norm, activation_fn=None, 111 | scope='g_dconv6', weights_initializer=initializer) 112 | gen6 = tf.concat([gen6, e3], 3) 113 | 114 | # 64 115 | gen7 = slim.conv2d(tf.image.resize_nearest_neighbor(tf.nn.relu(gen6), [128, 128]), 116 | 2 * n, [3, 3], 1, normalizer_fn=slim.batch_norm, activation_fn=None, 117 | scope='g_dconv7', weights_initializer=initializer) 118 | gen7 = tf.concat([gen7, e2], 3) 119 | 120 | # 128 121 | gen8 = slim.conv2d(tf.image.resize_nearest_neighbor(tf.nn.relu(gen7), [256, 256]), 122 | n, [3, 3], 1, normalizer_fn=slim.batch_norm, activation_fn=None, 123 | scope='g_dconv8', weights_initializer=initializer) 124 | # 256 125 | # gen8 = tf.nn.dropout(gen8, 0.5) 126 | gen8 = tf.concat([gen8, e1], 3) 127 | gen8 = tf.nn.relu(gen8) 128 | 129 | # 256 130 | gen_out = slim.conv2d(tf.image.resize_nearest_neighbor(gen8, [512, 512]), 131 | 3, [3, 3], 1, activation_fn=tf.nn.tanh, scope='g_out', 132 | weights_initializer=initializer) 133 | 134 | return gen_out 135 | 136 | -------------------------------------------------------------------------------- /Test_reconstruct_DMB_randomize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--dataset_name', default='IDRiD') 5 | parser.add_argument('--img_name', default='IDRiD_67.jpg') 6 | parser.add_argument('--gpus') 7 | args = parser.parse_args() 8 | print(args) 9 | 10 | import os 11 | if args.gpus: 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 13 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 14 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 15 | 16 | 17 | import keras.backend as K 18 | import tensorflow as tf 19 | import Net 20 | import numpy as np 21 | import StyleFeature 22 | import scipy.io as sio 23 | import cv2 24 | from DMB_fragment import rebuild_AMaps_by_img 25 | import pickle 26 | import random 27 | import yaml 28 | 29 | from Opts import save_images, matTonpy 30 | # =============================== path set =============================================== # 31 | out_dir = args.dataset_name + '_Randomize_' + args.img_name 32 | load_model = args.dataset_name 33 | db_dataset = args.dataset_name 34 | real_img_dataset = args.dataset_name # 'DRIVE' 35 | real_img_test_dataset = args.dataset_name # 'DRIVE' 36 | 37 | result_dir = 'Test' + '/' + out_dir + '' 38 | model_directory = 'Model_and_Result' + '/' + load_model + '/models' # Directory to restore trained model from. 39 | 40 | if tf.gfile.Exists(result_dir): 41 | print('Result dir exists! Press Enter to OVERRIDE...', end='') 42 | input() 43 | tf.gfile.DeleteRecursively(result_dir) 44 | if not os.path.exists(result_dir): 45 | os.makedirs(result_dir) 46 | 47 | os.system('cp {} {}'.format(__file__, result_dir)) 48 | 49 | # ============================== parameters set ========================================== # 50 | max_epoch = 20 51 | 52 | 53 | img_channel = 3 54 | img_size = 512 55 | img_x = 512 56 | img_y = 512 57 | gt_channel = 1 58 | 59 | z_size = 400 60 | 61 | 62 | 63 | # =============================== model and data definition ================================ # 64 | generator = Net.generator 65 | 66 | tf.reset_default_graph() 67 | 68 | gt = tf.placeholder(shape=[None, img_size, img_size, gt_channel], dtype=tf.float32) 69 | img = tf.placeholder(shape=[None, img_size, img_size, img_channel], dtype=tf.float32) 70 | mask = tf.placeholder(shape=[None, img_size, img_size, 1], dtype=tf.float32) 71 | z = tf.placeholder(shape=[None, z_size], dtype=tf.float32) 72 | 73 | # gt_mask = tf.concat([gt, mask], 3) 74 | 75 | img_detector = StyleFeature.get_style_model(img, mask, with_feature_mask_from=('dense_2', [])) 76 | act_input = { 77 | size: tf.image.resize_images((img_detector.get_layer(layer_name).related_projection.output - mean)/std, [size,size]) 78 | for layer_name, size, mean, std in zip(StyleFeature.STYLE_LAYERS, 79 | StyleFeature.STYLE_LAYERS_SIZE, 80 | StyleFeature.STYLE_LAYERS_MEAN, 81 | StyleFeature.STYLE_LAYERS_STD) 82 | } 83 | 84 | syn = generator(gt, act_input, z) 85 | 86 | 87 | # =============================== init ============================================= # 88 | t_vars = tf.trainable_variables() 89 | g_vars = list(set(t_vars) & set(tf.global_variables('generator'))) 90 | 91 | init = tf.variables_initializer(g_vars) 92 | sess = K.get_session() 93 | saver = tf.train.Saver(g_vars, max_to_keep=None) 94 | 95 | sess.run(init) 96 | #writer = tf.train.SummaryWriter(model_directory, sess.graph) 97 | 98 | # ==================================== restore weights ================================ # 99 | ckpt = tf.train.get_checkpoint_state(model_directory) 100 | restore_path = ckpt.model_checkpoint_path.replace('-9000', '-9000') 101 | print(restore_path) 102 | saver.restore(sess, restore_path) 103 | 104 | 105 | 106 | # ==================================== start training ===================================== # 107 | 108 | if real_img_test_dataset in ['FGADR', 'retinal-lesions', 'IDRiD']: 109 | # img_sample = np.load('data/{}_test_image.npy'.format(real_img_test_dataset)) 110 | gt_sample = np.load('data/{}_test_gt.npy'.format(real_img_test_dataset))[..., [0]] 111 | mask_sample = np.load('data/{}_test_mask.npy'.format(real_img_test_dataset)) 112 | elif real_img_test_dataset == 'DRIVE': 113 | img_sample, gt_sample, mask_sample = Net.matTonpy_35() 114 | 115 | 116 | # img_sample = (np.reshape(img_sample, [-1, img_x, img_y, img_channel]) - 0.5) * 2.0 117 | gt_sample = (np.reshape(gt_sample, [-1, img_x, img_y, gt_channel]) - 0.5) * 2.0 118 | mask_sample = (np.reshape(mask_sample, [-1, img_x, img_y, 1]) - 0.5) * 2.0 119 | 120 | with open('DMB/{}.by_img'.format(db_dataset), 'rb') as file: 121 | fragments_DB = pickle.load(file) 122 | 123 | with open('data/'+real_img_test_dataset+'_test.list', 'r') as f: 124 | fname_list = yaml.safe_load(f) 125 | 126 | # amaps = rebuild_AMaps_by_img(0, fragments_DB) 127 | 128 | # generate images with fragmentDB 129 | for epoch in range(max_epoch): 130 | print('epoch:', epoch) 131 | batchNum = 1 132 | 133 | if True: # for i, (gt_array, mask_array) in enumerate(zip(gt_sample, mask_sample)): 134 | i = fname_list.index(args.img_name) 135 | i, (gt_array, mask_array) = i, (gt_sample[i], mask_sample[i]) 136 | 137 | print('img:', batchNum) 138 | 139 | zs = np.random.normal(0, 0.001, size=[1, z_size]).astype(np.float32) 140 | 141 | amaps, lesion_map = rebuild_AMaps_by_img(i, fragments_DB, randomize=True, lesion_map=True) 142 | syn_array = sess.run(syn, feed_dict={gt: [gt_array], z:zs, mask: [mask_array], 143 | act_input[256]: [amaps[256]], 144 | act_input[64]: [amaps[64]]}) 145 | syn_array = (syn_array + 1) / 2 146 | syn_sample_m = syn_array * ((mask_array + 1) / 2) 147 | 148 | syn_sample_m = np.reshape(syn_sample_m, [img_x, img_y, img_channel]) 149 | save_images(np.reshape(syn_sample_m, [1, img_x, img_y, img_channel]), 150 | [1, 1], 151 | result_dir + '/{}_{}.jpg'.format(fname_list[i], epoch)) 152 | 153 | lesion_map = cv2.resize(lesion_map, (512, 512), interpolation=cv2.INTER_NEAREST) 154 | lesion_map = lesion_map * ((mask_array + 1) / 2) # crop by mask 155 | _, binary = cv2.threshold((((mask_array + 1) / 2) * 255).astype('uint8'), 0, 255, cv2.THRESH_BINARY) 156 | contours, _ = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 157 | lesion_map = cv2.drawContours(lesion_map, contours, -1, (255, 255, 255)) 158 | cv2.imwrite(result_dir + '/{}_{}_lesion_map.jpg'.format(fname_list[i], epoch), 159 | lesion_map) 160 | 161 | batchNum += 1 162 | 163 | sess.close() 164 | -------------------------------------------------------------------------------- /Test_reconstruct_DMB_randomize_FGADR.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--dataset_name', default='FGADR') 5 | parser.add_argument('--img_name', default='1709.jpg') 6 | parser.add_argument('--gpus') 7 | args = parser.parse_args() 8 | print(args) 9 | 10 | import os 11 | if args.gpus: 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 13 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 14 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 15 | 16 | 17 | import keras.backend as K 18 | import tensorflow as tf 19 | import Net 20 | import numpy as np 21 | import StyleFeature 22 | import scipy.io as sio 23 | import cv2 24 | from DMB_fragment import rebuild_AMaps_by_img 25 | import pickle 26 | import random 27 | import yaml 28 | 29 | from Opts import save_images, matTonpy 30 | # =============================== path set =============================================== # 31 | out_dir = args.dataset_name + '_Randomize_' + args.img_name 32 | load_model = args.dataset_name 33 | db_dataset = args.dataset_name 34 | real_img_dataset = args.dataset_name # 'DRIVE' 35 | real_img_test_dataset = args.dataset_name # 'DRIVE' 36 | 37 | result_dir = 'Test' + '/' + out_dir + '' 38 | model_directory = 'Model_and_Result' + '/' + load_model + '/models' # Directory to restore trained model from. 39 | 40 | if tf.gfile.Exists(result_dir): 41 | print('Result dir exists! Press Enter to OVERRIDE...', end='') 42 | input() 43 | tf.gfile.DeleteRecursively(result_dir) 44 | if not os.path.exists(result_dir): 45 | os.makedirs(result_dir) 46 | 47 | os.system('cp {} {}'.format(__file__, result_dir)) 48 | 49 | # ============================== parameters set ========================================== # 50 | max_epoch = 20 51 | 52 | 53 | img_channel = 3 54 | img_size = 512 55 | img_x = 512 56 | img_y = 512 57 | gt_channel = 1 58 | 59 | z_size = 400 60 | 61 | 62 | 63 | # =============================== model and data definition ================================ # 64 | generator = Net.generator 65 | 66 | tf.reset_default_graph() 67 | 68 | gt = tf.placeholder(shape=[None, img_size, img_size, gt_channel], dtype=tf.float32) 69 | img = tf.placeholder(shape=[None, img_size, img_size, img_channel], dtype=tf.float32) 70 | mask = tf.placeholder(shape=[None, img_size, img_size, 1], dtype=tf.float32) 71 | z = tf.placeholder(shape=[None, z_size], dtype=tf.float32) 72 | 73 | # gt_mask = tf.concat([gt, mask], 3) 74 | 75 | img_detector = StyleFeature.get_style_model(img, mask, with_feature_mask_from=('dense_2', [])) 76 | act_input = { 77 | size: tf.image.resize_images((img_detector.get_layer(layer_name).related_projection.output - mean)/std, [size,size]) 78 | for layer_name, size, mean, std in zip(StyleFeature.STYLE_LAYERS, 79 | StyleFeature.STYLE_LAYERS_SIZE, 80 | StyleFeature.STYLE_LAYERS_MEAN, 81 | StyleFeature.STYLE_LAYERS_STD) 82 | } 83 | 84 | syn = generator(gt, act_input, z) 85 | 86 | 87 | # =============================== init ============================================= # 88 | t_vars = tf.trainable_variables() 89 | g_vars = list(set(t_vars) & set(tf.global_variables('generator'))) 90 | 91 | init = tf.variables_initializer(g_vars) 92 | sess = K.get_session() 93 | saver = tf.train.Saver(g_vars, max_to_keep=None) 94 | 95 | sess.run(init) 96 | #writer = tf.train.SummaryWriter(model_directory, sess.graph) 97 | 98 | # ==================================== restore weights ================================ # 99 | ckpt = tf.train.get_checkpoint_state(model_directory) 100 | restore_path = ckpt.model_checkpoint_path.replace('-9000', '-9000') 101 | print(restore_path) 102 | saver.restore(sess, restore_path) 103 | 104 | 105 | 106 | # ==================================== start training ===================================== # 107 | 108 | if real_img_test_dataset in ['FGADR', 'retinal-lesions', 'IDRiD']: 109 | # img_sample = np.load('data/{}_test_image.npy'.format(real_img_test_dataset)) 110 | gt_sample = np.load('data/{}_test_gt.npy'.format(real_img_test_dataset))[..., [0]] 111 | mask_sample = np.load('data/{}_test_mask.npy'.format(real_img_test_dataset)) 112 | elif real_img_test_dataset == 'DRIVE': 113 | img_sample, gt_sample, mask_sample = Net.matTonpy_35() 114 | 115 | 116 | # img_sample = (np.reshape(img_sample, [-1, img_x, img_y, img_channel]) - 0.5) * 2.0 117 | gt_sample = (np.reshape(gt_sample, [-1, img_x, img_y, gt_channel]) - 0.5) * 2.0 118 | mask_sample = (np.reshape(mask_sample, [-1, img_x, img_y, 1]) - 0.5) * 2.0 119 | 120 | per_part = 250 121 | part_id = -1 122 | 123 | with open('data/'+real_img_test_dataset+'_test.list', 'r') as f: 124 | fname_list = yaml.safe_load(f) 125 | 126 | # amaps = rebuild_AMaps_by_img(0, fragments_DB) 127 | 128 | # generate images with fragmentDB 129 | for epoch in range(max_epoch): 130 | print('epoch:', epoch) 131 | batchNum = 1 132 | 133 | if True: # for i, (gt_array, mask_array) in enumerate(zip(gt_sample, mask_sample)): 134 | i = fname_list.index(args.img_name) 135 | i, (gt_array, mask_array) = i, (gt_sample[i], mask_sample[i]) 136 | 137 | print('img:', batchNum) 138 | 139 | if i//per_part != part_id: 140 | part_id = i//per_part 141 | with open('DMB/{}.by_img.{}'.format(db_dataset, part_id), 'rb') as file: 142 | fragments_DB = pickle.load(file) 143 | 144 | zs = np.random.normal(0, 0.001, size=[1, z_size]).astype(np.float32) 145 | 146 | amaps, lesion_map = rebuild_AMaps_by_img(i % per_part, fragments_DB, randomize=True, lesion_map=True) 147 | syn_array = sess.run(syn, feed_dict={gt: [gt_array], z:zs, mask: [mask_array], 148 | act_input[256]: [amaps[256]], 149 | act_input[64]: [amaps[64]]}) 150 | syn_array = (syn_array + 1) / 2 151 | syn_sample_m = syn_array * ((mask_array + 1) / 2) 152 | 153 | syn_sample_m = np.reshape(syn_sample_m, [img_x, img_y, img_channel]) 154 | save_images(np.reshape(syn_sample_m, [1, img_x, img_y, img_channel]), 155 | [1, 1], 156 | result_dir + '/{}_{}.jpg'.format(fname_list[i], epoch)) 157 | 158 | lesion_map = cv2.resize(lesion_map, (512, 512), interpolation=cv2.INTER_NEAREST) 159 | lesion_map = lesion_map * ((mask_array + 1) / 2) # crop by mask 160 | _, binary = cv2.threshold((((mask_array + 1) / 2) * 255).astype('uint8'), 0, 255, cv2.THRESH_BINARY) 161 | contours, _ = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 162 | lesion_map = cv2.drawContours(lesion_map, contours, -1, (255, 255, 255)) 163 | cv2.imwrite(result_dir + '/{}_{}_lesion_map.jpg'.format(fname_list[i], epoch), 164 | lesion_map) 165 | 166 | batchNum += 1 167 | 168 | sess.close() 169 | -------------------------------------------------------------------------------- /Test_reconstruct_DMB_numberadjust.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--dataset_name', default='IDRiD') 5 | parser.add_argument('--img_name', default='IDRiD_67.jpg') 6 | parser.add_argument('--gpus') 7 | args = parser.parse_args() 8 | print(args) 9 | 10 | import os 11 | if args.gpus: 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 13 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 14 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 15 | 16 | 17 | import keras.backend as K 18 | import tensorflow as tf 19 | import Net 20 | import numpy as np 21 | import StyleFeature 22 | import scipy.io as sio 23 | import cv2 24 | from DMB_fragment import rebuild_AMaps_by_img 25 | import pickle 26 | import random 27 | import yaml 28 | 29 | from Opts import save_images, matTonpy 30 | # =============================== path set =============================================== # 31 | out_dir = args.dataset_name + '_NumberAdjust_' + args.img_name 32 | load_model = args.dataset_name 33 | db_dataset = args.dataset_name 34 | real_img_dataset = args.dataset_name # 'DRIVE' 35 | real_img_test_dataset = args.dataset_name # 'DRIVE' 36 | 37 | result_dir = 'Test' + '/' + out_dir + '' 38 | model_directory = 'Model_and_Result' + '/' + load_model + '/models' # Directory to restore trained model from. 39 | 40 | if tf.gfile.Exists(result_dir): 41 | print('Result dir exists! Press Enter to OVERRIDE...', end='') 42 | input() 43 | tf.gfile.DeleteRecursively(result_dir) 44 | if not os.path.exists(result_dir): 45 | os.makedirs(result_dir) 46 | 47 | os.system('cp {} {}'.format(__file__, result_dir)) 48 | 49 | # ============================== parameters set ========================================== # 50 | max_epoch = 20 51 | 52 | 53 | img_channel = 3 54 | img_size = 512 55 | img_x = 512 56 | img_y = 512 57 | gt_channel = 1 58 | 59 | z_size = 400 60 | 61 | 62 | 63 | # =============================== model and data definition ================================ # 64 | generator = Net.generator 65 | 66 | tf.reset_default_graph() 67 | 68 | gt = tf.placeholder(shape=[None, img_size, img_size, gt_channel], dtype=tf.float32) 69 | img = tf.placeholder(shape=[None, img_size, img_size, img_channel], dtype=tf.float32) 70 | mask = tf.placeholder(shape=[None, img_size, img_size, 1], dtype=tf.float32) 71 | z = tf.placeholder(shape=[None, z_size], dtype=tf.float32) 72 | 73 | # gt_mask = tf.concat([gt, mask], 3) 74 | 75 | img_detector = StyleFeature.get_style_model(img, mask, with_feature_mask_from=('dense_2', [])) 76 | act_input = { 77 | size: tf.image.resize_images((img_detector.get_layer(layer_name).related_projection.output - mean)/std, [size,size]) 78 | for layer_name, size, mean, std in zip(StyleFeature.STYLE_LAYERS, 79 | StyleFeature.STYLE_LAYERS_SIZE, 80 | StyleFeature.STYLE_LAYERS_MEAN, 81 | StyleFeature.STYLE_LAYERS_STD) 82 | } 83 | 84 | syn = generator(gt, act_input, z) 85 | 86 | 87 | # =============================== init ============================================= # 88 | t_vars = tf.trainable_variables() 89 | g_vars = list(set(t_vars) & set(tf.global_variables('generator'))) 90 | 91 | init = tf.variables_initializer(g_vars) 92 | sess = K.get_session() 93 | saver = tf.train.Saver(g_vars, max_to_keep=None) 94 | 95 | sess.run(init) 96 | #writer = tf.train.SummaryWriter(model_directory, sess.graph) 97 | 98 | # ==================================== restore weights ================================ # 99 | ckpt = tf.train.get_checkpoint_state(model_directory) 100 | restore_path = ckpt.model_checkpoint_path.replace('-9000', '-9000') 101 | print(restore_path) 102 | saver.restore(sess, restore_path) 103 | 104 | 105 | 106 | # ==================================== start training ===================================== # 107 | 108 | if real_img_test_dataset in ['FGADR', 'retinal-lesions', 'IDRiD']: 109 | # img_sample = np.load('data/{}_test_image.npy'.format(real_img_test_dataset)) 110 | gt_sample = np.load('data/{}_test_gt.npy'.format(real_img_test_dataset))[..., [0]] 111 | mask_sample = np.load('data/{}_test_mask.npy'.format(real_img_test_dataset)) 112 | elif real_img_test_dataset == 'DRIVE': 113 | img_sample, gt_sample, mask_sample = Net.matTonpy_35() 114 | 115 | 116 | # img_sample = (np.reshape(img_sample, [-1, img_x, img_y, img_channel]) - 0.5) * 2.0 117 | gt_sample = (np.reshape(gt_sample, [-1, img_x, img_y, gt_channel]) - 0.5) * 2.0 118 | mask_sample = (np.reshape(mask_sample, [-1, img_x, img_y, 1]) - 0.5) * 2.0 119 | 120 | with open('DMB/{}.by_img'.format(db_dataset), 'rb') as file: 121 | fragments_DB = pickle.load(file) 122 | 123 | with open('data/'+real_img_test_dataset+'_test.list', 'r') as f: 124 | fname_list = yaml.safe_load(f) 125 | 126 | # amaps = rebuild_AMaps_by_img(0, fragments_DB) 127 | 128 | # generate images with fragmentDB 129 | for epoch in range(max_epoch+16): 130 | print('epoch:', epoch) 131 | batchNum = 1 132 | 133 | if True: # for i, (gt_array, mask_array) in enumerate(zip(gt_sample, mask_sample)): 134 | i = fname_list.index(args.img_name) 135 | i, (gt_array, mask_array) = i, (gt_sample[i], mask_sample[i]) 136 | 137 | print('img:', batchNum) 138 | 139 | zs = np.random.normal(0, 0.001, size=[1, z_size]).astype(np.float32) 140 | 141 | if epoch <= max_epoch: 142 | amaps, lesion_map = rebuild_AMaps_by_img(i, fragments_DB, lesion_map=True, quantity=epoch/max_epoch) 143 | else: 144 | amaps, lesion_map = rebuild_AMaps_by_img(i, fragments_DB, randomize=True, lesion_map=True, multiple=epoch-max_epoch+1) 145 | syn_array = sess.run(syn, feed_dict={gt: [gt_array], z:zs, mask: [mask_array], 146 | act_input[256]: [amaps[256]], 147 | act_input[64]: [amaps[64]]}) 148 | syn_array = (syn_array + 1) / 2 149 | syn_sample_m = syn_array * ((mask_array + 1) / 2) 150 | 151 | syn_sample_m = np.reshape(syn_sample_m, [img_x, img_y, img_channel]) 152 | save_images(np.reshape(syn_sample_m, [1, img_x, img_y, img_channel]), 153 | [1, 1], 154 | result_dir + '/{}_{}.jpg'.format(fname_list[i], epoch)) 155 | 156 | lesion_map = cv2.resize(lesion_map, (512, 512), interpolation=cv2.INTER_NEAREST) 157 | lesion_map = lesion_map * ((mask_array + 1) / 2) # crop by mask 158 | _, binary = cv2.threshold((((mask_array + 1) / 2) * 255).astype('uint8'), 0, 255, cv2.THRESH_BINARY) 159 | contours, _ = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 160 | lesion_map = cv2.drawContours(lesion_map, contours, -1, (255, 255, 255)) 161 | cv2.imwrite(result_dir + '/{}_{}_lesion_map.jpg'.format(fname_list[i], epoch), 162 | lesion_map) 163 | 164 | batchNum += 1 165 | 166 | sess.close() 167 | -------------------------------------------------------------------------------- /Test_reconstruct_DMB_numberadjust_FGADR.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--dataset_name', default='FGADR') 5 | parser.add_argument('--img_name', default='0516.jpg') 6 | parser.add_argument('--gpus') 7 | args = parser.parse_args() 8 | print(args) 9 | 10 | import os 11 | if args.gpus: 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 13 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 14 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 15 | 16 | 17 | import keras.backend as K 18 | import tensorflow as tf 19 | import Net 20 | import numpy as np 21 | import StyleFeature 22 | import scipy.io as sio 23 | import cv2 24 | from DMB_fragment import rebuild_AMaps_by_img 25 | import pickle 26 | import random 27 | import yaml 28 | 29 | from Opts import save_images, matTonpy 30 | # =============================== path set =============================================== # 31 | out_dir = args.dataset_name + '_NumberAdjust_' + args.img_name 32 | load_model = args.dataset_name 33 | db_dataset = args.dataset_name 34 | real_img_dataset = args.dataset_name # 'DRIVE' 35 | real_img_test_dataset = args.dataset_name # 'DRIVE' 36 | 37 | result_dir = 'Test' + '/' + out_dir + '' 38 | model_directory = 'Model_and_Result' + '/' + load_model + '/models' # Directory to restore trained model from. 39 | 40 | if tf.gfile.Exists(result_dir): 41 | print('Result dir exists! Press Enter to OVERRIDE...', end='') 42 | input() 43 | tf.gfile.DeleteRecursively(result_dir) 44 | if not os.path.exists(result_dir): 45 | os.makedirs(result_dir) 46 | 47 | os.system('cp {} {}'.format(__file__, result_dir)) 48 | 49 | # ============================== parameters set ========================================== # 50 | max_epoch = 20 51 | 52 | 53 | img_channel = 3 54 | img_size = 512 55 | img_x = 512 56 | img_y = 512 57 | gt_channel = 1 58 | 59 | z_size = 400 60 | 61 | 62 | 63 | # =============================== model and data definition ================================ # 64 | generator = Net.generator 65 | 66 | tf.reset_default_graph() 67 | 68 | gt = tf.placeholder(shape=[None, img_size, img_size, gt_channel], dtype=tf.float32) 69 | img = tf.placeholder(shape=[None, img_size, img_size, img_channel], dtype=tf.float32) 70 | mask = tf.placeholder(shape=[None, img_size, img_size, 1], dtype=tf.float32) 71 | z = tf.placeholder(shape=[None, z_size], dtype=tf.float32) 72 | 73 | # gt_mask = tf.concat([gt, mask], 3) 74 | 75 | img_detector = StyleFeature.get_style_model(img, mask, with_feature_mask_from=('dense_2', [])) 76 | act_input = { 77 | size: tf.image.resize_images((img_detector.get_layer(layer_name).related_projection.output - mean)/std, [size,size]) 78 | for layer_name, size, mean, std in zip(StyleFeature.STYLE_LAYERS, 79 | StyleFeature.STYLE_LAYERS_SIZE, 80 | StyleFeature.STYLE_LAYERS_MEAN, 81 | StyleFeature.STYLE_LAYERS_STD) 82 | } 83 | 84 | syn = generator(gt, act_input, z) 85 | 86 | 87 | # =============================== init ============================================= # 88 | t_vars = tf.trainable_variables() 89 | g_vars = list(set(t_vars) & set(tf.global_variables('generator'))) 90 | 91 | init = tf.variables_initializer(g_vars) 92 | sess = K.get_session() 93 | saver = tf.train.Saver(g_vars, max_to_keep=None) 94 | 95 | sess.run(init) 96 | #writer = tf.train.SummaryWriter(model_directory, sess.graph) 97 | 98 | # ==================================== restore weights ================================ # 99 | ckpt = tf.train.get_checkpoint_state(model_directory) 100 | restore_path = ckpt.model_checkpoint_path.replace('-9000', '-9000') 101 | print(restore_path) 102 | saver.restore(sess, restore_path) 103 | 104 | 105 | 106 | # ==================================== start training ===================================== # 107 | 108 | if real_img_test_dataset in ['FGADR', 'retinal-lesions', 'IDRiD']: 109 | # img_sample = np.load('data/{}_test_image.npy'.format(real_img_test_dataset)) 110 | gt_sample = np.load('data/{}_test_gt.npy'.format(real_img_test_dataset))[..., [0]] 111 | mask_sample = np.load('data/{}_test_mask.npy'.format(real_img_test_dataset)) 112 | elif real_img_test_dataset == 'DRIVE': 113 | img_sample, gt_sample, mask_sample = Net.matTonpy_35() 114 | 115 | 116 | # img_sample = (np.reshape(img_sample, [-1, img_x, img_y, img_channel]) - 0.5) * 2.0 117 | gt_sample = (np.reshape(gt_sample, [-1, img_x, img_y, gt_channel]) - 0.5) * 2.0 118 | mask_sample = (np.reshape(mask_sample, [-1, img_x, img_y, 1]) - 0.5) * 2.0 119 | 120 | per_part = 250 121 | part_id = -1 122 | 123 | with open('data/'+real_img_test_dataset+'_test.list', 'r') as f: 124 | fname_list = yaml.safe_load(f) 125 | 126 | # amaps = rebuild_AMaps_by_img(0, fragments_DB) 127 | 128 | # generate images with fragmentDB 129 | for epoch in range(max_epoch+16): 130 | print('epoch:', epoch) 131 | batchNum = 1 132 | 133 | if True: # for i, (gt_array, mask_array) in enumerate(zip(gt_sample, mask_sample)): 134 | i = fname_list.index(args.img_name) 135 | i, (gt_array, mask_array) = i, (gt_sample[i], mask_sample[i]) 136 | 137 | print('img:', batchNum) 138 | 139 | if i//per_part != part_id: 140 | part_id = i//per_part 141 | with open('DMB/{}.by_img.{}'.format(db_dataset, part_id), 'rb') as file: 142 | fragments_DB = pickle.load(file) 143 | 144 | zs = np.random.normal(0, 0.001, size=[1, z_size]).astype(np.float32) 145 | 146 | if epoch <= max_epoch: 147 | amaps, lesion_map = rebuild_AMaps_by_img(i, fragments_DB, lesion_map=True, quantity=epoch/max_epoch) 148 | else: 149 | amaps, lesion_map = rebuild_AMaps_by_img(i, fragments_DB, randomize=True, lesion_map=True, multiple=epoch-max_epoch+1) 150 | syn_array = sess.run(syn, feed_dict={gt: [gt_array], z:zs, mask: [mask_array], 151 | act_input[256]: [amaps[256]], 152 | act_input[64]: [amaps[64]]}) 153 | syn_array = (syn_array + 1) / 2 154 | syn_sample_m = syn_array * ((mask_array + 1) / 2) 155 | 156 | syn_sample_m = np.reshape(syn_sample_m, [img_x, img_y, img_channel]) 157 | save_images(np.reshape(syn_sample_m, [1, img_x, img_y, img_channel]), 158 | [1, 1], 159 | result_dir + '/{}_{}.jpg'.format(fname_list[i], epoch)) 160 | 161 | lesion_map = cv2.resize(lesion_map, (512, 512), interpolation=cv2.INTER_NEAREST) 162 | lesion_map = lesion_map * ((mask_array + 1) / 2) # crop by mask 163 | _, binary = cv2.threshold((((mask_array + 1) / 2) * 255).astype('uint8'), 0, 255, cv2.THRESH_BINARY) 164 | contours, _ = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 165 | lesion_map = cv2.drawContours(lesion_map, contours, -1, (255, 255, 255)) 166 | cv2.imwrite(result_dir + '/{}_{}_lesion_map.jpg'.format(fname_list[i], epoch), 167 | lesion_map) 168 | 169 | batchNum += 1 170 | 171 | sess.close() 172 | -------------------------------------------------------------------------------- /data/IDRiD/convert.py: -------------------------------------------------------------------------------- 1 | """Resize and crop images to square, save as tiff.""" 2 | 3 | import os 4 | from multiprocessing.pool import Pool 5 | 6 | import click 7 | import numpy as np 8 | from PIL import Image, ImageFilter 9 | import cv2 10 | 11 | N_PROC = 20 12 | 13 | def convert(fname, crop_size, convert_fname): 14 | img = cv2.imread(fname) 15 | img_MA = cv2.imread(fname.replace('1. Original Images/b. Testing Set', '2. All Segmentation Groundtruths/b. Testing Set/1. Microaneurysms').replace('.jpg', '_MA.tif')) 16 | img_HE = cv2.imread(fname.replace('1. Original Images/b. Testing Set', '2. All Segmentation Groundtruths/b. Testing Set/2. Haemorrhages').replace('.jpg', '_HE.tif')) 17 | img_EX = cv2.imread(fname.replace('1. Original Images/b. Testing Set', '2. All Segmentation Groundtruths/b. Testing Set/3. Hard Exudates').replace('.jpg', '_EX.tif')) 18 | img_SE = cv2.imread(fname.replace('1. Original Images/b. Testing Set', '2. All Segmentation Groundtruths/b. Testing Set/4. Soft Exudates').replace('.jpg', '_SE.tif')) 19 | img_OD = cv2.imread(fname.replace('1. Original Images/b. Testing Set', '2. All Segmentation Groundtruths/b. Testing Set/5. Optic Disc').replace('.jpg', '_OD.tif')) 20 | img_MA = img_MA[..., 2] if img_MA is not None else img[..., 2] * 0 21 | img_HE = img_HE[..., 2] if img_HE is not None else img[..., 2] * 0 22 | img_EX = img_EX[..., 2] if img_EX is not None else img[..., 2] * 0 23 | img_SE = img_SE[..., 2] if img_SE is not None else img[..., 2] * 0 24 | img_OD = img_OD[..., 2] if img_OD is not None else img[..., 2] * 0 25 | 26 | ba = img 27 | h, w, _ = ba.shape 28 | 29 | if w > 1.2 * h: 30 | # to get the threshold, compute the maximum value of left and right 1/32-width part 31 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 32 | left_max = gray[:, : w // 32].max().astype(int) 33 | right_max = gray[:, - w // 32:].max().astype(int) 34 | max_bg = np.maximum(left_max, right_max) 35 | 36 | # print(max_bg) # TODO: DEBUG 37 | _, foreground = cv2.threshold(gray, max_bg + 20, 255, cv2.THRESH_BINARY) 38 | bbox = cv2.boundingRect(cv2.findNonZero(foreground)) # (x, y, width, height) 39 | 40 | if bbox is None: 41 | print('bbox none for {} (???)'.format(fname)) 42 | else: 43 | left, upper, width, height = bbox 44 | 45 | # if we selected less than 80% of the original 46 | # height, just crop the square 47 | if width < 0.8 * h or height < 0.8 * h: 48 | print('bbox too small for {}'.format(fname)) 49 | bbox = None 50 | else: 51 | bbox = None 52 | 53 | if bbox is None: 54 | bbox = square_bbox(w, h) 55 | print(bbox, fname) 56 | 57 | # do croping 58 | left, upper, width, height = bbox 59 | img = img[upper:upper+height, left:left+width, ...] 60 | img_MA = img_MA[upper:upper+height, left:left+width] 61 | img_HE = img_HE[upper:upper+height, left:left+width] 62 | img_EX = img_EX[upper:upper+height, left:left+width] 63 | img_SE = img_SE[upper:upper+height, left:left+width] 64 | img_OD = img_OD[upper:upper+height, left:left+width] 65 | 66 | #padding 67 | if width != height: 68 | if width > height: 69 | pad_width = width - height 70 | pad = ((pad_width//2, pad_width-pad_width//2), (0, 0)) 71 | else: 72 | pad_width = height - width 73 | pad = ((0, 0), (pad_width // 2, pad_width - pad_width // 2)) 74 | img = np.pad(img, (pad[0], pad[1], (0,0)), 'constant', constant_values=0) 75 | img_MA = np.pad(img_MA, pad, 'constant', constant_values=0) 76 | img_HE = np.pad(img_HE, pad, 'constant', constant_values=0) 77 | img_EX = np.pad(img_EX, pad, 'constant', constant_values=0) 78 | img_SE = np.pad(img_SE, pad, 'constant', constant_values=0) 79 | img_OD = np.pad(img_OD, pad, 'constant', constant_values=0) 80 | 81 | # resizing 82 | img = cv2.resize(img, (crop_size, crop_size), interpolation=cv2.INTER_CUBIC) 83 | img_MA = cv2.resize(img_MA, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 84 | img_HE = cv2.resize(img_HE, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 85 | img_EX = cv2.resize(img_EX, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 86 | img_SE = cv2.resize(img_SE, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 87 | img_OD = cv2.resize(img_OD, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 88 | 89 | 90 | cv2.imwrite(convert_fname, img) 91 | cv2.imwrite(convert_fname.replace('.jpg', '_MA.png'), img_MA) 92 | cv2.imwrite(convert_fname.replace('.jpg', '_HE.png'), img_HE) 93 | cv2.imwrite(convert_fname.replace('.jpg', '_EX.png'), img_EX) 94 | cv2.imwrite(convert_fname.replace('.jpg', '_SE.png'), img_SE) 95 | cv2.imwrite(convert_fname.replace('.jpg', '_OD.png'), img_OD) 96 | 97 | 98 | def square_bbox(w, h): 99 | left = (w - h) // 2 100 | upper = 0 101 | right = left + h 102 | lower = h 103 | return (left, upper, right-left, lower-upper) 104 | 105 | 106 | def convert_square(fname, crop_size): 107 | img = Image.open(fname) 108 | bbox = square_bbox(img) 109 | cropped = img.crop(bbox) 110 | resized = cropped.resize([crop_size, crop_size]) 111 | return resized 112 | 113 | 114 | def get_convert_fname(fname, extension, directory, convert_directory): 115 | return fname.replace('jpg', extension).replace(directory, 116 | convert_directory) 117 | 118 | 119 | def process(args): 120 | fun, arg = args 121 | directory, convert_directory, fname, crop_size, extension = arg 122 | convert_fname = get_convert_fname(fname, extension, directory, 123 | convert_directory) 124 | if not os.path.exists(convert_fname): 125 | img = fun(fname, crop_size, convert_fname) 126 | 127 | 128 | def save(img, fname): 129 | img.save(fname, quality=97) 130 | 131 | @click.command() 132 | @click.option('--directory', default='IDRiD/A. Segmentation/1. Original Images/b. Testing Set', show_default=True, 133 | help="Directory with original images.") 134 | @click.option('--convert_directory', default='test_512/', show_default=True, 135 | help="Where to save converted images.") 136 | @click.option('--crop_size', default=512, show_default=True, 137 | help="Size of converted images.") 138 | @click.option('--extension', default='jpg', show_default=True, 139 | help="Filetype of converted images.") 140 | def main(directory, convert_directory, crop_size, extension): 141 | 142 | try: 143 | os.mkdir(convert_directory) 144 | except OSError: 145 | pass 146 | 147 | filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(directory) 148 | for f in fn if f.endswith('jpeg') or f.endswith('jpg') or f.endswith('png') or f.endswith('tiff')] 149 | filenames = sorted(filenames) 150 | 151 | print("Resizing images in {} to {}, this takes a while." 152 | "".format(directory, convert_directory)) 153 | 154 | n = len(filenames) 155 | # process in batches, sometimes weird things happen with Pool on my machine 156 | batchsize = 20 157 | batches = n // batchsize + 1 158 | pool = Pool(N_PROC) 159 | 160 | args = [] 161 | 162 | for f in filenames: 163 | args.append((convert, (directory, convert_directory, f, crop_size, 164 | extension))) 165 | # break # TODO: Debug 166 | 167 | for i in range(batches): 168 | print("batch {:>2} / {}".format(i + 1, batches)) 169 | pool.map(process, args[i * batchsize: (i + 1) * batchsize]) 170 | 171 | pool.close() 172 | 173 | print('done') 174 | 175 | if __name__ == '__main__': 176 | main() 177 | -------------------------------------------------------------------------------- /data/FGADR/convert.py: -------------------------------------------------------------------------------- 1 | """Resize and crop images to square, save as tiff.""" 2 | 3 | import os 4 | from multiprocessing.pool import Pool 5 | 6 | import click 7 | import numpy as np 8 | from PIL import Image, ImageFilter 9 | import cv2 10 | 11 | N_PROC = 20 12 | 13 | def convert(fname, crop_size, convert_fname): 14 | img = cv2.imread(fname) 15 | img_MA = cv2.imread(fname.replace('Original_Images', 'Microaneurysms_Masks')) 16 | img_HE = cv2.imread(fname.replace('Original_Images', 'Hemohedge_Masks')) 17 | img_EX = cv2.imread(fname.replace('Original_Images', 'HardExudate_Masks')) 18 | img_SE = cv2.imread(fname.replace('Original_Images', 'SoftExudate_Masks')) 19 | img_IM = cv2.imread(fname.replace('Original_Images', 'IRMA_Masks')) 20 | img_NE = cv2.imread(fname.replace('Original_Images', 'Neovascularization_Masks')) 21 | img_MA = img_MA[..., 2] if img_MA is not None else img[..., 2] * 0 22 | img_HE = img_HE[..., 2] if img_HE is not None else img[..., 2] * 0 23 | img_EX = img_EX[..., 2] if img_EX is not None else img[..., 2] * 0 24 | img_SE = img_SE[..., 2] if img_SE is not None else img[..., 2] * 0 25 | img_IM = img_IM[..., 2] if img_IM is not None else img[..., 2] * 0 26 | img_NE = img_NE[..., 2] if img_NE is not None else img[..., 2] * 0 27 | 28 | ba = img 29 | h, w, _ = ba.shape 30 | 31 | if w > 1.2 * h: 32 | # to get the threshold, compute the maximum value of left and right 1/32-width part 33 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 34 | left_max = gray[:, : w // 32].max().astype(int) 35 | right_max = gray[:, - w // 32:].max().astype(int) 36 | max_bg = np.maximum(left_max, right_max) 37 | 38 | # print(max_bg) # TODO: DEBUG 39 | _, foreground = cv2.threshold(gray, max_bg + 20, 255, cv2.THRESH_BINARY) 40 | bbox = cv2.boundingRect(cv2.findNonZero(foreground)) # (x, y, width, height) 41 | 42 | if bbox is None: 43 | print('bbox none for {} (???)'.format(fname)) 44 | else: 45 | left, upper, width, height = bbox 46 | 47 | # if we selected less than 80% of the original 48 | # height, just crop the square 49 | if width < 0.8 * h or height < 0.8 * h: 50 | print('bbox too small for {}'.format(fname)) 51 | bbox = None 52 | else: 53 | bbox = None 54 | 55 | if bbox is None: 56 | bbox = square_bbox(w, h) 57 | 58 | # do croping 59 | left, upper, width, height = bbox 60 | img = img[upper:upper+height, left:left+width, ...] 61 | img_MA = img_MA[upper:upper+height, left:left+width] 62 | img_HE = img_HE[upper:upper+height, left:left+width] 63 | img_EX = img_EX[upper:upper+height, left:left+width] 64 | img_SE = img_SE[upper:upper+height, left:left+width] 65 | img_IM = img_IM[upper:upper+height, left:left+width] 66 | img_NE = img_NE[upper:upper+height, left:left+width] 67 | 68 | #padding 69 | if width != height: 70 | if width > height: 71 | pad_width = width - height 72 | pad = ((pad_width//2, pad_width-pad_width//2), (0, 0)) 73 | else: 74 | pad_width = height - width 75 | pad = ((0, 0), (pad_width // 2, pad_width - pad_width // 2)) 76 | img = np.pad(img, (pad[0], pad[1], (0,0)), 'constant', constant_values=0) 77 | img_MA = np.pad(img_MA, pad, 'constant', constant_values=0) 78 | img_HE = np.pad(img_HE, pad, 'constant', constant_values=0) 79 | img_EX = np.pad(img_EX, pad, 'constant', constant_values=0) 80 | img_SE = np.pad(img_SE, pad, 'constant', constant_values=0) 81 | img_IM = np.pad(img_IM, pad, 'constant', constant_values=0) 82 | img_NE = np.pad(img_NE, pad, 'constant', constant_values=0) 83 | 84 | # resizing 85 | img = cv2.resize(img, (crop_size, crop_size), interpolation=cv2.INTER_CUBIC) 86 | img_MA = cv2.resize(img_MA, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 87 | img_HE = cv2.resize(img_HE, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 88 | img_EX = cv2.resize(img_EX, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 89 | img_SE = cv2.resize(img_SE, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 90 | img_IM = cv2.resize(img_IM, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 91 | img_NE = cv2.resize(img_NE, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 92 | 93 | convert_fname = convert_fname[:-6] + '.jpg' 94 | cv2.imwrite(convert_fname, img) 95 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 96 | _, binary = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY) 97 | cv2.imwrite(convert_fname.replace('.jpg', '_MASK.png'), binary) 98 | 99 | cv2.imwrite(convert_fname.replace('.jpg', '_MA.png'), img_MA) 100 | cv2.imwrite(convert_fname.replace('.jpg', '_HE.png'), img_HE) 101 | cv2.imwrite(convert_fname.replace('.jpg', '_EX.png'), img_EX) 102 | cv2.imwrite(convert_fname.replace('.jpg', '_SE.png'), img_SE) 103 | cv2.imwrite(convert_fname.replace('.jpg', '_IM.png'), img_IM) 104 | cv2.imwrite(convert_fname.replace('.jpg', '_NE.png'), img_NE) 105 | 106 | 107 | def square_bbox(w, h): 108 | left = (w - h) // 2 109 | upper = 0 110 | right = left + h 111 | lower = h 112 | return (left, upper, right-left, lower-upper) 113 | 114 | 115 | def convert_square(fname, crop_size): 116 | img = Image.open(fname) 117 | bbox = square_bbox(img) 118 | cropped = img.crop(bbox) 119 | resized = cropped.resize([crop_size, crop_size]) 120 | return resized 121 | 122 | 123 | def get_convert_fname(fname, extension, directory, convert_directory): 124 | return fname.replace('png', extension).replace(directory, 125 | convert_directory) 126 | 127 | 128 | def process(args): 129 | fun, arg = args 130 | directory, convert_directory, fname, crop_size, extension = arg 131 | convert_fname = get_convert_fname(fname, extension, directory, 132 | convert_directory) 133 | if not os.path.exists(convert_fname): 134 | img = fun(fname, crop_size, convert_fname) 135 | 136 | 137 | def save(img, fname): 138 | img.save(fname, quality=97) 139 | 140 | @click.command() 141 | @click.option('--directory', default='FGADR/Seg-set/Original_Images', show_default=True, 142 | help="Directory with original images.") 143 | @click.option('--convert_directory', default='resized_512/', show_default=True, 144 | help="Where to save converted images.") 145 | @click.option('--crop_size', default=512, show_default=True, 146 | help="Size of converted images.") 147 | @click.option('--extension', default='jpg', show_default=True, 148 | help="Filetype of converted images.") 149 | def main(directory, convert_directory, crop_size, extension): 150 | 151 | try: 152 | os.mkdir(convert_directory) 153 | except OSError: 154 | pass 155 | 156 | filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(directory) 157 | for f in fn if f.endswith('jpeg') or f.endswith('jpg') or f.endswith('png') or f.endswith('tiff')] 158 | filenames = sorted(filenames) 159 | 160 | print("Resizing images in {} to {}, this takes a while." 161 | "".format(directory, convert_directory)) 162 | 163 | n = len(filenames) 164 | # process in batches, sometimes weird things happen with Pool on my machine 165 | batchsize = 20 166 | batches = n // batchsize + 1 167 | pool = Pool(N_PROC) 168 | 169 | args = [] 170 | 171 | for f in filenames: 172 | args.append((convert, (directory, convert_directory, f, crop_size, 173 | extension))) 174 | # break # TODO: Debug 175 | 176 | for i in range(batches): 177 | print("batch {:>2} / {}".format(i + 1, batches)) 178 | pool.map(process, args[i * batchsize: (i + 1) * batchsize]) 179 | 180 | pool.close() 181 | 182 | print('done') 183 | 184 | if __name__ == '__main__': 185 | main() 186 | -------------------------------------------------------------------------------- /StyleFeature.py: -------------------------------------------------------------------------------- 1 | # derived from He Zhao's version created at 3:40 PM, 23/3/17 2 | import tensorflow as tf 3 | import vgg 4 | import detector 5 | import numpy as np 6 | import scipy.io as sio 7 | from skimage import measure 8 | from scipy import ndimage 9 | import cv2 10 | 11 | 12 | STYLE_LAYERS = ('conv2d_1', 'conv2d_3') 13 | STYLE_LAYERS_SIZE = (256, 64) 14 | STYLE_LAYERS_CHANNELS = (32, 64) 15 | STYLE_LAYERS_MEAN = (2e-7, -2e-5) 16 | STYLE_LAYERS_STD = (0.05, 0.03) 17 | CONTENT_LAYER = ('relu4_2',) 18 | 19 | 20 | 21 | 22 | 23 | class Lesion: 24 | """Descriptor of a single Lesion""" 25 | def __init__(self): 26 | self.sty_bbox = (0, 0, 0, 0) # bbox(y1, y2, x1, x2) of the lesion in the style reference image 27 | self.inmask = None # a patch of the input_mask that include all the lesion region 28 | self.feature = {} # the feature map dictionary, {"layer_name": (bbox, activation, GRAM)} 29 | 30 | 31 | def gauss_kernel(l=5, sig=1.): 32 | """ 33 | creates gaussian kernel with side length l and a sigma of sig 34 | """ 35 | 36 | ax = np.arange(-l // 2 + 1., l // 2 + 1.) 37 | xx, yy = np.meshgrid(ax, ax) 38 | 39 | kernel = np.exp(-(xx**2 + yy**2) / (2. * sig**2)) 40 | 41 | return kernel / np.max(kernel) 42 | 43 | 44 | def projection_to_mask(projection): 45 | with tf.name_scope('proj_to_fmask'): 46 | # projection = tf.reduce_sum(projection, axis=-1, keep_dims=True) # fused across channels? 47 | projection = tf.abs(projection) 48 | projection = projection / (1e-20 + tf.reduce_max(projection)) 49 | projection = tf.maximum(projection, 0.) 50 | # projection = tf.sqrt(projection) # broaden? 51 | return projection 52 | 53 | 54 | def get_input_mask(projection, dilation=2, threshold=0.33): 55 | with tf.name_scope('input_mask'): 56 | projection = tf.reduce_sum(projection, axis=-1, keepdims=True) # fused across channels? 57 | projection = tf.abs(projection) 58 | projection = projection / (1e-20 + tf.reduce_max(projection)) 59 | projection = tf.maximum(projection, 0.) 60 | projection = tf.to_float(projection > threshold) 61 | r = dilation 62 | x, y = np.ogrid[-r:r+1, -r:r+1] 63 | filter = (x**2 + y**2 <= r**2).astype(np.float32)[..., np.newaxis] - 1 64 | projection = tf.nn.dilation2d(projection, filter, [1, 1, 1, 1], [1, 1, 1, 1], 'SAME') 65 | projection = projection 66 | return projection 67 | 68 | 69 | def get_lesion_descriptors(sty_inmask, label_mask=None): 70 | """return a descriptor frameworks 71 | sty_inmask: mask to restrict style loss 72 | label_mask: used to find bbox and separate lesions 73 | """ 74 | assert (sty_inmask.shape[0] == 448) 75 | assert (sty_inmask.shape[1] == 448) 76 | assert (label_mask.shape[0] == 448) 77 | assert (label_mask.shape[1] == 448) 78 | 79 | descriptors = [] 80 | 81 | if label_mask is None: 82 | print('Use sty_inmask as label_mask!') 83 | label_mask = sty_inmask 84 | 85 | sty_inmask_labeled = measure.label(label_mask > 0, connectivity=2) # find connected components 86 | assert sty_inmask_labeled.max() >= 1 87 | positions = ndimage.find_objects(sty_inmask_labeled) # get bboxes 88 | 89 | for i in range(sty_inmask_labeled.max()): 90 | lesion = Lesion() 91 | slice1, slice2 = positions[i] 92 | lesion.sty_bbox = (slice1.start, slice1.stop, slice2.start, slice2.stop) 93 | lesion.inmask = ((sty_inmask[slice1, slice2] > 0) & (sty_inmask_labeled[slice1, slice2] == i+1)).astype(np.float32) 94 | descriptors.append(lesion) 95 | 96 | return descriptors 97 | 98 | 99 | def fill_features_into_descriptors(descriptors, model, sess, feed_dict, activation_restriction=True): 100 | """use model to get style feature and put them in descriptors 101 | activation_restriction=True, use activation_map as a restriction 102 | activation_restriction=False, use simply in_mask to restrict 103 | """ 104 | for i, layer in enumerate(STYLE_LAYERS): 105 | features = model.outputs[i] 106 | activation = projection_to_mask(model.output_projection[i]) 107 | 108 | for j, lesion in enumerate(descriptors): # type:Lesion 109 | y1, y2, x1, x2 = lesion.sty_bbox 110 | input_mask = np.pad(lesion.inmask, [[y1, 448-y2], [x1, 448-x2]], mode='constant') 111 | scaled_inmask = cv2.resize(input_mask, tuple(model.outputs[i].shape.as_list()[1:3])) 112 | slice1, slice2 = ndimage.find_objects(scaled_inmask > 0)[0] 113 | local_bbox = (slice1.start, slice1.stop, slice2.start, slice2.stop) 114 | local_inmask = scaled_inmask[None, slice1, slice2, None] 115 | local_activation = activation[:, slice1, slice2, :] * local_inmask \ 116 | if activation_restriction else \ 117 | (tf.ones([1, slice1.stop-slice1.start, slice2.stop-slice2.start, 1]) * local_inmask * 1.e-1) 118 | local_features = features[:, slice1, slice2, :] * local_activation 119 | local_features = tf.reshape(local_features, shape=[-1, local_features.shape.as_list()[1] * local_features.shape.as_list()[2], 120 | local_features.shape.as_list()[3]])[0] 121 | local_features_T = tf.transpose(local_features) 122 | local_gram = tf.matmul(local_features_T, local_features) / float(local_features.shape.as_list()[0] * local_features.shape.as_list()[1]) 123 | lesion.feature[layer] = (local_bbox,) + sess.run((local_activation, local_gram), feed_dict=feed_dict) 124 | 125 | 126 | def get_style_model(image, mask, with_feature_mask_from=None): 127 | 128 | if mask is not None: 129 | image = (image+1)*((mask+1)/2)-1 130 | 131 | if image._shape_as_list()[1] != 448: 132 | image = tf.image.resize_images(image, [448,448]) 133 | 134 | model = detector.get_layers_model(image, STYLE_LAYERS + ('dense_3',), 'style_model', 135 | with_projection_output_from=with_feature_mask_from) 136 | return model 137 | 138 | 139 | def get_patho_loss(img_model, syn_model): 140 | return tf.reduce_mean(tf.square( 141 | img_model.get_layer('my_input').related_projection.output 142 | - syn_model.get_layer('my_input').related_projection.output 143 | )) # MSE 144 | 145 | 146 | def get_severity_loss(img_model, syn_model): 147 | return tf.reduce_mean(tf.square( 148 | img_model.get_layer('dense_3').output 149 | - syn_model.get_layer('dense_3').output 150 | )) # MSE 151 | 152 | 153 | def get_content_features(image, mask): 154 | image = tf.multiply(image + 1, 127.5) 155 | if mask is not None: 156 | image = image * ((mask + 1) / 2) 157 | 158 | img_features = {} 159 | 160 | if image._shape_as_list()[1] != 512: 161 | image = tf.image.resize_images(image, [512, 512]) 162 | 163 | # with tf.device('/cpu:0'): 164 | 165 | img_pre = vgg.preprocess(image) 166 | vgg_path = 'data/imagenet-vgg-verydeep-19.mat' 167 | data = sio.loadmat(vgg_path) 168 | net = vgg.net(data, img_pre) 169 | 170 | for layer in CONTENT_LAYER: 171 | features = net[layer] 172 | img_features[layer] = features 173 | 174 | return img_features 175 | 176 | 177 | def get_retinal_loss(img, syn, mask): 178 | 179 | img_features = get_content_features(img, mask) 180 | syn_features = get_content_features(syn, mask) 181 | 182 | content_lossE = 0 183 | for content_layer in CONTENT_LAYER: 184 | coff = float(1.0 / len(CONTENT_LAYER)) 185 | img_content = img_features[content_layer] 186 | syn_content = syn_features[content_layer] 187 | content_lossE += coff * tf.reduce_mean(tf.abs(img_content - syn_content)) 188 | 189 | content_loss = tf.reduce_mean(content_lossE) 190 | 191 | return content_loss 192 | 193 | 194 | def get_tv_loss(img, mask, input_mask=None): 195 | # mask: [-1, 1] 196 | # input_mask: [0, 1] 197 | 198 | img = img*((mask+1)/2) 199 | # x = tf.reduce_sum(tf.abs(img[:, 1:, :, :] - img[:, :-1, :, :])) 200 | # y = tf.reduce_sum(tf.abs(img[:, :, 1:, :] - img[:, :, :-1, :])) 201 | 202 | if input_mask is not None: 203 | x = tf.reduce_sum(input_mask[:, :-1, :, :] * tf.abs(img[:, 1:, :, :] - img[:, :-1, :, :])) / (1e-8 + 3*tf.reduce_sum(input_mask)) 204 | y = tf.reduce_sum(input_mask[:, :, :-1, :] * tf.abs(img[:, :, 1:, :] - img[:, :, :-1, :])) / (1e-8 + 3*tf.reduce_sum(input_mask)) 205 | else: 206 | x = tf.reduce_mean(tf.abs(img[:, 1:, :, :] - img[:, :-1, :, :])) 207 | y = tf.reduce_mean(tf.abs(img[:, :, 1:, :] - img[:, :, :-1, :])) 208 | 209 | return x+y 210 | 211 | 212 | -------------------------------------------------------------------------------- /data/retinal-lesions/convert.py: -------------------------------------------------------------------------------- 1 | """Resize and crop images to square, save as tiff.""" 2 | 3 | import os 4 | from multiprocessing.pool import Pool 5 | 6 | import click 7 | import numpy as np 8 | from PIL import Image, ImageFilter 9 | import cv2 10 | 11 | N_PROC = 20 12 | 13 | def convert(fname, crop_size, convert_fname): 14 | img = cv2.imread(fname) 15 | img_CW = cv2.imread(fname.replace('images_896x896', 'lesion_segs_896x896').replace('.jpg', '/cotton_wool_spots.png')) 16 | img_FP = cv2.imread(fname.replace('images_896x896', 'lesion_segs_896x896').replace('.jpg', '/fibrous_proliferation.png')) 17 | img_EX = cv2.imread(fname.replace('images_896x896', 'lesion_segs_896x896').replace('.jpg', '/hard_exudate.png')) 18 | img_MA = cv2.imread(fname.replace('images_896x896', 'lesion_segs_896x896').replace('.jpg', '/microaneurysm.png')) 19 | img_NS = cv2.imread(fname.replace('images_896x896', 'lesion_segs_896x896').replace('.jpg', '/neovascularization.png')) 20 | img_PH = cv2.imread(fname.replace('images_896x896', 'lesion_segs_896x896').replace('.jpg', '/preretinal_hemorrhage.png')) 21 | img_RH = cv2.imread(fname.replace('images_896x896', 'lesion_segs_896x896').replace('.jpg', '/retinal_hemorrhage.png')) 22 | img_VH = cv2.imread(fname.replace('images_896x896', 'lesion_segs_896x896').replace('.jpg', '/vitreous_hemorrhage.png')) 23 | 24 | img_CW = img_CW[..., 2] if img_CW is not None else img[..., 2] * 0 25 | img_FP = img_FP[..., 2] if img_FP is not None else img[..., 2] * 0 26 | img_EX = img_EX[..., 2] if img_EX is not None else img[..., 2] * 0 27 | img_MA = img_MA[..., 2] if img_MA is not None else img[..., 2] * 0 28 | img_NS = img_NS[..., 2] if img_NS is not None else img[..., 2] * 0 29 | img_PH = img_PH[..., 2] if img_PH is not None else img[..., 2] * 0 30 | img_RH = img_RH[..., 2] if img_RH is not None else img[..., 2] * 0 31 | img_VH = img_VH[..., 2] if img_VH is not None else img[..., 2] * 0 32 | 33 | ba = img 34 | h, w, _ = ba.shape 35 | 36 | if w > 1.2 * h: 37 | # to get the threshold, compute the maximum value of left and right 1/32-width part 38 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 39 | left_max = gray[:, : w // 32].max().astype(int) 40 | right_max = gray[:, - w // 32:].max().astype(int) 41 | max_bg = np.maximum(left_max, right_max) 42 | 43 | # print(max_bg) # TODO: DEBUG 44 | _, foreground = cv2.threshold(gray, max_bg + 20, 255, cv2.THRESH_BINARY) 45 | bbox = cv2.boundingRect(cv2.findNonZero(foreground)) # (x, y, width, height) 46 | 47 | if bbox is None: 48 | print('bbox none for {} (???)'.format(fname)) 49 | else: 50 | left, upper, width, height = bbox 51 | 52 | # if we selected less than 80% of the original 53 | # height, just crop the square 54 | if width < 0.8 * h or height < 0.8 * h: 55 | print('bbox too small for {}'.format(fname)) 56 | bbox = None 57 | else: 58 | bbox = None 59 | 60 | if bbox is None: 61 | bbox = square_bbox(w, h) 62 | 63 | # do croping 64 | left, upper, width, height = bbox 65 | img = img[upper:upper+height, left:left+width, ...] 66 | img_CW = img_CW[upper:upper+height, left:left+width] 67 | img_FP = img_FP[upper:upper+height, left:left+width] 68 | img_EX = img_EX[upper:upper+height, left:left+width] 69 | img_MA = img_MA[upper:upper+height, left:left+width] 70 | img_NS = img_NS[upper:upper+height, left:left+width] 71 | img_PH = img_PH[upper:upper+height, left:left+width] 72 | img_RH = img_RH[upper:upper+height, left:left+width] 73 | img_VH = img_VH[upper:upper+height, left:left+width] 74 | 75 | #padding 76 | if width != height: 77 | if width > height: 78 | pad_width = width - height 79 | pad = ((pad_width//2, pad_width-pad_width//2), (0, 0)) 80 | else: 81 | pad_width = height - width 82 | pad = ((0, 0), (pad_width // 2, pad_width - pad_width // 2)) 83 | img = np.pad(img, (pad[0], pad[1], (0,0)), 'constant', constant_values=0) 84 | img_CW = np.pad(img_CW, pad, 'constant', constant_values=0) 85 | img_FP = np.pad(img_FP, pad, 'constant', constant_values=0) 86 | img_EX = np.pad(img_EX, pad, 'constant', constant_values=0) 87 | img_MA = np.pad(img_MA, pad, 'constant', constant_values=0) 88 | img_NS = np.pad(img_NS, pad, 'constant', constant_values=0) 89 | img_PH = np.pad(img_PH, pad, 'constant', constant_values=0) 90 | img_RH = np.pad(img_RH, pad, 'constant', constant_values=0) 91 | img_VH = np.pad(img_VH, pad, 'constant', constant_values=0) 92 | 93 | # resizing 94 | img = cv2.resize(img, (crop_size, crop_size), interpolation=cv2.INTER_CUBIC) 95 | img_CW = cv2.resize(img_CW, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 96 | img_FP = cv2.resize(img_FP, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 97 | img_EX = cv2.resize(img_EX, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 98 | img_MA = cv2.resize(img_MA, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 99 | img_NS = cv2.resize(img_NS, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 100 | img_PH = cv2.resize(img_PH, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 101 | img_RH = cv2.resize(img_RH, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 102 | img_VH = cv2.resize(img_VH, (crop_size, crop_size), interpolation=cv2.INTER_NEAREST) 103 | 104 | 105 | cv2.imwrite(convert_fname, img) 106 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 107 | _, binary = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY) 108 | cv2.imwrite(convert_fname.replace('.jpg', '_MASK.png'), binary) 109 | 110 | cv2.imwrite(convert_fname.replace('.jpg', '_CW.png'), img_CW) 111 | cv2.imwrite(convert_fname.replace('.jpg', '_FP.png'), img_FP) 112 | cv2.imwrite(convert_fname.replace('.jpg', '_EX.png'), img_EX) 113 | cv2.imwrite(convert_fname.replace('.jpg', '_MA.png'), img_MA) 114 | cv2.imwrite(convert_fname.replace('.jpg', '_NS.png'), img_NS) 115 | cv2.imwrite(convert_fname.replace('.jpg', '_PH.png'), img_PH) 116 | cv2.imwrite(convert_fname.replace('.jpg', '_RH.png'), img_RH) 117 | cv2.imwrite(convert_fname.replace('.jpg', '_VH.png'), img_VH) 118 | 119 | 120 | def square_bbox(w, h): 121 | left = (w - h) // 2 122 | upper = 0 123 | right = left + h 124 | lower = h 125 | return (left, upper, right-left, lower-upper) 126 | 127 | 128 | def convert_square(fname, crop_size): 129 | img = Image.open(fname) 130 | bbox = square_bbox(img) 131 | cropped = img.crop(bbox) 132 | resized = cropped.resize([crop_size, crop_size]) 133 | return resized 134 | 135 | 136 | def get_convert_fname(fname, extension, directory, convert_directory): 137 | return fname.replace('jpg', extension).replace(directory, 138 | convert_directory) 139 | 140 | 141 | def process(args): 142 | fun, arg = args 143 | directory, convert_directory, fname, crop_size, extension = arg 144 | convert_fname = get_convert_fname(fname, extension, directory, 145 | convert_directory) 146 | if not os.path.exists(convert_fname): 147 | img = fun(fname, crop_size, convert_fname) 148 | 149 | 150 | def save(img, fname): 151 | img.save(fname, quality=97) 152 | 153 | @click.command() 154 | @click.option('--directory', default='retinal-lesions/images_896x896', show_default=True, 155 | help="Directory with original images.") 156 | @click.option('--convert_directory', default='resized_512/', show_default=True, 157 | help="Where to save converted images.") 158 | @click.option('--crop_size', default=512, show_default=True, 159 | help="Size of converted images.") 160 | @click.option('--extension', default='jpg', show_default=True, 161 | help="Filetype of converted images.") 162 | def main(directory, convert_directory, crop_size, extension): 163 | 164 | try: 165 | os.mkdir(convert_directory) 166 | except OSError: 167 | pass 168 | 169 | filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(directory) 170 | for f in fn if f.endswith('jpeg') or f.endswith('jpg') or f.endswith('png') or f.endswith('tiff')] 171 | filenames = sorted(filenames) 172 | 173 | print("Resizing images in {} to {}, this takes a while." 174 | "".format(directory, convert_directory)) 175 | 176 | n = len(filenames) 177 | # process in batches, sometimes weird things happen with Pool on my machine 178 | batchsize = 20 179 | batches = n // batchsize + 1 180 | pool = Pool(N_PROC) 181 | 182 | args = [] 183 | 184 | for f in filenames: 185 | args.append((convert, (directory, convert_directory, f, crop_size, 186 | extension))) 187 | # break # TODO: Debug 188 | 189 | for i in range(batches): 190 | print("batch {:>2} / {}".format(i + 1, batches)) 191 | pool.map(process, args[i * batchsize: (i + 1) * batchsize]) 192 | 193 | pool.close() 194 | 195 | print('done') 196 | 197 | if __name__ == '__main__': 198 | main() 199 | -------------------------------------------------------------------------------- /DMB_fragment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import numpy as np 4 | from scipy import ndimage 5 | from random import randint 6 | from StyleFeature import STYLE_LAYERS, STYLE_LAYERS_SIZE, STYLE_LAYERS_CHANNELS 7 | import numpy as np 8 | 9 | palette = [ 10 | [182,182,254], [255,219,152], [168,255,153], 11 | 12 | [ 31, 119, 180], 13 | [255, 127, 14], 14 | [ 44, 160, 44], 15 | [214, 39, 40], 16 | [148, 103, 189], 17 | [140, 86, 75], 18 | [227, 119, 194], 19 | [127, 127, 127], 20 | [188, 189, 34], 21 | [ 23, 190, 207], 22 | 23 | [161, 201, 244], 24 | [255, 180, 130], 25 | [141, 229, 161], 26 | [255, 159, 155], 27 | [208, 187, 255], 28 | [222, 187, 155], 29 | [250, 176, 228], 30 | [207, 207, 207], 31 | [255, 254, 163], 32 | [185, 242, 240] 33 | ] 34 | palette = np.array(palette, dtype='uint8') 35 | np.random.seed(299792458) 36 | palette = np.concatenate([palette, np.random.randint(256, size=[1000,3], dtype='uint8')]) 37 | 38 | def extract_descriptors(intermed_amap: np.ndarray, featmap: np.ndarray, seg_label: np.ndarray, 39 | dataset_name: str, img_id: int, fname_debug='debug/test.png'): 40 | """segment the lesions in featmap, 41 | predict the segment label by seg_label, 42 | return the descriptors (fragments) with its attributes""" 43 | 44 | height, width = featmap.shape[:2] 45 | 46 | # to gray scale 47 | gray = np.abs(featmap) 48 | gray = np.mean(gray, axis=2) 49 | gray = cv2.GaussianBlur(gray, (51, 51), 10) 50 | gray = gray / gray.max() 51 | cv2.imwrite(fname_debug.replace('.png', '_aSeg2.png'), gray * 255) 52 | 53 | # to binary and segmentation 54 | ret, binary = cv2.threshold((gray*255).astype('uint8'), 0, 255, cv2.THRESH_OTSU) 55 | contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 56 | img_contours = cv2.drawContours(binary*0, contours, -1, (255, 255, 255)) 57 | cv2.imwrite(fname_debug.replace('.png', '_aSeg3.png'), img_contours) 58 | 59 | seg_label = (seg_label*255).astype('uint8') 60 | 61 | # extract each fragment and its attributs 62 | fragments = [] 63 | for i in range(len(contours)): 64 | mask_i = cv2.drawContours(np.zeros([height, width], dtype=np.uint8), 65 | contours, i, (255, 255, 255), cv2.FILLED) 66 | cv2.imwrite(fname_debug.replace('.png', '_aSeg4.png'), mask_i) 67 | 68 | # compute scale and location 69 | original_scale = (mask_i > 0).sum() ** 0.5 # sqrt of the area 70 | (x_min, y_min), = contours[i].min(axis=0) 71 | (x_max, y_max), = contours[i].max(axis=0) 72 | 73 | # prepare intermed_amap fragment 74 | fragment_data = {} 75 | for amap_size, amap_data in intermed_amap.items(): 76 | ratio = 512 // amap_size # down sampling ratio 77 | mask_frag = cv2.resize(mask_i, (amap_size, amap_size), interpolation=cv2.INTER_LINEAR) 78 | 79 | fragment = amap_data[0].copy() # fragment.shape=[256,256,32] [64, 64, 64] 80 | fragment[mask_frag==0] = 0. 81 | fragment = fragment[y_min//ratio:y_max//ratio+1, x_min//ratio:x_max//ratio+1] 82 | # cv2.imwrite(fname_debug.replace('.png', '_aSeg6.png'), (fragment[..., :3]/0.05+0.5)*255) 83 | fragment_data[amap_size] = fragment 84 | 85 | # debug: visualize seg_label 86 | # vis = np.zeros([height, width, 3], dtype='uint8') 87 | # vis[..., 2] = vis[..., 0] | seg_label[..., 0] 88 | # vis[..., 1] = vis[..., 1] | seg_label[..., 1] 89 | # vis[..., 1:] = vis[..., 1:] | seg_label[..., 2:3] 90 | # vis[..., :2] = vis[..., :2] | seg_label[..., 3:] 91 | # cv2.imwrite(fname_debug.replace('.png', '_aSeg5.png'), vis) 92 | 93 | # predict label of the predicted 94 | overlap = (seg_label & mask_i[..., None]) 95 | overlap_score = (overlap>0).sum(axis=(0,1)) 96 | predict_label = np.argmax(overlap_score) 97 | if overlap_score[predict_label] == 0: 98 | predict_label = -1 99 | 100 | 101 | fragments.append((x_min, y_min, x_max, y_max, original_scale, # location, scale 102 | fragment_data, # data 103 | dataset_name, 104 | img_id, 105 | predict_label, # label 106 | )) 107 | 108 | # debug: visulize with label 109 | # labels = ['MA','HE','EX','SE','UNK',] 110 | # colors = np.random.randint(0, 256, [len(contours), 3]) 111 | # colors[17] = [182,182,254]; colors[11] = [255,219,152]; colors[1] = [168,255,153] 112 | # img_contours = cv2.imread('/home/nyh/tllt/IDRiD/test_512/IDRiD_69.jpg') # np.zeros([512, 512, 3], 'uint8') 113 | # for i in range(len(contours)): 114 | # img_contours = cv2.drawContours(img_contours, contours, i, tuple(colors[i].tolist()), 2) 115 | # cv2.putText(img_contours, 116 | # labels[fragments[i][-1]], 117 | # (fragments[i][0], fragments[i][1]), cv2.FONT_HERSHEY_PLAIN, 1, tuple(colors[i].tolist()), 1) 118 | # cv2.imwrite(fname_debug.replace('.png', '_aSeg3.add.png'), img_contours) 119 | 120 | return fragments 121 | 122 | 123 | def rebuild_AMaps_by_cat(fragments, fragments_DB_by_cat, height=512, width=512, fname_debug='debug/test.png'): 124 | amap = {size: np.zeros([size, size, n_channel], dtype='float32') 125 | for size, n_channel in zip(STYLE_LAYERS_SIZE, STYLE_LAYERS_CHANNELS)} 126 | for x, y, category, seed, scale, rotation in fragments: 127 | seed = seed % len(fragments_DB_by_cat[category]) 128 | x_min, y_min, x_max, y_max, original_scale,\ 129 | feat_fragment, _, _, predict_label = fragments_DB_by_cat[category][seed] 130 | 131 | for size in STYLE_LAYERS_SIZE: 132 | ratio = height // size # down sampling ratio 133 | 134 | # padding 135 | padding = (y//ratio, max(0, height//ratio - (y//ratio + y_max//ratio - y_min//ratio + 1))), \ 136 | (x//ratio, max(0, height//ratio - (x//ratio + x_max//ratio - x_min//ratio + 1))), \ 137 | (0, 0) 138 | feat_fragment_to_add = np.pad(feat_fragment[size], padding, 'constant') 139 | feat_fragment_to_add = feat_fragment_to_add[:height//ratio, :width//ratio, :] 140 | # cv2.imwrite(fname_debug.replace('.png', '_aSeg7.png'), (feat_fragment[size]/2+0.5) * 255) 141 | 142 | # rotating 143 | M_rot = cv2.getRotationMatrix2D(center=(x//ratio, y//ratio), angle=rotation, scale=scale) 144 | feat_fragment_to_add = cv2.warpAffine(feat_fragment_to_add, M_rot, (height//ratio, width//ratio)) 145 | # cv2.imwrite(fname_debug.replace('.png', '_aSeg8.png'), (feat_fragment[size] / 2 + 0.5) * 255) 146 | 147 | amap[size] += feat_fragment_to_add 148 | 149 | return amap 150 | 151 | 152 | def rebuild_AMaps_by_img(imgid, fragments_DB_by_img, height=512, width=512, fname_debug='debug/test.png', 153 | randomize=False, quantity=None, multiple=None, lesion_map=None): 154 | """Reconstruct AMaps from descriptors that extracted from one reference image""" 155 | amap = {size: np.zeros([size, size, n_channel], dtype='float32') 156 | for size, n_channel in zip(STYLE_LAYERS_SIZE, STYLE_LAYERS_CHANNELS)} 157 | fragments = fragments_DB_by_img[imgid] 158 | if lesion_map is not None: # set palette 159 | my_palette = palette[:len(fragments)] 160 | lesion_map = np.zeros([256, 256, 3], dtype='uint8') 161 | if quantity is not None: # manipulate lesion quantity (0.x times) 162 | import random 163 | index = random.sample(range(len(fragments)), int(quantity*len(fragments))) 164 | if lesion_map is not None: 165 | my_palette = my_palette[index] 166 | fragments = [fragments[i] for i in index] 167 | if multiple is not None: # manipulate lesion quantity (n times) 168 | fragments = fragments*multiple 169 | for fid, fragment in enumerate(fragments): 170 | x_min, y_min, x_max, y_max, original_scale, \ 171 | feat_fragment, _, _, predict_label = fragment 172 | 173 | if randomize: 174 | new_x, new_y = 8*np.random.randint(0, 64), 8*np.random.randint(0, 64) 175 | x_min, x_max = new_x, new_x + x_max - x_min 176 | y_min, y_max = new_y, new_y + y_max - y_min 177 | 178 | for size in STYLE_LAYERS_SIZE: 179 | ratio = height // size 180 | 181 | # padding 182 | padding = (y_min//ratio, max(0, height//ratio - (y_max//ratio + 1))), \ 183 | (x_min//ratio, max(0, height//ratio - (x_max//ratio + 1))), \ 184 | (0, 0) 185 | feat_fragment_to_add = np.pad(feat_fragment[size], padding, 'constant') 186 | feat_fragment_to_add = feat_fragment_to_add[:height//ratio, :width//ratio, :] 187 | # cv2.imwrite(fname_debug.replace('.png', '_aSeg7.png'), (feat_fragment[size]/2+0.5) * 255) 188 | 189 | amap[size] += feat_fragment_to_add 190 | 191 | if lesion_map is not None and size == 256: 192 | lesion_map = lesion_map | ( 193 | my_palette[fid % len(my_palette)] & 194 | (255*np.any(feat_fragment_to_add != 0, axis=2, keepdims=True)).astype('uint8') 195 | ) 196 | 197 | ## DEBUG 198 | # def reg(x): 199 | # return (x-x.min())/(x.max()-x.min()) 200 | # cv2.imwrite(fname_debug.replace('.png', '_aSeg8.256.{}.png'.format(imgid)), reg(np.sum(amap[256], -1)) * 255) 201 | # cv2.imwrite(fname_debug.replace('.png', '_aSeg8.64.{}.png'.format(imgid)), reg(np.sum(amap[64], -1)) * 255) 202 | if lesion_map is not None: 203 | return amap, lesion_map 204 | return amap 205 | -------------------------------------------------------------------------------- /dataBlocks.py: -------------------------------------------------------------------------------- 1 | # author is He Zhao 2 | # The time to create is 2:18 PM, 7/12/16 3 | 4 | import numpy as np 5 | from multiprocessing import Process, Queue 6 | 7 | 8 | class DataIterator(object): 9 | def __init__(self, *data, **params): 10 | ''' 11 | PARAMS: 12 | fullbatch (bool): decides if the number of examples return after every 13 | iteration should be always a full batch. 14 | ''' 15 | self.data = data 16 | self.batchsize = params['batchsize'] 17 | if 'fullbatch' in params: 18 | self.fullbatch = params['fullbatch'] 19 | else: 20 | self.fullbatch = False 21 | 22 | def __iter__(self): 23 | self.first = 0 24 | return self 25 | 26 | def __len__(self): 27 | return len(self.data[0]) 28 | 29 | def __getitem__(self, key): 30 | outs = [] 31 | for val in self.data: 32 | outs.append(val[key]) 33 | return self.__class__(*outs, batchsize=self.batchsize, fullbatch=self.fullbatch) 34 | 35 | 36 | class SequentialIterator(DataIterator): 37 | ''' 38 | batchsize = 3 39 | [0, 1, 2], [3, 4, 5], [6, 7, 8] 40 | ''' 41 | def __next__(self): 42 | if self.fullbatch and self.first+self.batchsize > len(self): 43 | raise StopIteration() 44 | elif self.first >= len(self): 45 | raise StopIteration() 46 | 47 | outs = [] 48 | for val in self.data: 49 | outs.append(val[self.first:self.first+self.batchsize]) 50 | self.first += self.batchsize 51 | return outs 52 | 53 | 54 | class StepIterator(DataIterator): 55 | ''' 56 | batchsize = 3 57 | step = 1 58 | [0, 1, 2], [1, 2, 3], [2, 3, 4] 59 | ''' 60 | def __init__(self, *data, **params): 61 | super(self, StepIterator).__init__(self, *data, **params) 62 | self.step = params['step'] 63 | 64 | def __next__(self): 65 | if self.fullbatch and self.first+self.batchsize > len(self): 66 | raise StopIteration() 67 | elif self.first >= len(self): 68 | raise StopIteration() 69 | 70 | outs = [] 71 | for val in self.data: 72 | outs.append(val[self.first:self.first+self.batchsize]) 73 | self.first += self.step 74 | return outs 75 | 76 | 77 | def np_load_func(path): 78 | arr = np.load(path) 79 | return arr 80 | 81 | 82 | class DataBlocks(object): 83 | 84 | def __init__(self, data_paths, batchsize=32, load_func=np_load_func, allow_preload=False): 85 | """ 86 | DESCRIPTION: 87 | This is class for processing blocks of data, whereby dataset is loaded 88 | and unloaded into memory one block at a time. 89 | PARAM: 90 | data_paths (list or list of list): contains list of paths for data loading, 91 | example: 92 | [f1a.npy, f1b.npy, f1c.npy] or 93 | [(f1a.npy, f1b.npy, f1c.npy), (f2a.npy, f2b.npy, f2c.npy)] 94 | load_func (function): function for loading the data_paths, default to 95 | numpy file loader 96 | allow_preload (bool): by allowing preload, it will preload the next data block 97 | while training at the same time on the current datablock, 98 | this will reduce time but will also cost more memory. 99 | """ 100 | 101 | assert isinstance(data_paths, (list)), "data_paths is not a list" 102 | self.data_paths = data_paths 103 | self.batchsize = batchsize 104 | self.load_func = load_func 105 | self.allow_preload = allow_preload 106 | self.q = Queue() 107 | 108 | 109 | def __iter__(self): 110 | self.files = iter(self.data_paths) 111 | if self.allow_preload: 112 | self.lastblock = False 113 | bufile = next(self.files) 114 | self.load_file(bufile, self.q) 115 | return self 116 | 117 | 118 | def __next__(self): 119 | if self.allow_preload: 120 | if self.lastblock: 121 | raise StopIteration 122 | 123 | try: 124 | arr = self.q.get(block=True, timeout=None) 125 | self.iterator = SequentialIterator(*arr, batchsize=self.batchsize) 126 | bufile = next(self.files) 127 | p = Process(target=self.load_file, args=(bufile, self.q)) 128 | p.start() 129 | except: 130 | self.lastblock = True 131 | else: 132 | fpath = next(self.files) 133 | arr = self.load_file(fpath) 134 | self.iterator = SequentialIterator(*arr, batchsize=self.batchsize) 135 | 136 | return self.iterator 137 | 138 | 139 | def load_file(self, paths, queue=None): 140 | ''' 141 | paths (list or str): [] 142 | ''' 143 | data = [] 144 | if isinstance(paths, (list, tuple)): 145 | for path in paths: 146 | data.append(self.load_func(path)) 147 | else: 148 | data.append(self.load_func(paths)) 149 | if queue: 150 | queue.put(data) 151 | return data 152 | 153 | 154 | @property 155 | def nblocks(self): 156 | return len(self.data_paths) 157 | 158 | 159 | class SimpleBlocks(object): 160 | 161 | def __init__(self, data_paths, batchsize=32, load_func=np_load_func, allow_preload=False): 162 | """ 163 | DESCRIPTION: 164 | This is class for processing blocks of data, whereby dataset is loaded 165 | and unloaded into memory one block at a time. 166 | PARAM: 167 | data_paths (list or list of list): contains list of paths for data loading, 168 | example: 169 | [f1a.npy, f2a.npy, f3a.npy] ==> 1 col, 3 blocks or 170 | [(f1a.npy, f1b.npy, f1c.npy), (f2a.npy, f2b.npy, f2c.npy)] ==> 3 cols, 2 blocks 171 | load_func (function): function for loading the data_paths, default to 172 | numpy file loader 173 | allow_preload (bool): by allowing preload, it will preload the next data block 174 | while training at the same time on the current datablock, 175 | this will reduce time but will also cost more memory. 176 | """ 177 | 178 | assert isinstance(data_paths, (list)), "data_paths is not a list" 179 | self.data_paths = data_paths 180 | self.batchsize = batchsize 181 | self.load_func = load_func 182 | self.allow_preload = allow_preload 183 | self.q = Queue() 184 | 185 | 186 | def __iter__(self): 187 | self.files = iter(self.data_paths) 188 | if self.allow_preload: 189 | self.lastblock = False 190 | bufile = next(self.files) 191 | self.load_file(bufile, self.q) 192 | return self 193 | 194 | 195 | def __next__(self): 196 | if self.allow_preload: 197 | if self.lastblock: 198 | raise StopIteration 199 | 200 | try: 201 | arr = self.q.get(block=True, timeout=None) 202 | self.iterator = SequentialIterator(*arr, batchsize=self.batchsize) 203 | bufile = next(self.files) 204 | p = Process(target=self.load_file, args=(bufile, self.q)) 205 | p.start() 206 | except: 207 | self.lastblock = True 208 | else: 209 | fpath = next(self.files) 210 | arr = self.load_file(fpath) 211 | self.iterator = SequentialIterator(*arr, batchsize=self.batchsize) 212 | 213 | return self.iterator 214 | 215 | 216 | def load_file(self, paths, queue=None): 217 | ''' 218 | paths (list or str): [] 219 | ''' 220 | data = [] 221 | if isinstance(paths, (list, tuple)): 222 | for path in paths: 223 | data.append(self.load_func(path)) 224 | else: 225 | data.append(self.load_func(paths)) 226 | if queue: 227 | queue.put(data) 228 | return data 229 | 230 | 231 | @property 232 | def nblocks(self): 233 | return len(self.data_paths) 234 | 235 | 236 | class DataBlocks(SimpleBlocks): 237 | 238 | def __init__(self, data_paths, train_valid_ratio=[5,1], batchsize=32, load_func=np_load_func, allow_preload=False): 239 | """ 240 | DESCRIPTION: 241 | This is class for processing blocks of data, whereby dataset is loaded 242 | and unloaded into memory one block at a time. 243 | PARAM: 244 | data_paths (list or list of list): contains list of paths for data loading, 245 | example: 246 | [f1a.npy, f1b.npy, f1c.npy] or 247 | [(f1a.npy, f1b.npy, f1c.npy), (f2a.npy, f2b.npy, f2c.npy)] 248 | load_func (function): function for loading the data_paths, default to 249 | numpy file loader 250 | allow_preload (bool): by allowing preload, it will preload the next data block 251 | while training at the same time on the current datablock, 252 | this will reduce time but will also cost more memory. 253 | """ 254 | 255 | assert isinstance(data_paths, (list)), "data_paths is not a list" 256 | self.data_paths = data_paths 257 | self.train_valid_ratio = train_valid_ratio 258 | self.batchsize = batchsize 259 | self.load_func = load_func 260 | self.allow_preload = allow_preload 261 | self.q = Queue() 262 | 263 | 264 | def __next__(self): 265 | if self.allow_preload: 266 | if self.lastblock: 267 | raise StopIteration 268 | 269 | try: 270 | train, valid = self.q.get(block=True, timeout=None) 271 | self.train_iterator = SequentialIterator(*train, batchsize=self.batchsize) 272 | self.valid_iterator = SequentialIterator(*valid, batchsize=self.batchsize) 273 | bufile = next(self.files) 274 | p = Process(target=self.load_file, args=(bufile, self.q)) 275 | p.start() 276 | except: 277 | self.lastblock = True 278 | else: 279 | fpath = next(self.files) 280 | train, valid = self.load_file(fpath) 281 | self.train_iterator = SequentialIterator(*train, batchsize=self.batchsize, fullbatch=True) 282 | self.valid_iterator = SequentialIterator(*valid, batchsize=self.batchsize, fullbatch=True) 283 | return self.train_iterator, self.valid_iterator 284 | 285 | 286 | def load_file(self, paths, queue=None): 287 | ''' 288 | paths (list or str): [] 289 | ''' 290 | train = [] 291 | valid = [] 292 | if isinstance(paths, (list, tuple)): 293 | for path in paths: 294 | X = self.load_func(path) 295 | num_train = len(X) * self.train_valid_ratio[0] * 1.0 / sum(self.train_valid_ratio) 296 | num_train = int(num_train) 297 | train.append(X[:num_train]) 298 | valid.append(X[num_train:]) 299 | else: 300 | X = self.load_func(paths) 301 | # np.random.shuffle(X) 302 | num_train = len(X) * self.train_valid_ratio[0] * 1.0 / sum(self.train_valid_ratio) 303 | num_train = int(num_train) 304 | train.append(X[:num_train]) 305 | valid.append(X[num_train:]) 306 | data = [train, valid] 307 | if queue: 308 | queue.put(data) 309 | return data 310 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | # derived from He Zhao's version created at 10:23 AM, 29/11/16 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('dataset_name') 6 | parser.add_argument('test1') 7 | parser.add_argument('test2') 8 | parser.add_argument('test3') 9 | parser.add_argument('test4') 10 | parser.add_argument('--gpus') 11 | args = parser.parse_args() 12 | print(args) 13 | 14 | import os 15 | if args.gpus: 16 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 17 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 18 | print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"]) 19 | 20 | import keras.backend as K 21 | import tensorflow as tf 22 | import detector 23 | from tensorflow.python.training import training_util 24 | import Net 25 | import numpy as np 26 | import os 27 | import time 28 | import StyleFeature 29 | import scipy.io as sio 30 | from scipy import ndimage 31 | import imageio 32 | import utils 33 | import pdb 34 | import cv2 35 | import threading 36 | import queue 37 | #import matplotlib.pyplot as plt 38 | import pickle 39 | import yaml 40 | 41 | from Opts import save_images, matTonpy 42 | # =============================== path set =============================================== # 43 | load_model = None # 'pe_nosty_base' # 'initial_model' 44 | save_model = False 45 | real_img_dataset = args.dataset_name #'retinal-lesions' # 'DRIVE' 46 | real_img_test_dataset = args.dataset_name #'retinal-lesions' # 'DRIVE' 47 | 48 | # ============================== parameters set ========================================== # 49 | #adversarial 50 | w_adv = 1 51 | #pathological 52 | w_patho = 1e5 # 50e6 53 | #retinal_details 54 | w_retinal = 1 55 | #tv 56 | w_tv = 100 57 | #severity 58 | w_severity = 10 59 | 60 | 61 | # ============================== model set ========================================== # 62 | model_name = args.dataset_name 63 | 64 | result_dir = 'Model_and_Result' + '/' + model_name + '' 65 | sample_directory = result_dir + '/figs' 66 | sample_directory2 = result_dir + '/figs_mask' 67 | summary_dir = result_dir + '/summary' 68 | # Directory to save sample images from generator in. 69 | model_directory = result_dir + '/models' # Directory to save trained model to. 70 | 71 | print('Experiment name:', model_name) 72 | 73 | if tf.gfile.Exists(result_dir): 74 | print('Result dir exists! Press Enter to OVERRIDE...', end='') 75 | input() 76 | tf.gfile.DeleteRecursively(result_dir) 77 | if not os.path.exists(sample_directory): 78 | os.makedirs(sample_directory) 79 | if not os.path.exists(sample_directory2): 80 | os.makedirs(sample_directory2) 81 | if not os.path.exists(summary_dir): 82 | os.makedirs(summary_dir) 83 | if not os.path.exists(model_directory): 84 | os.makedirs(model_directory) 85 | 86 | os.system('cp {} {}'.format(__file__, result_dir)) 87 | os.system('cp {} {}'.format('Net.py', result_dir)) 88 | os.system('cp {} {}'.format('Opts.py', result_dir)) 89 | os.system('cp {} {}'.format('StyleFeature.py', result_dir)) 90 | 91 | with open(model_directory + '/training_log.txt', 'w') as f: 92 | f.close() 93 | # ============================== parameters set ========================================== # 94 | 95 | learning_rate = 0.0002 / 5 96 | beta1 = 0.5 97 | 98 | batch_size = 1 # Size of image batch to apply at each iteration. 99 | max_epoch = 1000 100 | 101 | img_channel = 3 102 | img_size = 512 103 | img_x = 512 104 | img_y = 512 105 | padding_l = 0 106 | padding_r = 0 107 | padding_t = 0 108 | padding_d = 0 109 | gt_channel = 1 110 | 111 | style_size = 512 112 | 113 | sample_batch = 4 114 | z_size = 400 115 | 116 | 117 | 118 | # =============================== model and data definition ================================ # 119 | from Net import generator, discriminator, build_data 120 | 121 | tf.reset_default_graph() 122 | 123 | gt = tf.placeholder(shape=[None, img_size, img_size, gt_channel], dtype=tf.float32) 124 | img = tf.placeholder(shape=[None, img_size, img_size, img_channel], dtype=tf.float32) 125 | mask = tf.placeholder(shape=[None, img_size, img_size, 1], dtype=tf.float32) 126 | z = tf.placeholder(shape=[None, z_size], dtype=tf.float32) 127 | 128 | # gt_mask = tf.concat([gt, mask], 3) 129 | # gt_mask = tf.Print(gt_mask, [tf.reduce_any(tf.is_nan(gt_mask))], 'is_nan(gt_mask):') 130 | # zz = tf.Print(z, [tf.reduce_any(tf.is_nan(z))], 'is_nan(z):') 131 | 132 | img_detector = StyleFeature.get_style_model(img, mask, with_feature_mask_from=('dense_2', [])) 133 | act_input = { 134 | size: tf.image.resize_images((img_detector.get_layer(layer_name).related_projection.output - mean)/std, [size,size]) 135 | for layer_name, size, mean, std in zip(StyleFeature.STYLE_LAYERS, 136 | StyleFeature.STYLE_LAYERS_SIZE, 137 | StyleFeature.STYLE_LAYERS_MEAN, 138 | StyleFeature.STYLE_LAYERS_STD) 139 | } 140 | 141 | projection = (img_detector.get_layer('my_input').related_projection.output - 1e-6) / 0.05 142 | projection = tf.clip_by_value(projection, -0.5, 0.5) 143 | projection = projection * 2 144 | projection = tf.abs(projection) 145 | projection = tf.reduce_mean(projection, 3, keepdims=True) 146 | projection = tf.nn.conv2d(projection, StyleFeature.gauss_kernel(31, 10)[..., None, None], [1,1,1,1],'SAME') # gauss blur 147 | projection = projection / (tf.reduce_max(projection) + 1e-7) 148 | binary_mask = tf.py_func(lambda gray: cv2.threshold((gray[0, ..., 0]*255).astype('uint8'), 0, 255, cv2.THRESH_OTSU)[1][None, ..., None], 149 | [projection], tf.uint8) 150 | binary_mask256 = tf.cast(tf.image.resize_bilinear(binary_mask, [256, 256]) > 0, 'float32') 151 | binary_mask64 = tf.cast(tf.image.resize_bilinear(binary_mask, [64, 64]) > 0, 'float32') 152 | 153 | masked_act_input = { 154 | 256: binary_mask256 * act_input[256], 155 | 64: binary_mask64 * act_input[64], 156 | } 157 | 158 | syn = generator(gt, masked_act_input, z) 159 | # syn = tf.Print(syn, [tf.reduce_any(tf.is_nan(syn))], 'is_nan(syn):') 160 | 161 | real_img_gt = tf.concat([img*((mask+1)/2), gt], 3) 162 | fake_syn_gt = tf.concat([syn*((mask+1)/2), gt], 3) 163 | # real_img_gt = tf.Print(real_img_gt, [tf.reduce_any(tf.is_nan(real_img_gt))], 'is_nan(real_img_gt):') 164 | # fake_syn_gt = tf.Print(fake_syn_gt, [tf.reduce_any(tf.is_nan(fake_syn_gt))], 'is_nan(fake_syn_gt):') 165 | 166 | Dx, Dx_logits = discriminator(real_img_gt) 167 | # Dx = tf.Print(Dx, [tf.reduce_any(tf.is_nan(Dx))], 'is_nan(Dx):') 168 | Dg, Dg_logits = discriminator(fake_syn_gt, reuse=True) 169 | 170 | db = build_data(batch_size, dataset_name=real_img_dataset) 171 | 172 | syn_detector = StyleFeature.get_style_model(syn, mask, with_feature_mask_from=('dense_2', [])) 173 | 174 | # ============================================================================================# 175 | # discriminator loss 176 | with tf.name_scope('d_loss'): 177 | d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=Dx_logits, labels=tf.ones_like(Dx))) 178 | d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=Dg_logits, labels=tf.zeros_like(Dg))) 179 | d_loss = d_loss_real + d_loss_fake 180 | 181 | # generator loss 182 | with tf.name_scope('g_loss'): 183 | 184 | g_loss_adversarial = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=Dg_logits, labels=tf.ones_like(Dg))) 185 | 186 | g_loss_patho = StyleFeature.get_patho_loss(img_detector, syn_detector) 187 | 188 | g_loss_severity = StyleFeature.get_severity_loss(img_detector, syn_detector) 189 | 190 | g_loss_retinal = StyleFeature.get_retinal_loss(img, syn, mask) 191 | 192 | g_loss_tv = StyleFeature.get_tv_loss(syn, mask) # + maskedtv_weight * StyleFeature.get_tv_loss(syn, mask, tv_mask) 193 | 194 | g_loss = w_adv * g_loss_adversarial \ 195 | + w_patho * g_loss_patho \ 196 | + w_severity * g_loss_severity \ 197 | + w_retinal * g_loss_retinal \ 198 | + w_tv * g_loss_tv 199 | 200 | 201 | # split the variable for two differentiable function 202 | t_vars = tf.trainable_variables() 203 | d_vars = list(set(t_vars) & set(tf.global_variables('discriminator'))) 204 | g_vars = list(set(t_vars) & set(tf.global_variables('generator'))) 205 | 206 | # optimizer 207 | global_step = tf.Variable(0, trainable=False) 208 | with tf.name_scope('train'): 209 | d_optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate*0.4).minimize(d_loss, var_list=d_vars, global_step=global_step) 210 | g_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars) 211 | 212 | # =============================== summary prepare ============================================= # 213 | 214 | # write summary 215 | Dx_sum = tf.summary.histogram("Dx", Dx) 216 | Dg_sum = tf.summary.histogram("Dg", Dg) 217 | 218 | Dx_sum_scalar = tf.summary.scalar("Dx_value", tf.reduce_mean(Dx)) 219 | Dg_sum_scalar = tf.summary.scalar("Dg_value", tf.reduce_mean(Dg)) 220 | 221 | # syn_sum = tf.image_summary("synthesize", syn) 222 | 223 | d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real) 224 | d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake) 225 | d_loss_sum = tf.summary.scalar("d_loss", d_loss) 226 | g_loss_patho_sum = tf.summary.scalar("g_loss_patho", g_loss_patho) 227 | g_loss_severity_sum = tf.summary.scalar("g_loss_severity", g_loss_severity) 228 | g_loss_retinal_sum = tf.summary.scalar("g_loss_retinal", g_loss_retinal) 229 | g_loss_tv_sum = tf.summary.scalar("g_loss_tv", g_loss_tv) 230 | g_loss_adversarial_sum = tf.summary.scalar("g_loss_adversarial", g_loss_adversarial) 231 | g_loss_sum = tf.summary.scalar("g_loss", g_loss) 232 | 233 | # g_sum = tf.merge_summary([Dg_sum, syn_sum, d_loss_fake_sum, g_loss_sum]) 234 | # d_sum = tf.merge_summary([Dx_sum, d_loss_real_sum, d_loss_fake_sum, d_loss_sum]) 235 | 236 | # g_sum = tf.merge_summary([Dg_sum_scalar]) 237 | # d_sum = tf.merge_summary([Dx_sum_scalar]) 238 | 239 | sum_merged = tf.summary.merge_all() 240 | sum_writer = tf.summary.FileWriter(summary_dir,) # graph=tf.get_default_graph()) 241 | 242 | # =============================== train phase ============================================= # 243 | init = tf.variables_initializer(d_vars + g_vars) 244 | sess = K.get_session() 245 | saver = tf.train.Saver(d_vars + g_vars, max_to_keep=None) 246 | 247 | sess.run(init) 248 | #writer = tf.train.SummaryWriter(model_directory, sess.graph) 249 | 250 | # ==================================== save initialization ================================ # 251 | if load_model: 252 | ckpt = tf.train.get_checkpoint_state('Model_and_Result/' + load_model + '/models') 253 | saver.restore(sess, ckpt.model_checkpoint_path) 254 | # saver.save(sess, model_directory + '/model-' + str(0) + '.cptk') 255 | print("load saved model and SAVE") 256 | elif save_model: 257 | saver.save(sess, model_directory + '/model-' + str(0) + '.cptk') 258 | print("Saved begining Model ") 259 | 260 | def augmented(db, q): 261 | for data_train, _ in db: 262 | for batch in data_train: 263 | xs, ys, ms = batch 264 | for i in range(batch_size): 265 | M = cv2.getRotationMatrix2D((np.random.randint(240, 272), np.random.randint(240, 272)), 266 | np.random.randint(-5, 5), np.random.uniform(0.97, 1.05)) 267 | xs[i] = cv2.warpAffine(xs[i], M, (512, 512)) 268 | ys[i] = cv2.warpAffine(ys[i], M, (512, 512)) 269 | ms[i] = cv2.warpAffine(ms[i], M, (512, 512)) 270 | # import IPython 271 | # IPython.embed() 272 | q.put((xs, ys, ms)) 273 | q.put(None) 274 | 275 | 276 | with open('DMB/{}.test_amaps_reconstruction'.format(real_img_test_dataset), 'rb') as file: 277 | test_amaps_reconstruction = pickle.load(file) 278 | 279 | with open('data/'+real_img_test_dataset+'_test.list', 'r') as f: 280 | file_list = yaml.safe_load(f) 281 | select = [file_list.index(args.test1), file_list.index(args.test2), file_list.index(args.test3), file_list.index(args.test4), ] 282 | 283 | img_sample = np.load('data/{}_test_image.npy'.format(real_img_test_dataset))[select, ...] 284 | gt_sample = np.load('data/{}_test_gt.npy'.format(real_img_test_dataset))[select, ..., [0]] 285 | mask_sample = np.load('data/{}_test_mask.npy'.format(real_img_test_dataset))[select, ...] 286 | 287 | img_sample = (np.reshape(img_sample, [-1, img_x, img_y, img_channel]) - 0.5) * 2.0 288 | gt_sample = (np.reshape(gt_sample, [-1, img_x, img_y, gt_channel]) - 0.5) * 2.0 289 | mask_sample = (np.reshape(mask_sample, [-1, img_x, img_y, 1]) - 0.5) * 2.0 290 | 291 | 292 | # ==================================== start training ===================================== # 293 | stime=time.time() 294 | q = queue.Queue() 295 | for epoch in range(max_epoch): 296 | batchNum = 1 297 | 298 | thread_aug = threading.Thread(target=augmented, args=[db, q]) 299 | thread_aug.setDaemon(True) 300 | thread_aug.start() 301 | 302 | while True: 303 | batch = q.get() 304 | if batch is None: 305 | break 306 | else: 307 | img_id = batchNum - 1 308 | 309 | z_sample = np.random.normal(0, 0.001, size=[batch_size, z_size]).astype(np.float32)#mean 310 | zs = z_sample 311 | 312 | xs, ys, ms = batch 313 | if xs.shape[0] != batch_size: 314 | continue 315 | 316 | # xs = np.transpose(xs, (0, 2, 3, 1)) 317 | xs = (np.reshape(xs, [batch_size, img_size, img_size, img_channel]) - 0.5) * 2.0 318 | xs = np.lib.pad(xs, ((0, 0), (padding_l, padding_r), (padding_t, padding_d), (0, 0)), 'constant', 319 | constant_values=(-1, -1)) # Pad the images so the are 32x32 320 | 321 | # ys = np.transpose(ys, (0, 2, 3, 1)) 322 | ys = ys[..., [0]] 323 | ys = (np.reshape(ys, [batch_size, img_size, img_size, gt_channel]) - 0.5) * 2.0 324 | ys = np.lib.pad(ys, ((0, 0), (padding_l, padding_r), (padding_t, padding_d), (0, 0)), 'constant', 325 | constant_values=(-1, -1)) # Pad the images so the are 32x32 326 | 327 | ms = (np.reshape(ms, [batch_size, img_size, img_size, 1]) - 0.5) * 2.0 328 | ms = np.lib.pad(ms, ((0, 0), (padding_l, padding_r), (padding_t, padding_d), (0, 0)), 'constant', 329 | constant_values=(-1, -1)) 330 | 331 | feed_dict = {img: xs, gt: ys, z: zs, mask: ms} 332 | # Update the discriminator 333 | _, dLoss = sess.run([d_optimizer, d_loss], feed_dict=feed_dict) 334 | 335 | # Update the generator, twice for good measure. 336 | _ = sess.run([g_optimizer], feed_dict=feed_dict) 337 | 338 | 339 | # _, gLoss, advL, pathoL, retinalL, tvL = sess.run([g_optimizer, g_loss, g_loss_adversarial, g_loss_patho, g_loss_retinal, g_loss_tv], feed_dict=feed_dict) 340 | _, gLoss, advL, pathoL, severityL, retinalL, tvL, summ = sess.run([g_optimizer, g_loss, g_loss_adversarial, g_loss_patho, g_loss_severity, g_loss_retinal, g_loss_tv, sum_merged], feed_dict=feed_dict) 341 | 342 | sum_writer.add_summary(summ, training_util.global_step(sess, global_step)) 343 | print("[Epoch: %2d.%2d / %2d] [%4d]G Loss: %.4f D Loss: %.4f, patho: %.4f, severity: %.4f, retinal: %.4f, adv: %.4f, tv: %.4f" \ 344 | % (epoch, img_id, max_epoch, batchNum, gLoss, dLoss, w_patho*pathoL, w_severity*severityL, w_retinal*retinalL, w_adv*advL, w_tv*tvL)) 345 | with open(model_directory + '/training_log.txt', 'a') as text_file: 346 | text_file.write( 347 | "[Epoch: %2d.%2d / %2d] [%4d]G Loss: %.4f D Loss: %.4f, patho: %.4f, severity: %.4f, retinal: %.4f, adv: %.4f, tv: %.4f \n" 348 | % (epoch, img_id, max_epoch, batchNum, gLoss, dLoss, w_patho*pathoL, w_severity*severityL, w_retinal*retinalL, w_adv*advL, w_tv*tvL)) 349 | batchNum += 1 350 | if training_util.global_step(sess, global_step) % 100 == 0: 351 | 352 | 353 | 354 | z1 = np.random.normal(0, 0.001, size=[1, z_size]).astype(np.float32) 355 | z2 = np.random.normal(0, 1.0, size=[1, z_size]).astype(np.float32) 356 | z3 = np.random.normal(0, 0.01, size=[1, z_size]).astype(np.float32) 357 | 358 | sa, sb, sc, sd = 0, 1, 2, 3 359 | 360 | 361 | syn_sample_a, dLreal_val_a, dLfake_val_a = sess.run([syn, Dx, Dg], 362 | feed_dict={img: [img_sample[sa]], gt: [gt_sample[sa]], z: z1, mask:[mask_sample[sa]], 363 | act_input[64]: [test_amaps_reconstruction[sa][64]], 364 | act_input[256]: [test_amaps_reconstruction[sa][256]] 365 | }) 366 | syn_sample_b, dLreal_val_b, dLfake_val_b = sess.run([syn, Dx, Dg], 367 | feed_dict={img: [img_sample[sb]], gt: [gt_sample[sb]], z: z3, mask:[mask_sample[sb]], 368 | act_input[64]: [test_amaps_reconstruction[sb][64]], 369 | act_input[256]: [test_amaps_reconstruction[sb][256]] 370 | }) 371 | syn_sample_c, dLreal_val_c, dLfake_val_c = sess.run([syn, Dx, Dg], 372 | feed_dict={img: [img_sample[sc]], gt: [gt_sample[sc]], z: z2, mask:[mask_sample[sc]], 373 | act_input[64]: [test_amaps_reconstruction[sc][64]], 374 | act_input[256]: [test_amaps_reconstruction[sc][256]] 375 | }) 376 | syn_sample_d, dLreal_val_d, dLfake_val_d = sess.run([syn, Dx, Dg], 377 | feed_dict={img: [img_sample[sd]], gt: [gt_sample[sd]], z: zs[:1], mask:[mask_sample[sd]], 378 | act_input[64]: [test_amaps_reconstruction[sd][64]], 379 | act_input[256]: [test_amaps_reconstruction[sd][256]] 380 | }) 381 | 382 | syn_sample = np.concatenate((syn_sample_a, syn_sample_b, syn_sample_c,syn_sample_d),axis=0) 383 | 384 | syn_sample_am = (syn_sample_a + 1) * ((mask_sample[sa] + 1) / 2) - 1 385 | syn_sample_bm = (syn_sample_b + 1) * ((mask_sample[sb] + 1) / 2) - 1 386 | syn_sample_cm = (syn_sample_c + 1) * ((mask_sample[sc] + 1) / 2) - 1 387 | syn_sample_dm = (syn_sample_d + 1) * ((mask_sample[sd] + 1) / 2) - 1 388 | syn_sample_m = np.concatenate((syn_sample_am, syn_sample_bm, syn_sample_cm, syn_sample_dm), axis=0) 389 | 390 | dLreal_val = (dLreal_val_a + dLreal_val_b + dLreal_val_c + dLreal_val_d) / 4 391 | dLfake_val = (dLfake_val_a + dLfake_val_b + dLfake_val_c + dLfake_val_d) / 4 392 | 393 | # Save sample generator images for viewing training progress. 394 | save_images(np.reshape(syn_sample, [sample_batch, img_x, img_y, img_channel]), 395 | [int(np.sqrt(sample_batch)), int(np.sqrt(sample_batch))], 396 | sample_directory + '/fig' + str(training_util.global_step(sess, global_step)) + '.png') 397 | 398 | save_images(np.reshape(syn_sample_m, [sample_batch, img_x, img_y, img_channel]), 399 | [int(np.sqrt(sample_batch)), int(np.sqrt(sample_batch))], 400 | sample_directory2 + '/fig' + str(training_util.global_step(sess, global_step)) + '.png') 401 | 402 | print("[Sample (global_step = %d)] real: %.4f fake: %.4f" \ 403 | % (training_util.global_step(sess, global_step), np.mean(dLreal_val), np.mean(dLfake_val))) 404 | with open(model_directory + '/training_log.txt', 'a') as text_file: 405 | text_file.write("[Sample (global_step = %d)] real: %.4f fake: %.4f \n" 406 | % (training_util.global_step(sess, global_step), np.mean(dLreal_val), 407 | np.mean(dLfake_val))) 408 | 409 | if training_util.global_step(sess, global_step) % 30000 == 0: 410 | saver.save(sess, 411 | model_directory + '/model-' + str(training_util.global_step(sess, global_step)) + '.cptk') 412 | print("Saved Model %d, time: %.4f" % (training_util.global_step(sess, global_step), time.time()-stime)) 413 | 414 | 415 | 416 | saver.save(sess, model_directory + '/model-' + str(training_util.global_step(sess, global_step)) + '.cptk') 417 | print("Saved Model %d, time: %.4f" % (training_util.global_step(sess, global_step), time.time()-stime)) 418 | 419 | sess.close() 420 | 421 | --------------------------------------------------------------------------------