├── LICENSE ├── README.md ├── aggregate_imputed.py ├── cluster.py ├── connected_components.py ├── data └── markers │ ├── cell-type-template.tsv │ └── signature-score-template.txt ├── differential.py ├── download_checkpoints.sh ├── download_demo.sh ├── enrich.py ├── extract_features.py ├── get_mask.py ├── hipt_4k.py ├── hipt_heatmap_utils.py ├── hipt_model_utils.py ├── image.py ├── impute.py ├── impute_by_basic.py ├── marker_score.py ├── pixannot.py ├── plot_imputed.py ├── plot_spots.py ├── preprocess.py ├── reduce_dim.py ├── reorganize_imputed.py ├── requirements.txt ├── rescale.py ├── run.sh ├── run_demo.sh ├── select_genes.py ├── structural_similarity.py ├── train.py ├── utils.py ├── vision_transformer.py ├── vision_transformer4k.py └── visual.py /README.md: -------------------------------------------------------------------------------- 1 | # Inferring Super-Resolution Tissue Architecture by Integrating Spatial Transcriptomics and Histology 2 | 3 | This software package implements iStar 4 | (Inferring Super-resolution Tissue ARchitecture), 5 | which enhances the spatial resolution of spatial transcriptomic data 6 | from a spot-level to a near-single-cell level. 7 | The iStar method is presented in the following paper: 8 | 9 | Daiwei Zhang, Amelia Schroeder, Hanying Yan, Haochen Yang, Jian Hu, Michelle Y. Y. Lee, Kyung S. Cho, Katalin Susztak, George X. Xu, Michael D. Feldman, Edward B. Lee, Emma E. Furth, Linghua Wang, Mingyao Li. 10 | Inferring super-resolution tissue architecture by integrating spatial transcriptomics with histology. 11 | *Nature Biotechnology* (2024). 12 | https://doi.org/10.1038/s41587-023-02019-9 13 | 14 | ## iStar WebUI (Update 2024-05-11) 15 | 16 | A web version of iStar is now available at [istar.live](http://istar.live). 17 | New features will be continuously added here as we develop expansions of the model. 18 | Please contact [Daiwei (David) Zhang](mailto:daiwei.zhang@pennmedicine.upenn.edu) 19 | if you encounter any issues or have any questions. 20 | 21 | ## Get Started 22 | 23 | To run the demo, 24 | ```python 25 | # Use Python 3.9 or above 26 | pip install -r requirements.txt 27 | ./run_demo.sh 28 | ``` 29 | Using GPUs is highly recommended. 30 | 31 | ### Data format 32 | 33 | - `he-raw.jpg`: Raw histology image 34 | - `cnts.tsv`: Gene count matrix. 35 | - Row 1: Gene names. 36 | - Row 2 and after: Each row is a spot. 37 | - Column 1: Spot ID. 38 | - Column 2 and after: Each column is a gene. 39 | - `locs-raw.tsv`: Spot locations 40 | - Row 1: Header 41 | - Row 2 and after: Each row is a spot. Must match rows in `cnts.tsv` 42 | - Column 1: Spot ID 43 | - Column 2: x-coordinate (horizontal axis). Must be in the same space as axis-1 (column) of the array indices of pixels in `he-raw.jpg`. 44 | - Column 2: y-coordinate (vertical axis). Must be in the same space as axis-0 (row) of the array indices of pixels in `he-raw.jpg`. 45 | - `pixel-size-raw.txt`: Side length (in micrometers) of pixels in `he-raw.jpg`. This value is usually between 0.1 and 1.0. 46 | - For [Visium](https://support.10xgenomics.com/spatial-gene-expression/software/pipelines/latest/output/spatial) data, this value can be approximated by `8000 / 2000 * tissue_hires_scalef`, where `tissue_hires_scalef` is stored in `scalefactors_json.json`. 47 | - `radius-raw.txt`: Number of pixels per spot radius in `he-raw.jpg`. 48 | - For [Visium](https://support.10xgenomics.com/spatial-gene-expression/software/pipelines/latest/output/spatial) data, this value can be computed by `spot_diameter_fullres * 0.5`, where `spot_diameter_fullres` is stored in `scalefactors_json.json`, and should be close to `55 * 0.5 / pixel_size_raw`. 49 | 50 | ## License 51 | 52 | The software package is licensed under GPL-3.0. 53 | For commercial use, please contact 54 | [Daiwei (David) Zhang](mailto:daiwei.zhang@pennmedicine.upenn.edu) and 55 | [Mingyao Li](mailto:mingyao@pennmedicine.upenn.edu). 56 | 57 | ## Acknowledgements 58 | 59 | The codes for iStar are written by Daiwei (David) Zhang and under active development. 60 | Please open an issue on GitHub if you have any questions about the software package. 61 | 62 | The codebase for the hierarchical vision transformer is built upon 63 | [Vision Transformer](https://arxiv.org/abs/2010.11929) 64 | (as implemented by [Hugging Face](https://github.com/huggingface/pytorch-image-models)), 65 | [DINO](https://github.com/facebookresearch/dino), and 66 | [HIPT](https://github.com/mahmoodlab/HIPT). 67 | We thank the authors for releasing the codes and the model weights. 68 | 69 | If you find this work useful, please consider citing 70 | ```bash 71 | @article{zhang2024inferring, 72 | title = {Inferring Super-Resolution Tissue Architecture by Integrating Spatial Transcriptomics with Histology}, 73 | author = {Zhang, Daiwei and Schroeder, Amelia and Yan, Hanying and Yang, Haochen and Hu, Jian and Lee, Michelle Y. Y. and Cho, Kyung S. and Susztak, Katalin and Xu, George X. and Feldman, Michael D. and Lee, Edward B. and Furth, Emma E. and Wang, Linghua and Li, Mingyao}, 74 | year = {2024}, 75 | month = jan, 76 | journal = {Nature Biotechnology}, 77 | pages = {1--6}, 78 | doi = {10.1038/s41587-023-02019-9}, 79 | } 80 | ``` 81 | as well as 82 | [Vision Transformer](https://arxiv.org/abs/2010.11929), 83 | [DINO](https://github.com/facebookresearch/dino), and 84 | [HIPT](https://github.com/mahmoodlab/HIPT). 85 | -------------------------------------------------------------------------------- /aggregate_imputed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from utils import load_pickle, save_tsv, read_lines 7 | 8 | 9 | def get_masks(labels): 10 | if labels.ndim == 3: 11 | labels, labels_dict = flatten_labels(labels) 12 | else: 13 | uniq = np.unique(labels) 14 | uniq = uniq[uniq >= 0] 15 | labels_dict = uniq[..., np.newaxis] 16 | labels_uniq = np.unique(labels[labels >= 0]) 17 | masks = [labels == lab for lab in labels_uniq] 18 | masks = np.array(masks) 19 | return masks, labels_dict 20 | 21 | 22 | def flatten_labels(labels): 23 | isin = (labels >= 0).all(-1) 24 | flat = np.full_like(labels[..., 0], -99) 25 | flat[~isin] = -1 26 | dic, indices = np.unique( 27 | labels[isin], return_inverse=True, axis=0) 28 | flat[isin] = indices 29 | return flat, dic 30 | 31 | 32 | def to_str(labels): 33 | labels = np.char.zfill(labels.astype(str), 2) 34 | labels = ['_'.join(e) for e in labels] 35 | return labels 36 | 37 | 38 | def aggregate(x, masks, labels): 39 | groups = [x[ma] for ma in masks] 40 | df = pd.DataFrame([[g.size, g.mean(), g.var()] for g in groups]) 41 | df.columns = ['count', 'mean', 'variance'] 42 | labels = to_str(labels) 43 | df.index = labels 44 | df.index.name = 'cluster' 45 | return df 46 | 47 | 48 | def aggregate_files(prefix, gene_names, masks, labels): 49 | for gname in gene_names: 50 | x = load_pickle(f'{prefix}cnts-super/{gname}.pickle') 51 | stats = aggregate(x, masks, labels) 52 | save_tsv(stats, f'{prefix}cnts-clustered/by-genes/{gname}.tsv') 53 | 54 | 55 | def main(): 56 | prefix = sys.argv[1] # e.g. 'data/her2st/B1/' 57 | clus = load_pickle(f'{prefix}clusters-gene/labels.pickle') 58 | masks, labels = get_masks(clus) 59 | gene_names = read_lines(prefix+'gene-names.txt') 60 | aggregate_files(prefix, gene_names, masks, labels) 61 | 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /cluster.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import time 3 | 4 | import numpy as np 5 | from sklearn.cluster import MiniBatchKMeans, KMeans, AgglomerativeClustering 6 | from sklearn.mixture import GaussianMixture 7 | # from sklearn.neighbors import kneighbors_graph 8 | # from hdbscan import HDBSCAN 9 | # from einops import reduce 10 | import matplotlib.pyplot as plt 11 | from scipy.cluster.hierarchy import dendrogram 12 | 13 | from utils import load_pickle, save_pickle, sort_labels, load_mask 14 | from image import smoothen, upscale 15 | from visual import plot_labels, plot_label_masks 16 | from connected_components import ( 17 | relabel_small_connected, cluster_connected) 18 | from reduce_dim import reduce_dim 19 | 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('embeddings', type=str) 24 | parser.add_argument('prefix', type=str) 25 | parser.add_argument('--mask', type=str, default=None) 26 | parser.add_argument('--method', type=str, default='km') 27 | parser.add_argument('--n-clusters', type=int, default=None) 28 | parser.add_argument('--n-components', type=float, default=None) 29 | parser.add_argument('--filter-size', type=int, default=None) 30 | parser.add_argument('--min-cluster-size', type=int, default=None) 31 | # parser.add_argument('--stride', type=int, default=4) 32 | # parser.add_argument('--location-weight', type=float, default=None) 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def cluster_sub(embs, labels, n_clusters, location_weight, method): 38 | if labels.ndim == 2: 39 | labels = labels[..., np.newaxis] 40 | labs_uniq = np.unique( 41 | labels.reshape(-1, labels.shape[-1]), 42 | axis=0) 43 | labels_sub = np.full_like(labels[..., [0]], -1) 44 | for lab in labs_uniq: 45 | isin = (labels == lab).all(-1) 46 | if (lab >= 0).all(): 47 | embs_sub = embs.copy().transpose(1, 2, 0) 48 | embs_sub[~isin] = np.nan 49 | embs_sub = embs_sub.transpose(2, 0, 1) 50 | labs_sub, __ = cluster( 51 | embs_sub, n_clusters, method, location_weight) 52 | assert labs_sub[isin].min() == 0 53 | labs_sub[isin] -= labs_sub[isin].min() 54 | labels_sub[isin] = labs_sub[isin][..., np.newaxis] 55 | labels_sub = labels_sub[..., 0] 56 | proba_sub = None 57 | return labels_sub, proba_sub 58 | 59 | 60 | def plot_dendrogram(model, outfile): 61 | # Create linkage matrix and then plot the dendrogram 62 | 63 | # create the counts of samples under each node 64 | counts = np.zeros(model.children_.shape[0]) 65 | n_samples = len(model.labels_) 66 | for i, merge in enumerate(model.children_): 67 | current_count = 0 68 | for child_idx in merge: 69 | if child_idx < n_samples: 70 | current_count += 1 # leaf node 71 | else: 72 | current_count += counts[child_idx - n_samples] 73 | counts[i] = current_count 74 | 75 | linkage_matrix = np.column_stack( 76 | [model.children_, model.distances_, counts] 77 | ).astype(float) 78 | 79 | # Plot the corresponding dendrogram 80 | plt.figure(figsize=(8, 8)) 81 | # color_threshold = model.distances_.max() * 0.5 82 | kwargs = dict( 83 | Z=linkage_matrix, p=model.n_clusters_, truncate_mode='lastp', 84 | color_threshold=-1) 85 | dendro = dendrogram(**kwargs) 86 | plt.savefig(outfile, dpi=300, bbox_inches='tight') 87 | plt.close() 88 | print(outfile) 89 | 90 | return dendro 91 | 92 | 93 | def cluster( 94 | embs, n_clusters, method='mbkm', location_weight=None, 95 | sort=True): 96 | 97 | x, mask = prepare_for_clustering(embs, location_weight) 98 | 99 | print(f'Clustering pixels using {method}...') 100 | t0 = time() 101 | if method == 'mbkm': 102 | model = MiniBatchKMeans( 103 | n_clusters=n_clusters, 104 | # batch_size=x.shape[0]//10, max_iter=1000, 105 | # max_no_improvement=100, n_init=10, 106 | random_state=0, verbose=0) 107 | elif method == 'km': 108 | model = KMeans( 109 | n_clusters=n_clusters, 110 | random_state=0, verbose=0) 111 | elif method == 'gm': 112 | model = GaussianMixture( 113 | n_components=n_clusters, 114 | covariance_type='diag', init_params='k-means++', 115 | random_state=0, verbose=1) 116 | # elif method == 'dbscan': 117 | # eps = x.var(0).sum()**0.5 * 0.5 118 | # min_samples = 5 119 | # model = DBSCAN(eps=eps, min_samples=min_samples, n_jobs=64) 120 | # elif method == 'hdbscan': 121 | # min_cluster_size = min(1000, x.shape[0] // 400 + 1) 122 | # min_samples = min_cluster_size // 10 + 1 123 | # model = HDBSCAN( 124 | # min_cluster_size=min_cluster_size, 125 | # min_samples=min_samples, 126 | # core_dist_n_jobs=64) 127 | elif method == 'agglomerative': 128 | # knn_graph = kneighbors_graph(x, n_neighbors=10, include_self=False) 129 | model = AgglomerativeClustering( 130 | n_clusters=n_clusters, 131 | linkage='ward', compute_distances=True) 132 | else: 133 | raise ValueError(f'Method `{method}` not recognized') 134 | print(x.shape) 135 | labels = model.fit_predict(x) 136 | print(int(time() - t0), 'sec') 137 | print('n_clusters:', np.unique(labels).size) 138 | 139 | if sort: 140 | labels = sort_labels(labels)[0] 141 | 142 | labels_arr = np.full(mask.shape, labels.min()-1, dtype=int) 143 | labels_arr[mask] = labels 144 | 145 | # if method == 'gm': 146 | # probs = model.predict_proba(embs) 147 | # probs = probs[:, order] 148 | # assert (probs.argmax(-1) == labels).all() 149 | # probs_arr = np.full( 150 | # mask.shape+(n_clusters,), np.nan, dtype=np.float32) 151 | # probs_arr[mask] = probs 152 | # else: 153 | # probs_arr = None 154 | 155 | return labels_arr, model 156 | 157 | 158 | def prepare_for_clustering(embs, location_weight): 159 | mask = np.all([np.isfinite(c) for c in embs], axis=0) 160 | embs = np.stack([c[mask] for c in embs], axis=-1) 161 | 162 | if location_weight is None: 163 | x = embs 164 | else: 165 | embs -= embs.mean(0) 166 | embs /= embs.var(0).sum()**0.5 167 | # get spatial coordinates 168 | locs = np.meshgrid( 169 | *[np.arange(mask.shape[i]) for i in range(mask.ndim)], 170 | indexing='ij') 171 | locs = np.stack(locs, -1).astype(float) 172 | locs = locs[mask] 173 | locs -= locs.mean(0) 174 | locs /= locs.var(0).sum()**0.5 175 | 176 | # balance embeddings and coordinates 177 | embs *= 1 - location_weight 178 | locs *= location_weight 179 | x = np.concatenate([embs, locs], axis=-1) 180 | return x, mask 181 | 182 | 183 | def reduce_embs_dim(x, **kwargs): 184 | x = reduce_dim(np.stack(x, -1), **kwargs)[0] 185 | x = x.transpose(2, 0, 1) 186 | return x 187 | 188 | 189 | def cluster_hierarchical( 190 | x_major, method, n_clusters, 191 | x_minor=None, min_cluster_size=None, reduce_dimension=False, 192 | location_weight=None): 193 | if reduce_dimension: 194 | x_major = reduce_embs_dim(x_major, method='pca', n_components=0.99) 195 | if x_minor is not None: 196 | x_minor = reduce_embs_dim(x_minor, method='pca', n_components=0.99) 197 | 198 | # compute major clusters 199 | labels_cls, __ = cluster( 200 | x_major, method=method, 201 | n_clusters=n_clusters, 202 | location_weight=location_weight) 203 | if min_cluster_size is not None: 204 | labels_cls = relabel_small_connected( 205 | labels_cls, min_size=min_cluster_size) 206 | 207 | # cluster connected components 208 | labels_con = cluster_connected(labels_cls) 209 | 210 | # compute sub clusters 211 | if x_minor is not None: 212 | labels_sub, __ = cluster_sub( 213 | x_minor, 214 | labels=labels_cls, 215 | method=method, 216 | n_clusters=4, 217 | location_weight=None) 218 | labels = [labels_cls, labels_sub, labels_con] 219 | else: 220 | labels = [labels_cls, labels_con] 221 | 222 | # combine cluster labels hierarchically 223 | labels = np.stack(labels, -1) 224 | 225 | return labels 226 | 227 | 228 | def plot_scatter(x, y, lab, outfile): 229 | plt.figure(figsize=(8, 8)) 230 | plt.scatter(x, y, c=lab, cmap='tab10', alpha=0.2) 231 | plt.savefig(outfile, dpi=300, bbox_inches='tight') 232 | plt.close() 233 | print(outfile) 234 | 235 | 236 | def upscale_label(lab, target_shape): 237 | onehot = [lab == la for la in range(lab.max()+1)] 238 | prob = [ 239 | upscale( 240 | oh.astype(np.float32)[..., np.newaxis], 241 | target_shape)[..., 0] 242 | for oh in onehot] 243 | label = np.argmax(prob, 0) 244 | return label 245 | 246 | 247 | def cluster_rescale(x, stride, method, n_clusters): 248 | 249 | img_shape = x[0].shape 250 | isin = np.isfinite(x[0]) 251 | 252 | start = stride // 2 253 | x = x[:, start::stride, start::stride] 254 | 255 | lab, model = cluster(x, method=method, n_clusters=n_clusters) 256 | 257 | label = upscale_label(lab, img_shape) 258 | label[~isin] = -1 259 | 260 | return label, model 261 | 262 | 263 | def flatten_label(label): 264 | img_shape = label.shape[:-1] 265 | label = label.reshape(-1, label.shape[-1]) 266 | n_bins = label.max() 267 | label = label[:, ::-1].T 268 | label = np.sum([lab * n_bins**i for i, lab in enumerate(label)], 0) 269 | label[label < 0] = -1 270 | label = np.unique(label, return_inverse=True)[1] - 1 271 | label = label.reshape(img_shape) 272 | return label 273 | 274 | 275 | def cluster_mbkmagglo(x, n_clusters): 276 | 277 | img_shape = x[0].shape 278 | isin = np.isfinite(x[0]) 279 | 280 | n_clusters_small = n_clusters * 50 281 | min_cluster_size = max(1, isin.sum() // n_clusters_small // 1000) 282 | lab_small = cluster_hierarchical( 283 | x, method='mbkm', n_clusters=n_clusters_small, 284 | min_cluster_size=min_cluster_size, location_weight=0.1) 285 | 286 | lab_flat = flatten_label(lab_small) # convert hierarchical label to 1D 287 | 288 | centroids = [ 289 | x[:, lab_flat == la].mean(1) 290 | for la in range(lab_flat.max()+1)] 291 | model = AgglomerativeClustering( 292 | n_clusters=n_clusters, 293 | linkage='ward', compute_distances=True) 294 | print(np.shape(centroids)) 295 | t0 = time() 296 | lab_cent = model.fit_predict(centroids) 297 | print(int(time() - t0), 'sec') 298 | lab_cent = sort_labels(lab_cent)[0] 299 | 300 | lab_super = lab_cent[lab_flat] 301 | 302 | lab_super = upscale_label(lab_super, img_shape) 303 | lab_super[~isin] = -1 304 | 305 | return lab_super, model 306 | 307 | 308 | def smooth(x, filter_size): 309 | x = x.transpose(1, 2, 0) 310 | x = smoothen(x, filter_size) 311 | x = x.transpose(2, 0, 1) 312 | return x 313 | 314 | 315 | def cluster_and_save( 316 | x, method, n_clusters, min_cluster_size=None, prefix=None): 317 | 318 | labels, __ = cluster(x, method=method, n_clusters=n_clusters) 319 | 320 | if min_cluster_size is not None: 321 | labels = relabel_small_connected( 322 | labels, min_size=min_cluster_size) 323 | 324 | if prefix is not None: 325 | save_pickle(labels, prefix+'labels.pickle') 326 | plot_labels( 327 | labels, prefix+'labels.png', 328 | white_background=True) 329 | plot_label_masks(labels, prefix+'masks/') 330 | 331 | return labels 332 | 333 | 334 | def preprocess_and_cluster( 335 | x, prefix=None, 336 | n_components=None, filter_size=None, 337 | n_clusters=None, min_cluster_size=None, 338 | method='km'): 339 | 340 | if n_components is not None: 341 | x = reduce_embs_dim(x, method='pca', n_components=n_components) 342 | 343 | if filter_size is not None: 344 | t0 = time() 345 | print('Smoothing embeddings...') 346 | x = smooth(x, filter_size) 347 | print(int(time() - t0), 'sec') 348 | 349 | if n_clusters is None: 350 | n_clusters_list = [10, 20, 30, 40, 50, 60, 70] 351 | elif np.size(n_clusters) > 1: 352 | n_clusters_list = n_clusters 353 | elif np.size(n_clusters) == 1: 354 | n_clusters_list = [n_clusters] 355 | 356 | labels_list = [] 357 | for n_clusters in n_clusters_list: 358 | if prefix is not None: 359 | pref = prefix 360 | if len(n_clusters_list) > 1: 361 | pref = f'{pref}nclusters{n_clusters:03d}/' 362 | else: 363 | pref = None 364 | labels = cluster_and_save( 365 | x, n_clusters=n_clusters, 366 | min_cluster_size=min_cluster_size, 367 | method=method, 368 | prefix=pref) 369 | labels_list.append(labels) 370 | return labels_list 371 | 372 | 373 | def main(): 374 | 375 | args = get_args() 376 | 377 | embs = load_pickle(args.embeddings) 378 | 379 | if isinstance(embs, dict): 380 | if 'cls' in embs.keys(): 381 | x = embs['cls'] 382 | else: 383 | x = embs['sub'] 384 | x = np.array(x) 385 | else: 386 | x = embs 387 | 388 | if args.mask is not None: 389 | mask = load_mask(args.mask) 390 | x[:, ~mask] = np.nan 391 | 392 | preprocess_and_cluster( 393 | x, 394 | n_components=args.n_components, 395 | filter_size=args.filter_size, 396 | n_clusters=args.n_clusters, 397 | min_cluster_size=args.min_cluster_size, 398 | method=args.method, 399 | prefix=args.prefix) 400 | 401 | 402 | if __name__ == '__main__': 403 | main() 404 | -------------------------------------------------------------------------------- /connected_components.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import label as label_connected 3 | 4 | from utils import sort_labels, get_most_frequent 5 | 6 | 7 | def get_largest_connected(labels): 8 | labels = label_connected(labels)[0] 9 | labels -= 1 10 | labels = sort_labels(labels)[0] 11 | labels = labels == 0 12 | return labels 13 | 14 | 15 | def get_adjacent(ind): 16 | # return the eight adjacent indices 17 | adj = np.meshgrid([-1, 0, 1], [-1, 0, 1], indexing='ij') 18 | adj = np.stack(adj, -1) 19 | adj = adj.reshape(-1, adj.shape[-1]) 20 | adj = adj[(adj != 0).any(-1)] 21 | adj += ind 22 | return adj 23 | 24 | 25 | def split_by_connected_size_single(labels, min_size): 26 | # return labels of small and large connected components 27 | # labels are binary 28 | labels = label_connected(labels)[0] 29 | labels -= 1 30 | labels = sort_labels(labels)[0] 31 | counts = np.unique(labels[labels >= 0], return_counts=True)[1] 32 | cut = np.sum(counts >= min_size) 33 | small = labels - cut 34 | small[small < 0] = -1 35 | large = labels.copy() 36 | large[labels >= cut] = -1 37 | return small, large 38 | 39 | 40 | def split_by_connected_size(labels, min_size): 41 | # return labels of small and large connected components 42 | # labels can be multi-categorical 43 | labs_uniq = np.unique(labels[labels >= 0]) 44 | small = np.full_like(labels, -1) 45 | large = np.full_like(labels, -1) 46 | for lab in labs_uniq: 47 | isin = labels == lab 48 | sma, lar = split_by_connected_size_single(isin, min_size) 49 | issma = sma >= 0 50 | islar = lar >= 0 51 | small[issma] = sma[issma] + small.max() + 1 52 | large[islar] = lar[islar] + large.max() + 1 53 | return small, large 54 | 55 | 56 | def relabel_small_connected(labels, min_size): 57 | # reassign labels of small connected components 58 | labels = labels.copy() 59 | small, __ = split_by_connected_size(labels, min_size) 60 | small = sort_labels(small, descending=False)[0] 61 | small_uniq = np.unique(small[small >= 0]) 62 | lab_na = min(-1, labels.min() - 1) 63 | for lab_small in small_uniq: 64 | 65 | isin = small == lab_small 66 | lab = labels[isin][0] 67 | 68 | # find adjacent labels 69 | indices = np.stack(np.where(isin), -1) 70 | labs_adj = [] 71 | labs_small_adj = [] 72 | for ind in indices: 73 | adj = get_adjacent(ind) 74 | is_within = np.logical_and( 75 | (adj < labels.shape).all(-1), 76 | (adj >= 0).all(-1)) 77 | adj[~is_within] = 0 # dummy index for out-of-bound 78 | la = labels[adj[:, 0], adj[:, 1]] 79 | lsa = small[adj[:, 0], adj[:, 1]] 80 | la[~is_within] = lab_na 81 | lsa[~is_within] = lab_na 82 | labs_adj.append(la) 83 | labs_small_adj.append(lsa) 84 | labs_adj = np.stack(labs_adj) 85 | labs_small_adj = np.stack(labs_small_adj) 86 | # eliminate background and identical labels 87 | is_other = (labs_adj >= 0) * (labs_adj != lab) 88 | if is_other.any(): 89 | # find most frequent adjacent labels 90 | lab_new = get_most_frequent(labs_adj[is_other]) 91 | # get location of new label 92 | i_new, i_adj_new = np.stack( 93 | np.where(labs_adj == lab_new), -1)[0] 94 | ind_new = get_adjacent(indices[i_new])[i_adj_new] 95 | # update small components 96 | lab_small_new = small[ind_new[0], ind_new[1]] 97 | else: 98 | lab_new = lab 99 | lab_small_new = lab_small 100 | # relabel to most frequent neighboring label 101 | labels[isin] = lab_new 102 | small[isin] = lab_small_new 103 | 104 | return labels 105 | 106 | 107 | def cluster_connected(labels): 108 | # subcluster labels by connectedness 109 | labels = labels.copy() 110 | isfg = labels >= 0 111 | labels_sub = np.full_like(labels, -1) 112 | labels_sub[~isfg] = labels[~isfg] 113 | 114 | labs_uniq = np.unique(labels[isfg]) 115 | for lab in labs_uniq: 116 | isin = labels == lab 117 | sublabs = label_connected(isin)[0] - 1 118 | sublabs = sort_labels(sublabs)[0] 119 | labels_sub[isin] = sublabs[isin] 120 | 121 | labels_sub[~isfg] = labels[~isfg] 122 | return labels_sub 123 | -------------------------------------------------------------------------------- /data/markers/cell-type-template.tsv: -------------------------------------------------------------------------------- 1 | gene label 2 | GENENAME1 cell_type_1 3 | GENENAME2 cell_type_1 4 | GENENAME3 cell_type_1 5 | GENENAME4 cell_type_2 6 | GENENAME5 cell_type_2 7 | -------------------------------------------------------------------------------- /data/markers/signature-score-template.txt: -------------------------------------------------------------------------------- 1 | GENENAME1 2 | GENENAME2 3 | GENENAME3 4 | GENENAME4 5 | -------------------------------------------------------------------------------- /differential.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from scipy.stats import norm as norm_rv 6 | 7 | from utils import load_tsv, save_tsv 8 | 9 | 10 | def aggregate(mean, variance, count, specificity=0.02, n_top=50): 11 | 12 | assert (mean.index == variance.index).all() 13 | assert (mean.index == count.index).all() 14 | assert (mean.columns == variance.columns).all() 15 | assert (mean.columns == count.columns).all() 16 | 17 | gene_names = mean.columns.to_numpy() 18 | cluster_names = mean.index.to_numpy() 19 | mean = mean.to_numpy() 20 | vari = variance.to_numpy() 21 | count = count.to_numpy() 22 | 23 | weight = count / count.sum(0) 24 | vari_global = (vari * weight).sum(0) 25 | mean /= vari_global**0.5 26 | vari /= vari_global + 1e-12 27 | 28 | n_clusters = len(cluster_names) 29 | results = dict( 30 | mean_interior=[], mean_exterior=[], 31 | variance_interior=[], variance_exterior=[], 32 | count_interior=[], count_exterior=[]) 33 | for i in range(n_clusters): 34 | isin = np.arange(n_clusters) == i 35 | mean_in = mean[isin].flatten() 36 | vari_in = vari[isin].flatten() 37 | count_in = count[isin].flatten() 38 | mean_out = (mean[~isin] * weight[~isin]).sum(0) 39 | vari_out = ( 40 | (mean[~isin]**2 * weight[~isin]).sum(0) 41 | - mean_out**2) 42 | count_out = count[~isin].sum(0) 43 | results['mean_interior'].append(mean_in) 44 | results['mean_exterior'].append(mean_out) 45 | results['variance_interior'].append(vari_in) 46 | results['variance_exterior'].append(vari_out) 47 | results['count_interior'].append(count_in) 48 | results['count_exterior'].append(count_out) 49 | 50 | for key in results.keys(): 51 | df = pd.DataFrame(np.array(results[key])) 52 | df.index = cluster_names 53 | df.index.name = 'cluster' 54 | df.columns = gene_names 55 | results[key] = df 56 | 57 | return results 58 | 59 | 60 | def two_sample_test(mean, variance, count, mode): 61 | pval = mean[0].copy() 62 | pval[:] = np.nan 63 | mean = np.stack(mean) 64 | variance = np.stack(variance) 65 | count = np.stack(count) 66 | std = (variance / count).sum(0)**0.5 67 | z = (mean[1] - mean[0]) / (std + 1e-12) 68 | if mode == 'positive-sided': 69 | p = norm_rv.sf(z) 70 | elif mode == 'negative-sided': 71 | p = norm_rv.cdf(z) 72 | elif mode == 'two-sided': 73 | p = norm_rv.sf(np.abs(z)) * 2 74 | else: 75 | raise ValueError('mode not recognized') 76 | pval[:] = p 77 | return pval 78 | 79 | 80 | def save_results(results, prefix, sort_key=None): 81 | cluster_names = results[list(results.keys())[0]].index.to_list() 82 | for cname in cluster_names: 83 | rslt = pd.DataFrame( 84 | {key: val.loc[cname] for key, val in results.items()}) 85 | rslt.index.name = 'gene' 86 | if sort_key is not None: 87 | rslt = rslt.sort_values(sort_key, ascending=False) 88 | first_column = rslt.pop(sort_key) 89 | rslt.insert(0, sort_key, first_column) 90 | save_tsv(rslt, f'{prefix}contrast/by-clusters/cluster-{cname}.tsv') 91 | print( 92 | 'Top overexpressed genes for each cluster saved to ' 93 | f'{prefix}contrast/by-clusters/') 94 | for key, val in results.items(): 95 | save_tsv(val, f'{prefix}contrast/by-metrics/{key}.tsv') 96 | 97 | 98 | def main(): 99 | 100 | prefix = sys.argv[1] # e.g. 'data/her2st/B1/' 101 | prefix = f'{prefix}cnts-clustered/by-clusters/' 102 | 103 | x_mean = load_tsv(f'{prefix}mean.tsv') 104 | x_vari = load_tsv(f'{prefix}variance.tsv') 105 | x_count = load_tsv(f'{prefix}count.tsv') 106 | 107 | results = aggregate(x_mean, x_vari, x_count) 108 | results['fold_change'] = ( 109 | results['mean_interior'] / results['mean_exterior']) 110 | results['pvalue_positive_sided'] = two_sample_test( 111 | mean=( 112 | results['mean_exterior'], 113 | results['mean_interior']), 114 | variance=( 115 | results['variance_exterior'], 116 | results['variance_interior']), 117 | count=( 118 | results['count_exterior'], 119 | results['count_interior']), 120 | mode='positive-sided') 121 | results['variance_raw'] = x_vari 122 | save_results(results, prefix, sort_key='fold_change') 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | source_256="https://upenn.box.com/shared/static/p0hc12l1bpu5c7fzieotv1d6592btv1l.pth" 5 | source_4k="https://upenn.box.com/shared/static/8qayhxzmdjpcr5loi88xtkfbqomag8a9.pth" 6 | target_256="checkpoints/vit256_small_dino.pth" 7 | target_4k="checkpoints/vit4k_xs_dino.pth" 8 | 9 | mkdir -p checkpoints 10 | wget ${source_256} -O ${target_256} 11 | wget ${source_4k} -O ${target_4k} 12 | -------------------------------------------------------------------------------- /download_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | prefix=$1 5 | 6 | source_img="https://upenn.box.com/shared/static/yya0lvlur8aase29hvy630jd06r64tdn.jpg" 7 | source_cnts="https://upenn.box.com/shared/static/kaoo8j31dx5lupyz8dctay7p5x3exqsa.tsv" 8 | source_locs="https://upenn.box.com/shared/static/7nbnorlr2h6tkeyghjibqitztezkadwh.tsv" 9 | source_radius="https://upenn.box.com/shared/static/a8655bmb02q9cqegndnwhcb5r0mqqphi.txt" 10 | source_pixsize="https://upenn.box.com/shared/static/1stmq5ly6iqnljt0uq8rotlki5q8sjfs.txt" 11 | 12 | target_img="${prefix}he-raw.jpg" 13 | target_cnts="${prefix}cnts.tsv" 14 | target_locs="${prefix}locs-raw.tsv" 15 | target_radius="${prefix}radius-raw.txt" 16 | target_pixsize="${prefix}pixel-size-raw.txt" 17 | 18 | mkdir -p `dirname $target_img` 19 | wget ${source_img} -O ${target_img} 20 | wget ${source_cnts} -O ${target_cnts} 21 | wget ${source_locs} -O ${target_locs} 22 | wget ${source_radius} -O ${target_radius} 23 | wget ${source_pixsize} -O ${target_pixsize} 24 | -------------------------------------------------------------------------------- /enrich.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | # from matplotlib.colors import LogNorm 7 | import seaborn as sns 8 | 9 | from utils import load_pickle, read_lines, save_tsv 10 | 11 | 12 | def get_data(prefix0, prefix1): 13 | labels0 = load_pickle(prefix0+'labels.pickle') 14 | labels1 = load_pickle(prefix1+'labels.pickle') 15 | labels1_names = read_lines(prefix1+'label-names.txt') 16 | return labels0, labels1, labels1_names 17 | 18 | 19 | def get_probs(labels0, labels1): 20 | nlabs0 = labels0.max() + 1 21 | nlabs1 = labels1.max() + 1 22 | results = np.full((nlabs0, nlabs1), np.nan) 23 | for l0 in range(nlabs0): 24 | for l1 in range(nlabs1): 25 | m0 = labels0 == l0 26 | m1 = labels1 == l1 27 | results[l0, l1] = (m0 * m1).sum() 28 | results /= np.nansum(results) 29 | return results 30 | 31 | 32 | def plot_probs(df, filename, cmap='tab10'): 33 | 34 | font = {'size': 15} 35 | plt.rc('font', **font) 36 | 37 | cmap = plt.get_cmap(cmap) 38 | color = [cmap(i) for i in range(df.shape[1])] 39 | df = df / np.nansum(df, 1, keepdims=True) 40 | df.plot(kind='bar', stacked=True, color=color) 41 | plt.xlabel('Cluster') 42 | plt.ylabel('Cell type proportion') 43 | plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) 44 | plt.savefig(filename, dpi=300, bbox_inches='tight') 45 | plt.close() 46 | print(filename) 47 | 48 | 49 | def plot_results(df, filename, cmap='tab10'): 50 | 51 | font = {'size': 15} 52 | plt.rc('font', **font) 53 | 54 | cmap = plt.get_cmap(cmap) 55 | color = [cmap(i) for i in range(df.shape[1])] 56 | df.plot(kind='bar', stacked=True, color=color) 57 | plt.xlabel('Cluster') 58 | plt.ylabel('Cell type proportion') 59 | plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) 60 | plt.savefig(filename, dpi=300, bbox_inches='tight') 61 | plt.close() 62 | print(filename) 63 | 64 | 65 | def probs_to_oddsratios(df): 66 | x = df.to_numpy() 67 | oddsratios = np.full_like(x, np.nan) 68 | for i in range(x.shape[0]): 69 | for j in range(x.shape[1]): 70 | not_i = [k for k in range(x.shape[0]) if k != i] 71 | not_j = [k for k in range(x.shape[1]) if k != j] 72 | a = x[i, j] * x[not_i][:, not_j].sum() 73 | b = x[i, not_j].sum() * x[not_i, j].sum() 74 | oddsratios[i, j] = a / b 75 | oddsratios = pd.DataFrame(oddsratios) 76 | oddsratios.columns = df.columns 77 | oddsratios.index = df.index 78 | return oddsratios 79 | 80 | 81 | def plot_enrichment(df, filename): 82 | sns.heatmap( 83 | df, cmap='magma', 84 | annot=True, fmt='.1f', annot_kws={'fontsize': 12}, 85 | square=True, linewidth=0.5) 86 | 87 | # # set x-ticks on top 88 | # ax.set(xlabel='', ylabel='') 89 | # ax.xaxis.tick_top() 90 | 91 | plt.savefig(filename, dpi=300, bbox_inches='tight') 92 | plt.close() 93 | print(filename) 94 | 95 | 96 | def process_oddsratios(df): 97 | df = df.drop(columns='Unclassified') 98 | x = df.to_numpy() 99 | threshold = 2**0.05 # avoid OR == 0.0 after rounding 100 | x[x < threshold] = np.nan 101 | x = np.log2(x) 102 | df[:] = x 103 | df = df.T 104 | return df 105 | 106 | 107 | def main(): 108 | prefix0 = sys.argv[1] # e.g. 'data/her2st/H123/clusters-gene/' 109 | prefix1 = sys.argv[2] # e.g. 'data/her2st/H123/markers/celltype/' 110 | labels0, labels1, labels1_names = get_data(prefix0, prefix1) 111 | 112 | probs = get_probs(labels0, labels1) 113 | probs = pd.DataFrame(probs) 114 | probs.columns = labels1_names 115 | plot_probs( 116 | probs, 117 | cmap='tab10', 118 | filename=prefix0+'proportions.png') 119 | plot_probs( 120 | probs, 121 | cmap='Set3', 122 | filename=prefix0+'proportions-altcmap.png') 123 | 124 | oddsratios = probs_to_oddsratios(probs) 125 | oddsratios = process_oddsratios(oddsratios) 126 | save_tsv( 127 | oddsratios, prefix0+'enrichment.csv', 128 | sep=',', na_rep='NA') 129 | plot_enrichment(oddsratios, prefix0+'enrichment.png') 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | import argparse 4 | 5 | from einops import rearrange, reduce, repeat 6 | import numpy as np 7 | import skimage 8 | import torch 9 | 10 | from utils import load_image 11 | from hipt_model_utils import eval_transforms 12 | from hipt_4k import HIPT_4K 13 | from utils import load_pickle, save_pickle, join 14 | from image import upscale, smoothen 15 | # from distill import distill_embeddings 16 | from connected_components import get_largest_connected 17 | from reduce_dim import reduce_dim 18 | 19 | 20 | def load_mask(filename): 21 | mask = load_image(filename) 22 | mask = mask > 0 23 | if mask.ndim == 3: 24 | mask = mask.any(2) 25 | factor = 16 26 | mask = reduce( 27 | mask.astype(np.float32), 28 | '(h0 h1) (w0 w1) -> h0 w0', 'mean', 29 | h1=factor, w1=factor) > 0.5 30 | return mask 31 | 32 | 33 | def match_foregrounds(embs, largest_only=False): 34 | print('Matching foregrounds...') 35 | t0 = time() 36 | channels = np.concatenate(list(embs.values())) 37 | mask = np.isfinite(channels).all(0) 38 | if largest_only: 39 | mask = get_largest_connected(mask) 40 | for group, channels in embs.items(): 41 | for chan in channels: 42 | chan[~mask] = np.nan 43 | print(int(time() - t0), 'sec') 44 | 45 | 46 | def patchify(x, patch_size): 47 | shape_ori = np.array(x.shape[:2]) 48 | shape_ext = ( 49 | (shape_ori + patch_size - 1) 50 | // patch_size * patch_size) 51 | x = np.pad( 52 | x, 53 | ( 54 | (0, shape_ext[0] - x.shape[0]), 55 | (0, shape_ext[1] - x.shape[1]), 56 | (0, 0)), 57 | mode='edge') 58 | tiles_shape = np.array(x.shape[:2]) // patch_size 59 | # x = rearrange( 60 | # x, '(h1 h) (w1 w) c -> h1 w1 h w c', 61 | # h=patch_size, w=patch_size) 62 | # x = rearrange( 63 | # x, '(h1 h) (w1 w) c -> (h1 w1) h w c', 64 | # h=patch_size, w=patch_size) 65 | tiles = [] 66 | for i0 in range(tiles_shape[0]): 67 | a0 = i0 * patch_size # TODO: change to patch_size[0] 68 | b0 = a0 + patch_size # TODO: change to patch_size[0] 69 | for i1 in range(tiles_shape[1]): 70 | a1 = i1 * patch_size # TODO: change to patch_size[1] 71 | b1 = a1 + patch_size # TODO: change to patch_size[1] 72 | tiles.append(x[a0:b0, a1:b1]) 73 | 74 | shapes = dict( 75 | original=shape_ori, 76 | padded=shape_ext, 77 | tiles=tiles_shape) 78 | return tiles, shapes 79 | 80 | 81 | def get_data(prefix): 82 | img = load_image(f'{prefix}he.jpg') 83 | return img 84 | 85 | 86 | def get_embeddings_sub(model, x): 87 | x = x.astype(np.float32) / 255.0 88 | x = eval_transforms()(x) 89 | x_cls, x_sub = model.forward_all256(x[None]) 90 | x_cls = x_cls.cpu().detach().numpy() 91 | x_sub = x_sub.cpu().detach().numpy() 92 | x_cls = x_cls[0].transpose(1, 2, 0) 93 | x_sub = x_sub[0].transpose(1, 2, 3, 4, 0) 94 | return x_cls, x_sub 95 | 96 | 97 | def get_embeddings_cls(model, x): 98 | x = torch.tensor(x.transpose(2, 0, 1)) 99 | with torch.no_grad(): 100 | __, x_sub4k = model.forward_all4k(x[None]) 101 | x_sub4k = x_sub4k.cpu().detach().numpy() 102 | x_sub4k = x_sub4k[0].transpose(1, 2, 0) 103 | return x_sub4k 104 | 105 | 106 | def get_embeddings(img, pretrained=True, device='cuda'): 107 | ''' 108 | Extract embeddings from histology tiles 109 | Args: 110 | tiles: Histology image tiles. 111 | Shape: (N, H, W, C). 112 | `H` and `W` are both divisible by 256. 113 | Channels `C` include R, G, B, foreground mask. 114 | Returns: 115 | emb_cls: Embeddings of (256 x 256)-sized patches 116 | Shape: (H/256, W/256, 384) 117 | emb_sub: Embeddings of (16 x 16)-sized patches 118 | Shape: (H/16, W/16, 384) 119 | ''' 120 | print('Extracting embeddings...') 121 | t0 = time() 122 | 123 | tile_size = 4096 124 | tiles, shapes = patchify(img, patch_size=tile_size) 125 | 126 | model256_path, model4k_path = None, None 127 | if pretrained: 128 | model256_path = 'checkpoints/vit256_small_dino.pth' 129 | model4k_path = 'checkpoints/vit4k_xs_dino.pth' 130 | model = HIPT_4K( 131 | model256_path=model256_path, 132 | model4k_path=model4k_path, 133 | device256=device, device4k=device) 134 | model.eval() 135 | patch_size = (256, 256) 136 | subpatch_size = (16, 16) 137 | n_subpatches = tuple( 138 | a // b for a, b in zip(patch_size, subpatch_size)) 139 | 140 | emb_sub = [] 141 | emb_mid = [] 142 | for i in range(len(tiles)): 143 | if i % 10 == 0: 144 | print('tile', i, '/', len(tiles)) 145 | x_mid, x_sub = get_embeddings_sub(model, tiles[i]) 146 | emb_mid.append(x_mid) 147 | emb_sub.append(x_sub) 148 | del tiles 149 | torch.cuda.empty_cache() 150 | emb_mid = rearrange( 151 | emb_mid, '(h1 w1) h2 w2 k -> (h1 h2) (w1 w2) k', 152 | h1=shapes['tiles'][0], w1=shapes['tiles'][1]) 153 | 154 | emb_cls = get_embeddings_cls(model, emb_mid) 155 | del emb_mid, model 156 | torch.cuda.empty_cache() 157 | 158 | shape_orig = np.array(shapes['original']) // subpatch_size 159 | 160 | chans_sub = [] 161 | for i in range(emb_sub[0].shape[-1]): 162 | chan = rearrange( 163 | np.array([e[..., i] for e in emb_sub]), 164 | '(h1 w1) h2 w2 h3 w3 -> (h1 h2 h3) (w1 w2 w3)', 165 | h1=shapes['tiles'][0], w1=shapes['tiles'][1]) 166 | chan = chan[:shape_orig[0], :shape_orig[1]] 167 | chans_sub.append(chan) 168 | del emb_sub 169 | 170 | chans_cls = [] 171 | for i in range(emb_cls[0].shape[-1]): 172 | chan = repeat( 173 | np.array([e[..., i] for e in emb_cls]), 174 | 'h12 w12 -> (h12 h3) (w12 w3)', 175 | h3=n_subpatches[0], w3=n_subpatches[1]) 176 | chan = chan[:shape_orig[0], :shape_orig[1]] 177 | chans_cls.append(chan) 178 | del emb_cls 179 | 180 | print(int(time() - t0), 'sec') 181 | 182 | return chans_cls, chans_sub 183 | 184 | 185 | def get_embeddings_shift( 186 | img, margin=256, stride=64, 187 | pretrained=True, device='cuda'): 188 | # margin: margin for shifting. Divisble by 256 189 | # stride: stride for shifting. Divides `margin`. 190 | factor = 16 # scaling factor between cls and sub. Fixed 191 | shape_emb = np.array(img.shape[:2]) // factor 192 | chans_cls = [ 193 | np.zeros(shape_emb, dtype=np.float32) 194 | for __ in range(192)] 195 | chans_sub = [ 196 | np.zeros(shape_emb, dtype=np.float32) 197 | for __ in range(384)] 198 | start_list = list(range(0, margin, stride)) 199 | n_reps = 0 200 | for start0 in start_list: 201 | for start1 in start_list: 202 | print(f'shift {start0}/{margin}, {start1}/{margin}') 203 | t0 = time() 204 | stop0, stop1 = -margin+start0, -margin+start1 205 | im = img[start0:stop0, start1:stop1] 206 | cls, sub = get_embeddings( 207 | im, pretrained=pretrained, device=device) 208 | del im 209 | sta0, sta1 = start0 // factor, start1 // factor 210 | sto0, sto1 = stop0 // factor, stop1 // factor 211 | for i in range(len(chans_cls)): 212 | chans_cls[i][sta0:sto0, sta1:sto1] += cls[i] 213 | del cls 214 | for i in range(len(chans_sub)): 215 | chans_sub[i][sta0:sto0, sta1:sto1] += sub[i] 216 | del sub 217 | n_reps += 1 218 | print(int(time() - t0), 'sec') 219 | 220 | mar = margin // factor 221 | for chan in chans_cls: 222 | chan /= n_reps 223 | chan[-mar:] = 0.0 224 | chan[:, -mar:] = 0.0 225 | for chan in chans_sub: 226 | chan /= n_reps 227 | chan[-mar:] = 0.0 228 | chan[:, -mar:] = 0.0 229 | 230 | return chans_cls, chans_sub 231 | 232 | 233 | def reshape_embeddings(emb_cls, emb_sub, tiles_shape): 234 | # emb_cls = emb_cls.reshape(tiles_shape + emb_cls.shape[1:]) 235 | # emb_sub = emb_sub.reshape(tiles_shape + emb_sub.shape[1:]) 236 | emb_cls = rearrange( 237 | emb_cls, '(h1 w1) h2 w2 k -> (h1 h2) (w1 w2) k', 238 | h1=tiles_shape[0], w1=tiles_shape[1]) 239 | # emb_sub = rearrange( 240 | # emb_sub, 'h1 w1 h2 w2 h3 w3 k -> (h1 h2 h3) (w1 w2 w3) k') 241 | return emb_cls, emb_sub 242 | 243 | 244 | def transpose_channels(x): 245 | return [x[..., i] for i in range(x.shape[-1])] 246 | 247 | 248 | def transpose_embeddings(embs, groups=None): 249 | if groups is None: 250 | groups = embs.keys() 251 | out = {} 252 | for key, chans in embs.items(): 253 | if key in groups: 254 | out[key] = transpose_channels(chans) 255 | else: 256 | out[key] = chans 257 | return out 258 | 259 | 260 | def match_resolutions(embs, target_shape, groups=None): 261 | if groups is None: 262 | groups = embs.keys() 263 | out = {} 264 | for grp, em in embs.items(): 265 | if grp in groups: 266 | print(f'Matching {grp} embedding resolutions...') 267 | t0 = time() 268 | em = [ 269 | upscale(im[..., np.newaxis], target_shape)[..., 0] 270 | for im in em] 271 | print(int(time() - t0), 'sec') 272 | out[grp] = em 273 | 274 | return out 275 | 276 | 277 | def combine_embs(embs): 278 | embs_new = {} 279 | for key, channels in embs.items(): 280 | channels = [c - np.nanmean(c) for c in channels] 281 | variances = [np.nanmean(c**2) for c in channels] 282 | std = np.sum(variances)**0.5 283 | channels = [c / std for c in channels] 284 | embs_new[key] = channels 285 | embs_new = join(list(embs_new.values())) 286 | return embs_new 287 | 288 | 289 | def rearrange_slide(tiles, shape): 290 | tiles = rearrange( 291 | tiles, '(h1 w1) h w c -> (h1 h) (w1 w) c', 292 | h1=shape[0], w1=shape[1]) 293 | return tiles 294 | 295 | 296 | def downscale(x, factors): 297 | x = reduce( 298 | x, '(h1 h) (w1 w) c -> h1 w1 c', 'mean', 299 | h=factors[0], w=factors[1]) 300 | return x 301 | 302 | 303 | def downscale_embedding(emb_dict, factor, groups=None): 304 | if groups is None: 305 | groups = emb_dict.keys() 306 | print('Downscaling slides...') 307 | t0 = time() 308 | factor = (factor, factor) 309 | y = {} 310 | for key, channel_list in emb_dict.items(): 311 | if key in groups: 312 | channel_list_new = [ 313 | downscale(channel[..., np.newaxis], factor)[..., 0] 314 | for channel in channel_list] 315 | else: 316 | channel_list_new = channel_list 317 | y[key] = channel_list_new 318 | print(int(time() - t0), 'sec') 319 | return y 320 | 321 | 322 | def save_embeddings(x, outfile): 323 | print('Saving embeddings...') 324 | t0 = time() 325 | save_pickle(x, outfile) 326 | print(int(time() - t0), 'sec') 327 | print('Embeddings saved to', outfile) 328 | 329 | 330 | def reduce_embs_dim( 331 | embs, n_components, method='pca', balance=False, 332 | groups=None): 333 | print(f'Reducing dimension of embeddings using {method}...') 334 | 335 | if groups is None: 336 | groups = embs.keys() 337 | 338 | embs_dict = {} 339 | models_dict = {} 340 | for grp, em in embs.items(): 341 | if grp in groups: 342 | t0 = time() 343 | em, mod = reduce_dim( 344 | em, n_components=n_components, method=method) 345 | else: 346 | mod = None 347 | embs_dict[grp] = em 348 | models_dict[grp] = mod 349 | print('runtime:', int(time() - t0), 'sec') 350 | 351 | return embs_dict, models_dict 352 | 353 | 354 | def get_args(): 355 | parser = argparse.ArgumentParser() 356 | parser.add_argument('prefix', type=str) 357 | parser.add_argument('--device', type=str, default='cuda') 358 | parser.add_argument('--reduction-method', type=str, default=None) 359 | parser.add_argument('--n-components', type=float, default=None) 360 | parser.add_argument('--smoothen-method', type=str, default='cv') 361 | parser.add_argument('--random-weights', action='store_true') 362 | parser.add_argument('--use-cache', action='store_true') 363 | parser.add_argument('--no-shift', action='store_true') 364 | parser.add_argument('--plot', action='store_true') 365 | args = parser.parse_args() 366 | return args 367 | 368 | 369 | # TODO: try more sophisticated methods in HistomicsTK 370 | def color_deconvolution(x): 371 | mask = np.isfinite(x) 372 | x[~mask] = 0.0 373 | x = (x * 255).astype(np.uint8) 374 | x = skimage.color.rgb2hed(x) 375 | x[~mask] = np.nan 376 | return x 377 | 378 | 379 | def recolor(tiles): 380 | h1, w1 = tiles.shape[:2] # number of tiles 381 | h2, w2 = 16, 16 # number of patches 382 | 383 | tiles = rearrange( 384 | tiles, 385 | 'h1 w1 (h2 h) (w2 w) c -> ' 386 | '(h1 w1 h2 w2) h w c', 387 | h2=h2, w2=w2) 388 | tiles = [color_deconvolution(t) for t in tiles] 389 | tiles = rearrange( 390 | tiles, 391 | '(h1 w1 h2 w2) (h w) c ->' 392 | 'h1 w1 (h2 h) (w2 w) c', 393 | h1=h1, w1=w1, h2=h2, w2=w2) 394 | return tiles 395 | 396 | 397 | def smoothen_embeddings( 398 | embs, size, kernel, 399 | method='cv', groups=None, device='cuda'): 400 | if groups is None: 401 | groups = embs.keys() 402 | out = {} 403 | for grp, em in embs.items(): 404 | if grp in groups: 405 | if isinstance(em, list): 406 | smoothened = [ 407 | smoothen( 408 | c[..., np.newaxis], size=size, 409 | kernel=kernel, backend=method, 410 | device=device)[..., 0] 411 | for c in em] 412 | else: 413 | smoothened = smoothen(em, size, method, device=device) 414 | else: 415 | smoothened = em 416 | out[grp] = smoothened 417 | return out 418 | 419 | 420 | def adjust_weights(embs, weights=None): 421 | print('Adjusting weights...') 422 | t0 = time() 423 | if weights is None: 424 | weights = {grp: 1.0 for grp in embs.keys()} 425 | for grp in embs.keys(): 426 | channels = embs[grp] 427 | wt = weights[grp] 428 | means = np.array([np.nanmean(chan) for chan in channels]) 429 | std = np.sum([np.nanvar(chan) for chan in channels])**0.5 430 | for chan, me in zip(channels, means): 431 | chan[:] -= me 432 | chan[:] /= std 433 | chan[:] *= wt**0.5 434 | print(int(time() - t0), 'sec') 435 | 436 | 437 | def quantize(x, labels, hardness=0.5): 438 | y = np.full_like(x, np.nan) 439 | for lab in np.unique(labels): 440 | isin = lab == labels 441 | y[isin] = x[isin].mean(0) * hardness + x[isin] * (1 - hardness) 442 | return y 443 | 444 | 445 | def main(): 446 | args = get_args() 447 | 448 | np.random.seed(0) 449 | torch.manual_seed(0) 450 | 451 | # load data 452 | wsi = get_data(prefix=args.prefix) 453 | 454 | if args.use_cache: 455 | cache_file = args.prefix + 'embeddings-hist-raw.pickle' 456 | if args.use_cache and os.path.exists(cache_file): 457 | embs = load_pickle(cache_file) 458 | else: 459 | # extract HIPT embeddings 460 | if not args.no_shift: 461 | emb_cls, emb_sub = get_embeddings_shift( 462 | wsi, pretrained=(not args.random_weights), 463 | device=args.device) 464 | else: 465 | emb_cls, emb_sub = get_embeddings( 466 | wsi, pretrained=(not args.random_weights), 467 | device=args.device) 468 | embs = dict(cls=emb_cls, sub=emb_sub) 469 | if args.use_cache: 470 | save_embeddings(embs, cache_file) 471 | 472 | embs['rgb'] = np.stack([ 473 | reduce( 474 | wsi[..., i].astype(np.float16) / 255.0, 475 | '(h1 h) (w1 w) -> h1 w1', 'mean', 476 | h=16, w=16).astype(np.float32) 477 | for i in range(3)]) 478 | del wsi 479 | 480 | # smoothen embeddings 481 | if args.smoothen_method is not None: 482 | print('Smoothening cls embeddings...') 483 | t0 = time() 484 | embs = smoothen_embeddings( 485 | embs, size=16, kernel='uniform', groups=['cls'], 486 | method=args.smoothen_method, 487 | device=args.device) 488 | print('runtime:', int(time()-t0)) 489 | 490 | print('Smoothening sub embeddings...') 491 | t0 = time() 492 | embs = smoothen_embeddings( 493 | embs, size=4, kernel='uniform', groups=['sub'], 494 | method=args.smoothen_method, 495 | device=args.device) 496 | print('runtime:', int(time()-t0)) 497 | 498 | # reduce embedding dimension 499 | if args.reduction_method is not None: 500 | embs, reducers = reduce_embs_dim( 501 | embs, n_components=args.n_components, 502 | method=args.reduction_method, balance=False, 503 | groups=['cls', 'sub']) 504 | save_pickle(reducers, args.prefix+'reducers.pickle') 505 | 506 | save_embeddings(embs, args.prefix + 'embeddings-hist.pickle') 507 | 508 | 509 | if __name__ == '__main__': 510 | main() 511 | -------------------------------------------------------------------------------- /get_mask.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | 5 | from utils import save_image, load_pickle 6 | from cluster import cluster 7 | from connected_components import relabel_small_connected 8 | from image import crop_image 9 | 10 | 11 | def remove_margins(embs, mar): 12 | for ke, va in embs.items(): 13 | embs[ke] = [ 14 | v[mar[0][0]:-mar[0][1], mar[1][0]:-mar[1][1]] 15 | for v in va] 16 | 17 | 18 | def get_mask_embeddings(embs, mar=16, min_connected=4000): 19 | 20 | n_clusters = 2 21 | 22 | # remove margins to avoid border effects 23 | remove_margins(embs, ((mar, mar), (mar, mar))) 24 | 25 | # get features 26 | x = np.concatenate(list(embs.values())) 27 | 28 | # segment image 29 | labels, __ = cluster(x, n_clusters=n_clusters, method='km') 30 | labels = relabel_small_connected(labels, min_size=min_connected) 31 | 32 | # select cluster for foreground 33 | rgb = np.stack(embs['rgb'], -1) 34 | i_foreground = np.argmax([ 35 | rgb[labels == i].std() for i in range(n_clusters)]) 36 | mask = labels == i_foreground 37 | 38 | # restore margins 39 | extent = [(-mar, s+mar) for s in mask.shape] 40 | mask = crop_image( 41 | mask, extent, 42 | mode='constant', constant_values=mask.min()) 43 | 44 | return mask 45 | 46 | 47 | def main(): 48 | 49 | inpfile = sys.argv[1] 50 | outfile = sys.argv[2] 51 | 52 | embs = load_pickle(inpfile) 53 | mask = get_mask_embeddings(embs) 54 | save_image(mask, outfile) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /hipt_4k.py: -------------------------------------------------------------------------------- 1 | # LinAlg / Stats / Plotting Dependencies 2 | from PIL import Image 3 | from einops import rearrange 4 | 5 | # Torch Dependencies 6 | import torch 7 | import torch.multiprocessing 8 | from torchvision import transforms 9 | 10 | # Local Dependencies 11 | # from hipt_heatmap_utils import * 12 | from hipt_model_utils import ( 13 | get_vit256, 14 | get_vit4k, 15 | tensorbatch2im, 16 | eval_transforms, 17 | ) 18 | 19 | 20 | Image.MAX_IMAGE_PIXELS = None 21 | torch.multiprocessing.set_sharing_strategy("file_system") 22 | 23 | 24 | class HIPT_4K(torch.nn.Module): 25 | """ 26 | HIPT Model (ViT-4K) for encoding non-square images (with [256 x 256] patch tokens), with 27 | [256 x 256] patch tokens encoded via ViT-256 using [16 x 16] patch tokens. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | model256_path=None, 33 | model4k_path=None, 34 | device256=torch.device("cuda:0"), 35 | device4k=torch.device("cuda:0"), 36 | ): 37 | super().__init__() 38 | self.model256 = get_vit256(pretrained_weights=model256_path).to( 39 | device256 40 | ) 41 | self.model4k = get_vit4k(pretrained_weights=model4k_path).to(device4k) 42 | self.device256 = device256 43 | self.device4k = device4k 44 | 45 | def forward(self, x): 46 | return self.forward_all(x)[0] 47 | 48 | def forward_all(self, x): 49 | """ 50 | Forward pass of HIPT (given an image tensor x), outputting the [CLS] token from ViT-4K. 51 | 1. x is center-cropped such that the W / H is divisible by the patch token size in ViT-4K (e.g. - 256 x 256). 52 | 2. x then gets unfolded into a "batch" of [256 x 256] images. 53 | 3. A pretrained ViT-256 model extracts the CLS token from each [256 x 256] image in the batch. 54 | 4. These batch-of-features are then reshaped into a 2D feature grid (of width "w_256" and height "h_256".) 55 | 5. This feature grid is then used as the input to ViT-4K, outputting [CLS]_4K. 56 | 57 | Args: 58 | - x (torch.Tensor): [1 x C x W' x H'] image tensor. 59 | 60 | Return: 61 | - features_cls4k (torch.Tensor): [1 x 192] cls token (d_4k = 192 by default). 62 | """ 63 | features_cls256, features_sub256 = self.forward_all256(x) 64 | features_cls4k, features_sub4k = self.forward_all4k(features_cls256) 65 | 66 | return features_cls4k, features_sub4k, features_sub256 67 | 68 | def forward_all256(self, x): 69 | batch_256, w_256, h_256 = self.prepare_img_tensor( 70 | x 71 | ) # 1. [1 x 3 x W x H] 72 | batch_256 = batch_256.unfold(2, 256, 256).unfold( 73 | 3, 256, 256 74 | ) # 2. [1 x 3 x w_256 x h_256 x 256 x 256] 75 | batch_256 = rearrange( 76 | batch_256, "b c p1 p2 w h -> (b p1 p2) c w h" 77 | ) # 2. [B x 3 x 256 x 256], where B = (1*w_256*h_256) 78 | 79 | features_cls256 = [] 80 | features_sub256 = [] 81 | for mini_bs in range( 82 | 0, batch_256.shape[0], 256 83 | ): # 3. B may be too large for ViT-256. We further take minibatches of 256. 84 | minibatch_256 = batch_256[mini_bs:mini_bs + 256].to( 85 | self.device256, non_blocking=True 86 | ) 87 | fea_all256 = self.model256.forward_all(minibatch_256).cpu() 88 | fea_cls256 = fea_all256[:, 0] 89 | fea_sub256 = fea_all256[:, 1:] 90 | features_cls256.append( 91 | fea_cls256 92 | ) # 3. Extracting ViT-256 features from [256 x 3 x 256 x 256] image batches. 93 | features_sub256.append(fea_sub256) 94 | 95 | features_cls256 = torch.vstack( 96 | features_cls256 97 | ) # 3. [B x 384], where 384 == dim of ViT-256 [ClS] token. 98 | features_sub256 = torch.vstack(features_sub256) 99 | features_cls256 = ( 100 | features_cls256.reshape(w_256, h_256, 384) 101 | .transpose(0, 1) 102 | .transpose(0, 2) 103 | .unsqueeze(dim=0) 104 | ) # [1 x 384 x w_256 x h_256] 105 | features_sub256 = ( 106 | features_sub256.reshape(w_256, h_256, 16, 16, 384) 107 | .permute(4, 0, 1, 2, 3) 108 | .unsqueeze(dim=0) 109 | ) # [1 x 384 x w_256 x h_256 x 16 x 16] 110 | return features_cls256, features_sub256 111 | 112 | def forward_all4k(self, features_cls256): 113 | __, __, w_256, h_256 = features_cls256.shape 114 | features_cls256 = features_cls256.to(self.device4k, non_blocking=True) 115 | features_all4k = self.model4k.forward_all(features_cls256) 116 | # attn_all4k = self.model4k.get_last_selfattention(features_cls256) 117 | features_cls4k = features_all4k[ 118 | :, 0 119 | ] # 5. [1 x 192], where 192 == dim of ViT-4K [ClS] token. 120 | features_sub4k = features_all4k[:, 1:] 121 | features_sub4k = features_sub4k.reshape(1, w_256, h_256, 192).permute( 122 | 0, 3, 1, 2 123 | ) 124 | return features_cls4k, features_sub4k 125 | 126 | def forward_asset_dict(self, x: torch.Tensor): 127 | """ 128 | Forward pass of HIPT (given an image tensor x), with certain intermediate representations saved in 129 | a dictionary (that is to be stored in a H5 file). See walkthrough of how the model works above. 130 | 131 | Args: 132 | - x (torch.Tensor): [1 x C x W' x H'] image tensor. 133 | 134 | Return: 135 | - asset_dict (dict): Dictionary of intermediate feature representations of HIPT and other metadata. 136 | - features_cls256 (np.array): [B x 384] extracted ViT-256 cls tokens 137 | - features_mean256 (np.array): [1 x 384] mean ViT-256 cls token (exluding non-tissue patches) 138 | - features_4k (np.array): [1 x 192] extracted ViT-4K cls token. 139 | - features_4k (np.array): [1 x 576] feature vector (concatenating mean ViT-256 + ViT-4K cls tokens) 140 | 141 | """ 142 | batch_256, w_256, h_256 = self.prepare_img_tensor(x) 143 | batch_256 = batch_256.unfold(2, 256, 256).unfold(3, 256, 256) 144 | batch_256 = rearrange(batch_256, "b c p1 p2 w h -> (b p1 p2) c w h") 145 | 146 | features_cls256 = [] 147 | for mini_bs in range(0, batch_256.shape[0], 256): 148 | minibatch_256 = batch_256[mini_bs:mini_bs + 256].to( 149 | self.device256, non_blocking=True 150 | ) 151 | features_cls256.append(self.model256(minibatch_256).detach().cpu()) 152 | 153 | features_cls256 = torch.vstack(features_cls256) 154 | features_mean256 = features_cls256.mean(dim=0).unsqueeze(dim=0) 155 | 156 | features_grid256 = ( 157 | features_cls256.reshape(w_256, h_256, 384) 158 | .transpose(0, 1) 159 | .transpose(0, 2) 160 | .unsqueeze(dim=0) 161 | ) 162 | features_grid256 = features_grid256.to( 163 | self.device4k, non_blocking=True 164 | ) 165 | features_cls4k = self.model4k.forward(features_grid256).detach().cpu() 166 | features_mean256_cls4k = torch.cat( 167 | [features_mean256, features_cls4k], dim=1 168 | ) 169 | 170 | asset_dict = { 171 | "features_cls256": features_cls256.numpy(), 172 | "features_mean256": features_mean256.numpy(), 173 | "features_cls4k": features_cls4k.numpy(), 174 | "features_mean256_cls4k": features_mean256_cls4k.numpy(), 175 | } 176 | return asset_dict 177 | 178 | def _get_region_attention_scores(self, region, scale=1): 179 | r""" 180 | Forward pass in hierarchical model with attention scores saved. 181 | 182 | Args: 183 | - region (PIL.Image): 4096 x 4096 Image 184 | - model256 (torch.nn): 256-Level ViT 185 | - model4k (torch.nn): 4096-Level ViT 186 | - scale (int): How much to scale the output image by (e.g. - scale=4 will resize images to be 1024 x 1024.) 187 | 188 | Returns: 189 | - np.array: [256, 256/scale, 256/scale, 3] np.array sequence of image patches from the 4K x 4K region. 190 | - attention_256 (torch.Tensor): [256, 256/scale, 256/scale, 3] torch.Tensor sequence of attention maps for 256-sized patches. 191 | - attention_4k (torch.Tensor): [1, 4096/scale, 4096/scale, 3] torch.Tensor sequence of attention maps for 4k-sized regions. 192 | """ 193 | x = eval_transforms()(region).unsqueeze(dim=0) 194 | 195 | batch_256, w_256, h_256 = self.prepare_img_tensor(x) 196 | batch_256 = batch_256.unfold(2, 256, 256).unfold(3, 256, 256) 197 | batch_256 = rearrange(batch_256, "b c p1 p2 w h -> (b p1 p2) c w h") 198 | batch_256 = batch_256.to(self.device256, non_blocking=True) 199 | features_cls256 = self.model256(batch_256) 200 | 201 | attention_256 = self.model256.get_last_selfattention(batch_256) 202 | nh = attention_256.shape[1] # number of head 203 | attention_256 = attention_256[:, :, 0, 1:].reshape(256, nh, -1) 204 | attention_256 = attention_256.reshape(w_256 * h_256, nh, 16, 16) 205 | attention_256 = ( 206 | torch.nn.functional.interpolate( 207 | attention_256, scale_factor=int(16 / scale), mode="nearest" 208 | ) 209 | .cpu() 210 | .numpy() 211 | ) 212 | 213 | features_grid256 = ( 214 | features_cls256.reshape(w_256, h_256, 384) 215 | .transpose(0, 1) 216 | .transpose(0, 2) 217 | .unsqueeze(dim=0) 218 | ) 219 | features_grid256 = features_grid256.to( 220 | self.device4k, non_blocking=True 221 | ) 222 | # features_cls4k = self.model4k.forward(features_grid256).detach().cpu() 223 | 224 | attention_4k = self.model4k.get_last_selfattention(features_grid256) 225 | nh = attention_4k.shape[1] # number of head 226 | attention_4k = attention_4k[0, :, 0, 1:].reshape(nh, -1) 227 | attention_4k = attention_4k.reshape(nh, w_256, h_256) 228 | attention_4k = ( 229 | torch.nn.functional.interpolate( 230 | attention_4k.unsqueeze(0), 231 | scale_factor=int(256 / scale), 232 | mode="nearest", 233 | )[0] 234 | .cpu() 235 | .numpy() 236 | ) 237 | 238 | if scale != 1: 239 | batch_256 = torch.nn.functional.interpolate( 240 | batch_256, scale_factor=(1 / scale), mode="nearest" 241 | ) 242 | 243 | return tensorbatch2im(batch_256), attention_256, attention_4k 244 | 245 | def prepare_img_tensor(self, img: torch.Tensor, patch_size=256): 246 | """ 247 | Helper function that takes a non-square image tensor, and takes a center crop s.t. the width / height 248 | are divisible by 256. 249 | 250 | (Note: "_256" for w / h is should technically be renamed as "_ps", but may not be easier to read. 251 | Until I need to make HIPT with patch_sizes != 256, keeping the naming convention as-is.) 252 | 253 | Args: 254 | - img (torch.Tensor): [1 x C x W' x H'] image tensor. 255 | - patch_size (int): Desired patch size to evenly subdivide the image. 256 | 257 | Return: 258 | - img_new (torch.Tensor): [1 x C x W x H] image tensor, where W and H are divisble by patch_size. 259 | - w_256 (int): # of [256 x 256] patches of img_new's width (e.g. - W/256) 260 | - h_256 (int): # of [256 x 256] patches of img_new's height (e.g. - H/256) 261 | """ 262 | make_divisble = lambda l, patch_size: (l - (l % patch_size)) 263 | b, c, w, h = img.shape 264 | load_size = make_divisble(w, patch_size), make_divisble(h, patch_size) 265 | w_256, h_256 = w // patch_size, h // patch_size 266 | img_new = transforms.CenterCrop(load_size)(img) 267 | return img_new, w_256, h_256 268 | -------------------------------------------------------------------------------- /hipt_heatmap_utils.py: -------------------------------------------------------------------------------- 1 | ### Dependencies 2 | # Base Dependencies 3 | import argparse 4 | import colorsys 5 | from io import BytesIO 6 | import os 7 | import random 8 | import requests 9 | import sys 10 | 11 | # LinAlg / Stats / Plotting Dependencies 12 | import cv2 13 | import h5py 14 | import matplotlib 15 | import matplotlib.pyplot as plt 16 | from matplotlib.patches import Polygon 17 | import numpy as np 18 | from PIL import Image 19 | from PIL import ImageFont 20 | from PIL import ImageDraw 21 | from scipy.stats import rankdata 22 | import skimage.io 23 | from skimage.measure import find_contours 24 | from tqdm import tqdm 25 | import webdataset as wds 26 | 27 | # Torch Dependencies 28 | import torch 29 | import torch.multiprocessing 30 | import torchvision 31 | from torchvision import transforms 32 | from einops import rearrange, repeat 33 | torch.multiprocessing.set_sharing_strategy('file_system') 34 | 35 | 36 | def concat_scores256(attns, w_256, h_256, size=(256,256)): 37 | r""" 38 | 39 | """ 40 | rank = lambda v: rankdata(v)*100/len(v) 41 | color_block = [rank(attn.flatten()).reshape(size) for attn in attns] 42 | color_hm = np.concatenate([ 43 | np.concatenate(color_block[i:(i+h_256)], axis=1) 44 | for i in range(0,h_256*w_256,h_256) 45 | ]) 46 | return color_hm 47 | 48 | 49 | def concat_scores4k(attn, size=(4096, 4096)): 50 | r""" 51 | 52 | """ 53 | rank = lambda v: rankdata(v)*100/len(v) 54 | color_hm = rank(attn.flatten()).reshape(size) 55 | return color_hm 56 | 57 | 58 | def get_scores256(attns, size=(256,256)): 59 | r""" 60 | """ 61 | rank = lambda v: rankdata(v)*100/len(v) 62 | color_block = [rank(attn.flatten()).reshape(size) for attn in attns][0] 63 | return color_block 64 | 65 | 66 | def cmap_map(function, cmap): 67 | r""" 68 | Applies function (which should operate on vectors of shape 3: [r, g, b]), on colormap cmap. 69 | This routine will break any discontinuous points in a colormap. 70 | 71 | Args: 72 | - function (function) 73 | - cmap (matplotlib.colormap) 74 | 75 | Returns: 76 | - matplotlib.colormap 77 | """ 78 | cdict = cmap._segmentdata 79 | step_dict = {} 80 | # Firt get the list of points where the segments start or end 81 | for key in ('red', 'green', 'blue'): 82 | step_dict[key] = list(map(lambda x: x[0], cdict[key])) 83 | step_list = sum(step_dict.values(), []) 84 | step_list = np.array(list(set(step_list))) 85 | # Then compute the LUT, and apply the function to the LUT 86 | reduced_cmap = lambda step : np.array(cmap(step)[0:3]) 87 | old_LUT = np.array(list(map(reduced_cmap, step_list))) 88 | new_LUT = np.array(list(map(function, old_LUT))) 89 | # Now try to make a minimal segment definition of the new LUT 90 | cdict = {} 91 | for i, key in enumerate(['red','green','blue']): 92 | this_cdict = {} 93 | for j, step in enumerate(step_list): 94 | if step in step_dict[key]: 95 | this_cdict[step] = new_LUT[j, i] 96 | elif new_LUT[j,i] != old_LUT[j, i]: 97 | this_cdict[step] = new_LUT[j, i] 98 | colorvector = list(map(lambda x: x + (x[1], ), this_cdict.items())) 99 | colorvector.sort() 100 | cdict[key] = colorvector 101 | 102 | return matplotlib.colors.LinearSegmentedColormap('colormap', cdict, 1024) 103 | 104 | 105 | def getConcatImage(imgs, how='horizontal', gap=0): 106 | r""" 107 | Function to concatenate list of images (vertical or horizontal). 108 | 109 | Args: 110 | - imgs (list of PIL.Image): List of PIL Images to concatenate. 111 | - how (str): How the images are concatenated (either 'horizontal' or 'vertical') 112 | - gap (int): Gap (in px) between images 113 | 114 | Return: 115 | - dst (PIL.Image): Concatenated image result. 116 | """ 117 | gap_dist = (len(imgs)-1)*gap 118 | 119 | if how == 'vertical': 120 | w, h = np.max([img.width for img in imgs]), np.sum([img.height for img in imgs]) 121 | h += gap_dist 122 | curr_h = 0 123 | dst = Image.new('RGBA', (w, h), color=(255, 255, 255, 0)) 124 | for img in imgs: 125 | dst.paste(img, (0, curr_h)) 126 | curr_h += img.height + gap 127 | 128 | elif how == 'horizontal': 129 | w, h = np.sum([img.width for img in imgs]), np.min([img.height for img in imgs]) 130 | w += gap_dist 131 | curr_w = 0 132 | dst = Image.new('RGBA', (w, h), color=(255, 255, 255, 0)) 133 | 134 | for idx, img in enumerate(imgs): 135 | dst.paste(img, (curr_w, 0)) 136 | curr_w += img.width + gap 137 | 138 | return dst 139 | 140 | 141 | def add_margin(pil_img, top, right, bottom, left, color): 142 | r""" 143 | Adds custom margin to PIL.Image. 144 | """ 145 | width, height = pil_img.size 146 | new_width = width + right + left 147 | new_height = height + top + bottom 148 | result = Image.new(pil_img.mode, (new_width, new_height), color) 149 | result.paste(pil_img, (left, top)) 150 | return result 151 | 152 | ################################################ 153 | # 256 x 256 ("Patch") Attention Heatmap Creation 154 | ################################################ 155 | def create_patch_heatmaps_indiv(patch, model256, output_dir, fname, threshold=0.5, 156 | offset=16, alpha=0.5, cmap=plt.get_cmap('coolwarm'), device256=torch.device('cuda:0')): 157 | r""" 158 | Creates patch heatmaps (saved individually) 159 | 160 | To be refactored! 161 | 162 | Args: 163 | - patch (PIL.Image): 256 x 256 Image 164 | - model256 (torch.nn): 256-Level ViT 165 | - output_dir (str): Save directory / subdirectory 166 | - fname (str): Naming structure of files 167 | - offset (int): How much to offset (from top-left corner with zero-padding) the region by for blending 168 | - alpha (float): Image blending factor for cv2.addWeighted 169 | - cmap (matplotlib.pyplot): Colormap for creating heatmaps 170 | 171 | Returns: 172 | - None 173 | """ 174 | patch1 = patch.copy() 175 | patch2 = add_margin(patch.crop((16,16,256,256)), top=0, left=0, bottom=16, right=16, color=(255,255,255)) 176 | b256_1, a256_1 = get_patch_attention_scores(patch1, model256, device256=device256) 177 | b256_1, a256_2 = get_patch_attention_scores(patch2, model256, device256=device256) 178 | save_region = np.array(patch.copy()) 179 | s = 256 180 | offset_2 = offset 181 | 182 | if threshold != None: 183 | for i in range(6): 184 | score256_1 = get_scores256(a256_1[:,i,:,:], size=(s,)*2) 185 | score256_2 = get_scores256(a256_2[:,i,:,:], size=(s,)*2) 186 | new_score256_2 = np.zeros_like(score256_2) 187 | new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] 188 | overlay256 = np.ones_like(score256_2)*100 189 | overlay256[offset_2:s, offset_2:s] += 100 190 | score256 = (score256_1+new_score256_2)/overlay256 191 | 192 | mask256 = score256.copy() 193 | mask256[mask256 < threshold] = 0 194 | mask256[mask256 > threshold] = 0.95 195 | 196 | color_block256 = (cmap(mask256)*255)[:,:,:3].astype(np.uint8) 197 | region256_hm = cv2.addWeighted(color_block256, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 198 | region256_hm[mask256==0] = 0 199 | img_inverse = save_region.copy() 200 | img_inverse[mask256 == 0.95] = 0 201 | Image.fromarray(region256_hm+img_inverse).save(os.path.join(output_dir, '%s_256th[%d].png' % (fname, i))) 202 | 203 | for i in range(6): 204 | score256_1 = get_scores256(a256_1[:,i,:,:], size=(s,)*2) 205 | score256_2 = get_scores256(a256_2[:,i,:,:], size=(s,)*2) 206 | new_score256_2 = np.zeros_like(score256_2) 207 | new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] 208 | overlay256 = np.ones_like(score256_2)*100 209 | overlay256[offset_2:s, offset_2:s] += 100 210 | score256 = (score256_1+new_score256_2)/overlay256 211 | color_block256 = (cmap(score256)*255)[:,:,:3].astype(np.uint8) 212 | region256_hm = cv2.addWeighted(color_block256, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 213 | Image.fromarray(region256_hm).save(os.path.join(output_dir, '%s_256[%s].png' % (fname, i))) 214 | 215 | 216 | def create_patch_heatmaps_concat(patch, model256, output_dir, fname, threshold=0.5, 217 | offset=16, alpha=0.5, cmap=plt.get_cmap('coolwarm'), device256=torch.device('cuda:0')): 218 | r""" 219 | Creates patch heatmaps (concatenated for easy comparison) 220 | 221 | To be refactored! 222 | 223 | Args: 224 | - patch (PIL.Image): 256 x 256 Image 225 | - model256 (torch.nn): 256-Level ViT 226 | - output_dir (str): Save directory / subdirectory 227 | - fname (str): Naming structure of files 228 | - offset (int): How much to offset (from top-left corner with zero-padding) the region by for blending 229 | - alpha (float): Image blending factor for cv2.addWeighted 230 | - cmap (matplotlib.pyplot): Colormap for creating heatmaps 231 | 232 | Returns: 233 | - None 234 | """ 235 | patch1 = patch.copy() 236 | patch2 = add_margin(patch.crop((16,16,256,256)), top=0, left=0, bottom=16, right=16, color=(255,255,255)) 237 | b256_1, a256_1 = get_patch_attention_scores(patch1, model256, device256=device256) 238 | b256_1, a256_2 = get_patch_attention_scores(patch2, model256, device256=device256) 239 | save_region = np.array(patch.copy()) 240 | s = 256 241 | offset_2 = offset 242 | 243 | if threshold != None: 244 | ths = [] 245 | for i in range(6): 246 | score256_1 = get_scores256(a256_1[:,i,:,:], size=(s,)*2) 247 | score256_2 = get_scores256(a256_2[:,i,:,:], size=(s,)*2) 248 | new_score256_2 = np.zeros_like(score256_2) 249 | new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] 250 | overlay256 = np.ones_like(score256_2)*100 251 | overlay256[offset_2:s, offset_2:s] += 100 252 | score256 = (score256_1+new_score256_2)/overlay256 253 | 254 | mask256 = score256.copy() 255 | mask256[mask256 < threshold] = 0 256 | mask256[mask256 > threshold] = 0.95 257 | 258 | color_block256 = (cmap(mask256)*255)[:,:,:3].astype(np.uint8) 259 | region256_hm = cv2.addWeighted(color_block256, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 260 | region256_hm[mask256==0] = 0 261 | img_inverse = save_region.copy() 262 | img_inverse[mask256 == 0.95] = 0 263 | ths.append(region256_hm+img_inverse) 264 | 265 | ths = [Image.fromarray(img) for img in ths] 266 | 267 | getConcatImage([getConcatImage(ths[0:3]), 268 | getConcatImage(ths[3:6])], how='vertical').save(os.path.join(output_dir, '%s_256th.png' % (fname))) 269 | 270 | 271 | hms = [] 272 | for i in range(6): 273 | score256_1 = get_scores256(a256_1[:,i,:,:], size=(s,)*2) 274 | score256_2 = get_scores256(a256_2[:,i,:,:], size=(s,)*2) 275 | new_score256_2 = np.zeros_like(score256_2) 276 | new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] 277 | overlay256 = np.ones_like(score256_2)*100 278 | overlay256[offset_2:s, offset_2:s] += 100 279 | score256 = (score256_1+new_score256_2)/overlay256 280 | color_block256 = (cmap(score256)*255)[:,:,:3].astype(np.uint8) 281 | region256_hm = cv2.addWeighted(color_block256, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 282 | hms.append(region256_hm) 283 | 284 | hms = [Image.fromarray(img) for img in hms] 285 | 286 | getConcatImage([getConcatImage(hms[0:3]), 287 | getConcatImage(hms[3:6])], how='vertical').save(os.path.join(output_dir, '%s_256hm.png' % (fname))) 288 | 289 | 290 | ################################################ 291 | # 4096 x 4096 ("Region") Attention Heatmap Creation 292 | ################################################ 293 | def get_region_attention_scores(region, model256, model4k, scale=1, 294 | device256=torch.device('cuda:0'), 295 | device4k=torch.device('cuda:0')): 296 | r""" 297 | Forward pass in hierarchical model with attention scores saved. 298 | 299 | To be refactored! 300 | 301 | Args: 302 | - region (PIL.Image): 4096 x 4096 Image 303 | - model256 (torch.nn): 256-Level ViT 304 | - model4k (torch.nn): 4096-Level ViT 305 | - scale (int): How much to scale the output image by (e.g. - scale=4 will resize images to be 1024 x 1024.) 306 | 307 | Returns: 308 | - np.array: [256, 256/scale, 256/scale, 3] np.array sequence of image patches from the 4K x 4K region. 309 | - attention_256 (torch.Tensor): [256, 256/scale, 256/scale, 3] torch.Tensor sequence of attention maps for 256-sized patches. 310 | - attention_4k (torch.Tensor): [1, 4096/scale, 4096/scale, 3] torch.Tensor sequence of attention maps for 4k-sized regions. 311 | """ 312 | t = transforms.Compose([ 313 | transforms.ToTensor(), 314 | transforms.Normalize( 315 | [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 316 | ) 317 | ]) 318 | 319 | with torch.no_grad(): 320 | batch_256 = t(region).unsqueeze(0).unfold(2, 256, 256).unfold(3, 256, 256) 321 | batch_256 = rearrange(batch_256, 'b c p1 p2 w h -> (b p1 p2) c w h') 322 | batch_256 = batch_256.to(device256, non_blocking=True) 323 | features_256 = model256(batch_256) 324 | 325 | attention_256 = model256.get_last_selfattention(batch_256) 326 | nh = attention_256.shape[1] # number of head 327 | attention_256 = attention_256[:, :, 0, 1:].reshape(256, nh, -1) 328 | attention_256 = attention_256.reshape(256, nh, 16, 16) 329 | attention_256 = nn.functional.interpolate(attention_256, scale_factor=int(16/scale), mode="nearest").cpu().numpy() 330 | 331 | features_4096 = features_256.unfold(0, 16, 16).transpose(0,1).unsqueeze(dim=0) 332 | attention_4096 = model4k.get_last_selfattention(features_4096.detach().to(device4k)) 333 | nh = attention_4096.shape[1] # number of head 334 | attention_4096 = attention_4096[0, :, 0, 1:].reshape(nh, -1) 335 | attention_4096 = attention_4096.reshape(nh, 16, 16) 336 | attention_4096 = nn.functional.interpolate(attention_4096.unsqueeze(0), scale_factor=int(256/scale), mode="nearest")[0].cpu().numpy() 337 | 338 | if scale != 1: 339 | batch_256 = nn.functional.interpolate(batch_256, scale_factor=(1/scale), mode="nearest") 340 | 341 | return tensorbatch2im(batch_256), attention_256, attention_4096 342 | 343 | 344 | def create_hierarchical_heatmaps_indiv(region, model256, model4k, output_dir, fname, 345 | offset=128, scale=4, alpha=0.5, cmap = plt.get_cmap('coolwarm'), threshold=None, 346 | device256=torch.device('cuda:0'), device4k=torch.device('cuda:0')): 347 | r""" 348 | Creates hierarchical heatmaps (Raw H&E + ViT-256 + ViT-4K + Blended Heatmaps saved individually). 349 | 350 | To be refactored! 351 | 352 | Args: 353 | - region (PIL.Image): 4096 x 4096 Image 354 | - model256 (torch.nn): 256-Level ViT 355 | - model4k (torch.nn): 4096-Level ViT 356 | - output_dir (str): Save directory / subdirectory 357 | - fname (str): Naming structure of files 358 | - offset (int): How much to offset (from top-left corner with zero-padding) the region by for blending 359 | - scale (int): How much to scale the output image by 360 | - alpha (float): Image blending factor for cv2.addWeighted 361 | - cmap (matplotlib.pyplot): Colormap for creating heatmaps 362 | 363 | Returns: 364 | - None 365 | """ 366 | 367 | region2 = add_margin(region.crop((128,128,4096,4096)), 368 | top=0, left=0, bottom=128, right=128, color=(255,255,255)) 369 | region3 = add_margin(region.crop((128*2,128*2,4096,4096)), 370 | top=0, left=0, bottom=128*2, right=128*2, color=(255,255,255)) 371 | region4 = add_margin(region.crop((128*3,128*3,4096,4096)), 372 | top=0, left=0, bottom=128*4, right=128*4, color=(255,255,255)) 373 | 374 | b256_1, a256_1, a4k_1 = get_region_attention_scores(region, model256, model4k, scale, device256=device256, device4k=device4k) 375 | b256_2, a256_2, a4k_2 = get_region_attention_scores(region2, model256, model4k, scale, device256=device256, device4k=device4k) 376 | b256_3, a256_3, a4k_3 = get_region_attention_scores(region3, model256, model4k, scale, device256=device256, device4k=device4k) 377 | b256_4, a256_4, a4k_4 = get_region_attention_scores(region4, model256, model4k, scale, device256=device256, device4k=device4k) 378 | offset_2 = (offset*1)//scale 379 | offset_3 = (offset*2)//scale 380 | offset_4 = (offset*3)//scale 381 | s = 4096//scale 382 | save_region = np.array(region.resize((s, s))) 383 | 384 | if threshold != None: 385 | for i in range(6): 386 | score256_1 = concat_scores256(a256_1[:,i,:,:], size=(s//16,)*2) 387 | score256_2 = concat_scores256(a256_2[:,i,:,:], size=(s//16,)*2) 388 | new_score256_2 = np.zeros_like(score256_2) 389 | new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] 390 | overlay256 = np.ones_like(score256_2)*100 391 | overlay256[offset_2:s, offset_2:s] += 100 392 | score256 = (score256_1+new_score256_2)/overlay256 393 | 394 | mask256 = score256.copy() 395 | mask256[mask256 < threshold] = 0 396 | mask256[mask256 > threshold] = 0.95 397 | 398 | color_block256 = (cmap(mask256)*255)[:,:,:3].astype(np.uint8) 399 | region256_hm = cv2.addWeighted(color_block256, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 400 | region256_hm[mask256==0] = 0 401 | img_inverse = save_region.copy() 402 | img_inverse[mask256 == 0.95] = 0 403 | Image.fromarray(region256_hm+img_inverse).save(os.path.join(output_dir, '%s_256th[%d].png' % (fname, i))) 404 | 405 | if False: 406 | for j in range(6): 407 | score4k_1 = concat_scores4k(a4k_1[j], size=(s,)*2) 408 | score4k = score4k_1 / 100 409 | color_block4k = (cmap(score4k)*255)[:,:,:3].astype(np.uint8) 410 | region4k_hm = cv2.addWeighted(color_block4k, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 411 | Image.fromarray(region4k_hm).save(os.path.join(output_dir, '%s_4k[%s].png' % (fname, j))) 412 | 413 | for j in range(6): 414 | score4k_1 = concat_scores4k(a4k_1[j], size=(s,)*2) 415 | score4k_2 = concat_scores4k(a4k_2[j], size=(s,)*2) 416 | score4k_3 = concat_scores4k(a4k_3[j], size=(s,)*2) 417 | score4k_4 = concat_scores4k(a4k_4[j], size=(s,)*2) 418 | 419 | new_score4k_2 = np.zeros_like(score4k_2) 420 | new_score4k_2[offset_2:s, offset_2:s] = score4k_2[:(s-offset_2), :(s-offset_2)] 421 | new_score4k_3 = np.zeros_like(score4k_3) 422 | new_score4k_3[offset_3:s, offset_3:s] = score4k_3[:(s-offset_3), :(s-offset_3)] 423 | new_score4k_4 = np.zeros_like(score4k_4) 424 | new_score4k_4[offset_4:s, offset_4:s] = score4k_4[:(s-offset_4), :(s-offset_4)] 425 | 426 | overlay4k = np.ones_like(score4k_2)*100 427 | overlay4k[offset_2:s, offset_2:s] += 100 428 | overlay4k[offset_3:s, offset_3:s] += 100 429 | overlay4k[offset_4:s, offset_4:s] += 100 430 | score4k = (score4k_1+new_score4k_2+new_score4k_3+new_score4k_4)/overlay4k 431 | 432 | color_block4k = (cmap(score4k)*255)[:,:,:3].astype(np.uint8) 433 | region4k_hm = cv2.addWeighted(color_block4k, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 434 | Image.fromarray(region4k_hm).save(os.path.join(output_dir, '%s_1024[%s].png' % (fname, j))) 435 | 436 | for i in range(6): 437 | score256_1 = concat_scores256(a256_1[:,i,:,:], size=(s//16,)*2) 438 | score256_2 = concat_scores256(a256_2[:,i,:,:], size=(s//16,)*2) 439 | new_score256_2 = np.zeros_like(score256_2) 440 | new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] 441 | overlay256 = np.ones_like(score256_2)*100 442 | overlay256[offset_2:s, offset_2:s] += 100 443 | score256 = (score256_1+new_score256_2)/overlay256 444 | color_block256 = (cmap(score256)*255)[:,:,:3].astype(np.uint8) 445 | region256_hm = cv2.addWeighted(color_block256, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 446 | Image.fromarray(region256_hm).save(os.path.join(output_dir, '%s_256[%s].png' % (fname, i))) 447 | 448 | for j in range(6): 449 | score4k_1 = concat_scores4k(a4k_1[j], size=(s,)*2) 450 | score4k_2 = concat_scores4k(a4k_2[j], size=(s,)*2) 451 | score4k_3 = concat_scores4k(a4k_3[j], size=(s,)*2) 452 | score4k_4 = concat_scores4k(a4k_4[j], size=(s,)*2) 453 | 454 | new_score4k_2 = np.zeros_like(score4k_2) 455 | new_score4k_2[offset_2:s, offset_2:s] = score4k_2[:(s-offset_2), :(s-offset_2)] 456 | new_score4k_3 = np.zeros_like(score4k_3) 457 | new_score4k_3[offset_3:s, offset_3:s] = score4k_3[:(s-offset_3), :(s-offset_3)] 458 | new_score4k_4 = np.zeros_like(score4k_4) 459 | new_score4k_4[offset_4:s, offset_4:s] = score4k_4[:(s-offset_4), :(s-offset_4)] 460 | 461 | overlay4k = np.ones_like(score4k_2)*100 462 | overlay4k[offset_2:s, offset_2:s] += 100 463 | overlay4k[offset_3:s, offset_3:s] += 100 464 | overlay4k[offset_4:s, offset_4:s] += 100 465 | score4k = (score4k_1+new_score4k_2+new_score4k_3+new_score4k_4)/overlay4k 466 | 467 | for i in range(6): 468 | score256_1 = concat_scores256(a256_1[:,i,:,:], size=(s//16,)*2) 469 | score256_2 = concat_scores256(a256_2[:,i,:,:], size=(s//16,)*2) 470 | new_score256_2 = np.zeros_like(score256_2) 471 | new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] 472 | overlay256 = np.ones_like(score256_2)*100*2 473 | overlay256[offset_2:s, offset_2:s] += 100*2 474 | score256 = (score256_1+new_score256_2)*2/overlay256 475 | 476 | factorize = lambda data: (data - np.min(data)) / (np.max(data) - np.min(data)) 477 | score = (score4k*overlay4k+score256*overlay256)/(overlay4k+overlay256) #factorize(score256*score4k) 478 | color_block = (cmap(score)*255)[:,:,:3].astype(np.uint8) 479 | region_hm = cv2.addWeighted(color_block, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 480 | Image.fromarray(region_hm).save(os.path.join(output_dir, '%s_factorized_4k[%s]_256[%s].png' % (fname, j, i))) 481 | 482 | return 483 | 484 | 485 | def create_hierarchical_heatmaps_concat(region, model256, model4k, output_dir, fname, 486 | offset=128, scale=4, alpha=0.5, cmap = plt.get_cmap('coolwarm'), 487 | device256=torch.device('cuda:0'), device4k=torch.device('cuda:0')): 488 | r""" 489 | Creates hierarchical heatmaps (With Raw H&E + ViT-256 + ViT-4K + Blended Heatmaps concatenated for easy comparison) 490 | 491 | To be refactored! 492 | 493 | Args: 494 | - region (PIL.Image): 4096 x 4096 Image 495 | - model256 (torch.nn): 256-Level ViT 496 | - model4k (torch.nn): 4096-Level ViT 497 | - output_dir (str): Save directory / subdirectory 498 | - fname (str): Naming structure of files 499 | - offset (int): How much to offset (from top-left corner with zero-padding) the region by for blending 500 | - scale (int): How much to scale the output image by 501 | - alpha (float): Image blending factor for cv2.addWeighted 502 | - cmap (matplotlib.pyplot): Colormap for creating heatmaps 503 | 504 | Returns: 505 | - None 506 | """ 507 | 508 | region2 = add_margin(region.crop((128,128,4096,4096)), 509 | top=0, left=0, bottom=128, right=128, color=(255,255,255)) 510 | region3 = add_margin(region.crop((128*2,128*2,4096,4096)), 511 | top=0, left=0, bottom=128*2, right=128*2, color=(255,255,255)) 512 | region4 = add_margin(region.crop((128*3,128*3,4096,4096)), 513 | top=0, left=0, bottom=128*4, right=128*4, color=(255,255,255)) 514 | 515 | b256_1, a256_1, a4k_1 = get_region_attention_scores(region, model256, model4k, scale, device256=device256, device4k=device4k) 516 | b256_2, a256_2, a4k_2 = get_region_attention_scores(region2, model256, model4k, scale, device256=device256, device4k=device4k) 517 | b256_3, a256_3, a4k_3 = get_region_attention_scores(region3, model256, model4k, scale, device256=device256, device4k=device4k) 518 | b256_4, a256_4, a4k_4 = get_region_attention_scores(region4, model256, model4k, scale, device256=device256, device4k=device4k) 519 | offset_2 = (offset*1)//scale 520 | offset_3 = (offset*2)//scale 521 | offset_4 = (offset*3)//scale 522 | s = 4096//scale 523 | save_region = np.array(region.resize((s, s))) 524 | 525 | for j in range(6): 526 | score4k_1 = concat_scores4k(a4k_1[j], size=(s,)*2) 527 | score4k_2 = concat_scores4k(a4k_2[j], size=(s,)*2) 528 | score4k_3 = concat_scores4k(a4k_3[j], size=(s,)*2) 529 | score4k_4 = concat_scores4k(a4k_4[j], size=(s,)*2) 530 | 531 | new_score4k_2 = np.zeros_like(score4k_2) 532 | new_score4k_2[offset_2:s, offset_2:s] = score4k_2[:(s-offset_2), :(s-offset_2)] 533 | new_score4k_3 = np.zeros_like(score4k_3) 534 | new_score4k_3[offset_3:s, offset_3:s] = score4k_3[:(s-offset_3), :(s-offset_3)] 535 | new_score4k_4 = np.zeros_like(score4k_4) 536 | new_score4k_4[offset_4:s, offset_4:s] = score4k_4[:(s-offset_4), :(s-offset_4)] 537 | 538 | overlay4k = np.ones_like(score4k_2)*100 539 | overlay4k[offset_2:s, offset_2:s] += 100 540 | overlay4k[offset_3:s, offset_3:s] += 100 541 | overlay4k[offset_4:s, offset_4:s] += 100 542 | score4k = (score4k_1+new_score4k_2+new_score4k_3+new_score4k_4)/overlay4k 543 | 544 | color_block4k = (cmap(score4k_1/100)*255)[:,:,:3].astype(np.uint8) 545 | region4k_hm = cv2.addWeighted(color_block4k, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 546 | 547 | for i in range(6): 548 | score256_1 = concat_scores256(a256_1[:,i,:,:], size=(s//16,)*2) 549 | score256_2 = concat_scores256(a256_2[:,i,:,:], size=(s//16,)*2) 550 | new_score256_2 = np.zeros_like(score256_2) 551 | new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] 552 | overlay256 = np.ones_like(score256_2)*100*2 553 | overlay256[offset_2:s, offset_2:s] += 100*2 554 | score256 = (score256_1+new_score256_2)*2/overlay256 555 | 556 | color_block256 = (cmap(score256)*255)[:,:,:3].astype(np.uint8) 557 | region256_hm = cv2.addWeighted(color_block256, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 558 | 559 | factorize = lambda data: (data - np.min(data)) / (np.max(data) - np.min(data)) 560 | score = (score4k*overlay4k+score256*overlay256)/(overlay4k+overlay256) #factorize(score256*score4k) 561 | color_block = (cmap(score)*255)[:,:,:3].astype(np.uint8) 562 | region_hm = cv2.addWeighted(color_block, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 563 | 564 | pad = 100 565 | canvas = Image.new('RGB', (s*2+pad,)*2, (255,)*3) 566 | draw = ImageDraw.Draw(canvas) 567 | font = ImageFont.truetype("arial.ttf", 50) 568 | draw.text((1024*0.5-pad*2, pad//4), "ViT-256 (Head: %d)" % i, (0, 0, 0), font=font) 569 | canvas = canvas.rotate(90) 570 | draw = ImageDraw.Draw(canvas) 571 | draw.text((1024*1.5-pad, pad//4), "ViT-4K (Head: %d)" % j, (0, 0, 0), font=font) 572 | canvas.paste(Image.fromarray(save_region), (pad,pad)) 573 | canvas.paste(Image.fromarray(region4k_hm), (1024+pad,pad)) 574 | canvas.paste(Image.fromarray(region256_hm), (pad,1024+pad)) 575 | canvas.paste(Image.fromarray(region_hm), (s+pad,s+pad)) 576 | canvas.save(os.path.join(output_dir, '%s_4k[%s]_256[%s].png' % (fname, j, i))) 577 | 578 | return 579 | 580 | 581 | def create_hierarchical_heatmaps_concat_select(region, model256, model4k, output_dir, fname, 582 | offset=128, scale=4, alpha=0.5, cmap = plt.get_cmap('coolwarm'), 583 | device256=torch.device('cuda:0'), device4k=torch.device('cuda:0')): 584 | r""" 585 | Creates hierarchical heatmaps (With Raw H&E + ViT-256 + ViT-4K + Blended Heatmaps concatenated for easy comparison), with only select attention heads are used. 586 | 587 | To be refactored! 588 | 589 | Args: 590 | - region (PIL.Image): 4096 x 4096 Image 591 | - model256 (torch.nn): 256-Level ViT 592 | - model4k (torch.nn): 4096-Level ViT 593 | - output_dir (str): Save directory / subdirectory 594 | - fname (str): Naming structure of files 595 | - offset (int): How much to offset (from top-left corner with zero-padding) the region by for blending 596 | - scale (int): How much to scale the output image by 597 | - alpha (float): Image blending factor for cv2.addWeighted 598 | - cmap (matplotlib.pyplot): Colormap for creating heatmaps 599 | 600 | Returns: 601 | - None 602 | """ 603 | 604 | region2 = add_margin(region.crop((128,128,4096,4096)), 605 | top=0, left=0, bottom=128, right=128, color=(255,255,255)) 606 | region3 = add_margin(region.crop((128*2,128*2,4096,4096)), 607 | top=0, left=0, bottom=128*2, right=128*2, color=(255,255,255)) 608 | region4 = add_margin(region.crop((128*3,128*3,4096,4096)), 609 | top=0, left=0, bottom=128*4, right=128*4, color=(255,255,255)) 610 | 611 | b256_1, a256_1, a4k_1 = get_region_attention_scores(region, model256, model4k, scale, device256=device256, device4k=device4k) 612 | b256_2, a256_2, a4k_2 = get_region_attention_scores(region2, model256, model4k, scale, device256=device256, device4k=device4k) 613 | b256_3, a256_3, a4k_3 = get_region_attention_scores(region3, model256, model4k, scale, device256=device256, device4k=device4k) 614 | b256_4, a256_4, a4k_4 = get_region_attention_scores(region4, model256, model4k, scale, device256=device256, device4k=device4k) 615 | offset_2 = (offset*1)//scale 616 | offset_3 = (offset*2)//scale 617 | offset_4 = (offset*3)//scale 618 | s = 4096//scale 619 | save_region = np.array(region.resize((s, s))) 620 | 621 | canvas = [[Image.fromarray(save_region), None, None], [None, None, None]] 622 | for idx_4k, j in enumerate([0,5]): 623 | score4k_1 = concat_scores4k(a4k_1[j], size=(s,)*2) 624 | score4k_2 = concat_scores4k(a4k_2[j], size=(s,)*2) 625 | score4k_3 = concat_scores4k(a4k_3[j], size=(s,)*2) 626 | score4k_4 = concat_scores4k(a4k_4[j], size=(s,)*2) 627 | 628 | new_score4k_2 = np.zeros_like(score4k_2) 629 | new_score4k_2[offset_2:s, offset_2:s] = score4k_2[:(s-offset_2), :(s-offset_2)] 630 | new_score4k_3 = np.zeros_like(score4k_3) 631 | new_score4k_3[offset_3:s, offset_3:s] = score4k_3[:(s-offset_3), :(s-offset_3)] 632 | new_score4k_4 = np.zeros_like(score4k_4) 633 | new_score4k_4[offset_4:s, offset_4:s] = score4k_4[:(s-offset_4), :(s-offset_4)] 634 | 635 | overlay4k = np.ones_like(score4k_2)*100 636 | overlay4k[offset_2:s, offset_2:s] += 100 637 | overlay4k[offset_3:s, offset_3:s] += 100 638 | overlay4k[offset_4:s, offset_4:s] += 100 639 | score4k = (score4k_1+new_score4k_2+new_score4k_3+new_score4k_4)/overlay4k 640 | 641 | color_block4k = (cmap(score4k_1/100)*255)[:,:,:3].astype(np.uint8) 642 | region4k_hm = cv2.addWeighted(color_block4k, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 643 | canvas[0][idx_4k+1] = Image.fromarray(region4k_hm) 644 | 645 | for idx_256, i in enumerate([2]): 646 | score256_1 = concat_scores256(a256_1[:,i,:,:], size=(s//16,)*2) 647 | score256_2 = concat_scores256(a256_2[:,i,:,:], size=(s//16,)*2) 648 | new_score256_2 = np.zeros_like(score256_2) 649 | new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] 650 | overlay256 = np.ones_like(score256_2)*100*2 651 | overlay256[offset_2:s, offset_2:s] += 100*2 652 | score256 = (score256_1+new_score256_2)*2/overlay256 653 | 654 | color_block256 = (cmap(score256)*255)[:,:,:3].astype(np.uint8) 655 | region256_hm = cv2.addWeighted(color_block256, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 656 | canvas[idx_256+1][0] = Image.fromarray(region256_hm) 657 | 658 | factorize = lambda data: (data - np.min(data)) / (np.max(data) - np.min(data)) 659 | score = (score4k*overlay4k+score256*overlay256)/(overlay4k+overlay256) #factorize(score256*score4k) 660 | color_block = (cmap(score)*255)[:,:,:3].astype(np.uint8) 661 | region_hm = cv2.addWeighted(color_block, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) 662 | canvas[idx_256+1][idx_4k+1] = Image.fromarray(region_hm) 663 | 664 | canvas = getConcatImage([getConcatImage(row) for row in canvas], how='vertical') 665 | canvas.save(os.path.join(output_dir, '%s_heatmap.png' % (fname))) 666 | return 667 | -------------------------------------------------------------------------------- /hipt_model_utils.py: -------------------------------------------------------------------------------- 1 | # Dependencies 2 | 3 | # LinAlg / Stats / Plotting Dependencies 4 | import numpy as np 5 | from PIL import Image 6 | 7 | # Torch Dependencies 8 | import torch 9 | import torch.multiprocessing 10 | from torchvision import transforms 11 | from einops import rearrange 12 | 13 | # Local Dependencies 14 | import vision_transformer as vits 15 | import vision_transformer4k as vits4k 16 | 17 | 18 | torch.multiprocessing.set_sharing_strategy("file_system") 19 | 20 | 21 | def get_vit256( 22 | pretrained_weights=None, arch="vit_small", device=torch.device("cuda:0") 23 | ): 24 | r""" 25 | Builds ViT-256 Model. 26 | 27 | Args: 28 | - pretrained_weights (str): Path to ViT-256 Model Checkpoint. 29 | - arch (str): Which model architecture. 30 | - device (torch): Torch device to save model. 31 | 32 | Returns: 33 | - model256 (torch.nn): Initialized model. 34 | """ 35 | 36 | checkpoint_key = "teacher" 37 | device = ( 38 | torch.device("cuda:0") 39 | if torch.cuda.is_available() 40 | else torch.device("cpu") 41 | ) 42 | model256 = vits.__dict__[arch](patch_size=16, num_classes=0) 43 | for p in model256.parameters(): 44 | p.requires_grad = False 45 | model256.eval() 46 | model256.to(device) 47 | 48 | if pretrained_weights is not None: 49 | state_dict = torch.load(pretrained_weights, map_location="cpu") 50 | if checkpoint_key is not None and checkpoint_key in state_dict: 51 | # print(f"Take key {checkpoint_key} in provided checkpoint dict") 52 | state_dict = state_dict[checkpoint_key] 53 | # remove `module.` prefix 54 | state_dict = { 55 | k.replace("module.", ""): v for k, v in state_dict.items()} 56 | # remove `backbone.` prefix induced by multicrop wrapper 57 | state_dict = { 58 | k.replace("backbone.", ""): v for k, v in state_dict.items()} 59 | model256.load_state_dict(state_dict, strict=False) 60 | # print("Pretrained weights loaded from {}".format(pretrained_weights)) 61 | 62 | return model256 63 | 64 | 65 | def get_vit4k( 66 | pretrained_weights=None, arch="vit4k_xs", 67 | device=torch.device("cuda:0")): 68 | r""" 69 | Builds ViT-4K Model. 70 | 71 | Args: 72 | - pretrained_weights (str): Path to ViT-4K Model Checkpoint. 73 | - arch (str): Which model architecture. 74 | - device (torch): Torch device to save model. 75 | 76 | Returns: 77 | - model256 (torch.nn): Initialized model. 78 | """ 79 | 80 | checkpoint_key = "teacher" 81 | device = ( 82 | torch.device("cuda:0") 83 | if torch.cuda.is_available() 84 | else torch.device("cpu") 85 | ) 86 | model4k = vits4k.__dict__[arch](num_classes=0) 87 | for p in model4k.parameters(): 88 | p.requires_grad = False 89 | model4k.eval() 90 | model4k.to(device) 91 | 92 | if pretrained_weights is not None: 93 | state_dict = torch.load(pretrained_weights, map_location="cpu") 94 | if checkpoint_key is not None and checkpoint_key in state_dict: 95 | # print(f"Take key {checkpoint_key} in provided checkpoint dict") 96 | state_dict = state_dict[checkpoint_key] 97 | # remove `module.` prefix 98 | state_dict = { 99 | k.replace("module.", ""): v for k, v in state_dict.items()} 100 | # remove `backbone.` prefix induced by multicrop wrapper 101 | state_dict = { 102 | k.replace("backbone.", ""): v for k, v in state_dict.items()} 103 | model4k.load_state_dict(state_dict, strict=False) 104 | # print("Pretrained weights loaded from {}".format(pretrained_weights)) 105 | 106 | return model4k 107 | 108 | 109 | def eval_transforms(): 110 | """ """ 111 | mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 112 | eval_t = transforms.Compose( 113 | [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)] 114 | ) 115 | return eval_t 116 | 117 | 118 | def roll_batch2img(batch: torch.Tensor, w: int, h: int, patch_size=256): 119 | """ 120 | Rolls an image tensor batch (batch of [256 x 256] images) 121 | into a [W x H] Pil.Image object. 122 | 123 | Args: 124 | batch (torch.Tensor): [B x 3 x 256 x 256] image tensor batch. 125 | 126 | Return: 127 | Image.PIL: [W x H X 3] Image. 128 | """ 129 | batch = batch.reshape(w, h, 3, patch_size, patch_size) 130 | img = rearrange(batch, "p1 p2 c w h-> c (p1 w) (p2 h)").unsqueeze(dim=0) 131 | return Image.fromarray(tensorbatch2im(img)[0]) 132 | 133 | 134 | def tensorbatch2im(input_image, imtype=np.uint8): 135 | r""" " 136 | Converts a Tensor array into a numpy image array. 137 | 138 | Args: 139 | - input_image (torch.Tensor): (B, C, W, H) Torch Tensor. 140 | - imtype (type): the desired type of the converted numpy array 141 | 142 | Returns: 143 | - image_numpy (np.array): (B, W, H, C) Numpy Array. 144 | """ 145 | if not isinstance(input_image, np.ndarray): 146 | # convert it into a numpy array 147 | image_numpy = input_image.cpu().float().numpy() 148 | # if image_numpy.shape[0] == 1: # grayscale to RGB 149 | # image_numpy = np.tile(image_numpy, (3, 1, 1)) 150 | image_numpy = ( 151 | (np.transpose(image_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0 152 | ) # post-processing: tranpose and scaling 153 | else: # if it is a numpy array, do nothing 154 | image_numpy = input_image 155 | return image_numpy.astype(imtype) 156 | -------------------------------------------------------------------------------- /image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 as cv 3 | import torch 4 | from torch import nn 5 | import skimage 6 | from scipy.ndimage import uniform_filter 7 | 8 | 9 | def impute_missing(x, mask, radius=3, method='ns'): 10 | 11 | method_dict = { 12 | 'telea': cv.INPAINT_TELEA, 13 | 'ns': cv.INPAINT_NS} 14 | method = method_dict[method] 15 | 16 | x = x.copy() 17 | if x.dtype == np.float64: 18 | x = x.astype(np.float32) 19 | 20 | x[mask] = 0 21 | mask = mask.astype(np.uint8) 22 | 23 | expand_dim = np.ndim(x) == 2 24 | if expand_dim: 25 | x = x[..., np.newaxis] 26 | channels = [x[..., i] for i in range(x.shape[-1])] 27 | y = [cv.inpaint(c, mask, radius, method) for c in channels] 28 | y = np.stack(y, -1) 29 | if expand_dim: 30 | y = y[..., 0] 31 | 32 | return y 33 | 34 | 35 | def smoothen( 36 | x, size, kernel='gaussian', backend='cv', mode='mean', 37 | impute_missing_values=True, device='cuda'): 38 | 39 | if x.ndim == 3: 40 | expand_dim = False 41 | elif x.ndim == 2: 42 | expand_dim = True 43 | x = x[..., np.newaxis] 44 | else: 45 | raise ValueError('ndim must be 2 or 3') 46 | 47 | mask = np.isfinite(x).all(-1) 48 | if (~mask).any() and impute_missing_values: 49 | x = impute_missing(x, ~mask) 50 | 51 | if kernel == 'gaussian': 52 | sigma = size / 4 # approximate std of uniform filter 1/sqrt(12) 53 | truncate = 4.0 54 | winsize = np.ceil(sigma * truncate).astype(int) * 2 + 1 55 | if backend == 'cv': 56 | print(f'gaussian filter: winsize={winsize}, sigma={sigma}') 57 | y = cv.GaussianBlur( 58 | x, (winsize, winsize), sigmaX=sigma, sigmaY=sigma, 59 | borderType=cv.BORDER_REFLECT) 60 | elif backend == 'skimage': 61 | y = skimage.filters.gaussian( 62 | x, sigma=sigma, truncate=truncate, 63 | preserve_range=True, channel_axis=-1) 64 | else: 65 | raise ValueError('backend must be cv or skimage') 66 | elif kernel == 'uniform': 67 | if backend == 'cv': 68 | kernel = np.ones((size, size), np.float32) / size**2 69 | y = cv.filter2D( 70 | x, ddepth=-1, kernel=kernel, 71 | borderType=cv.BORDER_REFLECT) 72 | if y.ndim == 2: 73 | y = y[..., np.newaxis] 74 | elif backend == 'torch': 75 | assert isinstance(size, int) 76 | padding = size // 2 77 | size = size + 1 78 | 79 | pool_dict = { 80 | 'mean': nn.AvgPool2d( 81 | kernel_size=size, stride=1, padding=0), 82 | 'max': nn.MaxPool2d( 83 | kernel_size=size, stride=1, padding=0)} 84 | pool = pool_dict[mode] 85 | 86 | mod = nn.Sequential( 87 | nn.ReflectionPad2d(padding), 88 | pool) 89 | y = mod(torch.tensor(x, device=device).permute(2, 0, 1)) 90 | y = y.permute(1, 2, 0) 91 | y = y.cpu().detach().numpy() 92 | else: 93 | raise ValueError('backend must be cv or torch') 94 | else: 95 | raise ValueError('kernel must be gaussian or uniform') 96 | 97 | if not mask.all(): 98 | y[~mask] = np.nan 99 | 100 | if expand_dim and y.ndim == 3: 101 | y = y[..., 0] 102 | 103 | return y 104 | 105 | 106 | def upscale(x, target_shape): 107 | mask = np.isfinite(x).all(tuple(range(2, x.ndim))) 108 | x = impute_missing(x, ~mask, radius=3) 109 | # TODO: Consider using pytorch with cuda to speed up 110 | # order: 0 == nearest neighbor, 1 == bilinear, 3 == bicubic 111 | dtype = x.dtype 112 | x = skimage.transform.resize( 113 | x, target_shape, order=3, preserve_range=True) 114 | x = x.astype(dtype) 115 | if not mask.all(): 116 | mask = skimage.transform.resize( 117 | mask.astype(float), target_shape, order=3, 118 | preserve_range=True) 119 | mask = mask > 0.5 120 | x[~mask] = np.nan 121 | return x 122 | 123 | 124 | def crop_image(img, extent, mode='edge', constant_values=None): 125 | extent = np.array(extent) 126 | pad = np.zeros((img.ndim, 2), dtype=int) 127 | for i, (lower, upper) in enumerate(extent): 128 | if lower < 0: 129 | pad[i][0] = 0 - lower 130 | if upper > img.shape[i]: 131 | pad[i][1] = upper - img.shape[i] 132 | if (pad != 0).any(): 133 | kwargs = {} 134 | if mode == 'constant' and constant_values is not None: 135 | kwargs['constant_values'] = constant_values 136 | img = np.pad(img, pad, mode=mode, **kwargs) 137 | extent += pad[:extent.shape[0], [0]] 138 | for i, (lower, upper) in enumerate(extent): 139 | img = img.take(range(lower, upper), axis=i) 140 | return img 141 | 142 | 143 | def get_disk_mask(radius, boundary_width=None): 144 | radius_ceil = np.ceil(radius).astype(int) 145 | locs = np.meshgrid( 146 | np.arange(-radius_ceil, radius_ceil+1), 147 | np.arange(-radius_ceil, radius_ceil+1), 148 | indexing='ij') 149 | locs = np.stack(locs, -1) 150 | distsq = (locs**2).sum(-1) 151 | isin = distsq <= radius**2 152 | if boundary_width is not None: 153 | isin *= distsq >= (radius-boundary_width)**2 154 | return isin 155 | 156 | 157 | def shrink_mask(x, size): 158 | size = size * 2 + 1 159 | x = uniform_filter(x.astype(float), size=size) 160 | x = np.isclose(x, 1) 161 | return x 162 | -------------------------------------------------------------------------------- /impute.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | import pytorch_lightning as pl 7 | from torch.optim import Adam 8 | from torch import nn 9 | import numpy as np 10 | 11 | from impute_by_basic import get_gene_counts, get_embeddings, get_locs 12 | from utils import read_lines, read_string, save_pickle 13 | from image import get_disk_mask 14 | from train import get_model as train_load_model 15 | # from reduce_dim import reduce_dim 16 | from visual import plot_matrix, plot_spot_masked_image 17 | 18 | 19 | class FeedForward(nn.Module): 20 | 21 | def __init__( 22 | self, n_inp, n_out, 23 | activation=None, residual=False): 24 | super().__init__() 25 | self.linear = nn.Linear(n_inp, n_out) 26 | if activation is None: 27 | # TODO: change activation to LeakyRelu(0.01) 28 | activation = nn.LeakyReLU(0.1, inplace=True) 29 | self.activation = activation 30 | self.residual = residual 31 | 32 | def forward(self, x, indices=None): 33 | if indices is None: 34 | y = self.linear(x) 35 | else: 36 | weight = self.linear.weight[indices] 37 | bias = self.linear.bias[indices] 38 | y = nn.functional.linear(x, weight, bias) 39 | y = self.activation(y) 40 | if self.residual: 41 | y = y + x 42 | return y 43 | 44 | 45 | class ELU(nn.Module): 46 | 47 | def __init__(self, alpha, beta): 48 | super().__init__() 49 | self.activation = nn.ELU(alpha=alpha, inplace=True) 50 | self.beta = beta 51 | 52 | def forward(self, x): 53 | return self.activation(x) + self.beta 54 | 55 | 56 | class ForwardSumModel(pl.LightningModule): 57 | 58 | def __init__(self, lr, n_inp, n_out): 59 | super().__init__() 60 | self.lr = lr 61 | self.net_lat = nn.Sequential( 62 | FeedForward(n_inp, 256), 63 | FeedForward(256, 256), 64 | FeedForward(256, 256), 65 | FeedForward(256, 256)) 66 | self.net_out = FeedForward( 67 | 256, n_out, 68 | activation=ELU(alpha=0.01, beta=0.01)) 69 | self.save_hyperparameters() 70 | 71 | def inp_to_lat(self, x): 72 | return self.net_lat.forward(x) 73 | 74 | def lat_to_out(self, x, indices=None): 75 | x = self.net_out.forward(x, indices) 76 | return x 77 | 78 | def forward(self, x, indices=None): 79 | x = self.inp_to_lat(x) 80 | x = self.lat_to_out(x, indices) 81 | return x 82 | 83 | def training_step(self, batch, batch_idx): 84 | x, y_mean = batch 85 | y_pred = self.forward(x) 86 | y_mean_pred = y_pred.mean(-2) 87 | # TODO: try l1 loss 88 | mse = ((y_mean_pred - y_mean)**2).mean() 89 | loss = mse 90 | self.log('rmse', mse**0.5, prog_bar=True) 91 | return loss 92 | 93 | def configure_optimizers(self): 94 | optimizer = Adam(self.parameters(), lr=self.lr) 95 | return optimizer 96 | 97 | 98 | class SpotDataset(Dataset): 99 | 100 | def __init__(self, x_all, y, locs, radius): 101 | super().__init__() 102 | mask = get_disk_mask(radius) 103 | x = get_patches_flat(x_all, locs, mask) 104 | isin = np.isfinite(x).all((-1, -2)) 105 | self.x = x[isin] 106 | self.y = y[isin] 107 | self.locs = locs[isin] 108 | self.size = x_all.shape[:2] 109 | self.radius = radius 110 | self.mask = mask 111 | 112 | def __len__(self): 113 | return len(self.x) 114 | 115 | def __getitem__(self, idx): 116 | return self.x[idx], self.y[idx] 117 | 118 | def show(self, channel_x, channel_y, prefix): 119 | mask = self.mask 120 | size = self.size 121 | locs = self.locs 122 | xs = self.x 123 | ys = self.y 124 | 125 | plot_spot_masked_image( 126 | locs=locs, values=xs[:, :, channel_x], mask=mask, size=size, 127 | outfile=f'{prefix}x{channel_x:04d}.png') 128 | 129 | plot_spot_masked_image( 130 | locs=locs, values=ys[:, channel_y], mask=mask, size=size, 131 | outfile=f'{prefix}y{channel_y:04d}.png') 132 | 133 | 134 | def get_disk(img, ij, radius): 135 | i, j = ij 136 | patch = img[i-radius:i+radius, j-radius:j+radius] 137 | disk_mask = get_disk_mask(radius) 138 | patch[~disk_mask] = 0.0 139 | return patch 140 | 141 | 142 | def get_patches_flat(img, locs, mask): 143 | shape = np.array(mask.shape) 144 | center = shape // 2 145 | r = np.stack([-center, shape-center], -1) # offset 146 | x_list = [] 147 | for s in locs: 148 | patch = img[ 149 | s[0]+r[0][0]:s[0]+r[0][1], 150 | s[1]+r[1][0]:s[1]+r[1][1]] 151 | if mask.all(): 152 | x = patch 153 | else: 154 | x = patch[mask] 155 | x_list.append(x) 156 | x_list = np.stack(x_list) 157 | return x_list 158 | 159 | 160 | def add_coords(embs): 161 | coords = np.stack(np.meshgrid( 162 | np.linspace(-1, 1, embs.shape[0]), 163 | np.linspace(-1, 1, embs.shape[1]), 164 | indexing='ij'), -1) 165 | coords = coords.astype(embs.dtype) 166 | mask = np.isfinite(embs).all(-1) 167 | coords[~mask] = np.nan 168 | embs = np.concatenate([embs, coords], -1) 169 | return embs 170 | 171 | 172 | # def reduce_embeddings(embs): 173 | # # cls features 174 | # cls, __ = reduce_dim(embs[..., :192], 0.99) 175 | # # sub features 176 | # sub, __ = reduce_dim(embs[..., 192:-3], 0.90) 177 | # rgb = embs[..., -3:] 178 | # embs = np.concatenate([cls, sub, rgb], -1) 179 | # return embs 180 | 181 | 182 | def get_data(prefix): 183 | gene_names = read_lines(f'{prefix}gene-names.txt') 184 | cnts = get_gene_counts(prefix) 185 | cnts = cnts[gene_names] 186 | embs = get_embeddings(prefix) 187 | # embs = embs[..., :192] # use high-level features only 188 | # embs = reduce_embeddings(embs) 189 | locs = get_locs(prefix, target_shape=embs.shape[:2]) 190 | # embs = add_coords(embs) 191 | return embs, cnts, locs 192 | 193 | 194 | def get_model_kwargs(kwargs): 195 | return get_model(**kwargs) 196 | 197 | 198 | def get_model( 199 | x, y, locs, radius, prefix, batch_size, epochs, lr, 200 | load_saved=False, device='cuda'): 201 | 202 | print('x:', x.shape, ', y:', y.shape) 203 | 204 | x = x.copy() 205 | 206 | dataset = SpotDataset(x, y, locs, radius) 207 | dataset.show( 208 | channel_x=0, channel_y=0, 209 | prefix=f'{prefix}training-data-plots/') 210 | model = train_load_model( 211 | model_class=ForwardSumModel, 212 | model_kwargs=dict( 213 | n_inp=x.shape[-1], 214 | n_out=y.shape[-1], 215 | lr=lr), 216 | dataset=dataset, prefix=prefix, 217 | batch_size=batch_size, epochs=epochs, 218 | load_saved=load_saved, device=device) 219 | model.eval() 220 | if device == 'cuda': 221 | torch.cuda.empty_cache() 222 | return model, dataset 223 | 224 | 225 | def normalize(embs, cnts): 226 | 227 | embs = embs.copy() 228 | cnts = cnts.copy() 229 | 230 | # TODO: check if adjsut_weights in extract_features can be skipped 231 | embs_mean = np.nanmean(embs, (0, 1)) 232 | embs_std = np.nanstd(embs, (0, 1)) 233 | embs -= embs_mean 234 | embs /= embs_std + 1e-12 235 | 236 | cnts_min = cnts.min(0) 237 | cnts_max = cnts.max(0) 238 | cnts -= cnts_min 239 | cnts /= (cnts_max - cnts_min) + 1e-12 240 | 241 | return embs, cnts, (embs_mean, embs_std), (cnts_min, cnts_max) 242 | 243 | 244 | def show_results(x, names, prefix): 245 | for name in ['CD19', 'MS4A1', 'ERBB2', 'GNAS']: 246 | if name in names: 247 | idx = np.where(names == name)[0][0] 248 | plot_matrix(x[..., idx], prefix+name+'.png') 249 | 250 | 251 | def predict_single_out(model, z, indices, names, y_range): 252 | z = torch.tensor(z, device=model.device) 253 | y = model.lat_to_out(z, indices=indices) 254 | y = y.cpu().detach().numpy() 255 | # y[y < 0.01] = 0.0 256 | # y[y > 1.0] = 1.0 257 | y *= y_range[:, 1] - y_range[:, 0] 258 | y += y_range[:, 0] 259 | return y 260 | 261 | 262 | def predict_single_lat(model, x): 263 | x = torch.tensor(x, device=model.device) 264 | z = model.inp_to_lat(x) 265 | z = z.cpu().detach().numpy() 266 | return z 267 | 268 | 269 | # def cluster_lat(x, prefix, device='cuda'): 270 | # x_minor = x 271 | # x_major = smoothen( 272 | # x_minor, size=8, method='cnn', mode='mean', 273 | # device=device) 274 | # labels = cluster_hierarchical( 275 | # x_major.transpose(2, 0, 1), x_minor.transpose(2, 0, 1), 276 | # method='km', n_clusters=10) 277 | # # x = reduce_dim(x, method='pca', n_components=0.95)[0] 278 | # # labels_raw = cluster( 279 | # # x.transpose(2, 0, 1), method='km', n_clusters=10)[0] 280 | # # labels_cls = relabel_small_connected(labels_raw, min_size=1000) 281 | # # labels_con = cluster_connected(labels_cls) 282 | # # labels = np.stack([labels_cls, labels_con], -1) 283 | # plot_labels(labels[..., :2], prefix+'clusters-genes.png') 284 | # save_pickle(labels, prefix+'clusters-genes.pickle') 285 | # return labels 286 | 287 | 288 | def predict( 289 | model_states, x_batches, name_list, y_range, prefix, 290 | device='cuda'): 291 | 292 | # states: different initial values for training 293 | # batches: subsets of observations 294 | # groups: subsets outcomes 295 | 296 | batch_size_outcome = 100 297 | 298 | model_states = [mod.to(device) for mod in model_states] 299 | 300 | # get features of second last layer 301 | z_states_batches = [ 302 | [predict_single_lat(mod, x_bat) for mod in model_states] 303 | for x_bat in x_batches] 304 | z_point = np.concatenate([ 305 | np.median(z_states, 0) 306 | for z_states in z_states_batches]) 307 | z_dict = dict(cls=z_point.transpose(2, 0, 1)) 308 | save_pickle( 309 | z_dict, 310 | prefix+'embeddings-gene.pickle') 311 | del z_point 312 | 313 | # predict and save y by batches in outcome dimension 314 | idx_list = np.arange(len(name_list)) 315 | n_groups_outcome = len(idx_list) // batch_size_outcome + 1 316 | idx_groups = np.array_split(idx_list, n_groups_outcome) 317 | for idx_grp in idx_groups: 318 | name_grp = name_list[idx_grp] 319 | y_ran = y_range[idx_grp] 320 | y_grp = np.concatenate([ 321 | np.median([ 322 | predict_single_out(mod, z, idx_grp, name_grp, y_ran) 323 | for mod, z in zip(model_states, z_states)], 0) 324 | for z_states in z_states_batches]) 325 | for i, name in enumerate(name_grp): 326 | save_pickle(y_grp[..., i], f'{prefix}cnts-super/{name}.pickle') 327 | 328 | 329 | def impute( 330 | embs, cnts, locs, radius, epochs, batch_size, prefix, 331 | n_states=1, load_saved=False, device='cuda', n_jobs=1): 332 | 333 | names = cnts.columns 334 | cnts = cnts.to_numpy() 335 | cnts = cnts.astype(np.float32) 336 | 337 | __, cnts, __, (cnts_min, cnts_max) = normalize(embs, cnts) 338 | 339 | # mask = np.isfinite(embs).all(-1) 340 | # embs[~mask] = 0.0 341 | 342 | kwargs_list = [ 343 | dict( 344 | x=embs, y=cnts, locs=locs, radius=radius, 345 | batch_size=batch_size, epochs=epochs, lr=1e-4, 346 | prefix=f'{prefix}states/{i:02d}/', 347 | load_saved=load_saved, device=device) 348 | for i in range(n_states)] 349 | 350 | if n_jobs is None or n_jobs < 1: 351 | n_jobs = n_states 352 | if n_jobs == 1: 353 | out_list = [get_model_kwargs(kwargs) for kwargs in kwargs_list] 354 | else: 355 | with multiprocessing.Pool(processes=n_jobs) as pool: 356 | out_list = pool.map(get_model_kwargs, kwargs_list) 357 | 358 | model_list = [out[0] for out in out_list] 359 | dataset_list = [out[1] for out in out_list] 360 | mask_size = dataset_list[0].mask.sum() 361 | 362 | # embs[~mask] = np.nan 363 | cnts_range = np.stack([cnts_min, cnts_max], -1) 364 | cnts_range /= mask_size 365 | 366 | batch_size_row = 50 367 | n_batches_row = embs.shape[0] // batch_size_row + 1 368 | embs_batches = np.array_split(embs, n_batches_row) 369 | del embs 370 | predict( 371 | model_states=model_list, x_batches=embs_batches, 372 | name_list=names, y_range=cnts_range, 373 | prefix=prefix, device=device) 374 | # show_results(cnts_pred, names, prefix) 375 | 376 | 377 | def get_args(): 378 | parser = argparse.ArgumentParser() 379 | parser.add_argument('prefix', type=str) 380 | parser.add_argument('--epochs', type=int, default=None) # e.g. 400 381 | parser.add_argument('--n-states', type=int, default=5) 382 | parser.add_argument('--device', type=str, default='cuda') 383 | parser.add_argument('--n-jobs', type=int, default=1) 384 | parser.add_argument('--load-saved', action='store_true') 385 | args = parser.parse_args() 386 | return args 387 | 388 | 389 | def main(): 390 | args = get_args() 391 | embs, cnts, locs = get_data(args.prefix) 392 | args = get_args() 393 | 394 | factor = 16 395 | radius = int(read_string(f'{args.prefix}radius.txt')) 396 | radius = radius / factor 397 | 398 | n_train = cnts.shape[0] 399 | batch_size = min(128, n_train//16) 400 | 401 | impute( 402 | embs=embs, cnts=cnts, locs=locs, radius=radius, 403 | epochs=args.epochs, batch_size=batch_size, 404 | n_states=args.n_states, prefix=args.prefix, 405 | load_saved=args.load_saved, 406 | device=args.device, n_jobs=args.n_jobs) 407 | 408 | 409 | if __name__ == '__main__': 410 | # torch.multiprocessing.set_start_method('spawn') 411 | main() 412 | -------------------------------------------------------------------------------- /impute_by_basic.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import numpy as np 5 | from sklearn.neighbors import NearestNeighbors 6 | from sklearn.neural_network import MLPRegressor 7 | import matplotlib.pyplot as plt 8 | 9 | from utils import ( 10 | load_tsv, load_pickle, save_pickle, load_image, save_image, 11 | read_lines) 12 | from image import smoothen 13 | 14 | 15 | class NearestNeighborsRegressor(): 16 | 17 | def __init__(self, n_neighbors, weights): 18 | self.nbrs = NearestNeighbors(n_neighbors=n_neighbors) 19 | self.weights = weights 20 | 21 | def fit(self, x): 22 | self.nbrs.fit(x) 23 | 24 | def predict_x(self, x): 25 | # get neighbor indices and weights 26 | distances, indices = self.nbrs.kneighbors(x) 27 | if self.weights == 'uniform': 28 | wts = np.ones_like(distances) 29 | elif self.weights == 'distance': 30 | wts = 1.0 / (distances + 1e-12) 31 | else: 32 | raise ValueError('Weight function not recognized') 33 | wts /= wts.sum(1, keepdims=True) 34 | self.wts = wts 35 | self.indices = indices 36 | 37 | def predict_y(self, y): 38 | assert y.ndim == 2 39 | y_neighbors = y[self.indices] 40 | y_mean = (y_neighbors * self.wts[..., np.newaxis]).sum(1) 41 | y_diffsq = (y_neighbors - y_mean[..., np.newaxis, :])**2 42 | y_variance = (y_diffsq * self.wts[..., np.newaxis]).sum(1) 43 | return y_mean, y_variance 44 | 45 | 46 | # def draw_overlay(locs, embs, radius, outfile): 47 | # em = embs[..., :3] 48 | # em -= np.nanmin(em) 49 | # em /= np.nanmax(em) 50 | # save_image( 51 | # draw_spots( 52 | # locs, 53 | # (em*255).astype(np.uint8), 54 | # rad=radius, color=128), 55 | # outfile) 56 | 57 | 58 | def log_normal(mean, variance): 59 | mean_new = np.exp(mean + variance * 0.5) 60 | variance_new = ( 61 | (np.exp(variance) - 1) 62 | * np.exp(mean * 2 + variance)) 63 | return mean_new, variance_new 64 | 65 | 66 | def impute_by_neighbors( 67 | y_train, x_train, x_test, prefix, 68 | n_neighbors=5, weights='uniform'): 69 | y_train = y_train.astype(np.float32) 70 | 71 | model = NearestNeighborsRegressor( 72 | n_neighbors=n_neighbors, weights=weights) 73 | model.fit(x=x_train) 74 | mask = np.isfinite(x_test).all(-1) 75 | model.predict_x(x_test[mask]) 76 | 77 | for name, y_tra in y_train.items(): 78 | y_tra = y_tra.to_numpy() 79 | y_mea, y_var = model.predict_y(y_tra[..., np.newaxis]) 80 | y_mea_arr = np.full( 81 | x_test.shape[:-1], np.nan, dtype=y_tra.dtype) 82 | y_var_arr = np.full( 83 | x_test.shape[:-1], np.nan, dtype=y_tra.dtype) 84 | y_mea_arr[mask] = y_mea[..., 0] 85 | y_var_arr[mask] = y_var[..., 0] 86 | save_pickle(y_mea_arr, f'{prefix}mean/{name}.pickle') 87 | save_pickle(y_var_arr, f'{prefix}variance/{name}.pickle') 88 | 89 | 90 | def impute_by_neural(y_train, x_train, x_test, prefix, **kwargs): 91 | model = MLPRegressor( 92 | hidden_layer_sizes=(128, 128, 128, 128), activation='relu', 93 | learning_rate='adaptive', learning_rate_init=1e-3, 94 | batch_size=100, max_iter=10000, tol=1e-6, 95 | alpha=1e-2, 96 | random_state=0, verbose=True) 97 | 98 | x_train = x_train.copy() 99 | x_test = x_test.copy() 100 | y_train = y_train.copy() 101 | 102 | names = y_train.columns 103 | y_train = y_train.to_numpy() 104 | y_train = y_train.astype(np.float32) 105 | 106 | x_mean = x_train.mean(0) 107 | x_std = x_train.std(0) 108 | x_train -= x_mean 109 | x_train /= x_std + 1e-12 110 | 111 | y_min = y_train.min(0) 112 | y_max = y_train.max(0) 113 | y_train -= y_min 114 | y_train /= (y_max - y_min) + 1e-12 115 | 116 | model.fit(x_train, y_train) 117 | 118 | x_test = x_test - x_mean 119 | x_test = x_test / (x_std + 1e-12) 120 | mask = np.isfinite(x_test).all(-1) 121 | y_test = model.predict(x_test[mask]) 122 | y_test = np.clip(y_test, 0, 1) 123 | # threshold = 0.1 124 | # y_test[y_test < threshold] = 0.0 125 | 126 | y_test *= y_max - y_min 127 | y_test += y_min 128 | 129 | y_test_arr = np.full( 130 | (x_test.shape[:-1] + y_test.shape[-1:]), 131 | np.nan, dtype=y_test.dtype) 132 | y_test_arr[mask] = y_test 133 | 134 | idx = np.where(names == 'MS4A1')[0][0] 135 | aa = y_test_arr[..., idx].copy() 136 | aa -= np.nanmin(aa) 137 | aa /= np.nanmax(aa) 138 | cmap = plt.get_cmap('turbo') 139 | img = cmap(aa)[..., :3] 140 | save_image((img * 255).astype(np.uint8), 'a.png') 141 | 142 | return y_test_arr, names 143 | 144 | 145 | def impute(y_train, x_train, x_test, prefix, method, **kwargs): 146 | 147 | if method == 'neighbors': 148 | impute_by_neighbors(y_train, x_train, x_test, prefix, **kwargs) 149 | elif method == 'neural': 150 | impute_by_neural(y_train, x_train, x_test, prefix, **kwargs) 151 | else: 152 | raise ValueError('Method not recognized.') 153 | 154 | 155 | def get_locs(prefix, target_shape=None): 156 | 157 | locs = load_tsv(f'{prefix}locs.tsv') 158 | 159 | # change xy coordinates to ij coordinates 160 | locs = np.stack([locs['y'], locs['x']], -1) 161 | 162 | # match coordinates of embeddings and spot locations 163 | if target_shape is not None: 164 | wsi = load_image(f'{prefix}he.jpg') 165 | current_shape = np.array(wsi.shape[:2]) 166 | rescale_factor = current_shape // target_shape 167 | locs = locs.astype(float) 168 | locs /= rescale_factor 169 | 170 | # find the nearest pixel 171 | locs = locs.round().astype(int) 172 | 173 | return locs 174 | 175 | 176 | def get_gene_counts(prefix, reorder_genes=True): 177 | cnts = load_tsv(f'{prefix}cnts.tsv') 178 | if reorder_genes: 179 | order = cnts.var().to_numpy().argsort()[::-1] 180 | cnts = cnts.iloc[:, order] 181 | return cnts 182 | 183 | 184 | def get_embeddings(prefix): 185 | embs = load_pickle(f'{prefix}embeddings-hist.pickle') 186 | embs = np.concatenate([embs['cls'], embs['sub'], embs['rgb']]) 187 | embs = embs.transpose(1, 2, 0) 188 | return embs 189 | 190 | 191 | def smoothen_batch(embs, **kwargs): 192 | embs_batches = np.array_split(embs, 8, axis=-1) 193 | embs_smooth = np.concatenate( 194 | [smoothen(e, **kwargs) for e in embs_batches], -1) 195 | return embs_smooth 196 | 197 | 198 | def get_training_data(prefix, gene_names, spot_radius, log_counts=False): 199 | # get targets (gene counts) 200 | cnts = get_gene_counts(prefix) 201 | cnts = cnts[gene_names] 202 | # transform gene counts to log scale 203 | if log_counts: 204 | cnts = np.log(1 + cnts) 205 | 206 | # get features (histology embeddings) 207 | embs = get_embeddings(prefix) 208 | locs = get_locs(prefix, target_shape=embs.shape[:2]) 209 | embs_agg = smoothen_batch( 210 | embs, size=spot_radius, method='cnn', fill_missing=True) 211 | embs_spots = embs_agg[locs[:, 0], locs[:, 1]] 212 | return embs_spots, cnts, embs_agg 213 | 214 | 215 | def main(): 216 | prefix = sys.argv[1] # e.g. 'data/her2st/B1/' 217 | spot_radius = 10 218 | 219 | cache_file = prefix + 'a.pickle' 220 | if os.path.exists(cache_file): 221 | embs_train, cnts_train, embs_test, embs_agg = load_pickle(cache_file) 222 | else: 223 | gene_names = read_lines(f'{prefix}gene-names.txt') 224 | embs_train, cnts_train, embs_agg = get_training_data( 225 | prefix, gene_names, spot_radius=spot_radius) 226 | embs_test = get_embeddings(prefix) 227 | save_pickle( 228 | (embs_train, cnts_train, embs_test, embs_agg), cache_file) 229 | 230 | mask = np.isfinite(embs_test).all(-1) 231 | embs_test[mask] = embs_agg[mask] 232 | # super-resolution imputation 233 | impute( 234 | y_train=cnts_train, x_train=embs_train, x_test=embs_test, 235 | method='neural', prefix=prefix+'cnts-super/') 236 | 237 | 238 | if __name__ == '__main__': 239 | main() 240 | -------------------------------------------------------------------------------- /marker_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from einops import reduce 5 | 6 | from utils import read_lines, load_pickle, save_pickle, load_mask 7 | from visual import plot_matrix 8 | 9 | 10 | def compute_score(cnts, mask=None, factor=None): 11 | if mask is not None: 12 | cnts[~mask] = np.nan 13 | 14 | if factor is not None: 15 | cnts = reduce( 16 | cnts, '(h0 h1) (w0 w1) c -> h0 w0 c', 'mean', 17 | h1=factor, w1=factor) 18 | 19 | cnts -= np.nanmin(cnts, (0, 1)) 20 | cnts /= np.nanmax(cnts, (0, 1)) + 1e-12 21 | score = cnts.mean(-1) 22 | 23 | return score 24 | 25 | 26 | def get_marker_score(prefix, genes_marker, threshold=None, factor=1): 27 | 28 | genes = read_lines(prefix+'gene-names.txt') 29 | mask = load_mask(prefix+'mask-small.png', verbose=False) 30 | 31 | gene_names = np.array(list(set(genes_marker).intersection(genes))) 32 | 33 | if len(gene_names) < len(genes_marker): 34 | print('Genes not found:') 35 | print(set(genes_marker).difference(set(gene_names))) 36 | 37 | cnts = [ 38 | load_pickle( 39 | f'{prefix}cnts-super/{gname}.pickle', verbose=False) 40 | for gname in gene_names] 41 | cnts = np.stack(cnts, -1) 42 | 43 | if threshold is not None: 44 | isin = np.nanmax(cnts, (0, 1)) >= threshold 45 | if not isin.all(): 46 | print('Genes that do not pass the threshold:') 47 | print(gene_names[~isin]) 48 | cnts = cnts[:, :, isin] 49 | gene_names = gene_names[isin] 50 | 51 | score = compute_score(cnts, mask=mask, factor=factor) 52 | return score 53 | 54 | 55 | def get_args(): 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('prefix_inp', type=str) 58 | parser.add_argument('genes_marker', type=str) 59 | parser.add_argument('prefix_out', type=str) 60 | parser.add_argument('--threshold', type=float, default=1e-3) 61 | args = parser.parse_args() 62 | return args 63 | 64 | 65 | def main(): 66 | 67 | args = get_args() 68 | 69 | # compute marker score 70 | genes_marker = read_lines(args.genes_marker) 71 | score = get_marker_score(args.prefix_inp, genes_marker, args.threshold) 72 | save_pickle(score, args.prefix_out+'.pickle') 73 | 74 | # visualize marker score 75 | score = np.clip( 76 | score, np.nanquantile(score, 0.001), 77 | np.nanquantile(score, 0.999)) 78 | save_pickle(score, args.prefix_out+'.pickle') 79 | plot_matrix(score, args.prefix_out+'.png', white_background=True) 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /pixannot.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | 4 | import numpy as np 5 | 6 | from utils import save_pickle, read_lines, write_lines, load_tsv 7 | from visual import plot_matrix, plot_labels, plot_label_masks 8 | from marker_score import get_marker_score 9 | 10 | 11 | def adjust_temperature(probs, temperature): 12 | logits = np.log(probs) 13 | logits /= temperature 14 | probs = np.exp(logits) 15 | probs = probs / probs.sum(-1, keepdims=True) 16 | return probs 17 | 18 | 19 | def sample_from_scores(x, temperature=0.05): 20 | probs_raw = x / x.sum(-1, keepdims=True) 21 | probs = adjust_temperature(probs_raw, temperature=temperature) 22 | z = np.random.rand(*probs.shape[:-1], 1) 23 | threshs = np.cumsum(probs, -1) 24 | labels = (z > threshs).sum(-1) 25 | return labels 26 | 27 | 28 | def get_scores(prefix, marker_file): 29 | gene_names = read_lines(f'{prefix}gene-names.txt') 30 | df = load_tsv(marker_file, index=False) 31 | df = df[['gene', 'label']] 32 | labels = np.sort(df['label'].unique()).tolist() 33 | scores = [] 34 | for lab in labels: 35 | isin = (df['label'] == lab).to_numpy() 36 | gene_names = df['gene'][isin].to_numpy() 37 | sco = get_marker_score(prefix, gene_names) 38 | scores.append(sco) 39 | scores = np.stack(scores, -1) 40 | return scores, labels 41 | 42 | 43 | def predict(scores, sample=False): 44 | mask = np.isfinite(scores).all(-1) 45 | if sample: 46 | labels = sample_from_scores(scores, temperature=0.05) 47 | labels[~mask] = -1 48 | else: 49 | labels = np.full(mask.shape, -1) 50 | labels[mask] = scores[mask].argmax(-1) 51 | return labels 52 | 53 | 54 | def clean(s): 55 | s = re.sub('[^0-9a-zA-Z]+', '-', s) 56 | s = s.lower() 57 | return s 58 | 59 | 60 | def plot_annot(labels, confidence, threshold, label_names, prefix): 61 | 62 | labels = labels.copy() 63 | 64 | # treat low-confidence predictions as unclassified 65 | labels[labels >= 0] += 1 66 | labels[confidence < threshold] = 0 67 | lab_names = ['Unclassified'] + label_names 68 | 69 | write_lines(lab_names, f'{prefix}label-names.txt') 70 | save_pickle(labels, f'{prefix}labels.pickle') 71 | plot_labels( 72 | labels, f'{prefix}labels.png', 73 | white_background=True, 74 | cmap='tab10') 75 | plot_labels( 76 | labels, f'{prefix}labels-altcmap.png', 77 | white_background=True, 78 | cmap='Set3') 79 | lab_names_clean = [clean(lname) for lname in lab_names] 80 | plot_label_masks( 81 | labels, f'{prefix}masks/', 82 | names=lab_names_clean, 83 | white_background=True) 84 | 85 | 86 | def main(): 87 | 88 | np.random.seed(0) 89 | 90 | prefix_inp = sys.argv[1] # e.g. data/her2st/H123/ 91 | marker_file = sys.argv[2] # e.g. data/markers/celltype.tsv 92 | prefix_out = sys.argv[3] # e.g. data/her2st/H123/cell-types/ 93 | 94 | scores, lab_names = get_scores(prefix_inp, marker_file) 95 | 96 | for x, lname in zip(scores.transpose(2, 0, 1), lab_names): 97 | plot_matrix( 98 | x, f'{prefix_out}scores/{clean(lname)}.png', 99 | white_background=True) 100 | 101 | confidence = scores.max(-1) 102 | plot_matrix( 103 | confidence, f'{prefix_out}confidence.png', 104 | white_background=True) 105 | 106 | labels = predict(scores) 107 | 108 | for threshold in [0.01, 0.05, 0.10, 0.20]: 109 | plot_annot( 110 | labels, confidence, threshold, lab_names, 111 | f'{prefix_out}threshold{int(threshold*1000):03d}/') 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /plot_imputed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from utils import load_pickle, save_image, read_lines, load_image 7 | # from visual import cmap_turbo_truncated 8 | 9 | 10 | def plot_super( 11 | x, outfile, underground=None, truncate=None): 12 | 13 | x = x.copy() 14 | mask = np.isfinite(x) 15 | 16 | if truncate is not None: 17 | x -= np.nanmean(x) 18 | x /= np.nanstd(x) + 1e-12 19 | x = np.clip(x, truncate[0], truncate[1]) 20 | 21 | x -= np.nanmin(x) 22 | x /= np.nanmax(x) + 1e-12 23 | 24 | cmap = plt.get_cmap('turbo') 25 | # cmap = cmap_turbo_truncated 26 | if underground is not None: 27 | under = underground.mean(-1, keepdims=True) 28 | under -= under.min() 29 | under /= under.max() + 1e-12 30 | 31 | img = cmap(x)[..., :3] 32 | if underground is not None: 33 | img = img * 0.5 + under * 0.5 34 | img[~mask] = 1.0 35 | img = (img * 255).astype(np.uint8) 36 | save_image(img, outfile) 37 | 38 | 39 | def main(): 40 | 41 | prefix = sys.argv[1] # e.g. 'data/her2st/B1/' 42 | gene_names = read_lines(f'{prefix}gene-names.txt') 43 | mask = load_image(f'{prefix}mask-small.png') > 0 44 | 45 | for gn in gene_names: 46 | cnts = load_pickle(f'{prefix}cnts-super/{gn}.pickle') 47 | cnts[~mask] = np.nan 48 | plot_super(cnts, f'{prefix}cnts-super-plots/{gn}.png') 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /plot_spots.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | from einops import reduce 5 | 6 | from utils import load_image, load_tsv, read_lines, read_string 7 | from visual import plot_spots 8 | 9 | 10 | # def plot_spots(cnts, locs, underground, gene_names, radius, prefix): 11 | # under_weight = 0.2 12 | # cmap = plt.get_cmap('turbo') 13 | # under = underground.mean(-1, keepdims=True) 14 | # under = np.tile(under, 3) 15 | # under -= under.min() 16 | # under /= under.max() + 1e-12 17 | # for k, name in enumerate(gene_names): 18 | # x = cnts[:, k] 19 | # x = x - x.min() 20 | # x = x / (x.max() + 1e-12) 21 | # img = under * under_weight 22 | # for u, ij in zip(x, locs): 23 | # lower = np.clip(ij - radius, 0, None) 24 | # upper = np.clip(ij + radius, None, img.shape[:2]) 25 | # color = np.array(cmap(u)[:3]) * (1 - under_weight) 26 | # img[lower[0]:upper[0], lower[1]:upper[1]] += color 27 | # img = (img * 255).astype(np.uint8) 28 | # save_image(img, f'{prefix}{name}.png') 29 | 30 | 31 | def plot_spots_multi( 32 | cnts, locs, gene_names, radius, img, prefix, 33 | disk_mask=True): 34 | for i, gname in enumerate(gene_names): 35 | ct = cnts[:, i] 36 | outfile = f'{prefix}{gname}.png' 37 | plot_spots( 38 | img=img, cnts=ct, locs=locs, radius=radius, 39 | cmap='turbo', weight=1.0, 40 | disk_mask=disk_mask, 41 | outfile=outfile) 42 | 43 | 44 | def main(): 45 | prefix = sys.argv[1] # e.g. 'data/her2st/B1/' 46 | factor = 16 47 | 48 | infile_cnts = f'{prefix}cnts.tsv' 49 | infile_locs = f'{prefix}locs.tsv' 50 | infile_img = f'{prefix}he.jpg' 51 | infile_genes = f'{prefix}gene-names.txt' 52 | infile_radius = f'{prefix}radius.txt' 53 | 54 | # load data 55 | cnts = load_tsv(infile_cnts) 56 | locs = load_tsv(infile_locs) 57 | assert (cnts.index == locs.index).all() 58 | spot_radius = int(read_string(infile_radius)) 59 | img = load_image(infile_img) 60 | 61 | if img.dtype == bool: 62 | img = img.astype(np.uint8) * 255 63 | if img.ndim == 2: 64 | img = np.tile(img[..., np.newaxis], 3) 65 | 66 | # select genes 67 | gene_names = read_lines(infile_genes) 68 | cnts = cnts[gene_names] 69 | cnts = cnts.to_numpy() 70 | 71 | # recale image 72 | locs = locs.astype(float) 73 | locs = np.stack([locs['y'], locs['x']], -1) 74 | locs /= factor 75 | locs = locs.round().astype(int) 76 | img = reduce( 77 | img.astype(float), '(h1 h) (w1 w) c -> h1 w1 c', 'mean', 78 | h=factor, w=factor).astype(np.uint8) 79 | 80 | # rescale spot 81 | spot_radius = np.round(spot_radius / factor).astype(int) 82 | 83 | # plot spot-level gene expression measurements 84 | plot_spots_multi( 85 | cnts=cnts, 86 | locs=locs, gene_names=gene_names, 87 | radius=spot_radius, disk_mask=True, 88 | img=img, prefix=prefix+'spots/') 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from einops import reduce 5 | 6 | from utils import load_image, save_image, load_mask 7 | from image import crop_image 8 | 9 | 10 | def adjust_margins(img, pad, pad_value=None): 11 | extent = np.stack([[0, 0], img.shape[:2]]).T 12 | # make size divisible by pad without changing coords 13 | remainder = (extent[:, 1] - extent[:, 0]) % pad 14 | complement = (pad - remainder) % pad 15 | extent[:, 1] += complement 16 | if pad_value is None: 17 | mode = 'edge' 18 | else: 19 | mode = 'constant' 20 | img = crop_image( 21 | img, extent, mode=mode, constant_values=pad_value) 22 | return img 23 | 24 | 25 | def reduce_mask(mask, factor): 26 | mask = reduce( 27 | mask.astype(np.float32), 28 | '(h0 h1) (w0 w1) -> h0 w0', 'mean', 29 | h1=factor, w1=factor) > 0.5 30 | return mask 31 | 32 | 33 | def get_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('prefix', type=str) 36 | parser.add_argument('--image', action='store_true') 37 | parser.add_argument('--mask', action='store_true') 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def main(): 43 | 44 | pad = 256 45 | args = get_args() 46 | 47 | if args.image: 48 | # load histology image 49 | img = load_image(args.prefix+'he-scaled.jpg') 50 | # pad image with white to make dimension divisible by 256 51 | img = adjust_margins(img, pad=pad, pad_value=255) 52 | # save histology image 53 | save_image(img, f'{args.prefix}he.jpg') 54 | 55 | if args.mask: 56 | # load tissue mask 57 | mask = load_mask(args.prefix+'mask-scaled.png') 58 | # pad mask with False to make dimension divisible by 256 59 | mask = adjust_margins(mask, pad=pad, pad_value=mask.min()) 60 | # save tissue mask 61 | save_image(mask, f'{args.prefix}mask.png') 62 | # save_image(~mask, f'{args.prefix}mask-whitebg.png') 63 | mask = reduce_mask(mask, factor=16) 64 | save_image(mask, f'{args.prefix}mask-small.png') 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /reduce_dim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.decomposition import PCA 3 | from umap import UMAP 4 | 5 | 6 | def reduce_dim( 7 | x, n_components, method='pca', 8 | pre_normalize=False, post_normalize=False): 9 | 10 | if n_components >= 1: 11 | n_components = int(n_components) 12 | 13 | isfin = np.isfinite(x).all(-1) 14 | if pre_normalize: 15 | x -= x[isfin].mean(0) 16 | x /= x[isfin].std(0) 17 | 18 | if method == 'pca': 19 | model = PCA(n_components=n_components) 20 | elif method == 'umap': 21 | model = UMAP( 22 | n_components=n_components, n_neighbors=20, min_dist=0.0, 23 | n_jobs=-1, random_state=0, verbose=True) 24 | else: 25 | raise ValueError(f'Method `{method}` not recognized') 26 | 27 | print(x[isfin].shape) 28 | u = model.fit_transform(x[isfin]) 29 | print('n_components:', u.shape[-1], '/', x.shape[-1]) 30 | if method == 'pca': 31 | print('pve:', model.explained_variance_ratio_.sum()) 32 | 33 | # order components by variance 34 | order = np.nanvar(u, axis=0).argsort()[::-1] 35 | u = u[:, order] 36 | # make all components have variance == 1 37 | if post_normalize: 38 | u -= u.mean(0) 39 | u /= u.std(0) 40 | z = np.full( 41 | isfin.shape + (u.shape[-1],), 42 | np.nan, dtype=np.float32) 43 | z[isfin] = u 44 | return z, model 45 | -------------------------------------------------------------------------------- /reorganize_imputed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | from utils import read_lines, load_tsv, save_tsv 4 | 5 | 6 | def main(): 7 | prefix = sys.argv[1] 8 | gene_names = read_lines(prefix + 'gene-names.txt') 9 | dfs = { 10 | gn: load_tsv(f'{prefix}cnts-clustered/by-genes/{gn}.tsv') 11 | for gn in gene_names} 12 | columns = dfs[list(dfs.keys())[0]].columns 13 | results = {} 14 | for col in columns: 15 | results[col] = pd.DataFrame( 16 | {gn: df[col] for gn, df in dfs.items()}) 17 | results['sum'] = ( 18 | (results['mean'] * results['count']) 19 | .round().astype(int)) 20 | for key, val in results.items(): 21 | save_tsv(val, f'{prefix}cnts-clustered/by-clusters/{key}.tsv') 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.25.2 2 | pillow==10.0.0 3 | pandas==2.1.0 4 | scikit-image==0.21.0 5 | opencv-python==4.8.0.76 6 | einops==0.6.1 7 | torch==2.0.1 8 | torchvision==0.15.2 9 | tomli==2.0.1 10 | pytorch-lightning==2.0.8 11 | matplotlib==3.7.2 12 | scikit-learn==1.3.1 13 | umap-learn==0.5.5 14 | seaborn==0.13.1 15 | -------------------------------------------------------------------------------- /rescale.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from time import time 4 | 5 | from skimage.transform import rescale 6 | import numpy as np 7 | 8 | from utils import ( 9 | load_image, save_image, read_string, write_string, 10 | load_tsv, save_tsv) 11 | 12 | 13 | def get_image_filename(prefix): 14 | file_exists = False 15 | for suffix in ['.jpg', '.png', '.tiff']: 16 | filename = prefix + suffix 17 | if os.path.exists(filename): 18 | file_exists = True 19 | break 20 | if not file_exists: 21 | raise FileNotFoundError('Image not found') 22 | return filename 23 | 24 | 25 | # def rescale_image(img, scale): 26 | # if img.ndim == 2: 27 | # img = rescale(img, scale, preserve_range=True) 28 | # elif img.ndim == 3: 29 | # channels = img.transpose(2, 0, 1) 30 | # channels = [rescale_image(c, scale) for c in channels] 31 | # img = np.stack(channels, -1) 32 | # else: 33 | # raise ValueError('Unrecognized image ndim') 34 | # return img 35 | 36 | 37 | def rescale_image(img, scale): 38 | if img.ndim == 2: 39 | scale = [scale, scale] 40 | elif img.ndim == 3: 41 | scale = [scale, scale, 1] 42 | else: 43 | raise ValueError('Unrecognized image ndim') 44 | img = rescale(img, scale, preserve_range=True) 45 | return img 46 | 47 | 48 | def get_args(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('prefix', type=str) 51 | parser.add_argument('--image', action='store_true') 52 | parser.add_argument('--mask', action='store_true') 53 | parser.add_argument('--locs', action='store_true') 54 | parser.add_argument('--radius', action='store_true') 55 | args = parser.parse_args() 56 | return args 57 | 58 | 59 | def main(): 60 | 61 | args = get_args() 62 | 63 | pixel_size_raw = float(read_string(args.prefix+'pixel-size-raw.txt')) 64 | pixel_size = float(read_string(args.prefix+'pixel-size.txt')) 65 | scale = pixel_size_raw / pixel_size 66 | 67 | if args.image: 68 | img = load_image(get_image_filename(args.prefix+'he-raw')) 69 | img = img.astype(np.float32) 70 | print(f'Rescaling image (scale: {scale:.3f})...') 71 | t0 = time() 72 | img = rescale_image(img, scale) 73 | print(int(time() - t0), 'sec') 74 | img = img.astype(np.uint8) 75 | save_image(img, args.prefix+'he-scaled.jpg') 76 | 77 | if args.mask: 78 | mask = load_image(args.prefix+'mask-raw.png') 79 | mask = mask > 0 80 | if mask.ndim == 3: 81 | mask = mask.any(2) 82 | print(f'Rescaling mask (scale: {scale:.3f})...') 83 | t0 = time() 84 | mask = rescale_image(mask.astype(np.float32), scale) 85 | print(int(time() - t0)) 86 | mask = mask > 0.5 87 | save_image(mask, args.prefix+'mask-scaled.png') 88 | 89 | if args.locs: 90 | locs = load_tsv(args.prefix+'locs-raw.tsv') 91 | locs = locs * scale 92 | locs = locs.round().astype(int) 93 | save_tsv(locs, args.prefix+'locs.tsv') 94 | 95 | if args.radius: 96 | radius = float(read_string(args.prefix+'radius-raw.txt')) 97 | radius = radius * scale 98 | radius = np.round(radius).astype(int) 99 | write_string(radius, args.prefix+'radius.txt') 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | prefix=$1 # e.g. data/demo/ 5 | 6 | device="cuda" # "cuda" or "cpu" 7 | pixel_size=0.5 # desired pixel size for the whole analysis 8 | n_genes=1000 # number of most variable genes to impute 9 | 10 | # preprocess histology image 11 | echo $pixel_size > ${prefix}pixel-size.txt 12 | python rescale.py ${prefix} --image 13 | python preprocess.py ${prefix} --image 14 | 15 | # extract histology features 16 | python extract_features.py ${prefix} --device=${device} 17 | # # If you want to retun model, you need to delete the existing results: 18 | # rm ${prefix}embeddings-hist-raw.pickle 19 | 20 | # auto detect tissue mask 21 | # If you have a user-defined tissue mask, put it at `${prefix}mask-raw.png` and comment out the line below 22 | python get_mask.py ${prefix}embeddings-hist.pickle ${prefix}mask-small.png 23 | 24 | # # segment image by histology features 25 | # python cluster.py --mask=${prefix}mask-small.png --n-clusters=10 ${prefix}embeddings-hist.pickle ${prefix}clusters-hist/ 26 | # # # segment image by histology features without tissue mask 27 | # # python cluster.py ${prefix}embeddings-hist.pickle ${prefix}clusters-hist/unmasked/ 28 | 29 | # select most highly variable genes to predict 30 | # If you have a user-defined list of genes, put it at `${prefix}gene-names.txt` and comment out the line below 31 | python select_genes.py --n-top=${n_genes} "${prefix}cnts.tsv" "${prefix}gene-names.txt" 32 | 33 | # predict super-resolution gene expression 34 | # rescale coordinates and spot radius 35 | python rescale.py ${prefix} --locs --radius 36 | 37 | # train gene expression prediction model and predict at super-resolution 38 | python impute.py ${prefix} --epochs=400 --device=${device} # train model from scratch 39 | # # If you want to retrain model, you need to delete the existing model: 40 | # rm -r ${prefix}states 41 | 42 | # visualize imputed gene expression 43 | python plot_imputed.py ${prefix} 44 | 45 | # segment image by gene features 46 | python cluster.py --filter-size=8 --min-cluster-size=20 --n-clusters=10 --mask=${prefix}mask-small.png ${prefix}embeddings-gene.pickle ${prefix}clusters-gene/ 47 | # # segment image without tissue mask 48 | # python cluster.py --filter-size=8 --min-cluster-size=20 ${prefix}embeddings-gene.pickle ${prefix}clusters-gene/unmasked/ 49 | # # segment image without spatial smoothing 50 | # python cluster.py --mask=${prefix}mask-small.png ${prefix}embeddings-gene.pickle ${prefix}clusters-gene/unsmoothed/ 51 | # python cluster.py ${prefix}embeddings-gene.pickle ${prefix}clusters-gene/unsmoothed/unmasked/ 52 | 53 | # differential analysis by clusters 54 | python aggregate_imputed.py ${prefix} 55 | python reorganize_imputed.py ${prefix} 56 | python differential.py ${prefix} 57 | 58 | # visualize spot-level gene expression data 59 | python plot_spots.py ${prefix} 60 | 61 | # # cell type inference 62 | # # see data/markers/cell-type-template.tsv for an example of a cell type reference panel 63 | # python pixannot.py ${prefix} data/markers/cell-type.tsv ${prefix}markers/cell-type/ 64 | # cp -r ${prefix}markers/cell-type/threshold010/* ${prefix}markers/cell-type/ 65 | # python enrich.py ${prefix}clusters-gene/ ${prefix}markers/cell-type/ 66 | 67 | # # user-defined tissue structure signature scores 68 | # # see data/markers/signature-score-template.tsv for an example of a signature score reference panel 69 | # python marker_score.py ${prefix} data/markers/signature-score.txt ${prefix}markers/signature-score 70 | -------------------------------------------------------------------------------- /run_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | prefix="data/demo/" 5 | 6 | # download demo data 7 | ./download_demo.sh $prefix 8 | # download pretrained models 9 | ./download_checkpoints.sh 10 | # run pipeline 11 | ./run.sh $prefix 12 | -------------------------------------------------------------------------------- /select_genes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import load_tsv, write_lines, read_lines 3 | 4 | 5 | def get_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('inpfile', type=str, help='e.g. demo/cnts.tsv') 8 | parser.add_argument( 9 | 'outfile', type=str, help='e.g. demo/gene-names.txt') 10 | parser.add_argument('--n-top', type=int, default=None, help='e.g. 50') 11 | parser.add_argument( 12 | '--extra', type=str, default=None, 13 | help='demo/marker-genes.txt') 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def main(): 19 | 20 | args = get_args() 21 | 22 | cnts = load_tsv(args.inpfile) 23 | order = cnts.var().to_numpy().argsort()[::-1] 24 | names = cnts.columns.to_list() 25 | names_all = [names[i] for i in order] 26 | 27 | names_top = names_all 28 | if args.n_top is not None: 29 | names_top = names_top[:args.n_top] 30 | 31 | if args.extra is None: 32 | names_extra = [] 33 | else: 34 | names_extra = read_lines(args.extra) 35 | names_extra = [ 36 | name for name in names_extra 37 | if (name in names_all) and (name not in names_top)] 38 | 39 | names = names_extra + names_top 40 | 41 | write_lines(names, args.outfile) 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | -------------------------------------------------------------------------------- /structural_similarity.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | from scipy.ndimage import uniform_filter 5 | 6 | from skimage._shared import utils 7 | from skimage._shared.filters import gaussian 8 | from skimage._shared.utils import ( 9 | _supported_float_type, check_shape_equality, warn) 10 | from skimage.util.arraycrop import crop 11 | from skimage.util.dtype import dtype_range 12 | 13 | __all__ = ['structural_similarity'] 14 | 15 | 16 | @utils.deprecate_multichannel_kwarg() 17 | def structural_similarity(im1, im2, 18 | *, 19 | win_size=None, gradient=False, data_range=None, 20 | channel_axis=None, multichannel=False, 21 | gaussian_weights=False, full=False, **kwargs): 22 | """ 23 | Compute the mean structural similarity index between two images. 24 | 25 | Parameters 26 | ---------- 27 | im1, im2 : ndarray 28 | Images. Any dimensionality with same shape. 29 | win_size : int or None, optional 30 | The side-length of the sliding window used in comparison. Must be an 31 | odd value. If `gaussian_weights` is True, this is ignored and the 32 | window size will depend on `sigma`. 33 | gradient : bool, optional 34 | If True, also return the gradient with respect to im2. 35 | data_range : float, optional 36 | The data range of the input image (distance between minimum and 37 | maximum possible values). By default, this is estimated from the image 38 | data-type. 39 | channel_axis : int or None, optional 40 | If None, the image is assumed to be a grayscale (single channel) image. 41 | Otherwise, this parameter indicates which axis of the array corresponds 42 | to channels. 43 | 44 | .. versionadded:: 0.19 45 | ``channel_axis`` was added in 0.19. 46 | multichannel : bool, optional 47 | If True, treat the last dimension of the array as channels. Similarity 48 | calculations are done independently for each channel then averaged. 49 | This argument is deprecated: specify `channel_axis` instead. 50 | gaussian_weights : bool, optional 51 | If True, each patch has its mean and variance spatially weighted by a 52 | normalized Gaussian kernel of width sigma=1.5. 53 | full : bool, optional 54 | If True, also return the full structural similarity image. 55 | 56 | Other Parameters 57 | ---------------- 58 | use_sample_covariance : bool 59 | If True, normalize covariances by N-1 rather than, N where N is the 60 | number of pixels within the sliding window. 61 | K1 : float 62 | Algorithm parameter, K1 (small constant, see [1]_). 63 | K2 : float 64 | Algorithm parameter, K2 (small constant, see [1]_). 65 | sigma : float 66 | Standard deviation for the Gaussian when `gaussian_weights` is True. 67 | 68 | Returns 69 | ------- 70 | mssim : float 71 | The mean structural similarity index over the image. 72 | grad : ndarray 73 | The gradient of the structural similarity between im1 and im2 [2]_. 74 | This is only returned if `gradient` is set to True. 75 | S : ndarray 76 | The full SSIM image. This is only returned if `full` is set to True. 77 | 78 | Notes 79 | ----- 80 | To match the implementation of Wang et. al. [1]_, set `gaussian_weights` 81 | to True, `sigma` to 1.5, and `use_sample_covariance` to False. 82 | 83 | .. versionchanged:: 0.16 84 | This function was renamed from ``skimage.measure.compare_ssim`` to 85 | ``skimage.metrics.structural_similarity``. 86 | 87 | References 88 | ---------- 89 | .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. 90 | (2004). Image quality assessment: From error visibility to 91 | structural similarity. IEEE Transactions on Image Processing, 92 | 13, 600-612. 93 | https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, 94 | :DOI:`10.1109/TIP.2003.819861` 95 | 96 | .. [2] Avanaki, A. N. (2009). Exact global histogram specification 97 | optimized for structural similarity. Optical Review, 16, 613-621. 98 | :arxiv:`0901.0065` 99 | :DOI:`10.1007/s10043-009-0119-z` 100 | 101 | """ 102 | check_shape_equality(im1, im2) 103 | float_type = _supported_float_type(im1.dtype) 104 | 105 | if channel_axis is not None: 106 | # loop over channels 107 | args = dict(win_size=win_size, 108 | gradient=gradient, 109 | data_range=data_range, 110 | channel_axis=None, 111 | gaussian_weights=gaussian_weights, 112 | full=full) 113 | args.update(kwargs) 114 | nch = im1.shape[channel_axis] 115 | mssim = np.empty(nch, dtype=float_type) 116 | 117 | if gradient: 118 | G = np.empty(im1.shape, dtype=float_type) 119 | if full: 120 | S = np.empty(im1.shape, dtype=float_type) 121 | channel_axis = channel_axis % im1.ndim 122 | _at = functools.partial(utils.slice_at_axis, axis=channel_axis) 123 | for ch in range(nch): 124 | ch_result = structural_similarity(im1[_at(ch)], 125 | im2[_at(ch)], **args) 126 | if gradient and full: 127 | mssim[ch], G[_at(ch)], S[_at(ch)] = ch_result 128 | elif gradient: 129 | mssim[ch], G[_at(ch)] = ch_result 130 | elif full: 131 | mssim[ch], S[_at(ch)] = ch_result 132 | else: 133 | mssim[ch] = ch_result 134 | mssim = mssim.mean() 135 | if gradient and full: 136 | return mssim, G, S 137 | elif gradient: 138 | return mssim, G 139 | elif full: 140 | return mssim, S 141 | else: 142 | return mssim 143 | 144 | K1 = kwargs.pop('K1', 0.01) 145 | K2 = kwargs.pop('K2', 0.03) 146 | K3 = kwargs.pop('K3', K2 / np.sqrt(2)) 147 | alpha = kwargs.pop('alpha', 1) 148 | beta = kwargs.pop('beta', 1) 149 | gamma = kwargs.pop('gamma', 1) 150 | sigma = kwargs.pop('sigma', 1.5) 151 | if K1 < 0: 152 | raise ValueError("K1 must be positive") 153 | if K2 < 0: 154 | raise ValueError("K2 must be positive") 155 | if K3 < 0: 156 | raise ValueError("K3 must be positive") 157 | if sigma < 0: 158 | raise ValueError("sigma must be positive") 159 | use_sample_covariance = kwargs.pop('use_sample_covariance', True) 160 | 161 | if gaussian_weights: 162 | # Set to give an 11-tap filter with the default sigma of 1.5 to match 163 | # Wang et. al. 2004. 164 | truncate = 3.5 165 | 166 | if win_size is None: 167 | if gaussian_weights: 168 | # set win_size used by crop to match the filter size 169 | r = int(truncate * sigma + 0.5) # radius as in ndimage 170 | win_size = 2 * r + 1 171 | else: 172 | win_size = 7 # backwards compatibility 173 | 174 | if np.any((np.asarray(im1.shape) - win_size) < 0): 175 | raise ValueError( 176 | 'win_size exceeds image extent. ' 177 | 'Either ensure that your images are ' 178 | 'at least 7x7; or pass win_size explicitly ' 179 | 'in the function call, with an odd value ' 180 | 'less than or equal to the smaller side of your ' 181 | 'images. If your images are multichannel ' 182 | '(with color channels), set channel_axis to ' 183 | 'the axis number corresponding to the channels.') 184 | 185 | if not (win_size % 2 == 1): 186 | raise ValueError('Window size must be odd.') 187 | 188 | if data_range is None: 189 | if im1.dtype != im2.dtype: 190 | warn("Inputs have mismatched dtype. Setting data_range based on " 191 | "im1.dtype.", stacklevel=2) 192 | dmin, dmax = dtype_range[im1.dtype.type] 193 | data_range = dmax - dmin 194 | 195 | ndim = im1.ndim 196 | 197 | if gaussian_weights: 198 | filter_func = gaussian 199 | filter_args = {'sigma': sigma, 'truncate': truncate, 'mode': 'reflect'} 200 | else: 201 | filter_func = uniform_filter 202 | filter_args = {'size': win_size} 203 | 204 | # ndimage filters need floating point data 205 | im1 = im1.astype(float_type, copy=False) 206 | im2 = im2.astype(float_type, copy=False) 207 | 208 | NP = win_size ** ndim 209 | 210 | # filter has already normalized by NP 211 | if use_sample_covariance: 212 | cov_norm = NP / (NP - 1) # sample covariance 213 | else: 214 | cov_norm = 1.0 # population covariance to match Wang et. al. 2004 215 | 216 | # compute (weighted) means 217 | ux = filter_func(im1, **filter_args) 218 | uy = filter_func(im2, **filter_args) 219 | 220 | # compute (weighted) variances and covariances 221 | uxx = filter_func(im1 * im1, **filter_args) 222 | uyy = filter_func(im2 * im2, **filter_args) 223 | uxy = filter_func(im1 * im2, **filter_args) 224 | vx = cov_norm * (uxx - ux * ux) 225 | vy = cov_norm * (uyy - uy * uy) 226 | vxy = cov_norm * (uxy - ux * uy) 227 | 228 | R = data_range 229 | C1 = (K1 * R) ** 2 230 | C2 = (K2 * R) ** 2 231 | C3 = (K3 * R) ** 2 232 | 233 | sxsy = np.sqrt(np.clip(vx * vy, 0, None)) 234 | 235 | A1, A2, A3, B1, B2, B3 = (( 236 | 2 * ux * uy + C1, 237 | 2 * sxsy + C2, 238 | vxy + C3, 239 | ux ** 2 + uy ** 2 + C1, 240 | vx + vy + C2, 241 | sxsy + C3, 242 | )) 243 | S = (A1/B1)**alpha * (A2/B2)**beta * (A3/B3)**gamma 244 | 245 | # to avoid edge effects will ignore filter radius strip around edges 246 | pad = (win_size - 1) // 2 247 | 248 | # compute (weighted) mean of ssim. Use float64 for accuracy. 249 | mssim = crop(S, pad).mean(dtype=np.float64) 250 | 251 | if full: 252 | return mssim, S 253 | else: 254 | return mssim 255 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from time import time 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import pytorch_lightning as pl 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | 12 | from utils import load_pickle, save_pickle 13 | 14 | 15 | matplotlib.use('Agg') 16 | 17 | 18 | class MetricTracker(pl.Callback): 19 | 20 | def __init__(self): 21 | self.collection = [] 22 | 23 | def on_train_epoch_end(self, trainer, *args, **kwargs): 24 | metrics = deepcopy(trainer.logged_metrics) 25 | self.collection.append(metrics) 26 | 27 | def clean(self): 28 | keys = [set(e.keys()) for e in self.collection] 29 | keys = set().union(*keys) 30 | for elem in self.collection: 31 | for ke in keys: 32 | if ke in elem.keys(): 33 | if isinstance(elem[ke], torch.Tensor): 34 | elem[ke] = elem[ke].item() 35 | else: 36 | elem[ke] = float('nan') 37 | 38 | 39 | def get_model( 40 | model_class, model_kwargs, dataset, prefix, 41 | epochs=None, device='cuda', load_saved=False, **kwargs): 42 | checkpoint_file = prefix + 'model.pt' 43 | history_file = prefix + 'history.pickle' 44 | 45 | # load model if exists 46 | if load_saved and os.path.exists(checkpoint_file): 47 | model = model_class.load_from_checkpoint(checkpoint_file) 48 | print(f'Model loaded from {checkpoint_file}') 49 | history = load_pickle(history_file) 50 | else: 51 | model = None 52 | history = [] 53 | 54 | # train model 55 | if (epochs is not None) and (epochs > 0): 56 | model, hist, trainer = train_model( 57 | model=model, 58 | model_class=model_class, model_kwargs=model_kwargs, 59 | dataset=dataset, epochs=epochs, device=device, 60 | **kwargs) 61 | trainer.save_checkpoint(checkpoint_file) 62 | print(f'Model saved to {checkpoint_file}') 63 | history += hist 64 | save_pickle(history, history_file) 65 | print(f'History saved to {history_file}') 66 | plot_history(history, prefix) 67 | 68 | return model 69 | 70 | 71 | def train_model( 72 | dataset, batch_size, epochs, 73 | model=None, model_class=None, model_kwargs={}, 74 | device='cuda'): 75 | if model is None: 76 | model = model_class(**model_kwargs) 77 | dataloader = DataLoader( 78 | dataset, batch_size=batch_size, 79 | shuffle=True) 80 | tracker = MetricTracker() 81 | device_accelerator_dict = { 82 | 'cuda': 'gpu', 83 | 'cpu': 'cpu'} 84 | accelerator = device_accelerator_dict[device] 85 | trainer = pl.Trainer( 86 | max_epochs=epochs, 87 | callbacks=[tracker], 88 | deterministic=True, 89 | accelerator=accelerator, 90 | devices=1, 91 | logger=False, 92 | enable_checkpointing=False, 93 | enable_progress_bar=True) 94 | model.train() 95 | t0 = time() 96 | trainer.fit(model=model, train_dataloaders=dataloader) 97 | print(int(time() - t0), 'sec') 98 | tracker.clean() 99 | history = tracker.collection 100 | return model, history, trainer 101 | 102 | 103 | def plot_history(history, prefix): 104 | plt.figure(figsize=(16, 16)) 105 | groups = set([e.split('_')[-1] for e in history[0].keys()]) 106 | groups = np.sort(list(groups)) 107 | for i, grp in enumerate(groups): 108 | plt.subplot(len(groups), 1, 1+i) 109 | for metric in history[0].keys(): 110 | if metric.endswith(grp): 111 | hist = np.array([e[metric] for e in history]) 112 | hmin, hmax = hist.min(), hist.max() 113 | label = f'{metric} ({hmin:+013.6f}, {hmax:+013.6f})' 114 | hist -= hmin 115 | hist /= hmax + 1e-12 116 | plt.plot(hist, label=label) 117 | plt.legend() 118 | plt.ylim(0, 1) 119 | outfile = f'{prefix}history.png' 120 | plt.savefig(outfile, dpi=300) 121 | plt.close() 122 | print(outfile) 123 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from PIL import Image 3 | import pickle 4 | import os 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import yaml 9 | 10 | 11 | Image.MAX_IMAGE_PIXELS = None 12 | 13 | 14 | def mkdir(path): 15 | dirname = os.path.dirname(path) 16 | if dirname != '': 17 | os.makedirs(dirname, exist_ok=True) 18 | 19 | 20 | def load_image(filename, verbose=True): 21 | img = Image.open(filename) 22 | img = np.array(img) 23 | if img.ndim == 3 and img.shape[-1] == 4: 24 | img = img[..., :3] # remove alpha channel 25 | if verbose: 26 | print(f'Image loaded from {filename}') 27 | return img 28 | 29 | 30 | def load_mask(filename, verbose=True): 31 | mask = load_image(filename, verbose=verbose) 32 | mask = mask > 0 33 | if mask.ndim == 3: 34 | mask = mask.any(2) 35 | return mask 36 | 37 | 38 | def save_image(img, filename): 39 | mkdir(filename) 40 | Image.fromarray(img).save(filename) 41 | print(filename) 42 | 43 | 44 | def read_lines(filename): 45 | with open(filename, 'r') as file: 46 | lines = [line.rstrip() for line in file] 47 | return lines 48 | 49 | 50 | def read_string(filename): 51 | return read_lines(filename)[0] 52 | 53 | 54 | def write_lines(strings, filename): 55 | mkdir(filename) 56 | with open(filename, 'w') as file: 57 | for s in strings: 58 | file.write(f'{s}\n') 59 | print(filename) 60 | 61 | 62 | def write_string(string, filename): 63 | return write_lines([string], filename) 64 | 65 | 66 | def save_pickle(x, filename): 67 | mkdir(filename) 68 | with open(filename, 'wb') as file: 69 | pickle.dump(x, file) 70 | print(filename) 71 | 72 | 73 | def load_pickle(filename, verbose=True): 74 | with open(filename, 'rb') as file: 75 | x = pickle.load(file) 76 | if verbose: 77 | print(f'Pickle loaded from {filename}') 78 | return x 79 | 80 | 81 | def load_tsv(filename, index=True): 82 | if index: 83 | index_col = 0 84 | else: 85 | index_col = None 86 | df = pd.read_csv(filename, sep='\t', header=0, index_col=index_col) 87 | print(f'Dataframe loaded from {filename}') 88 | return df 89 | 90 | 91 | def save_tsv(x, filename, **kwargs): 92 | mkdir(filename) 93 | if 'sep' not in kwargs.keys(): 94 | kwargs['sep'] = '\t' 95 | x.to_csv(filename, **kwargs) 96 | print(filename) 97 | 98 | 99 | def load_yaml(filename, verbose=False): 100 | with open(filename, 'r') as file: 101 | content = yaml.safe_load(file) 102 | if verbose: 103 | print(f'YAML loaded from {filename}') 104 | return content 105 | 106 | 107 | def save_yaml(filename, content): 108 | with open(filename, 'w') as file: 109 | yaml.dump(content, file) 110 | print(file) 111 | 112 | 113 | def join(x): 114 | return list(itertools.chain.from_iterable(x)) 115 | 116 | 117 | def get_most_frequent(x): 118 | # return the most frequent element in array 119 | uniqs, counts = np.unique(x, return_counts=True) 120 | return uniqs[counts.argmax()] 121 | 122 | 123 | def sort_labels(labels, descending=True): 124 | labels = labels.copy() 125 | isin = labels >= 0 126 | labels_uniq, labels[isin], counts = np.unique( 127 | labels[isin], return_inverse=True, return_counts=True) 128 | c = counts 129 | if descending: 130 | c = c * (-1) 131 | order = c.argsort() 132 | rank = order.argsort() 133 | labels[isin] = rank[labels[isin]] 134 | return labels, labels_uniq[order] 135 | -------------------------------------------------------------------------------- /vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """ 12 | Mostly copy-paste from timm library. 13 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 14 | """ 15 | import math 16 | from functools import partial 17 | import warnings 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | 23 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 24 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 25 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 26 | def norm_cdf(x): 27 | # Computes standard normal cumulative distribution function 28 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 29 | 30 | if (mean < a - 2 * std) or (mean > b + 2 * std): 31 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 32 | "The distribution of values may be incorrect.", 33 | stacklevel=2) 34 | 35 | with torch.no_grad(): 36 | # Values are generated by using a truncated uniform distribution and 37 | # then using the inverse CDF for the normal distribution. 38 | # Get upper and lower cdf values 39 | l = norm_cdf((a - mean) / std) 40 | u = norm_cdf((b - mean) / std) 41 | 42 | # Uniformly fill tensor with values from [l, u], then translate to 43 | # [2l-1, 2u-1]. 44 | tensor.uniform_(2 * l - 1, 2 * u - 1) 45 | 46 | # Use inverse cdf transform for normal distribution to get truncated 47 | # standard normal 48 | tensor.erfinv_() 49 | 50 | # Transform to proper mean, std 51 | tensor.mul_(std * math.sqrt(2.)) 52 | tensor.add_(mean) 53 | 54 | # Clamp to ensure it's in the proper range 55 | tensor.clamp_(min=a, max=b) 56 | return tensor 57 | 58 | 59 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 60 | # type: (Tensor, float, float, float, float) -> Tensor 61 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 62 | 63 | 64 | def drop_path(x, drop_prob: float = 0., training: bool = False): 65 | if drop_prob == 0. or not training: 66 | return x 67 | keep_prob = 1 - drop_prob 68 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 69 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 70 | random_tensor.floor_() # binarize 71 | output = x.div(keep_prob) * random_tensor 72 | return output 73 | 74 | 75 | class DropPath(nn.Module): 76 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 77 | """ 78 | def __init__(self, drop_prob=None): 79 | super(DropPath, self).__init__() 80 | self.drop_prob = drop_prob 81 | 82 | def forward(self, x): 83 | return drop_path(x, self.drop_prob, self.training) 84 | 85 | 86 | class Mlp(nn.Module): 87 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 88 | super().__init__() 89 | out_features = out_features or in_features 90 | hidden_features = hidden_features or in_features 91 | self.fc1 = nn.Linear(in_features, hidden_features) 92 | self.act = act_layer() 93 | self.fc2 = nn.Linear(hidden_features, out_features) 94 | self.drop = nn.Dropout(drop) 95 | 96 | def forward(self, x): 97 | x = self.fc1(x) 98 | x = self.act(x) 99 | x = self.drop(x) 100 | x = self.fc2(x) 101 | x = self.drop(x) 102 | return x 103 | 104 | 105 | class Attention(nn.Module): 106 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 107 | super().__init__() 108 | self.num_heads = num_heads 109 | head_dim = dim // num_heads 110 | self.scale = qk_scale or head_dim ** -0.5 111 | 112 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 113 | self.attn_drop = nn.Dropout(attn_drop) 114 | self.proj = nn.Linear(dim, dim) 115 | self.proj_drop = nn.Dropout(proj_drop) 116 | 117 | def forward(self, x): 118 | B, N, C = x.shape 119 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 120 | q, k, v = qkv[0], qkv[1], qkv[2] 121 | 122 | attn = (q @ k.transpose(-2, -1)) * self.scale 123 | attn = attn.softmax(dim=-1) 124 | attn = self.attn_drop(attn) 125 | 126 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 127 | x = self.proj(x) 128 | x = self.proj_drop(x) 129 | return x, attn 130 | 131 | 132 | class Block(nn.Module): 133 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 134 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 135 | super().__init__() 136 | self.norm1 = norm_layer(dim) 137 | self.attn = Attention( 138 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 139 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 140 | self.norm2 = norm_layer(dim) 141 | mlp_hidden_dim = int(dim * mlp_ratio) 142 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 143 | 144 | def forward(self, x, return_attention=False): 145 | y, attn = self.attn(self.norm1(x)) 146 | if return_attention: 147 | return attn 148 | x = x + self.drop_path(y) 149 | x = x + self.drop_path(self.mlp(self.norm2(x))) 150 | return x 151 | 152 | 153 | class PatchEmbed(nn.Module): 154 | """ Image to Patch Embedding 155 | """ 156 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 157 | super().__init__() 158 | num_patches = (img_size // patch_size) * (img_size // patch_size) 159 | self.img_size = img_size 160 | self.patch_size = patch_size 161 | self.num_patches = num_patches 162 | 163 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 164 | 165 | def forward(self, x): 166 | B, C, H, W = x.shape 167 | x = self.proj(x).flatten(2).transpose(1, 2) 168 | return x 169 | 170 | 171 | class VisionTransformer(nn.Module): 172 | """ Vision Transformer """ 173 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 174 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 175 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 176 | super().__init__() 177 | self.num_features = self.embed_dim = embed_dim 178 | 179 | self.patch_embed = PatchEmbed( 180 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 181 | num_patches = self.patch_embed.num_patches 182 | 183 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 184 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 185 | self.pos_drop = nn.Dropout(p=drop_rate) 186 | 187 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 188 | self.blocks = nn.ModuleList([ 189 | Block( 190 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 191 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 192 | for i in range(depth)]) 193 | self.norm = norm_layer(embed_dim) 194 | 195 | # Classifier head 196 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 197 | 198 | trunc_normal_(self.pos_embed, std=.02) 199 | trunc_normal_(self.cls_token, std=.02) 200 | self.apply(self._init_weights) 201 | 202 | def _init_weights(self, m): 203 | if isinstance(m, nn.Linear): 204 | trunc_normal_(m.weight, std=.02) 205 | if isinstance(m, nn.Linear) and m.bias is not None: 206 | nn.init.constant_(m.bias, 0) 207 | elif isinstance(m, nn.LayerNorm): 208 | nn.init.constant_(m.bias, 0) 209 | nn.init.constant_(m.weight, 1.0) 210 | 211 | def interpolate_pos_encoding(self, x, w, h): 212 | npatch = x.shape[1] - 1 213 | N = self.pos_embed.shape[1] - 1 214 | if npatch == N and w == h: 215 | return self.pos_embed 216 | class_pos_embed = self.pos_embed[:, 0] 217 | patch_pos_embed = self.pos_embed[:, 1:] 218 | dim = x.shape[-1] 219 | w0 = w // self.patch_embed.patch_size 220 | h0 = h // self.patch_embed.patch_size 221 | # we add a small number to avoid floating point error in the interpolation 222 | # see discussion at https://github.com/facebookresearch/dino/issues/8 223 | w0, h0 = w0 + 0.1, h0 + 0.1 224 | patch_pos_embed = nn.functional.interpolate( 225 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 226 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 227 | mode='bicubic', 228 | ) 229 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 230 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 231 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 232 | 233 | def prepare_tokens(self, x): 234 | B, nc, w, h = x.shape 235 | x = self.patch_embed(x) # patch linear embedding 236 | 237 | # add the [CLS] token to the embed patch tokens 238 | cls_tokens = self.cls_token.expand(B, -1, -1) 239 | x = torch.cat((cls_tokens, x), dim=1) 240 | 241 | # add positional encoding to each token 242 | x = x + self.interpolate_pos_encoding(x, w, h) 243 | 244 | return self.pos_drop(x) 245 | 246 | def forward(self, x): 247 | x = self.forward_all(x) 248 | return x[:, 0] 249 | 250 | def forward_all(self, x): 251 | x = self.prepare_tokens(x) 252 | for blk in self.blocks: 253 | x = blk(x) 254 | x = self.norm(x) 255 | return x 256 | 257 | def get_last_selfattention(self, x): 258 | x = self.prepare_tokens(x) 259 | for i, blk in enumerate(self.blocks): 260 | if i < len(self.blocks) - 1: 261 | x = blk(x) 262 | else: 263 | # return attention of the last block 264 | return blk(x, return_attention=True) 265 | 266 | def get_intermediate_layers(self, x, n=1): 267 | x = self.prepare_tokens(x) 268 | # we return the output tokens from the `n` last blocks 269 | output = [] 270 | for i, blk in enumerate(self.blocks): 271 | x = blk(x) 272 | if len(self.blocks) - i <= n: 273 | output.append(self.norm(x)) 274 | return output 275 | 276 | 277 | def vit_tiny(patch_size=16, **kwargs): 278 | model = VisionTransformer( 279 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 280 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 281 | return model 282 | 283 | 284 | def vit_small(patch_size=16, **kwargs): 285 | model = VisionTransformer( 286 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 287 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 288 | return model 289 | 290 | 291 | def vit_base(patch_size=16, **kwargs): 292 | model = VisionTransformer( 293 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 294 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 295 | return model 296 | 297 | 298 | class DINOHead(nn.Module): 299 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 300 | super().__init__() 301 | nlayers = max(nlayers, 1) 302 | if nlayers == 1: 303 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 304 | else: 305 | layers = [nn.Linear(in_dim, hidden_dim)] 306 | if use_bn: 307 | layers.append(nn.BatchNorm1d(hidden_dim)) 308 | layers.append(nn.GELU()) 309 | for _ in range(nlayers - 2): 310 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 311 | if use_bn: 312 | layers.append(nn.BatchNorm1d(hidden_dim)) 313 | layers.append(nn.GELU()) 314 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 315 | self.mlp = nn.Sequential(*layers) 316 | self.apply(self._init_weights) 317 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 318 | self.last_layer.weight_g.data.fill_(1) 319 | if norm_last_layer: 320 | self.last_layer.weight_g.requires_grad = False 321 | 322 | def _init_weights(self, m): 323 | if isinstance(m, nn.Linear): 324 | trunc_normal_(m.weight, std=.02) 325 | if isinstance(m, nn.Linear) and m.bias is not None: 326 | nn.init.constant_(m.bias, 0) 327 | 328 | def forward(self, x): 329 | x = self.mlp(x) 330 | x = nn.functional.normalize(x, dim=-1, p=2) 331 | x = self.last_layer(x) 332 | return x 333 | -------------------------------------------------------------------------------- /vision_transformer4k.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from functools import partial 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def softmax(x, dim, inplace=False): 9 | if inplace: 10 | torch.exp(x, out=x) 11 | else: 12 | x = torch.exp(x) 13 | s = torch.sum(x, dim=dim, keepdim=True) 14 | x /= s 15 | return x 16 | 17 | 18 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 19 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 20 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 21 | def norm_cdf(x): 22 | # Computes standard normal cumulative distribution function 23 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 24 | 25 | if (mean < a - 2 * std) or (mean > b + 2 * std): 26 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 27 | "The distribution of values may be incorrect.", 28 | stacklevel=2) 29 | 30 | with torch.no_grad(): 31 | # Values are generated by using a truncated uniform distribution and 32 | # then using the inverse CDF for the normal distribution. 33 | # Get upper and lower cdf values 34 | l = norm_cdf((a - mean) / std) 35 | u = norm_cdf((b - mean) / std) 36 | 37 | # Uniformly fill tensor with values from [l, u], then translate to 38 | # [2l-1, 2u-1]. 39 | tensor.uniform_(2 * l - 1, 2 * u - 1) 40 | 41 | # Use inverse cdf transform for normal distribution to get truncated 42 | # standard normal 43 | tensor.erfinv_() 44 | 45 | # Transform to proper mean, std 46 | tensor.mul_(std * math.sqrt(2.)) 47 | tensor.add_(mean) 48 | 49 | # Clamp to ensure it's in the proper range 50 | tensor.clamp_(min=a, max=b) 51 | return tensor 52 | 53 | 54 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 55 | # type: (Tensor, float, float, float, float) -> Tensor 56 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 57 | 58 | 59 | 60 | def drop_path(x, drop_prob: float = 0., training: bool = False): 61 | if drop_prob == 0. or not training: 62 | return x 63 | keep_prob = 1 - drop_prob 64 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 65 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 66 | random_tensor.floor_() # binarize 67 | output = x.div(keep_prob) * random_tensor 68 | return output 69 | 70 | 71 | class DropPath(nn.Module): 72 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 73 | """ 74 | def __init__(self, drop_prob=None): 75 | super(DropPath, self).__init__() 76 | self.drop_prob = drop_prob 77 | 78 | def forward(self, x): 79 | return drop_path(x, self.drop_prob, self.training) 80 | 81 | 82 | class Mlp(nn.Module): 83 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 84 | super().__init__() 85 | out_features = out_features or in_features 86 | hidden_features = hidden_features or in_features 87 | self.fc1 = nn.Linear(in_features, hidden_features) 88 | self.act = act_layer() 89 | self.fc2 = nn.Linear(hidden_features, out_features) 90 | self.drop = nn.Dropout(drop) 91 | 92 | def forward(self, x): 93 | x = self.fc1(x) 94 | x = self.act(x) 95 | x = self.drop(x) 96 | x = self.fc2(x) 97 | x = self.drop(x) 98 | return x 99 | 100 | 101 | class Attention(nn.Module): 102 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 103 | super().__init__() 104 | self.num_heads = num_heads 105 | head_dim = dim // num_heads 106 | self.scale = qk_scale or head_dim ** -0.5 107 | 108 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 109 | self.attn_drop = nn.Dropout(attn_drop) 110 | self.proj = nn.Linear(dim, dim) 111 | self.proj_drop = nn.Dropout(proj_drop) 112 | 113 | def forward(self, x): 114 | B, N, C = x.shape 115 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 116 | q, k, v = qkv[0], qkv[1], qkv[2] 117 | 118 | attn = q @ k.transpose(-2, -1) 119 | attn *= self.scale 120 | softmax(attn, dim=-1, inplace=True) # attn = attn.softmax(dim=-1) 121 | attn = self.attn_drop(attn) 122 | 123 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 124 | del qkv, q, k, v 125 | x = self.proj(x) 126 | x = self.proj_drop(x) 127 | return x, attn 128 | 129 | 130 | class Block(nn.Module): 131 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 132 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 133 | super().__init__() 134 | self.norm1 = norm_layer(dim) 135 | self.attn = Attention( 136 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 137 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 138 | self.norm2 = norm_layer(dim) 139 | mlp_hidden_dim = int(dim * mlp_ratio) 140 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 141 | 142 | def forward(self, x, return_attention=False): 143 | y, attn = self.attn(self.norm1(x)) 144 | if return_attention: 145 | return attn 146 | x = x + self.drop_path(y) 147 | x = x + self.drop_path(self.mlp(self.norm2(x))) 148 | return x 149 | 150 | 151 | class VisionTransformer4K(nn.Module): 152 | """ Vision Transformer 4K """ 153 | def __init__(self, num_classes=0, img_size=[224], input_embed_dim=384, output_embed_dim = 192, 154 | depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 155 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, num_prototypes=64, **kwargs): 156 | super().__init__() 157 | embed_dim = output_embed_dim 158 | self.num_features = self.embed_dim = embed_dim 159 | self.phi = nn.Sequential(*[nn.Linear(input_embed_dim, output_embed_dim), nn.GELU(), nn.Dropout(p=drop_rate)]) 160 | num_patches = int(img_size[0] // 16)**2 161 | # print("# of Patches:", num_patches) 162 | 163 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 164 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 165 | self.pos_drop = nn.Dropout(p=drop_rate) 166 | 167 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 168 | self.blocks = nn.ModuleList([ 169 | Block( 170 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 171 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 172 | for i in range(depth)]) 173 | self.norm = norm_layer(embed_dim) 174 | 175 | # Classifier head 176 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 177 | 178 | trunc_normal_(self.pos_embed, std=.02) 179 | trunc_normal_(self.cls_token, std=.02) 180 | self.apply(self._init_weights) 181 | 182 | def _init_weights(self, m): 183 | if isinstance(m, nn.Linear): 184 | trunc_normal_(m.weight, std=.02) 185 | if isinstance(m, nn.Linear) and m.bias is not None: 186 | nn.init.constant_(m.bias, 0) 187 | elif isinstance(m, nn.LayerNorm): 188 | nn.init.constant_(m.bias, 0) 189 | nn.init.constant_(m.weight, 1.0) 190 | 191 | def interpolate_pos_encoding(self, x, w, h): 192 | npatch = x.shape[1] - 1 193 | N = self.pos_embed.shape[1] - 1 194 | if npatch == N and w == h: 195 | return self.pos_embed 196 | class_pos_embed = self.pos_embed[:, 0] 197 | patch_pos_embed = self.pos_embed[:, 1:] 198 | dim = x.shape[-1] 199 | w0 = w // 1 200 | h0 = h // 1 201 | # we add a small number to avoid floating point error in the interpolation 202 | # see discussion at https://github.com/facebookresearch/dino/issues/8 203 | w0, h0 = w0 + 0.1, h0 + 0.1 204 | patch_pos_embed = nn.functional.interpolate( 205 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 206 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 207 | mode='bicubic', 208 | ) 209 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 210 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 211 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 212 | 213 | def prepare_tokens(self, x): 214 | # print('preparing tokens (after crop)', x.shape) 215 | self.mpp_feature = x 216 | B, embed_dim, w, h = x.shape 217 | x = x.flatten(2, 3).transpose(1,2) 218 | 219 | x = self.phi(x) 220 | 221 | 222 | # add the [CLS] token to the embed patch tokens 223 | cls_tokens = self.cls_token.expand(B, -1, -1) 224 | x = torch.cat((cls_tokens, x), dim=1) 225 | 226 | # add positional encoding to each token 227 | x = x + self.interpolate_pos_encoding(x, w, h) 228 | 229 | return self.pos_drop(x) 230 | 231 | def forward(self, x): 232 | x = self.forward_all(x) 233 | return x[:, 0] 234 | 235 | def forward_all(self, x): 236 | x = self.prepare_tokens(x) 237 | for blk in self.blocks: 238 | x = blk(x) 239 | x = self.norm(x) 240 | return x 241 | 242 | def get_last_selfattention(self, x): 243 | x = self.prepare_tokens(x) 244 | for i, blk in enumerate(self.blocks): 245 | if i < len(self.blocks) - 1: 246 | x = blk(x) 247 | else: 248 | # return attention of the last block 249 | return blk(x, return_attention=True) 250 | 251 | def get_intermediate_layers(self, x, n=1): 252 | x = self.prepare_tokens(x) 253 | # we return the output tokens from the `n` last blocks 254 | output = [] 255 | for i, blk in enumerate(self.blocks): 256 | x = blk(x) 257 | if len(self.blocks) - i <= n: 258 | output.append(self.norm(x)) 259 | return output 260 | 261 | def vit4k_xs(patch_size=16, **kwargs): 262 | model = VisionTransformer4K( 263 | patch_size=patch_size, input_embed_dim=384, output_embed_dim=192, 264 | depth=6, num_heads=6, mlp_ratio=4, 265 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 266 | return model 267 | 268 | def count_parameters(model): 269 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 270 | -------------------------------------------------------------------------------- /visual.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from matplotlib.colors import ListedColormap 7 | from einops import rearrange 8 | 9 | from utils import save_image 10 | from image import get_disk_mask 11 | 12 | 13 | def cmap_myset(x): 14 | cmap = ListedColormap([ 15 | '#17BECF', # cyan 16 | '#FD8D3C', # orange 17 | '#A6D854', # light green 18 | '#9467BD', # purple 19 | '#A5A5A5', # gray 20 | '#F4CAE4', # light pink 21 | '#C47A3D', # brown 22 | '#FFF800', # yellow 23 | ]) 24 | return cmap(x) 25 | 26 | 27 | def cmap_accent(x): 28 | cmap = ListedColormap([ 29 | '#386cb0', # blue 30 | '#fdc086', # orange 31 | '#7fc97f', # green 32 | '#beaed4', # purple 33 | '#f0027f', # magenta 34 | '#bf5b17', # brown 35 | '#666666', # gray 36 | '#ffff99', # yellow 37 | ]) 38 | return cmap(x) 39 | 40 | 41 | def cmap_turbo_adj(x): 42 | 43 | a = 0.70 # lightness adjustment 44 | b = 0.70 # satuation adjustment 45 | 46 | cmap = plt.get_cmap('turbo') 47 | x = np.array(cmap(x))[..., :3] 48 | x = 1 - (1 - x) * a 49 | lightness = x.mean(-1, keepdims=True) 50 | x = (x - lightness) * b + lightness 51 | return x 52 | 53 | 54 | def cmap_turbo_truncated(x): 55 | cmap = plt.get_cmap('turbo') 56 | x = x * 0.9 + 0.05 57 | y = np.array(cmap(x))[..., :3] 58 | return y 59 | 60 | 61 | def cmap_tab30(x): 62 | n_base = 20 63 | n_max = 30 64 | brightness = 0.7 65 | brightness = (brightness,) * 3 + (1.0,) 66 | isin_base = (x < n_base)[..., np.newaxis] 67 | isin_extended = ((x >= n_base) * (x < n_max))[..., np.newaxis] 68 | isin_beyond = (x >= n_max)[..., np.newaxis] 69 | color = ( 70 | isin_base * cmap_tab20(x) 71 | + isin_extended * cmap_tab20(x-n_base) * brightness 72 | + isin_beyond * (0.0, 0.0, 0.0, 1.0)) 73 | return color 74 | 75 | 76 | def cmap_tab70(x): 77 | cmap_base = cmap_tab30 78 | brightness = 0.5 79 | brightness = np.array([brightness] * 3 + [1.0]) 80 | color = [ 81 | cmap_base(x), # same as base colormap 82 | 1 - (1 - cmap_base(x-20)) * brightness, # brighter 83 | cmap_base(x-20) * brightness, # darker 84 | 1 - (1 - cmap_base(x-40)) * brightness**2, # even brighter 85 | cmap_base(x-40) * brightness**2, # even darker 86 | [0.0, 0.0, 0.0, 1.0], # black 87 | ] 88 | x = x[..., np.newaxis] 89 | isin = [ 90 | (x < 30), 91 | (x >= 30) * (x < 40), 92 | (x >= 40) * (x < 50), 93 | (x >= 50) * (x < 60), 94 | (x >= 60) * (x < 70), 95 | (x >= 70)] 96 | color_out = np.sum( 97 | [isi * col for isi, col in zip(isin, color)], 98 | axis=0) 99 | return color_out 100 | 101 | 102 | def interlaced_cmap(cmap, stop, stride, start=0): 103 | 104 | def cmap_new(x): 105 | isin = x >= 0 106 | x[isin] = (x[isin] * stride + start) % stop 107 | return cmap(x) 108 | 109 | return cmap_new 110 | 111 | 112 | def reversed_cmap(cmap, stop): 113 | 114 | def cmap_new(x): 115 | isin = x >= 0 116 | x[isin] = stop - x[isin] - 1 117 | return cmap(x) 118 | 119 | return cmap_new 120 | 121 | 122 | def cmap_tab20(x): 123 | cmap = plt.get_cmap('tab20') 124 | x = x % 20 125 | x = (x // 10) + (x % 10) * 2 126 | return cmap(x) 127 | 128 | 129 | def get_cmap_tab_multi(n_colors, n_shades, paired=True): 130 | cmap_base = cmap_tab20 131 | n_base = 10 132 | assert n_colors <= n_base 133 | 134 | def cmap(x): 135 | isin = x >= 0 136 | x = x * isin 137 | 138 | # lightness 139 | i = x // n_colors 140 | if paired: 141 | is_odd = i % 2 == 1 142 | i[~is_odd] //= 2 143 | i[is_odd] = n_shades - 1 - (i[is_odd] - 1) // 2 144 | 145 | # hue 146 | j = x % n_colors 147 | colors = np.stack( 148 | [cmap_base(k)[..., :3] for k in [j, j+n_base]]) 149 | 150 | # compute color from hue and lightness 151 | weights = 1 - i / max(1, n_shades - 1) 152 | weights = np.stack([weights, 1-weights]) 153 | col = (weights[..., np.newaxis] * colors).sum(0) 154 | col[~isin] = 0 155 | return col 156 | 157 | return cmap 158 | 159 | 160 | def get_cmap_discrete(n_colors, cmap_name): 161 | cmap_base = plt.get_cmap(cmap_name) 162 | 163 | def cmap(x): 164 | x = x / float(n_colors-1) 165 | return cmap_base(x) 166 | 167 | return cmap 168 | 169 | 170 | def plot_labels( 171 | labels, filename, cmap=None, white_background=True, 172 | transparent_background=False, 173 | interlace=False, reverse=False): 174 | if labels.ndim == 3: 175 | n_labels = labels[..., 0].max() + 1 176 | n_shades = labels[..., 1].max() + 1 177 | isin = (labels >= 0).all(-1) 178 | labels_uni = labels[..., 0].copy() 179 | labels_uni[isin] = ( 180 | n_labels * labels[isin][..., -1] 181 | + labels[isin][..., 0]) 182 | labels = labels_uni 183 | elif labels.ndim == 2: 184 | n_labels = labels.max() + 1 185 | n_shades = 1 186 | 187 | if cmap is None: 188 | if n_labels <= 70: 189 | cmap = 'tab70' 190 | else: 191 | cmap = 'turbo' 192 | 193 | if cmap == 'tab70': 194 | cmap = cmap_tab70 195 | elif cmap == 'turbo': 196 | cmap = plt.get_cmap('turbo') 197 | labels = labels / labels.max() 198 | elif cmap == 'multi': 199 | cmap = get_cmap_tab_multi(n_labels, n_shades) 200 | elif isinstance(cmap, str): 201 | cmap = plt.get_cmap(cmap) 202 | 203 | if interlace: 204 | cmap = interlaced_cmap( 205 | cmap, stop=n_labels, start=0, stride=9) 206 | if reverse: 207 | cmap = reversed_cmap(cmap, stop=n_labels) 208 | 209 | image = cmap(labels) 210 | 211 | mask_extra = labels < 0 212 | mask_background = (labels == labels.min()) * mask_extra 213 | image[mask_extra] = 0.5 # gray 214 | if white_background: 215 | background_color = 1.0 # white 216 | else: 217 | background_color = 0.0 # black 218 | image[mask_background] = background_color 219 | if transparent_background: 220 | image[mask_background, -1] = 0.0 221 | image = (image * 255).astype(np.uint8) 222 | 223 | if filename is not None: 224 | save_image(image, filename) 225 | 226 | return image 227 | 228 | 229 | def plot_embeddings(embeddings, prefix, groups=None, same_color_scale=True): 230 | if groups is None: 231 | groups = embeddings.keys() 232 | cmap = plt.get_cmap('turbo') 233 | os.makedirs(os.path.dirname(prefix), exist_ok=True) 234 | for key in groups: 235 | emb = embeddings[key] 236 | mask = np.all([np.isfinite(channel) for channel in emb], 0) 237 | min_all = np.min([channel[mask].min() for channel in emb], 0) 238 | max_all = np.max([channel[mask].max() for channel in emb], 0) 239 | for i, channel in enumerate(emb): 240 | if same_color_scale: 241 | min_chan, max_chan = min_all, max_all 242 | else: 243 | min_chan = channel[mask].min() 244 | max_chan = channel[mask].max() 245 | image = (channel - min_chan) / (max_chan - min_chan) 246 | image = cmap(image)[..., :3] 247 | if not mask.all(): 248 | image[~mask] = 0.0 249 | image = Image.fromarray((image * 255).astype(np.uint8)) 250 | outfile = f'{prefix}{key}-{i:02d}.png' 251 | image.save(outfile) 252 | print(outfile) 253 | 254 | 255 | def plot_spots( 256 | img, cnts, locs, radius, outfile, cmap='magma', 257 | weight=0.8, disk_mask=True, standardize_img=False): 258 | cnts = cnts.astype(np.float32) 259 | 260 | img = img.astype(np.float32) 261 | img /= 255.0 262 | 263 | if standardize_img: 264 | if np.isclose(0.0, np.nanstd(img, (0, 1))).all(): 265 | img[:] = 1.0 266 | else: 267 | img -= np.nanmin(img) 268 | img /= np.nanmax(img) + 1e-12 269 | 270 | cnts -= np.nanmin(cnts) 271 | cnts /= np.nanmax(cnts) + 1e-12 272 | 273 | cmap = plt.get_cmap(cmap) 274 | if disk_mask: 275 | mask_patch = get_disk_mask(radius) 276 | else: 277 | mask_patch = np.ones((radius*2, radius*2)).astype(bool) 278 | indices_patch = np.stack(np.where(mask_patch), -1) 279 | indices_patch -= radius 280 | for ij, ct in zip(locs, cnts): 281 | color = np.array(cmap(ct)[:3]) 282 | indices = indices_patch + ij 283 | img[indices[:, 0], indices[:, 1]] *= 1 - weight 284 | img[indices[:, 0], indices[:, 1]] += color * weight 285 | img = (img * 255).astype(np.uint8) 286 | save_image(img, outfile) 287 | 288 | 289 | def plot_label_masks(labels, prefix, names=None, white_background=True): 290 | 291 | cmap = plt.get_cmap('tab10') 292 | color_pos = cmap(1)[:3] # orange 293 | color_neg = (0.8, 0.8, 0.8) # gray 294 | color_bg = 0.0 # black 295 | if white_background: 296 | color_bg = 1.0 # white 297 | 298 | foreground = labels >= 0 299 | labs_uniq = np.unique(labels) 300 | labs_uniq = labs_uniq[labs_uniq >= 0] 301 | if names is None: 302 | names = [f'{i:03d}' for i in labs_uniq] 303 | 304 | for lab in labs_uniq: 305 | mask = labels == lab 306 | img = np.zeros(mask.shape+(3,), dtype=np.float32) 307 | img[mask] = color_pos 308 | img[~mask] = color_neg 309 | img[~foreground] = color_bg 310 | img = (img * 255).astype(np.uint8) 311 | nam = names[lab] 312 | save_image(img, f'{prefix}{nam}.png') 313 | 314 | 315 | def mat_to_img( 316 | x, white_background=True, transparent_background=False, 317 | cmap='turbo', minmax=None): 318 | mask = np.isfinite(x) 319 | x = x.astype(np.float32) 320 | if minmax is None: 321 | minmax = (np.nanmin(x), np.nanmax(x) + 1e-12) 322 | print('minmax:', minmax) 323 | x -= minmax[0] 324 | x /= minmax[1] - minmax[0] 325 | if isinstance(cmap, str): 326 | cmap = plt.get_cmap(cmap) 327 | x = cmap(x) 328 | if white_background: 329 | x[~mask] = 1.0 330 | if transparent_background: 331 | x[~mask, -1] = 0.0 332 | x = (x * 255).astype(np.uint8) 333 | return x 334 | 335 | 336 | def plot_matrix(x, outfile, **kwargs): 337 | img = mat_to_img(x, **kwargs) 338 | save_image(img, outfile) 339 | 340 | 341 | def plot_spot_masked_image( 342 | locs, values, mask, size, outfile=None): 343 | mask_indices = np.stack(np.where(mask), -1) 344 | shape = np.array(mask.shape) 345 | offset = (-1) * shape // 2 346 | locs = locs + offset 347 | img = np.full(size, np.nan) 348 | for loc, val in zip(locs, values): 349 | indices = loc + mask_indices 350 | img[indices[:, 0], indices[:, 1]] = val 351 | plot_matrix(img, outfile) 352 | 353 | 354 | def plot_labels_3d(labs, filename=None): 355 | depth = labs.shape[0] 356 | labs = rearrange(labs, 'd h w -> (d h) w') 357 | img = plot_labels(labs, filename) 358 | img = rearrange(img, '(d h) w c -> d h w c', d=depth) 359 | return img 360 | 361 | 362 | def compress_indices(indices): 363 | indices = indices.astype(np.int32) 364 | indices = np.unique(indices, axis=0) 365 | return indices 366 | 367 | 368 | def plot_cells(x, masks, filename, tissue=None, boundaries=None): 369 | if tissue is None: 370 | shape = np.max([m.max(0) for m in masks], 0) 371 | shape = np.ceil(shape).astype(int) + 1 372 | else: 373 | shape = tissue.shape 374 | mat = np.zeros(shape, dtype=np.float32) 375 | for u, indices in zip(x, masks): 376 | indices = compress_indices(indices) 377 | mat[indices[:, 0], indices[:, 1]] = u 378 | if tissue is not None: 379 | mat[~tissue] = np.nan 380 | if boundaries is not None: 381 | for indices in boundaries: 382 | indices = compress_indices(indices) 383 | mat[indices[:, 0], indices[:, 1]] = np.nan 384 | plot_matrix(mat, filename) 385 | 386 | 387 | def plot_colorbar(cmap, n_labels, filename): 388 | 389 | size = (200, 200) 390 | 391 | if n_labels is None: 392 | x = np.linspace(0, 1, size[0]*10) 393 | else: 394 | x = np.arange(n_labels) 395 | x = np.repeat(x, size[0]) 396 | 397 | x = np.tile(x[:, np.newaxis], size[1]) 398 | x = x.swapaxes(0, 1) 399 | 400 | if n_labels is None: 401 | plot_matrix(x, filename, cmap=cmap) 402 | else: 403 | plot_labels(x, filename, cmap=cmap) 404 | --------------------------------------------------------------------------------