├── .gitignore ├── README.md ├── interactive_plotting ├── __init__.py ├── bokeh_plots.py ├── experimental │ ├── __init__.py │ ├── plots.py │ ├── scatter3d.py │ └── surface3d.ts ├── holoviews_plots.py └── utils │ ├── __init__.py │ └── _utils.py ├── notebooks └── interactive_plotting_tutorial.ipynb ├── requirements.txt ├── resources └── images │ ├── dpt_plot.png │ ├── gene_trend.png │ ├── graph_plot.png │ ├── heatmap.png │ ├── highlight_de.png │ ├── inter_hist.png │ ├── link_plot.png │ ├── scatter3d.png │ ├── scatter_cat.png │ ├── scatter_cont.png │ ├── scatter_general1.png │ ├── scatter_general2.png │ └── thresh_hist.png └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore folders anywhere in the repo 2 | **/data 3 | **/.ipynb_checkpoints 4 | **/__pycache__ 5 | **/.idea 6 | 7 | # ignore specific folders 8 | write_results/ 9 | figures/ 10 | dataframes/ 11 | write_results/ 12 | 13 | # ignore files anywhere in the repo 14 | paths.py 15 | 16 | # generic python stuff to ignore 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .Rhistory 59 | .pybiomart.sqlite 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interactive Plotting in Scanpy 2 | 3 | 4 | ## About 5 | This repository contains 11 different interactive plotting functions, which may be useful during exploratory analysis. 6 | 7 | Almost every function provides some information when hovering over the plot and some parts of the plots can be hidden by clicking the legend. 8 | 9 | ## Installation 10 | To install this package, do the following: 11 | ```bash 12 | conda install nodejs # >= v6.10.0, for 3D scatterplot 13 | pip install git+https://github.com/theislab/interactive_plotting 14 | ``` 15 | For 3D scatterplot, `node.js >= v6.10.0` is required. Go to node's [website](https://nodejs.org/en/) for instructions on how to install it. 16 | 17 | ## Getting Started 18 | We recommend checking out the [tutorial notebook](./notebooks/interactive_plotting_tutorial.ipynb). 19 | ```ipl.scatter```, ```ipl.scatterc``` ```ipl.dpt``` can handle large number of cells (100K+). 20 | 21 | In your Jupyter Notebook, execute the following lines: 22 | ```python 23 | import holoviews as hv # needed for scatter, scatterc and dpt 24 | hv.extension('bokeh') 25 | 26 | import interactive_plotting as ipl 27 | 28 | from bokeh.io import output_notebook 29 | output_notebook() 30 | ``` 31 | 32 | ## Gallery 33 | Here are some exemplary figures for each of the plotting functions. 34 | ```python 35 | ipl.ex.scatter 36 | ``` 37 | ![Scatterplot - general](resources/images/scatter_general2.png?raw=true "Scatterplot - general") 38 | 39 | --- 40 | 41 | ```python 42 | ipl.ex.scatter3d 43 | ``` 44 | ![Scatterplot - 3D](resources/images/scatter3d.png?raw=true "Scatterplot - 3D") 45 | 46 | --- 47 | 48 | ```python 49 | ipl.ex.scatter 50 | ``` 51 | ![Scatterplot - general](resources/images/scatter_general1.png?raw=true "Scatterplot - general") 52 | 53 | --- 54 | 55 | ```python 56 | ipl.scatter 57 | ``` 58 | ![Scatterplot (emb. cont.)](resources/images/scatter_cont.png?raw=true "Scatterplot - embedding (continous)") 59 | 60 | --- 61 | 62 | ```python 63 | ipl.scatterc 64 | ``` 65 | ![Scatterplot (emb. cat.)](resources/images/scatter_cat.png?raw=true "Scatterplot - embedding (categorical)") 66 | 67 | --- 68 | 69 | ```python 70 | ipl.ex.heatmap 71 | ``` 72 | ![Heatmap](https://raw.githubusercontent.com/theislab/interactive_plotting/experimental/resources/images/heatmap.png "Heatmap") 73 | 74 | --- 75 | 76 | ```python 77 | ipl.dpt 78 | ``` 79 | ![DPT plot](resources/images/dpt_plot.png?raw=true "DPT plot") 80 | 81 | --- 82 | 83 | ```python 84 | ipl.graph 85 | ``` 86 | ![Graph plot](resources/images/graph_plot.png?raw=true "Graph plot") 87 | 88 | --- 89 | 90 | ```python 91 | ipl.link_plot 92 | ``` 93 | ![link plot](resources/images/link_plot.png?raw=true "Link plot") 94 | 95 | --- 96 | 97 | ```python 98 | ipl.highlight_de 99 | ``` 100 | ![highlight differential expression plot](resources/images/highlight_de.png?raw=true "Highlight differential expression") 101 | 102 | --- 103 | 104 | ```python 105 | ipl.gene_trend 106 | ``` 107 | ![gene trend](resources/images/gene_trend.png?raw=true "Gene trend") 108 | 109 | --- 110 | 111 | ```python 112 | ipl.interactive_hist 113 | ``` 114 | ![interactive histogram](resources/images/inter_hist.png?raw=true "Interactive histogram") 115 | 116 | --- 117 | 118 | ```python 119 | ipl.thresholding_hist 120 | ``` 121 | ![thresholding histogram](resources/images/thresh_hist.png?raw=true "Thresholding histogram") 122 | 123 | ## Troubleshooting 124 | * [Notebook size is **huge**](https://github.com/theislab/interactive_plotting/issues/2) - This has to do with ```ipl.link_plot``` and ```ipl.velocity_plot```. Until a fix is found, we suggest removing these figures after you're done using them. 125 | * [Getting "OPub data rate exceeded" error](https://github.com/theislab/interactive_plotting/issues/7) - Try starting jupyter notebook as following: 126 | 127 | ```jupyter notebook --NotebookApp.iopub_data_rate_limit=1e10``` 128 | 129 | For generating jupyter config file, see [here](https://stackoverflow.com/questions/43288550/iopub-data-rate-exceeded-in-jupyter-notebook-when-viewing-image). 130 | -------------------------------------------------------------------------------- /interactive_plotting/__init__.py: -------------------------------------------------------------------------------- 1 | from interactive_plotting.bokeh_plots import interactive_hist, \ 2 | thresholding_hist, \ 3 | highlight_de, \ 4 | link_plot, \ 5 | gene_trend 6 | from interactive_plotting.holoviews_plots import scatter, scatterc, dpt, graph 7 | 8 | import interactive_plotting.experimental as ex 9 | -------------------------------------------------------------------------------- /interactive_plotting/bokeh_plots.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from sklearn.gaussian_process.kernels import * 4 | from sklearn import neighbors 5 | from scipy.sparse import issparse 6 | from scipy.spatial import distance_matrix, ConvexHull 7 | 8 | from functools import reduce 9 | from collections import defaultdict 10 | from itertools import product 11 | 12 | import warnings 13 | 14 | import numpy as np 15 | import pandas as pd 16 | import scanpy as sc 17 | 18 | import matplotlib.cm as cm 19 | import matplotlib 20 | import bokeh 21 | 22 | 23 | from interactive_plotting.utils._utils import sample_unif, sample_density, to_hex_palette 24 | from bokeh.plotting import figure, show, save as bokeh_save 25 | from bokeh.models import ColumnDataSource, Slider, HoverTool, ColorBar, \ 26 | Patches, Legend, CustomJS, TextInput, LabelSet, Select 27 | from bokeh.models.ranges import Range1d 28 | from bokeh.models.mappers import CategoricalColorMapper, LinearColorMapper 29 | from bokeh.layouts import layout, column, row, GridSpec 30 | from bokeh.transform import linear_cmap, factor_mark, factor_cmap 31 | from bokeh.core.enums import MarkerType 32 | from bokeh.palettes import Set1, Set2, Set3, inferno, viridis 33 | from bokeh.models.widgets.buttons import Button 34 | 35 | _bokeh_version = tuple(map(int, bokeh.__version__.split('.'))) 36 | 37 | 38 | _inter_hist_js_code=""" 39 | // here is where original data is stored 40 | var x = orig.data['values']; 41 | 42 | x = x.sort((a, b) => a - b); 43 | var n_bins = parseInt(bins.value); // can be either string or int 44 | var bin_size = (x[x.length - 1] - x[0]) / n_bins; 45 | 46 | var hist = new Array(n_bins).fill().map((_, i) => { return 0; }); 47 | var l_edges = new Array(n_bins).fill().map((_, i) => { return x[0] + bin_size * i; }); 48 | var r_edges = new Array(n_bins).fill().map((_, i) => { return x[0] + bin_size * (i + 1); }); 49 | var indices = new Array(n_bins).fill().map((_) => { return []; }); 50 | 51 | // create the histogram 52 | for (var i = 0; i < x.length; i++) { 53 | for (var j = 0; j < r_edges.length; j++) { 54 | if (x[i] <= r_edges[j]) { 55 | hist[j] += 1; 56 | indices[j].push(i); 57 | break; 58 | } 59 | } 60 | } 61 | 62 | // make it a density 63 | var sum = hist.reduce((a, b) => a + b, 0); 64 | var deltas = r_edges.map((c, i) => { return c - l_edges[i]; }); 65 | // just like in numpy 66 | hist = hist.map((c, i) => { return c / deltas[i] / sum; }); 67 | 68 | source.data['hist'] = hist; 69 | source.data['l_edges'] = l_edges; 70 | source.data['r_edges'] = r_edges; 71 | source.data['indices'] = indices; 72 | 73 | source.change.emit(); 74 | """ 75 | 76 | 77 | def _inter_color_code(*colors): 78 | assert len(colors) > 0, 'Doesn\'t make sense using no colors.' 79 | color_code = '\n'.join((f'renderers[i].glyph.{c} = {{field: cb_obj.value, transform: transform}};' 80 | for c in colors)) 81 | return f""" 82 | var transform = cmaps[cb_obj.value]['transform']; 83 | var low = Math.min.apply(Math, source.data[cb_obj.value]); 84 | var high = Math.max.apply(Math, source.data[cb_obj.value]); 85 | 86 | for (var i = 0; i < renderers.length; i++) {{ 87 | {color_code} 88 | }} 89 | 90 | color_bar.color_mapper.low = low; 91 | color_bar.color_mapper.high = high; 92 | color_bar.color_mapper.palette = transform.palette; 93 | 94 | source.change.emit(); 95 | """ 96 | 97 | 98 | def _set_plot_wh(fig, w, h): 99 | if w is not None: 100 | fig.plot_width = w 101 | if h is not None: 102 | fig.plot_height = h 103 | 104 | 105 | def _create_mapper(adata, key): 106 | """ 107 | Helper function to create CategoricalColorMapper from annotated data. 108 | 109 | Params 110 | -------- 111 | adata: AnnData 112 | annotated data object 113 | key: str 114 | key in `adata.obs.obs_keys()` or `adata.var_names`, for which we want the colors; if no colors for given 115 | column are found in `adata.uns[key_colors]`, use `viridis` palette 116 | 117 | Returns 118 | -------- 119 | mapper: bokeh.models.mappers.CategoricalColorMapper 120 | mapper which maps valuems from `adata.obs[key]` to colors 121 | """ 122 | if not key in adata.obs_keys(): 123 | assert key in adata.var_names, f'`{key}` not found in `adata.obs_keys()` or `adata.var_names`' 124 | ix = np.where(adata.var_names == key)[0][0] 125 | vals = list(adata.X[:, ix]) 126 | palette = cm.get_cmap('viridis', adata.n_obs) 127 | 128 | mapper = dict(zip(vals, range(len(vals)))) 129 | palette = to_hex_palette(palette([mapper[v] for v in vals])) 130 | 131 | return LinearColorMapper(palette=palette, low=np.min(vals), high=np.max(vals)) 132 | 133 | is_categorical = adata.obs[key].dtype.name == 'category' 134 | default_palette = cm.get_cmap('viridis', adata.n_obs if not is_categorical else len(adata.obs[key].unique())) 135 | palette = adata.uns.get(f'{key}_colors', default_palette) 136 | 137 | if palette is default_palette: 138 | vals = adata.obs[key].unique() if is_categorical else adata.obs[key] 139 | mapper = dict(zip(vals, range(len(vals)))) 140 | palette = palette([mapper[v] for v in vals]) 141 | 142 | palette = to_hex_palette(palette) 143 | 144 | if is_categorical: 145 | return CategoricalColorMapper(palette=palette, factors=list(map(str, adata.obs[key].cat.categories))) 146 | 147 | return LinearColorMapper(palette=palette, low=np.min(adata.obs[key]), high=np.max(adata.obs[key])) 148 | 149 | 150 | def _smooth_expression(x, y, n_points=100, time_span=[None, None], mode='gp', kernel_params=dict(), kernel_default_params=dict(), 151 | kernel_expr=None, default=False, verbose=False, **opt_params): 152 | """Smooth out the expression of given values. 153 | 154 | Params 155 | -------- 156 | x: list(number) 157 | list of features 158 | y: list(number) 159 | list of targets 160 | n_points: int, optional (default: `100`) 161 | number of points to extrapolate 162 | time_span: list(int), optional (default `[None, None]`) 163 | initial and final start values for range 164 | mode: str, optional (default: `'gp'`) 165 | which regressor to use, available (`'gp'`: Gaussian Process, `'krr'`: Kernel Ridge Regression) 166 | kernel_params: dict, optional (default: `dict()`) 167 | dictionary of kernels with their parameters, keys correspond to variable names 168 | which can be later combined using `kernel_expr`. Supported kernels: `ConstantKernel`, `WhiteKernel`, 169 | `RBF`, `Mattern`, `RationalQuadratic`, `ExpSineSquared`, `DotProduct`, `PairWiseKernel`. 170 | kernel_default_params: dict, optional (default: `dict()`) 171 | default parameters for a kernel, if not found in `kernel_params` 172 | kernel_expr: str, default (`None`) 173 | expression to combine kernel variables specified in `kernel_params`. Supported operators are `+`, `*`, `**`; 174 | example: kernel_expr=`'(a + b) ** 2'`, kernel_params=`{'a': ConstantKernel(1), 'b': DotProduct(2)}` 175 | default: bool, optional (default: `False`) 176 | whether to use default kernel (RBF), if none specified and/or to use default 177 | parameters for kernel variables in` kernel_expr`, not found in `kernel_params` 178 | if False, throws an Exception 179 | verbose: bool, optional (default: `False`) 180 | be more verbose 181 | **opt_params: kwargs 182 | keyword arguments for optimizer 183 | 184 | Returns 185 | -------- 186 | x_test: np.array 187 | points for which we predict the values 188 | x_mean: np.array 189 | mean of the response 190 | cov: np.array (`None` for mode=`'krr'`) 191 | covariance matrix of the response 192 | """ 193 | 194 | from sklearn.kernel_ridge import KernelRidge 195 | from sklearn.gaussian_process import GaussianProcessRegressor 196 | import operator as op 197 | import ast 198 | 199 | def _eval(node): 200 | if isinstance(node, ast.Num): 201 | return node.n 202 | 203 | if isinstance(node, ast.Name): 204 | if not default and node.id not in kernel_params: 205 | raise ValueError(f'Error while parsing `{kernel_expr}`: `{node.id}` is not a valid key in kernel_params. To use RBF kernel with default parameters, specify default=True.') 206 | params = kernel_params.get(node.id, kernel_default_params) 207 | kernel_type = params.pop('type', 'rbf') 208 | return kernels[kernel_type](**params) 209 | 210 | if isinstance(node, ast.BinOp): 211 | return operators[type(node.op)](_eval(node.left), _eval(node.right)) 212 | 213 | if isinstance(node, ast.UnaryOp): 214 | return operators[type(node.op)](_eval(node.operand)) 215 | 216 | raise TypeError(node) 217 | 218 | operators = {ast.Add : op.add, 219 | ast.Mult: op.mul, 220 | ast.Pow :op.pow} 221 | kernels = dict(const=ConstantKernel, 222 | white=WhiteKernel, 223 | rbf=RBF, 224 | mat=Matern, 225 | rq=RationalQuadratic, 226 | esn=ExpSineSquared, 227 | dp=DotProduct, 228 | pw=PairwiseKernel) 229 | 230 | minn, maxx = time_span 231 | x_test = np.linspace(np.min(x) if minn is None else minn, np.max(x) if maxx is None else maxx, n_points)[:, None] 232 | 233 | if mode == 'krr': 234 | gamma = opt_params.pop('gamma', None) 235 | 236 | if gamma is None: 237 | length_scale = kernel_default_params.get('length_scale', 0.2) 238 | gamma = 1 / (2 * length_scale ** 2) 239 | if verbose: 240 | print(f'Smoothing using KRR with length_scale: {length_scale}.') 241 | 242 | kernel = opt_params.pop('kernel', 'rbf') 243 | model = KernelRidge(gamma=gamma, kernel=kernel, **opt_params) 244 | model.fit(x, y) 245 | 246 | return x_test, model.predict(x_test), [None] * n_points 247 | 248 | if mode == 'gp': 249 | 250 | if kernel_expr is None: 251 | assert len(kernel_params) == 1 252 | kernel_expr, = kernel_params.keys() 253 | 254 | kernel = _eval(ast.parse(kernel_expr, mode='eval').body) 255 | alpha = opt_params.pop('alpha', None) 256 | if alpha is None: 257 | alpha = np.std(y) 258 | 259 | optimizer = opt_params.pop('optimizer', None) 260 | opt_params['kernel'] = kernel 261 | 262 | model = GaussianProcessRegressor(alpha=alpha, optimizer=optimizer, **opt_params) 263 | model.fit(x, y) 264 | 265 | mean, cov = model.predict(x_test, return_cov=True) 266 | 267 | return x_test, mean, cov 268 | 269 | raise ValueError(f'Uknown type: `{type}`.') 270 | 271 | 272 | def _create_gt_fig(adatas, dataframe, color_key, title, color_mapper, show_cont_annot=False, 273 | use_raw=True, genes=[], legend_loc='top_right', 274 | plot_width=None, plot_height=None): 275 | """ 276 | Helper function which create a figure with smoothed velocities, including 277 | confidence intervals, if possible. 278 | 279 | Params: 280 | -------- 281 | dataframe: pandas.DataFrame 282 | dataframe containing the velocity data 283 | color_key: str 284 | column in `dataframe` that is to be mapped to colors 285 | title: str 286 | title of the figure 287 | color_mapper: bokeh.models.mappers.CategoricalColorMapper 288 | transformation which assings a value from `dataframe[color_key]` to a color 289 | show_cont_annot: bool, optional (default: `False`) 290 | show continuous annotations in `adata.obs`, if `color_key` is 291 | itself a continuous variable 292 | use_raw: bool, optional (default: `True`) 293 | whether to use adata.raw to get the expression 294 | genes: list, optional (default: `[]`) 295 | list of possible genes to color in, 296 | only works if `color_key` is continuous variable 297 | legend_loc: str, default(`'top_right'`) 298 | position of the legend 299 | plot_width: int, optional (default: `None`) 300 | width of the plot 301 | plot_height: int, optional (default: `None`) 302 | height of the plot 303 | 304 | Returns: 305 | -------- 306 | fig: bokeh.plotting.figure 307 | figure containing the plot 308 | """ 309 | 310 | # these markers are nearly indistinguishble 311 | markers = [marker for marker in MarkerType if marker not in ['circle_cross', 'circle_x']] 312 | fig = figure(title=title) 313 | _set_plot_wh(fig, plot_width, plot_height) 314 | 315 | renderers, color_selects = [], [] 316 | for i, (adata, marker, (path, df)) in enumerate(zip(adatas, markers, dataframe.iterrows())): 317 | ds = {'dpt': df['dpt'], 318 | 'expr': df['expr'], 319 | f'{color_key}': df[color_key]} 320 | is_categorical = color_key in adata.obs_keys() and adata.obs[color_key].dtype.name == 'category' 321 | if not is_categorical: 322 | ds, mappers = _get_mappers(adata, ds, genes, use_raw=use_raw) 323 | 324 | source = ColumnDataSource(ds) 325 | renderers.append(fig.scatter('dpt', 'expr', source=source, 326 | color={'field': color_key, 'transform': color_mapper if is_categorical else mappers[color_key]['transform']}, 327 | fill_color={'field': color_key, 'transform': color_mapper if is_categorical else mappers[color_key]['transform']}, 328 | line_color={'field': color_key, 'transform': color_mapper if is_categorical else mappers[color_key]['transform']}, 329 | marker=marker, size=10, legend_label=path, muted_alpha=0)) 330 | 331 | fig.xaxis.axis_label = 'dpt' 332 | fig.yaxis.axis_label = 'expression' 333 | if legend_loc is not None: 334 | fig.legend.location = legend_loc 335 | 336 | if not is_categorical and show_cont_annot: 337 | color_selects.append(_add_color_select(color_key, fig, [renderers[-1]], source, mappers, suffix=f' [{path}]')) 338 | 339 | ds = dict(df[['x_test', 'x_mean', 'x_cov']]) 340 | if ds.get('x_test') is not None: 341 | if ds.get('x_mean') is not None: 342 | source = ColumnDataSource(ds) 343 | fig.line('x_test', 'x_mean', source=source, muted_alpha=0, legend_label=path) 344 | if all(map(lambda val: val is not None, ds.get('x_cov', [None]))): 345 | x_mean = ds['x_mean'] 346 | x_cov = ds['x_cov'] 347 | band_x = np.append(ds['x_test'][::-1], ds['x_test']) 348 | # black magic, known only to the most illustrious of wizards 349 | band_y = np.append((x_mean - np.sqrt(np.diag(x_cov)))[::-1], (x_mean + np.sqrt(np.diag(x_cov)))) 350 | fig.patch(band_x, band_y, alpha=0.1, line_color='black', fill_color='black', 351 | legend_label=path, line_dash='dotdash', muted_alpha=0) 352 | 353 | if ds.get('x_grad') is not None: 354 | fig.line('x_test', 'x_grad', source=source, muted_alpha=0) 355 | 356 | fig.legend.click_policy = 'mute' 357 | 358 | return column(fig, *color_selects) 359 | 360 | 361 | def interactive_hist(adata, keys=['n_counts', 'n_genes'], 362 | bins='auto', max_bins=100, 363 | groups=None, fill_alpha=0.4, 364 | palette=None, display_all=True, 365 | tools='pan, reset, wheel_zoom, save', 366 | legend_loc='top_right', 367 | plot_width=None, plot_height=None, save=None, 368 | *args, **kwargs): 369 | """Utility function to plot distributions with variable number of bins. 370 | 371 | Params 372 | -------- 373 | adata: AnnData object 374 | annotated data object 375 | keys: list(str), optional (default: `['n_counts', 'n_genes']`) 376 | keys in `adata.obs` or `adata.var` where the distibutions are stored 377 | bins: int; str, optional (default: `auto`) 378 | number of bins used for plotting or str from numpy.histogram 379 | max_bins: int, optional (default: `1000`) 380 | maximum number of bins possible 381 | groups: list(str), (default: `None`) 382 | keys in `adata.obs.obs_keys()`, groups by all possible combinations of values, e.g. for 383 | 3 plates and 2 time points, we would create total of 6 groups 384 | fill_alpha: float[0.0, 1.0], (default: `0.4`) 385 | alpha channel of the fill color 386 | palette: list(str), optional (default: `None`) 387 | palette to use 388 | display_all: bool, optional (default: `True`) 389 | display the statistics for all data 390 | tools: str, optional (default: `'pan,reset, wheel_zoom, save'`) 391 | palette of interactive tools for the user 392 | legend_loc: str, (default: `'top_right'`) 393 | position of the legend 394 | legend_loc: str, default(`'top_left'`) 395 | position of the legend 396 | plot_width: int, optional (default: `None`) 397 | width of the plot 398 | plot_height: int, optional (default: `None`) 399 | height of the plot 400 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`) 401 | path where to save the plot 402 | *args, **kwargs: arguments, keyword arguments 403 | addition argument to bokeh.models.figure 404 | 405 | Returns 406 | -------- 407 | None 408 | """ 409 | 410 | if max_bins < 1: 411 | raise ValueError(f'`max_bins` must >= 1') 412 | 413 | palette = Set1[9] + Set2[8] + Set3[12] if palette is None else palette 414 | 415 | # check the input 416 | for key in keys: 417 | if key not in adata.obs.keys() and \ 418 | key not in adata.var.keys() and \ 419 | key not in adata.var_names: 420 | raise ValueError(f'The key `{key}` does not exist in `adata.obs`, `adata.var` or `adata.var_names`.') 421 | 422 | def _create_adata_groups(): 423 | if groups is None: 424 | return [adata], [('all',)] 425 | 426 | combs = list(product(*[set(adata.obs[g]) for g in groups])) 427 | adatas= [adata[reduce(lambda l, r: l & r, 428 | (adata.obs[k] == v for k, v in zip(groups, vals)), True)] 429 | for vals in combs] + [adata] 430 | 431 | if display_all: 432 | combs += [('all',)] 433 | adatas += [adata] 434 | 435 | return adatas, combs 436 | 437 | # group_v_combs contains the value combinations 438 | ad_gs = _create_adata_groups() 439 | 440 | cols = [] 441 | for key in keys: 442 | callbacks = [] 443 | fig = figure(*args, tools=tools, **kwargs) 444 | slider = Slider(start=1, end=max_bins, value=0, step=1, 445 | title='Bins') 446 | 447 | plots = [] 448 | for j, (ad, group_vs) in enumerate(filter(lambda ad_g: ad_g[0].n_obs > 0, zip(*ad_gs))): 449 | 450 | if key in ad.obs.keys(): 451 | orig = ad.obs[key] 452 | hist, edges = np.histogram(orig, density=True, bins=bins) 453 | elif key in ad.var.keys(): 454 | orig = ad.var[key] 455 | hist, edges = np.histogram(orig, density=True, bins=bins) 456 | else: 457 | orig = ad[:, key].X 458 | hist, edges = np.histogram(orig, density=True, bins=bins) 459 | 460 | slider.value = len(hist) 461 | # case when automatic bins 462 | max_bins = max(max_bins, slider.value) 463 | 464 | # original data, used for recalculation of histogram in JS code 465 | orig = ColumnDataSource(data=dict(values=orig)) 466 | # data that we update in JS code 467 | source = ColumnDataSource(data=dict(hist=hist, l_edges=edges[:-1], r_edges=edges[1:])) 468 | 469 | legend = ', '.join(': '.join(map(str, gv)) for gv in zip(groups, group_vs)) \ 470 | if groups is not None else 'all' 471 | p = fig.quad(source=source, top='hist', bottom=0, 472 | left='l_edges', right='r_edges', 473 | fill_color=palette[j], legend_label=legend if legend_loc is not None else None, 474 | muted_alpha=0, 475 | line_color="#555555", fill_alpha=fill_alpha) 476 | 477 | # create callback and slider 478 | callback = CustomJS(args=dict(source=source, orig=orig), code=_inter_hist_js_code) 479 | callback.args['bins'] = slider 480 | callbacks.append(callback) 481 | 482 | # add the current plot so that we can set it 483 | # visible/invisible in JS code 484 | plots.append(p) 485 | 486 | slider.end = max_bins 487 | 488 | # slider now updates all values 489 | slider.js_on_change('value', *callbacks) 490 | 491 | button = Button(label='Toggle', button_type='primary') 492 | button.callback = CustomJS( 493 | args={'plots': plots}, 494 | code=''' 495 | for (var i = 0; i < plots.length; i++) { 496 | plots[i].muted = !plots[i].muted; 497 | } 498 | ''' 499 | ) 500 | 501 | if legend_loc is not None: 502 | fig.legend.location = legend_loc 503 | fig.legend.click_policy = 'mute' 504 | 505 | fig.xaxis.axis_label = key 506 | fig.yaxis.axis_label = 'normalized frequency' 507 | _set_plot_wh(fig, plot_width, plot_height) 508 | 509 | cols.append(column(slider, button, fig)) 510 | 511 | if _bokeh_version > (1, 0, 4): 512 | from bokeh.layouts import grid 513 | plot = grid(children=cols, ncols=2) 514 | else: 515 | cols = list(map(list, np.array_split(cols, np.ceil(len(cols) / 2)))) 516 | plot = layout(children=cols, sizing_mode='fixed', ncols=2) 517 | 518 | if save is not None: 519 | save = save if str(save).endswith('.html') else str(save) + '.html' 520 | bokeh_save(plot, save) 521 | else: 522 | show(plot) 523 | 524 | 525 | def thresholding_hist(adata, key, categories, basis=['umap'], components=[1, 2], 526 | bins='auto', palette=None, legend_loc='top_right', 527 | plot_width=None, plot_height=None, save=None): 528 | """Histogram with the option to highlight categories based on thresholding binned values. 529 | 530 | Params 531 | -------- 532 | adata: AnnData object 533 | annotated data object 534 | key: str 535 | key in `adata.obs_keys()` where the data is stored 536 | categories: dict 537 | dictionary with keys corresponding to group names and values to starting boundaries `[min, max]` 538 | basis: list, optional (default: `['umap']`) 539 | basis in `adata.obsm_keys()` to visualize 540 | components: list(int); list(list(int)), optional (default: `[1, 2]`) 541 | components to use for each basis 542 | bins: int; str, optional (default: `auto`) 543 | number of bins used for initial binning or a string key used in from numpy.histogram 544 | palette: list(str), optional (default: `None`) 545 | palette to use for coloring categories 546 | legend_loc: str, default(`'top_right'`) 547 | position of the legend 548 | plot_width: int, optional (default: `None`) 549 | width of the plot 550 | plot_height: int, optional (default: `None`) 551 | height of the plot 552 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`) 553 | path where to save the plot 554 | 555 | Returns 556 | -------- 557 | None 558 | """ 559 | 560 | if not isinstance(components[0], list): 561 | components = [components] 562 | 563 | if len(components) != len(basis): 564 | assert len(basis) % len(components) == 0 and len(basis) >= len(components) 565 | components = components * (len(basis) // len(components)) 566 | 567 | if not isinstance(components, np.ndarray): 568 | components = np.asarray(components) 569 | 570 | if not isinstance(basis, list): 571 | basis = [basis] 572 | 573 | palette = Set1[9] + Set2[8] + Set3[12] if palette is None else palette 574 | 575 | hist_fig = figure() 576 | _set_plot_wh(hist_fig, plot_width, plot_height) 577 | 578 | hist_fig.xaxis.axis_label = key 579 | hist_fig.yaxis.axis_label = 'normalized frequency' 580 | hist, edges = np.histogram(adata.obs[key], density=True, bins=bins) 581 | 582 | source = ColumnDataSource(data=dict(hist=hist, l_edges=edges[:-1], r_edges=edges[1:], 583 | category=['default'] * len(hist), indices=[[]] * len(hist))) 584 | 585 | df = pd.concat([pd.DataFrame(adata.obsm[f'X_{bs}'][:, comp - (bs != 'diffmap')], columns=[f'x_{bs}', f'y_{bs}']) 586 | for bs, comp in zip(basis, components)], axis=1) 587 | df['values'] = list(adata.obs[key]) 588 | df['category'] = 'default' 589 | df['visible_category'] = 'default' 590 | df['cat_stack'] = [['default']] * len(df) 591 | 592 | orig = ColumnDataSource(df) 593 | color = dict(field='category', transform=CategoricalColorMapper(palette=palette, factors=list(categories.keys()))) 594 | hist_fig.quad(source=source, top='hist', bottom=0, 595 | left='l_edges', right='r_edges', color=color, 596 | line_color="#555555", legend_group='category') 597 | if legend_loc is not None: 598 | hist_fig.legend.location = legend_loc 599 | 600 | emb_figs = [] 601 | for bs, comp in zip(basis, components): 602 | fig = figure(title=bs) 603 | 604 | fig.xaxis.axis_label = f'{bs}_{comp[0]}' 605 | fig.yaxis.axis_label = f'{bs}_{comp[1]}' 606 | _set_plot_wh(fig, plot_width, plot_height) 607 | 608 | fig.scatter(f'x_{bs}', f'y_{bs}', source=orig, size=10, color=color, legend_group='category') 609 | if legend_loc is not None: 610 | fig.legend.location = legend_loc 611 | 612 | emb_figs.append(fig) 613 | 614 | inputs, category_cbs = [], [] 615 | code_start, code_mid, code_thresh = [], [], [] 616 | args = {'source': source, 'orig': orig} 617 | 618 | for col, cat_item in zip(palette, categories.items()): 619 | cat, (start, end) = cat_item 620 | inp_min = TextInput(name='test', value=f'{start}', title=f'{cat}/min') 621 | inp_max = TextInput(name='test', value=f'{end}', title=f'{cat}/max') 622 | 623 | code_start.append(f''' 624 | var min_{cat} = parseFloat(inp_min_{cat}.value); 625 | var max_{cat} = parseFloat(inp_max_{cat}.value); 626 | ''') 627 | code_mid.append(f''' 628 | var mid_{cat} = (source.data['r_edges'][i] - source.data['l_edges'][i]) / 2; 629 | ''') 630 | code_thresh.append(f''' 631 | if (source.data['l_edges'][i] + mid_{cat} >= min_{cat} && source.data['r_edges'][i] - mid_{cat} <= max_{cat}) {{ 632 | source.data['category'][i] = '{cat}'; 633 | for (var j = 0; j < source.data['indices'][i].length; j++) {{ 634 | var ix = source.data['indices'][i][j]; 635 | orig.data['category'][ix] = '{cat}'; 636 | }} 637 | }} 638 | ''') 639 | args[f'inp_min_{cat}'] = inp_min 640 | args[f'inp_max_{cat}'] = inp_max 641 | min_ds = ColumnDataSource(dict(xs=[start] * 2)) 642 | max_ds = ColumnDataSource(dict(xs=[end] * 2)) 643 | 644 | inputs.extend([inp_min, inp_max]) 645 | 646 | code_thresh.append( 647 | ''' 648 | { 649 | source.data['category'][i] = 'default'; 650 | for (var j = 0; j < source.data['indices'][i].length; j++) { 651 | var ix = source.data['indices'][i][j]; 652 | orig.data['category'][ix] = 'default'; 653 | } 654 | } 655 | ''') 656 | callback = CustomJS(args=args, code=f''' 657 | {';'.join(code_start)} 658 | for (var i = 0; i < source.data['hist'].length; i++) {{ 659 | {';'.join(code_mid)} 660 | {' else '.join(code_thresh)} 661 | }} 662 | orig.change.emit(); 663 | source.change.emit(); 664 | ''') 665 | 666 | for input in inputs: 667 | input.js_on_change('value', callback) 668 | 669 | slider = Slider(start=1, end=100, value=len(hist), title='Bins') 670 | interactive_hist_cb = CustomJS(args={'source': source, 'orig': orig, 'bins': slider}, code=_inter_hist_js_code) 671 | slider.js_on_change('value', interactive_hist_cb, callback) 672 | 673 | plot = column(row(hist_fig, column(slider, *inputs)), *emb_figs) 674 | 675 | if save is not None: 676 | save = save if str(save).endswith('.html') else str(save) + '.html' 677 | bokeh_save(plot, save) 678 | else: 679 | show(plot) 680 | 681 | 682 | def gene_trend(adata, paths, genes=None, mode='gp', exp_key='X', 683 | separate_paths=False, show_cont_annot=False, 684 | extra_genes=[], n_points=100, show_zero_counts=True, 685 | time_span=[None, None], use_raw=True, 686 | n_velocity_genes=5, length_scale=0.2, 687 | path_key='louvain', color_key='louvain', 688 | share_y=True, legend_loc='top_right', 689 | plot_width=None, plot_height=None, save=None, **kwargs): 690 | """ 691 | Function which shows expression levels as well as velocity per gene as a function of DPT. 692 | 693 | Params 694 | -------- 695 | adata: AnnData 696 | annotated data object 697 | paths: list(list(str)) 698 | different paths to visualize 699 | genes: list, optional (default: `None`) 700 | list of genes to show, if `None` take `n_velocity` genes 701 | from `adata.var['velocity_genes']` 702 | mode: str, optional (default: `'gp'`) 703 | whether to use Kernel Ridge Regression (`'krr'`) or a Gaussian Process (`'gp'`) for 704 | smoothing the expression values 705 | exp_key: str, optional (default: `'X'`) 706 | key from adata.layers or just `'X'` to get expression values 707 | separate_paths: bool, optional (default: `False`) 708 | whether to show each path for each gene in a separate plot 709 | show_cont_annot: bool, optional (default: `False`) 710 | show continuous annotations in `adata.obs`, 711 | only works if `color_key` is continuous variable 712 | extra_genes: list(str), optional (default: `[]`) 713 | list of possible genes to color in, 714 | only works if `color_key` is continuous variable 715 | n_points: int, optional (default: `100`) 716 | how many points to use for the smoothing 717 | time_span: list(int), optional (default `[None, None]`) 718 | initial and final start values for range, `None` corresponds to min/max 719 | use_raw: bool, optional (default: `True`) 720 | whether to use adata.raw to get the expression 721 | show_zero_counts: bool, optional (default: `True`) 722 | whether to show cells with zero counts 723 | n_velocity_genes: int, optional (default: `5`) 724 | number of genes to take from` adata.var['velocity_genes']` 725 | length_scale : float, optional (default `0.2`) 726 | length scale for RBF kernel 727 | path_key: str, optional (default: `'louvain'`) 728 | key in `adata.obs_keys()` where to look for groups specified in `paths` argument 729 | color_key: str, optional (default: `'louvain'`) 730 | key in `adata.obs_keys()` which is color in plot 731 | share_y: bool, optional (default: `True`) 732 | whether to share y-axis when plotting paths separately 733 | legend_loc: str, default(`'top_right'`) 734 | position of the legend 735 | plot_width: int, optional (default: `None`) 736 | width of the plot 737 | plot_height: int, optional (default: `None`) 738 | height of the plot 739 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`) 740 | path where to save the plot 741 | **kwargs: kwargs 742 | keyword arguments for KRR or GP 743 | 744 | Returns 745 | -------- 746 | None 747 | """ 748 | 749 | if mode == 'krr': 750 | warnings.warn('KRR is experimental; please consider using mode=`gp`') 751 | 752 | for path in paths: 753 | for p in path: 754 | assert p in adata.obs[path_key].cat.categories, f'`{p}` is not in `adata.obs[path_key]`. Possible values are: `{list(adata.obs[path_key].cat.categories)}`.' 755 | 756 | # check the input 757 | if 'dpt_pseudotime' not in adata.obs.keys(): 758 | raise ValueError('`dpt_pseudotime` is not in `adata.obs.keys()`') 759 | 760 | # check the genes list 761 | if genes is None: 762 | genes = adata[:, adata.var['velocity_genes']].var_names[:n_velocity_genes] 763 | 764 | genes_indicator = np.in1d(genes, adata.var_names) #[gene in adata.var_names for gene in genes] 765 | if not all(genes_indicator): 766 | genes_missing = np.array(genes)[np.invert(genes_indicator)] 767 | print(f'Could not find the following genes: `{genes_missing}`.') 768 | genes = list(np.array(genes)[genes_indicator]) 769 | 770 | mapper = _create_mapper(adata, color_key) 771 | figs, adatas = [], [] 772 | 773 | for gene in genes: 774 | data = defaultdict(list) 775 | row_figs = [] 776 | y_lim_min, y_lim_max = np.inf, -np.inf 777 | for path in paths: 778 | path_ix = np.in1d(adata.obs[path_key], path) 779 | ad = adata[path_ix].copy() 780 | 781 | minn, maxx = time_span 782 | ad.obs['dpt_pseudotime'] = ad.obs['dpt_pseudotime'].replace(np.inf, 1) 783 | minn = np.min(ad.obs['dpt_pseudotime']) if minn is None else minn 784 | maxx = np.max(ad.obs['dpt_pseudotime']) if maxx is None else maxx 785 | 786 | # wish I could get rid of this copy 787 | ad = ad[(ad.obs['dpt_pseudotime'] >= minn) & (ad.obs['dpt_pseudotime'] <= maxx)] 788 | 789 | gene_exp = ad[:, gene].layers[exp_key] if exp_key != 'X' else (ad.raw if use_raw else ad)[:, gene].X 790 | 791 | # exclude dropouts 792 | ix = (gene_exp > 0).squeeze() 793 | indexer = slice(None) if show_zero_counts else ix 794 | # just use for sanity check with asserts 795 | rev_indexer = ix if show_zero_counts else slice(None) 796 | 797 | dpt = ad.obs['dpt_pseudotime'] 798 | 799 | if issparse(gene_exp): 800 | gene_exp = gene_exp.A 801 | 802 | gene_exp = np.squeeze(gene_exp[indexer, None]) 803 | data['expr'].append(gene_exp) 804 | y_lim_min, y_lim_max = min(y_lim_min, np.min(gene_exp)), max(y_lim_max, np.max(gene_exp)) 805 | 806 | # compute smoothed values from expression 807 | data['dpt'].append(np.squeeze(dpt[indexer, None])) 808 | data[color_key].append(np.array(ad[indexer].obs[color_key])) 809 | 810 | assert all(gene_exp[rev_indexer] > 0) 811 | 812 | if len(gene_exp[rev_indexer]) == 0: 813 | print(f'All counts are 0 for: `{gene}`.') 814 | continue 815 | 816 | x_test, exp_mean, exp_cov = _smooth_expression(np.expand_dims(dpt[ix], -1), gene_exp[ix if show_zero_counts else slice(None)], mode=mode, 817 | time_span=time_span, n_points=n_points, kernel_params=dict(k=dict(length_scale=length_scale)), 818 | **kwargs) 819 | 820 | data['x_test'].append(x_test) 821 | data['x_mean'].append(exp_mean) 822 | data['x_cov'].append(exp_cov) 823 | 824 | # we need this for the _create mapper 825 | adatas.append(ad[indexer]) 826 | 827 | if separate_paths: 828 | dataframe = pd.DataFrame(data, index=list(map(lambda path: ', '.join(map(str, path)), [path]))) 829 | row_figs.append(_create_gt_fig(adatas, dataframe, color_key, title=gene, color_mapper=mapper, 830 | show_cont_annot=show_cont_annot, legend_loc=legend_loc, genes=extra_genes, 831 | use_raw=use_raw, plot_width=plot_width, plot_height=plot_height)) 832 | adatas = [] 833 | data = defaultdict(list) 834 | 835 | if separate_paths: 836 | if share_y: 837 | # first child is the figure 838 | for fig in map(lambda c: c.children[0], row_figs): 839 | fig.y_range = Range1d(y_lim_min - 0.1, y_lim_max + 0.1) 840 | 841 | figs.append(row(row_figs)) 842 | row_figs = [] 843 | else: 844 | dataframe = pd.DataFrame(data, index=list(map(lambda path: ', '.join(map(str, path)), paths))) 845 | figs.append(_create_gt_fig(adatas, dataframe, color_key, title=gene, color_mapper=mapper, 846 | show_cont_annot=show_cont_annot, legend_loc=legend_loc, genes=extra_genes, 847 | use_raw=use_raw, plot_width=plot_width, plot_height=plot_height)) 848 | 849 | plot = column(*figs) 850 | 851 | if save is not None: 852 | save = save if str(save).endswith('.html') else str(save) + '.html' 853 | bokeh_save(plot, save) 854 | else: 855 | show(plot) 856 | 857 | 858 | def highlight_de(adata, basis='umap', components=[1, 2], n_top_genes=10, 859 | de_keys='names, scores, pvals_adj, logfoldchanges', 860 | cell_keys='', n_neighbors=5, fill_alpha=0.1, show_hull=True, 861 | legend_loc='top_right', plot_width=None, plot_height=None, save=None): 862 | """ 863 | Highlight differential expression by hovering over clusters. 864 | 865 | Params 866 | -------- 867 | adata: AnnData 868 | annotated data object 869 | basis: str, optional (default: `'umap'`) 870 | basis used in visualization 871 | components: list(int), optional (default: `[1, 2]`) 872 | components of the basis 873 | n_top_genes: int, optional (default: `10`) 874 | number of differentially expressed genes to display 875 | de_keys: list(str); str, optional (default: `'names, scores, pvals_ads, logfoldchanges'`) 876 | list or comma-seperated values of keys in `adata.uns['rank_genes_groups'].keys()` 877 | to be displayed for each cluster 878 | cell_keys: list(str); str, optional (default: '') 879 | keys in `adata.obs_keys()` to be displayed 880 | n_neighbors: int, optional (default: `5`) 881 | number of neighbors for KNN classifier, which 882 | controls how the convex hull looks like 883 | fill_alpha: float, optional (default: `0.1`) 884 | alpha value of the cluster colors 885 | show_hull: bool, optional (default: `True`) 886 | show the convex hull along each cluster 887 | legend_loc: str, default(`'top_right'`) 888 | position of the legend 889 | plot_width: int, optional (default: `None`) 890 | width of the plot 891 | plot_height: int, optional (default: `None`) 892 | height of the plot 893 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`) 894 | path where to save the plot 895 | 896 | Returns 897 | -------- 898 | None 899 | """ 900 | 901 | if 'rank_genes_groups' not in adata.uns_keys(): 902 | raise ValueError('Run differential expression first.') 903 | 904 | 905 | if isinstance(de_keys, str): 906 | de_keys = list(dict.fromkeys(map(str.strip, de_keys.split(',')))) 907 | if de_keys != ['']: 908 | assert all(map(lambda k: k in adata.uns['rank_genes_groups'].keys(), de_keys)), 'Not all keys are in `adata.uns[\'rank_genes_groups\']`.' 909 | else: 910 | de_keys = [] 911 | 912 | if isinstance(cell_keys, str): 913 | cell_keys = list(dict.fromkeys(map(str.strip, cell_keys.split(',')))) 914 | if cell_keys != ['']: 915 | assert all(map(lambda k: k in adata.obs.keys(), cell_keys)), 'Not all keys are in `adata.obs.keys()`.' 916 | else: 917 | cell_keys = [] 918 | 919 | if f'X_{basis}' not in adata.obsm.keys(): 920 | raise ValueError(f'Key `X_{basis}` not found in adata.obsm.') 921 | 922 | if not isinstance(components, np.ndarray): 923 | components = np.asarray(components) 924 | 925 | key = adata.uns['rank_genes_groups']['params']['groupby'] 926 | if key not in cell_keys: 927 | cell_keys.insert(0, key) 928 | 929 | df = pd.DataFrame(adata.obsm[f'X_{basis}'][:, components - (basis != 'diffmap')], columns=['x', 'y']) 930 | for k in cell_keys: 931 | df[k] = list(map(str, adata.obs[k])) 932 | 933 | knn = neighbors.KNeighborsClassifier(n_neighbors) 934 | knn.fit(df[['x', 'y']], adata.obs[key]) 935 | df['prediction'] = knn.predict(df[['x', 'y']]) 936 | 937 | conv_hulls = df[df[key] == df['prediction']].groupby(key).apply(lambda df: df.iloc[ConvexHull(np.vstack([df['x'], df['y']]).T).vertices]) 938 | 939 | mapper = _create_mapper(adata, key) 940 | categories = adata.obs[key].cat.categories 941 | fig = figure(tools='pan, reset, wheel_zoom, lasso_select, save') 942 | _set_plot_wh(fig, plot_width, plot_height) 943 | legend_dict = defaultdict(list) 944 | 945 | for k in categories: 946 | d = df[df[key] == k] 947 | data_source = ColumnDataSource(d) 948 | legend_dict[k].append(fig.scatter('x', 'y', source=data_source, color={'field': key, 'transform': mapper}, size=5, muted_alpha=0)) 949 | 950 | hover_cell = HoverTool(renderers=[r[0] for r in legend_dict.values()], tooltips=[(f'{key}', f'@{key}')] + [(f'{k}', f'@{k}') for k in cell_keys[1:]]) 951 | 952 | c_hulls = conv_hulls.copy() 953 | de_possible = conv_hulls[key].isin(adata.uns['rank_genes_groups']['names'].dtype.names) 954 | ok_patches = [] 955 | prev_cat = [] 956 | for i, isin in enumerate((~de_possible, de_possible)): 957 | conv_hulls = c_hulls[isin] 958 | 959 | if len(conv_hulls) == 0: 960 | continue 961 | 962 | # must use 'group' instead of key since key is MultiIndex 963 | conv_hulls.rename(columns={'louvain': 'group'}, inplace=True) 964 | xs, ys, ks = zip(*conv_hulls.groupby('group').apply(lambda df: list(map(list, (df['x'], df['y'], df['group']))))) 965 | tmp_data = defaultdict(list) 966 | tmp_data['xs'] = xs 967 | tmp_data['ys'] = ys 968 | tmp_data[key] = list(map(lambda k: k[0], ks)) 969 | 970 | if i == 1: 971 | ix = list(map(lambda k: adata.uns['rank_genes_groups']['names'].dtype.names.index(k), tmp_data[key])) 972 | for k in de_keys: 973 | tmp = np.array(list(zip(*adata.uns['rank_genes_groups'][k])))[ix, :n_top_genes] 974 | for j in range(n_top_genes): 975 | tmp_data[f'{k}_{j}'] = tmp[:, j] 976 | 977 | tmp_data = pd.DataFrame(tmp_data) 978 | for k in categories: 979 | d = tmp_data[tmp_data[key] == k] 980 | source = ColumnDataSource(d) 981 | 982 | patches = fig.patches('xs', 'ys', source=source, fill_alpha=fill_alpha, muted_alpha=0, hover_alpha=0.5, 983 | color={'field': key, 'transform': mapper} if (show_hull and i == 1) else None, 984 | hover_color={'field': key, 'transform': mapper} if (show_hull and i == 1) else None) 985 | legend_dict[k].append(patches) 986 | if i == 1: 987 | ok_patches.append(patches) 988 | 989 | hover_group = HoverTool(renderers=ok_patches, tooltips=[(f'{key}', f'@{key}'), 990 | ('groupby', adata.uns['rank_genes_groups']['params']['groupby']), 991 | ('reference', adata.uns['rank_genes_groups']['params']['reference']), 992 | ('rank', ' | '.join(de_keys))] + [(f'#{i + 1}', ' | '.join((f'@{k}_{i}' for k in de_keys))) for i in range(n_top_genes)] 993 | ) 994 | 995 | 996 | fig.toolbar.active_inspect = [hover_group] 997 | if len(cell_keys) > 1: 998 | fig.add_tools(hover_group, hover_cell) 999 | else: 1000 | fig.add_tools(hover_group) 1001 | 1002 | if legend_loc is not None: 1003 | legend = Legend(items=list(legend_dict.items()), location=legend_loc) 1004 | fig.add_layout(legend) 1005 | fig.legend.click_policy = 'hide' # hide does disable hovering, whereas 'mute' does not 1006 | 1007 | fig.xaxis.axis_label = f'{basis}_{components[0]}' 1008 | fig.yaxis.axis_label = f'{basis}_{components[1]}' 1009 | 1010 | if save is not None: 1011 | save = save if str(save).endswith('.html') else str(save) + '.html' 1012 | bokeh_save(fig, save) 1013 | else: 1014 | show(fig) 1015 | 1016 | 1017 | def link_plot(adata, key, genes=None, basis=['umap', 'pca'], components=[1, 2], 1018 | subsample=None, steps=[40, 40], sample_size=500, 1019 | distance=2, cutoff=True, highlight_only=None, palette=None, 1020 | show_legend=False, legend_loc='top_right', plot_width=None, plot_height=None, save=None): 1021 | """ 1022 | Display the distances of cells from currently highlighted cell. 1023 | 1024 | Params 1025 | -------- 1026 | adata: AnnData 1027 | annotated data object 1028 | key: str 1029 | key in `adata.obs_keys()` to color the static plot 1030 | genes: list(str), optional (default: `None`) 1031 | list of genes in `adata.var_names`, 1032 | which are used to compute the distance; 1033 | if None, take all the genes 1034 | basis: list(str), optional (default:`['umap', 'pca']`) 1035 | list of basis to use when plotting; 1036 | only the first plot is hoverable 1037 | components: list(int); list(list(int)), optional (default: `[1, 2]`) 1038 | list of components for each basis 1039 | subsample: str, optional (default: `None`) 1040 | subsample strategy to use when there are too many cells 1041 | possible values are: `"density"`, `"uniform"`, `None` 1042 | steps: int; list(int), optional (default: `[40, 40]`) 1043 | number of steps in each direction when using `subsample="uniform"` 1044 | sample_size: int, optional (default: `500`) 1045 | number of cells to sample based on their density in the respective embedding 1046 | when using `subsample="density"`; should be < `1000` 1047 | distance: int; str, optional (default: `2`) 1048 | for integers, use p-norm, 1049 | for strings, only `'dpt'` is available 1050 | cutoff: bool, optional (default: `True`) 1051 | if `True`, do not color cells whose distance is further away 1052 | than the threshold specified by the slider 1053 | highlight_only: 'str', optional (default: `None`) 1054 | key in `adata.obs_keys()`, which makes highlighting 1055 | work only on clusters specified by this parameter 1056 | palette: matplotlib.colors.Colormap; list(str), optional (default: `None`) 1057 | colormap to use, if None, use plt.cm.RdYlBu 1058 | show_legend: bool, optional (default: `False`) 1059 | display the legend also in the linked plot 1060 | legend_loc: str, optional (default `'top_right'`) 1061 | location of the legend 1062 | seed: int, optional (default: `None`) 1063 | seed when `subsample='density'` 1064 | plot_width: int, optional (default: `None`) 1065 | width of the plot 1066 | plot_height: int, optional (default: `None`) 1067 | height of the plot 1068 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`) 1069 | path where to save the plot 1070 | 1071 | Returns 1072 | -------- 1073 | None 1074 | """ 1075 | 1076 | assert key in adata.obs.keys(), f'`{key}` not found in `adata.obs`.' 1077 | 1078 | if subsample == 'uniform': 1079 | adata, _ = sample_unif(adata, steps, basis[0]) 1080 | elif subsample == 'density': 1081 | adata, _ = sample_density(adata, sample_size, basis[0], seed=seed) 1082 | elif subsample is not None: 1083 | raise ValueError(f'Unknown subsample strategy: `{subsample}`.') 1084 | 1085 | palette = cm.RdYlBu if palette is None else palette 1086 | if isinstance(palette, matplotlib.colors.Colormap): 1087 | palette = to_hex_palette(palette(range(palette.N), 1., bytes=True)) 1088 | 1089 | if not isinstance(components[0], list): 1090 | components = [components] 1091 | 1092 | if len(components) != len(basis): 1093 | assert len(basis) % len(components) == 0 and len(basis) >= len(components) 1094 | components = components * (len(basis) // len(components)) 1095 | 1096 | if not isinstance(components, np.ndarray): 1097 | components = np.asarray(components) 1098 | 1099 | if highlight_only is not None: 1100 | assert highlight_only in adata.obs_keys(), f'`{highlight_only}` is not in adata.obs_keys().' 1101 | 1102 | genes = adata.var_names if genes is None else genes 1103 | gene_subset = np.in1d(adata.var_names, genes) 1104 | 1105 | if distance != 'dpt': 1106 | d = adata.X[:, gene_subset] 1107 | if issparse(d): 1108 | d = d.A 1109 | dmat = distance_matrix(d, d, p=distance) 1110 | else: 1111 | if not all(gene_subset): 1112 | warnings.warn('`genes` is not None, are you sure this is what you want when using `dpt` distance?') 1113 | 1114 | dmat = [] 1115 | ad_tmp = adata.copy() 1116 | ad_tmp = ad_tmp[:, gene_subset] 1117 | for i in range(ad_tmp.n_obs): 1118 | ad_tmp.uns['iroot'] = i 1119 | sc.tl.dpt(ad_tmp) 1120 | dmat.append(list(ad_tmp.obs['dpt_pseudotime'].replace([np.nan, np.inf], [0, 1]))) 1121 | 1122 | dmat = pd.DataFrame(dmat, columns=list(map(str, range(adata.n_obs)))) 1123 | df = pd.concat([pd.DataFrame(adata.obsm[f'X_{bs}'][:, comp - (bs != 'diffmap')], columns=[f'x{i}', f'y{i}']) 1124 | for i, (bs, comp) in enumerate(zip(basis, components))] + [dmat], axis=1) 1125 | df['hl_color'] = np.nan 1126 | df['index'] = range(len(df)) 1127 | df['hl_key'] = list(adata.obs[highlight_only]) if highlight_only is not None else 0 1128 | df[key] = list(map(str, adata.obs[key])) 1129 | 1130 | start_ix = '0' # our root cell 1131 | ds = ColumnDataSource(df) 1132 | mapper = linear_cmap(field_name='hl_color', palette=palette, 1133 | low=df[start_ix].min(), high=df[start_ix].max()) 1134 | static_fig_mapper = _create_mapper(adata, key) 1135 | 1136 | static_figs = [] 1137 | figs, renderers = [], [] 1138 | for i, bs in enumerate(basis): 1139 | # linked plots 1140 | fig = figure(tools='pan, reset, save, ' + ('zoom_in, zoom_out' if i == 0 else 'wheel_zoom'), 1141 | title=bs, plot_width=400, plot_height=400) 1142 | _set_plot_wh(fig, plot_width, plot_height) 1143 | 1144 | kwargs = {} 1145 | if show_legend and legend_loc is not None: 1146 | kwargs['legend_group'] = 'hl_key' if highlight_only is not None else key 1147 | 1148 | scatter = fig.scatter(f'x{i}', f'y{i}', source=ds, line_color=mapper, color=mapper, 1149 | hover_color='black', size=8, line_width=8, line_alpha=0, **kwargs) 1150 | if show_legend and legend_loc is not None: 1151 | fig.legend.location = legend_loc 1152 | 1153 | figs.append(fig) 1154 | renderers.append(scatter) 1155 | 1156 | # static plots 1157 | fig = figure(title=bs, plot_width=400, plot_height=400) 1158 | 1159 | fig.scatter(f'x{i}', f'y{i}', source=ds, size=8, 1160 | color={'field': key, 'transform': static_fig_mapper}, **kwargs) 1161 | 1162 | if legend_loc is not None: 1163 | fig.legend.location = legend_loc 1164 | 1165 | static_figs.append(fig) 1166 | 1167 | fig = figs[0] 1168 | 1169 | end = dmat[~np.isinf(dmat)].max().max() if distance != 'dpt' else 1.0 1170 | slider = Slider(start=0, end=end, value=end / 2, step=end / 1000, 1171 | title='Distance ' + '(dpt)' if distance == 'dpt' else f'({distance}-norm)') 1172 | col_ds = ColumnDataSource(dict(value=[start_ix])) 1173 | update_color_code = f''' 1174 | source.data['hl_color'] = source.data[first].map( 1175 | (x, i) => {{ return isNaN(x) || 1176 | {'x > slider.value || ' if cutoff else ''} 1177 | source.data['hl_key'][first] != source.data['hl_key'][i] ? NaN : x; }} 1178 | ); 1179 | ''' 1180 | slider.callback = CustomJS(args={'slider': slider, 'mapper': mapper['transform'], 'source': ds, 'col': col_ds}, code=f''' 1181 | mapper.high = slider.value; 1182 | var first = col.data['value']; 1183 | {update_color_code} 1184 | source.change.emit(); 1185 | ''') 1186 | 1187 | h_tool = HoverTool(renderers=renderers, tooltips=[], show_arrow=False) 1188 | h_tool.callback = CustomJS(args=dict(source=ds, slider=slider, col=col_ds), code=f''' 1189 | var indices = cb_data.index['1d'].indices; 1190 | if (indices.length == 0) {{ 1191 | source.data['hl_color'] = source.data['hl_color']; 1192 | }} else {{ 1193 | var first = indices[0]; 1194 | source.data['hl_color'] = source.data[first]; 1195 | {update_color_code} 1196 | col.data['value'] = first; 1197 | col.change.emit(); 1198 | }} 1199 | source.change.emit(); 1200 | ''') 1201 | fig.add_tools(h_tool) 1202 | 1203 | color_bar = ColorBar(color_mapper=mapper['transform'], width=12, location=(0,0)) 1204 | fig.add_layout(color_bar, 'left') 1205 | 1206 | fig.add_tools(h_tool) 1207 | plot = column(slider, row(*static_figs), row(*figs)) 1208 | 1209 | if save is not None: 1210 | save = save if str(save).endswith('.html') else str(save) + '.html' 1211 | bokeh_save(plot, save) 1212 | else: 1213 | show(plot) 1214 | 1215 | 1216 | def _get_mappers(adata, df, genes=[], use_raw=True, sort=True): 1217 | if sort: 1218 | genes = sorted(genes) 1219 | 1220 | mappers = {c:{'field': c, 'transform': _create_mapper(adata, c)} 1221 | for c in (sorted if sort else list)(filter(lambda c: adata.obs[c].dtype.name != 'category', adata.obs.columns)) + genes} 1222 | 1223 | # assume all columns in .obs are numbers 1224 | for k in filter(lambda k: k not in genes, mappers.keys()): 1225 | df[k] = list(adata.obs[k].astype(float)) 1226 | 1227 | indices, = np.where(np.in1d(adata.var_names, genes)) 1228 | for ix in indices: 1229 | df[adata.var_names[ix]] = (adata.raw if use_raw else adata).X[:, ix] 1230 | 1231 | return df, mappers 1232 | 1233 | 1234 | def _add_color_select(key, fig, renderers, source, mappers, colors=['color', 'fill_color', 'line_color'], 1235 | color_bar_pos='right', suffix=''): 1236 | color_bar = ColorBar(color_mapper=mappers[key]['transform'], width=10, location=(0, 0)) 1237 | fig.add_layout(color_bar, color_bar_pos) 1238 | 1239 | code = _inter_color_code(*colors) 1240 | callback= CustomJS(args=dict(renderers=renderers, source=source, color_bar=color_bar, cmaps=mappers), 1241 | code=code) 1242 | 1243 | return Select(title=f'Select variable to color{suffix}:', value=key, 1244 | options=list(mappers.keys()), callback=callback) 1245 | -------------------------------------------------------------------------------- /interactive_plotting/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | from interactive_plotting.experimental.plots import heatmap, scatter2 as scatter 2 | from interactive_plotting.experimental.scatter3d import scatter3d 3 | -------------------------------------------------------------------------------- /interactive_plotting/experimental/plots.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from interactive_plotting.utils._utils import * 4 | 5 | from pandas.api.types import is_categorical 6 | from collections import OrderedDict as odict 7 | from datashader.colors import * 8 | from bokeh.palettes import Viridis256 9 | from holoviews.streams import Selection1D 10 | from holoviews.operation import decimate 11 | from holoviews.operation.datashader import datashade, dynspread, rasterize 12 | from bokeh.models import HoverTool 13 | 14 | import numpy as np 15 | import pandas as pd 16 | import datashader as ds 17 | import holoviews as hv 18 | import warnings 19 | 20 | 21 | def pad(minn, maxx, padding=0.1): 22 | if minn > maxx: 23 | maxx, minn = minn, maxx 24 | delta = maxx - minn 25 | 26 | return minn - (delta * padding), maxx + (delta * padding) 27 | 28 | 29 | def minmax(component, perc=None, is_sorted=False): 30 | if perc is not None: 31 | assert len(perc) == 2, 'Percentile must be of length 2.' 32 | component = np.clip(component, *np.percentile(component, sorted(perc))) 33 | 34 | return (np.nanmin(component), np.nanmax(component)) if not is_sorted else (component[0], component[-1]) 35 | 36 | 37 | def scatter2(adata, x, y, color, order_key=None, indices=None, layer=None, subsample='datashade', use_raw=False, 38 | size=5, jitter=None, perc=None, cmap=None, 39 | hover_keys=None, hover_dims=(10, 10), kde=None, density=None, density_size=150, 40 | keep_frac=0.2, steps=40, seed=None, use_original_limits=False, 41 | legend_loc='top_right', show_legend=True, plot_height=600, plot_width=600, save=None): 42 | ''' 43 | Plot a scatterplot. 44 | 45 | Params 46 | ------- 47 | adata: anndata.AnnData 48 | adata object 49 | x: Union[Int, Str] 50 | values on the x-axis 51 | if of type `NoneType`, x-axis will be based on `order_key` 52 | if of type `Int`, it corresponds to index in `adata.var_names` 53 | if of type `Str`, it can be either: 54 | - key in `adata.var_names` 55 | - key in `adata.obs` 56 | - `'{basis_name}:{index}'` or `'{basis_name}'`, such as 57 | `'{basis}:{index}'` where `'{basis}'` is a basis from `adata.obsm` and 58 | `'{index}'` is the number of the component to choose 59 | y: Union[Int, Str] 60 | values on the y-axis 61 | if of type `Int`, it corresponds to index in `adata.var_names` 62 | if of type `Str`, it can be either: 63 | - key in `adata.var_names` 64 | - key in `adata.obs` 65 | - `'{basis_name}:{index}'` or `'{basis_name}'`, such as 66 | `'{basis}:{index}'` where `'{basis}'` is a basis from `adata.obsm` and 67 | `'{index}'` is the number of the component to choose 68 | color: Union[Str, NoneType], optional (default: `None`) 69 | key in `adata.obs` to color in 70 | order_key: Union[Str, NoneType], optional (default: `None`) 71 | key in `adata.obs` which defines cell-ordering, 72 | such as `'dpt_pseudotime'`, cells will 73 | be sorted in ascending order 74 | indices: Union[np.array[Int], np.array[Bool], NoneType], optional (default: `None`) 75 | subset of cells to plot, 76 | if `None`, plot all cells 77 | layer: Union[Str, NoneType], optinal (default: `None`) 78 | key in `adata.layers`, used when `x` or `y` is a gene 79 | subsample: Str, optional (default: `'datashade'`) 80 | subsampling strategy for large data 81 | possible values are `None, 'none', 'datashade', 'decimate'` 82 | using `subsample='datashade'` is preferred over other option since it does not subset 83 | when using `subsample='datashade'`, colorbar is not visible 84 | use_raw: Bool, optional (default: `False`) 85 | whether to use `.raw` attribute 86 | size: Int, optional (default: `5`) 87 | size of the glyphs 88 | jitter: Union[Tuple[Float, Float], NoneType], optional (default: `None`) 89 | variance of normal distribution to use 90 | as a jittering 91 | perc: Union[List[Float, Float], NoneType], optional (default: `None`) 92 | percentiles for color clipping 93 | hover_keys: Union[List[Str], NoneType], optional (default: `[]`) 94 | keys in `adata.obs` to display when hovering over cells, 95 | if `None`, display nothing, 96 | if `[]`, display cell index, 97 | note that if `subsample='datashade'`, only cell index can be display 98 | kde: Union[Float, NoneType], optional (default: `None`) 99 | whether to show a kernel density estimate 100 | if a `Float`, it corresponds to the bandwidth 101 | density: Union[Str, None], optional (default: `None`) 102 | whether to show marginal x,y-densities, 103 | can be one of `'group'`, `'all'` 104 | if `'group'` and `color` is categorical variable, 105 | the density estimate is shown per category 106 | if `'all'`, the density is estimated using all points 107 | density_size: Int, optional (default: `150`) 108 | height and width of density plots 109 | hover_dims: Tuple[Int, Int], optional (default: `(10, 10)`) 110 | number of rows and columns of hovering tiles, 111 | only used when `subsample='datashade'` 112 | keep_frac: Float, optional (default: `0.2`) 113 | fraction of cells to keep, used when `subsample='decimate'` 114 | steps: Union[Int, Tuple[Int, Int]], optional (default: `40`) 115 | step size when the embedding directions 116 | larger step size corresponds to higher density of points 117 | seed: Union[Float, NoneType], optional (default: `None`) 118 | random seed, used when `subsample='decimate'` 119 | use_original_limits: Bool, optional (default: `False`) 120 | internal use only 121 | legend_loc: Str, optional (default: `'top_right'`) 122 | position of the legend 123 | show_legend:, Bool, optional (default: `True`) 124 | whether to show legend 125 | plot_height: Int, optional (default: `600`) 126 | height of the plot 127 | plot_width: Int, optional (default: `600`) 128 | width of the plot 129 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`) 130 | path where to save the plot 131 | 132 | Returns 133 | ------- 134 | plot: hv.ScatterPlot 135 | a scatterplot 136 | ''' 137 | 138 | if hover_keys is not None: 139 | for h in hover_keys: 140 | assert h in adata.obs, f'Hover key `{h}` not found in `adata.obs`.' 141 | 142 | if perc is not None: 143 | assert len(perc) == 2, f'`perc` must be of length `2`, found: `{len(perc)}`.' 144 | assert all((p is not None for p in perc)), '`perc` cannot contain `None`.' 145 | perc = sorted(perc) 146 | 147 | assert len(hover_dims) == 2, f'Expected `hover_dims` to be of length `2`, found `{len(hover_dims)}`.' 148 | assert all((isinstance(d, int) for d in hover_dims)), 'All of `hover_dims` must be of type `int`.' 149 | assert all((d > 1 for d in hover_dims)), 'All of `hover_dims` must be `> 1`.' 150 | 151 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.' 152 | assert subsample in (None, 'none', 'datashade', 'decimate'), f'Invalid subsampling strategy `{subsample}`. ' \ 153 | 'Possible values are `None, \'none\', \'datashade\', \'decimate\'`.' 154 | 155 | adata_mraw = get_mraw(adata, use_raw and layer is None) 156 | if adata_mraw is adata and use_raw: 157 | warnings.warn('Failed fetching the `.raw`. attribute of `adata`.') 158 | 159 | if indices is None: 160 | indices = np.arange(adata_mraw.n_obs) 161 | 162 | xlim, ylim = None, None 163 | if order_key is not None: 164 | assert order_key in adata.obs, f'`{order_key}` not found in `adata.obs`.' 165 | ixs = np.argsort(adata.obs[order_key][indices]) 166 | else: 167 | ixs = np.arange(len(indices)) 168 | 169 | if x is None: 170 | if order_key is not None: 171 | x, xlabel = adata.obs[order_key][indices][ixs], order_key 172 | else: 173 | x, xlabel = ixs, 'index' 174 | else: 175 | x, xlabel, xlim = get_xy_data(x, adata, adata_mraw, layer, indices, use_original_limits) 176 | 177 | y, ylabel, ylim = get_xy_data(y, adata, adata_mraw, layer, indices, use_original_limits, inc=1) 178 | 179 | # jitter 180 | if jitter is not None: 181 | msg = 'Using `jitter` != `None` and `use_original_limits=True` can negatively impact the limits.' 182 | if isinstance(jitter, (tuple, list)): 183 | assert len(jitter) == 2, f'`jitter` must be of length `2`, found `{len(jitter)}`.' 184 | if any((j is not None for j in jitter)) and use_original_limits: 185 | warnings.warn(msg) 186 | else: 187 | assert isinstance(jitter, float), 'Expected` jitter` to be of type `float`, found `{type(jitter).__name__}`.' 188 | warnings.warn(msg) 189 | 190 | x = x.astype(np.float64) 191 | y = y.astype(np.float64) 192 | 193 | adata_mraw = adata_mraw[indices, :] 194 | if color is not None: 195 | if color in adata.obs: 196 | condition = adata.obs[color][indices][ixs] 197 | else: 198 | if isinstance(color, int): 199 | color = adata_mraw.var_names[color] 200 | condition = adata_mraw.obs_vector(color)[ixs] 201 | else: 202 | condition = None 203 | 204 | hover = None 205 | if hover_keys is not None: 206 | hover = {'index': ixs} 207 | for key in hover_keys: 208 | hover[key] = adata.obs[key][indices][ixs] 209 | 210 | plot = _scatter(adata, x=x.copy(), y=y.copy(), 211 | condition=condition, by=color, 212 | xlabel=xlabel, ylabel=ylabel, 213 | title=color, hover=hover, jitter=jitter, 214 | perc=perc, xlim=xlim, ylim=ylim, 215 | hover_width=hover_dims[1], hover_height=hover_dims[0], kde=kde, density=density, density_size=density_size, 216 | subsample=subsample, steps=steps, keep_frac=keep_frac, seed=seed, legend_loc=legend_loc, 217 | size=size, cmap=cmap, show_legend=show_legend, plot_height=plot_height, plot_width=plot_width) 218 | 219 | if save is not None: 220 | hv.renderer('bokeh').save(plot, save) 221 | 222 | return plot 223 | 224 | 225 | def _scatter(adata, x, y, condition, by=None, subsample='datashade', steps=40, keep_frac=0.2, 226 | seed=None, legend_loc='top_right', size=4, xlabel=None, ylabel=None, title=None, 227 | use_raw=True, hover=None, hover_width=10, hover_height=10, kde=None, 228 | density=None, density_size=150, jitter=None, perc=None, xlim=None, ylim=None, 229 | cmap=None, show_legend=True, plot_height=400, plot_width=400): 230 | 231 | _sentinel = object() 232 | 233 | def create_density_plots(df, density, kdims, cmap): 234 | cm = {} 235 | if density == 'all': 236 | dfs = {_sentinel: df} 237 | elif density == 'group': 238 | if 'z' not in df.columns: 239 | warnings.warn(f'`density=\'groups\' was specified, but no group found. Did you specify `color=...`?') 240 | dfs = {_sentinel: df} 241 | elif not is_categorical(df['z']): 242 | warnings.warn(f'`density=\'groups\' was specified, but column `{condition}` is not categorical.') 243 | dfs = {_sentinel: df} 244 | else: 245 | dfs = {k:v for k, v in df.groupby('z')} 246 | cm = cmap 247 | else: 248 | raise ValueError(f'Invalid `density` type: \'`{density}`\'. Possible values are `\'all\'`, `\'group\'`.') 249 | # assumes x, y order in kdims 250 | return [hv.Overlay([hv.Distribution(df, kdims=dim).opts(color=cm.get(k, 'black'), 251 | framewise=True) 252 | for k, df in dfs.items()]) 253 | for dim in kdims] 254 | 255 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.' 256 | 257 | adata_mraw = get_mraw(adata, use_raw) 258 | 259 | if subsample == 'uniform': 260 | cb_kwargs = {'steps': steps} 261 | elif subsample == 'density': 262 | cb_kwargs = {'size': int(keep_frac * adata.n_obs), 'seed': seed} 263 | else: 264 | cb_kwargs = {} 265 | 266 | categorical = False 267 | if condition is None: 268 | cmap = ['black'] * len(x) if subsample == 'datashade' else 'black' 269 | elif is_categorical(condition): 270 | categorical = True 271 | cmap = Sets1to3 if cmap is None else cmap 272 | cmap = odict(zip(condition.cat.categories, adata.uns.get(f'{by}_colors', cmap))) 273 | else: 274 | cmap = Viridis256 if cmap is None else cmap 275 | 276 | jitter_x, jitter_y = None, None 277 | if isinstance(jitter, (tuple, list)): 278 | assert len(jitter) == 2, f'`jitter` must be of length `2`, found `{len(jitter)}`.' 279 | jitter_x, jitter_y = jitter 280 | elif jitter is not None: 281 | jitter_x, jitter_y = jitter, jitter 282 | 283 | if jitter_x is not None: 284 | x += np.random.normal(0, jitter_x, size=x.shape) 285 | if jitter_y is not None: 286 | y += np.random.normal(0, jitter_y, size=y.shape) 287 | 288 | data = {'x': x, 'y': y, 'z': condition} 289 | vdims = ['z'] 290 | 291 | hovertool = None 292 | if hover is not None: 293 | for k, dt in hover.items(): 294 | vdims.append(k) 295 | data[k] = dt 296 | hovertool = HoverTool(tooltips=[(key.capitalize(), f'@{key}') 297 | for key in (['index'] if subsample == 'datashade' else hover.keys())]) 298 | 299 | data = pd.DataFrame(data) 300 | if categorical: 301 | data['z'] = data['z'].astype('category') 302 | 303 | if not vdims: 304 | vdims = None 305 | 306 | if xlim is None: 307 | xlim = pad(*minmax(x)) 308 | if ylim is None: 309 | ylim = pad(*minmax(y)) 310 | 311 | kdims=[('x', 'x' if xlabel is None else xlabel), 312 | ('y', 'y' if ylabel is None else ylabel)] 313 | 314 | scatter = hv.Scatter(data, kdims=kdims, vdims=vdims).sort('z') 315 | scatter = scatter.opts(size=size, xlim=xlim, ylim=ylim) 316 | 317 | kde_plot= None if kde is None else \ 318 | hv.Bivariate(scatter).opts(bandwidth=kde, show_legend=False, line_width=2) 319 | xdist, ydist = (None, None) if density is None else create_density_plots(data, density, kdims, cmap) 320 | 321 | if categorical: 322 | scatter = scatter.opts(cmap=cmap, color='z', show_legend=show_legend, legend_position=legend_loc) 323 | elif 'z' in data: 324 | scatter = scatter.opts(cmap=cmap, color='z', 325 | clim=tuple(map(float, minmax(data['z'], perc))), 326 | colorbar=True, 327 | colorbar_opts={'width': 20}) 328 | else: 329 | scatter = scatter.opts(color='black') 330 | 331 | legend = None 332 | if subsample == 'datashade': 333 | subsampled = dynspread(datashade(scatter, aggregator=(ds.count_cat('z') if categorical else ds.mean('z')) if vdims is not None else None, 334 | color_key=cmap, cmap=cmap, 335 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize], 336 | min_alpha=255).opts(axiswise=True, framewise=True), threshold=0.8, max_px=5) 337 | if show_legend and categorical: 338 | legend = hv.NdOverlay({k: hv.Points([0, 0], label=str(k)).opts(size=0, color=v) 339 | for k, v in cmap.items()}) 340 | if hover is not None: 341 | t = hv.util.Dynamic(rasterize(scatter, width=hover_width, height=hover_height, streams=[hv.streams.RangeXY], 342 | aggregator=ds.reductions.min('index')), operation=hv.QuadMesh)\ 343 | .opts(tools=[hovertool], axiswise=True, framewise=True, 344 | alpha=0, hover_alpha=0.25, 345 | height=plot_height, width=plot_width) 346 | scatter = t * subsampled 347 | else: 348 | scatter = subsampled 349 | 350 | elif subsample == 'decimate': 351 | scatter = decimate(scatter, max_samples=int(adata.n_obs * keep_frac), 352 | streams=[hv.streams.RangeXY(transient=True)], random_seed=seed) 353 | 354 | if legend is not None: 355 | scatter = (scatter * legend).opts(legend_position=legend_loc) 356 | 357 | if kde_plot is not None: 358 | scatter *= kde_plot 359 | 360 | scatter = scatter.opts(height=plot_height, width=plot_width) 361 | scatter = scatter.opts(hv.opts.Scatter(tools=[hovertool])) if hovertool is not None else scatter 362 | 363 | if xdist is not None and ydist is not None: 364 | scatter = (scatter << ydist.opts(width=density_size)) << xdist.opts(height=density_size) 365 | 366 | return scatter.opts(title=title if title is not None else '') 367 | 368 | 369 | def _heatmap(adata, genes, group, sort_genes=True, use_raw=False, 370 | agg_fns=['mean'], hover=True, 371 | xrotation=90, yrotation=0, colorbar=True, cmap=None, 372 | plot_height=300, plot_width=600): 373 | ''' 374 | Internal heatmap function. 375 | 376 | Params 377 | ------- 378 | adata: anndata.AnnData 379 | adata object 380 | genes: List[Str] 381 | genes in `adata.var_names` 382 | group: Str 383 | key in `adata.obs`, must be categorical 384 | sort_genes: Bool, optional (default: `True`) 385 | whether to sort the genes 386 | use_raw: Bool, optional (default: `True`) 387 | whether to use `.raw` attribute 388 | agg_fns: List[Str], optional (default: `['mean']`) 389 | list of pandas' aggregation functions 390 | hover: Bool, optional (deault: `True`) 391 | whether to show hover information 392 | xrotation: Int, optional (default: `90`) 393 | rotation of labels on x-axis 394 | yrotation: Int, optional (default: `0`) 395 | rotation of labels on y-axis 396 | colorbar: Bool, optional (default: `True`) 397 | whether to show colorbar 398 | cmap: Union[List[Str], NoneType], optional (default, `None`) 399 | colormap of the heatmap, 400 | if `None`, use `Viridis256` 401 | plot_height: Int, optional (default: `600`) 402 | height of the heatmap 403 | plot_width: Int, optional (default: `200`) 404 | width of the heatmap 405 | 406 | Returns 407 | ------- 408 | plot: hv.HeatMap 409 | a heatmap 410 | ''' 411 | 412 | assert group in adata.obs 413 | assert is_categorical(adata.obs[group]) 414 | assert len(agg_fns) > 0 415 | 416 | for g in genes: 417 | assert g in adata.var_names, f'Unable to find gene `{g}` in `adata.var_names`.' 418 | 419 | genes = sorted(genes) if sort_genes else genes 420 | groups = sorted(list(adata.obs[group].cat.categories)) 421 | 422 | adata_mraw = get_mraw(adata, use_raw) 423 | common_subset = list(set(adata.obs_names) & set(adata_mraw.obs_names)) 424 | adata, adata_mraw = adata[common_subset, :], adata_mraw[common_subset, :] 425 | 426 | ixs = np.in1d(adata.obs[group], groups) 427 | adata, adata_mraw = adata[ixs, :], adata_mraw[ixs, :] 428 | 429 | df = pd.DataFrame(adata_mraw[:, genes].X, columns=genes) 430 | df['group'] = list(map(str, adata.obs[group])) 431 | groupby = df.groupby('group') 432 | 433 | vals = {agg_fn: groupby.agg(agg_fn) for agg_fn in agg_fns} 434 | z_value = vals.pop(agg_fns[0]) 435 | 436 | x = hv.Dimension('x', label='Gene') 437 | y = hv.Dimension('y', label='Group') 438 | z = hv.Dimension('z', label='Expression') 439 | vdims = [(k, k.capitalize()) for k in vals.keys()] 440 | 441 | heatmap = hv.HeatMap({'x': np.array(genes), 'y': np.array(groups), 'z': z_value, **vals}, 442 | kdims=[('x', 'Gene'), ('y', 'Group')], 443 | vdims=[('z', 'Expression')] + vdims).opts(tools=['box_select'] + (['hover'] if hover else []), 444 | xrotation=xrotation, yrotation=yrotation) 445 | 446 | return heatmap.opts(frame_width=plot_width, frame_height=plot_height, colorbar=colorbar, cmap=cmap) 447 | 448 | 449 | @wrap_as_col 450 | def heatmap(adata, genes, groups=None, compare='genes', agg_fns=['mean', 'var'], use_raw=False, 451 | order_keys=[], hover=True, show_highlight=False, show_scatter=False, 452 | subsample=None, keep_frac=0.2, seed=None, 453 | xrotation=90, yrotation=0, colorbar=True, cont_cmap=None, 454 | height=200, width=600, save=None, **scatter_kwargs): 455 | ''' 456 | Plot a heatmap with groups selected from a drop-down menu. 457 | If `show_highlight=True` and `show_scatterplot=True`, additional 458 | interaction occurrs when clicking on the highlighted heatmap. 459 | 460 | Params 461 | ------- 462 | adata: anndata.AnnData 463 | adata object 464 | genes: List[Str] 465 | genes in `adata.var_names` 466 | groups: List[Str], optional (default: `None`) 467 | categorical observation in `adata.obs`, 468 | if `None`, get all groups from `adata.obs` 469 | compare: Str, optional (default: `'genes'`) 470 | only used when `show_scatterplot=True`, 471 | creates a drop-down: 472 | if `'genes'`: 473 | drow-down menu will contain values from `genes` and clicking 474 | a gene in highlighted heatmap will plot scatterplot of the 2 475 | genes with groups colored in 476 | if `'basis'`: 477 | drow-down menu will contain available bases and clicking 478 | a gene in highlighted heatmap will plot the gene in the selected 479 | embedding with its expression colored in 480 | if `'order'`: 481 | drop-down menu will contain values from `order_keys`, 482 | and clicking on a gene in highlighted heatmap will plot its expression 483 | in selected order 484 | agg_fns: List[Str], optional (default: `['mean', 'var']` 485 | names of pandas' aggregation functions, such `'min'`, ... 486 | the first function specified is mapped to colors 487 | use_raw: Bool, optional (default: `False`) 488 | whether to use `.raw` for gene expression 489 | order_keys: List[Str], optional (default: `None`) 490 | keys in `adata.obs`, used when `compare='order'` 491 | hover: Bool, optional (default: `True`) 492 | whether to display hover information over the heatmap 493 | show_highlight: Bool, optional (default: `False`) 494 | whether to show when using boxselect 495 | show_scatter: Bool, optional (default: `False`) 496 | whether to show a scatterplot, 497 | if `True`, overrides `show_highlight=False` 498 | subsample: Str, optional (default: `'decimate'`) 499 | subsampling strategy for large data 500 | possible values are `None, 'none', 'decimate'` 501 | keep_frac: Float, optional (default: `0.2`) 502 | fraction of cells to keep, used when `subsample='decimate'` 503 | seed: Union[Float, NoneType], optional (default: `None`) 504 | random seed, used when `subsample='decimate'` 505 | xrotation: Int, optional (default: `90`) 506 | rotation of labels on x-axis 507 | yrotation: Int, optional (default: `0`) 508 | rotation of labels on y-axis 509 | colorbar: Bool, optional (default: `True`) 510 | whether to show colorbar 511 | cont_cmap: Union[List[Str], NoneType], optional (default, `None`) 512 | colormap of the heatmap, 513 | if `None`, use `Viridis256` 514 | height: Int, optional (default: `200`) 515 | height of the heatmap 516 | width: Int, optional (default: `600`) 517 | width of the heatmap 518 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`) 519 | path where to save the plot 520 | **scatter_kwargs: 521 | additional argument for `ipl.experimental.scatter`, 522 | only used when `show_scatter=True` 523 | 524 | Returns 525 | ------- 526 | holoviews plot 527 | ''' 528 | 529 | def _highlight(group, index): 530 | original = hm[group] 531 | if not index: 532 | return original 533 | 534 | return original.iloc[sorted(index)] 535 | 536 | def _scatter(group, which, gwise, x, y): 537 | indices = adata.obs[group] == y if gwise else np.isin(adata.obs[group], highlight[group].data['y']) 538 | # this is necessary 539 | indices = np.where(indices)[0] 540 | 541 | if is_ordered: 542 | scatter_kwargs['order_key'] = which 543 | x, y = None, x 544 | elif f'X_{which}' in adata.obsm: 545 | group = x 546 | x, y = which, which 547 | else: 548 | x, y = x, which 549 | 550 | return scatter2(adata, x=x, y=y, color=group, indices=indices, 551 | **scatter_kwargs).opts(axiswise=True, framewise=True) 552 | 553 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.' 554 | assert subsample in (None, 'none','decimate'), f'Invalid subsampling strategy `{subsample}`. ' \ 555 | 'Possible values are `None, \'none\', \'decimate\'`.' 556 | 557 | assert compare in ('genes', 'basis', 'order'), f'`compare` must be one of `\'genes\', \'basis\', \'order\'`.' 558 | 559 | if cont_cmap is None: 560 | cont_cmap = Viridis256 561 | 562 | is_ordered = False 563 | scatter_kwargs['use_original_limits'] = True 564 | scatter_kwargs['subsample'] = None 565 | if 'plot_width' not in scatter_kwargs: 566 | scatter_kwargs['plot_width'] = 300 567 | if 'plot_height' not in scatter_kwargs: 568 | scatter_kwargs['plot_height'] = 300 569 | 570 | if groups is not None: 571 | assert len(groups) > 0, f'Number of groups `> 1`.' 572 | else: 573 | groups = [k for k in adata.obs.keys() if is_categorical(adata.obs[k])] 574 | 575 | kdims=[hv.Dimension('Group',values=groups, default=groups[0])] 576 | 577 | hm = hv.DynamicMap(lambda g: _heatmap(adata, genes, agg_fns=agg_fns, group=g, 578 | hover=hover, use_raw=use_raw, 579 | cmap=cont_cmap, 580 | xrotation=xrotation, yrotation=yrotation, 581 | colorbar=colorbar), kdims=kdims).opts(frame_height=height, 582 | frame_width=width) 583 | if not show_highlight and not show_scatter: 584 | return hm 585 | 586 | highlight = hv.DynamicMap(_highlight, kdims=kdims, streams=[Selection1D(source=hm)]) 587 | if not show_scatter: 588 | return (hm + highlight).cols(1) 589 | 590 | if compare == 'basis': 591 | basis = [b.lstrip('X_') for b in adata.obsm.keys()] 592 | kdims += [hv.Dimension('Components', values=basis, default=basis[0])] 593 | elif compare == 'genes': 594 | kdims += [hv.Dimension('Genes', values=genes, default=genes[0])] 595 | else: 596 | is_ordered = True 597 | k = scatter_kwargs.pop('order_key', None) 598 | assert k is not None or order_keys != [], f'No order keys specified.' 599 | 600 | if k is not None and k not in order_keys: 601 | order_keys.append(k) 602 | 603 | for k in order_keys: 604 | assert k in adata.obs, f'Order key `{k}` not found in `adata.obs`.' 605 | 606 | kdims += [hv.Dimension('Order', values=order_keys)] 607 | 608 | kdims += [hv.Dimension('Groupwise', type=bool, values=[True, False], default=True)] 609 | 610 | scatter_stream = hv.streams.Tap(source=highlight, x=genes[0], y=adata.obs[groups[0]].values[0]) 611 | scatter = hv.DynamicMap(_scatter, kdims=kdims, streams=[scatter_stream]) 612 | 613 | if subsample == 'decimate': 614 | scatter = decimate(scatter, max_samples=int(adata.n_obs * keep_frac), 615 | streams=[hv.streams.RangeXY(transient=True)], random_seed=seed) 616 | 617 | plot = (hm + highlight + scatter).cols(1) 618 | 619 | if save is not None: 620 | hv.rendered('bokeh').save(plot, save) 621 | 622 | return plot 623 | 624 | -------------------------------------------------------------------------------- /interactive_plotting/experimental/scatter3d.py: -------------------------------------------------------------------------------- 1 | from interactive_plotting.utils import * 2 | 3 | from bokeh.core.properties import Any, Dict, Instance, String 4 | from bokeh.models import ( 5 | ColumnDataSource, 6 | LayoutDOM, 7 | Legend, 8 | LegendItem, 9 | ColorBar, 10 | LinearColorMapper, 11 | FixedTicker 12 | ) 13 | from bokeh.io import save 14 | from bokeh.resources import CDN 15 | from bokeh.layouts import row 16 | from bokeh.plotting import figure 17 | from bokeh.colors import RGB 18 | from pandas.api.types import is_categorical_dtype 19 | from anndata import AnnData 20 | from typing import Union, Optional, Sequence, Tuple 21 | from time import sleep 22 | from collections import defaultdict 23 | 24 | import matplotlib 25 | import matplotlib.cm as cm 26 | import numpy as np 27 | import webbrowser 28 | import tempfile 29 | 30 | 31 | _DEFAULT = { 32 | 'width': '600px', 33 | 'height': '600px', 34 | 'style': 'dot-color', 35 | 'showPerspective': False, 36 | 'showGrid': True, 37 | 'keepAspectRatio': True, 38 | 'verticalRatio': 1.0, 39 | 'cameraPosition': { 40 | 'horizontal': 1, 41 | 'vertical': 0.25, 42 | 'distance': 2, 43 | } 44 | } 45 | 46 | 47 | class Surface3d(LayoutDOM): 48 | __implementation__ = 'surface3d.ts' 49 | __javascript__ = 'https://unpkg.com/vis-graph3d@latest/dist/vis-graph3d.min.js' 50 | 51 | data_source = Instance(ColumnDataSource) 52 | 53 | x = String 54 | y = String 55 | z = String 56 | color = String 57 | 58 | options = Dict(String, Any, default=_DEFAULT) 59 | 60 | 61 | def _to_hex_colors(values, cmap, perc=None): 62 | minn, maxx = minmax(values, perc) 63 | norm = matplotlib.colors.Normalize(vmin=minn, vmax=maxx, clip=True) 64 | 65 | mapper = cm.ScalarMappable(norm=norm, cmap=cmap) 66 | 67 | return [matplotlib.colors.to_hex(mapper.to_rgba(v)) for v in values], minn, maxx 68 | 69 | 70 | def _mpl_to_hex_palette(cmap): 71 | if isinstance(cmap, matplotlib.colors.ListedColormap): 72 | rgb_cmap = (255 * cmap(range(256))).astype('int') 73 | return [RGB(*tuple(rgb)).to_hex() for rgb in rgb_cmap] 74 | 75 | assert all(map(lambda c: matplotlib.colors.is_color_like(c), cmap)), 'Not all colors are color-like.' 76 | 77 | return [matplotlib.colors.to_hex(c) for c in cmap] 78 | 79 | 80 | 81 | def scatter3d(adata: AnnData, 82 | key: str, 83 | basis: str = 'umap', 84 | components: Sequence[int] = (0, 1, 2), 85 | steps: Union[Tuple[int, int], int] = 100, 86 | perc: Optional[Tuple[int, int]] = None, 87 | n_ticks: int = 10, 88 | vertical_ratio: float = 1, 89 | show_axes: bool = False, 90 | keep_aspect_ratio: bool = True, 91 | perspective: bool = True, 92 | tooltips: Optional[Sequence[str]] = [], 93 | cmap: Optional[matplotlib.colors.ListedColormap] = None, 94 | dot_size_ratio: float = 0.01, 95 | show_legend: bool = True, 96 | show_cbar: bool = True, 97 | plot_height: Optional[int] = 1400, 98 | plot_width: Optional[int] = 1400): 99 | """ 100 | Parameters 101 | ---------- 102 | adata : :class:`anndata.AnnData` 103 | Annotated data object. 104 | key 105 | Key in `adata.obs` or `adata.var_names` to color in. 106 | basis 107 | Basis to use. 108 | components 109 | Components of the basis to plot. 110 | steps 111 | Step size when the subsampling the data. 112 | Larger step size corresponds to higher density of points. 113 | perc 114 | Percentile by which to clip colors. 115 | n_ticks 116 | Number of ticks for colorbar if `key` is not categorical. 117 | vertical_ratio 118 | Ratio by which to squish the z-axis. 119 | show_axes 120 | Whether to show axes. 121 | keep_aspect_ratio 122 | Whether to keep aspect ratio. 123 | perspective 124 | Whether to keep the perspective. 125 | tooltips 126 | Keys in `adata.obs` to visualize when hovering over cells. 127 | cmap 128 | Colormap to use. 129 | dot_size_ratio 130 | Ratio of the dots with respect to the plot size. 131 | show_legend 132 | Whether to show legend when annotation is categorical. 133 | show_cbar 134 | Whether to show colorbar when annotation is continuous. 135 | plot_height 136 | Height of the plot in pixels. If `None`, try getting the screen height. 137 | plot_width 138 | Width of the plot in pixels. If `None`, try getting the screen width. 139 | 140 | Returns 141 | ------- 142 | None 143 | Nothing, just plots in a new tab. 144 | """ 145 | 146 | def _wrap_as_div(row, sep=':'): 147 | res = [] 148 | for kind, val in zip(tooltips, row): 149 | if isinstance(val, float): 150 | res.append(f'
{kind}{sep} {val:.04f}
') 151 | else: 152 | res.append(f'
{kind}{sep} {val}
') 153 | 154 | return ''.join(res) 155 | 156 | basis_key = f'X_{basis}' 157 | assert basis_key in adata.obsm, f'Basis `{basis_key}` not found in `adata.obsm`.' 158 | if perc is not None: 159 | assert len(perc) == 2, f'Percentile must be of length `2`, found `{len(perc)}`.' 160 | assert len(components) == 3, f'Number of components must be `3`, found `{len(components)}`.' 161 | assert all(c >= 0 for c in components), f'All components must be non-negative, found `{min(components)}`.' 162 | assert max(components) < adata.obsm[basis_key].shape[-1], \ 163 | f'Component `{max(components)}` is >= than number of components `{adata.obsm[basis_key].shape[-1]}`.' 164 | assert key in adata.obs or key in adata.var_names, f'Key `{key}` not found in `adata.obs` or `adata.var_names`.' 165 | 166 | colors = adata.uns.get(f'{key}_colors', None) 167 | if steps is not None: 168 | # this somehow destroys the colors 169 | adata, _ = sample_unif(adata, steps, bs=basis, components=components) 170 | 171 | data = dict(x=adata.obsm[basis_key][:, components[0]], 172 | y=adata.obsm[basis_key][:, components[1]], 173 | z=adata.obsm[basis_key][:, components[2]]) 174 | 175 | fig = figure(tools=[], outline_line_width=0, toolbar_location='left', disabled=True) 176 | to_add = None 177 | 178 | if key in adata.obs and is_categorical_dtype(adata.obs[key]): 179 | if cmap is None: 180 | cmap = colors or cm.tab20b 181 | 182 | hex_palette = _mpl_to_hex_palette(cmap) 183 | 184 | mapper = defaultdict(lambda: '#AAAAAA', zip(adata.obs[key].cat.categories, hex_palette)) 185 | colors = [str(mapper[c]) for c in adata.obs[key]] 186 | 187 | n_cls = len(adata.obs[key].cat.categories) 188 | _ = fig.circle([0] * n_cls, [0] * n_cls, 189 | color=list(mapper.values()), 190 | visible=False, radius=0) 191 | if show_legend: 192 | to_add = Legend(items=[ 193 | LegendItem(label=str(c), index=i, renderers=[_]) 194 | for i, c in enumerate(mapper.keys()) 195 | ]) 196 | else: 197 | vals = adata.obs_vector(key) if key in adata.var_names else adata.obs[key] 198 | 199 | cmap = cm.viridis if cmap is None else cmap 200 | colors, minn, maxx = _to_hex_colors(vals, cmap, perc=perc) 201 | hex_palette = _mpl_to_hex_palette(cmap) 202 | 203 | _ = fig.circle(0, 0, visible=False, radius=0) 204 | 205 | color_mapper = LinearColorMapper(palette=hex_palette, low=minn, high=maxx) 206 | if show_cbar: 207 | to_add = ColorBar(color_mapper=color_mapper, ticker=FixedTicker(ticks=np.linspace(minn, maxx, n_ticks)), 208 | label_standoff=12, border_line_color=None, location=(0, 0)) 209 | 210 | data['color'] = colors 211 | if tooltips is None: 212 | tooltips = adata.obs_keys() 213 | if len(tooltips): 214 | data['tooltip'] = adata.obs[tooltips].apply(_wrap_as_div, axis=1) 215 | 216 | source = ColumnDataSource(data=data) 217 | if plot_width is None or plot_height is None: 218 | try: 219 | import screeninfo 220 | for monitor in screeninfo.get_monitors(): 221 | break 222 | plot_width = max(monitor.width - 300, 0) if plot_width is None else plot_width 223 | plot_height = max(monitor.height, 300) if plot_height is None else plot_height 224 | except ImportError: 225 | print('Unable to get the screen size, please install package `screeninfo` as `pip install screeninfo`.') 226 | plot_width = 1200 if plot_width is None else plot_width 227 | plot_height = 1200 if plot_height is None else plot_height 228 | except: 229 | plot_width = 1200 if plot_width is None else plot_width 230 | plot_height = 1200 if plot_height is None else plot_height 231 | 232 | surface = Surface3d(x="x", y="y", z="z", color="color", 233 | data_source=source, options={**_DEFAULT, 234 | **dict(dotSizeRatio=dot_size_ratio, 235 | showXAxis=show_axes, 236 | showYAxis=show_axes, 237 | showZAxis=show_axes, 238 | xLabel=f'{basis}_{components[0]}', 239 | yLabel=f'{basis}_{components[1]}', 240 | zLabel=f'{basis}_{components[2]}', 241 | showPerspective=perspective, 242 | height=f'{plot_height}px', 243 | width=f'{plot_width}px', 244 | verticalRatio=vertical_ratio, 245 | keepAspectRatio=keep_aspect_ratio, 246 | showLegend=False, 247 | tooltip='tooltip' in data, 248 | xCenter='50%', 249 | yCenter='50%', 250 | showGrid=show_axes)}) 251 | if to_add is not None: 252 | fig.add_layout(to_add, 'left') 253 | 254 | # dirty little trick, makes plot disappear 255 | # ideally, one would modify the DOM in the .ts file but I'm just lazy 256 | fig.xgrid.visible = False 257 | fig.ygrid.visible = False 258 | fig.xaxis.visible = False 259 | fig.yaxis.visible = False 260 | 261 | with tempfile.NamedTemporaryFile(suffix='.html') as fout: 262 | path = save(row(surface, fig), fout.name, resources=CDN, title=f'Scatter3D - {key}') 263 | fout.flush() 264 | webbrowser.open_new_tab(path) 265 | sleep(2) # better safe than sorry 266 | -------------------------------------------------------------------------------- /interactive_plotting/experimental/surface3d.ts: -------------------------------------------------------------------------------- 1 | import {HTMLBox, HTMLBoxView} from "models/layouts/html_box" 2 | import {ColumnDataSource} from "models/sources/column_data_source" 3 | import * as p from "core/properties" 4 | 5 | declare namespace vis { 6 | class Graph3d { 7 | constructor(el: HTMLElement, data: object, OPTIONS: object) 8 | setData(data: vis.DataSet): void 9 | } 10 | 11 | class DataSet { 12 | add(data: unknown): void 13 | } 14 | } 15 | 16 | function _tooltip(obj: {x: number, y: number, z: number, data: {tooltip: string}}) { 17 | return obj.data["tooltip"]; 18 | } 19 | 20 | const OPTIONS = { 21 | width: '1200px', 22 | height: '1200px', 23 | style: 'dot-color', 24 | showPerspective: true, 25 | tooltip: _tooltip, 26 | showGrid: false, 27 | showXAxis: false, 28 | showYAxis: false, 29 | showZAxis: false, 30 | keepAspectRatio: true, 31 | verticalRatio: 1.0, 32 | cameraPosition: { 33 | horizontal: 1, 34 | vertical: 0.25, 35 | distance: 2.0, 36 | }, 37 | } 38 | 39 | export class Surface3dView extends HTMLBoxView { 40 | model: Surface3d 41 | 42 | private _graph: vis.Graph3d 43 | 44 | render(): void { 45 | super.render() 46 | if (this.model.options["tooltip"]) { // we want to show tooltips 47 | this.model.options["tooltip"] = _tooltip 48 | } 49 | this._graph = new vis.Graph3d(this.el, this.get_data(), this.model.options) 50 | } 51 | 52 | connect_signals(): void { 53 | super.connect_signals() 54 | this.connect(this.model.data_source.change, () => this._graph.setData(this.get_data())) 55 | } 56 | 57 | get_data(): vis.DataSet { 58 | const data = new vis.DataSet() 59 | const source = this.model.data_source 60 | 61 | if ("tooltip" in source.data) { 62 | for (let i = 0; i < source.get_length()!; i++) { 63 | data.add({ 64 | x: source.data[this.model.x][i], 65 | y: source.data[this.model.y][i], 66 | z: source.data[this.model.z][i], 67 | style: source.data[this.model.color][i], 68 | tooltip: source.data["tooltip"][i] 69 | }) 70 | } 71 | } else { 72 | for (let i = 0; i < source.get_length()!; i++) { 73 | data.add({ 74 | x: source.data[this.model.x][i], 75 | y: source.data[this.model.y][i], 76 | z: source.data[this.model.z][i], 77 | style: source.data[this.model.color][i], 78 | }) 79 | } 80 | } 81 | 82 | return data 83 | } 84 | } 85 | 86 | export namespace Surface3d { 87 | export type Attrs = p.AttrsOf 88 | 89 | export type Props = HTMLBox.Props & { 90 | x: p.Property 91 | y: p.Property 92 | z: p.Property 93 | color: p.Property 94 | data_source: p.Property 95 | options: p.Property<{[key: string]: unknown}> 96 | } 97 | } 98 | 99 | export interface Surface3d extends Surface3d.Attrs {} 100 | 101 | export class Surface3d extends HTMLBox { 102 | properties: Surface3d.Props 103 | 104 | constructor(attrs?: Partial) { 105 | super(attrs) 106 | } 107 | 108 | static __name__ = "Surface3d" 109 | 110 | static init_Surface3d(): void { 111 | this.prototype.default_view = Surface3dView 112 | 113 | this.define({ 114 | x: [ p.String ], 115 | y: [ p.String ], 116 | z: [ p.String ], 117 | color: [ p.String ], 118 | data_source: [ p.Instance ], 119 | options: [ p.Any, OPTIONS ] 120 | }) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /interactive_plotting/holoviews_plots.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from interactive_plotting.utils._utils import * 4 | 5 | from collections import OrderedDict as odict 6 | 7 | from pandas.api.types import is_categorical_dtype, is_string_dtype, infer_dtype 8 | from scipy.sparse import issparse 9 | from functools import partial 10 | from bokeh.palettes import Viridis256 11 | from datashader.colors import Sets1to3 12 | from pandas.core.indexes.base import Index 13 | from holoviews.operation.datashader import datashade, bundle_graph, shade, dynspread, rasterize, spread 14 | from holoviews.operation import decimate 15 | from bokeh.models import HoverTool 16 | 17 | import scanpy as sc 18 | import numpy as np 19 | import pandas as pd 20 | import networkx as nx 21 | import holoviews as hv 22 | import datashader as ds 23 | import warnings 24 | 25 | 26 | try: 27 | assert callable(sc.tl.dpt) 28 | dpt_fn = sc.tl.dpt 29 | except AssertionError: 30 | from scanpy.api.tl import dpt as dpt_fn 31 | 32 | #TODO: DRY 33 | 34 | @wrap_as_panel 35 | def scatter(adata, genes=None, basis=None, components=(1, 2), obs_keys=None, 36 | obsm_keys=None, use_raw=False, subsample='datashade', steps=40, keep_frac=None, lazy_loading=True, 37 | default_obsm_ixs=[0], sort=True, skip=True, seed=None, cols=None, size=4, 38 | perc=None, show_perc=True, cmap=None, plot_height=400, plot_width=400, save=None): 39 | ''' 40 | Scatter plot for continuous observations. 41 | 42 | Params 43 | -------- 44 | adata: anndata.Anndata 45 | anndata object 46 | genes: List[Str], optional (default: `None`) 47 | list of genes to add for visualization 48 | if `None`, use `adata.var_names` 49 | basis: Union[Str, List[Str]], optional (default: `None`) 50 | basis in `adata.obsm`, if `None`, get all available 51 | components: Union[List[Int], List[List[Int]]], optional (default: `[1, 2]`) 52 | components of specified `basis` 53 | if it's of type `List[Int]`, all the basis have use the same components 54 | obs_keys: List[Str], optional (default: `None`) 55 | keys of categorical observations in `adata.obs` 56 | if `None`, get all available 57 | obsm_keys: List[Str], optional (default: `None`) 58 | keys of categorical observations in `adata.obsm` 59 | if `None`, get all available 60 | use_raw: Bool, optional (default: `False`) 61 | use `adata.raw` for gene expression levels 62 | subsample: Str, optional (default: `'datashade'`) 63 | subsampling strategy for large data 64 | possible values are `None, 'none', 'datashade', 'decimate', 'density', 'uniform'` 65 | using `subsample='datashade'` is preferred over other options since it does not subset 66 | when using `subsample='datashade'`, colorbar is not visible 67 | `'density'` and `'uniform'` use first element of `basis` for their computation 68 | steps: Union[Int, Tuple[Int, Int]], optional (default: `40`) 69 | step size when the embedding directions 70 | larger step size corresponds to higher density of points 71 | keep_frac: Float, optional (default: `adata.n_obs / 5`) 72 | number of observations to keep when `subsample='decimate'` 73 | lazy_loading: Bool, optional (default: `False`) 74 | only visualize when necessary 75 | for notebook sharing, consider using `lazy_loading=False` 76 | default_obsm_ixs: List[Int], optional (default: `[0]`) 77 | indices of 2-D elements in `adata.obsm` to add 78 | when `obsm_keys=None` 79 | by default adds only 1st column 80 | sort: Bool, optional (default: `True`) 81 | whether sort the `genes`, `obs_keys` and `obsm_keys` 82 | in ascending order 83 | skip: Bool, optional (default: `True`) 84 | skip all the keys not found in the corresponding collections 85 | seed: Int, optional (default: `None`) 86 | random seed, used when `subsample='decimate'`` 87 | cols: Int, optional (default: `2`) 88 | number of columns when plotting basis 89 | if `None`, use togglebar 90 | size: Int, optional (default: `4`) 91 | size of the glyphs 92 | works only when `subsample != 'datashade'` 93 | perc: List[Float], optional (default: `None`) 94 | percentile for colors 95 | useful when `lazy_loading = False` 96 | works only when `subsample != 'datashade'` 97 | show_perc: Bool, optional (default: `True`) 98 | show percentile slider when `lazy_loading = True` 99 | works only when `subsample != 'datashade'` 100 | cmap: List[Str], optional (default: `bokeh.palettes.Viridis256`) 101 | continuous colormap in hex format 102 | plot_height: Int, optional (default: `400`) 103 | height of the plot in pixels 104 | plot_width: Int, optional (default: `400`) 105 | width of the plot in pixels 106 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`) 107 | path where to save the plot 108 | 109 | Returns 110 | -------- 111 | plot: panel.panel 112 | holoviews plot wrapped in `panel.panel` 113 | ''' 114 | 115 | def create_scatterplot(gene, perc_low, perc_high, *args, bs=None): 116 | ixs = np.where(basis == bs)[0][0] 117 | is_diffmap = bs == 'diffmap' 118 | 119 | if len(args) > 0: 120 | ixs = np.where(basis == bs)[0][0] * 2 121 | comp = (np.array([args[ixs], args[ixs + 1]]) - (not is_diffmap)) % adata.obsm[f'X_{bs}'].shape[-1] 122 | else: 123 | comp = np.array(components[ixs]) # need to make a copy 124 | 125 | if perc_low is not None and perc_high is not None: 126 | if perc_low > perc_high: 127 | perc_low, perc_high = perc_high, perc_low 128 | perc = [perc_low, perc_high] 129 | else: 130 | perc = None 131 | 132 | ad, _ = alazy[bs, tuple(comp)] 133 | ad_mraw = ad.raw if use_raw else ad 134 | 135 | # because diffmap has small range, it iterferes with 136 | # the legend created 137 | emb = ad.obsm[f'X_{bs}'][:, comp] * (1000 if is_diffmap else 1) 138 | comp += not is_diffmap # naming consistence 139 | 140 | bsu = bs.upper() 141 | x = hv.Dimension('x', label=f'{bsu}{comp[0]}') 142 | y = hv.Dimension('y', label=f'{bsu}{comp[1]}') 143 | 144 | #if ignore_after is not None and ignore_after in gene: 145 | if gene in ad.obsm.keys(): 146 | data = ad.obsm[gene][:, 0] 147 | elif gene in ad.obs.keys(): 148 | data = ad.obs[gene].values 149 | elif gene in ad_mraw.var_names: 150 | data = ad_mraw.obs_vector(gene) 151 | else: 152 | gene, ix = gene.split(ignore_after) 153 | ix = int(ix) 154 | data = ad.obsm[gene][:, ix] 155 | 156 | data = np.array(data, dtype=np.float64) 157 | 158 | # we need to clip the data as well 159 | scatter = hv.Scatter({'x': emb[:, 0], 'y': emb[:, 1], 'gene': data}, 160 | kdims=[x, y], vdims='gene') 161 | 162 | return scatter.opts(cmap=cmap, color='gene', 163 | colorbar=True, 164 | colorbar_opts={'width': CBW}, 165 | size=size, 166 | clim=minmax(data, perc=perc), 167 | xlim=minmax(emb[:, 0]), 168 | ylim=minmax(emb[:, 1]), 169 | xlabel=f'{bsu}{comp[0]}', 170 | ylabel=f'{bsu}{comp[1]}') 171 | 172 | def _create_scatterplot_nl(bs, gene, perc_low, perc_high, *args): 173 | # arg switching 174 | return create_scatterplot(gene, perc_low, perc_high, *args, bs=bs) 175 | 176 | if perc is None: 177 | perc = [None, None] 178 | assert len(perc) == 2, f'Percentile must be of length 2, found `{len(perc)}`.' 179 | if all(map(lambda p: p is not None, perc)): 180 | perc = sorted(perc) 181 | 182 | if keep_frac is None: 183 | keep_frac = 0.2 184 | 185 | if basis is None: 186 | basis = np.ravel(sorted(filter(len, map(BS_PAT.findall, adata.obsm.keys())))) 187 | elif isinstance(basis, str): 188 | basis = np.array([basis]) 189 | elif not isinstance(basis, np.ndarray): 190 | basis = np.array(basis) 191 | 192 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.' 193 | assert subsample in ALL_SUBSAMPLING_STRATEGIES, f'Invalid subsampling strategy `{subsample}`. Possible values are `{ALL_SUBSAMPLING_STRATEGIES}`.' 194 | 195 | if subsample == 'uniform': 196 | cb_kwargs = {'steps': steps} 197 | elif subsample == 'density': 198 | cb_kwargs = {'size': int(keep_frac * adata.n_obs), 'seed': seed} 199 | else: 200 | cb_kwargs = {} 201 | alazy = SamplingLazyDict(adata, subsample, callback_kwargs=cb_kwargs) 202 | adata_mraw = adata.raw if use_raw else adata # maybe raw 203 | 204 | if obs_keys is None: 205 | obs_keys = skip_or_filter(adata, adata.obs.keys(), adata.obs.keys(), dtype=is_numeric, 206 | where='obs', skip=True, warn=False) 207 | else: 208 | if not iterable(obs_keys): 209 | obs_keys = [obs_keys] 210 | obs_keys = skip_or_filter(adata, obs_keys, adata.obs.keys(), dtype=is_numeric, 211 | where='obs', skip=skip) 212 | 213 | if obsm_keys is None: 214 | obsm_keys = get_all_obsm_keys(adata, default_obsm_ixs) 215 | ignore_after = OBSM_SEP 216 | obsm_keys = skip_or_filter(adata, obsm_keys, adata.obsm.keys(), where='obsm', 217 | dtype=is_numeric, skip=True, warn=False, ignore_after=ignore_after) 218 | else: 219 | if not iterable(obsm_keys): 220 | obsm_keys = [obsm_keys] 221 | 222 | ignore_after = OBSM_SEP if any((OBSM_SEP in obs_key for obs_key in obsm_keys)) else None 223 | obsm_keys = skip_or_filter(adata, obsm_keys, adata.obsm.keys(), where='obsm', 224 | dtype=is_numeric, skip=skip, ignore_after=ignore_after) 225 | 226 | if genes is None: 227 | genes = adata_mraw.var_names 228 | elif not iterable(genes): 229 | genes = [genes] 230 | genes = skip_or_filter(adata_mraw, genes, adata_mraw.var_names, where='adata.var_names', skip=skip) 231 | 232 | if isinstance(genes, Index): 233 | genes = list(genes) 234 | 235 | if sort: 236 | if any(genes[i] > genes[i + 1] for i in range(len(genes) - 1)): 237 | genes = sorted(genes) 238 | if any(obs_keys[i] > obs_keys[i + 1] for i in range(len(obs_keys) - 1)): 239 | obs_keys = sorted(obs_keys) 240 | if any(obsm_keys[i] > obsm_keys[i + 1] for i in range(len(obsm_keys) - 1)): 241 | obsm_keys = sorted(obsm_keys) 242 | 243 | conditions = obs_keys + obsm_keys + genes 244 | if len(conditions) == 0: 245 | warnings.warn(f'Nothing to plot, no conditions found.') 246 | return 247 | 248 | if not isinstance(components, np.ndarray): 249 | components = np.array(components) 250 | if components.ndim == 1: 251 | components = np.repeat(components[np.newaxis, :], len(basis), axis=0) 252 | 253 | assert components.ndim == 2, f'Only `2` dimensional components are supported, got `{components.ndim}`.' 254 | assert components.shape[-1] == 2, f'Components\' second dimension must be of size `2`, got `{components.shape[-1]}`.' 255 | if not isinstance(basis, np.ndarray): 256 | basis = np.array(basis) 257 | 258 | assert components.shape[0] == len(basis), f'Expected #components == `{len(basis)}`, got `{components.shape[0]}`.' 259 | assert np.all(components >= 0), f'Currently, only positive indices are supported, found `{list(map(list, components))}`.' 260 | 261 | diffmap_ix = np.where(basis != 'diffmap')[0] 262 | components[diffmap_ix, :] -= 1 263 | 264 | for bs, comp in zip(basis, components): 265 | shape = adata.obsm[f'X_{bs}'].shape 266 | assert f'X_{bs}' in adata.obsm.keys(), f'`X_{bs}` not found in `adata.obsm`' 267 | assert shape[-1] > np.max(comp), f'Requested invalid components `{list(comp)}` for basis `X_{bs}` with shape `{shape}`.' 268 | 269 | if adata.n_obs > SUBSAMPLE_THRESH and subsample in NO_SUBSAMPLE: 270 | warnings.warn(f'Number of cells `{adata.n_obs}` > `{SUBSAMPLE_THRESH}`. Consider specifying `subsample={SUBSAMPLING_STRATEGIES}`.') 271 | 272 | if len(conditions) > HOLOMAP_THRESH and not lazy_loading: 273 | warnings.warn(f'Number of conditions `{len(conditions)}` > `{HOLOMAP_THRESH}`. Consider specifying `lazy_loading=True`.') 274 | 275 | if cmap is None: 276 | cmap = Viridis256 277 | 278 | kdims = [hv.Dimension('Basis', values=basis), 279 | hv.Dimension('Condition', values=conditions), 280 | hv.Dimension('Percentile (lower)', range=(0, 100), step=0.1, type=float, default=0 if perc[0] is None else perc[0]), 281 | hv.Dimension('Percentile (upper)', range=(0, 100), step=0.1, type=float, default=100 if perc[1] is None else perc[1])] 282 | 283 | cs = create_scatterplot 284 | _cs = _create_scatterplot_nl 285 | if not show_perc or subsample == 'datashade' or not lazy_loading: 286 | kdims = kdims[:2] 287 | cs = lambda gene, *args, **kwargs: create_scatterplot(gene, perc[0], perc[1], *args, **kwargs) 288 | _cs = lambda bs, gene, *args, **kwargs: _create_scatterplot_nl(bs, gene, perc[0], perc[1], *args, **kwargs) 289 | 290 | if not lazy_loading: 291 | dynmaps = [hv.HoloMap({(g, b):cs(g, bs=b) for g in conditions for b in basis}, kdims=kdims[::-1])] 292 | else: 293 | for bs, comp in zip(basis, components): 294 | kdims.append(hv.Dimension(f'{bs.upper()}[X]', 295 | type=int, default=1, step=1, 296 | range=(1, adata.obsm[f'X_{bs}'].shape[-1]))) 297 | kdims.append(hv.Dimension(f'{bs.upper()}[Y]', 298 | type=int, default=2, step=1, 299 | range=(1, adata.obsm[f'X_{bs}'].shape[-1]))) 300 | if cols is None: 301 | dynmaps = [hv.DynamicMap(_cs, kdims=kdims)] 302 | else: 303 | dynmaps = [hv.DynamicMap(partial(cs, bs=bs), kdims=kdims[1:]) for bs in basis] 304 | 305 | if subsample == 'datashade': 306 | dynmaps = [dynspread(datashade(d, aggregator=ds.mean('gene'), color_key='gene', 307 | cmap=cmap, streams=[hv.streams.RangeXY(transient=True)]), 308 | threshold=0.8, max_px=5) 309 | for d in dynmaps] 310 | elif subsample == 'decimate': 311 | dynmaps = [decimate(d, max_samples=int(adata.n_obs * keep_frac), 312 | streams=[hv.streams.RangeXY(transient=True)], random_seed=seed) for d in dynmaps] 313 | 314 | dynmaps = [d.opts(framewise=True, axiswise=True, frame_height=plot_height, frame_width=plot_width) for d in dynmaps] 315 | 316 | if cols is None: 317 | plot = dynmaps[0].opts(title='', frame_height=plot_height, frame_width=plot_width) 318 | else: 319 | plot = hv.Layout(dynmaps).opts(title='', height=plot_height, width=plot_width).cols(cols) 320 | 321 | if save is not None: 322 | hv.renderer('bokeh').save(plot, save) 323 | 324 | return plot 325 | 326 | 327 | @wrap_as_panel 328 | def scatterc(adata, basis=None, components=[1, 2], obs_keys=None, 329 | obsm_keys=None, subsample='datashade', steps=40, keep_frac=None, hover=False, lazy_loading=True, 330 | default_obsm_ixs=[0], sort=True, skip=True, seed=None, legend_loc='top_right', cols=None, size=4, 331 | cmap=None, show_legend=True, plot_height=400, plot_width=400, save=None): 332 | ''' 333 | Scatter plot for categorical observations. 334 | 335 | Params 336 | -------- 337 | adata: anndata.Anndata 338 | anndata object 339 | basis: Union[Str, List[Str]], optional (default: `None`) 340 | basis in `adata.obsm`, if `None`, get all available 341 | components: Union[List[Int], List[List[Int]]], optional (default: `[1, 2]`) 342 | components of specified `basis` 343 | if it's of type `List[Int]`, all the basis have use the same components 344 | obs_keys: List[Str], optional (default: `None`) 345 | keys of categorical observations in `adata.obs` 346 | if `None`, get all available 347 | obsm_keys: List[Str], optional (default: `None`) 348 | keys of categorical observations in `adata.obsm` 349 | if `None`, get all available 350 | subsample: Str, optional (default: `'datashade'`) 351 | subsampling strategy for large data 352 | possible values are `None, 'none', 'datashade', 'decimate', 'density', 'uniform'` 353 | using `subsample='datashade'` is preferred over other options since it does not subset 354 | when using `subsample='datashade'`, colorbar is not visible 355 | `'density'` and `'uniform'` use first element of `basis` for their computation 356 | steps: Union[Int, Tuple[Int, Int]], optional (default: `40`) 357 | step size when the embedding directions 358 | larger step size corresponds to higher density of points 359 | keep_frac: Float, optional (default: `adata.n_obs / 5`) 360 | number of observations to keep when `subsample='decimate'` 361 | hover: Union[Bool, Int], optional (default: `False`) 362 | whether to display cell index when hovering over a block 363 | if integer, it specifies the number of rows/columns (defualt: `10`) 364 | lazy_loading: Bool, optional (default: `False`) 365 | only visualize when necessary 366 | for notebook sharing, consider using `lazy_loading=False` 367 | sort: Bool, optional (default: `True`) 368 | whether sort the `genes`, `obs_keys` and `obsm_keys` 369 | in ascending order 370 | skip: Bool, optional (default: `True`) 371 | skip all the keys not found in the corresponding collections 372 | seed: Int, optional (default: `None`) 373 | random seed, used when `subsample='decimate'`` 374 | legend_loc: Str, optional (default: `top_right`) 375 | position of the legend 376 | cols: Int, optional (default: `None`) 377 | number of columns when plotting basis 378 | if `None`, use togglebar 379 | size: Int, optional (default: `4`) 380 | size of the glyphs 381 | works only when `subsample!='datashade'` 382 | cmap: List[Str], optional (default: `datashader.colors.Sets1to3`) 383 | categorical colormap in hex format 384 | plot_height: Int, optional (default: `400`) 385 | height of the plot in pixels 386 | plot_width: Int, optional (default: `400`) 387 | width of the plot in pixels 388 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`) 389 | path where to save the plot 390 | 391 | Returns 392 | -------- 393 | plot: panel.panel 394 | holoviews plot wrapped in `panel.panel` 395 | ''' 396 | 397 | def create_legend(condition, bs): 398 | # slightly hacky solution to get the correct initial limits 399 | xlim = lims['x'][bs] 400 | ylim = lims['y'][bs] 401 | 402 | return hv.NdOverlay({k: hv.Points([0, 0], label=str(k)).opts(size=0, color=v, xlim=xlim, ylim=ylim) # alpha affects legend 403 | for k, v in cmaps[condition].items()}) 404 | 405 | def add_hover(subsampled, dynmaps=None, by_block=True): 406 | hovertool = HoverTool(tooltips=[('Cell Index', '@index')]) 407 | hover_width, hover_height = (10, 10) if isinstance(hover, bool) else (hover, hover) 408 | 409 | if by_block: 410 | if dynmaps is None: 411 | dynmaps = subsampled 412 | 413 | return [s * hv.util.Dynamic(rasterize(d, width=hover_width, height=hover_height, streams=[hv.streams.RangeXY], 414 | aggregator=ds.reductions.min('index')), operation=hv.QuadMesh)\ 415 | .opts(tools=[hovertool], axiswise=True, framewise=True, alpha=0, hover_alpha=0.25, 416 | height=plot_height, width=plot_width) 417 | for s, d in zip(subsampled, dynmaps)] 418 | 419 | return [s.opts(tools=[hovertool]) for s in subsampled] 420 | 421 | def create_scatterplot(cond, *args, bs=None): 422 | ixs = np.where(basis == bs)[0][0] 423 | is_diffmap = bs == 'diffmap' 424 | 425 | if len(args) > 0: 426 | ixs = np.where(basis == bs)[0][0] * 2 427 | comp = (np.array([args[ixs], args[ixs + 1]]) - (not is_diffmap)) % adata.obsm[f'X_{bs}'].shape[-1] 428 | else: 429 | comp = np.array(components[ixs]) # need to make a copy 430 | 431 | # subsample is uniform or density 432 | ad, ixs = alazy[bs, tuple(comp)] 433 | # because diffmap has small range, it interferes with the legend 434 | emb = ad.obsm[f'X_{bs}'][:, comp] * (1000 if is_diffmap else 1) 435 | comp += not is_diffmap # naming consistence 436 | 437 | bsu = bs.upper() 438 | x = hv.Dimension('x', label=f'{bsu}{comp[0]}') 439 | y = hv.Dimension('y', label=f'{bsu}{comp[1]}') 440 | 441 | #if ignore_after is not None and ignore_after in gene: 442 | if cond in ad.obsm.keys(): 443 | data = ad.obsm[cond][:, 0] 444 | elif cond in ad.obs.keys(): 445 | data = ad.obs[cond] 446 | else: 447 | cond, ix = cond.split(ignore_after) 448 | ix = int(ix) 449 | data = ad.obsm[cond][:, ix] 450 | 451 | data = pd.Categorical(data).as_ordered() 452 | scatter = hv.Scatter({'x': emb[:, 0], 'y': emb[:, 1], 'cond': data, 'index': ixs}, 453 | kdims=[x, y], vdims=['cond', 'index']).sort('cond') 454 | 455 | return scatter.opts(color_index='cond', cmap=cmaps[cond], 456 | show_legend=show_legend, 457 | legend_position=legend_loc, 458 | size=size, 459 | xlim=minmax(emb[:, 0]), 460 | ylim=minmax(emb[:, 1]), 461 | xlabel=f'{bsu}{comp[0]}', 462 | ylabel=f'{bsu}{comp[1]}') 463 | 464 | def _cs(bs, cond, *args): 465 | return create_scatterplot(cond, *args, bs=bs) 466 | 467 | if keep_frac is None: 468 | keep_frac = 0.2 469 | 470 | if basis is None: 471 | basis = np.ravel(sorted(filter(len, map(BS_PAT.findall, adata.obsm.keys())))) 472 | elif isinstance(basis, str): 473 | basis = np.array([basis]) 474 | elif not isinstance(basis, np.ndarray): 475 | basis = np.array(basis) 476 | 477 | if not isinstance(hover, bool): 478 | assert hover > 1, f'Expected `hover` to be `> 1` when being an integer, found: `{hover}`.' 479 | 480 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.' 481 | assert subsample in ALL_SUBSAMPLING_STRATEGIES, f'Invalid subsampling strategy `{subsample}`. Possible values are `{ALL_SUBSAMPLING_STRATEGIES}`.' 482 | 483 | if subsample == 'uniform': 484 | cb_kwargs = {'steps': steps} 485 | elif subsample == 'density': 486 | cb_kwargs = {'size': int(keep_frac * adata.n_obs), 'seed': seed} 487 | else: 488 | cb_kwargs = {} 489 | alazy = SamplingLazyDict(adata, subsample, callback_kwargs=cb_kwargs) 490 | 491 | if obs_keys is None: 492 | obs_keys = skip_or_filter(adata, adata.obs.keys(), adata.obs.keys(), 493 | dtype='category', where='obs', skip=True, warn=False) 494 | else: 495 | if not iterable(obs_keys): 496 | obs_keys = [obs_keys] 497 | 498 | obs_keys = skip_or_filter(adata, obs_keys, adata.obs.keys(), 499 | dtype='category', where='obs', skip=skip) 500 | 501 | if obsm_keys is None: 502 | obsm_keys = get_all_obsm_keys(adata, default_obsm_ixs) 503 | obsm_keys = skip_or_filter(adata, obsm_keys, adata.obsm.keys(), where='obsm', 504 | dtype='category', skip=True, warn=False, ignore_after=OBSM_SEP) 505 | else: 506 | if not iterable(obsm_keys): 507 | obsm_keys = [obsm_keys] 508 | 509 | ignore_after = OBSM_SEP if any((OBSM_SEP in obs_key for obs_key in obsm_keys)) else None 510 | obsm_keys = skip_or_filter(adata, obsm_keys, adata.obsm.keys(), where='obsm', 511 | dtype='category', skip=skip, ignore_after=ignore_after) 512 | 513 | if sort: 514 | if any(obs_keys[i] > obs_keys[i + 1] for i in range(len(obs_keys) - 1)): 515 | obs_keys = sorted(obs_keys) 516 | if any(obsm_keys[i] > obsm_keys[i + 1] for i in range(len(obsm_keys) - 1)): 517 | obsm_keys = sorted(obsm_keys) 518 | 519 | conditions = obs_keys + obsm_keys 520 | 521 | if len(conditions) == 0: 522 | warnings.warn('Nothing to plot, no conditions found.') 523 | return 524 | 525 | if not isinstance(components, np.ndarray): 526 | components = np.array(components) 527 | if components.ndim == 1: 528 | components = np.repeat(components[np.newaxis, :], len(basis), axis=0) 529 | 530 | assert components.ndim == 2, f'Only `2` dimensional components are supported, got `{components.ndim}`.' 531 | assert components.shape[-1] == 2, f'Components\' second dimension must be of size `2`, got `{components.shape[-1]}`.' 532 | 533 | assert components.shape[0] == len(basis), f'Expected #components == `{len(basis)}`, got `{components.shape[0]}`.' 534 | assert np.all(components >= 0), f'Currently, only positive indices are supported, found `{list(map(list, components))}`.' 535 | 536 | diffmap_ix = np.where(basis != 'diffmap')[0] 537 | components[diffmap_ix, :] -= 1 538 | 539 | for bs, comp in zip(basis, components): 540 | shape = adata.obsm[f'X_{bs}'].shape 541 | assert f'X_{bs}' in adata.obsm.keys(), f'`X_{bs}` not found in `adata.obsm`' 542 | assert shape[-1] > np.max(comp), f'Requested invalid components `{list(comp)}` for basis `X_{bs}` with shape `{shape}`.' 543 | 544 | if adata.n_obs > SUBSAMPLE_THRESH and subsample in NO_SUBSAMPLE: 545 | warnings.warn(f'Number of cells `{adata.n_obs}` > `{SUBSAMPLE_THRESH}`. Consider specifying `subsample={SUBSAMPLING_STRATEGIES}`.') 546 | 547 | if len(conditions) > HOLOMAP_THRESH and not lazy_loading: 548 | warnings.warn(f'Number of conditions `{len(conditions)}` > `{HOLOMAP_THRESH}`. Consider specifying `lazy_loading=True`.') 549 | 550 | if cmap is None: 551 | cmap = Sets1to3 552 | 553 | lims = dict(x=dict(), y=dict()) 554 | for bs in basis: 555 | emb = adata.obsm[f'X_{bs}'] 556 | is_diffmap = bs == 'diffmap' 557 | if is_diffmap: 558 | emb = (emb * 1000).copy() 559 | lims['x'][bs] = minmax(emb[:, 0 + is_diffmap]) 560 | lims['y'][bs] = minmax(emb[:, 1 + is_diffmap]) 561 | 562 | kdims = [hv.Dimension('Basis', values=basis), 563 | hv.Dimension('Condition', values=conditions)] 564 | 565 | cmaps = dict() 566 | for cond in conditions: 567 | color_key = f'{cond}_colors' 568 | # use the datashader default cmap since setting it doesn't work (for multiple conditions) 569 | cmaps[cond] = odict(zip(adata.obs[cond].cat.categories, # adata.uns.get(color_key, cmap))) 570 | cmap if subsample == 'datashade' else adata.uns.get(color_key, cmap))) 571 | # this approach (for datashader) does not really work - the legend gets mixed up 572 | # cmap = dict(ChainMap(*[c.copy() for c in cmaps.values()])) 573 | # if len(cmap.keys()) != len([k for c in conditions for k in cmaps[c].keys()]): 574 | # warnings.warn('Found same key across multiple conditions. The colormap/legend may not accurately display the colors.') 575 | 576 | if not lazy_loading: 577 | # have to wrap because of the *args 578 | dynmaps = [hv.HoloMap({(c, b):create_scatterplot(c, bs=b) for c in conditions for b in basis}, kdims=kdims[::-1])] 579 | else: 580 | for bs, comp in zip(basis, components): 581 | kdims.append(hv.Dimension(f'{bs.upper()}[X]', 582 | type=int, default=1, step=1, 583 | range=(1, adata.obsm[f'X_{bs}'].shape[-1]))) 584 | kdims.append(hv.Dimension(f'{bs.upper()}[Y]', 585 | type=int, default=2, step=1, 586 | range=(1, adata.obsm[f'X_{bs}'].shape[-1]))) 587 | 588 | if cols is None: 589 | dynmaps = [hv.DynamicMap(_cs, kdims=kdims)] 590 | else: 591 | dynmaps = [hv.DynamicMap(partial(create_scatterplot, bs=bs), kdims=kdims[1:]) for bs in basis] 592 | 593 | legend = None 594 | if subsample == 'datashade': 595 | subsampled = [dynspread(datashade(d, aggregator=ds.count_cat('cond'), color_key=cmap, 596 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize], 597 | min_alpha=255).opts(axiswise=True, framewise=True), threshold=0.8, max_px=5) 598 | for d in dynmaps] 599 | dynmaps = add_hover(subsampled, dynmaps) if hover else subsampled 600 | 601 | if show_legend: 602 | warnings.warn('Automatic adjustment of axes is currently not working when ' 603 | '`show_legend=True` and `subsample=\'datashade\'`.') 604 | legend = hv.DynamicMap(create_legend, kdims=kdims[:2][::-1]) 605 | elif hover: 606 | warnings.warn('Automatic adjustment of axes is currently not working when hovering is enabled.') 607 | 608 | elif subsample == 'decimate': 609 | subsampled = [decimate(d, max_samples=int(adata.n_obs * keep_frac), 610 | streams=[hv.streams.RangeXY(transient=True)], random_seed=seed) for d in dynmaps] 611 | dynmaps = add_hover(subsampled, by_block=False) if hover else subsampled 612 | elif hover: 613 | dynmaps = add_hover(dynmaps, by_block=False) 614 | 615 | if cols is None: 616 | dynmap = dynmaps[0].opts(title='', frame_height=plot_height, frame_width=plot_width, axiswise=True, framewise=True) 617 | if legend is not None: 618 | dynmap = (dynmap * legend).opts(legend_position=legend_loc) 619 | else: 620 | if legend is not None: 621 | dynmaps = [(d * l).opts(legend_position=legend_loc) 622 | for d, l in zip(dynmaps, legend.layout('bs') if lazy_loading and show_legend else [legend] * len(dynmaps))] 623 | 624 | dynmap = hv.Layout([d.opts(axiswise=True, framewise=True, 625 | frame_height=plot_height, frame_width=plot_width) for d in dynmaps]).cols(cols) 626 | 627 | plot = dynmap.cols(cols).opts(title='', height=plot_height, width=plot_width) if cols is not None else dynmap 628 | if save is not None: 629 | hv.renderer('bokeh').save(plot, save) 630 | 631 | return plot 632 | 633 | 634 | @wrap_as_col 635 | def dpt(adata, key, genes=None, basis=None, components=[1, 2], 636 | subsample='datashade', steps=40, use_raw=False, keep_frac=None, 637 | sort=True, skip=True, seed=None, show_legend=True, root_cell_all=False, 638 | root_cell_hl=True, root_cell_bbox=True, root_cell_size=None, root_cell_color='orange', 639 | legend_loc='top_right', size=4, perc=None, show_perc=True, cat_cmap=None, cont_cmap=None, 640 | plot_height=400, plot_width=400, *args, **kwargs): 641 | ''' 642 | Scatter plot for categorical observations. 643 | 644 | Params 645 | -------- 646 | adata: anndata.Anndata 647 | anndata object 648 | key: Str 649 | key in `adata.obs`, `adata.obsm` or `adata.var_names` 650 | to be visualized in top right plot 651 | can be categorical or continuous 652 | genes: List[Str], optional (default: `None`) 653 | list of genes to add for visualization 654 | if `None`, use `adata.var_names` 655 | basis: Union[Str, List[Str]], optional (default: `None`) 656 | basis in `adata.obsm`, if `None`, get all available 657 | components: Union[List[Int], List[List[Int]]], optional (default: `[1, 2]`) 658 | components of specified `basis` 659 | if it's of type `List[Int]`, all the basis have use the same components 660 | use_raw: Bool, optional (default: `False`) 661 | use `adata.raw` for gene expression levels 662 | subsample: Str, optional (default: `'datashade'`) 663 | subsampling strategy for large data 664 | possible values are `None, 'none', 'datashade', 'decimate', 'density', 'uniform'` 665 | using `subsample='datashade'` is preferred over other options since it does not subset 666 | when using `subsample='datashade'`, colorbar is not visible 667 | `'density'` and `'uniform'` use first element of `basis` for their computation 668 | steps: Union[Int, Tuple[Int, Int]], optional (default: `40`) 669 | step size when the embedding directions 670 | larger step size corresponds to higher density of points 671 | keep_frac: Float, optional (default: `adata.n_obs / 5`) 672 | number of observations to keep when `subsample='decimate'` 673 | sort: Bool, optional (default: `True`) 674 | whether sort the `genes`, `obs_keys` and `obsm_keys` 675 | in ascending order 676 | skip: Bool, optional (default: `True`) 677 | skip all the keys not found in the corresponding collections 678 | seed: Int, optional (default: `None`) 679 | random seed, used when `subsample='decimate'`` 680 | show_legend: Bool, optional (default: `True`) 681 | whether to show legend 682 | legend_loc: Str, optional (default: `top_right`) 683 | position of the legend 684 | cols: Int, optional (default: `None`) 685 | number of columns when plotting basis 686 | if `None`, use togglebar 687 | size: Int, optional (default: `4`) 688 | size of the glyphs 689 | works only when `subsample!='datashade'` 690 | perc: List[Float], optional (default: `None`) 691 | percentile for colors when `key` refers to continous observation 692 | works only when `subsample != 'datashade'` 693 | show_perc: Bool, optional (default: `True`) 694 | show percentile slider when `key` refers to continous observation 695 | works only when `subsample != 'datashade'` 696 | cat_cmap: List[Str], optional (default: `datashader.colors.Sets1to3`) 697 | categorical colormap in hex format 698 | used when `key` is categorical variable 699 | cont_cmap: List[Str], optional (default: `bokeh.palettes.Viridis256`) 700 | continuous colormap in hex format 701 | used when `key` is continuous variable 702 | root_cell_all: Bool, optional (default: `False`) 703 | show all root cells, even though they might not be in the embedding 704 | (e.g. when subsample='uniform' or 'density') 705 | otherwise only show in the embedding (based on the data of the 1st `basis`) 706 | root_cell_hl: Bool, optional (default: `True`) 707 | highlight the root cell 708 | root_cell_bbox: Bool, optional (default: `True`) 709 | show bounding box around the root cell 710 | root_cell_size: Int, optional (default `None`) 711 | size of the root cell, if `None`, it's `size * 2` 712 | root_cell_color: Str, optional (default: `red`) 713 | color of the root cell, can be in hex format 714 | plot_height: Int, optional (default: `400`) 715 | height of the plot in pixels 716 | plot_width: Int, optional (default: `400`) 717 | width of the plot in pixels 718 | *args, **kwargs: 719 | additional arguments for `sc.tl.dpt` 720 | 721 | Returns 722 | -------- 723 | plot: panel.Column 724 | holoviews plot wrapped in `panel.Column` 725 | ''' 726 | 727 | def create_scatterplot(root_cell, gene, bs, perc_low, perc_high, *args, typp='expr', ret_hl=False): 728 | ixs = np.where(basis == bs)[0][0] 729 | is_diffmap = bs == 'diffmap' 730 | 731 | if len(args) > 0: 732 | ixs = np.where(basis == bs)[0][0] * 2 733 | comp = (np.array([args[ixs], args[ixs + 1]]) - (not is_diffmap)) % adata.obsm[f'X_{bs}'].shape[-1] 734 | else: 735 | comp = np.array(components[ixs]) # need to make a copy 736 | 737 | ad, _ = alazy[bs, tuple(comp)] 738 | ad_mraw = ad.raw if use_raw else ad 739 | 740 | if perc_low is not None and perc_high is not None: 741 | if perc_low > perc_high: 742 | perc_low, perc_high = perc_high, perc_low 743 | perc = [perc_low, perc_high] 744 | else: 745 | perc = None 746 | 747 | # because diffmap has small range, it iterferes with 748 | # the legend created 749 | emb = ad.obsm[f'X_{bs}'][:, comp] * (1000 if is_diffmap else 1) 750 | comp += not is_diffmap # naming consistence 751 | 752 | bsu = bs.upper() 753 | x = hv.Dimension('x', label=f'{bsu}{comp[0]}') 754 | y = hv.Dimension('y', label=f'{bsu}{comp[1]}') 755 | xmin, xmax = minmax(emb[:, 0]) 756 | ymin, ymax = minmax(emb[:, 1]) 757 | 758 | # adata is the original, ad may be subsampled 759 | mask = np.in1d(adata.obs_names, ad.obs_names) 760 | 761 | if typp == 'emb_discrete': 762 | scatter = hv.Scatter({'x': emb[:, 0], 'y': emb[:, 1], 'condition': data[mask]}, 763 | kdims=[x, y], vdims='condition').sort('condition') 764 | 765 | scatter = scatter.opts(title=key, 766 | color='condition', 767 | xlim=(xmin, xmax), 768 | ylim=(ymin, ymax), 769 | size=size, 770 | xlabel=f'{bsu}{comp[0]}', 771 | ylabel=f'{bsu}{comp[1]}') 772 | 773 | if is_cat: 774 | # we're manually creating legend (for datashade) 775 | return scatter.opts(cmap=cat_cmap, show_legend=False) 776 | 777 | return scatter.opts(colorbar=True, colorbar_opts={'width': CBW}, 778 | cmap=cont_cmap, clim=minmax(data, perc=perc)) 779 | 780 | 781 | if typp == 'root_cell_hl': 782 | # find the index of the root cell in maybe subsampled data 783 | rid = np.where(ad.obs_names == root_cell)[0] 784 | if not len(rid): 785 | return hv.Scatter([]).opts(axiswise=True, framewise=True) 786 | 787 | rid = rid[0] 788 | dx, dy = (xmax - xmin) / 25, (ymax - ymin) / 25 789 | rx, ry = emb[rid, 0], emb[rid, 1] 790 | 791 | root_cell_scatter = hv.Scatter({'x': emb[rid, 0], 'y': emb[rid, 1]}).opts(color=root_cell_color, size=root_cell_size) 792 | if root_cell_bbox: 793 | root_cell_scatter *= hv.Bounds((rx - dx, ry - dy, rx + dx, ry + dy)).opts(line_width=4, color=root_cell_color).opts(axiswise=True, framewise=True) 794 | 795 | return root_cell_scatter 796 | 797 | 798 | adata.uns['iroot'] = np.where(adata.obs_names == root_cell)[0][0] 799 | dpt_fn(adata, *args, **kwargs) 800 | 801 | pseudotime = adata.obs['dpt_pseudotime'].values 802 | pseudotime = pseudotime[mask] 803 | pseudotime[pseudotime == np.inf] = 1 804 | pseudotime[pseudotime == -np.inf] = 0 805 | 806 | if typp == 'emb': 807 | 808 | scatter = hv.Scatter({'x': emb[:, 0], 'y': emb[:, 1], 'pseudotime': pseudotime}, 809 | kdims=[x, y], vdims='pseudotime') 810 | 811 | return scatter.opts(title='Pseudotime', 812 | cmap=cont_cmap, color='pseudotime', 813 | colorbar=True, 814 | colorbar_opts={'width': CBW}, 815 | size=size, 816 | clim=minmax(pseudotime, perc=perc), 817 | xlim=(xmin, xmax), 818 | ylim=(ymin, ymax), 819 | xlabel=f'{bsu}{comp[0]}', 820 | ylabel=f'{bsu}{comp[1]}')# if not ret_hl else root_cell_scatter 821 | 822 | if typp == 'expr': 823 | expr = ad_mraw.obs_vector(gene) 824 | 825 | x = hv.Dimension('x', label='pseudotime') 826 | y = hv.Dimension('y', label='expression') 827 | # data is in outer scope 828 | scatter_expr = hv.Scatter({'x': pseudotime, 'y': expr, 'condition': data[mask]}, 829 | kdims=[x, y], vdims='condition') 830 | 831 | scatter_expr = scatter_expr.opts(title=key, 832 | color='condition', 833 | size=size, 834 | xlim=minmax(pseudotime), 835 | ylim=minmax(expr)) 836 | if is_cat: 837 | # we're manually creating legend (for datashade) 838 | return scatter_expr.opts(cmap=cat_cmap, show_legend=False) 839 | 840 | return scatter_expr.opts(colorbar=True, colorbar_opts={'width': CBW}, 841 | cmap=cont_cmap, clim=minmax(data, perc=perc)) 842 | 843 | if typp == 'hist': 844 | return hv.Histogram(np.histogram(pseudotime, bins=20)).opts(xlabel='pseudotime', 845 | ylabel='frequency', 846 | color='#f2f2f2') 847 | 848 | raise RuntimeError(f'Unknown type `{typp}` for `create_scatterplot`.') 849 | 850 | # we copy beforehand 851 | if kwargs.pop('copy', False): 852 | adata = adata.copy() 853 | 854 | if keep_frac is None: 855 | keep_frac = 0.2 856 | 857 | if root_cell_size is None: 858 | root_cell_size = size * 2 859 | 860 | if basis is None: 861 | basis = np.ravel(sorted(filter(len, map(BS_PAT.findall, adata.obsm.keys())))) 862 | elif isinstance(basis, str): 863 | basis = np.array([basis]) 864 | elif not isinstance(basis, np.ndarray): 865 | basis = np.array(basis) 866 | 867 | if perc is None: 868 | perc = [None, None] 869 | assert len(perc) == 2, f'Percentile must be of length 2, found `{len(perc)}`.' 870 | if all(map(lambda p: p is not None, perc)): 871 | perc = sorted(perc) 872 | 873 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.' 874 | assert subsample in ALL_SUBSAMPLING_STRATEGIES, f'Invalid subsampling strategy `{subsample}`. Possible values are `{ALL_SUBSAMPLING_STRATEGIES}`.' 875 | 876 | if subsample == 'uniform': 877 | cb_kwargs = {'steps': steps} 878 | elif subsample == 'density': 879 | cb_kwargs = {'size': int(keep_frac * adata.n_obs), 'seed': seed} 880 | else: 881 | cb_kwargs = {} 882 | alazy = SamplingLazyDict(adata, subsample, callback_kwargs=cb_kwargs) 883 | adata_mraw = adata.raw if use_raw else adata 884 | 885 | if genes is None: 886 | genes = adata_mraw.var_names 887 | elif not iterable(genes): 888 | genes = [genes] 889 | genes = skip_or_filter(adata_mraw, genes, adata_mraw.var_names, where='adata.var_names', skip=skip) 890 | 891 | if sort: 892 | if any(genes[i] > genes[i + 1] for i in range(len(genes) - 1)): 893 | genes = sorted(genes) 894 | 895 | if len(genes) == 0: 896 | warnings.warn(f'No genes found. Consider speciying `skip=False`.') 897 | return 898 | 899 | if not isinstance(components, np.ndarray): 900 | components = np.array(components) 901 | if components.ndim == 1: 902 | components = np.repeat(components[np.newaxis, :], len(basis), axis=0) 903 | 904 | assert components.ndim == 2, f'Only `2` dimensional components are supported, got `{components.ndim}`.' 905 | assert components.shape[-1] == 2, f'Components\' second dimension must be of size `2`, got `{components.shape[-1]}`.' 906 | 907 | if not isinstance(basis, np.ndarray): 908 | basis = np.array(basis) 909 | 910 | assert components.shape[0] == len(basis), f'Expected #components == `{len(basis)}`, got `{components.shape[0]}`.' 911 | assert np.all(components >= 0), f'Currently, only positive indices are supported, found `{list(map(list, components))}`.' 912 | 913 | diffmap_ix = np.where(basis != 'diffmap')[0] 914 | components[diffmap_ix, :] -= 1 915 | 916 | for bs, comp in zip(basis, components): 917 | shape = adata.obsm[f'X_{bs}'].shape 918 | assert f'X_{bs}' in adata.obsm.keys(), f'`X_{bs}` not found in `adata.obsm`' 919 | assert shape[-1] > np.max(comp), f'Requested invalid components `{list(comp)}` for basis `X_{bs}` with shape `{shape}`.' 920 | 921 | if adata.n_obs > SUBSAMPLE_THRESH and subsample in NO_SUBSAMPLE: 922 | warnings.warn(f'Number of cells `{adata.n_obs}` > `{SUBSAMPLE_THRESH}`. Consider specifying `subsample={SUBSAMPLING_STRATEGIES}`.') 923 | 924 | if cat_cmap is None: 925 | cat_cmap = Sets1to3 926 | 927 | if cont_cmap is None: 928 | cont_cmap = Viridis256 929 | 930 | kdims = [hv.Dimension('Root cell', values=(adata if root_cell_all else alazy[basis[0], tuple(components[0])][0]).obs_names), 931 | hv.Dimension('Gene', values=genes), 932 | hv.Dimension('Basis', values=basis)] 933 | cs = lambda cell, gene, bs, *args, **kwargs: create_scatterplot(cell, gene, bs, perc[0], perc[1], *args, **kwargs) 934 | 935 | data, is_cat = get_data(adata, key) 936 | if is_cat: 937 | data = pd.Categorical(data) 938 | aggregator = ds.count_cat 939 | cmap = cat_cmap 940 | legend = hv.NdOverlay({c: hv.Points([0, 0], label=str(c)).opts(size=0, color=color) 941 | for c, color in zip(data.categories, cat_cmap)}) 942 | else: 943 | data = np.array(data, dtype=np.float64) 944 | aggregator = ds.mean 945 | cmap = cont_cmap 946 | legend = None 947 | if show_perc and subsample != 'datashade': 948 | kdims += [ 949 | hv.Dimension('Percentile (lower)', range=(0, 100), step=0.1, type=float, default=0 if perc[0] is None else perc[0]), 950 | hv.Dimension('Percentile (upper)', range=(0, 100), step=0.1, type=float, default=100 if perc[1] is None else perc[1]) 951 | ] 952 | cs = create_scatterplot 953 | 954 | emb = hv.DynamicMap(partial(cs, typp='emb'), kdims=kdims) 955 | if root_cell_hl: 956 | root_cell = hv.DynamicMap(partial(cs, typp='root_cell_hl'), kdims=kdims) 957 | emb_d = hv.DynamicMap(partial(cs, typp='emb_discrete'), kdims=kdims) 958 | expr = hv.DynamicMap(partial(cs, typp='expr'), kdims=kdims) 959 | hist = hv.DynamicMap(partial(cs, typp='hist'), kdims=kdims) 960 | 961 | if subsample == 'datashade': 962 | emb = dynspread(datashade(emb, aggregator=ds.mean('pseudotime'), cmap=cont_cmap, 963 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize], 964 | min_alpha=255), 965 | threshold=0.8, max_px=5) 966 | emb_d = dynspread(datashade(emb_d, aggregator=aggregator('condition'), cmap=cmap, 967 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize], 968 | min_alpha=255), 969 | threshold=0.8, max_px=5) 970 | expr = dynspread(datashade(expr, aggregator=aggregator('condition'), cmap=cmap, 971 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize], 972 | min_alpha=255), 973 | threshold=0.8, max_px=5) 974 | elif subsample == 'decimate': 975 | emb, emb_d, expr = (decimate(d, max_samples=int(adata.n_obs * keep_frac)) for d in (emb, emb_d, expr)) 976 | 977 | if root_cell_hl: 978 | emb *= root_cell # emb * root_cell.opts(axiswise=True, framewise=True) 979 | 980 | emb = emb.opts(axiswise=False, framewise=True, frame_height=plot_height, frame_width=plot_width) 981 | expr = expr.opts(axiswise=True, framewise=True, frame_height=plot_height, frame_width=plot_width) 982 | emb_d = emb_d.opts(axiswise=True, framewise=True, frame_height=plot_height, frame_width=plot_width) 983 | hist = hist.opts(axiswise=True, framewise=True, frame_height=plot_height, frame_width=plot_width) 984 | 985 | if show_legend and legend is not None: 986 | emb_d = (emb_d * legend).opts(legend_position=legend_loc, show_legend=True) 987 | 988 | return ((emb + emb_d) + (hist + expr).opts(axiswise=True, framewise=True)).cols(2) 989 | 990 | 991 | @wrap_as_col 992 | def graph(adata, key, basis=None, components=[1, 2], obs_keys=[], color_key=None, color_key_reduction=np.sum, 993 | ixs=None, top_n_edges=None, filter_edges=None, directed=True, bundle=False, bundle_kwargs={}, 994 | subsample=None, layouts=None, layout_kwargs={}, force_paga_indices=False, 995 | degree_by=None, legend_loc='top_right', node_size=12, edge_width=2, arrowhead_length=None, 996 | perc=None, color_edges_by='weight', hover_selection='nodes', 997 | node_cmap=None, edge_cmap=None, plot_height=600, plot_width=600): 998 | ''' 999 | Params 1000 | -------- 1001 | 1002 | adata: anndata.Anndata 1003 | anndata object 1004 | key: Str 1005 | key in `adata.uns`, `adata.uns[\'paga\'] or adata.uns[\'neighbors\']` which 1006 | represents the graph as an adjacency matrix (can be sparse) 1007 | use `'paga'` to access PAGA connectivies graph or (prefix `'p:...'`) 1008 | to access `adata.uns['paga'][...]` 1009 | for `adata.uns['neighbors'][...]`, use prefix `'n:...'` 1010 | basis: Union[Str, List[Str]], optional (default: `None`) 1011 | basis in `adata.obsm`, if `None`, get all of them 1012 | components: Union[List[Int], List[List[Int]]], optional (default: `[1, 2]`) 1013 | components of specified `basis` 1014 | if it's of type `List[Int]`, all the basis have use the same components 1015 | color_key: Str, optional (default: `None`) 1016 | variable in `adata.obs` with which to color in each node 1017 | or `'incoming'`, `'outgoing'` for coloring values based on weights 1018 | color_key_reduction: Callable, optional (default: `np.sum`) 1019 | a numpy function, such as `np.mean`, `np.max`, ... when 1020 | `color_key` is `'incoming'` or `'outgoing'` 1021 | obs_keys: List[Str], optional (default: `None`) 1022 | keys of categorical observations in `adata.obs` 1023 | if `None`, get all available, only visible when `hover_selection='nodes'` 1024 | ixs: List[Int], optional (default: `None`) 1025 | list of indices of nodes of graph to visualize 1026 | if `None`, visualize all 1027 | top_n_edges: Union[Int, Tuple[Int, Bool, Str]], optional (default: `None`) 1028 | only for directed graph 1029 | maximum number of outgoing edges per node to keep based on decreasing weight 1030 | if a tuple, the second element specifies whether it's ascending or not 1031 | the third one whether whether to consider outgoing ('out') or ('in') incoming edges 1032 | filter_edges: Tuple[Float, Float], optional (default: `None`) 1033 | min and max threshold values for edge visualization 1034 | nodes without edges will *NOT* be removed 1035 | directed: Bool, optional (default: `True`) 1036 | whether the graph is directed or not 1037 | subsample: Str, optional (default: `None`) 1038 | subsampling strategies for edges 1039 | possible values are `None, \'none\', \'datashade\'` 1040 | bundle: Bool, optional (default: `False`) 1041 | whether to bundle edges together (can be computationally expensive) 1042 | bundle_kwargs: Dict, optional (defaul: `{}`) 1043 | kwargs for bundler, e.g. `iterations=1` (default `4`) 1044 | for more options, see `hv.operation.datashader.bundle_graph` 1045 | layouts: List[Str], optional (default: `None`) 1046 | layout names to use when drawing graph, e.g. `'umap'` in `adata.obsm` 1047 | or `'kamada_kawai'` from `nx.layouts` 1048 | if `None`, use all available layouts 1049 | layout_kwargs: Dict[Str, Dict], optional (default: `{}`) 1050 | kwargs for a given layout 1051 | force_paga_indices: Bool, optional (default: `False`) 1052 | by default, when `key='paga'`, all indices are used 1053 | regardless of what was specified 1054 | degree_by: Str, optional (default: `None`) 1055 | if `'weights'`, use edge weights when calculating the degree 1056 | only visible when `hover_selection='nodes'` 1057 | legend_loc: Str, optional (default: `'top_right'`) 1058 | locations of the legend, if `None`, do not show legend 1059 | node_size: Float, optional (default: `12`) 1060 | size of the graph nodes 1061 | edge_width: Float, optional (default: `2`) 1062 | width of the graph edges 1063 | arrowhead_length: Float, optional (default: `None`) 1064 | length of the arrow when `directed=True` 1065 | perc: List[Float], optional (default: `None`) 1066 | percentile for edge colors 1067 | *WARNING* this can remove nodes and will be fixed in the future 1068 | color_edges_by: Str, optional (default: `weight`) 1069 | whether to color edges, if `None` do not color edges 1070 | hover_selection: Str, optional (default: `'nodes'`) 1071 | whether to define hover over `'nodes'` or `'edges'` 1072 | if `subsample == 'datashade'`, it is always `'nodes'` 1073 | node_cmap: List[Str], optional (default: `datashader.colors.Sets1to3`) 1074 | colormap in hex format for `color_key` 1075 | edge_cmap: List[Str], optional (default: `bokeh.palettes.Viridis256`) 1076 | continuous colormap in hex format for edges 1077 | plot_height: Int, optional (default: `600`) 1078 | height of the plot in pixels 1079 | plot_width: Int, optional (default: `600`) 1080 | width of the plot in pixels 1081 | 1082 | Returns 1083 | -------- 1084 | plot: panel.Column 1085 | `hv.DynamicMap` wrapped in `panel.Column` that displays the graph in various layouts 1086 | ''' 1087 | 1088 | def normalize(emb): 1089 | # TODO: to this once 1090 | # normalize because of arrows... 1091 | emb = emb.copy() 1092 | x_min, y_min = np.min(emb[:, 0]), np.min(emb[:, 1]) 1093 | emb[:, 0] = (emb[:, 0] - x_min) / (np.max(emb[:, 0]) - x_min) 1094 | emb[:, 1] = (emb[:, 1] - y_min) / (np.max(emb[:, 1]) - y_min) 1095 | return emb 1096 | 1097 | def create_graph(adata, data): 1098 | if perc is not None: 1099 | data = percentile(data, perc) 1100 | create_using = nx.DiGraph if directed else nx.Graph 1101 | g = (nx.from_scipy_sparse_matrix if issparse(data) else nx.from_numpy_array)(data, create_using=create_using) 1102 | 1103 | if filter_edges is not None: 1104 | minn, maxx = filter_edges 1105 | minn = minn if minn is not None else -np.inf 1106 | maxx = maxx if maxx is not None else np.inf 1107 | for e, attr in list(g.edges.items()): 1108 | if attr['weight'] < minn or attr['weight'] > maxx: 1109 | g.remove_edge(*e) 1110 | 1111 | to_keep = None 1112 | if top_n_edges is not None: 1113 | if isinstance(top_n_edges, (tuple, list)): 1114 | to_keep, ascending, group_by = top_n_edges 1115 | else: 1116 | to_keep, ascending, group_by = top_n_edges, False, 'out' 1117 | 1118 | source, target = zip(*g.edges) 1119 | weights = [v['weight'] for v in g.edges.values()] 1120 | tmp = pd.DataFrame({'out': source, 'in': target, 'w': weights}) 1121 | 1122 | to_keep = set(map(tuple, tmp.groupby(group_by).apply(lambda g: g.sort_values('w', ascending=ascending).take(range(min(to_keep, len(g)))))[['out', 'in']].values)) 1123 | 1124 | for e in list(g.edges): 1125 | if e not in to_keep: 1126 | g.remove_edge(*e) 1127 | 1128 | if not len(g.nodes): 1129 | raise RuntimeError('Empty graph.') 1130 | 1131 | if not len(g.edges): 1132 | msg = 'No edges to visualize.' 1133 | if filter_edges is not None: 1134 | msg += f' Consider altering the edge filtering thresholds `{filter_edges}`.' 1135 | if top_n_edges is not None: 1136 | msg += f' Perhaps use more top edges than `{to_keep}`.' 1137 | raise RuntimeError(msg) 1138 | 1139 | if hover_selection == 'nodes': 1140 | if directed: 1141 | nx.set_node_attributes(g, values=dict(g.in_degree(weight=degree_by)), 1142 | name='indegree') 1143 | nx.set_node_attributes(g, values=dict(g.out_degree(weight=degree_by)), 1144 | name='outdegree') 1145 | nx.set_node_attributes(g, values=nx.in_degree_centrality(g), 1146 | name='indegree centrality') 1147 | nx.set_node_attributes(g, values=nx.out_degree_centrality(g), 1148 | name='outdegree centrality') 1149 | else: 1150 | nx.set_node_attributes(g, values=dict(g.degree(weight=degree_by)), 1151 | name='degree') 1152 | nx.set_node_attributes(g, values=nx.degree_centrality(g), 1153 | name='centrality') 1154 | 1155 | if not is_paga: 1156 | nx.set_node_attributes(g, values=dict(zip(g.nodes.keys(), adata.obs.index)), 1157 | name='name') 1158 | for key in list(obs_keys): 1159 | nx.set_node_attributes(g, values=dict(zip(g.nodes.keys(), adata.obs[key])), 1160 | name=key) 1161 | if color_key is not None: 1162 | # color_vals has been set beforehand 1163 | nx.set_node_attributes(g, values=dict(zip(g.nodes.keys(), adata.obs[color_key] if color_key in adata.obs.keys() else color_vals)), 1164 | name=color_key) 1165 | 1166 | else: 1167 | nx.set_node_attributes(g, values=dict(zip(g.nodes.keys(), adata.obs[color_key].cat.categories)), 1168 | name=color_key) 1169 | 1170 | return g 1171 | 1172 | def embed_graph(layout_key, graph): 1173 | bs_key = f'X_{layout_key}' 1174 | if bs_key in adata.obsm.keys(): 1175 | emb = adata_ss.obsm[bs_key][:, get_component[layout_key]] 1176 | emb = normalize(emb) 1177 | layout = dict(zip(graph.nodes.keys(), emb)) 1178 | l_kwargs = {} 1179 | elif layout_key == 'paga': 1180 | layout = dict(zip(graph.nodes.keys(), paga_pos)) 1181 | l_kwargs = {} 1182 | elif layout_key in DEFAULT_LAYOUTS: 1183 | layout = DEFAULT_LAYOUTS[layout_key] 1184 | l_kwargs = layout_kwargs.get(layout_key, {}) 1185 | 1186 | g = hv.Graph.from_networkx(graph, positions=layout, **l_kwargs) 1187 | g = g.opts(inspection_policy='nodes' if subsample == 'datashade' else hover_selection, 1188 | tools=['hover', 'box_select'], 1189 | edge_color=hv.dim(color_edges_by) if color_edges_by is not None else None, 1190 | edge_line_width=edge_width * (hv.dim('weight') if is_paga else 1), 1191 | edge_cmap=edge_cmap, 1192 | node_color=color_key, 1193 | node_cmap=node_cmap, 1194 | directed=directed, 1195 | colorbar=True, 1196 | show_legend=legend_loc is not None 1197 | ) 1198 | 1199 | return g if arrowhead_length is None else g.opts(arrowhead_length=arrowhead_length) 1200 | 1201 | def get_nodes(layout_key): # DRY DRY DRY 1202 | nodes = bundled[layout_key].nodes 1203 | bs_key = f'X_{layout_key}' 1204 | 1205 | if bs_key in adata.obsm.keys(): 1206 | emb = adata_ss.obsm[bs_key][:, get_component[layout_key]] 1207 | emb = normalize(emb) 1208 | xlim = minmax(emb[:, 0]) 1209 | ylim = minmax(emb[:, 1]) 1210 | elif layout_key == 'paga': 1211 | xlim = minmax(paga_pos[:, 0]) 1212 | ylim = minmax(paga_pos[:, 1]) 1213 | else: 1214 | xlim, ylim = bundled[layout_key].range('x'), bundled[layout_key].range('y') 1215 | 1216 | xlim, ylim = pad(*xlim), pad(*ylim) # for datashade 1217 | 1218 | # remove axes for datashade 1219 | return nodes.opts(xlim=xlim, ylim=ylim, xaxis=None, yaxis=None, show_legend=legend_loc is not None) 1220 | 1221 | assert subsample in (None, 'none', 'datashade'), \ 1222 | f'Invalid subsampling strategy `{subsample}`. Possible values are None, \'none\', \'datashade\'.`' 1223 | 1224 | if top_n_edges is not None: 1225 | assert directed, f'`n_top_edges` works only on directed graphs.`' 1226 | if isinstance(top_n_edges, (tuple, list)): 1227 | assert len(top_n_edges) == 3, f'`top_n_edges` must be of length 3, found `{len(top_n_edges)}`.' 1228 | assert isinstance(top_n_edges[0], int), f'`top_n_edges[0]` must be an int, found `{type(top_n_edges[0])}`.' 1229 | assert isinstance(top_n_edges[1], bool), f'`top_n_edges[1]` must be a bool, found `{type(top_n_edges[1])}`.' 1230 | assert top_n_edges[2] in ('in', 'out') , '`top_n_edges[2]` must be either \'in\' or \'out\'.' 1231 | else: 1232 | assert isinstance(top_n_edges, int), f'`top_n_edges` must be an int, found `{type(top_n_edges)}`.' 1233 | 1234 | if edge_cmap is None: 1235 | edge_cmap = Viridis256 1236 | if node_cmap is None: 1237 | node_cmap = Sets1to3 1238 | 1239 | if color_key is not None: 1240 | assert color_key in adata.obs or color_key in ('incoming', 'outgoing'), f'Color key `{color_key}` not found in `adata.obs` and is not \'incoming\' or \'outgoing\'.' 1241 | 1242 | if obs_keys is None: 1243 | obs_keys = adata.obs.keys() 1244 | else: 1245 | for obs_key in obs_keys: 1246 | assert obs_key in adata.obs.keys(), f'Key `{obs_key}` not found in `adata.obs`.' 1247 | 1248 | if key.startswith('p:') or key.startswith('n:'): 1249 | which, key = key.split(':') 1250 | elif key == 'paga': # QOL 1251 | which, key = 'p', 'connectivities' 1252 | else: 1253 | which = None 1254 | 1255 | paga_pos = None 1256 | is_paga = False 1257 | if which is None: 1258 | if key in adata.uns.keys(): 1259 | data = adata.uns[key] 1260 | elif hasattr(adata, 'obsp') and key in adata.obsp: 1261 | data = adata.obsp[key] 1262 | elif which == 'n' and key in adata.uns['neighbors'].keys(): 1263 | data = adata.uns['neighbors'][key] 1264 | elif which == 'p' and key in adata.uns['paga'].keys(): 1265 | data = adata.uns['paga'][key] 1266 | is_paga = True 1267 | directed = False 1268 | if 'pos' in adata.uns['paga'].keys(): 1269 | paga_pos = adata.uns['paga']['pos'] 1270 | else: 1271 | raise ValueError(f'Key `{key}` not found in `adata.uns` or ' 1272 | '`adata.uns[\'neighbors\']` or `adata.uns[\'paga\']`. ' 1273 | 'To visualize the graphs in `uns[\'neighbors\']` or uns[\'paga\'] ' 1274 | 'prefix the key with `n:` or `p:`, respectively (e.g. `n:connectivities`).') 1275 | assert data.ndim == 2, f'Adjacency matrix must be dimension of `2`, found `{adata.ndim}`.' 1276 | assert data.shape[0] == data.shape[1], 'Adjacency matrix is not square, found shape `{data.shape}`.' 1277 | 1278 | if ixs is None or (is_paga and not force_paga_indices): 1279 | ixs = np.arange(data.shape[0]) 1280 | else: 1281 | assert np.min(ixs) >= 0 1282 | assert np.max(ixs) < adata.shape[0] 1283 | 1284 | data = data[ixs, :][:, ixs] 1285 | adata_ss = adata[ixs, :] if not is_paga or (len(ixs) != data.shape[0] and force_paga_indices) else adata 1286 | 1287 | if layouts is None: 1288 | layouts = list(DEFAULT_LAYOUTS.keys()) 1289 | if isinstance(layouts, str): 1290 | layouts = [layouts] 1291 | for l in layouts: 1292 | assert l in DEFAULT_LAYOUTS.keys(), f'Unknown layout `{l}`. Available layouts are `{list(DEFAULT_LAYOUTS.keys())}`.' 1293 | 1294 | if np.min(data) < 0 and 'kamada_kawai' in layouts: 1295 | warnings.warn('`kamada_kawai` layout required non-negative edges, removing it from the list of possible layouts.') 1296 | layouts.remove('kamada_kawai') 1297 | 1298 | if basis is None: 1299 | basis = np.ravel(sorted(filter(len, map(BS_PAT.findall, adata.obsm.keys())))) 1300 | elif basis is str: 1301 | basis = [basis] 1302 | if not isinstance(basis, np.ndarray): 1303 | basis = np.array(basis) 1304 | 1305 | if not isinstance(components, np.ndarray): 1306 | components = np.array(components) 1307 | if components.ndim == 1: 1308 | components = np.repeat(components[np.newaxis, :], len(basis), axis=0) 1309 | if len(basis): 1310 | components[np.where(basis != 'diffmap')] -= 1 1311 | 1312 | if is_paga: 1313 | g_name = adata.uns['paga']['groups'] 1314 | if color_key is None or color_key != g_name: 1315 | warnings.warn(f'Color key `{color_key}` differs from PAGA\'s groups `{g_name}`, setting it to `{g_name}`.') 1316 | color_key = g_name 1317 | if len(basis): 1318 | warnings.warn(f'Cannot plot PAGA in the basis `{basis}`, removing them from layouts.') 1319 | basis, components = [], [] 1320 | 1321 | for bs, comp in zip(basis, components): 1322 | shape = adata.obsm[f'X_{bs}'].shape 1323 | assert f'X_{bs}' in adata.obsm.keys(), f'`X_{bs}` not found in `adata.obsm`' 1324 | assert shape[-1] > np.max(comp), f'Requested invalid components `{list(comp)}` for basis `X_{bs}` with shape `{shape}`.' 1325 | 1326 | if paga_pos is not None: 1327 | basis = ['paga'] 1328 | components = [0, 1] 1329 | get_component = dict(zip(basis, components)) 1330 | 1331 | is_categorical = False 1332 | if color_key is not None: 1333 | node_cmap = adata_ss.uns[f'{color_key}_colors'] if f'{color_key}_colors' in adata_ss.uns else node_cmap 1334 | if color_key in adata.obs: 1335 | color_vals = adata_ss.obs[color_key] 1336 | if is_categorical_dtype(color_vals) or is_string_dtype(adata.obs[color_key]): 1337 | color_vals = color_vals.astype('category').cat.categories 1338 | is_categorical = True 1339 | node_cmap = odict(zip(color_vals, to_hex_palette(node_cmap))) 1340 | else: 1341 | color_vals = adata_ss.obs[color_key].values 1342 | else: 1343 | color_vals = np.array(color_key_reduction(data, axis=int(color_key == 'outgoing'))).flatten() 1344 | 1345 | if not is_categorical: 1346 | legend_loc = None 1347 | 1348 | layouts = np.append(basis, layouts) 1349 | if len(layouts) == 0: 1350 | warnings.warn('Nothing to plot, no layouts found.') 1351 | return 1352 | 1353 | # because of the categories 1354 | graph = create_graph(adata_ss, data=data) 1355 | 1356 | kdims = [hv.Dimension('Layout', values=layouts)] 1357 | g = hv.DynamicMap(partial(embed_graph, graph=graph), kdims=kdims).opts(axiswise=True, framewise=True) # necessary as well 1358 | 1359 | if subsample != 'datashade': 1360 | for layout_key in layouts: 1361 | bs_key = f'X_{layout_key}' 1362 | if bs_key in adata.obsm.keys(): 1363 | emb = adata_ss.obsm[bs_key][:, get_component[layout_key]] 1364 | emb = normalize(emb) 1365 | xlim = minmax(emb[:, 0]) 1366 | ylim = minmax(emb[:, 1]) 1367 | elif layout_key == 'paga': 1368 | xlim = minmax(paga_pos[:, 0]) 1369 | ylim = minmax(paga_pos[:, 1]) 1370 | else: 1371 | xlim, ylim = g[layout_key].range('x'), g[layout_key].range('y') 1372 | xlim, ylim = pad(*xlim), pad(*ylim) 1373 | g[layout_key].opts(xlim=xlim, ylim=ylim) # other layouts are not normalized 1374 | 1375 | bundled = bundle_graph(g, **bundle_kwargs, weight=None) if bundle else g.clone() 1376 | nodes = hv.DynamicMap(get_nodes, kdims=kdims).opts(axiswise=True, framewise=True) # needed for datashade 1377 | 1378 | if subsample == 'datashade': 1379 | g = (datashade(bundled, normalization='linear', color_key=color_edges_by, min_alpha=128, 1380 | cmap='black' if color_edges_by is None else edge_cmap, 1381 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize])) 1382 | res = (g * nodes).opts(height=plot_height, width=plot_width).opts( 1383 | hv.opts.Nodes(size=node_size, tools=['hover'], cmap=node_cmap, 1384 | fill_color='orange' if color_key is None else color_key) 1385 | ) 1386 | else: 1387 | res = bundled.opts(height=plot_height, width=plot_width).opts( 1388 | hv.opts.Graph( 1389 | node_size=node_size, 1390 | node_fill_color='orange' if color_key is None else color_key, 1391 | node_nonselection_alpha=0.05, 1392 | edge_nonselection_alpha=0.05, 1393 | edge_cmap=edge_cmap, 1394 | node_cmap=node_cmap 1395 | ) 1396 | ) 1397 | if legend_loc is not None and color_key is not None: 1398 | res *= hv.NdOverlay({k: hv.Points([0,0], label=str(k)).opts(size=0, color=v) 1399 | for k, v in node_cmap.items()}) 1400 | 1401 | if legend_loc is not None and color_key is not None: 1402 | res = res.opts(legend_position=legend_loc) 1403 | 1404 | return res.opts(hv.opts.Graph(xaxis=None, yaxis=None)) 1405 | -------------------------------------------------------------------------------- /interactive_plotting/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from interactive_plotting.utils._utils import * 2 | -------------------------------------------------------------------------------- /interactive_plotting/utils/_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from functools import wraps 4 | from collections import Iterable 5 | from inspect import signature 6 | from sklearn.neighbors import NearestNeighbors 7 | from scipy.sparse import issparse 8 | 9 | import anndata 10 | import matplotlib.colors as colors 11 | import scanpy as sc 12 | import numpy as np 13 | import pandas as pd 14 | import networkx as nx 15 | import panel as pn 16 | import re 17 | import itertools 18 | import warnings 19 | 20 | 21 | NO_SUBSAMPLE = (None, 'none') 22 | SUBSAMPLING_STRATEGIES = ('datashade', 'decimate', 'density', 'uniform') 23 | ALL_SUBSAMPLING_STRATEGIES = NO_SUBSAMPLE + SUBSAMPLING_STRATEGIES 24 | 25 | SUBSAMPLE_THRESH = 30_000 26 | HOLOMAP_THRESH = 50 27 | OBSM_SEP = ':' 28 | 29 | CBW = 10 # colorbar width 30 | BS_PAT = re.compile('^X_(.+)') 31 | 32 | # for graph 33 | DEFAULT_LAYOUTS = {l.split('_layout')[0]:getattr(nx.layout, l) 34 | for l in dir(nx.layout) if l.endswith('_layout')} 35 | DEFAULT_LAYOUTS.pop('bipartite') 36 | DEFAULT_LAYOUTS.pop('rescale') 37 | DEFAULT_LAYOUTS.pop('spectral') 38 | 39 | 40 | class SamplingLazyDict(dict): 41 | 42 | def __init__(self, adata, subsample, *args, callback_kwargs={}, **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.adata = adata 45 | self.callback_kwargs = callback_kwargs 46 | 47 | if subsample == 'uniform': 48 | self.callback = sample_unif 49 | elif subsample == 'density': 50 | self.callback = sample_density 51 | else: 52 | ixs = list(range(adata.n_obs)) 53 | self.callback = lambda *args, **kwargs: (adata, ixs) 54 | 55 | def __getitem__(self, key): 56 | if key not in self: 57 | bs, comps = key 58 | rev_comps = comps[::-1] 59 | 60 | if (bs, rev_comps) in self.keys(): 61 | res, ixs = self[bs, rev_comps] 62 | else: 63 | res, ixs = self.callback(self.adata, bs=bs, components=comps, **self.callback_kwargs) 64 | 65 | self[key] = res, ixs 66 | 67 | return res, ixs 68 | 69 | return super().__getitem__(key) 70 | 71 | 72 | def to_hex_palette(palette, normalize=True): 73 | """ 74 | Converts matplotlib color array to hex strings 75 | """ 76 | if not isinstance(palette, np.ndarray): 77 | palette = np.array(palette) 78 | 79 | if isinstance(palette[0], str): 80 | assert all(map(colors.is_color_like, palette)), 'Not all strings are color like.' 81 | return palette 82 | 83 | if normalize: 84 | minn = np.min(palette) 85 | # normalize to [0, 1] 86 | palette = (palette - minn) / (np.max(palette) - minn) 87 | 88 | return [colors.to_hex(c) if colors.is_color_like(c) else c for c in palette] 89 | 90 | 91 | def pad(minn, maxx, padding=0.05): 92 | if minn > maxx: 93 | maxx, minn = minn, maxx 94 | return minn - padding, maxx + padding 95 | 96 | 97 | def iterable(obj): 98 | ''' 99 | Checks whether the object is iterable non-string. 100 | 101 | Params 102 | -------- 103 | obj: Object 104 | Python object 105 | 106 | Returns 107 | -------- 108 | is_iterable: Bool 109 | whether the object is not `str` and is 110 | instance of class `Iterable` 111 | ''' 112 | 113 | return not isinstance(obj, str) and isinstance(obj, Iterable) 114 | 115 | 116 | def istype(obj): 117 | ''' 118 | Checks whether the object is of class `type`. 119 | 120 | Params 121 | -------- 122 | obj: Union[Object, Tuple] 123 | Python object or a tuple 124 | 125 | Returns 126 | -------- 127 | is_type: Bool 128 | `True` if the objects is instance of class `type` or 129 | all the element of the tuple is of class `type` 130 | ''' 131 | 132 | return isinstance(obj, type) or (isinstance(obj, tuple) and all(map(lambda o: isinstance(o, type), obj))) 133 | 134 | def is_numeric(obj): 135 | ''' 136 | Params 137 | obj: Object 138 | Python object 139 | 140 | -------- 141 | Returns 142 | is_numeric: Bool 143 | `True` if the object is numeric, else `False` 144 | -------- 145 | ''' 146 | return all(hasattr(obj, attr) 147 | for attr in ('__add__', '__sub__', '__mul__', '__truediv__', '__pow__')) 148 | 149 | 150 | def is_categorical(obj): 151 | ''' 152 | Is the object categorical? 153 | 154 | Params 155 | -------- 156 | obj: Python object 157 | object that has attribute `'dtype'`, 158 | 159 | Returns 160 | -------- 161 | is_categorical: Bool 162 | `True` if it's categorical else `False` 163 | ''' 164 | 165 | return obj.dtype.name == 'category' 166 | 167 | 168 | def minmax(component, perc=None, is_sorted=False): 169 | ''' 170 | Get the minimum and maximum value of an array. 171 | 172 | Params 173 | -------- 174 | component: Union[np.ndarray, List, Tuple] 175 | 1-D array 176 | perc: Union[List[Float], Tuple[Float]] 177 | clip the values by the percentiles 178 | is_sorted: Bool, optional (default: `False`) 179 | whether the component is already sorted, 180 | if `True`, min and max are the first and last 181 | elements respectively 182 | 183 | Returns 184 | -------- 185 | min_max: Tuple[Float, Float] 186 | minimum and maximum values that are not NaN 187 | ''' 188 | if perc is not None: 189 | assert len(perc) == 2, 'Percentile must be of length 2.' 190 | component = np.clip(component, *np.percentile(component, sorted(perc))) 191 | 192 | return (np.nanmin(component), np.nanmax(component)) if not is_sorted else (component[0], component[-1]) 193 | 194 | 195 | def skip_or_filter(adata, needles, haystack, where='', dtype=None, 196 | skip=False, warn=True, ignore_after=None): 197 | ''' 198 | Find all the needles in a given haystack. 199 | 200 | Params 201 | -------- 202 | adata: anndata.AnnData 203 | anndata object 204 | needles: List[Str] 205 | keys to search for 206 | haystack: Iterable 207 | collection to search, e.g. `adata.obs.keys()` 208 | where: Str, optional, default (`''`) 209 | attribute of `anndata.AnnData` where to look, e. g. `'obsm'` 210 | dtype: Union[Callable, Type] 211 | expected datatype of the needles 212 | skip: Bool, optional (default: `False`) 213 | whether to skip the needles which do not have 214 | the expected `dtype` 215 | warn: Bool 216 | whether to issue a warning if `skip=True` 217 | ignore_after: Str, optional (default: `None`) 218 | token used for extracting the actual key name 219 | from the needle of form `KeyTokenIndex`, neeeded when 220 | `where='obsm'`, useful e.g. for extracting specific components 221 | 222 | Returns 223 | -------- 224 | found_needles: List[Str] 225 | list of all the needles in haystack 226 | if `skip=False`, will throw a `RuntimeError` 227 | if the needle's type differs from `dtype` 228 | ''' 229 | 230 | needles_f = list(map(lambda n: n[:n.find(ignore_after)], needles)) if ignore_after is not None else needles 231 | res = [] 232 | 233 | for n, nf in zip(needles, needles_f): 234 | if nf not in haystack: 235 | msg = f'`{nf}` not found in `adata.{where}.keys()`.' 236 | if not skip: 237 | assert False, msg 238 | if warn: 239 | warnings.warn(msg + ' Skipping.') 240 | continue 241 | 242 | col = getattr(adata, where)[nf] 243 | val = col[0] if isinstance(col, np.ndarray) else col.iloc[0] # np.ndarray of pd.DataFrame 244 | if n != nf: 245 | assert where == 'obsm', f'Indexing is only supported for `adata.obsm`, found {nf} in adata.`{where}`.' 246 | _, ix = n.split(ignore_after) 247 | assert nf == _, 'Unable to parse input.' 248 | val = val[int(ix)] 249 | 250 | msg = None 251 | is_tup = isinstance(dtype, tuple) 252 | if isinstance(dtype, type) or is_tup: 253 | if not isinstance(val, dtype): 254 | types = dtype.__name__ if not is_tup else f"Union[{', '.join(map(lambda d: d.__name__, dtype))}]" 255 | msg = f'Expected `{nf}` to be of type `{types}`, found `{type(val).__name__}`.' 256 | elif callable(dtype): 257 | if not dtype(val): 258 | msg = f'`{nf}` did not pass the type checking of `{callable.__name__}`.' 259 | else: 260 | assert isinstance(dtype, str) 261 | if not dtype == col.dtype.name: 262 | msg = f'Expected `{nf}` to be of type `{dtype}`, found `{col.dtype.name}`.' 263 | 264 | if msg is not None: 265 | if not skip: 266 | raise RuntimeError(msg) 267 | if warn: 268 | warnings.warn(msg + ' Skipping.') 269 | continue 270 | 271 | res.append(n) 272 | 273 | return res 274 | 275 | 276 | def has_attributes(*args, **kwargs): 277 | ''' 278 | Params 279 | -------- 280 | *args: variable length arguments 281 | key to check in for 282 | **kwargs: keyword arguments 283 | attributes to check, keys will be interpreted 284 | as arguments in function signature to check 285 | and values are lists annotated as follows: 286 | `[, 'a:', '', '', 'a:', ...]` 287 | using type `None` will result in no type checking 288 | 289 | Returns 290 | -------- 291 | wrapped: Callable 292 | function, which does the checks at runtime 293 | -------- 294 | ''' 295 | 296 | def inner(fn): 297 | ''' 298 | Binds the arguments of the function and checks the types. 299 | 300 | Params 301 | -------- 302 | fn: Callable 303 | function to wrap 304 | 305 | Returns 306 | -------- 307 | wrapped: Callable 308 | the wrapped function 309 | ''' 310 | 311 | @wraps(fn) 312 | def inner2(*fargs, **fkwargs): 313 | bound = sig.bind(*fargs, **fkwargs) 314 | bound.apply_defaults() 315 | 316 | for k, v in kwargs.items(): 317 | if isinstance(v, type) or istype(v): 318 | assert isinstance(bound.arguments[k], v), f'Argument: `{k}` must be of type: `{v}`.' 319 | elif iterable(v): 320 | if not iterable(v[0]): 321 | v = [v] 322 | 323 | for vals in v: 324 | typp = None 325 | if vals[0] is None or istype(vals[0]): 326 | typp, *vals = vals 327 | if not vals[0].startswith('a:'): 328 | raise ValueError('The first element must be an attribute ' 329 | f'annotated with: `a:`, found: `{vals[0]}`. ' 330 | f'Consider using: `a:{vals[0]}`.') 331 | 332 | obj = None 333 | for val in vals: 334 | if val.startswith('a:'): 335 | obj = getattr(obj if obj is not None else bound.arguments[k], val[2:]) 336 | else: 337 | assert obj is not None 338 | obj = obj[val] 339 | 340 | if typp is not None: 341 | assert isinstance(obj, typp) 342 | else: 343 | raise RuntimeError(f'Unable to decode invariant: `{k}={v}`.') 344 | 345 | return fn(*fargs, **fkwargs) 346 | 347 | sig = signature(fn) 348 | for param in tuple(kwargs.keys()) + args: 349 | if not param in sig.parameters.keys(): 350 | raise ValueError(f'Parameter `{param}` not found in the signature.') 351 | 352 | return inner2 353 | 354 | return inner 355 | 356 | 357 | def wrap_as_panel(fn): 358 | ''' 359 | Wrap the widget inside a panel. 360 | 361 | Params 362 | -------- 363 | fn: Callable 364 | funtion that returns a plot, such as `scatter` 365 | 366 | Returns 367 | -------- 368 | wrapper: Callable 369 | function which return object of type `pn.panel` 370 | ''' 371 | 372 | @wraps(fn) 373 | def inner(*args, **kwargs): 374 | reverse = kwargs.pop('reverse', True) 375 | res = fn(*args, **kwargs) 376 | if res is None: 377 | return None 378 | 379 | res = pn.panel(res) 380 | if reverse and hasattr(res, 'reverse'): 381 | res.reverse() 382 | 383 | return res 384 | 385 | return inner 386 | 387 | 388 | def wrap_as_col(fn): 389 | ''' 390 | Wrap the widget in a column, having it's 391 | input in one row. 392 | 393 | Params 394 | -------- 395 | fn: Callable 396 | funtion that returns a plot, such as `dpt` 397 | 398 | Returns 399 | -------- 400 | wrapped: Callable 401 | function which return object of type `pn.Column` 402 | ''' 403 | 404 | def chunkify(l, n): 405 | for i in range(0, len(l), n): 406 | yield l[i: i + n] 407 | 408 | @wraps(fn) 409 | def inner(*args, **kwargs): 410 | reverse = kwargs.pop('reverse', True) 411 | res = fn(*args, **kwargs) 412 | if res is None: 413 | return None 414 | 415 | res = pn.panel(res) 416 | if reverse and hasattr(res, 'reverse'): 417 | res.reverse() 418 | 419 | widgets = list(map(lambda w: pn.Row(*w), filter(len, chunkify(res[0], 3)))) 420 | return pn.Column(*(widgets + [res[1]])) 421 | 422 | return inner 423 | 424 | 425 | def get_data(adata, needle, ignore_after=OBSM_SEP, haystacks=['obs', 'obsm', 'var_names']): 426 | f''' 427 | Search for a needle in multiple haystacks. 428 | 429 | Params 430 | -------- 431 | adata: anndata.AnnData 432 | anndata object 433 | needle: Str 434 | needle to search for 435 | ignore_after: Str, optional (default: `'{OBSM_SEP}'`) 436 | token used for extracting the actual key name 437 | from the needle of form `KeyTokenIndex`, neeeded 438 | when `'obsm' in haystacks`, useful e.g. for extracting specific components 439 | haystack: List[Str], optional (default: `['obs', 'obsm', 'var_names']`) 440 | attributes of `anndata.AnnData` 441 | 442 | Returns 443 | -------- 444 | (result, is_categorical): Tuple[Object, Bool] 445 | the found object and whether it's categorical 446 | ''' 447 | 448 | for haystack in haystacks: 449 | obj = getattr(adata, haystack) 450 | if ignore_after in needle and haystack == 'obsm': 451 | k, ix = needle.split(ignore_after) 452 | ix = int(ix) 453 | else: 454 | k, ix = needle, None 455 | 456 | if k in obj: 457 | res = obj[k] if haystack != 'var_names' else adata.obs_vector(k) 458 | if ix is not None: 459 | assert res.ndim == 2, f'`adata.{haystack}[{k}]` must have a dimension of 2, found `{res.dim}`.' 460 | assert res.shape[-1] > ix, f'Index `{ix}` out of bounds for `adata.{haystack}[{k}]` of shape `{res.shape}`.' 461 | res = res[:, ix] 462 | if res.shape != (adata.n_obs, ): 463 | msg = f'`{needle}` in `adata.{haystack}` has wrong shape of `{res.shape}`.' 464 | if haystack == 'obsm': 465 | msg += f' Try using `{needle}{OBSM_SEP}ix`, 0 <= `ix` < {res.shape[-1]}.' 466 | raise ValueError(msg) 467 | 468 | return res, is_categorical(res) 469 | 470 | raise ValueError(f'Unable to find `{needle}` in `adata.{haystacks}`.') 471 | 472 | 473 | def get_all_obsm_keys(adata, ixs): 474 | if not isinstance(ixs, (tuple, list)): 475 | ixs = [ixs] 476 | 477 | assert all(map(lambda ix: ix >= 0, ixs)), f'All indices must be non-negative.' 478 | 479 | return list(itertools.chain.from_iterable((f'{key}{OBSM_SEP}{ix}' 480 | for key in adata.obsm.keys() if isinstance(adata.obsm[key], np.ndarray) and adata.obsm[key].ndim == 2 and adata.obsm[key].shape[-1] > ix) 481 | for ix in ixs)) 482 | # based on: 483 | # https://github.com/velocyto-team/velocyto-notebooks/blob/master/python/DentateGyrus.ipynb 484 | def sample_unif(adata, steps, bs='umap', components=(0, 1)): 485 | if not isinstance(steps, (tuple, list)): 486 | steps = [steps] * len(components) 487 | 488 | assert len(components) 489 | assert min(components) >= 0 490 | 491 | embedding = adata.obsm[f'X_{bs}'][:, components] 492 | n_dim = len(components) 493 | 494 | grs = [] 495 | for i in range(n_dim): 496 | m, M = np.min(embedding[:, i]), np.max(embedding[:, i]) 497 | m = m - 0.025 * np.abs(M - m) 498 | M = M + 0.025 * np.abs(M - m) 499 | gr = np.linspace(m, M, num=steps[i]) 500 | grs.append(gr) 501 | 502 | meshes_tuple = np.meshgrid(*grs) 503 | gridpoints_coordinates = np.vstack([i.flat for i in meshes_tuple]).T 504 | 505 | nn = NearestNeighbors() 506 | nn.fit(embedding) 507 | dist, ixs = nn.kneighbors(gridpoints_coordinates, 1) 508 | 509 | diag_step_dist = np.linalg.norm([grs[dim_i][1] - grs[dim_i][0] for dim_i in range(n_dim)]) 510 | min_dist = diag_step_dist / 2 511 | 512 | ixs = ixs[dist < min_dist] 513 | ixs = np.unique(ixs) 514 | 515 | return adata[ixs, :].copy(), ixs 516 | 517 | 518 | def sample_density(adata, size, bs='umap', seed=None, components=[0, 1]): 519 | if size >= adata.n_obs: 520 | return adata 521 | 522 | if components[0] == components[1]: 523 | tmp = pd.DataFrame(np.ones(adata.n_obs) / adata.n_obs, columns=['prob_density']) 524 | else: 525 | # should be unique, using it only once since we cache the results 526 | # we don't need to add the components 527 | key_added = f'{bs}_density_ipl_tmp' 528 | remove_key = False # we may be operating on original object, keep it clean 529 | if key_added not in adata.obs.keys(): 530 | sc.tl.embedding_density(adata, bs, key_added=key_added) 531 | remove_key = True 532 | tmp = pd.DataFrame(np.exp(adata.obs[key_added]) / np.sum(np.exp(adata.obs[key_added]))) 533 | tmp.rename(columns={key_added: 'prob_density'}, inplace=True) 534 | if remove_key: 535 | del adata.obs[key_added] 536 | 537 | state = np.random.RandomState(seed) 538 | ixs = sorted(state.choice(range(adata.n_obs), size=size, p=tmp['prob_density'], replace=False)) 539 | 540 | return adata[ixs].copy(), ixs 541 | 542 | 543 | def get_xy_data(x, adata, adata_mraw, layer, indices, use_original_limits=False, inc=0): 544 | 545 | def extract(data, ix): 546 | if isinstance(data, anndata._core.anndata.Raw) or \ 547 | isinstance(data, anndata.AnnData): 548 | return data.obs_vector(ix)[indices] 549 | if issparse(data): 550 | return np.squeeze(data.getcol(ix).A)[indices] 551 | 552 | # assume np.array 553 | return data[indices, ix] 554 | 555 | xlim = None 556 | msg = f'Unable to decode key `{x}`.' 557 | 558 | if layer is None: 559 | take_from = adata_mraw 560 | else: 561 | # mraw is always adata in this case 562 | assert layer in adata.layers, f'Layer `\'{layer}\'` not found in `adata.layers`.' 563 | assert adata is adata_mraw, 'Sanity check failed.' 564 | take_from = adata.layers[layer] 565 | 566 | if not isinstance(x, int): 567 | assert isinstance(x, str) 568 | # can't use take from, since it can be an array 569 | if x in adata_mraw.var_names: 570 | ix = np.where(x == adata_mraw.var_names)[0][0] 571 | xlabel = adata_mraw.var_names[ix] 572 | x = extract(take_from, ix) 573 | 574 | return x, xlabel, xlim 575 | 576 | if x in adata_mraw.obs_keys(): 577 | x, xlabel = adata_mraw.obs[x].values, x 578 | if use_original_limits: 579 | xlim = pad(*minmax(x)) 580 | 581 | return x[indices], xlabel, xlim 582 | 583 | x = x.lstrip('X_') 584 | comp, *ix = x.split(':') 585 | ix = int(ix[0]) if len(ix) else (int(x == 'diffmap') + inc) 586 | 587 | if f'X_{comp}' in adata.obsm: 588 | xlabel = f'{comp}_{ix}' 589 | x = adata.obsm[f'X_{comp}'][:, ix] 590 | if use_original_limits: 591 | xlim = pad(*minmax(x)) 592 | 593 | return x[indices], xlabel, xlim 594 | 595 | raise RuntimeError(msg) 596 | 597 | xlabel = adata_mraw.var_names[x] 598 | x = extract(take_from, x) 599 | 600 | return x, xlabel, xlim 601 | 602 | 603 | def get_mraw(adata, use_raw): 604 | if not use_raw: 605 | return adata 606 | 607 | return adata.raw if hasattr(adata, 'raw') and adata.raw is not None else adata 608 | 609 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scanpy>=1.4.3 2 | pandas>=0.23.4 3 | matplotlib>=3.0.2 4 | setuptools>=41.0.1 5 | scipy>=1.2.0 6 | numpy>=1.16.4 7 | scikit_learn>=0.21.3 8 | anndata>=0.7.3 9 | bokeh==1.4.0 10 | datashader==0.9.0 11 | panel==0.6.2 12 | holoviews==1.12.7 13 | -------------------------------------------------------------------------------- /resources/images/dpt_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/dpt_plot.png -------------------------------------------------------------------------------- /resources/images/gene_trend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/gene_trend.png -------------------------------------------------------------------------------- /resources/images/graph_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/graph_plot.png -------------------------------------------------------------------------------- /resources/images/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/heatmap.png -------------------------------------------------------------------------------- /resources/images/highlight_de.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/highlight_de.png -------------------------------------------------------------------------------- /resources/images/inter_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/inter_hist.png -------------------------------------------------------------------------------- /resources/images/link_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/link_plot.png -------------------------------------------------------------------------------- /resources/images/scatter3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/scatter3d.png -------------------------------------------------------------------------------- /resources/images/scatter_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/scatter_cat.png -------------------------------------------------------------------------------- /resources/images/scatter_cont.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/scatter_cont.png -------------------------------------------------------------------------------- /resources/images/scatter_general1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/scatter_general1.png -------------------------------------------------------------------------------- /resources/images/scatter_general2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/scatter_general2.png -------------------------------------------------------------------------------- /resources/images/thresh_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/thresh_hist.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | import os 4 | 5 | setup( 6 | name='Interactive Plotting', 7 | version='0.0.5', 8 | description='Interactive plotting functions for scanpy', 9 | url='https://github.com/theislab/interactive_plotting', 10 | license='MIT', 11 | packages=find_packages(), 12 | setup_requires=['setuptools_scm'], 13 | include_package_data=True, 14 | install_requires=list(map(str.strip, 15 | open(os.path.abspath('requirements.txt'), 'r').read().split())), 16 | 17 | zip_safe=False 18 | ) 19 | --------------------------------------------------------------------------------