├── .gitignore ├── README.md ├── README_HPL.md ├── README_additional_cohort.md ├── README_bu.md ├── README_replication.md ├── data_manipulation ├── __init__.py ├── data.py ├── dataset.py ├── preprocessor.py └── utils.py ├── demos ├── HPL_visualizer_mutlticancer.gif ├── framework_methodology.jpg └── slides │ └── HPL Summary.pdf ├── models ├── activations.py ├── clustering │ ├── correlations.py │ ├── cox_proportional_hazard_regression_leiden_clusters.py │ ├── data_processing.py │ ├── evaluation_metrics.py │ ├── leiden_representations.py │ ├── leiden_representations_fold.py │ └── logistic_regression_leiden_clusters.py ├── clustering_utils.py ├── data_augmentation.py ├── evaluation │ ├── evaluation.py │ ├── features.py │ ├── folds.py │ ├── latent_space.py │ ├── metrics.py │ ├── prognosis.py │ └── tools.py ├── loss.py ├── mil │ ├── Attention_MIL.py │ ├── Attention_MIL_TCGA_COAD.py │ ├── Attention_MIL_TCGA_LUAD.py │ ├── Attention_MIL_TCGA_LUAD_Histo.py │ ├── Attention_MIL_TCGA_LUAD_MultiMagFlat.py │ ├── Attention_MIL_TCGA_LUAD_MultiMagFlat_Survival.py │ ├── Attention_MIL_TCGA_LUAD_MultiMagFlat_Survival_HighLow.py │ ├── Attention_MIL_TCGA_LUAD_MultiMagFlat_Survival_backup.py │ ├── Attention_MIL_TCGA_LUAD_MultiMagFlat_labels.py │ └── Attention_MIL_TCGA_LUAD_labels.py ├── networks │ ├── attention.py │ ├── discriminator.py │ ├── encoder_contrastive.py │ ├── encoder_gan.py │ └── generator.py ├── normalization.py ├── nuance.py ├── ops.py ├── optimizer.py ├── regularizers.py ├── score │ ├── crimage_score.py │ ├── frechet_inception_distance.py │ ├── inception_score.py │ ├── k_nearest_neighbor.py │ ├── kernel_inception_distance.py │ ├── mmd.py │ ├── mode_score.py │ ├── score.py │ └── utils.py ├── selfsupervised │ ├── BYOL.py │ ├── BarlowTwins.py │ ├── DINO.py │ ├── RealReas.py │ ├── SimCLR.py │ ├── SimSiam.py │ └── SwAV.py ├── tools.py ├── utils.py ├── visualization │ ├── attention_maps.py │ ├── clusters.py │ ├── forest_plots.py │ ├── latent_space.py │ ├── survival.py │ ├── utils.py │ └── weight_hist.py └── wandb_utils.py ├── report_representationsleiden_cox.py ├── report_representationsleiden_cox_individual.py ├── report_representationsleiden_lr.py ├── report_representationsleiden_samples.py ├── requirements.txt ├── run_representationsleiden.py ├── run_representationsleiden_assignment.py ├── run_representationsleiden_evalutation.py ├── run_representationspathology.py ├── run_representationspathology_projection.py ├── run_representationspathology_projection_dataset.py └── utilities ├── files ├── BLCA │ ├── BLCA_clinical.tsv │ ├── overall_survival_TCGA_folds.csv │ ├── overall_survival_TCGA_folds.jpg │ └── overall_survival_TCGA_folds.pkl ├── BRCA │ ├── BRCA_clinical.tsv │ ├── overall_survival_TCGA_folds.csv │ ├── overall_survival_TCGA_folds.jpg │ └── overall_survival_TCGA_folds.pkl ├── CESC │ ├── CESC_clinical.tsv │ ├── overall_survival_TCGA_folds.csv │ ├── overall_survival_TCGA_folds.jpg │ └── overall_survival_TCGA_folds.pkl ├── COAD │ ├── COAD_clinical.tsv │ ├── overall_survival_TCGA_folds.csv │ ├── overall_survival_TCGA_folds.jpg │ └── overall_survival_TCGA_folds.pkl ├── LUAD │ ├── HPC_annotations │ │ ├── LUAD_HPC_annotations.csv │ │ └── backtrack │ │ │ ├── set_0_train.csv │ │ │ ├── set_10_train.csv │ │ │ ├── set_11_train.csv │ │ │ ├── set_12_train.csv │ │ │ ├── set_13_train.csv │ │ │ ├── set_14_train.csv │ │ │ ├── set_15_train.csv │ │ │ ├── set_16_train.csv │ │ │ ├── set_17_train.csv │ │ │ ├── set_18_train.csv │ │ │ ├── set_19_train.csv │ │ │ ├── set_1_train.csv │ │ │ ├── set_20_train.csv │ │ │ ├── set_21_train.csv │ │ │ ├── set_22_train.csv │ │ │ ├── set_23_train.csv │ │ │ ├── set_24_train.csv │ │ │ ├── set_25_train.csv │ │ │ ├── set_26_train.csv │ │ │ ├── set_27_train.csv │ │ │ ├── set_28_train.csv │ │ │ ├── set_29_train.csv │ │ │ ├── set_2_train.csv │ │ │ ├── set_30_train.csv │ │ │ ├── set_31_train.csv │ │ │ ├── set_32_train.csv │ │ │ ├── set_33_train.csv │ │ │ ├── set_34_train.csv │ │ │ ├── set_35_train.csv │ │ │ ├── set_36_train.csv │ │ │ ├── set_37_train.csv │ │ │ ├── set_38_train.csv │ │ │ ├── set_39_train.csv │ │ │ ├── set_3_train.csv │ │ │ ├── set_40_train.csv │ │ │ ├── set_41_train.csv │ │ │ ├── set_42_train.csv │ │ │ ├── set_43_train.csv │ │ │ ├── set_44_train.csv │ │ │ ├── set_45_train.csv │ │ │ ├── set_4_train.csv │ │ │ ├── set_5_train.csv │ │ │ ├── set_6_train.csv │ │ │ ├── set_7_train.csv │ │ │ ├── set_8_train.csv │ │ │ └── set_9_train.csv │ ├── overall_survival_TCGA_folds.csv │ ├── overall_survival_TCGA_folds.jpg │ ├── overall_survival_TCGA_folds.pkl │ ├── overall_survival_TCGA_folds_SOTA.csv │ └── overall_survival_TCGA_folds_SOTA.pkl ├── LUADLUSC │ ├── LUADLUSC_lungsubtype_overall_survival.csv │ ├── gdc_manifest.txt │ └── lungsubtype_Institutions.pkl ├── LUSC │ ├── overall_survival_TCGA_folds.csv │ ├── overall_survival_TCGA_folds.jpg │ └── overall_survival_TCGA_folds.pkl ├── Multi-Cancer │ └── tcga_v07_10panCancer.pkl ├── PRAD │ ├── PRAD_clinical.tsv │ ├── overall_survival_TCGA_folds.csv │ ├── overall_survival_TCGA_folds.jpg │ └── overall_survival_TCGA_folds.pkl ├── SKCM │ ├── SKCM_clinical.tsv │ ├── overall_survival_TCGA_folds.csv │ ├── overall_survival_TCGA_folds.jpg │ └── overall_survival_TCGA_folds.pkl ├── STAD │ ├── STAD_clinical.tsv │ ├── overall_survival_TCGA_folds.csv │ ├── overall_survival_TCGA_folds.jpg │ └── overall_survival_TCGA_folds.pkl ├── TCGA │ ├── TCGA_Institutions.csv │ ├── TCGA_immune_landscape.csv │ └── raw │ │ └── clinical.tsv ├── UCEC │ ├── UCEC_clinical.tsv │ ├── overall_survival_TCGA_folds.csv │ ├── overall_survival_TCGA_folds.jpg │ └── overall_survival_TCGA_folds.pkl └── indexes_to_remove │ └── TCGAFFPE_LUADLUSC_5x_60pc │ ├── complete.pkl │ ├── test.pkl │ ├── train.pkl │ └── valid.pkl ├── fold_creation ├── class_folds.ipynb └── survival_folds.ipynb ├── h5_handling ├── combine_complete_h5.py ├── create_metadata_h5.py ├── h5_float_to_int.py ├── include_samples_slides_participants_h5.py ├── split_h5_by_pattern.py └── subsample_h5.py ├── hovernet_annotations └── create_hovernet_master.py ├── tile_cleaning ├── remove_indexes_h5.py ├── review_cluster_create_pickles.ipynb └── review_cluster_external_dataset.ipynb └── visualizations ├── SHAP_interpretation.ipynb ├── cluster_WSI_overlay.ipynb ├── cluster_correlations.ipynb ├── visualizations_LUAD.ipynb ├── visualizations_LUADLUSC.ipynb ├── visualizations_multicancer.ipynb └── visualizations_survival.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.icloud 2 | -------------------------------------------------------------------------------- /data_manipulation/__init__.py: -------------------------------------------------------------------------------- 1 | debug = False -------------------------------------------------------------------------------- /data_manipulation/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data_manipulation.dataset import Dataset 3 | 4 | 5 | class Data: 6 | def __init__(self, dataset, marker, patch_h, patch_w, n_channels, batch_size, project_path=os.getcwd(), thresholds=(), labels=None, empty=False, num_clusters=500, clust_percent=1.0, load=True): 7 | 8 | # Directories and file name handling. 9 | self.dataset = dataset 10 | self.marker = marker 11 | self.dataset_name = '%s_%s' % (self.dataset, self.marker) 12 | relative_dataset_path = os.path.join(self.dataset, self.marker) 13 | relative_dataset_path = os.path.join('datasets', relative_dataset_path) 14 | relative_dataset_path = os.path.join(project_path, relative_dataset_path) 15 | self.pathes_path = os.path.join(relative_dataset_path, 'patches_h%s_w%s' % (patch_h, patch_w)) 16 | 17 | self.patch_h = patch_h 18 | self.patch_w = patch_w 19 | self.n_channels = n_channels 20 | self.batch_size = batch_size 21 | 22 | # Train dataset 23 | self.hdf5_train = os.path.join(self.pathes_path, 'hdf5_%s_train.h5' % self.dataset_name) 24 | print('Train Set:', self.hdf5_train) 25 | self.training = None 26 | if os.path.isfile(self.hdf5_train) and load: 27 | self.training = Dataset(self.hdf5_train, patch_h, patch_w, n_channels, batch_size=batch_size, thresholds=thresholds, labels=labels, empty=empty, num_clusters=num_clusters, clust_percent=clust_percent) 28 | 29 | # Validation dataset, some datasets work with those. 30 | self.hdf5_validation = os.path.join(self.pathes_path, 'hdf5_%s_validation.h5' % self.dataset_name) 31 | print('Validation Set:', self.hdf5_validation) 32 | self.validation = None 33 | if os.path.isfile(self.hdf5_validation) and load: 34 | self.validation = Dataset(self.hdf5_validation, patch_h, patch_w, n_channels, batch_size=batch_size, thresholds=thresholds, labels=None, empty=empty) 35 | 36 | # Test dataset 37 | self.hdf5_test = os.path.join(self.pathes_path, 'hdf5_%s_test.h5' % self.dataset_name) 38 | print('Test Set:', self.hdf5_test) 39 | self.test = None 40 | if os.path.isfile(self.hdf5_test) and load: 41 | self.test = Dataset(self.hdf5_test, patch_h, patch_w, n_channels, batch_size=batch_size, thresholds=thresholds, labels=None, empty=empty) 42 | print() 43 | -------------------------------------------------------------------------------- /data_manipulation/dataset.py: -------------------------------------------------------------------------------- 1 | from data_manipulation.utils import inception_feature_labels 2 | import numpy as np 3 | import random 4 | import h5py 5 | 6 | 7 | class Dataset: 8 | def __init__(self, hdf5_path, patch_h, patch_w, n_channels, batch_size, thresholds=(), labels=None, empty=False, num_clusters=500, clust_percent=1.0): 9 | 10 | self.i = 0 11 | self.batch_size = batch_size 12 | self.done = False 13 | self.thresholds = thresholds 14 | self.patch_h = patch_h 15 | self.patch_w = patch_w 16 | self.n_channels = n_channels 17 | 18 | # Options for conditional PathologyGAN 19 | self.num_clusters = num_clusters 20 | self.clust_percent = clust_percent 21 | 22 | self.labels_name = labels 23 | if labels is None: 24 | self.labels_flag = False 25 | else: 26 | self.labels_flag = True 27 | 28 | self.hdf5_path = hdf5_path 29 | # Get images and labels 30 | self.images = list() 31 | self.labels = list() 32 | if not empty: 33 | self.images, self.labels, self.embedding = self.get_hdf5_data() 34 | self.size = len(self.images) 35 | self.iterations = len(self.images)//self.batch_size + 1 36 | 37 | def __iter__(self): 38 | return self 39 | 40 | def __next__(self): 41 | return self.next_batch(self.batch_size) 42 | 43 | @property 44 | def shape(self): 45 | return [len(self.images), self.patch_h, self.patch_w, self.n_channels] 46 | 47 | def get_hdf5_data(self): 48 | hdf5_file = h5py.File(self.hdf5_path, 'r') 49 | 50 | # Legacy code for initial naming of images, label keys. 51 | labels_name = self.labels_name 52 | naming = list(hdf5_file.keys()) 53 | if 'images' in naming: 54 | image_name = 'images' 55 | if labels_name is None: 56 | labels_name = 'labels' 57 | else: 58 | for naming in list(hdf5_file.keys()): 59 | if 'img' in naming or 'image' in naming: 60 | image_name = naming 61 | elif 'labels' in naming and self.labels_name is None: 62 | labels_name = naming 63 | 64 | # Get images, labels, and embeddings if neccesary. 65 | images = hdf5_file[image_name] 66 | embedding = None 67 | labels = np.zeros((images.shape[0])) 68 | if self.labels_flag: 69 | if self.labels_name == 'inception' or self.labels_name == 'self': 70 | labels, embedding = inception_feature_labels(self.hdf5_path, image_name, self.patch_h, self.patch_w, self.n_channels, self.num_clusters, self.clust_percent, set_type=self.labels_name) 71 | labels, embedding = inception_feature_labels(self.hdf5_path, image_name, self.patch_h, self.patch_w, self.n_channels, self.num_clusters, self.clust_percent, set_type=self.labels_name) 72 | else: 73 | labels = hdf5_file[labels_name] 74 | return images, labels, embedding 75 | 76 | def set_pos(self, i): 77 | self.i = i 78 | 79 | def get_pos(self): 80 | return self.i 81 | 82 | def reset(self): 83 | self.set_pos(0) 84 | 85 | def set_batch_size(self, batch_size): 86 | self.batch_size = batch_size 87 | 88 | def set_thresholds(self, thresholds): 89 | self.thresholds = thresholds 90 | 91 | def adapt_label(self, label): 92 | thresholds = self.thresholds + (None,) 93 | adapted = [0.0 for _ in range(len(thresholds))] 94 | i = None 95 | for i, threshold in enumerate(thresholds): 96 | if threshold is None or label < threshold: 97 | break 98 | adapted[i] = label if len(adapted) == 1 else 1.0 99 | return adapted 100 | 101 | def next_batch(self, n): 102 | if self.done: 103 | self.done = False 104 | raise StopIteration 105 | batch_img = self.images[self.i:self.i + n] 106 | batch_labels = self.labels[self.i:self.i + n] 107 | self.i += len(batch_img) 108 | delta = n - len(batch_img) 109 | if delta == n: 110 | raise StopIteration 111 | if 0 < delta: 112 | batch_img = np.concatenate((batch_img, self.images[:delta]), axis=0) 113 | batch_labels = np.concatenate((batch_labels, self.labels[:delta]), axis=0) 114 | self.i = delta 115 | self.done = True 116 | return batch_img/255.0, batch_labels 117 | -------------------------------------------------------------------------------- /demos/HPL_visualizer_mutlticancer.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/demos/HPL_visualizer_mutlticancer.gif -------------------------------------------------------------------------------- /demos/framework_methodology.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/demos/framework_methodology.jpg -------------------------------------------------------------------------------- /demos/slides/HPL Summary.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/demos/slides/HPL Summary.pdf -------------------------------------------------------------------------------- /models/activations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # from models.generative.utils import power_iteration_method 3 | 4 | def leakyReLU(x, alpha=0.2): 5 | return tf.maximum(alpha*x, x) 6 | 7 | def ReLU(x): 8 | return tf.nn.relu(x) 9 | 10 | def tanh(x): 11 | return tf.nn.tanh(x) 12 | 13 | def sigmoid(x): 14 | return tf.sigmoid(x) -------------------------------------------------------------------------------- /models/clustering/leiden_representations_fold.py: -------------------------------------------------------------------------------- 1 | # Imports. 2 | import scanpy as sc 3 | import anndata 4 | import argparse 5 | import copy 6 | import csv 7 | 8 | # Add project path 9 | main_path = os.path.dirname(os.path.realpath(__file__)) 10 | main_path = '/'.join(main_path.split('/')[:-2]) 11 | sys.path.append(main_path) 12 | 13 | # Data science packages. 14 | import seaborn as sns 15 | import pandas as pd 16 | import numpy as np 17 | 18 | # Other libraries. 19 | import random 20 | import h5py 21 | import os 22 | import gc 23 | 24 | # Own libs. 25 | from models.evaluation.folds import load_existing_split 26 | 27 | 28 | def adata_to_csv(adata, main_cluster_path, adata_name): 29 | current_df = pd.DataFrame() 30 | for column in adata.obs.columns: 31 | current_df[column] = adata.obs[column].values 32 | current_df.to_csv(os.path.join(main_cluster_path, '%s.csv' % adata_name), index=False) 33 | 34 | 35 | def representations_to_frame(h5_path, rep_key='z_latent'): 36 | if h5_path is not None: 37 | with h5py.File(h5_path, 'r') as content: 38 | for key in content.keys(): 39 | if rep_key in key: 40 | representations = content[key][:] 41 | dim_columns = list(range(representations.shape[1])) 42 | frame = pd.DataFrame(representations, columns=dim_columns) 43 | break 44 | 45 | rest_columns = list() 46 | for key in content.keys(): 47 | if 'latent' in key: 48 | continue 49 | frame[key] = content[key][:].astype(str) 50 | rest_columns.append(key) 51 | else: 52 | frame, dim_columns, rest_columns = None, None, None 53 | 54 | return frame, dim_columns, rest_columns 55 | 56 | 57 | def run_clustering(frame, dim_columns, rest_columns, resolution, groupby, n_neighbors, main_cluster_path, adata_name, subsample=None, save_adata=False, tabs='\t\t'): 58 | 59 | if subsample is not None: 60 | print('%sSubsampling DataFrame to %s samples' % (tabs, subsample)) 61 | frame_sub = frame.sample(n=subsample, random_state=1) 62 | else: 63 | frame_sub = frame.copy(deep=True) 64 | 65 | print('%s%s File' % (tabs, adata_name)) 66 | adata = anndata.AnnData(X=frame_sub[dim_columns].to_numpy(), obs=frame_sub[rest_columns]) 67 | # Nearest Neighbors 68 | print('%sPCA' % tabs) 69 | sc.tl.pca(adata, svd_solver='arpack', n_comps=adata.X.shape[1]-1) 70 | print('%sNearest Neighbors' % tabs) 71 | sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=adata.X.shape[1]-1, method='umap', metric='euclidean', key_added='nn_leiden') 72 | print('%sLeiden' % tabs, resolution) 73 | sc.tl.leiden(adata, resolution=resolution, key_added=groupby, neighbors_key='nn_leiden') 74 | adata_to_csv(adata, main_cluster_path, adata_name) 75 | if save_adata: 76 | adata.write(os.path.join(main_cluster_path, adata_name) + '.h5ad', compression='gzip') 77 | print() 78 | return adata 79 | 80 | def assign_clusters(frame, dim_columns, rest_columns, groupby, adata, main_cluster_path, adata_name, save_adata=False, tabs='\t\t'): 81 | print('%s%s File' % (tabs, adata_name)) 82 | adata_test = anndata.AnnData(X=frame[dim_columns].to_numpy(), obs=frame[rest_columns]) 83 | print('%sNearest Neighbors on data' % tabs) 84 | sc.tl.ingest(adata_test, adata, obs=groupby, embedding_method='pca', neighbors_key='nn_leiden') 85 | adata_to_csv(adata_test, main_cluster_path, adata_name) 86 | if save_adata: 87 | adata_test.write(os.path.join(main_cluster_path, adata_name + '.h5ad'), compression='gzip') 88 | print() 89 | del adata_test 90 | 91 | ########### Code to introduce into run_representationsleiden.py if needed ########### 92 | # After print('\tFold', i) if needed. 93 | # This needs a fix here, AnnData has a memory leak issue: https://github.com/theislab/anndata/issues/360 94 | # script_path = os.path.dirname(os.path.realpath(__file__)) 95 | # command = 'python3 %s/leiden_representations_fold.py --meta_name %s ---matching_field %s --fold %s --resolution %s --n_neighbors %s \ 96 | # --subsample %s --rep_key %s --folds_pickle %s --h5_complete_path %s --h5_additional_path %s ' % (script_path, meta_field, matching_field, i, \ 97 | # resolution, n_neighbors, subsample, rep_key, folds_pickle, h5_complete_path, h5_additional_path) 98 | # if save_adata: 99 | # command += '--save_adata' 100 | # os.system(command) 101 | 102 | ##### Main ####### 103 | parser = argparse.ArgumentParser(description='Script to combine all H5 representation file into a \'complete\' one.') 104 | parser.add_argument('--meta_name', dest='meta_name', type=str, required=True, help='Purpose of the clustering, name of folder.') 105 | parser.add_argument('--matching_field', dest='matching_field', type=str, required=True, help='Key used to match folds split and H5 representation file.') 106 | parser.add_argument('--fold', dest='fold', type=int, required=True, help='Fold number to run flow on.') 107 | parser.add_argument('--resolution', dest='resolution', type=int, required=True, help='Resolution for Leiden algorithm.') 108 | parser.add_argument('--n_neighbors', dest='n_neighbors', type=int, required=True, help='Number of neighbors to use for Leiden.') 109 | parser.add_argument('--subsample', dest='subsample', type=int, required=True, help='Number of samples to use on Leiden (given memory constrains).') 110 | parser.add_argument('--rep_key', dest='rep_key', type=str, required=True, help='Key pattern for representations to grab: z_latent, h_latent.') 111 | parser.add_argument('--folds_pickle', dest='folds_pickle', type=str, required=True, help='Pickle file with folds information.') 112 | parser.add_argument('--h5_complete_path', dest='h5_complete_path', type=str, required=True, help='H5 file path to run the leiden clustering folds.') 113 | parser.add_argument('--h5_additional_path', dest='h5_additional_path', type=str, default=None, help='Additional H5 representation to assign leiden clusters.') 114 | parser.add_argument('--save_adata', dest='save_adata', action='store_true', default=False, help='Save AnnData file for each fold.') 115 | args = parser.parse_args() 116 | meta_name = args.meta_name 117 | matching_field = args.matching_field 118 | fold = args.fold 119 | resolution = args.resolution 120 | n_neighbors = args.n_neighbors 121 | subsample = args.subsample 122 | rep_key = args.rep_key 123 | folds_pickle = args.folds_pickle 124 | main_path = args.main_path 125 | h5_complete_path = args.h5_complete_path 126 | h5_additional_path = args.h5_additional_path 127 | save_adata = args.save_adata 128 | 129 | # Get folds from existing split. 130 | folds = load_existing_split(folds_pickle) 131 | 132 | complete_frame, complete_dims, complete_rest = representations_to_frame(h5_complete_path, rep_key=rep_key) 133 | additional_frame, additional_dims, additional_rest = representations_to_frame(h5_additional_path, rep_key=rep_key) 134 | 135 | # Setup folder esqueme 136 | main_cluster_path = h5_complete_path.split('hdf5_')[0] 137 | main_cluster_path = os.path.join(main_cluster_path, meta_name) 138 | main_cluster_path = os.path.join(main_cluster_path, 'adatas') 139 | 140 | print('Leiden %s' % resolution) 141 | groupby = 'leiden_%s' % resolution 142 | print('\tFold', fold) 143 | 144 | train_samples, valid_samples, test_samples = folds[fold] 145 | 146 | # Train set. 147 | train_frame = complete_frame[complete_frame[matching_field].isin(train_samples)] 148 | adata_name = h5_complete_path.split('/hdf5_')[1].split('.h5')[0] + '_%s__fold%s' % (groupby.replace('.', 'p'), fold) 149 | adata_train = run_clustering(train_frame, complete_dims, complete_rest, resolution, groupby, n_neighbors, main_cluster_path, '%s_train_subsample' % adata_name, subsample=subsample, save_adata=True) 150 | if subsample is not None: 151 | assign_clusters(train_frame, complete_dims, complete_rest, groupby, adata_train, main_cluster_path, '%s_train' % adata_name, save_adata=save_adata) 152 | 153 | # Validation set. 154 | if len(valid_samples) > 0: 155 | valid_frame = complete_frame[complete_frame[matching_field].isin(valid_samples)] 156 | assign_clusters(valid_frame, complete_dims, complete_rest, groupby, adata_train, main_cluster_path, '%s_valid' % adata_name, save_adata=save_adata) 157 | 158 | # Test set. 159 | test_frame = complete_frame[complete_frame[matching_field].isin(test_samples)] 160 | assign_clusters(test_frame, complete_dims, complete_rest, groupby, adata_train, main_cluster_path, '%s_test' % adata_name, save_adata=save_adata) 161 | 162 | if additional_frame is not None: 163 | adata_name = h5_additional_path.split('/hdf5_')[1].split('.h5')[0] + '_%s__fold%s' % (groupby.replace('.', 'p'), fold) 164 | assign_clusters(additional_frame, additional_dims, additional_rest, groupby, adata_train, main_cluster_path, adata_name, save_adata=save_adata) 165 | 166 | del adata_train 167 | -------------------------------------------------------------------------------- /models/clustering_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from sklearn.cluster import KMeans 3 | import numpy as np 4 | import umap 5 | import copy 6 | import h5py 7 | import os 8 | import gc 9 | 10 | # Get discriminator projections for real images. 11 | def get_projections_all_dataset_projections(model, data, session, run_options, data_out_path, num_clusters, batch_size=50): 12 | num_samples = data.training.images.shape[0] 13 | batches = int(num_samples/batch_size) 14 | 15 | hdf5_features = os.path.join(data_out_path, 'checkpoints/selconditioned_labels.h5') 16 | if os.path.isfile(hdf5_features): 17 | os.remove(hdf5_features) 18 | 19 | with h5py.File(hdf5_features, mode='w') as hdf5_features_file: 20 | features_storage = hdf5_features_file.create_dataset(name='features', shape=(num_samples, model.features_fake.shape[1]), dtype=np.float32) 21 | 22 | print('Projecting images...') 23 | ind = 0 24 | for batch_num in range(batches): 25 | batch_images = data.training.images[batch_num*batch_size:(batch_num+1)*batch_size] 26 | if np.amax(batch_images) > 1.0: batch_images = batch_images/255. 27 | feed_dict = {model.real_images:batch_images} 28 | batch_projections = session.run([model.features_real], feed_dict=feed_dict, options=run_options)[0] 29 | features_storage[batch_num*batch_size:(batch_num+1)*batch_size] = batch_projections 30 | ind += batch_size 31 | if ind%10000==0: print('Processed', ind, 'images') 32 | print('Processed', ind, 'images') 33 | 34 | print('Running UMAP...') 35 | umap_reducer = umap.UMAP(n_neighbors=30, min_dist=0.0, n_components=2, random_state=42, low_memory=True) 36 | umap_fitted = umap_reducer.fit(features_storage[model.selected_indx, :]) 37 | embedding_umap_clustering = umap_fitted.transform(features_storage) 38 | 39 | # K-Means 40 | print('Running K_means...') 41 | kmeans = KMeans(init='k-means++', n_clusters=num_clusters, n_init=10).fit(embedding_umap_clustering) 42 | new_classes = kmeans.predict(umap_reducer.transform(features_storage[model.selected_indx, :])) 43 | 44 | if np.unique(model.feature_labels).shape[0] > 1: 45 | print('Hungarian matching...') 46 | match = hungarian_matching(model=model, new_classes=new_classes, current_classes=model.feature_labels, num_clusters=model.k) 47 | model.mapping_ = [int(j) for i, j in sorted(match)] 48 | 49 | # Set labels. 50 | print('Mapping...') 51 | model.feature_labels = np.array([model.mapping_[x] for x in new_classes]) 52 | 53 | feature_labels_storage = hdf5_features_file.create_dataset(name='feat_cluster_labels', shape=[num_samples] + [1], dtype=np.float32) 54 | embedding_storage = hdf5_features_file.create_dataset(name='embedding', shape=[num_samples] + [2], dtype=np.float32) 55 | 56 | print('Finding clusters for embeddings...') 57 | 58 | # Save storage for cluster labels. 59 | for i in range(num_samples): 60 | i_class = kmeans.predict(embedding_umap_clustering[i,:].reshape((1,-1)))[0] 61 | feature_labels_storage[i, 0] = model.mapping_[i_class] 62 | embedding_storage[i, :] = embedding_umap_clustering[i, :] 63 | if i%10000==0: print('Processed', i, 'cluster classes') 64 | 65 | # Get discriminator projections for real images. 66 | def get_projections(model, data, session, run_options): 67 | feature_projection = np.zeros((len(model.selected_indx), model.feature_space)) 68 | indx = 0 69 | while indx < len(model.selected_indx): 70 | # Real images. 71 | if (indx + model.batch_size) < len(model.selected_indx): 72 | current_ind = model.selected_indx[indx:indx+model.batch_size] 73 | else: 74 | current_ind = model.selected_indx[indx:] 75 | batch_images = data.training.images[current_ind, :, :, :]/255. 76 | feed_dict = {model.real_images:batch_images} 77 | batch_projections = session.run([model.features_real], feed_dict=feed_dict, options=run_options)[0] 78 | feature_projection[indx:indx+model.batch_size, :] = batch_projections 79 | indx += model.batch_size 80 | return feature_projection 81 | 82 | # Get centroids for each cluster. 83 | def get_initialization_centroids(model, embedding): 84 | means = np.zeros((model.k, embedding.shape[1])) 85 | for i in range(model.k): 86 | mask = (model.feature_labels == i) 87 | i_mean = np.zeros(embedding.shape[1]) 88 | numels = mask.astype(int).sum() 89 | if numels > 0: 90 | for index, flag in enumerate(mask): 91 | if flag: i_mean += embedding[index, :] 92 | means[i, :] = i_mean/numels 93 | else: 94 | means[i, :] = random.randint(0, embedding.shape[1] - 1) 95 | return means 96 | 97 | # Hungarian matching for classes. 98 | def hungarian_matching(model, new_classes, current_classes, num_clusters): 99 | # from sklearn.utils.linear_assignment_ import linear_assignment 100 | from scipy.optimize import linear_sum_assignment as linear_assignment 101 | num_samples = new_classes.shape[0] 102 | num_correct = np.zeros((num_clusters, num_clusters)) 103 | 104 | for i in range(num_clusters): 105 | for j in range(num_clusters): 106 | coin = int(((new_classes==i)*(current_classes==j)).sum()) 107 | num_correct[i, j] = coin 108 | 109 | match = linear_assignment(num_samples-num_correct) 110 | 111 | res = list() 112 | for out_c, gt_c in np.transpose(np.asarray(match)): 113 | res.append((out_c, gt_c)) 114 | 115 | return res 116 | 117 | # Recluster based on discriminator projections. 118 | def recluster(model, data, session, run_options): 119 | # Get centroids. 120 | print('\tGetting feature projections...') 121 | feature_projection = get_projections(model=model, data=data, session=session, run_options=run_options) 122 | 123 | # Run K-Means. 124 | print('\tFitting UMAP...') 125 | umap_fitted = umap.UMAP(n_components=2, random_state=42, low_memory=True).fit(feature_projection) 126 | print('\tTransforming UMAP...') 127 | embedding = umap_fitted.transform(feature_projection) 128 | 129 | # Get projections. 130 | print('\tGetting previous centroids...') 131 | if np.unique(model.feature_labels).shape[0] > 1: 132 | embedding = umap_fitted.transform(feature_projection) 133 | initialization = get_initialization_centroids(model=model, embedding=embedding) 134 | else: 135 | initialization = 'k-means++' 136 | 137 | # Run K-Means. 138 | print('\tRunning K-Means...') 139 | kmeans = KMeans(init=initialization, n_clusters=model.k, n_init=10).fit(embedding) 140 | new_classes = kmeans.predict(embedding) 141 | 142 | # Compute permutation, Hungarian match. 143 | if np.unique(model.feature_labels).shape[0] > 1: 144 | print('\tHungarian matching...') 145 | match = hungarian_matching(model=model, new_classes=new_classes, current_classes=model.feature_labels, num_clusters=model.k) 146 | model.mapping_ = [int(j) for i, j in sorted(match)] 147 | 148 | # Set labels. 149 | print('\tMapping...') 150 | model.feature_labels = np.array([model.mapping_[x] for x in new_classes]) 151 | 152 | clust_labels, counts = np.unique(model.feature_labels, return_counts=True) 153 | model.categorical = counts/np.sum(counts) 154 | model.reclusters_iter += 1 155 | 156 | return umap_fitted, kmeans 157 | 158 | 159 | # Get cluster labels for images. 160 | def get_labels_cluster(model, images_batch, session, run_options, umap_fitted, kmeans): 161 | feed_dict = {model.real_images:images_batch} 162 | batch_projections = session.run([model.features_real], feed_dict=feed_dict, options=run_options)[0] 163 | embedding_batch = umap_fitted.transform(batch_projections) 164 | batch_classes = kmeans.predict(embedding_batch) 165 | # Takes into account the hungarian matching. 166 | permuted_prediction = np.array([model.mapping_[x] for x in batch_classes]) 167 | 168 | return permuted_prediction 169 | 170 | ## Self-supervised Clustering SwAV. 171 | # Sinkhorn Knopp for Cluster Assignment 172 | # SwAV Paper: https://arxiv.org/abs/2006.09882 173 | # Q+ = Diag(u)*exp(C.t*Z/eps)*Diag(v) 174 | # 175 | # solution for max{ QC.tZ.T } + eps H(Q) 176 | # Q+ e Q 177 | # 178 | # u and v are renormalization vector in Re^K and Re^B respectevely. 179 | def sinkhorn(sample_prototype_batch, batch_size, epsilon=0.05, n_iters=3): 180 | 181 | # Clarify this Q 182 | # sample_prototype_batch (batch_size, prototype_dim) 183 | Q = tf.transpose(tf.exp(sample_prototype_batch/epsilon)) 184 | # Q (batch_size, prototype_dim) 185 | n = tf.reduce_sum(Q) 186 | Q = Q/n 187 | K,B = Q.shape.as_list() 188 | B = batch_size 189 | 190 | u = tf.zeros_like(K, dtype=tf.float32) 191 | r = tf.ones_like(K, dtype=tf.float32)/float(K) 192 | c = tf.ones_like(K, dtype=tf.float32)/float(B) 193 | 194 | for _ in range(n_iters): 195 | u = tf.reduce_sum(Q, axis=1) 196 | Q *= tf.expand_dims((r/u), axis=1) 197 | Q *= tf.expand_dims(c/tf.reduce_sum(Q, axis=0), 0) 198 | 199 | final_quantity = Q/tf.reduce_sum(Q, axis=0, keepdims=True) 200 | final_quantity = tf.transpose(final_quantity) 201 | 202 | return final_quantity 203 | 204 | def sinkhorn_np(sample_prototype_batch, epsilon=0.05, n_iters=3): 205 | 206 | # Clarify this Q 207 | # Q (batch_size, prototype_dim) 208 | # sample_prototype_batch (batch_size, prototype_dim) 209 | Q = np.transpose(np.exp(sample_prototype_batch/epsilon)) 210 | Q /= np.sum(Q) 211 | 212 | K,B = Q.shape 213 | 214 | u = np.zeros(K, dtype=np.float32) 215 | r = np.ones(K, dtype=np.float32)/float(K) 216 | c = np.ones(B, dtype=np.float32)/float(B) 217 | 218 | for _ in range(n_iters): 219 | u = np.sum(Q, axis=1) 220 | Q *= np.expand_dims((r/u), axis=1) 221 | Q *= np.expand_dims(c/np.sum(Q,axis=0), 0) 222 | 223 | final_quantity = Q/np.sum(Q, axis=0, keepdims=True) 224 | final_quantity = np.transpose(final_quantity) 225 | 226 | return final_quantity 227 | -------------------------------------------------------------------------------- /models/evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from scipy import sparse 4 | import scipy.sparse.linalg as sp_linalg 5 | 6 | 7 | # Using tf.gradients, but it's too costly to run. Using numerical_jacobian instead. 8 | def get_gen_jacobian(session, model, z_batch): 9 | # Currently tf.gradients(ys, xs): does sum(dy/dx) for all y in ys 10 | 11 | batch, height, width, channels = model.fake_images.shape.as_list() 12 | jacobian_matrix = np.zeros((height, width, channels, model.z_dim), dtype=np.float32) 13 | 14 | print('Starting Jacobian Calculation') 15 | for h in range(height): 16 | for w in range(width): 17 | for c in range(channels): 18 | # We iterate over the M elements of the output vector 19 | print('tf.gradients') 20 | grad_func = tf.gradients(ys=model.fake_images[:, h, w, c], xs=model.z_input) 21 | gradients = session.run(grad_func, feed_dict={model.z_input: z_batch}) 22 | gradients_avg = np.reshape(np.mean(gradients[0], axis=0), (1,-1)) 23 | jacobian_matrix[h, w, c, :] = gradients_avg 24 | 25 | print('Done Jacobian') 26 | 27 | jacobian_matrix = np.reshape(jacobian_matrix, (-1, model.z_dim)) 28 | return jacobian_matrix 29 | 30 | 31 | def numerical_jacobian(session, model, z_batch, epsilon = 1e-3): 32 | batch_size, height, width, channels = model.fake_images.shape.as_list() 33 | batch_size, z_dim = z_batch.shape 34 | 35 | numerical_jacobian = np.zeros((z_dim, height, width, channels), dtype=np.float32) 36 | im = session.run(model.fake_images, feed_dict={model.z_input: z_batch}) 37 | 38 | for zi in range(z_dim): 39 | zi_batch_sample = np.array(z_batch, copy=True) 40 | zi_batch_sample[:, zi] += epsilon 41 | im_sample = session.run(model.fake_images, feed_dict={model.z_input: zi_batch_sample}) 42 | ep = (im_sample - im)/epsilon 43 | numerical_jacobian[zi, :, :, :] = np.mean(ep, axis=0) 44 | 45 | numerical_jacobian = np.reshape(numerical_jacobian, (model.z_dim, -1)) 46 | # numerical_jacobian = numerical_jacobian.T 47 | 48 | return numerical_jacobian 49 | 50 | 51 | def jacobian_singular_values(session, model, z_batch): 52 | jacobian_matrix = numerical_jacobian(session, model, z_batch) 53 | jac_max_sign = matrix_singular_values(matrix=jacobian_matrix, n_sing=1, mode='LM') 54 | jac_min_sign = matrix_singular_values(matrix=jacobian_matrix, n_sing=1, mode='SM') 55 | return [jac_max_sign, jac_min_sign] 56 | 57 | 58 | def l2_normalize(vec, epsilon=1e-12): 59 | suma = np.sum(vec**2) 60 | norm = np.sqrt(suma+ epsilon) 61 | return vec/norm 62 | 63 | 64 | def power_iteration_method(matrix, power_iterations=10): 65 | filter_shape = matrix.shape 66 | filter_reshape = np.reshape(matrix, [-1, filter_shape[-1]]) 67 | 68 | u_shape = (1, filter_shape[-1]) 69 | # If I put trainable = False, I don't need to use tf.stop_gradient() 70 | u = np.random.normal(size=u_shape) 71 | 72 | u_norm = u 73 | v_norm = None 74 | 75 | for i in range(power_iterations): 76 | v_iter = np.matmul(u_norm, filter_reshape.T) 77 | v_norm = l2_normalize(vec=v_iter, epsilon=1e-12) 78 | u_iter = np.matmul(v_norm, filter_reshape) 79 | u_norm = l2_normalize(vec=u_iter, epsilon=1e-12) 80 | 81 | singular_w = np.matmul(np.matmul(v_norm, filter_reshape), u_norm.T)[0,0] 82 | 83 | return singular_w 84 | 85 | # Singular value desposition with Alrnoldi Iteration Method. 86 | def matrix_singular_values(matrix, n_sing, mode='LM'): 87 | 88 | filter_shape = matrix.shape 89 | matrix_reshape = np.reshape(matrix, [-1, filter_shape[-1]]) 90 | 91 | #Semi-positive definite matrix A*A.T, A.T*A 92 | dim1, dim2 = matrix_reshape.shape 93 | if dim1 > dim2: 94 | aa_t = np.matmul(matrix_reshape.T, matrix_reshape) 95 | else: 96 | aa_t = np.matmul(matrix_reshape, matrix_reshape.T) 97 | 98 | 99 | # RuntimeWarning 100 | # Trows warning to use eig instead if the say too is small.1 101 | # Eigs is an approximation, Eig calculated all eigenvalues, and eigenvectors of the matrix. 102 | 103 | try: 104 | eigenvalues, eigenvectors = sp_linalg.eigs(A=aa_t, k=n_sing, which=mode) 105 | except RuntimeWarning: 106 | pass 107 | except RuntimeError: 108 | eigenvalues = None 109 | pass 110 | 111 | if eigenvalues is None: 112 | return None 113 | 114 | if n_sing > 1: 115 | eigenvalues = np.sort(eigenvalues) 116 | if 'LM' in mode: 117 | eigenvalues = eigenvalues[::-1] 118 | 119 | sing_matrix = np.sqrt(eigenvalues) 120 | 121 | return sing_matrix 122 | 123 | 124 | def filter_singular_values(model, n_sing): 125 | gen_filters = model.gen_filters 126 | dis_filters = model.dis_filters 127 | 128 | gen_singular = dict() 129 | dis_singular = dict() 130 | for filter in gen_filters: 131 | f_name = str(filter.name.split(':')[0].replace('/', '_')) 132 | gen_singular[f_name] = matrix_singular_values(matrix=filter.eval(), n_sing=n_sing) 133 | for filter in dis_filters: 134 | f_name = str(filter.name.split(':')[0].replace('/', '_')) 135 | dis_singular[f_name] = matrix_singular_values(matrix=filter.eval(), n_sing=n_sing) 136 | 137 | return gen_singular, dis_singular 138 | 139 | 140 | def numerical_hessian(session, model, z_batch, epsilon=1e-3): 141 | batch_size, height, width, channels = model.fake_images.shape.as_list() 142 | batch_size, z_dim = z_batch.shape 143 | 144 | numerical_jacobian = np.zeros((batch_size, z_dim, height, width, channels), dtype=np.float32) 145 | 146 | for sample in range(batch_size): 147 | z_batch_sample = np.zeros((z_dim, z_dim), dtype=np.float32) 148 | z_batch_sample_ep = np.zeros((z_dim, z_dim), dtype=np.float32) 149 | for z in range(z_dim): 150 | z_batch_sample[z, :] = np.array(z_batch[sample, :], copy=True) 151 | z_batch_sample_ep[z, :] = np.array(z_batch[sample, :], copy=True) 152 | z_batch_sample_ep[z, z] += epsilon 153 | im_sample = session.run(model.fake_images, feed_dict={model.z_input: z_batch_sample}) 154 | im_sample_ep = session.run(model.fake_images, feed_dict={model.z_input: z_batch_sample_ep}) 155 | numerical_jacobian_sample = (im_sample_ep-im_sample)/epsilon 156 | numerical_jacobian[sample, :, :, :, :] = numerical_jacobian_sample 157 | 158 | numerical_jacobian = np.mean(numerical_jacobian, axis=0) 159 | numerical_jacobian = np.reshape(numerical_jacobian, (model.z_dim, -1)) 160 | return numerical_jacobian -------------------------------------------------------------------------------- /models/evaluation/prognosis.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | 4 | 5 | # Creates an grid with region-of-interest given a range and step size 6 | def create_buckets(x_s, y_s): 7 | x_min, x_max, x_b = x_s 8 | y_min, y_max, y_b = y_s 9 | 10 | x = np.linspace(x_min, x_max, x_b) 11 | x_ranges = list() 12 | for i_x in range(x_b): 13 | x_ranges.append((x[i_x], x[i_x+1])) 14 | if i_x+1 == x_b-1: 15 | break 16 | 17 | y = np.linspace(y_min, y_max, y_b) 18 | y_ranges = list() 19 | for i_y in range(y_b): 20 | y_ranges.append((y[i_y], y[i_y+1])) 21 | if i_y+1 == y_b-1: 22 | break 23 | 24 | zones = list() 25 | for x_range in x_ranges: 26 | for y_range in y_ranges: 27 | zones.append((x_range[0], x_range[1], y_range[0], y_range[1])) 28 | 29 | print('X range:', x_s) 30 | print('Y range:', y_s) 31 | print('Number of Buckets:', len(rois)) 32 | return zones 33 | 34 | # Classifies a tissue patch assigning it to a region-of-interest. 35 | def classify_patch(patch_emb, rois): 36 | x_i, y_i = patch_emb 37 | for roi_i, roi_range in enumerate(rois): 38 | x_min, x_max, y_min, y_max = roi_range 39 | if (x_i < x_max and x_min < x_i) and (y_i < y_max and y_min < y_i): 40 | return roi_i 41 | return len(rois) 42 | 43 | 44 | # Dataset labels, creates a dictionary per patient_id, with keys of 45 | # 'label' and 'patches', list of indeces with patient patches. 46 | def patient_ids_to_data(labels): 47 | ids_to_ind = dict() 48 | for index, fields in enumerate(labels): 49 | patient_id = fields[0] 50 | if patient_id not in ids_to_ind: 51 | ids_to_ind[patient_id] = OrderedDict() 52 | ids_to_ind[patient_id]['patches'] = list() 53 | if fields[1] > 5: 54 | ids_to_ind[patient_id]['label'] = 1 55 | else: 56 | ids_to_ind[patient_id]['label'] = 0 57 | ids_to_ind[patient_id]['patches'].append(index) 58 | return ids_to_ind 59 | 60 | 61 | # Method that prepares data for training, takes in the embedding and patient_ids to indeces dict. 62 | # Gives an array with 0-ID, 1-Label, 2:-Features per patient. 63 | def classify_dataset(embedding, ids_to_ind, rois=None, min_patches=12, norm=True): 64 | print('Minimun number of patches per patient:', min_patches) 65 | dropped = list() 66 | n_patients = 0 67 | for patient in ids_to_ind: 68 | if len(ids_to_ind[patient]['patches']) < min_patches: 69 | dropped.append(patient) 70 | continue 71 | else: 72 | n_patients+=1 73 | 74 | if rois is not None: 75 | print('Using ROIs provided, ID+Label+#ROIs', len(rois)+3) 76 | patient_features = np.zeros((n_patients, len(rois)+3)) 77 | else: 78 | print('No ROIs provided, using encoding as features.') 79 | patient_features = np.zeros((n_patients, embedding.shape[1]+2)) 80 | 81 | print('Dropped patients:', len(dropped), ' IDs:', dropped) 82 | print('Number of Patients:', patient_features.shape[0]) 83 | print('Number of Features:', patient_features.shape[1]) 84 | ind_dic = 0 85 | for patient in ids_to_ind: 86 | patches_patient = ids_to_ind[patient]['patches'] 87 | if len(patches_patient) < min_patches: 88 | continue 89 | patient_features[ind_dic, 0] = patient 90 | patient_features[ind_dic, 1] = ids_to_ind[patient]['label'] 91 | 92 | if rois is not None: 93 | for patch in patches_patient: 94 | roi = classify_patch(embedding[patch], rois) 95 | patient_features[ind_dic, roi+2] += 1 96 | if norm: 97 | total_patient = np.sum(patient_features[ind_dic, 2:]) 98 | patient_features[ind_dic, 2:] = patient_features[ind_dic, 2:]/total_patient 99 | else: 100 | patient_features[ind_dic, 2:] = np.mean(embedding[patches_patient, :]) 101 | ind_dic += 1 102 | return patient_features 103 | 104 | 105 | # Prepares patient features array for training and test, balances both of them so 106 | # there's no more > 5 years. 107 | def prepare_data(patient_roi, ratio_training=0.8, display=True): 108 | l5_patients_ind = np.argwhere(patient_roi[:, 1]==1)[:, 0] 109 | s5_patients_ind = np.argwhere(patient_roi[:, 1]==0)[:, 0] 110 | num_train = int(ratio_training*len(s5_patients_ind)) 111 | num_test = len(s5_patients_ind) - num_train 112 | 113 | np.random.shuffle(l5_patients_ind) 114 | 115 | train_l5 = l5_patients_ind[:num_train] 116 | test_l5 = l5_patients_ind[num_train:num_train+num_test] 117 | train_s5 = s5_patients_ind[:num_train] 118 | test_s5 = s5_patients_ind[num_train:] 119 | 120 | train = np.concatenate([train_s5, train_l5]) 121 | test = np.concatenate([test_s5, test_l5]) 122 | 123 | np.random.shuffle(train) 124 | np.random.shuffle(test) 125 | 126 | all_ = np.concatenate([l5_patients_ind, s5_patients_ind]) 127 | all_features = patient_roi[all_, 2:] 128 | all_labels = patient_roi[all_, 1].astype(np.int) 129 | 130 | train_features = patient_roi[train, 2:] 131 | train_labels = patient_roi[train, 1].astype(np.int) 132 | 133 | test_features = patient_roi[test, 2:] 134 | test_labels = patient_roi[test, 1].astype(np.int) 135 | 136 | if display: 137 | print('Larger 5:') 138 | print('\tTrain:', len(train_l5)) 139 | print('\tTest:', len(test_l5)) 140 | 141 | print('Smaller 5:') 142 | print('\tTrain:', len(train_s5)) 143 | print('\tTest:', len(test_s5)) 144 | 145 | return [train_features, train_labels], [test_features, test_labels], [all_features, all_labels] 146 | 147 | 148 | # Runs Logistic Regression over train and test set. 149 | def logistic_regression(train, test, all_): 150 | from sklearn.linear_model import LogisticRegression 151 | from sklearn import metrics 152 | 153 | logisticRegr = LogisticRegression(solver='liblinear') 154 | logisticRegr.fit(train[0], train[1]) 155 | 156 | predictions = logisticRegr.predict(test[0]) 157 | 158 | train_score = logisticRegr.score(train[0], train[1]) 159 | test_score = logisticRegr.score(test[0], test[1]) 160 | all_score = logisticRegr.score(all_[0], all_[1]) 161 | cm = metrics.confusion_matrix(test[1], predictions) 162 | 163 | print('Train Accuracy', np.round(train_score,2)) 164 | print('Test Accuracy', np.round(test_score,2)) 165 | print('All Accuracy', np.round(all_score,2)) 166 | print('Confusion matrix:') 167 | print(cm) 168 | 169 | 170 | # Pull image from a range of x,y values given the encodings, 171 | # Bucket verifies that the index before to certain list of interest. 172 | def pull_image_index(x_min, x_max, y_min, y_max, encodings, bucket_ind): 173 | image_ind = list() 174 | for index in range(encodings.shape[0]): 175 | x_i, y_i = encodings[index, :] 176 | if (x_i < x_max and x_min < x_i) and (y_i < y_max and y_min < y_i) and index in bucket_ind: 177 | image_ind.append(index) 178 | return image_ind 179 | -------------------------------------------------------------------------------- /models/evaluation/tools.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import h5py 5 | import random 6 | from collections import OrderedDict 7 | from models.score.utils import * 8 | 9 | 10 | def get_top_nearest_neighbors(num_generated, nearneig, real_features_hdf5, real_img_hdf5, gen_features_hdf5, gen_img_hdf5, maximum=False, random_select=False, save_path=None): 11 | 12 | real_features_file = h5py.File(real_features_hdf5, 'r') 13 | gen_features_file = h5py.File(gen_features_hdf5, 'r') 14 | real_img_file = h5py.File(real_img_hdf5, 'r') 15 | gen_img_file = h5py.File(gen_img_hdf5, 'r') 16 | 17 | real_features = real_features_file['features'] 18 | gen_features = gen_features_file['features'] 19 | real_img = real_img_file['images'] 20 | gen_img = gen_img_file['images'] 21 | 22 | with tf.Session() as sess: 23 | real_features = tf.constant(np.array(real_features), dtype=tf.float32) 24 | gen_features = tf.constant(np.array(gen_features), dtype=tf.float32) 25 | 26 | # Get Nearest Neighbors for all generated images. 27 | gen_real_distances = tf.sqrt(tf.abs(euclidean_distance(gen_features, real_features))) 28 | neg = tf.negative(gen_real_distances) 29 | neg_s_distances, s_indices = tf.math.top_k(input=neg, k=nearneig, sorted=True) 30 | s_distances = tf.negative(neg_s_distances) 31 | 32 | 33 | # Getting the top smallest distances between Generated and Real images. 34 | neg_s_distances1, s_indices1 = tf.math.top_k(input=neg, k=1, sorted=True) 35 | neg_s_distances1 = tf.transpose(neg_s_distances1) 36 | if not random_select: 37 | if maximum: 38 | neg_s_distances1 = tf.negative(neg_s_distances1) 39 | neg_s_distances1, s_indices1 = tf.math.top_k(input=neg_s_distances1, k=num_generated, sorted=True) 40 | s_indices1 = tf.transpose(s_indices1) 41 | s_indices1 = s_indices1.eval() 42 | else: 43 | lin = list(range(int(gen_real_distances.shape[0]))) 44 | random.shuffle(lin) 45 | s_indices1 = np.zeros((num_generated,1), dtype=np.int8) 46 | s_indices1[:, 0] = lin[:num_generated] 47 | 48 | s_indices = s_indices.eval() 49 | s_distances = s_distances.eval() 50 | # For the images with top smallest distances, show nearest neighbors. 51 | neighbors = OrderedDict() 52 | for i, ind in enumerate(s_indices1): 53 | ind = ind[0] 54 | neighbors[ind] = list() 55 | for j in range(nearneig): 56 | neighbors[ind].append((s_indices[ind,j], s_distances[ind,j])) 57 | 58 | if save_path is not None: 59 | height, width, channels = real_img.shape[1:] 60 | grid = np.zeros((num_generated*height, (nearneig+1)*width, channels)) 61 | for i, ind in enumerate(s_indices1): 62 | ind = ind[0] 63 | total = gen_img[ind] 64 | for j in range(nearneig): 65 | real = real_img[s_indices[ind,j]]/255. 66 | total = np.concatenate([total, real], axis=1) 67 | grid[i*height:(i+1)*height, :, :] = total 68 | plt.imsave(save_path, grid) 69 | 70 | return neighbors 71 | 72 | 73 | def find_top_nearest_neighbors(generated_list, nearneig, real_features_hdf5, real_img_hdf5, gen_features_hdf5, gen_img_hdf5, maximum=False, save_path=None): 74 | real_features_file = h5py.File(real_features_hdf5, 'r') 75 | gen_features_file = h5py.File(gen_features_hdf5, 'r') 76 | real_img_file = h5py.File(real_img_hdf5, 'r') 77 | gen_img_file = h5py.File(gen_img_hdf5, 'r') 78 | 79 | real_features = real_features_file['features'] 80 | gen_features = gen_features_file['features'] 81 | real_img = real_img_file['images'] 82 | gen_img = gen_img_file['images'] 83 | 84 | with tf.Session() as sess: 85 | real_features = tf.constant(np.array(real_features), dtype=tf.float32) 86 | gen_features = tf.constant(np.array(gen_features), dtype=tf.float32) 87 | 88 | # Get Nearest Neighbors for all generated images. 89 | gen_real_distances = tf.sqrt(tf.abs(euclidean_distance(gen_features, real_features))) 90 | neg = tf.negative(gen_real_distances) 91 | neg_s_distances, s_indices = tf.math.top_k(input=neg, k=nearneig, sorted=True) 92 | s_distances = tf.negative(neg_s_distances) 93 | 94 | s_indices = s_indices.eval() 95 | s_distances = s_distances.eval() 96 | # For the images with top smallest distances, show nearest neighbors. 97 | height, width, channels = real_img.shape[1:] 98 | neighbors = dict() 99 | grid = np.zeros((len(generated_list)*height, (nearneig+1)*width, channels)) 100 | for i, ind in enumerate(generated_list): 101 | total = gen_img[ind] 102 | neighbors[ind] = list() 103 | for j in range(nearneig): 104 | neighbors[ind].append((s_indices[ind,j], s_distances[ind,j])) 105 | real = real_img[s_indices[ind,j]]/255. 106 | total = np.concatenate([total, real], axis=1) 107 | grid[i*height:(i+1)*height, :, :] = total 108 | plt.imshow(grid) 109 | if save_path is not None: 110 | plt.imsave(save_path, grid) 111 | return neighbors 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /models/networks/attention.py: -------------------------------------------------------------------------------- 1 | from models.normalization import * 2 | from models.regularizers import * 3 | from models.activations import * 4 | from models.evaluation import * 5 | from models.optimizer import * 6 | from models.loss import * 7 | from models.ops import * 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | 13 | display = True 14 | 15 | # Feature Extractor Network 20x. 16 | def feature_extractor_20x(inputs, z_dim, regularizer_scale, use, reuse, scope): 17 | print('Feature Extractor Network 20x:', inputs.shape[-1], 'Dimensions') 18 | interm = inputs 19 | if use: 20 | with tf.variable_scope('feature_extractor_20x_%s' % scope, reuse=reuse): 21 | 22 | interm = tf.reshape(interm, (-1, z_dim)) 23 | net = dense(inputs=interm, out_dim=int(z_dim), scope=1, use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 24 | net = ReLU(net) 25 | net = dense(inputs=net, out_dim=int(z_dim), scope=2, use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 26 | interm = ReLU(net) 27 | interm = tf.reshape(interm, (-1, 16, z_dim)) 28 | print() 29 | return interm 30 | 31 | # Feature Extractor Network 10x. 32 | def feature_extractor_10x(inputs, z_dim, regularizer_scale, use, reuse, scope): 33 | print('Feature Extractor Network 10x:', inputs.shape[-1], 'Dimensions') 34 | interm = inputs 35 | if use: 36 | with tf.variable_scope('feature_extractor_10x_%s' % scope, reuse=reuse): 37 | 38 | interm = tf.reshape(interm, (-1, z_dim)) 39 | net = dense(inputs=interm, out_dim=int(z_dim), scope=1, use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 40 | net = ReLU(net) 41 | net = dense(inputs=net, out_dim=int(z_dim), scope=2, use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 42 | interm = ReLU(net) 43 | interm = tf.reshape(interm, (-1, 4, z_dim)) 44 | print() 45 | return interm 46 | 47 | # Feature Extractor Network 5x 48 | def feature_extractor_5x(inputs, z_dim, regularizer_scale, use, reuse, scope): 49 | print('Feature Extractor Network 5x:', inputs.shape[-1], 'Dimensions') 50 | interm = inputs 51 | if use: 52 | with tf.variable_scope('feature_extractor_5x_%s' % scope, reuse=reuse): 53 | net = dense(inputs=inputs, out_dim=int(z_dim), scope=1, use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 54 | net = ReLU(net) 55 | net = dense(inputs=net, out_dim=int(z_dim), scope=2, use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 56 | interm = ReLU(net) 57 | print() 58 | return interm 59 | 60 | # Attention Network 20x. 61 | def attention_20x(inputs, z_dim, att_dim, regularizer_scale, reuse, scope, use_gated=True): 62 | print('Attention Network 20x:', inputs.shape[-1], 'Dimensions') 63 | with tf.variable_scope('attention_20x_%s' % scope, reuse=reuse): 64 | 65 | net1 = tf.reshape(inputs, (-1, z_dim)) 66 | net1 = dense(inputs=net1, out_dim=att_dim, scope='V_k', use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 67 | net1 = tanh(net1) 68 | 69 | if use_gated: 70 | # GatedAttention. 71 | net2 = dense(inputs=inputs, out_dim=att_dim, scope='U_k', use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 72 | net2 = sigmoid(net2) 73 | net = net1*net2 74 | else: 75 | net = net1 76 | 77 | # Get weights. 78 | net = dense(inputs=net, out_dim=1, scope='W', use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 79 | net_reshape = tf.reshape(net, (-1, 16, 1)) 80 | weights = tf.nn.softmax(net_reshape, axis=1) 81 | print() 82 | return weights 83 | 84 | # Attention Network 10x. 85 | def attention_10x(inputs, z_dim, att_dim, regularizer_scale, reuse, scope, use_gated=True): 86 | print('Attention Network 10x:', inputs.shape[-1], 'Dimensions') 87 | with tf.variable_scope('attention_10x_%s' % scope, reuse=reuse): 88 | 89 | net1 = tf.reshape(inputs, (-1, z_dim)) 90 | net1 = dense(inputs=net1, out_dim=att_dim, scope='V_k', use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 91 | net1 = tanh(net1) 92 | 93 | if use_gated: 94 | # GatedAttention. 95 | net2 = dense(inputs=inputs, out_dim=att_dim, scope='U_k', use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 96 | net2 = sigmoid(net2) 97 | net = net1*net2 98 | else: 99 | net = net1 100 | 101 | # Get weights. 102 | net = dense(inputs=net, out_dim=1, scope='W', use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 103 | net_reshape = tf.reshape(net, (-1, 4, 1)) 104 | weights = tf.nn.softmax(net_reshape, axis=1) 105 | print() 106 | return weights 107 | 108 | # Aggregate 20x representations. 109 | def aggregate_20x_representations(interm, weights, reuse, scope): 110 | print('Aggregate Network 20x:', interm.shape[-1], 'Dimensions') 111 | with tf.variable_scope('aggregate_20x_%s' % scope, reuse=reuse): 112 | weighted_rep = interm*weights 113 | aggregated_rep = tf.reduce_sum(weighted_rep, axis=1) 114 | return aggregated_rep 115 | 116 | # Aggregate 10x representations. 117 | def aggregate_10x_representations(interm, weights, reuse, scope): 118 | print('Aggregate Network 10x:', interm.shape[-1], 'Dimensions') 119 | with tf.variable_scope('aggregate_10x_%s' % scope, reuse=reuse): 120 | weighted_rep = interm*weights 121 | aggregated_rep = tf.reduce_sum(weighted_rep, axis=1) 122 | return aggregated_rep 123 | 124 | # Feature combination for concatenated vector of 5x/10x/20x 125 | def feature_extractor_comb(inputs, z_dim, regularizer_scale, use, reuse, scope): 126 | print('Feature Extractor Network All Magnifications:', inputs.shape[-1], 'Dimensions') 127 | interm = inputs 128 | if use: 129 | with tf.variable_scope('feature_extractor_comb_%s' % scope, reuse=reuse): 130 | net = dense(inputs=inputs, out_dim=int(z_dim)*3, scope=1, use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 131 | net = ReLU(net) 132 | net = dense(inputs=net, out_dim=int(z_dim)*3, scope=2, use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 133 | interm = ReLU(net) 134 | print() 135 | return interm 136 | 137 | # Attention Network. 138 | def attention(inputs, z_dim, att_dim, regularizer_scale, reuse, scope, use_gated=True): 139 | print('Attention Network All Magnifications:', inputs.shape[-1], 'Dimensions') 140 | with tf.variable_scope('attention_%s' % scope, reuse=reuse): 141 | 142 | # 143 | net1 = dense(inputs=inputs, out_dim=att_dim, scope='V_k', use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 144 | net1 = tanh(net1) 145 | 146 | if use_gated: 147 | # GatedAttention. 148 | net2 = dense(inputs=inputs, out_dim=att_dim, scope='U_k', use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 149 | net2 = sigmoid(net2) 150 | net = net1*net2 151 | else: 152 | net = net1 153 | 154 | # Get weights. 155 | net = dense(inputs=net, out_dim=1, scope='W', use_bias=True, spectral=False, init='glorot_uniform', regularizer=l2_reg(regularizer_scale), display=True) 156 | net = tf.transpose(net) 157 | weights = tf.nn.softmax(net) 158 | weights = tf.transpose(weights) 159 | print() 160 | 161 | return weights 162 | 163 | # Patient aggregation for representations and weights. 164 | def patient_aggregation(interm, weights, reuse, scope): 165 | print('Patient Aggregation Network:', interm.shape[-1], 'Dimensions') 166 | with tf.variable_scope('patient_aggregation_%s' % scope, reuse=reuse): 167 | # Weight each sample. 168 | patient_rep = tf.reshape(tf.reduce_sum(weights*interm, axis=0), (-1,1)) 169 | patient_rep = tf.transpose(patient_rep) 170 | print() 171 | return patient_rep 172 | 173 | def attention_network(model, represenation_input_5x, represenation_input_10x, represenation_input_20x, regularizer_scale, reuse, name='attention_network'): 174 | 175 | with tf.variable_scope(name, reuse=reuse): 176 | 177 | # Feature Extractions. 178 | model.interm_5x = feature_extractor_5x(inputs=represenation_input_5x, z_dim=model.z_dim, regularizer_scale=regularizer_scale, use=True, reuse=False, scope=1) 179 | model.interm_10x = feature_extractor_10x(inputs=represenation_input_10x, z_dim=model.z_dim, regularizer_scale=regularizer_scale, use=True, reuse=False, scope=1) 180 | model.interm_20x = feature_extractor_20x(inputs=represenation_input_20x, z_dim=model.z_dim, regularizer_scale=regularizer_scale, use=True, reuse=False, scope=1) 181 | 182 | ################### Multi-Magnification Attention MIL portion of the model. 183 | # Attention and aggregation of 20x. 184 | model.weights_20x = attention_20x(inputs=model.interm_20x, z_dim=model.z_dim, att_dim=model.att_dim, regularizer_scale=regularizer_scale, use_gated=model.use_gated, reuse=False, scope=1) 185 | aggregate_tiles_20x = aggregate_20x_representations(model.interm_20x, model.weights_20x, reuse=False, scope=1) 186 | 187 | # Attention and aggregation of 10x. 188 | model.weights_10x = attention_10x(inputs=model.interm_10x, z_dim=model.z_dim, att_dim=model.att_dim, regularizer_scale=regularizer_scale, use_gated=model.use_gated, reuse=False, scope=1) 189 | aggregate_tiles_10x = aggregate_10x_representations(model.interm_10x, model.weights_10x, reuse=False, scope=1) 190 | 191 | # Concatenate all magnification representations: 3*z_dim. 192 | rep_multimag = tf.concat([model.interm_5x, aggregate_tiles_10x, aggregate_tiles_20x], axis=1) 193 | rep_multimag = feature_extractor_comb(inputs=rep_multimag, z_dim=model.z_dim, regularizer_scale=regularizer_scale, use=True, reuse=False, scope=1) 194 | 195 | # Attention and Patient Represenation. 196 | model.weights = attention(inputs=rep_multimag, z_dim=model.z_dim, att_dim=model.att_dim, regularizer_scale=regularizer_scale, use_gated=model.use_gated, reuse=False, scope=1) 197 | patient_rep_ind = patient_aggregation(interm=rep_multimag, weights=model.weights, reuse=False, scope=1) 198 | 199 | return patient_rep_ind 200 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /models/normalization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from models.activations import * 4 | from models.ops import * 5 | 6 | def batch_norm(inputs, training, c=None, spectral=False, scope=False): 7 | output = tf.layers.batch_normalization(inputs=inputs, training=training) 8 | return output 9 | 10 | def instance_norm(inputs, training, c=None, spectral=False, scope=False): 11 | # Not used: training 12 | output = tf.contrib.layers.instance_norm(inputs=inputs) 13 | return output 14 | 15 | def layer_norm(inputs, training, c=None, spectral=False, scope=False): 16 | # Not used: training 17 | output = tf.contrib.layers.layer_norm(inputs=inputs, scope='layer_norm_%s' % scope) 18 | return output 19 | 20 | def group_norm(inputs, training, c=None, spectral=False, scope=False): 21 | # Not used: training 22 | output = tf.contrib.layers.group_norm(inputs=inputs) 23 | return output 24 | 25 | def conditional_instance_norm(inputs, training, c, scope, spectral=False): 26 | input_dims = inputs.shape.as_list() 27 | if len(input_dims) == 4: 28 | batch, height, width, channels = input_dims 29 | else: 30 | batch, channels = input_dims 31 | 32 | with tf.variable_scope('conditional_instance_norm_%s' % scope): 33 | decay = 0.9 34 | epsilon = 1e-5 35 | 36 | # MLP for gamma, and beta. 37 | inter_dim = int((channels+c.shape.as_list()[-1])/2) 38 | net = dense(inputs=c, out_dim=inter_dim, scope=1, spectral=spectral, display=False) 39 | net = ReLU(net) 40 | gamma = dense(inputs=net, out_dim=channels, scope='gamma', spectral=spectral, display=False) 41 | gamma = ReLU(gamma) 42 | beta = dense(inputs=net, out_dim=channels, scope='beta', spectral=spectral, display=False) 43 | if len(input_dims) == 4: 44 | gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1) 45 | beta = tf.expand_dims(tf.expand_dims(beta, 1), 1) 46 | 47 | if len(input_dims) == 4: 48 | batch_mean, batch_variance = tf.nn.moments(inputs, axes=[1,2], keep_dims=True) 49 | else: 50 | batch_mean, batch_variance = tf.nn.moments(inputs, axes=[1], keep_dims=True) 51 | 52 | batch_norm_output = tf.nn.batch_normalization(inputs, batch_mean, batch_variance, beta, gamma, epsilon) 53 | 54 | return batch_norm_output 55 | 56 | def conditional_batch_norm(inputs, training, c, scope, spectral=False): 57 | input_dims = inputs.shape.as_list() 58 | if len(input_dims) == 4: 59 | batch, height, width, channels = input_dims 60 | else: 61 | batch, channels = input_dims 62 | 63 | with tf.variable_scope('conditional_batch_norm_%s' % scope) : 64 | decay = 0.9 65 | epsilon = 1e-5 66 | 67 | test_mean = tf.get_variable("pop_mean", shape=[channels], dtype=tf.float32, initializer=tf.constant_initializer(0.0), trainable=False) 68 | test_variance = tf.get_variable("pop_var", shape=[channels], dtype=tf.float32, initializer=tf.constant_initializer(1.0), trainable=False) 69 | 70 | # MLP for gamma, and beta. 71 | inter_dim = int((channels+c.shape.as_list()[-1])/2) 72 | net = dense(inputs=c, out_dim=inter_dim, scope=1, spectral=spectral, display=False) 73 | net = ReLU(net) 74 | gamma = dense(inputs=net, out_dim=channels, scope='gamma', spectral=spectral, display=False) 75 | gamma = ReLU(gamma) 76 | beta = dense(inputs=net, out_dim=channels, scope='beta', spectral=spectral, display=False) 77 | if len(input_dims) == 4: 78 | gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1) 79 | beta = tf.expand_dims(tf.expand_dims(beta, 1), 1) 80 | 81 | if training: 82 | if len(input_dims) == 4: 83 | batch_mean, batch_variance = tf.nn.moments(inputs, axes=[0, 1, 2]) 84 | # batch_mean, batch_variance = tf.nn.moments(inputs, axes=[0, 1, 2], keep_dims=True) 85 | else: 86 | batch_mean, batch_variance = tf.nn.moments(inputs, axes=[0, 1]) 87 | # batch_mean, batch_variance = tf.nn.moments(inputs, axes=[0, 1], keep_dims=True) 88 | ema_mean = tf.assign(test_mean, test_mean * decay + batch_mean * (1 - decay)) 89 | ema_variance = tf.assign(test_variance, test_variance * decay + batch_variance * (1 - decay)) 90 | with tf.control_dependencies([ema_mean, ema_variance]): 91 | batch_norm_output = tf.nn.batch_normalization(inputs, batch_mean, batch_variance, beta, gamma, epsilon) 92 | else: 93 | batch_norm_output = tf.nn.batch_normalization(inputs, test_mean, test_variance, beta, gamma, epsilon) 94 | return batch_norm_output 95 | -------------------------------------------------------------------------------- /models/nuance.py: -------------------------------------------------------------------------------- 1 | from scipy.special import gamma 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | def bandwith(dim_v): 6 | # Assuming here that after the mapping the latent space has some sort of gaussian distribution. 7 | gz = 2 * gamma(0.5 * (dim_v+1)) / gamma(0.5 * dim_v) 8 | return 1. / (2. * gz) 9 | 10 | def rbf_gauss(u, v, gamma_): 11 | dist_table_matrix = tf.expand_dims(u, 0) - tf.expand_dims(v, 1) 12 | l2_dist = tf.reduce_sum(dist_table_matrix**2, axis=-1) 13 | rbf = tf.exp(-gamma_*l2_dist) 14 | return rbf 15 | 16 | def HSIC(u, v, gamma=None): 17 | dim_u = u.shape.as_list()[1] 18 | dim_v = v.shape.as_list()[1] 19 | 20 | if gamma is None: 21 | gamma_u = bandwith(dim_u) 22 | gamma_v = bandwith(dim_v) 23 | else: 24 | gamma_u = gamma 25 | gamma_v = gamma 26 | 27 | uu = rbf_gauss(u, u, gamma_=gamma_u) 28 | vv = rbf_gauss(v, v, gamma_=gamma_v) 29 | 30 | # HSIC = E_xx'yy'[k(x,x')l(y,y')] + E_xx'[k(x,x')]E_yy'[l(y,y')] - 2 E_xy[ E_x'[k(x,x')] E_y'[l(y,y')] ] 31 | term_1 = tf.reduce_mean(uu * vv) 32 | term_2 = tf.reduce_mean(uu) * tf.reduce_mean(vv) 33 | term_3 = 2 * tf.reduce_mean( tf.reduce_mean(uu, axis=1) * tf.reduce_mean(vv, axis=1) ) 34 | value = tf.sqrt(term_1 + term_2 - term_3) 35 | return value -------------------------------------------------------------------------------- /models/regularizers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def orthogonal_reg(scale): 5 | 6 | def ortho_reg(w): 7 | 8 | if len(w.shape.as_list()) > 2: 9 | filter_size, filter_size, input_channels, output_channels = w.shape.as_list() 10 | w_reshape = tf.reshape(w, (-1, output_channels)) 11 | dim = output_channels 12 | else: 13 | output_dim, input_dim = w.shape.as_list() 14 | dim = input_dim 15 | w_reshape = w 16 | 17 | identity = tf.eye(dim) 18 | 19 | wt_w = tf.matmul(a=w_reshape, b=w_reshape, transpose_a=True) 20 | term = tf.multiply(wt_w, (tf.ones_like(identity)-identity)) 21 | 22 | reg = 2*tf.nn.l2_loss(term) 23 | 24 | return scale*reg 25 | 26 | return ortho_reg 27 | 28 | def perceptual_path_length(model, path_length_decay=0.01, path_length_weight=2.0): 29 | path_length_noise = tf.random_normal(tf.shape(model.fake_images))/np.sqrt(model.image_height*model.image_width) 30 | path_length_grad = tf.gradients(tf.reduce_sum(model.fake_images * path_length_noise), model.w_latent) 31 | 32 | # Shape: (?, 200, 6) 33 | path_length_grad_s2 = tf.square(path_length_grad) 34 | # Shape: (1, ?, 200, 6) 35 | path_length_1 = tf.reduce_sum(path_length_grad_s2, axis=2) 36 | # Shape: (1, ?, 6) 37 | path_length = tf.sqrt(tf.reduce_mean(path_length_1, axis=1)) 38 | # Shape: (1, 6) 39 | 40 | with tf.control_dependencies(None): 41 | path_length_mean_var = tf.Variable(name='path_length_mean', trainable=False, initial_value=0.0, dtype=tf.float32) 42 | path_length_mean = path_length_mean_var + path_length_decay*(tf.reduce_mean(path_length)-path_length_mean_var) 43 | path_length_update = tf.assign(path_length_mean_var, path_length_mean) 44 | 45 | with tf.control_dependencies([path_length_update]): 46 | path_length_penalty = tf.square(path_length-path_length_mean) 47 | 48 | # print('path_length_mean:', path_length_mean.shape) 49 | regularization_term = tf.reduce_mean(path_length_penalty)*path_length_weight 50 | # print('Final variable:', regularization_term.shape) 51 | 52 | return regularization_term 53 | 54 | 55 | def l2_reg(scale): 56 | return tf.contrib.layers.l2_regularizer(scale) 57 | 58 | 59 | def l1_reg(scale): 60 | return tf.contrib.layers.l1_regularizer(scale) -------------------------------------------------------------------------------- /models/score/crimage_score.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow.contrib.gan as tfgan 3 | from data_manipulation.utils import * 4 | from models.score.utils import * 5 | from models.score.frechet_inception_distance import * 6 | from models.score.kernel_inception_distance import * 7 | from models.score.mmd import * 8 | from models.score.k_nearest_neighbor import * 9 | from models.score.inception_score import * 10 | from models.score.mode_score import * 11 | 12 | 13 | class CRImage_Scores(object): 14 | def __init__(self, ref1_crimage, ref2_crimage, name_x, name_y, k=1, GPU=False, display=False): 15 | # super(Scores, self).__init__() 16 | self.ref1_crimage = ref1_crimage 17 | self.ref2_crimage = ref2_crimage 18 | self.name_x = name_x 19 | self.name_y = name_y 20 | self.title = '%s Features - %s Features' % (self.name_x, self.name_y) 21 | self.k = k 22 | 23 | self.fid = None 24 | self.kid = None 25 | self.mmd = None 26 | self.knn_x = None 27 | self.knn_y = None 28 | self.knn = None 29 | 30 | self.display = display 31 | 32 | if GPU: 33 | self.config = tf.ConfigProto() 34 | else: 35 | self.config = tf.ConfigProto(device_count = {'GPU': 0}) 36 | 37 | if self.display: 38 | print('Running:', self.title) 39 | print('Loded CRImage Files') 40 | print(self.name_x, 'File:', self.ref1_crimage) 41 | print(self.name_y, 'Shape:', self.ref2_crimage) 42 | self.build_graph() 43 | if self.display: 44 | print('Created Graph.') 45 | 46 | def build_graph(self): 47 | self.x_input = tf.placeholder(dtype=tf.float32, shape=(None, 3), name='x_features') 48 | self.y_input = tf.placeholder(dtype=tf.float32, shape=(None, 3), name='y_features') 49 | self.fid_output = tfgan.eval.frechet_classifier_distance_from_activations(self.x_input, self.y_input) 50 | self.kid_output = kernel_inception_distance(self.x_input, self.y_input) 51 | self.mmd_output = maximmum_mean_discrepancy_score(self.x_input, self.y_input) 52 | self.indices_output, self.labels_output = k_nearest_neighbor_tf_part(self.x_input, self.y_input, k=self.k) 53 | 54 | def read_crimage(self, file_path): 55 | imgs = list() 56 | with open(file_path) as content: 57 | for line in content: 58 | line = line.replace('\n', '') 59 | # 16 5 0.0003195399 0.7619048 50072 60 | values = line.split(' ') 61 | values.pop() 62 | if len(values) != 3: 63 | values.pop() 64 | if '' == line: 65 | continue 66 | imgs.append(values) 67 | return np.array(imgs) 68 | 69 | def report_scores(self): 70 | print() 71 | print('--------------------------------------------------------') 72 | print(self.title) 73 | print('Frechet Inception Distance:', self.fid) 74 | # print('Kernel Inception Distance:', self.kid) 75 | # print('Mean Minimum Distance:', self.mmd) 76 | # print('%s-NN %s Accuracy:' % (self.k, self.name_x), self.knn_x) 77 | # print('%s-NN %s Accuracy:' % (self.k, self.name_y), self.knn_y) 78 | # print('%s-NN Accuracy:' % (self.k), self.knn) 79 | print() 80 | print('--------------------------------------------------------') 81 | print() 82 | 83 | def run_crimage_scores(self): 84 | score_dict = dict() 85 | features_x = self.read_crimage(self.ref1_crimage) 86 | features_y = self.read_crimage(self.ref2_crimage) 87 | 88 | with tf.Session(config=tf.ConfigProto(device_count = {'GPU': 0})) as sess: 89 | sess.run(tf.global_variables_initializer()) 90 | feed_dict = {self.x_input:features_x, self.y_input:features_y} 91 | self.fid, self.kid, self.mmd, self.indices, self.labels = sess.run([self.fid_output, self.kid_output, self.mmd_output, self.indices_output, self.labels_output], feed_dict) 92 | self.knn_x, self.knn_y, self.knn = k_nearest_neighbor_np_part(self.indices, self.labels, k=self.k, x_samples=features_x.shape[0]) 93 | 94 | if self.display: 95 | self.report_scores() -------------------------------------------------------------------------------- /models/score/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from scipy import linalg 4 | from models.score.utils import * 5 | 6 | 7 | def frechet_inception_distance(x_features, y_features, batch_size, sqrt=False): 8 | batch_scores = list() 9 | batches = int(x_features.shape.as_list()[0]/batch_size) 10 | for i in range(batches): 11 | if batches-1 == i: 12 | x_features_batch = x_features[i*batch_size: , :] 13 | y_features_batch = y_features[i*batch_size: , :] 14 | else: 15 | x_features_batch = x_features[i*batch_size : (i+1)*batch_size, :] 16 | y_features_batch = y_features[i*batch_size : (i+1)*batch_size, :] 17 | 18 | samples = x_features_batch.shape.as_list()[0] 19 | x_feat = tf.reshape(x_features_batch, (samples, -1)) 20 | y_feat = tf.reshape(y_features_batch, (samples, -1)) 21 | 22 | x_mean = tf.reduce_mean(x_feat, axis=0) 23 | y_mean = tf.reduce_mean(y_feat, axis=0) 24 | 25 | # Review this two lines. 26 | x_cov = covariance(x_feat) 27 | y_cov = covariance(y_feat) 28 | 29 | means = dot_product(x_mean, x_mean) + dot_product(y_mean, y_mean) - 2*dot_product(x_mean, y_mean) 30 | cov_s = linalg.sqrtm(tf.matmul(x_cov, y_cov), True) 31 | cov_s = cov_s.real 32 | covas = tf.trace(x_cov + y_cov - 2*cov_s) 33 | 34 | fid = means + covas 35 | if sqrt: 36 | fid = tf.sqrt(fid) 37 | batch_scores.append(np.array(fid)) 38 | return np.mean(batch_scores), np.std(batch_scores) -------------------------------------------------------------------------------- /models/score/inception_score.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def inception_score(p_xi_y, batch_size, epsilon=1e-20): 6 | batch_scores = list() 7 | batches = int(p_xi_y.shape.as_list([0])/batch_size) 8 | for i in range(batches): 9 | if batches-1 == i: 10 | p_xi_y_batch = p_xi_y[i*batch_size: , :] 11 | else: 12 | p_xi_y_batch = p_xi_y[i*batch_size: (i+1)*batch_size, :] 13 | # Marginal label distribution over all batch samples. 14 | p_y_batch = tf.reduce_mean(p_xi_y_batch, axis=0) 15 | kl_dist = p_xi_y_batch * (tf.log(p_xi_y_batch+epsilon) - tf.log(p_y_batch+epsilon)) 16 | is_batch = tf.exp(tf.reduce_mean(kl_dist, axis=-1)) 17 | batch_scores.append(is_batch) 18 | return np.mean(batch_scores), np.std(batch_scores) 19 | -------------------------------------------------------------------------------- /models/score/k_nearest_neighbor.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from models.score.utils import * 4 | 5 | 6 | def k_nearest_neighbor_tf_part(x, y, k): 7 | x_samples = tf.shape(x)[0] 8 | y_samples = tf.shape(y)[0] 9 | 10 | xx_d = euclidean_distance(x, x) 11 | yy_d = euclidean_distance(y, y) 12 | xy_d = euclidean_distance(x, y) 13 | 14 | labels = tf.concat([tf.ones((x_samples,1)), tf.zeros((y_samples,1))], axis=0) 15 | 16 | x_dist = tf.concat([xx_d, xy_d], axis=-1) 17 | y_dist = tf.concat([tf.transpose(xy_d), yy_d], axis=-1) 18 | total_dist = tf.concat([x_dist, y_dist], axis=0) 19 | ''' 20 | x1x1 x1x2 ... x1x100 | x1y1 x1xy2 ... x1y200 21 | ... | ... 22 | x100x1 x100x2 ... x100x100 | x100y1 x100xy2 ... x100y200 23 | ________________________________________________________ 24 | y1x1 y1x2 ... y1x100 | y1y1 y1xy2 ... y1y200 25 | ... | ... 26 | y100x1 y100x2 ... y100x100 | y100y1 y1xy2 ... y100y100 27 | ... | ... 28 | y200x1 y200x2 ... y200x100 | y200y1 y200xy2 ... y200y200 29 | 30 | Diagonals of this tensor are the distance for the vector with itself. 31 | ''' 32 | total_dist = tf.sqrt(tf.abs(total_dist)) 33 | inf_eye = tf.eye(tf.shape(total_dist)[0])*1e+7 34 | 35 | #All element positive now, no smallest elements functions. 36 | all_dist = tf.math.add(inf_eye, total_dist) 37 | neg_all_dist = tf.negative(all_dist) 38 | values, indices = tf.math.top_k(input=neg_all_dist, k=k, sorted=True) 39 | values = tf.negative(values) 40 | 41 | return indices, labels 42 | 43 | # This part is a pain in the ass to do with tensorflow. 44 | # addition_labels = tf.get_variable(initializer=tf.zeros_initializer(), shape=(), name='addition_labels') 45 | # def find_acc_labels(addition_labels, indices, labels, k): 46 | # for i in tf.shape(indices)[0]: 47 | # add = 0 48 | # for j in range(k): 49 | # a = tf.gather_nd(indices, [i, j]) 50 | # add += tf.gather_nd(labels, [a, 0]) 51 | # tf.assign(addition_labels[i], add) 52 | # find_acc_labels(addition_labels, indices, labels, k) 53 | 54 | 55 | def k_nearest_neighbor_np_part(indices, labels, k, x_samples): 56 | num_vectors = indices.shape[0] 57 | addition_labels = np.zeros((num_vectors, 1)) 58 | for i in range(num_vectors): 59 | add = 0 60 | for j in range(k): 61 | add += labels[indices[i, j]] 62 | 63 | addition_labels[i] = add 64 | 65 | # Numpy implemementation. 66 | prediction = 1.*(addition_labels>(k/2.)) 67 | true_positive = np.sum(prediction*labels) 68 | false_positive = np.sum(prediction*(1-labels)) 69 | true_negative = np.sum((1-prediction)*labels) 70 | false_negative = np.sum((1-prediction)*(1-labels)) 71 | 72 | precision = true_positive/(true_positive+false_positive) 73 | recall = true_positive/(true_positive+false_negative) 74 | 75 | accuracy_true = true_positive/(true_positive+false_negative) 76 | accuracy_false = true_negative/(true_negative+false_positive) 77 | 78 | matched = np.equal(labels, prediction)*1. 79 | accuracy_x = np.mean(matched[:x_samples, :]) 80 | accuracy_y = np.mean(matched[x_samples:, :]) 81 | accuracy = np.mean(matched) 82 | 83 | return accuracy_x, accuracy_y, accuracy -------------------------------------------------------------------------------- /models/score/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from models.score.utils import * 4 | 5 | 6 | def kernel_inception_distance(x, y, gamma=1, coef=1, degree=3): 7 | k_xx = polinomial_kernel(x, x) 8 | k_yy = polinomial_kernel(y, y) 9 | k_xy = polinomial_kernel(x, y) 10 | kid = maximum_mean_discrepancy(k_xx, k_yy, k_xy) 11 | return kid -------------------------------------------------------------------------------- /models/score/mmd.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from models.score.utils import * 4 | 5 | 6 | def maximmum_mean_discrepancy_score(x, y, sigma=1): 7 | xx_d = euclidean_distance(x, x) 8 | yy_d = euclidean_distance(y, y) 9 | xy_d = euclidean_distance(x, y) 10 | 11 | scale = tf.reduce_mean(xx_d) 12 | # Gaussian kernel 13 | k_xx = tf.exp(-(xx_d)/(2*scale*(sigma**2))) 14 | k_yy = tf.exp(-(yy_d)/(2*scale*(sigma**2))) 15 | k_xy = tf.exp(-(xy_d)/(2*scale*(sigma**2))) 16 | 17 | mmd = maximum_mean_discrepancy(k_xx, k_yy, k_xy) 18 | return mmd -------------------------------------------------------------------------------- /models/score/mode_score.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def mode_score(p_xi_y, p_ui_y, batch_size, epsilon=1e-20): 6 | batch_scores = list() 7 | batches = int(p_xi_y.shape.as_list([0])/batch_size) 8 | for i in range(batches): 9 | if batches-1 == i: 10 | p_xi_y_batch = p_xi_y[i*batch_size: , :] 11 | p_ui_y_batch = p_ui_y[i*batch_size: , :] 12 | else: 13 | p_xi_y_batch = p_xi_y[i*batch_size: (i+1)*batch_size, :] 14 | p_ui_y_batch = p_ui_y[i*batch_size: (i+1)*batch_size, :] 15 | # Marginal label distribution over all batch samples. 16 | p_y_batch = tf.reduce_mean(p_xi_y_batch, axis=0) 17 | p_yu_batch = tf.reduce_mean(p_ui_y_batch, axis=0) 18 | kl_dist = p_xi_y_batch * (tf.log(p_xi_y_batch+epsilon) - tf.log(p_y_batch+epsilon)) 19 | kl_dist_y = p_y_batch * (tf.log(p_y_batch+epsilon) - tf.log(p_yu_batch+epsilon)) 20 | kl_dist_y = tf.reduce_sum(kl_dist_y) 21 | is_batch = tf.exp(tf.reduce_mean(kl_dist, axis=-1) - kl_dist_y) 22 | batch_scores.append(is_batch) 23 | return np.mean(batch_scores), np.std(batch_scores) 24 | -------------------------------------------------------------------------------- /models/score/score.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.gan as tfgan 3 | from data_manipulation.utils import * 4 | from models.score.utils import * 5 | from models.score.frechet_inception_distance import * 6 | from models.score.kernel_inception_distance import * 7 | from models.score.mmd import * 8 | from models.score.k_nearest_neighbor import * 9 | from models.score.inception_score import * 10 | from models.score.mode_score import * 11 | 12 | 13 | class Scores(object): 14 | def __init__(self, hdf5_features_x, hdf5_features_y, name_x, name_y, x_dict='features', y_dict='features', k=1, GPU=False, display=False): 15 | super(Scores, self).__init__() 16 | self.hdf5_features_x_path = hdf5_features_x 17 | self.hdf5_features_y_path = hdf5_features_y 18 | self.name_x = name_x 19 | self.name_y = name_y 20 | self.title = '%s Features - %s Features' % (self.name_x, self.name_y) 21 | self.k = k 22 | self.x_dict = x_dict 23 | self.y_dict = y_dict 24 | 25 | self.fid = None 26 | self.kid = None 27 | self.mmd = None 28 | self.knn_x = None 29 | self.knn_y = None 30 | self.knn = None 31 | 32 | self.display = display 33 | 34 | if GPU: 35 | self.config = tf.ConfigProto() 36 | else: 37 | self.config = tf.ConfigProto(device_count = {'GPU': 0}) 38 | 39 | self.read_hdfs() 40 | if self.display: 41 | print('Running:', self.title) 42 | print('Loded HDF5 Files') 43 | print(self.name_x, 'Shape:', self.features_x.shape) 44 | print(self.name_y, 'Shape:', self.features_y.shape) 45 | self.build_graph() 46 | if self.display: 47 | print('Created Graph.') 48 | 49 | def read_hdfs(self): 50 | self.features_x = read_hdf5(self.hdf5_features_x_path, self.x_dict) 51 | self.features_y = read_hdf5(self.hdf5_features_y_path, self.y_dict) 52 | 53 | def build_graph(self): 54 | self.x_input = tf.placeholder(dtype=tf.float32, shape=(None, self.features_x.shape[-1]), name='x_features') 55 | self.y_input = tf.placeholder(dtype=tf.float32, shape=(None, self.features_y.shape[-1]), name='y_features') 56 | self.fid_output = tfgan.eval.frechet_classifier_distance_from_activations(self.x_input, self.y_input) 57 | self.kid_output = kernel_inception_distance(self.x_input, self.y_input) 58 | self.mmd_output = maximmum_mean_discrepancy_score(self.x_input, self.y_input) 59 | self.indices_output, self.labels_output = k_nearest_neighbor_tf_part(self.x_input, self.y_input, k=self.k) 60 | 61 | def run_mmd(self): 62 | with tf.Session(config=self.config) as sess: 63 | sess.run(tf.global_variables_initializer()) 64 | feed_dict = {self.x_input:self.features_x, self.y_input:self.features_y} 65 | self.mmd = sess.run([self.mmd_output], feed_dict) 66 | 67 | def run_scores(self): 68 | with tf.Session(config=self.config) as sess: 69 | sess.run(tf.global_variables_initializer()) 70 | feed_dict = {self.x_input:self.features_x, self.y_input:self.features_y} 71 | self.fid, self.kid, self.mmd, self.indices, self.labels = sess.run([self.fid_output, self.kid_output, self.mmd_output, self.indices_output, self.labels_output], feed_dict) 72 | self.knn_x, self.knn_y,self.knn = k_nearest_neighbor_np_part(self.indices, self.labels, k=self.k, x_samples=self.features_x.shape[0]) 73 | if self.display: 74 | self.report_scores() 75 | print() 76 | 77 | def report_scores(self): 78 | print() 79 | print('--------------------------------------------------------') 80 | print(self.title) 81 | print('Frechet Inception Distance:', self.fid) 82 | # print('Kernel Inception Distance:', self.kid) 83 | # print('Mean Minimum Distance:', self.mmd) 84 | # print('%s-NN %s Accuracy:' % (self.k, self.name_x), self.knn_x) 85 | # print('%s-NN %s Accuracy:' % (self.k, self.name_y), self.knn_y) 86 | # print('%s-NN Accuracy:' % (self.k), self.knn) 87 | print() 88 | print('--------------------------------------------------------') 89 | print() 90 | -------------------------------------------------------------------------------- /models/score/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def polinomial_kernel(x, y, gamma=1., coef=1, degree=3): 4 | # Pair-wise dot product. 5 | xy = dot_product(x, y) 6 | g = gamma/float(x.shape.as_list()[1]) 7 | xy_g = xy*g 8 | kernel = tf.pow(xy_g + coef, degree) 9 | return kernel 10 | 11 | 12 | # Euclidean distance, pairwise comparison. 13 | def euclidean_distance(x, y, squared=True): 14 | # num_sample_x = x.shape.as_list()[0] 15 | # num_sample_y = y.shape.as_list()[0] 16 | 17 | # x_reshape = tf.reshape(x, shape=(num_sample_x, -1)) 18 | # y_reshape = tf.reshape(y, shape=(num_sample_y, -1)) 19 | 20 | x_sum_dims = sum(x.shape.as_list()[1:]) 21 | y_sum_dims = sum(y.shape.as_list()[1:]) 22 | 23 | x_reshape = tf.reshape(x, shape=(-1, x_sum_dims)) 24 | y_reshape = tf.reshape(y, shape=(-1, y_sum_dims)) 25 | 26 | x_2 = tf.reduce_sum(x_reshape**2, axis=-1, keepdims=True) 27 | y_2 = tf.reduce_sum(y_reshape**2, axis=-1, keepdims=True) 28 | y_2 = tf.transpose(y_2) 29 | x_y = dot_product(x, y) 30 | 31 | distance = x_2 + y_2 - 2*x_y 32 | 33 | if not squared: 34 | # Maintain distance, removing negative values 35 | distance = (distance + tf.abs(distance))/2 36 | distance = tf.sqrt(distance) 37 | 38 | # print('distance', distance) 39 | 40 | return distance 41 | 42 | # Dot product of all pair-wise vectors. 43 | def dot_product(x, y): 44 | x_sum_dims = sum(x.shape.as_list()[1:]) 45 | y_sum_dims = sum(y.shape.as_list()[1:]) 46 | 47 | x_reshape = tf.reshape(x, shape=(-1, x_sum_dims)) 48 | y_reshape = tf.reshape(y, shape=(-1, y_sum_dims)) 49 | 50 | dot_prod = tf.matmul(x_reshape, y_reshape, transpose_b=True) 51 | return dot_prod 52 | 53 | def covariance(x): 54 | mean_x = tf.reduce_mean(x, axis=0, keepdims=True) 55 | mx = tf.matmul(mean_x, mean_x, transpose_a=True) 56 | vx = tf.matmul(x, x, transpose_a=True)/tf.cast(x.shape.as_list()[0], tf.float32) 57 | covariance = vx - mx 58 | return covariance 59 | 60 | def maximum_mean_discrepancy(k_xx, k_yy, k_xy): 61 | samples_x = tf.cast(tf.shape(k_xx)[0], dtype=tf.float32) 62 | samples_y = tf.cast(tf.shape(k_yy)[0], dtype=tf.float32) 63 | 64 | k_xx_diag = tf.multiply(k_xx, tf.eye(tf.shape(k_xx)[0])) 65 | k_xx = k_xx - k_xx_diag 66 | 67 | k_yy_diag = tf.multiply(k_yy, tf.eye(tf.shape(k_yy)[0])) 68 | k_yy = k_yy - k_yy_diag 69 | 70 | E_xx = tf.reduce_sum(k_xx)/(samples_x*(samples_x-1)) 71 | E_yy = tf.reduce_sum(k_yy)/(samples_y*(samples_y-1)) 72 | E_xy = tf.reduce_mean(k_xy) 73 | mmd_2 = E_xx + E_yy - 2*E_xy 74 | mmd = tf.sqrt(tf.maximum(mmd_2,0)) 75 | return mmd 76 | -------------------------------------------------------------------------------- /models/tools.py: -------------------------------------------------------------------------------- 1 | from tensorflow.contrib.tensorboard.plugins import projector 2 | from models.utils import * 3 | 4 | 5 | # Method to generate random samples from a model, it also dumps a sprite image width them. 6 | def generate_samples(model, n_images, data_out_path, name='geneated_samples.png'): 7 | saver = tf.train.Saver() 8 | with tf.Session() as session: 9 | # Initializer and restoring model. 10 | session.run(tf.global_variables_initializer()) 11 | check = get_checkpoint(data_out_path) 12 | saver.restore(session, check) 13 | # Sample images. 14 | gen_samples, sample_z = show_generated(session=session, z_input=model.z_input, z_dim=model.z_dim, output_fake=model.output_gen, n_images=n_images, show=False) 15 | 16 | images_path = os.path.join(data_out_path, 'images') 17 | 18 | # Dump images into sprite. 19 | # image_sprite = write_sprite_image(filename=os.path.join(images_path, name), data=gen_samples, metadata=False) 20 | 21 | return gen_samples, sample_z 22 | 23 | # Method to generate random samples from a model, it also dumps a sprite image width them. 24 | def generate_from_latent(model, latent_vector, data_out_path): 25 | saver = tf.train.Saver() 26 | with tf.Session() as session: 27 | # Initializer and restoring model. 28 | session.run(tf.global_variables_initializer()) 29 | check = get_checkpoint(data_out_path) 30 | saver.restore(session, check) 31 | # Sample images. 32 | feed_dict = {model.z_input: latent_vector.reshape((-1, model.z_dim))} 33 | gen_batch = session.run(model.output_gen, feed_dict=feed_dict) 34 | return gen_batch 35 | 36 | 37 | # Method to generate images from the linear interpolation of two latent space vectors. 38 | def linear_interpolation(model, n_images, data_out_path, orig_vector, dest_vector): 39 | saver = tf.train.Saver() 40 | with tf.Session() as session: 41 | # Initializer and restoring model. 42 | session.run(tf.global_variables_initializer()) 43 | check = get_checkpoint(data_out_path) 44 | saver.restore(session, check) 45 | 46 | sequence = np.zeros((n_images, model.z_dim)) 47 | # Generate images from model. 48 | alphaValues = np.linspace(0, 1, n_images) 49 | for i, alpha in enumerate(alphaValues): 50 | sequence[i, :] = orig_vector*(1-alpha) + dest_vector*alpha 51 | # Latent space interpolation 52 | 53 | feed_dict = {model.z_input: sequence} 54 | linear_interpolation = session.run(model.output_gen, feed_dict=feed_dict) 55 | 56 | return linear_interpolation, sequence 57 | 58 | 59 | # Generates samples from the latent space to show in tensorboard. 60 | # Restores a model and somples from it. 61 | def run_latent(model, n_images, data_out_path, sprite=True): 62 | 63 | tensorboard_path = os.path.join(data_out_path, 'tensorboard') 64 | saver = tf.train.Saver() 65 | with tf.Session() as session: 66 | 67 | # Initializer and restoring model. 68 | session.run(tf.global_variables_initializer()) 69 | check = get_checkpoint(data_out_path) 70 | saver.restore(session, check) 71 | 72 | # Inputs for tensorboard. 73 | tf_data = tf.Variable(tf.zeros((n_images, model.z_dim)), name='tf_data') 74 | input_sample = tf.placeholder(tf.float32, shape=(n_images, model.z_dim)) 75 | set_tf_data = tf.assign(tf_data, input_sample, validate_shape=False) 76 | 77 | if sprite: 78 | # Sample images. 79 | gen_samples, sample_z = show_generated(session=session, z_input=model.z_input, z_dim=model.z_dim, output_fake=model.output_gen, n_images=n_images, show=False) 80 | # Generate sprite of images. 81 | write_sprite_image(filename=os.path.join(data_out_path, 'gen_sprite.png'), data=gen_samples) 82 | else: 83 | sample_z = np.random.uniform(low=-1., high=1., size=(n_images, model.z_dim)) 84 | 85 | # Variable for embedding. 86 | saver_latent = tf.train.Saver([tf_data]) 87 | session.run(set_tf_data, feed_dict={input_sample: sample_z}) 88 | saver_latent.save(sess=session, save_path=os.path.join(tensorboard_path, 'tf_data.ckpt')) 89 | 90 | # Tensorflow embedding. 91 | config = projector.ProjectorConfig() 92 | embedding = config.embeddings.add() 93 | embedding.tensor_name = tf_data.name 94 | if sprite: 95 | embedding.metadata_path = os.path.join(data_out_path, 'metadata.tsv') 96 | embedding.sprite.image_path = os.path.join(data_out_path, 'gen_sprite.png') 97 | embedding.sprite.single_image_dim.extend([model.image_height, model.image_width]) 98 | projector.visualize_embeddings(tf.summary.FileWriter(tensorboard_path), config) 99 | 100 | 101 | -------------------------------------------------------------------------------- /models/visualization/survival.py: -------------------------------------------------------------------------------- 1 | # Imports. 2 | import matplotlib.pyplot as plt 3 | from decimal import Decimal 4 | import numpy as np 5 | import os 6 | 7 | # Survival libs. 8 | from lifelines import KaplanMeierFitter 9 | from lifelines.statistics import logrank_test 10 | 11 | 12 | def plot_KM_high_low(high_risk_df, low_risk_df, ax_title, ax, event_ind_field, event_data_field, max_months=None): 13 | kmf_l = KaplanMeierFitter(label='Low-Risk') 14 | kmf_l.fit(low_risk_df[event_data_field].astype(float)/12, event_observed=low_risk_df[event_ind_field].astype(float)) 15 | kmf_h = KaplanMeierFitter(label='High-Risk') 16 | kmf_h.fit(high_risk_df[event_data_field].astype(float)/12, event_observed=high_risk_df[event_ind_field].astype(float)) 17 | 18 | kmf_l.plot_survival_function(show_censors=True, ci_show=True, ax=ax) 19 | kmf_h.plot_survival_function(show_censors=True, ci_show=True, ax=ax) 20 | 21 | ax.set_title(ax_title) 22 | ax.set_ylim([0.0,1.10]) 23 | if max_months is not None: 24 | ax.set_xlim([0.0, max_months]) 25 | 26 | 27 | def plot_k_fold_cv_KM(high_risk, low_risk, title, max_months, event_ind_field, event_data_field, file_path=None): 28 | results = logrank_test(high_risk[event_data_field].astype(float), low_risk[event_data_field].astype(float), event_observed_A=high_risk[event_ind_field].astype(float), event_observed_B=low_risk[event_ind_field].astype(float)) 29 | title_add = 'P-Value: %.2E ' % (Decimal(results.p_value)) 30 | mosaic = '''A''' 31 | fig = plt.figure(figsize=(15,7), constrained_layout=True) 32 | ax_dict = fig.subplot_mosaic(mosaic) 33 | plot_KM_high_low(high_risk, low_risk, ax_title=title + title_add, ax=ax_dict['A'], event_ind_field=event_ind_field, event_data_field=event_data_field, max_months=max_months) 34 | plt.savefig(file_path) 35 | plt.close(fig) 36 | 37 | return results.p_value 38 | 39 | def save_fold_KMs(risk_groups, additional_risk, resolution, groupby, cis, event_ind_field, event_data_field, max_months, cox_cluster_path): 40 | cis = np.array(cis) 41 | 42 | # Save folds avg. 43 | test_mean = np.round(np.mean(cis[:, 2]), 2) 44 | title = 'Leiden %s\nC-Index: %s \n' % (resolution, test_mean) 45 | file_path=os.path.join(cox_cluster_path, 'KM_%s_test.jpg' % (str(groupby).replace('.', 'p'))) 46 | test_pval = plot_k_fold_cv_KM(risk_groups[1], risk_groups[0], title=title, max_months=max_months/12, event_ind_field=event_ind_field, event_data_field=event_data_field, file_path=file_path) 47 | 48 | additional_pval = None 49 | if cis[0,3] is not None: 50 | # Save folds avg. 51 | additional_mean = np.round(np.mean(cis[:, 3]), 2) 52 | title = 'Leiden %s\nC-Index: %s \n' % (resolution, additional_mean) 53 | file_path=os.path.join(cox_cluster_path, 'KM_%s_additional.jpg' % (str(groupby).replace('.', 'p'))) 54 | additional_pval = plot_k_fold_cv_KM(additional_risk[1], additional_risk[0], title=title, max_months=max_months/12, event_ind_field=event_ind_field, event_data_field=event_data_field, file_path=file_path) 55 | 56 | return test_pval, additional_pval -------------------------------------------------------------------------------- /models/visualization/utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import * 2 | import numpy as np 3 | import h5py 4 | import os 5 | 6 | from models.utils import * 7 | 8 | 9 | def setup_esq_folder(hdf5_projections_path, folder_type, att_model, attention_run, label): 10 | img_path = hdf5_projections_path.split('hdf5_')[0] 11 | img_path = os.path.join(img_path, 'Representations') 12 | img_path = os.path.join(img_path, folder_type) 13 | att_img_path = os.path.join(img_path, '%s_%s' % (att_model, attention_run)) 14 | img_path = os.path.join(att_img_path, label) 15 | histograms_img_mil_path = os.path.join(att_img_path, 'weight_histograms') 16 | 17 | latents_5x_img_mil_path = os.path.join(att_img_path, 'latents_5x') 18 | latents_10x_img_mil_path = os.path.join(att_img_path, 'latents_10x') 19 | latents_20x_img_mil_path = os.path.join(att_img_path, 'latents_20x') 20 | 21 | slides_5x_img_mil_path = os.path.join(att_img_path, 'WSI_5x') 22 | slides_10x_img_mil_path = os.path.join(att_img_path, 'WSI_10x') 23 | slides_20x_img_mil_path = os.path.join(att_img_path, 'WSI_20x') 24 | 25 | miss_slides_5x_img_mil_path = os.path.join(att_img_path, 'WSI_5x_miss') 26 | miss_slides_10x_img_mil_path = os.path.join(att_img_path, 'WSI_10x_miss') 27 | miss_slides_20x_img_mil_path = os.path.join(att_img_path, 'WSI_20x_miss') 28 | 29 | paths = [histograms_img_mil_path, latents_5x_img_mil_path, latents_10x_img_mil_path, latents_20x_img_mil_path, \ 30 | miss_slides_5x_img_mil_path, miss_slides_10x_img_mil_path, miss_slides_20x_img_mil_path, slides_5x_img_mil_path, slides_10x_img_mil_path, slides_20x_img_mil_path] 31 | for path in paths: 32 | if not os.path.isdir(path): 33 | os.makedirs(path) 34 | return histograms_img_mil_path, [latents_5x_img_mil_path, latents_10x_img_mil_path, latents_20x_img_mil_path], \ 35 | [slides_5x_img_mil_path, slides_10x_img_mil_path, slides_20x_img_mil_path], [miss_slides_5x_img_mil_path, miss_slides_10x_img_mil_path, miss_slides_20x_img_mil_path] 36 | 37 | 38 | def gather_projection_var_multimag(hdf5_projections_path): 39 | all_output = gather_content_multi_magnification(hdf5_projections_path, set_type='combined', h_latent=False) 40 | projection_content = h5py.File(hdf5_projections_path, mode='r') 41 | orig_set = projection_content['combined_set_original'] 42 | cluster_labels_20x = None 43 | cluster_labels_10x = None 44 | cluster_labels_5x = None 45 | if 'cluster_labels_20x' in projection_content: 46 | cluster_labels_20x = projection_content['cluster_labels_20x'] 47 | cluster_labels_10x = projection_content['cluster_labels_10x'] 48 | cluster_labels_5x = projection_content['cluster_labels_5x'] 49 | 50 | all_output = list(all_output) 51 | all_output.extend([orig_set, cluster_labels_20x, cluster_labels_10x, cluster_labels_5x]) 52 | return all_output 53 | 54 | 55 | def read_images(hdf5_path_img, set_name): 56 | content = h5py.File(hdf5_path_img, mode='r') 57 | img = content['%s_img' % set_name] 58 | sld = content['%s_slides' % set_name] 59 | pat = content['%s_patterns' % set_name] 60 | return img, sld, pat 61 | 62 | 63 | def get_all_magnification_references(main_path, img_size, dataset_5x, dataset_10x, dataset_20x): 64 | # Test images magnifications. 65 | h5_test_mag = [] 66 | hdf5_path_img_test_5x = '%s/datasets/%s/he/patches_h%s_w%s/hdf5_%s_he_test.h5' % (main_path, dataset_5x, img_size, img_size, dataset_5x) 67 | hdf5_path_img_test_10x = '%s/datasets/%s/he/patches_h%s_w%s/hdf5_%s_he_test.h5' % (main_path, dataset_10x, img_size, img_size, dataset_10x) 68 | hdf5_path_img_test_20x = '%s/datasets/%s/he/patches_h%s_w%s/hdf5_%s_he_test.h5' % (main_path, dataset_20x, img_size, img_size, dataset_20x) 69 | 70 | # Validation images magnifications. 71 | h5_valid_mag = [] 72 | hdf5_path_img_valid_5x = '%s/datasets/%s/he/patches_h%s_w%s/hdf5_%s_he_validation.h5' % (main_path, dataset_5x, img_size, img_size, dataset_5x) 73 | hdf5_path_img_valid_10x = '%s/datasets/%s/he/patches_h%s_w%s/hdf5_%s_he_validation.h5' % (main_path, dataset_10x, img_size, img_size, dataset_10x) 74 | hdf5_path_img_valid_20x = '%s/datasets/%s/he/patches_h%s_w%s/hdf5_%s_he_validation.h5' % (main_path, dataset_20x, img_size, img_size, dataset_20x) 75 | 76 | # Train images magnifications. 77 | h5_train_mag = [] 78 | hdf5_path_img_train_5x = '%s/datasets/%s/he/patches_h%s_w%s/hdf5_%s_he_train.h5' % (main_path, dataset_5x, img_size, img_size, dataset_5x) 79 | hdf5_path_img_train_10x = '%s/datasets/%s/he/patches_h%s_w%s/hdf5_%s_he_train.h5' % (main_path, dataset_10x, img_size, img_size, dataset_10x) 80 | hdf5_path_img_train_20x = '%s/datasets/%s/he/patches_h%s_w%s/hdf5_%s_he_train.h5' % (main_path, dataset_20x, img_size, img_size, dataset_20x) 81 | 82 | h5_test_mag = [hdf5_path_img_test_5x, hdf5_path_img_test_10x, hdf5_path_img_test_20x] 83 | h5_valid_mag = [hdf5_path_img_valid_5x, hdf5_path_img_valid_10x, hdf5_path_img_valid_20x] 84 | h5_train_mag = [hdf5_path_img_train_5x, hdf5_path_img_train_10x, hdf5_path_img_train_20x] 85 | 86 | # Verify all files are there. 87 | flag = False 88 | for h5_file in h5_test_mag + h5_valid_mag + h5_train_mag: 89 | if not os.path.isfile(h5_file): 90 | print('Image H5 file not found:', h5_file) 91 | flag = True 92 | if flag: 93 | print() 94 | print('Missing H5 files with images: Break #1 - Look at Dataset image variables x5/x10/x20. Files could be missing too.') 95 | exit() 96 | 97 | return h5_train_mag, h5_valid_mag, h5_test_mag 98 | 99 | 100 | def gather_original_partition(train_sld, valid_sld, test_sld): 101 | def gather_slides(slides): 102 | unique_slides = np.unique(slides[:].astype(str)).tolist() 103 | # unique_partic = ['-'.join(slide.split('-')[:3]) for slide in unique_slides] 104 | return unique_slides 105 | 106 | def gather_patients(slides): 107 | unique_slides = np.unique(slides[:].astype(str)).tolist() 108 | unique_partic = ['-'.join(slide.split('-')[:3]) for slide in unique_slides] 109 | return unique_partic 110 | 111 | train_sld_20x, train_sld_10x, train_sld_5x = train_sld 112 | valid_sld_20x, valid_sld_10x, valid_sld_5x = valid_sld 113 | test_sld_20x, test_sld_10x, test_sld_5x = test_sld 114 | 115 | part_train_20x = gather_slides(train_sld_20x) 116 | part_train_10x = gather_slides(train_sld_10x) 117 | part_train_5x = gather_slides(train_sld_5x) 118 | orig_part_train = list(set(part_train_20x + part_train_10x + part_train_5x)) 119 | 120 | part_valid_20x = gather_slides(valid_sld_20x) 121 | part_valid_10x = gather_slides(valid_sld_10x) 122 | part_valid_5x = gather_slides(valid_sld_5x) 123 | orig_part_valid = list(set(part_valid_20x + part_valid_10x + part_valid_5x)) 124 | 125 | part_test_20x = gather_slides(test_sld_20x) 126 | part_test_10x = gather_slides(test_sld_10x) 127 | part_test_5x = gather_slides(test_sld_5x) 128 | orig_part_test = list(set(part_test_20x + part_test_10x + part_test_5x)) 129 | 130 | orig_part = [orig_part_train, orig_part_valid, orig_part_test] 131 | 132 | print('Intersection Slides and sets:') 133 | print('Train/Valid:', set(orig_part_train).intersection(set(orig_part_valid))) 134 | print('Train/Test: ', set(orig_part_test).intersection(set(orig_part_valid))) 135 | print('Valid/Test: ', set(orig_part_train).intersection(set(orig_part_test))) 136 | 137 | return orig_part 138 | 139 | 140 | def pull_top_missclassified(test_slides, slides, slides_metrics, probs, patterns, label, top_percent=0.1): 141 | all_slides = list() 142 | all_diff = list() 143 | all_class = list() 144 | all_probs = list() 145 | all_preds = list() 146 | 147 | for slide in test_slides: 148 | inds = np.argwhere(slides_metrics==slide)[0,0] 149 | inds_p = np.argwhere(slides==slide)[0,0] 150 | if label in patterns[inds_p,0]: 151 | class_slide = 1 152 | else: 153 | class_slide = 0 154 | 155 | diff = np.abs(class_slide - probs[inds,1]) 156 | all_slides.append(slide) 157 | all_diff.append(diff) 158 | all_class.append(class_slide) 159 | all_probs.append(probs[inds,1]) 160 | all_preds.append(np.argmax(probs[inds])) 161 | 162 | all_slides = np.vstack(all_slides) 163 | all_diff = np.vstack(all_diff) 164 | all_class = np.vstack(all_class) 165 | all_probs = np.vstack(all_probs) 166 | all_preds = np.vstack(all_preds) 167 | 168 | inds = np.argsort(all_diff[:,0]) 169 | inds_match = np.argwhere(all_diff[:,0]<0.5) 170 | inds_not = np.argwhere(all_diff[:,0]>=0.5) 171 | inds_match = np.intersect1d(inds_match, inds) 172 | inds_not = np.intersect1d(inds_not, inds) 173 | 174 | top_nsamples = math.ceil(top_percent*all_diff.shape[0]) 175 | wrt_ind = inds_not 176 | top_ind = inds_match[:top_nsamples] 177 | 178 | wrt_prob = all_probs[wrt_ind,0] 179 | wrt_slides = all_slides[wrt_ind,0] 180 | wrt_class = all_class[wrt_ind,0] 181 | 182 | top_prob = all_probs[top_ind,0] 183 | top_slides = all_slides[top_ind,0] 184 | top_class = all_class[top_ind,0] 185 | 186 | top_sl = list(zip(top_slides, top_prob, top_class)) 187 | wrt_sl = list(zip(wrt_slides, wrt_prob, wrt_class)) 188 | 189 | return top_sl, wrt_sl 190 | -------------------------------------------------------------------------------- /models/visualization/weight_hist.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('pdf') 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os 6 | 7 | from models.evaluation.latent_space import * 8 | from models.utils import * 9 | from models.visualization.utils import * 10 | 11 | 12 | def show_weight_distributions(test_slides, weights_5x, weights_10x, weights_20x, slides, patterns, inst_labels, orig_set, fold_number, label, hist_img_mil_path, show_distribution=True): 13 | for slide, prob, class_ in test_slides: 14 | # Slide indices 15 | slide_indices = list(np.argwhere(slides[:]==slide)[:,0]) 16 | 17 | site = inst_labels[slide_indices[0]] 18 | 19 | # Slide weights 5x. 20 | slide_weights_5x = weights_5x[slide_indices,0] 21 | slide_weights_5x_norm = (slide_weights_5x - np.min(slide_weights_5x))/(np.max(slide_weights_5x)-np.min(slide_weights_5x)) 22 | 23 | # Slide weights 10x. 24 | slide_weights_10x = weights_10x[slide_indices, :, 0]*np.reshape(weights_5x[slide_indices,0], (-1, 1)) 25 | slide_weights_10x = np.reshape(slide_weights_10x, (-1,1)) 26 | slide_weights_10x_norm = (slide_weights_10x - np.min(slide_weights_10x))/(np.max(slide_weights_10x)-np.min(slide_weights_10x)) 27 | 28 | # Slide weights 20x. 29 | slide_weights_20x = weights_20x[slide_indices, :, 0]*np.reshape(weights_5x[slide_indices,0], (-1, 1)) 30 | slide_weights_20x = np.reshape(slide_weights_20x, (-1,1)) 31 | slide_weights_20x_norm = (slide_weights_20x - np.min(slide_weights_20x))/(np.max(slide_weights_20x)-np.min(slide_weights_20x)) 32 | 33 | fig, axes = plt.subplots(figsize=(50,15), ncols=1, nrows=3) 34 | for i, values in enumerate([(slide_weights_5x_norm, '5x'), (slide_weights_10x_norm, '10x'), (slide_weights_20x_norm, '20x')]): 35 | weights_slide, magnification = values 36 | axes[i].set_title(magnification, fontsize=20) 37 | axes[i].hist(weights_slide, bins=200, log=True) 38 | plt.suptitle('%s - %s - %s - Predict %s - %s' % (slide, magnification, patterns[slide_indices[0]][0], np.round(prob,4), site), fontsize=24) 39 | plt.savefig(os.path.join(hist_img_mil_path, '%s_prob%s_class%s.jpg' % (slide, np.round(prob,4), class_))) 40 | if show_distribution: 41 | plt.show() 42 | plt.close(fig) 43 | 44 | 45 | def get_weight_distributions(slides, patterns, institutions, orig_set, label, fold_path, directories, num_folds, h5_file_name): 46 | histograms_img_mil_path, latent_paths, wsi_paths, miss_wsi_paths = directories 47 | for fold_number in range(num_folds): 48 | print('\tFold', fold_number) 49 | 50 | hdf5_path_weights_comb = '%s/fold_%s/results/%s' % (fold_path, fold_number, h5_file_name) 51 | 52 | ### Attention runs 53 | weights_20x, weights_10x, weights_5x, probs, slides_metrics, train_slides, valid_slides, test_slides = gather_attention_results(hdf5_path_weights_comb) 54 | 55 | top_slides, wrt_slides = pull_top_missclassified(test_slides, slides, slides_metrics, probs, patterns, label, top_percent=0.10) 56 | 57 | top_slides = top_slides[:10] 58 | wrt_slides = wrt_slides[-10:] 59 | 60 | ### Histogram of weight distributions 61 | show_weight_distributions(top_slides, weights_5x, weights_10x, weights_20x, slides, patterns, institutions, orig_set, fold_number, label, histograms_img_mil_path, show_distribution=False) 62 | show_weight_distributions(wrt_slides, weights_5x, weights_10x, weights_20x, slides, patterns, institutions, orig_set, fold_number, label, histograms_img_mil_path, show_distribution=False) 63 | 64 | 65 | -------------------------------------------------------------------------------- /models/wandb_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def save_model_config(model, data): 4 | model_config = dict() 5 | 6 | ### Input data variables. 7 | model_config['image_height'] = model.image_height 8 | model_config['image_width'] = model.image_width 9 | model_config['image_channels'] = model.image_channels 10 | model_config['batch_size'] = model.batch_size 11 | 12 | ### Architecture parameters. 13 | model_config['attention'] = model.attention 14 | model_config['layers'] = model.layers 15 | model_config['spectral'] = model.spectral 16 | model_config['z_dim'] = model.z_dim 17 | model_config['init'] = model.init 18 | 19 | ### Hyperparameters. 20 | model_config['power_iterations'] = model.power_iterations 21 | model_config['regularizer_scale'] = model.regularizer_scale 22 | model_config['learning_rate_e'] = model.learning_rate_e 23 | model_config['beta_1'] = model.beta_1 24 | if hasattr(model, 'temperature'): 25 | model_config['temperature'] = model.temperature 26 | if hasattr(model, 'lambda_'): 27 | model_config['lambda_'] = model.lambda_ 28 | 29 | ### Data augmentation conditions. 30 | # Spatial transformation. 31 | model_config['crop'] = model.crop 32 | model_config['rotation'] = model.rotation 33 | model_config['flip'] = model.flip 34 | # Color transofrmation. 35 | model_config['color_distort'] = model.color_distort 36 | # Gaussian Blur and Noise. 37 | model_config['g_blur'] = model.g_blur 38 | model_config['g_noise'] = model.g_noise 39 | # Sobel Filter. 40 | model_config['sobel_filter'] = model.sobel_filter 41 | 42 | model_config['model_name'] = model.model_name 43 | 44 | model_config['conv_space_out'] = model.conv_space_out.shape.as_list() 45 | model_config['h_rep_out'] = model.h_rep_out.shape.as_list() 46 | model_config['z_rep_out'] = model.z_rep_out.shape.as_list() 47 | 48 | model_config['dataset'] = data.dataset 49 | 50 | return model_config 51 | 52 | 53 | def save_model_config_att(model): 54 | model_config = dict() 55 | 56 | ### Architecture parameters. 57 | model_config['z_dim'] = model.z_dim 58 | model_config['att_dim'] = model.att_dim 59 | model_config['init'] = model.init 60 | model_config['use_gated'] = model.use_gated 61 | 62 | ### Hyperparameters. 63 | model_config['beta_1'] = model.beta_1 64 | model_config['beta_2'] = model.beta_2 65 | model_config['learning_rate'] = model.learning_rate 66 | model_config['regularizer_scale'] = model.regularizer_scale 67 | 68 | return model_config 69 | -------------------------------------------------------------------------------- /report_representationsleiden_cox.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import argparse 3 | import os 4 | 5 | # Own libs. 6 | from models.clustering.logistic_regression_leiden_clusters import run_circular_plots 7 | from models.clustering.cox_proportional_hazard_regression_leiden_clusters import * 8 | 9 | 10 | ##### Main ####### 11 | parser = argparse.ArgumentParser(description='Report classification and cluster performance based on Logistic Regression.') 12 | parser.add_argument('--meta_folder', dest='meta_folder', type=str, default=None, help='Purpose of the clustering, name of folder.') 13 | parser.add_argument('--matching_field', dest='matching_field', type=str, default=None, help='Key used to match folds split and H5 representation file.') 14 | parser.add_argument('--event_ind_field', dest='event_ind_field', type=str, default=None, help='Key used to match event indicator field.') 15 | parser.add_argument('--event_data_field', dest='event_data_field', type=str, default=None, help='Key used to match event data field.') 16 | parser.add_argument('--diversity_key', dest='diversity_key', type=str, default=None, help='Key use to check diversity within cluster: Slide, Institution, Sample.') 17 | parser.add_argument('--type_composition', dest='type_composition', type=str, default='clr', help='Space transformation type: percent, clr, ilr, alr.') 18 | parser.add_argument('--min_tiles', dest='min_tiles', type=int, default=100, help='Minimum number of tiles per matching_field.') 19 | parser.add_argument('--folds_pickle', dest='folds_pickle', type=str, default=None, help='Pickle file with folds information.') 20 | parser.add_argument('--force_fold', dest='force_fold', type=int, default=None, help='Force fold of clustering.') 21 | parser.add_argument('--h5_complete_path', dest='h5_complete_path', type=str, required=True, help='H5 file path to run the leiden clustering folds.') 22 | parser.add_argument('--h5_additional_path', dest='h5_additional_path', type=str, default=None, help='Additional H5 representation to assign leiden clusters.') 23 | parser.add_argument('--additional_as_fold', dest='additional_as_fold', action='store_true', default=False, help='Flag to specify if additional H5 file will be used for cross-validation.') 24 | parser.add_argument('--report_clusters', dest='report_clusters', action='store_true', default=False, help='Flag to report cluster circular plots.') 25 | args = parser.parse_args() 26 | meta_folder = args.meta_folder 27 | matching_field = args.matching_field 28 | event_ind_field = args.event_ind_field 29 | event_data_field = args.event_data_field 30 | diversity_key = args.diversity_key 31 | type_composition = args.type_composition 32 | min_tiles = args.min_tiles 33 | folds_pickle = args.folds_pickle 34 | force_fold = args.force_fold 35 | h5_complete_path = args.h5_complete_path 36 | h5_additional_path = args.h5_additional_path 37 | additional_as_fold = args.additional_as_fold 38 | report_clusters = args.report_clusters 39 | 40 | max_months = 15.0*15.0 41 | 42 | # Use connectivity between clusters as features. 43 | use_conn = False 44 | use_ratio = False 45 | top_variance_feat = 99 46 | 47 | # Alphas and resolutions. 48 | l1_ratios = [0.0] 49 | alphas = 10. ** np.linspace(-4, 4, 50) 50 | 51 | # resolutions = [0.4, 0.7, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0] 52 | resolutions = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0] 53 | resolutions = [2.0] 54 | 55 | # Report figures for clusters. 56 | if report_clusters: run_circular_plots(resolutions, meta_folder, event_ind_field, matching_field, folds_pickle, h5_complete_path, h5_additional_path, diversity_key) 57 | 58 | # Run Cox Proportional Hazard Regression with L1/L2 Penalties. 59 | run_cph_regression_exhaustive(alphas, resolutions, meta_folder, matching_field, folds_pickle, event_ind_field, event_data_field, h5_complete_path, h5_additional_path, diversity_key, 60 | type_composition, min_tiles, max_months, additional_as_fold, force_fold, l1_ratios, use_conn=use_conn, use_ratio=use_ratio, top_variance_feat=top_variance_feat) 61 | 62 | # Summarize AUC performance for top 4 penalties. 63 | summary_resolution_cindex(resolutions, h5_complete_path, meta_folder, l1_ratios, min_tiles, force_fold) 64 | 65 | -------------------------------------------------------------------------------- /report_representationsleiden_cox_individual.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import argparse 3 | import os 4 | 5 | # Own libs. 6 | from models.clustering.logistic_regression_leiden_clusters import run_circular_plots 7 | from models.clustering.cox_proportional_hazard_regression_leiden_clusters import * 8 | 9 | ##### Main ####### 10 | parser = argparse.ArgumentParser(description='Report classification and cluster performance based on Logistic Regression.') 11 | parser.add_argument('--alpha', dest='alpha', type=float, default=None, help='Cox regression penalty value.') 12 | parser.add_argument('--resolution', dest='resolution', type=float, default=1.0, help='Leiden resolution.') 13 | parser.add_argument('--meta_folder', dest='meta_folder', type=str, default=None, help='Purpose of the clustering, name of folder.') 14 | parser.add_argument('--matching_field', dest='matching_field', type=str, default=None, help='Key used to match folds split and H5 representation file.') 15 | parser.add_argument('--event_ind_field', dest='event_ind_field', type=str, default=None, help='Key used to match event indicator field.') 16 | parser.add_argument('--event_data_field', dest='event_data_field', type=str, default=None, help='Key used to match event data field.') 17 | parser.add_argument('--diversity_key', dest='diversity_key', type=str, default=None, help='Key use to check diversity within cluster: Slide, Institution, Sample.') 18 | parser.add_argument('--type_composition', dest='type_composition', type=str, default='clr', help='Space transformation type: percent, clr, ilr, alr.') 19 | parser.add_argument('--l1_ratio', dest='l1_ratio', type=float, default=0.0, help='L1 Penalty for Cox regression.') 20 | parser.add_argument('--min_tiles', dest='min_tiles', type=int, default=100, help='Minimum number of tiles per matching_field.') 21 | parser.add_argument('--force_fold', dest='force_fold', type=int, default=None, help='Force fold of clustering.') 22 | parser.add_argument('--folds_pickle', dest='folds_pickle', type=str, default=None, help='Pickle file with folds information.') 23 | parser.add_argument('--h5_complete_path', dest='h5_complete_path', type=str, required=True, help='H5 file path to run the leiden clustering folds.') 24 | parser.add_argument('--h5_additional_path', dest='h5_additional_path', type=str, default=None, help='Additional H5 representation to assign leiden clusters.') 25 | parser.add_argument('--additional_as_fold', dest='additional_as_fold', action='store_true', default=False, help='Flag to specify if additional H5 file will be used for cross-validation.') 26 | args = parser.parse_args() 27 | alpha = args.alpha 28 | resolution = args.resolution 29 | meta_folder = args.meta_folder 30 | matching_field = args.matching_field 31 | event_ind_field = args.event_ind_field 32 | event_data_field = args.event_data_field 33 | diversity_key = args.diversity_key 34 | type_composition = args.type_composition 35 | min_tiles = args.min_tiles 36 | l1_ratio = args.l1_ratio 37 | folds_pickle = args.folds_pickle 38 | force_fold = args.force_fold 39 | h5_complete_path = args.h5_complete_path 40 | h5_additional_path = args.h5_additional_path 41 | additional_as_fold = args.additional_as_fold 42 | 43 | # Use connectivity between clusters as features. 44 | use_conn = False 45 | use_ratio = False 46 | top_variance_feat = 0 47 | 48 | # Other features 49 | q_buckets = 2 50 | max_months = 15.0*15.0 51 | 52 | # Run Cox Proportional Hazard Regression with L1/L2 Penalties. 53 | run_cph_regression_individual(alpha, resolution, meta_folder, matching_field, folds_pickle, event_ind_field, event_data_field, h5_complete_path, h5_additional_path, diversity_key, type_composition, 54 | min_tiles, max_months, additional_as_fold, force_fold, l1_ratio, q_buckets=q_buckets, use_conn=use_conn, use_ratio=use_ratio, top_variance_feat=top_variance_feat, 55 | remove_clusters=None, p_th=0.05) 56 | 57 | -------------------------------------------------------------------------------- /report_representationsleiden_lr.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import argparse 3 | import os 4 | 5 | # Own libs. 6 | from models.clustering.logistic_regression_leiden_clusters import * 7 | 8 | 9 | ##### Main ####### 10 | parser = argparse.ArgumentParser(description='Report classification and cluster performance based on Logistic Regression.') 11 | parser.add_argument('--meta_folder', dest='meta_folder', type=str, default=None, help='Purpose of the clustering, name of folder.') 12 | parser.add_argument('--meta_field', dest='meta_field', type=str, default=None, help='Meta field to use for the Logistic Regression.') 13 | parser.add_argument('--matching_field', dest='matching_field', type=str, default=None, help='Key used to match folds split and H5 representation file.') 14 | parser.add_argument('--diversity_key', dest='diversity_key', type=str, default=None, help='Key use to check diversity within cluster: Slide, Institution, Sample.') 15 | parser.add_argument('--type_composition', dest='type_composition', type=str, default='clr', help='Space transformation type: percent, clr, ilr, alr.') 16 | parser.add_argument('--min_tiles', dest='min_tiles', type=int, default=100, help='Minimum number of tiles per matching_field.') 17 | parser.add_argument('--folds_pickle', dest='folds_pickle', type=str, default=None, help='Pickle file with folds information.') 18 | parser.add_argument('--force_fold', dest='force_fold', type=int, default=None, help='Force fold of clustering.') 19 | parser.add_argument('--h5_complete_path', dest='h5_complete_path', type=str, required=True, help='H5 file path to run the leiden clustering folds.') 20 | parser.add_argument('--h5_additional_path', dest='h5_additional_path', type=str, default=None, help='Additional H5 representation to assign leiden clusters.') 21 | parser.add_argument('--additional_as_fold', dest='additional_as_fold', action='store_true', default=False, help='Flag to specify if additional H5 file will be used for cross-validation.') 22 | parser.add_argument('--report_clusters', dest='report_clusters', action='store_true', default=False, help='Flag to report cluster circular plots.') 23 | parser.add_argument('--min_range_auc', dest='min_range_auc', type=float, default=0.87, help='Force fold of clustering.') 24 | args = parser.parse_args() 25 | meta_folder = args.meta_folder 26 | meta_field = args.meta_field 27 | matching_field = args.matching_field 28 | diversity_key = args.diversity_key 29 | type_composition = args.type_composition 30 | min_tiles = args.min_tiles 31 | folds_pickle = args.folds_pickle 32 | force_fold = args.force_fold 33 | h5_complete_path = args.h5_complete_path 34 | h5_additional_path = args.h5_additional_path 35 | report_clusters = args.report_clusters 36 | additional_as_fold = args.additional_as_fold 37 | min_range_auc = args.min_range_auc 38 | 39 | # Use connectivity between clusters as features. 40 | use_conn = False 41 | use_ratio = False 42 | top_variance_feat = 99 43 | 44 | # Default alphas and resolutions. 45 | alphas = [0.1, 0.5, 1.0, 5.0, 10.0, 25.0, 30.0] 46 | alphas = [0.5, 1.0, 5.0, 10.0, 25.0, 30.0] 47 | resolutions = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5] 48 | 49 | # Report figures for clusters. 50 | if report_clusters: 51 | run_circular_plots(resolutions, meta_folder, meta_field, matching_field, folds_pickle, h5_complete_path, h5_additional_path, diversity_key) 52 | 53 | # Run logistic regression for different L1 penalties. 54 | run_logistic_regression(alphas, resolutions, meta_folder, meta_field, matching_field, folds_pickle, h5_complete_path, h5_additional_path, force_fold, additional_as_fold, diversity_key, 55 | use_conn=use_conn, use_ratio=use_ratio, top_variance_feat=top_variance_feat, type_composition=type_composition, min_tiles=min_tiles, p_th=0.05) 56 | 57 | # Summarize results. 58 | summarize_run(alphas, resolutions, meta_folder, meta_field, min_tiles, folds_pickle, h5_complete_path, ylim=[min_range_auc,1.01]) 59 | 60 | 61 | -------------------------------------------------------------------------------- /report_representationsleiden_samples.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import pandas as pd 3 | import argparse 4 | import os 5 | 6 | # Own libraries 7 | from data_manipulation.data import Data 8 | from models.visualization.clusters import plot_cluster_images, plot_wsi_clusters, plot_wsi_clusters_interactions 9 | 10 | 11 | ##### Main ####### 12 | parser = argparse.ArgumentParser(description='Report cluster images from a given Leiden cluster configuration.') 13 | parser.add_argument('--meta_folder', dest='meta_folder', type=str, default=None, help='Purpose of the clustering, name of folder.') 14 | parser.add_argument('--meta_field', dest='meta_field', type=str, default=None, help='Meta field to use for the Logistic Regression or Cox event indicator.') 15 | parser.add_argument('--matching_field', dest='matching_field', type=str, default=None, help='Key used to match folds split and H5 representation file.') 16 | parser.add_argument('--resolution', dest='resolution', type=float, default=None, help='Minimum number of tiles per matching_field.') 17 | parser.add_argument('--dpi', dest='dpi', type=int, default=1000, help='Highest quality: 1000.') 18 | parser.add_argument('--fold', dest='fold', type=int, default=0, help='Minimum number of tiles per matching_field.') 19 | parser.add_argument('--dataset', dest='dataset', type=str, default='TCGAFFPE_LUADLUSC_5x', help='Dataset to use.') 20 | parser.add_argument('--h5_complete_path', dest='h5_complete_path', type=str, required=True, help='H5 file path to run the leiden clustering folds.') 21 | parser.add_argument('--h5_additional_path', dest='h5_additional_path', type=str, default=None, help='Additional H5 representation to assign leiden clusters.') 22 | parser.add_argument('--min_tiles', dest='min_tiles', type=int, default=400, help='Minimum number of tiles per matching_field.') 23 | parser.add_argument('--dbs_path', dest='dbs_path', type=str, default=None, help='Path for the output run.') 24 | parser.add_argument('--img_size', dest='img_size', type=int, default=224, help='Image size for the model.') 25 | parser.add_argument('--img_ch', dest='img_ch', type=int, default=3, help='Number of channels for the model.') 26 | parser.add_argument('--marker', dest='marker', type=str, default='he', help='Marker of dataset to use.') 27 | parser.add_argument('--tile_img', dest='tile_img', action='store_true', default=False, help='Dump cluster tile images.') 28 | parser.add_argument('--extensive', dest='extensive', action='store_true', default=False, help='Flag to dump test set cluster images in addition to train.') 29 | parser.add_argument('--additional_as_fold', dest='additional_as_fold', action='store_true', default=False, help='Flag to specify if additional H5 file will be used for cross-validation.') 30 | args = parser.parse_args() 31 | meta_folder = args.meta_folder 32 | meta_field = args.meta_field 33 | matching_field = args.matching_field 34 | resolution = args.resolution 35 | dpi = args.dpi 36 | fold = args.fold 37 | min_tiles = args.min_tiles 38 | image_height = args.img_size 39 | image_width = args.img_size 40 | image_channels = args.img_ch 41 | marker = args.marker 42 | dataset = args.dataset 43 | h5_complete_path = args.h5_complete_path 44 | h5_additional_path = args.h5_additional_path 45 | dbs_path = args.dbs_path 46 | tile_img = args.tile_img 47 | extensive = args.extensive 48 | additional_as_fold = args.additional_as_fold 49 | 50 | # Dominating clusters to pull WSI. 51 | value_cluster_ids = dict() 52 | value_cluster_ids[1] = [] 53 | value_cluster_ids[0] = [] 54 | only_id = True 55 | 56 | ######################################################## 57 | ############# LUAD vs LUSC ############################# 58 | # Leiden_2.0 fold 4. 59 | # value_cluster_ids = dict() 60 | # value_cluster_ids[1] = [11,31,28,36,22,35] 61 | # value_cluster_ids[0] = [5, 45,] 62 | # only_id = False 63 | 64 | ######################################################## 65 | ############# LUAD OS ################################## 66 | ## Leiden 2.0 fold 0. 67 | # value_cluster_ids = dict() 68 | # value_cluster_ids[0] = [31, 1,37, 0,16, 8, 5] 69 | # value_cluster_ids[1] = [15,39,41,22,10,14,27] 70 | # only_id = True 71 | 72 | ######################################################## 73 | ############# LUAD PFS ################################# 74 | ## Leiden 2.0 fold 0. 75 | # value_cluster_ids = dict() 76 | # value_cluster_ids[0] = [39,45,29,27,22,36,32, 0,37,21] 77 | # value_cluster_ids[1] = [15,11, 6,44, 5,24] 78 | # only_id = True 79 | 80 | # Default path for GDC manifest. 81 | manifest_csv = '%s/utilities/files/LUADLUSC/gdc_manifest.txt' % os.path.dirname(os.path.realpath(__file__)) 82 | 83 | # Default DBs path. 84 | if dbs_path is None: 85 | dbs_path = os.path.dirname(os.path.realpath(__file__)) 86 | 87 | # Leiden convention name. 88 | groupby = 'leiden_%s' % resolution 89 | 90 | # Dataset images. 91 | data = Data(dataset=dataset, marker=marker, patch_h=image_height, patch_w=image_width, n_channels=image_channels, batch_size=64, project_path=dbs_path, load=True) 92 | 93 | # Dump cluster images. 94 | if tile_img: 95 | plot_cluster_images(groupby, meta_folder, data, fold, h5_complete_path, dpi, value_cluster_ids, extensive=extensive) 96 | 97 | # Save WSI overlay with clusters. 98 | plot_wsi_clusters(groupby, meta_folder, matching_field, meta_field, data, fold, h5_complete_path, h5_additional_path, additional_as_fold, dpi, min_tiles, manifest_csv=manifest_csv, 99 | value_cluster_ids=value_cluster_ids, type_='percent', only_id=only_id, n_wsi_samples=3) 100 | 101 | # Save WSI overlay with clusters. 102 | # plot_wsi_clusters_interactions(groupby, meta_folder, 'slides', meta_field, data, fold, h5_complete_path, h5_additional_path, additional_as_fold, dpi, min_tiles, manifest_csv=manifest_csv, 103 | # inter_dict=inter_dict, type_='percent', only_id=only_id, n_wsi_samples=2) 104 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | anndata==0.8.0 3 | astor==0.8.1 4 | asttokens==2.0.8 5 | astunparse==1.6.3 6 | autograd==1.4 7 | autograd-gamma==0.5.0 8 | backcall==0.2.0 9 | CacheControl==0.12.11 10 | certifi==2022.6.15 11 | charset-normalizer==2.1.1 12 | click==8.1.3 13 | cloudpickle==2.1.0 14 | configparser==5.3.0 15 | cycler==0.11.0 16 | Cython==0.29.32 17 | decorator==5.1.1 18 | docker-pycreds==0.4.0 19 | dunamai==1.13.0 20 | ecos==2.0.10 21 | executing==1.0.0 22 | fonttools==4.37.1 23 | formulaic==0.2.4 24 | future==0.18.2 25 | gast==0.3.3 26 | get_version==3.5.4 27 | gitdb==4.0.9 28 | GitPython==3.1.27 29 | google-pasta==0.2.0 30 | grpcio==1.47.0 31 | h5py==3.7.0 32 | hdmedians==0.14.2 33 | idna==3.3 34 | importlib-metadata==4.12.0 35 | interface-meta==1.3.0 36 | ipython==8.4.0 37 | jedi==0.18.1 38 | joblib==1.1.0 39 | Keras-Applications==1.0.8 40 | Keras-Preprocessing==1.1.2 41 | kiwisolver==1.4.4 42 | legacy-api-wrap==1.2 43 | leidenalg==0.8.10 44 | lifelines==0.26.3 45 | llvmlite==0.39.1 46 | lockfile==0.12.2 47 | matplotlib==3.5.3 48 | matplotlib-inline==0.1.6 49 | msgpack==1.0.4 50 | natsort==8.2.0 51 | networkx==2.8.6 52 | numba==0.56.2 53 | numexpr==2.8.3 54 | numpy==1.23.3 55 | nvidia-cublas-cu11==11.10.3.66 56 | nvidia-cuda-cupti-cu11==11.7.101 57 | nvidia-cuda-nvcc-cu11==11.7.99 58 | nvidia-cuda-runtime-cu11==11.7.99 59 | nvidia-cudnn-cu11==8.5.0.96 60 | nvidia-cufft-cu11==10.7.2.91 61 | nvidia-curand-cu11==10.2.10.91 62 | nvidia-cusolver-cu11==11.4.0.1 63 | nvidia-cusparse-cu11==11.7.4.91 64 | nvidia-dali-cuda110==1.16.0 65 | nvidia-dali-nvtf-plugin==1.16.0+nv22.8 66 | nvidia-horovod==0.25.0+nv22.8 67 | nvidia-nccl-cu11==2.14.3 68 | nvidia-pyindex==1.0.9 69 | nvidia-tensorflow==1.15.5+nv22.8 70 | opt-einsum==3.3.0 71 | osqp==0.6.2.post5 72 | packaging==21.3 73 | pandas==1.4.4 74 | parso==0.8.3 75 | pathtools==0.1.2 76 | patsy==0.5.2 77 | pexpect==4.8.0 78 | pickleshare==0.7.5 79 | Pillow==9.2.0 80 | promise==2.3 81 | prompt-toolkit==3.0.31 82 | protobuf==3.20.1 83 | psutil==5.9.1 84 | ptyprocess==0.7.0 85 | pure-eval==0.2.2 86 | Pygments==2.13.0 87 | pynndescent==0.5.7 88 | pyparsing==3.0.9 89 | python-dateutil==2.8.2 90 | pytz==2022.2.1 91 | PyWavelets==1.3.0 92 | PyYAML==6.0 93 | qdldl==0.1.5.post2 94 | requests==2.28.1 95 | scanpy==1.8.1 96 | scikit-bio==0.5.6 97 | scikit-image==0.15.0 98 | scikit-learn==1.1.2 99 | scikit-survival==0.18.0 100 | scipy==1.9.1 101 | seaborn==0.11.2 102 | sentry-sdk==1.9.7 103 | shortuuid==1.0.9 104 | sinfo==0.3.4 105 | six==1.16.0 106 | sklearn==0.0 107 | smmap==5.0.0 108 | stack-data==0.5.0 109 | statsmodels==0.13.2 110 | stdlib-list==0.8.0 111 | subprocess32==3.5.4 112 | tables==3.7.0 113 | tensorboard==1.15.0 114 | tensorflow-estimator==1.15.1 115 | termcolor==1.1.0 116 | termcolor-whl==1.1.2 117 | threadpoolctl==3.1.0 118 | tqdm==4.64.1 119 | traitlets==5.3.0 120 | umap-learn==0.5.3 121 | urllib3==1.26.12 122 | wandb==0.12.7 123 | wcwidth==0.2.5 124 | wrapt==1.14.1 125 | yaspin==2.2.0 126 | zipp==3.8.1 127 | lxml==4.9.1 128 | kneed==0.8.2 -------------------------------------------------------------------------------- /run_representationsleiden.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import argparse 3 | import os 4 | import warnings 5 | warnings.filterwarnings('ignore') 6 | 7 | # Own libs. 8 | from models.clustering.leiden_representations import run_leiden 9 | 10 | # Folder permissions for cluster. 11 | os.umask(0o002) 12 | # H5 File bug over network file system. 13 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 14 | 15 | 16 | ##### Main ####### 17 | parser = argparse.ArgumentParser(description='Run Leiden Community detection over Self-Supervised representations.') 18 | parser.add_argument('--resolution', dest='resolution', type=float, default=None, help='Leiden resolution.') 19 | parser.add_argument('--subsample', dest='subsample', type=int, default=None, help='Number of sample to run Leiden on, default is None, 200000 works well.') 20 | parser.add_argument('--n_neighbors', dest='n_neighbors', type=int, default=250, help='Number of neighbors to use when creating the graph, default is 250.') 21 | parser.add_argument('--meta_field', dest='meta_field', type=str, default=None, help='Purpose of the clustering, name of folder.') 22 | parser.add_argument('--matching_field', dest='matching_field', type=str, default=None, help='Key used to match folds split and H5 representation file.') 23 | parser.add_argument('--rep_key', dest='rep_key', type=str, default='z_latent', help='Key pattern for representations to grab: z_latent, h_latent.') 24 | parser.add_argument('--folds_pickle', dest='folds_pickle', type=str, default=None, help='Pickle file with folds information.') 25 | parser.add_argument('--main_path', dest='main_path', type=str, default=None, help='Workspace main path.') 26 | parser.add_argument('--h5_complete_path', dest='h5_complete_path', type=str, required=True, help='H5 file path to run the leiden clustering folds.') 27 | parser.add_argument('--h5_additional_path', dest='h5_additional_path', type=str, default=None, help='Additional H5 representation to assign leiden clusters.') 28 | args = parser.parse_args() 29 | subsample = args.subsample 30 | resolution = args.resolution 31 | n_neighbors = args.n_neighbors 32 | meta_field = args.meta_field 33 | matching_field = args.matching_field 34 | rep_key = args.rep_key 35 | folds_pickle = args.folds_pickle 36 | main_path = args.main_path 37 | h5_complete_path = args.h5_complete_path 38 | h5_additional_path = args.h5_additional_path 39 | 40 | if main_path is None: 41 | main_path = os.path.dirname(os.path.realpath(__file__)) 42 | 43 | # Default resolutions. 44 | if resolution is None: 45 | resolutions = [0.4, 0.7, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0] 46 | # resolutions.extend([6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0]) 47 | else: 48 | resolutions = [resolution] 49 | 50 | # Include surrounding cluster tile annotations. 51 | include_connections = False 52 | 53 | # Run leiden clustering. 54 | run_leiden(meta_field, matching_field, rep_key, h5_complete_path, h5_additional_path, folds_pickle, resolutions, n_neighbors=n_neighbors, subsample=subsample, include_connections=include_connections) -------------------------------------------------------------------------------- /run_representationsleiden_assignment.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import argparse 3 | import os 4 | import warnings 5 | warnings.filterwarnings('ignore') 6 | 7 | # Own libs. 8 | from models.clustering.leiden_representations import assign_additional_only 9 | 10 | # Folder permissions for cluster. 11 | os.umask(0o002) 12 | # H5 File bug over network file system. 13 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 14 | 15 | 16 | ##### Main ####### 17 | parser = argparse.ArgumentParser(description='Run Leiden Community detection over Self-Supervised representations.') 18 | parser.add_argument('--resolution', dest='resolution', type=float, default=None, help='Leiden resolution.') 19 | parser.add_argument('--meta_field', dest='meta_field', type=str, default=None, help='Purpose of the clustering, name of folder.') 20 | parser.add_argument('--rep_key', dest='rep_key', type=str, default='z_latent', help='Key pattern for representations to grab: z_latent, h_latent.') 21 | parser.add_argument('--folds_pickle', dest='folds_pickle', type=str, default=None, help='Pickle file with folds information.') 22 | parser.add_argument('--main_path', dest='main_path', type=str, default=None, help='Workspace main path.') 23 | parser.add_argument('--h5_complete_path', dest='h5_complete_path', type=str, required=True, help='H5 file path to run the leiden clustering folds.') 24 | parser.add_argument('--h5_additional_path', dest='h5_additional_path', type=str, required=True, help='Additional H5 representation to assign leiden clusters.') 25 | args = parser.parse_args() 26 | resolution = args.resolution 27 | meta_field = args.meta_field 28 | rep_key = args.rep_key 29 | folds_pickle = args.folds_pickle 30 | main_path = args.main_path 31 | h5_complete_path = args.h5_complete_path 32 | h5_additional_path = args.h5_additional_path 33 | 34 | if main_path is None: 35 | main_path = os.path.dirname(os.path.realpath(__file__)) 36 | 37 | # Default resolutions. 38 | if resolution is None: 39 | resolutions = [0.4, 0.7, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0] 40 | # resolutions.extend([6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0]) 41 | else: 42 | resolutions = [resolution] 43 | 44 | # Run leiden clustering. 45 | assign_additional_only(meta_field, rep_key, h5_complete_path, h5_additional_path, folds_pickle, resolutions) -------------------------------------------------------------------------------- /run_representationsleiden_evalutation.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import argparse 3 | import os 4 | 5 | # Own libs. 6 | from models.clustering.evaluation_metrics import evaluate_cluster_configurations 7 | 8 | # Folder permissions for cluster. 9 | os.umask(0o002) 10 | # H5 File bug over network file system. 11 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 12 | 13 | 14 | ##### Main ####### 15 | parser = argparse.ArgumentParser(description='Report classification and cluster performance based on Logistic Regression.') 16 | parser.add_argument('--meta_folder', dest='meta_folder', type=str, default=None, help='Purpose of the clustering, name of folder.') 17 | parser.add_argument('--folds_pickle', dest='folds_pickle', type=str, default=None, help='Pickle file with folds information.') 18 | parser.add_argument('--h5_complete_path', dest='h5_complete_path', type=str, required=True, help='H5 file path to run the leiden clustering folds.') 19 | parser.add_argument('--include_silnngraph', dest='include_silnngraph', action='store_true', default=False, help='Flag to specify if silhoutte based on NN graph is run (comp. costly).') 20 | args = parser.parse_args() 21 | meta_folder = args.meta_folder 22 | folds_pickle = args.folds_pickle 23 | h5_complete_path = args.h5_complete_path 24 | include_silnngraph = args.include_silnngraph 25 | 26 | resolutions = [0.25, 0.4, 0.5, 0.7, 0.75, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 6.0, 7.0, 7.0, 8.0, 9.0] 27 | # resolutions = [0.25, 0.4, 0.5, 0.7, 0.75, 1.0] 28 | 29 | evaluate_cluster_configurations(h5_complete_path, meta_folder, folds_pickle, resolutions, threshold_inst=0.01, include_nngraph=include_silnngraph) 30 | -------------------------------------------------------------------------------- /run_representationspathology.py: -------------------------------------------------------------------------------- 1 | # Imports. 2 | from data_manipulation.data import Data 3 | import tensorflow as tf 4 | import argparse 5 | import os 6 | 7 | # Folder permissions for cluster. 8 | os.umask(0o002) 9 | # H5 File bug over network file system. 10 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 11 | 12 | parser = argparse.ArgumentParser(description='Self-Supervised model training.') 13 | parser.add_argument('--epochs', dest='epochs', type=int, default=90, help='Number epochs to run: default is 45 epochs.') 14 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=64, help='Batch size, default size is 64.') 15 | parser.add_argument('--dataset', dest='dataset', type=str, default='vgh_nki', help='Dataset to use.') 16 | parser.add_argument('--marker', dest='marker', type=str, default='he', help='Marker of dataset to use, default is H&E.') 17 | parser.add_argument('--img_size', dest='img_size', type=int, default=224, help='Image size for the model.') 18 | parser.add_argument('--img_ch', dest='img_ch', type=int, default=3, help='Number of channels for the model.') 19 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=128, help='Latent space size for contrastive loss.') 20 | parser.add_argument('--model', dest='model', type=str, default='ContrastivePathology', help='Model name, used to select the type of model (SimCLR, BYOL, SwAV).') 21 | parser.add_argument('--main_path', dest='main_path', type=str, default=None, help='Path for the output run.') 22 | parser.add_argument('--dbs_path', dest='dbs_path', type=str, default=None, help='Directory with DBs to use.') 23 | parser.add_argument('--check_every', dest='check_every', type=int, default=10, help='Save checkpoint and projects samples every X epcohs.') 24 | parser.add_argument('--restore', dest='restore', action='store_true', default=False, help='Restore previous run and continue.') 25 | parser.add_argument('--report', dest='report', action='store_true', default=False, help='Report Latent Space progress.') 26 | args = parser.parse_args() 27 | epochs = args.epochs 28 | batch_size = args.batch_size 29 | dataset = args.dataset 30 | marker = args.marker 31 | image_width = args.img_size 32 | image_height = args.img_size 33 | image_channels = args.img_ch 34 | z_dim = args.z_dim 35 | model = args.model 36 | main_path = args.main_path 37 | dbs_path = args.dbs_path 38 | check_every = args.check_every 39 | restore = args.restore 40 | report = args.report 41 | 42 | 43 | # Main paths for data output and databases. 44 | if main_path is None: 45 | main_path = os.path.dirname(os.path.realpath(__file__)) 46 | if dbs_path is None: 47 | dbs_path = os.path.dirname(os.path.realpath(__file__)) 48 | 49 | # Directory handling. 50 | name_run = 'h%s_w%s_n%s_zdim%s' % (image_height, image_width, image_channels, z_dim) 51 | data_out_path = os.path.join(main_path, 'data_model_output') 52 | data_out_path = os.path.join(data_out_path, model) 53 | data_out_path = os.path.join(data_out_path, dataset) 54 | data_out_path = os.path.join(data_out_path, name_run) 55 | 56 | # Hyperparameters for training. 57 | regularizer_scale = 1e-4 58 | learning_rate_e = 5e-4 59 | beta_1 = 0.5 60 | 61 | # Model Architecture param. 62 | layers_map = {512:7, 448:6, 256:6, 224:5, 128:5, 112:4, 56:3, 28:2} 63 | layers = layers_map[image_height] 64 | spectral = True 65 | attention = 56 66 | init = 'xavier' 67 | # init = 'orthogonal' 68 | 69 | # Testing with CRC and 56x56. 70 | if image_height == 56: attention = 28 71 | 72 | # Handling of different models. 73 | if 'BYOL' in model: 74 | z_dim = 256 75 | from models.selfsupervised.BYOL import RepresentationsPathology 76 | elif 'SimCLR' in model: 77 | from models.selfsupervised.SimCLR import RepresentationsPathology 78 | elif 'SwAV' in model: 79 | learning_rate_e = 1e-5 80 | from models.selfsupervised.SwAV import RepresentationsPathology 81 | elif 'SimSiam' in model: 82 | from models.selfsupervised.SimSiam import RepresentationsPathology 83 | elif 'Relational' in model: 84 | from models.selfsupervised.RealReas import RepresentationsPathology 85 | elif 'BarlowTwins' in model: 86 | from models.selfsupervised.BarlowTwins import RepresentationsPathology 87 | elif 'DINO' in model: 88 | from models.selfsupervised.DINO import RepresentationsPathology 89 | 90 | # Collect dataset. 91 | data = Data(dataset=dataset, marker=marker, patch_h=image_height, patch_w=image_width, n_channels=image_channels, batch_size=batch_size, project_path=dbs_path) 92 | 93 | # Run PathologyContrastive Encoder. 94 | with tf.Graph().as_default(): 95 | # Instantiate Model. 96 | contrast_pathology = RepresentationsPathology(data=data, z_dim=z_dim, layers=layers, beta_1=beta_1, init=init, regularizer_scale=regularizer_scale, spectral=spectral, attention=attention, 97 | learning_rate_e=learning_rate_e, model_name=model) 98 | # Train Model. 99 | losses = contrast_pathology.train(epochs, data_out_path, data, restore, print_epochs=10, n_images=25, checkpoint_every=check_every, report=report) 100 | -------------------------------------------------------------------------------- /run_representationspathology_projection.py: -------------------------------------------------------------------------------- 1 | # Imports. 2 | from models.evaluation.features import * 3 | from data_manipulation.data import Data 4 | import tensorflow as tf 5 | import argparse 6 | import os 7 | 8 | # Folder permissions for cluster. 9 | os.umask(0o002) 10 | # H5 File bug over network file system. 11 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Project images onto Self-Supervised model latent space.') 15 | parser.add_argument('--checkpoint', dest='checkpoint', required=True, help='Path to pre-trained weights (.ckt) of ContrastivePathology.') 16 | parser.add_argument('--real_hdf5', dest='real_hdf5', type=str, default=200, required=True, help='Path for real image to encode.') 17 | parser.add_argument('--img_size', dest='img_size', type=int, default=224, help='Image size for the model.') 18 | parser.add_argument('--img_ch', dest='img_ch', type=int, default=3, help='Number of channels for the model.') 19 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=128, help='Latent space size, default is 128.') 20 | parser.add_argument('--dataset', dest='dataset', type=str, default='vgh_nki', help='Dataset to use.') 21 | parser.add_argument('--marker', dest='marker', type=str, default='he', help='Marker of dataset to use.') 22 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=64, help='Batch size, default size is 64.') 23 | parser.add_argument('--model', dest='model', type=str, default='ContrastivePathology', help='Model name, used to select the type of model (SimCLR, BYOL, SwAV).') 24 | parser.add_argument('--main_path', dest='main_path', type=str, default=None, help='Path for the output run.') 25 | parser.add_argument('--dbs_path', dest='dbs_path', type=str, default=None, help='Directory with DBs to use.') 26 | parser.add_argument('--save_img', dest='save_img', action='store_true', default=False, help='Save reconstructed images in the H5 file.') 27 | args = parser.parse_args() 28 | checkpoint = args.checkpoint 29 | real_hdf5 = args.real_hdf5 30 | image_width = args.img_size 31 | image_height = args.img_size 32 | image_channels = args.img_ch 33 | z_dim = args.z_dim 34 | dataset = args.dataset 35 | marker = args.marker 36 | batch_size = args.batch_size 37 | model = args.model 38 | main_path = args.main_path 39 | dbs_path = args.dbs_path 40 | save_img = args.save_img 41 | 42 | # Main paths for data output and databases. 43 | if main_path is None: 44 | main_path = os.path.dirname(os.path.realpath(__file__)) 45 | if dbs_path is None: 46 | dbs_path = os.path.dirname(os.path.realpath(__file__)) 47 | 48 | # Directory handling. 49 | name_run = 'h%s_w%s_n%s_zdim%s' % (image_height, image_width, image_channels, z_dim) 50 | data_out_path = os.path.join(main_path, 'data_model_output') 51 | data_out_path = os.path.join(data_out_path, model) 52 | data_out_path = os.path.join(data_out_path, dataset) 53 | data_out_path = os.path.join(data_out_path, name_run) 54 | 55 | # Hyperparameters for training. 56 | regularizer_scale = 1e-4 57 | learning_rate_e = 5e-4 58 | beta_1 = 0.5 59 | 60 | # Model Architecture param. 61 | layers_map = {512:7, 448:6, 256:6, 224:5, 128:5, 112:4, 56:3, 28:2} 62 | layers = layers_map[image_height] 63 | spectral = True 64 | attention = 56 65 | init = 'xavier' 66 | # init = 'orthogonal' 67 | 68 | # Handling of different models. 69 | if 'BYOL' in model: 70 | z_dim = 256 71 | from models.selfsupervised.BYOL import RepresentationsPathology 72 | elif 'SimCLR' in model: 73 | from models.selfsupervised.SimCLR import RepresentationsPathology 74 | elif 'SwAV' in model: 75 | learning_rate_e = 1e-5 76 | from models.selfsupervised.SwAV import RepresentationsPathology 77 | elif 'SimSiam' in model: 78 | from models.selfsupervised.SimSiam import RepresentationsPathology 79 | elif 'Relational' in model: 80 | from models.selfsupervised.RealReas import RepresentationsPathology 81 | elif 'BarlowTwins' in model: 82 | from models.selfsupervised.BarlowTwins import RepresentationsPathology 83 | elif 'DINO' in model: 84 | from models.selfsupervised.DINO import RepresentationsPathology 85 | 86 | # Collect dataset. 87 | data = Data(dataset=dataset, marker=marker, patch_h=image_height, patch_w=image_width, n_channels=image_channels, batch_size=batch_size, project_path=dbs_path) 88 | 89 | # Run PathologyContrastive Encoder. 90 | with tf.Graph().as_default(): 91 | # Instantiate Model. 92 | contrast_pathology = RepresentationsPathology(data=data, z_dim=z_dim, layers=layers, beta_1=beta_1, init=init, regularizer_scale=regularizer_scale, spectral=spectral, attention=attention, learning_rate_e=learning_rate_e, model_name=model) 93 | 94 | # Run projections into H5. 95 | real_encode_contrastive_from_checkpoint(model=contrast_pathology, data=data, data_out_path=main_path, checkpoint=checkpoint, real_hdf5=real_hdf5, batches=batch_size, save_img=save_img) 96 | -------------------------------------------------------------------------------- /run_representationspathology_projection_dataset.py: -------------------------------------------------------------------------------- 1 | # Imports. 2 | from models.evaluation.features import * 3 | from data_manipulation.data import Data 4 | import tensorflow as tf 5 | import argparse 6 | import os 7 | 8 | # Folder permissions for cluster. 9 | os.umask(0o002) 10 | # H5 File bug over network file system. 11 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Project images onto Self-Supervised model latent space.') 15 | parser.add_argument('--checkpoint', dest='checkpoint', required=True, help='Path to pre-trained weights (.ckt) of ContrastivePathology.') 16 | parser.add_argument('--img_size', dest='img_size', type=int, default=224, help='Image size for the model.') 17 | parser.add_argument('--img_ch', dest='img_ch', type=int, default=3, help='Number of channels for the model.') 18 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=128, help='Latent space size, default is 128.') 19 | parser.add_argument('--dataset', dest='dataset', type=str, default='vgh_nki', help='Dataset to use.') 20 | parser.add_argument('--marker', dest='marker', type=str, default='he', help='Marker of dataset to use.') 21 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=64, help='Batch size, default size is 64.') 22 | parser.add_argument('--model', dest='model', type=str, default='ContrastivePathology', help='Model name, used to select the type of model (SimCLR, BYOL, SwAV).') 23 | parser.add_argument('--main_path', dest='main_path', type=str, default=None, help='Path for the output run.') 24 | parser.add_argument('--dbs_path', dest='dbs_path', type=str, default=None, help='Directory with DBs to use.') 25 | parser.add_argument('--save_img', dest='save_img', action='store_true', default=False, help='Save reconstructed images in the H5 file.') 26 | args = parser.parse_args() 27 | checkpoint = args.checkpoint 28 | image_width = args.img_size 29 | image_height = args.img_size 30 | image_channels = args.img_ch 31 | z_dim = args.z_dim 32 | dataset = args.dataset 33 | marker = args.marker 34 | batch_size = args.batch_size 35 | model = args.model 36 | main_path = args.main_path 37 | dbs_path = args.dbs_path 38 | save_img = args.save_img 39 | 40 | # Main paths for data output and databases. 41 | if main_path is None: 42 | main_path = os.path.dirname(os.path.realpath(__file__)) 43 | if dbs_path is None: 44 | dbs_path = os.path.dirname(os.path.realpath(__file__)) 45 | 46 | # Directory handling. 47 | name_run = 'h%s_w%s_n%s_zdim%s' % (image_height, image_width, image_channels, z_dim) 48 | data_out_path = os.path.join(main_path, 'data_model_output') 49 | data_out_path = os.path.join(data_out_path, model) 50 | data_out_path = os.path.join(data_out_path, dataset) 51 | data_out_path = os.path.join(data_out_path, name_run) 52 | 53 | # Hyperparameters for training. 54 | regularizer_scale = 1e-4 55 | learning_rate_e = 5e-4 56 | beta_1 = 0.5 57 | 58 | # Model Architecture param. 59 | layers_map = {512:7, 448:6, 256:6, 224:5, 128:5, 112:4, 56:3, 28:2} 60 | layers = layers_map[image_height] 61 | spectral = True 62 | attention = 56 63 | init = 'xavier' 64 | # init = 'orthogonal' 65 | 66 | # Handling of different models. 67 | if 'BYOL' in model: 68 | z_dim = 256 69 | from models.selfsupervised.BYOL import RepresentationsPathology 70 | elif 'SimCLR' in model: 71 | from models.selfsupervised.SimCLR import RepresentationsPathology 72 | elif 'SwAV' in model: 73 | learning_rate_e = 1e-5 74 | from models.selfsupervised.SwAV import RepresentationsPathology 75 | elif 'SimSiam' in model: 76 | from models.selfsupervised.SimSiam import RepresentationsPathology 77 | elif 'Relational' in model: 78 | from models.selfsupervised.RealReas import RepresentationsPathology 79 | elif 'BarlowTwins' in model: 80 | from models.selfsupervised.BarlowTwins import RepresentationsPathology 81 | elif 'DINO' in model: 82 | from models.selfsupervised.DINO import RepresentationsPathology 83 | 84 | # Collect dataset. 85 | data = Data(dataset=dataset, marker=marker, patch_h=image_height, patch_w=image_width, n_channels=image_channels, batch_size=batch_size, project_path=dbs_path) 86 | 87 | # Run PathologyContrastive Encoder. 88 | with tf.Graph().as_default(): 89 | # Instantiate Model. 90 | contrast_pathology = RepresentationsPathology(data=data, z_dim=z_dim, layers=layers, beta_1=beta_1, init=init, regularizer_scale=regularizer_scale, spectral=spectral, attention=attention, learning_rate_e=learning_rate_e, model_name=model) 91 | 92 | for real_hdf5 in [data.hdf5_train, data.hdf5_validation, data.hdf5_test]: 93 | # Run projections into H5. 94 | real_encode_contrastive_from_checkpoint(model=contrast_pathology, data=data, data_out_path=main_path, checkpoint=checkpoint, real_hdf5=real_hdf5, batches=batch_size, save_img=save_img) 95 | -------------------------------------------------------------------------------- /utilities/files/BLCA/overall_survival_TCGA_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/BLCA/overall_survival_TCGA_folds.jpg -------------------------------------------------------------------------------- /utilities/files/BLCA/overall_survival_TCGA_folds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/BLCA/overall_survival_TCGA_folds.pkl -------------------------------------------------------------------------------- /utilities/files/BRCA/overall_survival_TCGA_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/BRCA/overall_survival_TCGA_folds.jpg -------------------------------------------------------------------------------- /utilities/files/BRCA/overall_survival_TCGA_folds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/BRCA/overall_survival_TCGA_folds.pkl -------------------------------------------------------------------------------- /utilities/files/CESC/overall_survival_TCGA_folds.csv: -------------------------------------------------------------------------------- 1 | samples,os_event_ind,os_event_data 2 | TCGA-C5-A1M6,1,31.3972602739726 3 | TCGA-HM-A3JJ,1,21.665753424657535 4 | TCGA-VS-A9U6,0,43.3972602739726 5 | TCGA-HG-A9SC,0,16.56986301369863 6 | TCGA-ZX-AA5X,0,3.9123287671232876 7 | TCGA-VS-A94W,0,40.865753424657534 8 | TCGA-C5-A3HE,0,18.016438356164386 9 | TCGA-VS-A9UD,0,24.295890410958904 10 | TCGA-MA-AA3X,0,20.284931506849315 11 | TCGA-EA-A50E,1,7.463013698630137 12 | TCGA-C5-A901,0,17.03013698630137 13 | TCGA-FU-A3WB,0,16.142465753424656 14 | TCGA-Q1-A6DT,1,9.04109589041096 15 | TCGA-C5-A7CM,0,20.350684931506848 16 | TCGA-JX-A3Q0,0,209.58904109589042 17 | TCGA-EX-A1H6,0,7.923287671232876 18 | TCGA-EA-A3HU,0,33.30410958904109 19 | TCGA-C5-A1BI,0,36.55890410958904 20 | TCGA-JW-A5VH,1,3.287671232876712 21 | TCGA-DS-A7WI,1,8.284931506849315 22 | TCGA-C5-A2M1,0,38.43287671232876 23 | TCGA-EA-A3HR,0,30.9041095890411 24 | TCGA-LP-A4AX,0,12.493150684931507 25 | TCGA-C5-A905,0,160.40547945205478 26 | TCGA-VS-A9V0,0,18.838356164383562 27 | TCGA-EA-A5FO,0,26.794520547945204 28 | TCGA-EA-A3HS,0,31.528767123287672 29 | TCGA-Q1-A6DV,0,16.142465753424656 30 | TCGA-C5-A8YT,1,20.81095890410959 31 | TCGA-C5-A8YQ,1,23.506849315068493 32 | TCGA-DS-A3LQ,0,22.98082191780822 33 | TCGA-HM-A4S6,0,14.926027397260274 34 | TCGA-EA-A5ZD,0,27.287671232876715 35 | TCGA-C5-A7UE,0,155.76986301369863 36 | TCGA-DS-A0VN,0,118.65205479452055 37 | TCGA-DS-A0VK,1,36.75616438356165 38 | TCGA-EA-A44S,0,12.131506849315068 39 | TCGA-DS-A1O9,1,8.745205479452055 40 | TCGA-JX-A5QV,0,20.90958904109589 41 | TCGA-VS-A952,0,58.45479452054794 42 | TCGA-LP-A4AU,0,11.276712328767124 43 | TCGA-C5-A7XC,0,50.991780821917814 44 | TCGA-C5-A8ZZ,0,20.90958904109589 45 | TCGA-EX-A1H5,0,20.350684931506848 46 | TCGA-HM-A3JK,0,20.778082191780822 47 | TCGA-FU-A3HZ,0,36.26301369863013 48 | TCGA-DR-A0ZL,0,87.74794520547945 49 | TCGA-C5-A902,0,4.898630136986301 50 | TCGA-LP-A5U2,0,0.2958904109589041 51 | TCGA-EA-A6QX,0,24.0 52 | TCGA-Q1-A73Q,0,9.336986301369864 53 | TCGA-C5-A1MH,1,38.991780821917814 54 | TCGA-DS-A7WF,1,16.175342465753424 55 | TCGA-VS-A953,1,15.682191780821919 56 | TCGA-UC-A7PG,1,12.164383561643836 57 | TCGA-RA-A741,0,14.597260273972605 58 | TCGA-C5-A1BQ,1,19.857534246575344 59 | TCGA-EX-A8YF,0,15.550684931506847 60 | TCGA-C5-A8XJ,0,146.86027397260276 61 | TCGA-4J-AA1J,0,17.81917808219178 62 | TCGA-EA-A97N,0,0.36164383561643837 63 | TCGA-C5-A1M5,1,67.46301369863014 64 | TCGA-UC-A7PF,1,93.9945205479452 65 | TCGA-EA-A3QE,0,25.01917808219178 66 | TCGA-DG-A2KM,0,63.978082191780814 67 | TCGA-C5-A1BF,1,18.73972602739726 68 | TCGA-VS-A9UV,1,3.419178082191781 69 | TCGA-VS-A8EC,0,46.52054794520548 70 | TCGA-ZJ-A8QQ,0,67.59452054794521 71 | TCGA-JW-A5VG,0,27.419178082191785 72 | TCGA-C5-A7CG,0,210.67397260273972 73 | TCGA-Q1-A73P,0,15.87945205479452 74 | TCGA-BI-A20A,0,23.671232876712327 75 | TCGA-MU-A51Y,0,28.076712328767123 76 | TCGA-C5-A7CL,1,15.484931506849314 77 | TCGA-FU-A23L,0,23.835616438356162 78 | TCGA-C5-A1BJ,0,144.16438356164383 79 | TCGA-C5-A7CO,0,147.35342465753425 80 | TCGA-IR-A3LH,0,78.7068493150685 81 | TCGA-C5-A7CJ,1,101.8191780821918 82 | TCGA-VS-A9UZ,0,67.19999999999999 83 | TCGA-HG-A2PA,1,25.413698630136984 84 | TCGA-EX-A69M,1,8.317808219178083 85 | TCGA-UC-A7PI,0,69.50136986301371 86 | TCGA-C5-A7UI,1,94.94794520547946 87 | TCGA-MY-A5BF,0,20.843835616438355 88 | TCGA-C5-A1BK,0,177.04109589041096 89 | TCGA-EA-A4BA,0,24.821917808219176 90 | TCGA-JX-A3Q8,0,44.61369863013699 91 | TCGA-C5-A2LY,0,78.34520547945205 92 | TCGA-C5-A3HL,0,20.416438356164385 93 | TCGA-FU-A3HY,0,31.364383561643837 94 | TCGA-EA-A556,0,14.893150684931506 95 | TCGA-VS-A8EG,0,45.56712328767124 96 | TCGA-MU-A5YI,0,34.61917808219178 97 | TCGA-FU-A3TQ,0,26.13698630136986 98 | TCGA-C5-A1M7,0,46.323287671232876 99 | TCGA-IR-A3L7,0,147.38630136986302 100 | TCGA-BI-A0VR,0,49.47945205479452 101 | TCGA-C5-A2LZ,1,100.14246575342466 102 | TCGA-EA-A78R,0,13.479452054794521 103 | TCGA-UC-A7PD,1,11.67123287671233 104 | TCGA-VS-A94X,1,16.635616438356166 105 | TCGA-DS-A0VM,0,117.9945205479452 106 | TCGA-Q1-A5R3,0,15.945205479452055 107 | TCGA-VS-A8QA,0,36.131506849315066 108 | TCGA-EA-A439,0,31.726027397260275 109 | TCGA-FU-A57G,0,35.44109589041096 110 | TCGA-C5-A7X5,1,13.610958904109589 111 | TCGA-C5-A7CK,1,134.3342465753425 112 | TCGA-VS-A8Q8,1,32.153424657534245 113 | TCGA-DG-A2KL,0,44.942465753424656 114 | TCGA-DS-A1OC,0,12.361643835616437 115 | TCGA-MY-A913,0,17.227397260273975 116 | TCGA-IR-A3LF,0,96.95342465753424 117 | TCGA-JW-A5VK,0,20.482191780821918 118 | TCGA-EX-A69L,0,19.79178082191781 119 | TCGA-Q1-A73O,0,14.07123287671233 120 | TCGA-C5-A7UH,0,131.1123287671233 121 | TCGA-LP-A7HU,0,13.347945205479453 122 | TCGA-DG-A2KK,0,82.06027397260274 123 | TCGA-C5-A2LX,0,83.04657534246576 124 | TCGA-Q1-A73R,0,18.64109589041096 125 | TCGA-EA-A5ZF,0,27.221917808219178 126 | TCGA-C5-A7X8,0,2.728767123287671 127 | TCGA-ZJ-AAXA,0,1.4136986301369863 128 | TCGA-C5-A1MQ,0,33.8958904109589 129 | TCGA-MA-AA43,0,11.375342465753425 130 | TCGA-Q1-A5R1,0,15.583561643835615 131 | TCGA-C5-A7UC,1,17.194520547945206 132 | TCGA-DS-A1OD,0,127.36438356164382 133 | TCGA-Q1-A73S,0,22.61917808219178 134 | TCGA-C5-A7X3,1,9.336986301369864 135 | TCGA-MY-A5BE,0,35.04657534246576 136 | TCGA-C5-A1MK,1,2.432876712328767 137 | TCGA-VS-A9UQ,0,41.52328767123288 138 | TCGA-DG-A2KJ,0,95.11232876712329 139 | TCGA-ZJ-A8QR,1,19.134246575342466 140 | TCGA-VS-A9UL,1,14.531506849315068 141 | TCGA-ZJ-AAXU,0,0.1643835616438356 142 | TCGA-VS-A94Y,1,4.734246575342466 143 | TCGA-MA-AA3Y,0,17.81917808219178 144 | TCGA-C5-A1ME,0,57.73150684931507 145 | TCGA-LP-A5U3,0,0.821917808219178 146 | TCGA-C5-A1MN,1,40.93150684931507 147 | TCGA-MY-A5BD,0,54.8054794520548 148 | TCGA-JW-A69B,0,28.372602739726027 149 | TCGA-C5-A1MJ,1,0.4602739726027397 150 | TCGA-MA-AA3Z,0,19.56164383561644 151 | TCGA-C5-A3HD,0,52.010958904109586 152 | TCGA-C5-A1M9,1,35.013698630136986 153 | TCGA-C5-A1BM,1,82.84931506849315 154 | TCGA-IR-A3LA,0,137.16164383561645 155 | TCGA-C5-A0TN,1,11.441095890410958 156 | TCGA-FU-A3EO,0,16.10958904109589 157 | TCGA-LP-A4AW,0,0.8876712328767123 158 | TCGA-VS-A8EL,0,65.4904109589041 159 | TCGA-FU-A2QG,0,19.035616438356165 160 | TCGA-MA-AA41,0,9.172602739726027 161 | TCGA-VS-A9U7,0,48.394520547945206 162 | TCGA-EA-A5O9,0,25.90684931506849 163 | TCGA-EA-A3HQ,0,37.347945205479455 164 | TCGA-2W-A8YY,0,17.52328767123288 165 | TCGA-VS-A9V4,1,4.33972602739726 166 | TCGA-IR-A3LI,0,81.96164383561644 167 | TCGA-HM-A6W2,0,9.435616438356163 168 | TCGA-C5-A2M2,1,33.23835616438356 169 | TCGA-C5-A1MP,0,3.5835616438356164 170 | TCGA-C5-A1MF,0,53.16164383561643 171 | TCGA-VS-A950,0,40.14246575342466 172 | TCGA-FU-A3YQ,0,28.306849315068497 173 | TCGA-VS-A9UH,0,46.915068493150685 174 | TCGA-EX-A449,0,14.695890410958903 175 | TCGA-DG-A2KH,0,1.1178082191780823 176 | TCGA-IR-A3LK,1,29.852054794520548 177 | TCGA-C5-A2LV,0,73.44657534246575 178 | TCGA-DS-A7WH,0,17.52328767123288 179 | TCGA-C5-A1BE,1,68.84383561643835 180 | TCGA-VS-A9UB,0,29.95068493150685 181 | TCGA-C5-A1M8,0,30.21369863013699 182 | TCGA-EA-A410,0,26.400000000000002 183 | TCGA-EA-A3Y4,0,36.88767123287671 184 | TCGA-WL-A834,0,26.005479452054793 185 | TCGA-JW-AAVH,0,18.147945205479452 186 | TCGA-MA-AA42,0,8.515068493150684 187 | TCGA-VS-A954,0,56.35068493150685 188 | TCGA-XS-A8TJ,0,29.26027397260274 189 | TCGA-EA-A43B,0,26.005479452054793 190 | TCGA-GH-A9DA,0,17.753424657534246 191 | TCGA-EX-A3L1,0,15.221917808219178 192 | TCGA-C5-A2LS,0,44.21917808219178 193 | TCGA-FU-A40J,0,14.005479452054796 194 | TCGA-C5-A3HF,1,17.852054794520548 195 | TCGA-FU-A3TX,0,1.4794520547945205 196 | TCGA-EA-A411,0,24.55890410958904 197 | TCGA-FU-A23K,0,12.230136986301371 198 | TCGA-VS-A957,0,55.4958904109589 199 | TCGA-EA-A1QS,0,39.55068493150685 200 | TCGA-EA-A3HT,0,31.364383561643837 201 | TCGA-EA-A5ZE,0,27.254794520547946 202 | TCGA-IR-A3LC,0,129.36986301369862 203 | TCGA-VS-A958,0,50.136986301369866 204 | TCGA-IR-A3LL,0,36.36164383561644 205 | TCGA-VS-A9U5,0,50.465753424657535 206 | TCGA-C5-A8XI,0,8.35068493150685 207 | TCGA-DR-A0ZM,0,58.88219178082192 208 | TCGA-MU-A8JM,0,19.956164383561642 209 | TCGA-C5-A2LT,0,73.18356164383562 210 | TCGA-MA-AA3W,0,22.52054794520548 211 | TCGA-Q1-A6DW,0,17.556164383561644 212 | TCGA-EA-A3QD,0,13.052054794520549 213 | TCGA-FU-A770,0,1.1178082191780823 214 | TCGA-C5-A907,0,14.728767123287671 215 | TCGA-IR-A3LB,1,66.8054794520548 216 | TCGA-EA-A1QT,0,40.865753424657534 217 | TCGA-C5-A1BN,1,5.457534246575342 218 | TCGA-JX-A3PZ,1,21.106849315068494 219 | TCGA-FU-A3NI,1,20.975342465753425 220 | TCGA-C5-A7CH,0,154.32328767123286 221 | -------------------------------------------------------------------------------- /utilities/files/CESC/overall_survival_TCGA_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/CESC/overall_survival_TCGA_folds.jpg -------------------------------------------------------------------------------- /utilities/files/CESC/overall_survival_TCGA_folds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/CESC/overall_survival_TCGA_folds.pkl -------------------------------------------------------------------------------- /utilities/files/COAD/overall_survival_TCGA_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/COAD/overall_survival_TCGA_folds.jpg -------------------------------------------------------------------------------- /utilities/files/COAD/overall_survival_TCGA_folds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/COAD/overall_survival_TCGA_folds.pkl -------------------------------------------------------------------------------- /utilities/files/LUAD/HPC_annotations/LUAD_HPC_annotations.csv: -------------------------------------------------------------------------------- 1 | HPC,Epithelium Stroma Ratio Detailed,Tissue Morphologies,Lymphocytic Infiltration,Epithelium Stroma Ratio,Hot/Warm/Cold Lymphocytic Infiltration,Summary,Exemplar Tiles,Pathologist Agreement 2 | 0,more stroma,Tumor with Stroma,Severe,More Stroma ,Hot,Inflamed irregular acini,"1,1",2 3 | 1,elastosis or collagenosis,Tumor with Stroma,Severe,More Stroma ,Hot,"Inflamed compact stroma, sparse tumor","4,2",1 4 | 2,near-normal lung,Lung Parenchyma,Mild,Non-Tumor,Cold,Compressed normal lung and smaller airways/vessels,"9,1",2 5 | 3,elastosis or collagenosis,Stroma,Mild,Non-Tumor,Cold,Coarse fibrillar stroma,"3,2",3 6 | 4,near-normal lung,Lung Parenchyma,Mild,Non-Tumor,Cold,Open normal lung,"9,1",3 7 | 5,more epithelium,Tumor with Stroma,Moderate,More Epithelium ,Hot,Solid with stromal TILs,"6,1",3 8 | 6,more epithelium,Tumor with Stroma,Mild,More Epithelium ,Warm,Solid and sieve-like cribriform,"19,2",2 9 | 7,more stroma,Tumor with Stroma,Mild,More Stroma ,Warm,Stroma-rich solid,"5,2",2 10 | 8,more epithelium,Tumor with Stroma,Mild,More Epithelium ,Hot,Angulated columnar acini,"10,2",3 11 | 9,reactive lung changes,Tumor with Stroma,Moderate,More Stroma ,Hot,"Diverse, inflamed, sparse tumor","14,3",2 12 | 10,near-normal lung,Lung Parenchyma,Mild,Non-Tumor,Cold,Haemorrhagic lung,"12,3",1 13 | 11,more epithelium,Tumor with Stroma,Moderate,More Epithelium ,Warm,Discohesive solid and compressed lumina,"10,1",2 14 | 12,reactive lung changes,Tumor with Stroma,Mild,More Stroma ,Hot,Diverse metaplastic and acinar,"17,2",1 15 | 13,near-normal lung,Lung Parenchyma,Moderate,Non-Tumor,Warm,Open lung with Mild interstitial thickening,"3,2",2 16 | 14,necrosis,Stroma,N/A,Non-Tumor,Warm,Necrosis,"0,2",2 17 | 15,more epithelium,Tumor with Stroma,Mild,More Epithelium ,Warm,"Solid, cold","5,0",3 18 | 16,elastosis or collagenosis,Stroma,Moderate,Non-Tumor,Hot,Inflamed elastosis/collapse,"4,4",3 19 | 17,reactive lung changes,Tumor with Stroma,Moderate,Equal,Hot,Surfaces and margins,"12,0",2 20 | 18,roughly equal,Tumor with Stroma,Mild,More Epithelium ,Warm,Complex acinar/cribriform nests,"11,3",3 21 | 19,more epithelium,Tumor with Stroma,Very sparse,More Epithelium ,Warm,"Mucinous, lepidic","5,2",2 22 | 20,other connective tissue,Mesenchymal Tissues,Very sparse,Non-Tumor,Warm,Peribronchial fat and glands,"8,0",2 23 | 21,more epithelium,Tumor with Stroma,Mild,More Epithelium ,Warm,Compressed luminal patterns,"0,0",2 24 | 22,vessels,Mesenchymal Tissues,Mild,Non-Tumor,Cold,Large vessel lumina,"5,3",3 25 | 23,elastosis or collagenosis,Mesenchymal Tissues,Mild,Non-Tumor,Cold,Large vessel walls,"17,4",3 26 | 24,more epithelium,Tumor with Stroma,Very sparse,More Epithelium ,Cold,Mucin lakes,"3,1",2 27 | 25,more epithelium,Tumor with Stroma,Severe,More Epithelium ,Hot,"Solid, inflamed","9,2",2 28 | 26,more stroma,Tumor with Stroma,Mild,More Stroma ,Cold,Small angulated acini in dense cold stroma,"5,1",2 29 | 27,more stroma,Tumor with Stroma,Mild,Equal,Warm,Small discohesive nests,"7,0",1 30 | 28,roughly equal,Tumor with Stroma,Mild,More Epithelium ,Warm,"Lepidic, papillary","4,0",2 31 | 29,more epithelium,Tumor with Stroma,Moderate,More Epithelium ,Hot,Solid pattern overrun by TILs,"1,3",3 32 | 30,more epithelium,Tumor with Stroma,Mild,More Epithelium ,Warm,Dispersed micropapillae,"13,2",3 33 | 31,vessels,Mesenchymal Tissues,Moderate,Non-Tumor,Warm,Elastotic vessels,"1,3",2 34 | 32,near-normal lung,Lung Parenchyma,Mild,Non-Tumor,Cold,Mildly fibrotic lung,"6,0",2 35 | 33,more epithelium,Tumor with Stroma,Moderate,More Epithelium ,Hot,"Small nests, retraction artefact","13,3",2 36 | 34,elastosis or collagenosis,Stroma,Moderate,Non-Tumor,Hot,Dense inflamed collagen,"6,2",2 37 | 35,more epithelium,Tumor with Stroma,Mild,More Epithelium ,Warm,Papillae and micropapillae,"2,4",2 38 | 36,reactive lung changes,Lung Parenchyma,Moderate,Non-Tumor,Hot,Inflammatory interstitial expansion,"0,7",2 39 | 37,more stroma,Tumor with Stroma,Mild,Equal,Warm,"Lepidic, acinar/papillary","2,4",2 40 | 38,more epithelium,Tumor with Stroma,Moderate,More Epithelium ,Hot,"Mostly solid, clear-cell change","1,2",3 41 | 39,more epithelium,Tumor with Stroma,Moderate,More Epithelium ,Warm,Crowded discohesive nests,"4,0",2 42 | 40,airway,Mesenchymal Tissues,Mild,Non-Tumor,Hot,Larger bronchi,"0,0",3 43 | 41,roughly equal,Tumor with Stroma,Mild,Equal,Warm,Dense stroma with small nests,"3,0",2 44 | 42,cartilage,Mesenchymal Tissues,Very sparse,Non-Tumor,Cold,Cartilage,"0,0",3 45 | 43,other connective tissue,Lymphocytes,Severe,Non-Tumor,Hot,Confluent lymphocytes,"2,1",1 46 | 44,other connective tissue,Fold Artefact,Moderate,More Stroma ,Hot,Fold Artefact,"4,0",2 47 | 45,more stroma,Tumor with Stroma,Mild,More Epithelium ,Warm,Classical cribriform,"0,3",3 -------------------------------------------------------------------------------- /utilities/files/LUAD/overall_survival_TCGA_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/LUAD/overall_survival_TCGA_folds.jpg -------------------------------------------------------------------------------- /utilities/files/LUAD/overall_survival_TCGA_folds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/LUAD/overall_survival_TCGA_folds.pkl -------------------------------------------------------------------------------- /utilities/files/LUAD/overall_survival_TCGA_folds_SOTA.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/LUAD/overall_survival_TCGA_folds_SOTA.pkl -------------------------------------------------------------------------------- /utilities/files/LUADLUSC/lungsubtype_Institutions.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/LUADLUSC/lungsubtype_Institutions.pkl -------------------------------------------------------------------------------- /utilities/files/LUSC/overall_survival_TCGA_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/LUSC/overall_survival_TCGA_folds.jpg -------------------------------------------------------------------------------- /utilities/files/LUSC/overall_survival_TCGA_folds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/LUSC/overall_survival_TCGA_folds.pkl -------------------------------------------------------------------------------- /utilities/files/Multi-Cancer/tcga_v07_10panCancer.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/Multi-Cancer/tcga_v07_10panCancer.pkl -------------------------------------------------------------------------------- /utilities/files/PRAD/overall_survival_TCGA_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/PRAD/overall_survival_TCGA_folds.jpg -------------------------------------------------------------------------------- /utilities/files/PRAD/overall_survival_TCGA_folds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/PRAD/overall_survival_TCGA_folds.pkl -------------------------------------------------------------------------------- /utilities/files/SKCM/overall_survival_TCGA_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/SKCM/overall_survival_TCGA_folds.jpg -------------------------------------------------------------------------------- /utilities/files/SKCM/overall_survival_TCGA_folds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/SKCM/overall_survival_TCGA_folds.pkl -------------------------------------------------------------------------------- /utilities/files/STAD/overall_survival_TCGA_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/STAD/overall_survival_TCGA_folds.jpg -------------------------------------------------------------------------------- /utilities/files/STAD/overall_survival_TCGA_folds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/STAD/overall_survival_TCGA_folds.pkl -------------------------------------------------------------------------------- /utilities/files/UCEC/overall_survival_TCGA_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/UCEC/overall_survival_TCGA_folds.jpg -------------------------------------------------------------------------------- /utilities/files/UCEC/overall_survival_TCGA_folds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/UCEC/overall_survival_TCGA_folds.pkl -------------------------------------------------------------------------------- /utilities/files/indexes_to_remove/TCGAFFPE_LUADLUSC_5x_60pc/complete.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/indexes_to_remove/TCGAFFPE_LUADLUSC_5x_60pc/complete.pkl -------------------------------------------------------------------------------- /utilities/files/indexes_to_remove/TCGAFFPE_LUADLUSC_5x_60pc/test.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/indexes_to_remove/TCGAFFPE_LUADLUSC_5x_60pc/test.pkl -------------------------------------------------------------------------------- /utilities/files/indexes_to_remove/TCGAFFPE_LUADLUSC_5x_60pc/train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/indexes_to_remove/TCGAFFPE_LUADLUSC_5x_60pc/train.pkl -------------------------------------------------------------------------------- /utilities/files/indexes_to_remove/TCGAFFPE_LUADLUSC_5x_60pc/valid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdalbertoCq/Histomorphological-Phenotype-Learning/c2372341dddf2f78ecbd9f4e81e378c24f17c5d3/utilities/files/indexes_to_remove/TCGAFFPE_LUADLUSC_5x_60pc/valid.pkl -------------------------------------------------------------------------------- /utilities/fold_creation/class_folds.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "outputs": [], 7 | "source": [ 8 | "import pandas as pd\n", 9 | "import numpy as np\n", 10 | "import pickle\n", 11 | "import copy\n", 12 | "import math\n", 13 | "import os" 14 | ], 15 | "metadata": { 16 | "collapsed": false, 17 | "pycharm": { 18 | "name": "#%%\n" 19 | } 20 | } 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "outputs": [], 26 | "source": [ 27 | "def store_data(data, file_path):\n", 28 | " with open(file_path, 'wb') as file:\n", 29 | " pickle.dump(data, file)\n", 30 | "\n", 31 | "def get_frac_split(meta_df, matching_field, ind_column, num_folds=5):\n", 32 | " # Copy dataframe.\n", 33 | " df = meta_df.copy(deep=True)\n", 34 | "\n", 35 | " # Get unique classes.\n", 36 | " unique_classes = np.unique(meta_df[ind_column])\n", 37 | " # randomize rows\n", 38 | " df = df.sample(frac=1).reset_index(drop=True)\n", 39 | "\n", 40 | " folds = dict()\n", 41 | " for i in range(num_folds):\n", 42 | " folds[i] = dict()\n", 43 | " folds[i]['train'] = list()\n", 44 | " folds[i]['test'] = list()\n", 45 | "\n", 46 | " for class_ in unique_classes:\n", 47 | " # Get slides for class.\n", 48 | " slides = np.unique(df[df[ind_column]==class_][matching_field].values)\n", 49 | "\n", 50 | " # Test size.\n", 51 | " num_samples = len(slides)\n", 52 | " test_size = math.floor(num_samples*(1/num_folds))\n", 53 | "\n", 54 | " # Iterate through chunks and add samples to fold.\n", 55 | " for i in range(num_folds):\n", 56 | " test_sample = slides[i*test_size:(i+1)*test_size]\n", 57 | " train_sample = list(set(slides).difference(set(test_sample)))\n", 58 | " folds[i]['train'].extend(train_sample)\n", 59 | " folds[i]['test'].extend(test_sample)\n", 60 | "\n", 61 | " return folds\n", 62 | "\n", 63 | "def get_folds(meta_df, matching_field, ind_column, num_folds=5, valid_set=False):\n", 64 | "\n", 65 | " # Get initial split for train/test.\n", 66 | " folds = get_frac_split(meta_df, matching_field, ind_column, num_folds=num_folds)\n", 67 | "\n", 68 | " for i in range(num_folds):\n", 69 | " whole_train_samples = folds[i]['train']\n", 70 | " subset_df = meta_df[meta_df[matching_field].isin(whole_train_samples)]\n", 71 | " train_val_folds = get_frac_split(subset_df, matching_field, ind_column, num_folds=num_folds)\n", 72 | " del folds[i]['train']\n", 73 | " folds[i]['train'] = train_val_folds[0]['train']\n", 74 | " folds[i]['valid'] = train_val_folds[0]['test']\n", 75 | "\n", 76 | " return folds\n", 77 | "\n", 78 | "# Verify: This should all be empty.\n", 79 | "def sanity_check_overlap(folds, num_folds):\n", 80 | " # For each fold, no overlap between cells.\n", 81 | " for i in range(num_folds):\n", 82 | " result = set(folds[i]['train']).intersection(set(folds[i]['valid']))\n", 83 | " if len(result) > 0:\n", 84 | " print(result)\n", 85 | "\n", 86 | " result = set(folds[i]['train']).intersection(set(folds[i]['test']))\n", 87 | " if len(result) > 0:\n", 88 | " print(result)\n", 89 | "\n", 90 | " result = set(folds[i]['valid']).intersection(set(folds[i]['test']))\n", 91 | " if len(result) > 0:\n", 92 | " print(result)\n", 93 | "\n", 94 | " # No overlap between test sets of all folds.\n", 95 | " for i in range(num_folds):\n", 96 | " for j in range(num_folds):\n", 97 | " if i==j: continue\n", 98 | " result = set(folds[i]['test']).intersection(set(folds[j]['test']))\n", 99 | " if len(result) > 0:\n", 100 | " print('Fold %s-%s' % (i,j), result)\n", 101 | "\n", 102 | "# Fit for legacy code.\n", 103 | "def fit_format(folds):\n", 104 | " slides_folds = dict()\n", 105 | " for i, fold in enumerate(folds):\n", 106 | " slides_folds[i] = dict()\n", 107 | " slides_folds[i]['train'] = [(slide, None, None) for slide in folds[i]['train']]\n", 108 | " slides_folds[i]['valid'] = [(slide, None, None) for slide in folds[i]['valid']]\n", 109 | " slides_folds[i]['test'] = [(slide, None, None) for slide in folds[i]['test']]\n", 110 | "\n", 111 | " return slides_folds\n" 112 | ], 113 | "metadata": { 114 | "collapsed": false, 115 | "pycharm": { 116 | "name": "#%%\n" 117 | } 118 | } 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "outputs": [], 124 | "source": [ 125 | "\n", 126 | "meta_csv = './tcga_panCancer.csv'\n", 127 | "pickle_path = './tcga_panCancer.pkl'\n", 128 | "\n", 129 | "\n", 130 | "\n", 131 | "# Read meta data file, rename column.\n", 132 | "meta_df = pd.read_csv(meta_csv)\n", 133 | "cancer_types = meta_df['type'].values\n", 134 | "del meta_df['type']\n", 135 | "meta_df['cancer_types'] = cancer_types\n", 136 | "\n", 137 | "# Create mapping for cancer types and integers.\n", 138 | "mapping_cancers = dict(zip(np.unique(cancer_types), range(len(np.unique(cancer_types)))))\n", 139 | "\n", 140 | "# Map new columns for integer indicator.\n", 141 | "meta_df['cancer_types_ind'] = meta_df['cancer_types'].astype(str).map(mapping_cancers)" 142 | ], 143 | "metadata": { 144 | "collapsed": false, 145 | "pycharm": { 146 | "name": "#%%\n", 147 | "is_executing": true 148 | } 149 | } 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 7, 154 | "outputs": [], 155 | "source": [ 156 | "\n", 157 | "folds = get_folds(meta_df, matching_field='slides', ind_column='cancer_types_ind', num_folds=5, valid_set=True)\n", 158 | "final_folds = fit_format(folds)\n", 159 | "\n", 160 | "# If no output, all good.\n", 161 | "sanity_check_overlap(folds, num_folds=5)\n", 162 | "\n", 163 | "store_data(final_folds, pickle_path)" 164 | ], 165 | "metadata": { 166 | "collapsed": false, 167 | "pycharm": { 168 | "name": "#%%\n" 169 | } 170 | } 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "outputs": [], 176 | "source": [], 177 | "metadata": { 178 | "collapsed": false, 179 | "pycharm": { 180 | "name": "#%%\n" 181 | } 182 | } 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "Python 3", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 2 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython2", 201 | "version": "2.7.6" 202 | } 203 | }, 204 | "nbformat": 4, 205 | "nbformat_minor": 0 206 | } -------------------------------------------------------------------------------- /utilities/h5_handling/combine_complete_h5.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import h5py 4 | import sys 5 | import os 6 | 7 | # Add project path 8 | main_path = os.path.dirname(os.path.realpath(__file__)) 9 | main_path = '/'.join(main_path.split('/')[:-2]) 10 | sys.path.append(main_path) 11 | 12 | # Folder permissions for cluster. 13 | os.umask(0o002) 14 | # H5 File bug over network file system. 15 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 16 | 17 | 18 | ##### Methods ####### 19 | # Get set paths. 20 | def representations_h5(main_path, model, dataset, img_size, zdim): 21 | hdf5_path_train = '%s/results/%s/%s/h%s_w%s_n3_zdim%s/hdf5_%s_he_train.h5' % (main_path, model, dataset, img_size, img_size, zdim, dataset) 22 | hdf5_path_valid = '%s/results/%s/%s/h%s_w%s_n3_zdim%s/hdf5_%s_he_validation.h5' % (main_path, model, dataset, img_size, img_size, zdim, dataset) 23 | hdf5_path_test = '%s/results/%s/%s/h%s_w%s_n3_zdim%s/hdf5_%s_he_test.h5' % (main_path, model, dataset, img_size, img_size, zdim, dataset) 24 | return [hdf5_path_train, hdf5_path_valid, hdf5_path_test] 25 | 26 | # Get total number of tiles for all sets. 27 | def get_total_samples(data): 28 | total_samples = 0 29 | for h5_path in [data[0], data[1], data[2]]: 30 | if not os.path.isfile(h5_path): 31 | print('Warning: H5 file not found', h5_path) 32 | continue 33 | with h5py.File(h5_path, 'r') as content: 34 | key_1 = list(content.keys())[0] 35 | total_samples += content[key_1].shape[0] 36 | return total_samples 37 | 38 | # Get key_names, shape, and dtype 39 | def data_specs(data): 40 | key_dict = dict() 41 | with h5py.File(data[0], 'r') as content: 42 | for key in content.keys(): 43 | key_dict[key] = dict() 44 | key_dict[key]['shape'] = content[key].shape[1:] 45 | key_dict[key]['dtype'] = content[key].dtype 46 | 47 | return key_dict 48 | 49 | 50 | def create_complete_h5(data, num_tiles, key_dict, override): 51 | h5_complete_path = data[0].replace('_train.h5', '_complete.h5') 52 | if override: 53 | os.remove(h5_complete_path) 54 | if os.path.isfile(h5_complete_path): 55 | print('File already exists, if you want to overwrite enable the flag --override') 56 | print(h5_complete_path) 57 | print() 58 | exit() 59 | 60 | storage_dict = dict() 61 | content = h5py.File(h5_complete_path, mode='w') 62 | for key in key_dict: 63 | shape = [num_tiles] + list(key_dict[key]['shape']) 64 | dtype = key_dict[key]['dtype'] 65 | storage_dict[key] = content.create_dataset(name=key.replace('train_', ''), shape=shape, dtype=dtype) 66 | 67 | storage_ref_dict = dict() 68 | dt = h5py.special_dtype(vlen=str) 69 | storage_ref_dict['indexes'] = content.create_dataset(name='indexes', shape=[shape[0]], dtype=np.int32) 70 | storage_ref_dict['original_set'] = content.create_dataset(name='original_set', shape=[shape[0]], dtype=dt) 71 | 72 | index = 0 73 | for set_path, set_name in [(data[0], 'train'), (data[1], 'valid'), (data[2], 'test')]: 74 | print('Iterating through %s ...' % set_name) 75 | 76 | with h5py.File(set_path, 'r') as content: 77 | set_dict = dict() 78 | for key in storage_dict: 79 | set_dict[key] = content[key.replace('train_', '%s_'%set_name)] 80 | 81 | for i in range(set_dict[key].shape[0]): 82 | # Original data. 83 | for key in storage_dict: 84 | storage_dict[key][index] = set_dict[key][i] 85 | # Back-referencing to sets. 86 | storage_ref_dict['indexes'][index] = i 87 | storage_ref_dict['original_set'][index] = set_name 88 | 89 | # Verbose. 90 | if i%1e+5==0: 91 | print('\tprocessed %s entries' % i) 92 | index += 1 93 | print() 94 | 95 | ##### Main ####### 96 | parser = argparse.ArgumentParser(description='Script to combine all H5 representation file into a \'complete\' one.') 97 | parser.add_argument('--img_size', dest='img_size', type=int, default=224, help='Image size for the model.') 98 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=128, help='Dimensionality of projections, default is the Z latent of Self-Supervised.') 99 | parser.add_argument('--dataset', dest='dataset', type=str, default='vgh_nki', help='Dataset to use.') 100 | parser.add_argument('--model', dest='model', type=str, default='ContrastivePathology', help='Model name, used to select the type of model (SimCLR, BYOL, SwAV).') 101 | parser.add_argument('--main_path', dest='main_path', type=str, default=None, help='Path for the output run.') 102 | parser.add_argument('--override', dest='override', action='store_true', default=False, help='Override \'complete\' H5 file if it already exists.') 103 | args = parser.parse_args() 104 | img_size = args.img_size 105 | z_dim = args.z_dim 106 | dataset = args.dataset 107 | model = args.model 108 | main_path = args.main_path 109 | override = args.override 110 | 111 | if main_path is None: 112 | main_path = os.path.dirname(os.path.realpath(__file__)) 113 | main_path = '/'.join(main_path.split('/')[:-2]) 114 | 115 | # Get representations paths. 116 | data = representations_h5(main_path, model, dataset, img_size, z_dim) 117 | 118 | # Get total number fo samples. 119 | num_tiles = get_total_samples(data) 120 | 121 | # Dictionary with keys, shapes, and dtypes. 122 | key_dict = data_specs(data) 123 | 124 | # Combine all H5 into a 'complete' one. 125 | create_complete_h5(data, num_tiles, key_dict, override) 126 | -------------------------------------------------------------------------------- /utilities/h5_handling/create_metadata_h5.py: -------------------------------------------------------------------------------- 1 | # Imports. 2 | import pandas as pd 3 | import numpy as np 4 | import argparse 5 | import h5py 6 | import sys 7 | import os 8 | 9 | # Add project path 10 | main_path = os.path.dirname(os.path.realpath(__file__)) 11 | main_path = '/'.join(main_path.split('/')[:-2]) 12 | sys.path.append(main_path) 13 | 14 | # Folder permissions for cluster. 15 | os.umask(0o002) 16 | # H5 File bug over network file system. 17 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 18 | 19 | ##### Methods ####### 20 | # Read metada data and list of individuals. 21 | def read_meta_data(file_path, matching_field): 22 | frame = pd.read_csv(file_path) 23 | meta_individuals = list(sorted(frame[matching_field].values.astype(str))) 24 | return frame, meta_individuals 25 | 26 | 27 | # Get number of samples that overlap with individuals in the meta file. 28 | def h5_overlap_meta_individuals(h5_file, matching_field, meta_individuals): 29 | h5_samples = 0 30 | with h5py.File(h5_file, 'r') as content: 31 | h5_individual_prev = '' 32 | match_flag = False 33 | for index_h5 in range(content[matching_field].shape[0]): 34 | h5_individual = content[matching_field][index_h5].decode("utf-8") 35 | if h5_individual != h5_individual_prev: 36 | if h5_individual in meta_individuals: 37 | match_flag = True 38 | else: 39 | match_flag = False 40 | h5_individual_prev = h5_individual 41 | if match_flag: 42 | h5_samples +=1 43 | 44 | return h5_samples 45 | 46 | # Get key_names, shape, and dtype 47 | def data_specs(h5_path): 48 | key_dict = dict() 49 | with h5py.File(h5_path, 'r') as content: 50 | for key in content.keys(): 51 | key_dict[key] = dict() 52 | key_dict[key]['shape'] = content[key].shape[1:] 53 | key_dict[key]['dtype'] = content[key].dtype 54 | 55 | return key_dict 56 | 57 | # Create new H5 file with individuals in the meta file. 58 | def create_metadata_h5(h5_file, meta_name, list_meta_field, matching_field, meta_individuals, num_tiles, key_dict, override): 59 | h5_metadata_path = h5_file.replace('.h5', '_%s.h5' % meta_name) 60 | 61 | if os.path.isfile(h5_metadata_path): 62 | if override: 63 | os.remove(h5_metadata_path) 64 | else: 65 | print('File already exists, if you want to overwrite enable the flag --override') 66 | print(h5_metadata_path) 67 | print() 68 | exit() 69 | 70 | storage_dict = dict() 71 | content = h5py.File(h5_metadata_path, mode='w') 72 | for key in key_dict: 73 | shape = [num_tiles] + list(key_dict[key]['shape']) 74 | dtype = key_dict[key]['dtype'] 75 | storage_dict[key] = content.create_dataset(name=key.replace('train_', ''), shape=shape, dtype=dtype) 76 | 77 | dt = h5py.special_dtype(vlen=str) 78 | for meta_field in list_meta_field: 79 | dtype = frame[meta_field].dtype 80 | if str(dtype) == 'object': 81 | dtype = dt 82 | storage_dict[meta_field] = content.create_dataset(name=meta_field, shape=[num_tiles], dtype=dtype) 83 | 84 | index = 0 85 | print('Iterating through %s ...' % h5_file.split('/')[-1]) 86 | with h5py.File(h5_file, 'r') as orig_content: 87 | 88 | set_dict = dict() 89 | for key in storage_dict: 90 | flag_meta_field = False 91 | for meta_field in list_meta_field: 92 | if key == meta_field: 93 | flag_meta_field = True 94 | break 95 | if flag_meta_field: 96 | continue 97 | set_dict[key] = orig_content[key] 98 | 99 | for i in range(set_dict[list(storage_dict.keys())[0]].shape[0]): 100 | # Verbose 101 | if i%100000==0: 102 | print('\tprocessed %s entries' % i) 103 | 104 | # If sample doesn't overlap with meta file, get rid of it. 105 | h5_individual = set_dict[matching_field][i].decode("utf-8") 106 | if h5_individual not in meta_individuals: 107 | continue 108 | for key in storage_dict: 109 | # Check for field to include. 110 | if key in list_meta_field: 111 | storage_dict[key][index] = frame[frame[matching_field].astype(str)==str(h5_individual)][key] 112 | # Copy all other fiels 113 | else: 114 | storage_dict[key][index] = set_dict[key][i] 115 | 116 | index += 1 117 | print() 118 | 119 | 120 | ##### Main ####### 121 | parser = argparse.ArgumentParser(description='Script to create a subset H5 representation file based on meta data file.') 122 | parser.add_argument('--meta_file', dest='meta_file', type=str, required=True, help='Path to CSV file with meta data.') 123 | parser.add_argument('--meta_name', dest='meta_name', type=str, required=True, help='Name to use to rename H5 file.') 124 | parser.add_argument('--list_meta_field', dest='list_meta_field', type=str, required=True, help='Field name that contains the information to include in the H5 file.', nargs="*") 125 | parser.add_argument('--matching_field', dest='matching_field', type=str, required=True, help='Reference filed to use, cross check between original H5 and meta file.') 126 | parser.add_argument('--h5_file', dest='h5_file', type=str, required=True, help='Original H5 file to parse.') 127 | parser.add_argument('--override', dest='override', action='store_true', default=False, help='Override \'complete\' H5 file if it already exists.') 128 | args = parser.parse_args() 129 | meta_file = args.meta_file 130 | meta_name = args.meta_name 131 | list_meta_field = args.list_meta_field 132 | matching_field = args.matching_field 133 | h5_file = args.h5_file 134 | override = args.override 135 | 136 | # Read meta data file and list of individuals according to the matching_field. 137 | frame, meta_individuals = read_meta_data(meta_file, matching_field) 138 | 139 | # Get number of tiles from all individuals in the original H5, <= to the original. 140 | num_tiles = h5_overlap_meta_individuals(h5_file, matching_field, meta_individuals) 141 | 142 | # Dictionary with keys, shapes, and dtypes. 143 | key_dict = data_specs(h5_file) 144 | 145 | # Create H5 with the list of individuals and the field. 146 | create_metadata_h5(h5_file, meta_name, list_meta_field, matching_field, meta_individuals, num_tiles, key_dict, override) 147 | -------------------------------------------------------------------------------- /utilities/h5_handling/h5_float_to_int.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import shutil 4 | import math 5 | import h5py 6 | import glob 7 | import sys 8 | import os 9 | 10 | 11 | parser = argparse.ArgumentParser(description='Convert H5 file image dataset from float32 to uint8, saving storage.') 12 | parser.add_argument('--dataset', dest='dataset', type=str, default=None, help='Dataset to use.') 13 | parser.add_argument('--marker', dest='marker', type=str, default='he', help='Marker of dataset to use.') 14 | parser.add_argument('--img_size', dest='img_size', type=int, default=224, help='Image size for the model.') 15 | parser.add_argument('--img_ch', dest='img_ch', type=int, default=3, help='Number of channels for the model.') 16 | parser.add_argument('--dbs_path', dest='dbs_path', type=str, default=None, help='Path where datasets are stored.') 17 | args = parser.parse_args() 18 | dataset = args.dataset 19 | marker = args.marker 20 | img_size = args.img_size 21 | img_ch = args.img_ch 22 | dbs_path = args.dbs_path 23 | 24 | if dbs_path is None: 25 | dbs_path = os.path.dirname(os.path.realpath(__file__)) 26 | dbs_path = '/'.join(dbs_path.split('/')[:-2]) 27 | 28 | sys.path.append(dbs_path) 29 | from data_manipulation.data import Data 30 | 31 | print('Dataset:', dataset) 32 | batch_size = 100 33 | 34 | data = Data(dataset=dataset, marker=marker, patch_h=img_size, patch_w=img_size, n_channels=img_ch, batch_size=64, project_path=dbs_path, load=False) 35 | 36 | # Dataset files. 37 | for hdf5_path in [data.hdf5_train, data.hdf5_validation, data.hdf5_test]: 38 | 39 | hdf5_path_new = hdf5_path + '_int' 40 | if os.path.isfile(hdf5_path_new): 41 | os.remove(hdf5_path_new) 42 | 43 | print('Current File:', hdf5_path) 44 | with h5py.File(hdf5_path, 'r') as original_content: 45 | with h5py.File(hdf5_path_new, mode='w') as hdf5_content: 46 | for key in original_content.keys(): 47 | print('\t', key, '-', original_content[key].shape) 48 | if 'images' in key or 'img' in key: 49 | normalized_flag = False 50 | if np.amax(original_content[key][:10, :, :, :]) == 1.: 51 | normalized_flag = True 52 | img_storage = hdf5_content.create_dataset(key, shape=original_content[key].shape, dtype=np.uint8) 53 | blocks = math.ceil(original_content[key].shape[0]/batch_size) 54 | for i in range(blocks): 55 | if normalized_flag: 56 | img_storage[i*batch_size:(i+1)*batch_size, :, :, :] = original_content[key][i*batch_size:(i+1)*batch_size, :, :, :]*255 57 | else: 58 | img_storage[i*batch_size:(i+1)*batch_size, :, :, :] = original_content[key][i*batch_size:(i+1)*batch_size, :, :, :] 59 | if i*batch_size%10000==0: print('\t\t', 'Processed', i*batch_size, 'images') 60 | else: 61 | hdf5_content.create_dataset(key, data=original_content[key]) 62 | os.remove(hdf5_path) 63 | shutil.move(hdf5_path_new, hdf5_path) 64 | -------------------------------------------------------------------------------- /utilities/h5_handling/include_samples_slides_participants_h5.py: -------------------------------------------------------------------------------- 1 | # Imports. 2 | import argparse 3 | import h5py 4 | import sys 5 | import os 6 | 7 | # Add project path 8 | main_path = os.path.dirname(os.path.realpath(__file__)) 9 | main_path = '/'.join(main_path.split('/')[:-2]) 10 | sys.path.append(main_path) 11 | 12 | # Own libraries 13 | from data_manipulation.data import Data 14 | 15 | # Folder permissions for cluster. 16 | os.umask(0o002) 17 | # H5 File bug over network file system. 18 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 19 | 20 | ##### Methods ####### 21 | # Methods to include slide and participant into H5 file. 22 | def include_slide_participants_h5(h5_path, reference_field): 23 | # Sample 'TCGA-AA-3812-01' 24 | # Slide 'TCGA-02-0001-01C-01-TS1' 25 | # Participant 'TCGA-02-0001' 26 | slides_h5 = list() 27 | samples_h5 = list() 28 | participants_h5 = list() 29 | with h5py.File(h5_path, 'r+') as content: 30 | print('Processing file:', h5_path) 31 | if reference_field not in content.keys(): 32 | print('\tReference field', reference_field, 'not found') 33 | print('\tH5 Keys:', ', '.join(content.keys())) 34 | exit() 35 | file_name_prev = '' 36 | for index_h5 in range(content[reference_field].shape[0]): 37 | file_name = content[reference_field][index_h5,0].decode("utf-8") 38 | try: 39 | base_name = file_name.split('.')[0] 40 | base_name = base_name.split('_')[1] 41 | base_name = base_name.split('-') 42 | slide_name = '-'.join(base_name) 43 | participant_name = '-'.join(base_name[:3]) 44 | sample_name = '-'.join(base_name[:4]) 45 | slides_h5.append(slide_name) 46 | samples_h5.append(sample_name) 47 | participants_h5.append(participant_name) 48 | except: 49 | if file_name_prev != file_name: 50 | print('\tCorrupted entry, not able to find mapping:', file_name) 51 | slides_h5.append('None') 52 | samples_h5.append('None') 53 | participants_h5.append('None') 54 | file_name_prev = file_name 55 | 56 | if 'slides' not in content.keys(): 57 | content.create_dataset('slides', data=slides_h5) 58 | if 'samples' not in content.keys(): 59 | content.create_dataset('samples', data=samples_h5) 60 | if 'participants' not in content.keys(): 61 | content.create_dataset('participants', data=participants_h5) 62 | print() 63 | 64 | 65 | ##### Main ####### 66 | parser = argparse.ArgumentParser(description='Script to include sample, slide, and participant information from a reference field in the H5.') 67 | parser.add_argument('--img_size', dest='img_size', type=int, default=224, help='Image size for the model.') 68 | parser.add_argument('--img_ch', dest='img_ch', type=int, default=3, help='Number of channels for the model.') 69 | parser.add_argument('--dataset', dest='dataset', type=str, default='vgh_nki', help='Dataset to use.') 70 | parser.add_argument('--marker', dest='marker', type=str, default='he', help='Marker of dataset to use.') 71 | parser.add_argument('--main_path', dest='main_path', type=str, default=None, help='Path for the output run.') 72 | parser.add_argument('--dbs_path', dest='dbs_path', type=str, default=None, help='Directory with DBs to use.') 73 | parser.add_argument('--ref_field', dest='ref_field', type=str, default='file_name', help='Key name that contains slide and participant information.') 74 | args = parser.parse_args() 75 | image_width = args.img_size 76 | image_height = args.img_size 77 | image_channels = args.img_ch 78 | dataset = args.dataset 79 | marker = args.marker 80 | main_path = args.main_path 81 | dbs_path = args.dbs_path 82 | ref_field = args.ref_field 83 | 84 | if main_path is None: 85 | main_path = os.path.dirname(os.path.realpath(__file__)) 86 | main_path = '/'.join(main_path.split('/')[:-2]) 87 | 88 | if dbs_path is None: 89 | dbs_path = main_path 90 | 91 | # Data Class with all h5 92 | data = Data(dataset=dataset, marker=marker, patch_h=image_height, patch_w=image_width, n_channels=image_channels, batch_size=64, project_path=dbs_path, load=False) 93 | 94 | for h5_path in [data.hdf5_train, data.hdf5_validation, data.hdf5_test]: 95 | if not os.path.isfile(h5_path): 96 | print('Warning: H5 file not found', h5_path) 97 | continue 98 | include_slide_participants_h5(h5_path, ref_field) 99 | 100 | # 101 | # def representations_h5(main_path, model, dataset, img_size, zdim): 102 | # hdf5_path_train = '%s/results/%s/%s/h%s_w%s_n3_zdim%s/hdf5_%s_he_train.h5' % (main_path, model, dataset, img_size, img_size, zdim, dataset) 103 | # hdf5_path_valid = '%s/results/%s/%s/h%s_w%s_n3_zdim%s/hdf5_%s_he_validation.h5' % (main_path, model, dataset, img_size, img_size, zdim, dataset) 104 | # hdf5_path_test = '%s/results/%s/%s/h%s_w%s_n3_zdim%s/hdf5_%s_he_test.h5' % (main_path, model, dataset, img_size, img_size, zdim, dataset) 105 | # return [hdf5_path_train, hdf5_path_valid, hdf5_path_test] 106 | # 107 | # 108 | # data = representations_h5(main_path, 'ContrastivePathology_BarlowTwins_3', dataset, 224, 128) 109 | # for h5_path in [data[0], data[1], data[2]]: 110 | # if not os.path.isfile(h5_path): 111 | # print('Warning: H5 file not found', h5_path) 112 | # continue 113 | # include_slide_participants_h5(h5_path, ref_field) 114 | -------------------------------------------------------------------------------- /utilities/h5_handling/split_h5_by_pattern.py: -------------------------------------------------------------------------------- 1 | # Imports. 2 | import argparse 3 | import h5py 4 | import sys 5 | import os 6 | 7 | # Add project path. 8 | main_path = os.path.dirname(os.path.realpath(__file__)) 9 | main_path = '/'.join(main_path.split('/')[:-2]) 10 | sys.path.append(main_path) 11 | 12 | from data_manipulation.utils import load_data 13 | 14 | # Folder permissions for cluster. 15 | os.umask(0o002) 16 | # H5 File bug over network file system. 17 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 18 | 19 | 20 | ##### Methods ####### 21 | # Get number of samples that contain pattern in matching field. 22 | def h5_overlap_pattern_individuals(h5_file, matching_field, pattern): 23 | h5_samples = 0 24 | with h5py.File(h5_file, 'r') as content: 25 | for index_h5 in range(content[matching_field].shape[0]): 26 | field_value = content[matching_field][index_h5].decode("utf-8") 27 | if pattern in field_value: 28 | h5_samples +=1 29 | 30 | return h5_samples 31 | 32 | # Get key_names, shape, and dtype. 33 | def data_specs(data_path): 34 | key_dict = dict() 35 | with h5py.File(data_path, 'r') as content: 36 | for key in content.keys(): 37 | key_dict[key] = dict() 38 | key_dict[key]['shape'] = content[key].shape[1:] 39 | key_dict[key]['dtype'] = content[key].dtype 40 | h5_samples = content[key].shape[0] 41 | 42 | return key_dict, h5_samples 43 | 44 | # Filter out given indexes. 45 | def create_complete_h5(data_path, num_tiles, key_dict, pattern, matching_field, override): 46 | h5_complete_path = data_path.replace('.h5', '_%s.h5' % pattern) 47 | if override: 48 | os.remove(h5_complete_path) 49 | if os.path.isfile(h5_complete_path): 50 | print('File already exists, if you want to overwrite enable the flag --override') 51 | print(h5_complete_path) 52 | print() 53 | exit() 54 | 55 | storage_dict = dict() 56 | content = h5py.File(h5_complete_path, mode='w') 57 | for key in key_dict: 58 | shape = [num_tiles] + list(key_dict[key]['shape']) 59 | dtype = key_dict[key]['dtype'] 60 | key_ = key.replace('train_', '') 61 | key_ = key_.replace('valid_', '') 62 | key_ = key_.replace('test_', '') 63 | storage_dict[key] = content.create_dataset(name=key_, shape=shape, dtype=dtype) 64 | 65 | index = 0 66 | print('Iterating through %s ...' % data_path) 67 | with h5py.File(data_path, 'r') as content: 68 | set_dict = dict() 69 | for key in storage_dict: 70 | set_dict[key] = content[key] 71 | 72 | for i in range(set_dict[key].shape[0]): 73 | # Verbose. 74 | if i%1e+5==0: 75 | print('\tprocessed %s entries' % i) 76 | 77 | # Check in pattern is contained in this instance. 78 | if pattern not in set_dict[matching_field][i].decode("utf-8"): 79 | continue 80 | 81 | # Original data. 82 | for key in storage_dict: 83 | storage_dict[key][index] = set_dict[key][i] 84 | 85 | if num_tiles == index: 86 | break 87 | 88 | index += 1 89 | print() 90 | 91 | 92 | ##### Main ####### 93 | parser = argparse.ArgumentParser(description='Script to create a new H5 file that contains a particular pattern.') 94 | parser.add_argument('--h5_file', dest='h5_file', type=str, required=True, help='Original H5 file to parse.') 95 | parser.add_argument('--matching_field', dest='matching_field', type=str, required=True, help='Reference filed to use, cross check between original H5 and meta file.') 96 | parser.add_argument('--pattern', dest='pattern', type=str, required=True, help='Pattern to search for in the matching_field entries.') 97 | parser.add_argument('--override', dest='override', action='store_true', default=False, help='Override \'complete\' H5 file if it already exists.') 98 | args = parser.parse_args() 99 | h5_file = args.h5_file 100 | matching_field = args.matching_field 101 | pattern = args.pattern 102 | override = args.override 103 | 104 | # Get number of tiles from all individuals in the original H5, <= to the original. 105 | num_tiles = h5_overlap_pattern_individuals(h5_file, matching_field, pattern) 106 | 107 | # Information content H5 file. 108 | key_dict, h5_samples = data_specs(h5_file) 109 | 110 | # Remove from H5 file. 111 | create_complete_h5(h5_file, num_tiles, key_dict, pattern, matching_field, override) 112 | 113 | -------------------------------------------------------------------------------- /utilities/h5_handling/subsample_h5.py: -------------------------------------------------------------------------------- 1 | # Imports. 2 | import argparse 3 | import random 4 | import h5py 5 | import sys 6 | import os 7 | 8 | # Add project path. 9 | main_path = os.path.dirname(os.path.realpath(__file__)) 10 | main_path = '/'.join(main_path.split('/')[:-2]) 11 | sys.path.append(main_path) 12 | 13 | from data_manipulation.utils import load_data 14 | 15 | # Folder permissions for cluster. 16 | os.umask(0o002) 17 | # H5 File bug over network file system. 18 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 19 | 20 | 21 | ##### Methods ####### 22 | # Get key_names, shape, and dtype. 23 | def data_specs(data_path, num_samples): 24 | key_dict = dict() 25 | with h5py.File(data_path, 'r') as content: 26 | for key in content.keys(): 27 | key_dict[key] = dict() 28 | key_dict[key]['shape'] = content[key].shape[1:] 29 | key_dict[key]['dtype'] = content[key].dtype 30 | h5_samples = content[key].shape[0] 31 | 32 | if num_samples > content[key].shape[0]: 33 | print('Number of subsamples %s is smaller than the size of original H5 file %s' % (num_samples,content[key].shape[0])) 34 | exit() 35 | 36 | return key_dict, h5_samples 37 | 38 | # Get subsamples from original H5 file. 39 | def create_complete_h5(data_path, num_samples, key_dict, override): 40 | h5_complete_path = data_path.replace('.h5', '_%s_subsampled.h5' % num_samples) 41 | if override: 42 | os.remove(h5_complete_path) 43 | if os.path.isfile(h5_complete_path): 44 | print('File already exists, if you want to overwrite enable the flag --override') 45 | print(h5_complete_path) 46 | print() 47 | exit() 48 | 49 | storage_dict = dict() 50 | content = h5py.File(h5_complete_path, mode='w') 51 | for key in key_dict: 52 | shape = [num_samples] + list(key_dict[key]['shape']) 53 | dtype = key_dict[key]['dtype'] 54 | key_ = key.replace('train_', '') 55 | key_ = key_.replace('valid_', '') 56 | key_ = key_.replace('test_', '') 57 | storage_dict[key] = content.create_dataset(name=key_, shape=shape, dtype=dtype) 58 | 59 | index = 0 60 | print('Iterating through %s ...' % data_path) 61 | with h5py.File(data_path, 'r') as content: 62 | set_dict = dict() 63 | for key in storage_dict: 64 | set_dict[key] = content[key] 65 | 66 | all_indexes = list(range(content[key].shape[0])) 67 | random.shuffle(all_indexes) 68 | all_indexes = all_indexes[:num_samples] 69 | 70 | for i, random_index in enumerate(all_indexes): 71 | # Original data. 72 | for key in storage_dict: 73 | storage_dict[key][index] = set_dict[key][random_index] 74 | 75 | # Verbose. 76 | if i%1e+5==0: 77 | print('\tprocessed %s entries' % i) 78 | index += 1 79 | print() 80 | 81 | 82 | ##### Main ####### 83 | parser = argparse.ArgumentParser(description='Script to subsample indexes from H5 file.') 84 | parser.add_argument('--h5_file', dest='h5_file', type=str, required=True, help='Original H5 file to parse.') 85 | parser.add_argument('--num_samples', dest='num_samples', type=int, required=True, help='Number of random subsamples to pick from original H5 file.') 86 | parser.add_argument('--override', dest='override', action='store_true', default=False, help='Override \'complete\' H5 file if it already exists.') 87 | args = parser.parse_args() 88 | h5_file = args.h5_file 89 | num_samples = args.num_samples 90 | override = args.override 91 | 92 | 93 | # Check if files exist. 94 | if not os.path.isfile(h5_file): 95 | print('File not found:', h5_file) 96 | exit() 97 | 98 | # Information content H5 file. 99 | key_dict, h5_samples = data_specs(h5_file, num_samples) 100 | 101 | # Remove from H5 file. 102 | create_complete_h5(h5_file, num_samples, key_dict, override) 103 | 104 | -------------------------------------------------------------------------------- /utilities/hovernet_annotations/create_hovernet_master.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import argparse 4 | import math 5 | import glob 6 | import sys 7 | import os 8 | 9 | # Add project path 10 | main_path = os.path.dirname(os.path.realpath(__file__)) 11 | main_path = '/'.join(main_path.split('/')[:-2]) 12 | sys.path.append(main_path) 13 | 14 | # Own library. 15 | from models.visualization.attention_maps import get_x_y 16 | 17 | def create_main_file(cell_types, cell_names, csvs_path, output_csv): 18 | all_csvs = glob.glob(os.path.join(csvs_path, '*_data_details_per_tile_per_celltype.txt')) 19 | for i_csv, csv_path in enumerate(sorted(all_csvs)): 20 | all_data = list() 21 | slide_id = csv_path.split('/')[-1].split('_')[0] 22 | print('(%s/%s) Iterating through slide %s...' % (i_csv+1, len(all_csvs), slide_id)) 23 | slide_df = pd.read_csv(csv_path) 24 | slide_df = slide_df[['TileID', 'type', 'Nb_partices']] 25 | slide_df = slide_df[~slide_df['type'].isna()] 26 | 27 | if slide_df.shape[0] == 0: 28 | print('\t[Warning] Empty slide annotation.') 29 | continue 30 | 31 | # For each tile 32 | slide_df = slide_df.sort_values(by='TileID') 33 | tile_id_prev = '' 34 | cell_counts_tile = [0]*len(cell_types) 35 | for i, values in enumerate(zip(slide_df.TileID.values.tolist(), slide_df.type.values.astype(int).tolist(), slide_df.Nb_partices.values.astype(int).tolist())): 36 | tile_id, type_cell, counts = values 37 | if tile_id!=tile_id_prev: 38 | if i!=0: 39 | all_data.append([slide_id, tile_id_prev, x, y] + cell_counts_tile) 40 | x, y = get_x_y(tile_id) 41 | cell_counts_tile = [0]*len(cell_types) 42 | tile_id_prev = tile_id 43 | cell_counts_tile[type_cell] = counts 44 | 45 | all_data = np.stack(all_data) 46 | all_df = pd.DataFrame(all_data, columns=['slides', 'tile_id', 'x', 'y'] + cell_names) 47 | all_df.to_csv(output_csv, mode='a', header=not os.path.exists(output_csv), index=False) 48 | 49 | def include_5x_xy_annotations(all_df, output_csv): 50 | x_values = all_df.x.values.astype(int).tolist() 51 | y_values = all_df.y.values.astype(int).tolist() 52 | x_5x_values = list() 53 | y_5x_values = list() 54 | for x, y in zip(x_values, y_values): 55 | x_5x_values.append(math.floor(x/4)) 56 | y_5x_values.append(math.floor(y/4)) 57 | 58 | all_df['x_5x'] = x_5x_values 59 | all_df['y_5x'] = y_5x_values 60 | 61 | all_df.to_csv(output_csv, index=False) 62 | return all_df 63 | 64 | def create_5x_from_20x(all_df, cell_names, output_5x_csv): 65 | # Get 5x annotations from 20x. 66 | if 'slide_tile_5x' not in all_df.columns: 67 | all_df['slide_tile_5x'] = all_df.apply(lambda x: '%s_%s_%s' % (x['slides'], x['x_5x'], x['y_5x']), axis=1) 68 | all_df['annotated_20x_tile_count'] = 1 69 | all_5x_df = all_df[['slide_tile_5x', 'annotated_20x_tile_count']+cell_names].groupby('slide_tile_5x').sum() 70 | all_5x_df = all_5x_df.reset_index() 71 | if 'slide' not in all_5x_df.columns: 72 | all_5x_df['slides'] = all_5x_df.apply(lambda x: x['slide_tile_5x'].split('_')[0], axis=1) 73 | all_5x_df['tiles'] = all_5x_df.apply(lambda x: '%s_%s' % (x['slide_tile_5x'].split('_')[1], x['slide_tile_5x'].split('_')[2]), axis=1) 74 | del all_5x_df['slide_tile_5x'] 75 | all_5x_df.to_csv(output_5x_csv, index=False) 76 | 77 | 78 | ##### Main ####### 79 | parser = argparse.ArgumentParser(description='Script to all HoverNet annotations into one file.\nDirectory structure assumption: workspace/datasets/HoverNet/\'dataset\'/\'magnification\'.') 80 | parser.add_argument('--dataset', dest='dataset', type=str, default='NYU_LUADall_5x', help='Dataset to use.') 81 | parser.add_argument('--magnification', dest='magnification', type=str, default='20x', help='Magnification.') 82 | parser.add_argument('--main_path', dest='main_path', type=str, default=None, help='Path for the output run.') 83 | args = parser.parse_args() 84 | dataset = args.dataset 85 | magnification = args.magnification 86 | main_path = args.main_path 87 | 88 | 89 | if main_path is None: 90 | main_path = os.path.dirname(os.path.realpath(__file__)) 91 | main_path = '/'.join(main_path.split('/')[:-2]) 92 | 93 | # Working directories. 94 | csvs_path = '%s/datasets/HoverNet/%s/%s' % (main_path, dataset, magnification) 95 | output_csv = os.path.join(csvs_path, '%s_hovernet_annotations_20x.csv' % dataset) 96 | 97 | # Reference for cell types in the tiles. 98 | cell_types = [0, 1, 2, 3, 4, 5 ] 99 | cell_names = ['cell other', 'cell neoplastic', 'cell inflammatory', 'cell connective', 'cell dead', 'cell non-neoplastic epithelial'] 100 | 101 | # Create main file. 102 | if not os.path.isfile(output_csv): 103 | create_main_file(cell_types, cell_names, csvs_path, output_csv) 104 | all_df = pd.read_csv(output_csv) 105 | 106 | # Include 5x locations. 107 | if magnification == '20x' and (('x_5x' not in all_df.columns) or ('y_5x' not in all_df.columns)): 108 | all_df = include_5x_xy_annotations(all_df, output_csv) 109 | 110 | # Create a reference file for 5x tiles from 20x. 111 | output_csv = os.path.join(csvs_path, '%s_hovernet_annotations_5x.csv' % dataset) 112 | if not os.path.isfile(output_csv): 113 | create_5x_from_20x(all_df, cell_names, output_csv) -------------------------------------------------------------------------------- /utilities/tile_cleaning/remove_indexes_h5.py: -------------------------------------------------------------------------------- 1 | # Imports. 2 | import argparse 3 | import h5py 4 | import sys 5 | import os 6 | 7 | # Add project path. 8 | main_path = os.path.dirname(os.path.realpath(__file__)) 9 | main_path = '/'.join(main_path.split('/')[:-2]) 10 | sys.path.append(main_path) 11 | 12 | # Own libs. 13 | from data_manipulation.utils import load_data 14 | 15 | # Folder permissions for cluster. 16 | os.umask(0o002) 17 | # H5 File bug over network file system. 18 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 19 | 20 | 21 | ##### Methods ####### 22 | # Get key_names, shape, and dtype. 23 | def data_specs(data_path): 24 | key_dict = dict() 25 | with h5py.File(data_path, 'r') as content: 26 | for key in content.keys(): 27 | key_dict[key] = dict() 28 | key_dict[key]['shape'] = content[key].shape[1:] 29 | key_dict[key]['dtype'] = content[key].dtype 30 | h5_samples = content[key].shape[0] 31 | 32 | return key_dict, h5_samples 33 | 34 | # Filter out given indexes. 35 | def create_complete_h5(data_path, num_tiles, key_dict, indexes_to_remove, override): 36 | h5_complete_path = data_path.replace('.h5', '_filtered.h5') 37 | if override: 38 | os.remove(h5_complete_path) 39 | if os.path.isfile(h5_complete_path): 40 | print('File already exists, if you want to overwrite enable the flag --override') 41 | print(h5_complete_path) 42 | print() 43 | exit() 44 | 45 | storage_dict = dict() 46 | content = h5py.File(h5_complete_path, mode='w') 47 | for key in key_dict: 48 | shape = [num_tiles] + list(key_dict[key]['shape']) 49 | dtype = key_dict[key]['dtype'] 50 | key_ = key.replace('train_', '') 51 | key_ = key_.replace('valid_', '') 52 | key_ = key_.replace('test_', '') 53 | storage_dict[key] = content.create_dataset(name=key_, shape=shape, dtype=dtype) 54 | 55 | index = 0 56 | print('Iterating through %s ...' % data_path) 57 | with h5py.File(data_path, 'r') as content: 58 | set_dict = dict() 59 | for key in storage_dict: 60 | set_dict[key] = content[key] 61 | 62 | for i in range(set_dict[key].shape[0]): 63 | if num_tiles == index: 64 | break 65 | 66 | # Original data. 67 | for key in storage_dict: 68 | storage_dict[key][index] = set_dict[key][i] 69 | 70 | # Check if this is a tile to remove 71 | if i in indexes_to_remove: 72 | indexes_to_remove.remove(i) 73 | continue 74 | 75 | # Verbose. 76 | if i%1e+5==0: 77 | print('\tprocessed %s entries' % i) 78 | index += 1 79 | print() 80 | 81 | 82 | ##### Main ####### 83 | parser = argparse.ArgumentParser(description='Script to remove indexes from H5 file.') 84 | parser.add_argument('--h5_file', dest='h5_file', type=str, required=True, help='Original H5 file to parse.') 85 | parser.add_argument('--pickle_file', dest='pickle_file', type=str, required=True, help='Pickle file with indexes to remove.') 86 | parser.add_argument('--override', dest='override', action='store_true', default=False, help='Override \'complete\' H5 file if it already exists.') 87 | args = parser.parse_args() 88 | h5_file = args.h5_file 89 | pickle_file = args.pickle_file 90 | override = args.override 91 | 92 | 93 | # Check if files exist. 94 | for file_path in [h5_file, pickle_file]: 95 | if not os.path.isfile(file_path): 96 | print('File not found:', file_path) 97 | exit() 98 | 99 | # Grab indexes. 100 | indexes_to_remove = load_data(pickle_file) 101 | 102 | # Information content H5 file. 103 | key_dict, h5_samples = data_specs(h5_file) 104 | 105 | # Remove from H5 file. 106 | remain_samples = h5_samples - len(indexes_to_remove) 107 | create_complete_h5(h5_file, remain_samples, key_dict, indexes_to_remove, override) 108 | 109 | --------------------------------------------------------------------------------