├── .gitignore
├── README.md
├── interactive_plotting
├── __init__.py
├── bokeh_plots.py
├── experimental
│ ├── __init__.py
│ ├── plots.py
│ ├── scatter3d.py
│ └── surface3d.ts
├── holoviews_plots.py
└── utils
│ ├── __init__.py
│ └── _utils.py
├── notebooks
└── interactive_plotting_tutorial.ipynb
├── requirements.txt
├── resources
└── images
│ ├── dpt_plot.png
│ ├── gene_trend.png
│ ├── graph_plot.png
│ ├── heatmap.png
│ ├── highlight_de.png
│ ├── inter_hist.png
│ ├── link_plot.png
│ ├── scatter3d.png
│ ├── scatter_cat.png
│ ├── scatter_cont.png
│ ├── scatter_general1.png
│ ├── scatter_general2.png
│ └── thresh_hist.png
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore folders anywhere in the repo
2 | **/data
3 | **/.ipynb_checkpoints
4 | **/__pycache__
5 | **/.idea
6 |
7 | # ignore specific folders
8 | write_results/
9 | figures/
10 | dataframes/
11 | write_results/
12 |
13 | # ignore files anywhere in the repo
14 | paths.py
15 |
16 | # generic python stuff to ignore
17 | # Byte-compiled / optimized / DLL files
18 | __pycache__/
19 | *.py[cod]
20 | *$py.class
21 |
22 | # C extensions
23 | *.so
24 |
25 | # Distribution / packaging
26 | .Python
27 | build/
28 | develop-eggs/
29 | dist/
30 | downloads/
31 | eggs/
32 | .eggs/
33 | lib/
34 | lib64/
35 | parts/
36 | sdist/
37 | var/
38 | wheels/
39 | *.egg-info/
40 | .installed.cfg
41 | *.egg
42 | MANIFEST
43 |
44 | # PyInstaller
45 | # Usually these files are written by a python script from a template
46 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
47 | *.manifest
48 | *.spec
49 |
50 | # Installer logs
51 | pip-log.txt
52 | pip-delete-this-directory.txt
53 |
54 | # Unit test / coverage reports
55 | htmlcov/
56 | .tox/
57 | .coverage
58 | .Rhistory
59 | .pybiomart.sqlite
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | .hypothesis/
66 | .pytest_cache/
67 |
68 | # Translations
69 | *.mo
70 | *.pot
71 |
72 | # Django stuff:
73 | *.log
74 | local_settings.py
75 | db.sqlite3
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 |
87 | # PyBuilder
88 | target/
89 |
90 | # Jupyter Notebook
91 | .ipynb_checkpoints
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # celery beat schedule file
97 | celerybeat-schedule
98 |
99 | # SageMath parsed files
100 | *.sage.py
101 |
102 | # Environments
103 | .env
104 | .venv
105 | env/
106 | venv/
107 | ENV/
108 | env.bak/
109 | venv.bak/
110 |
111 | # Spyder project settings
112 | .spyderproject
113 | .spyproject
114 |
115 | # Rope project settings
116 | .ropeproject
117 |
118 | # mkdocs documentation
119 | /site
120 |
121 | # mypy
122 | .mypy_cache/
123 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Interactive Plotting in Scanpy
2 |
3 |
4 | ## About
5 | This repository contains 11 different interactive plotting functions, which may be useful during exploratory analysis.
6 |
7 | Almost every function provides some information when hovering over the plot and some parts of the plots can be hidden by clicking the legend.
8 |
9 | ## Installation
10 | To install this package, do the following:
11 | ```bash
12 | conda install nodejs # >= v6.10.0, for 3D scatterplot
13 | pip install git+https://github.com/theislab/interactive_plotting
14 | ```
15 | For 3D scatterplot, `node.js >= v6.10.0` is required. Go to node's [website](https://nodejs.org/en/) for instructions on how to install it.
16 |
17 | ## Getting Started
18 | We recommend checking out the [tutorial notebook](./notebooks/interactive_plotting_tutorial.ipynb).
19 | ```ipl.scatter```, ```ipl.scatterc``` ```ipl.dpt``` can handle large number of cells (100K+).
20 |
21 | In your Jupyter Notebook, execute the following lines:
22 | ```python
23 | import holoviews as hv # needed for scatter, scatterc and dpt
24 | hv.extension('bokeh')
25 |
26 | import interactive_plotting as ipl
27 |
28 | from bokeh.io import output_notebook
29 | output_notebook()
30 | ```
31 |
32 | ## Gallery
33 | Here are some exemplary figures for each of the plotting functions.
34 | ```python
35 | ipl.ex.scatter
36 | ```
37 | 
38 |
39 | ---
40 |
41 | ```python
42 | ipl.ex.scatter3d
43 | ```
44 | 
45 |
46 | ---
47 |
48 | ```python
49 | ipl.ex.scatter
50 | ```
51 | 
52 |
53 | ---
54 |
55 | ```python
56 | ipl.scatter
57 | ```
58 | ")
59 |
60 | ---
61 |
62 | ```python
63 | ipl.scatterc
64 | ```
65 | ")
66 |
67 | ---
68 |
69 | ```python
70 | ipl.ex.heatmap
71 | ```
72 | 
73 |
74 | ---
75 |
76 | ```python
77 | ipl.dpt
78 | ```
79 | 
80 |
81 | ---
82 |
83 | ```python
84 | ipl.graph
85 | ```
86 | 
87 |
88 | ---
89 |
90 | ```python
91 | ipl.link_plot
92 | ```
93 | 
94 |
95 | ---
96 |
97 | ```python
98 | ipl.highlight_de
99 | ```
100 | 
101 |
102 | ---
103 |
104 | ```python
105 | ipl.gene_trend
106 | ```
107 | 
108 |
109 | ---
110 |
111 | ```python
112 | ipl.interactive_hist
113 | ```
114 | 
115 |
116 | ---
117 |
118 | ```python
119 | ipl.thresholding_hist
120 | ```
121 | 
122 |
123 | ## Troubleshooting
124 | * [Notebook size is **huge**](https://github.com/theislab/interactive_plotting/issues/2) - This has to do with ```ipl.link_plot``` and ```ipl.velocity_plot```. Until a fix is found, we suggest removing these figures after you're done using them.
125 | * [Getting "OPub data rate exceeded" error](https://github.com/theislab/interactive_plotting/issues/7) - Try starting jupyter notebook as following:
126 |
127 | ```jupyter notebook --NotebookApp.iopub_data_rate_limit=1e10```
128 |
129 | For generating jupyter config file, see [here](https://stackoverflow.com/questions/43288550/iopub-data-rate-exceeded-in-jupyter-notebook-when-viewing-image).
130 |
--------------------------------------------------------------------------------
/interactive_plotting/__init__.py:
--------------------------------------------------------------------------------
1 | from interactive_plotting.bokeh_plots import interactive_hist, \
2 | thresholding_hist, \
3 | highlight_de, \
4 | link_plot, \
5 | gene_trend
6 | from interactive_plotting.holoviews_plots import scatter, scatterc, dpt, graph
7 |
8 | import interactive_plotting.experimental as ex
9 |
--------------------------------------------------------------------------------
/interactive_plotting/bokeh_plots.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from sklearn.gaussian_process.kernels import *
4 | from sklearn import neighbors
5 | from scipy.sparse import issparse
6 | from scipy.spatial import distance_matrix, ConvexHull
7 |
8 | from functools import reduce
9 | from collections import defaultdict
10 | from itertools import product
11 |
12 | import warnings
13 |
14 | import numpy as np
15 | import pandas as pd
16 | import scanpy as sc
17 |
18 | import matplotlib.cm as cm
19 | import matplotlib
20 | import bokeh
21 |
22 |
23 | from interactive_plotting.utils._utils import sample_unif, sample_density, to_hex_palette
24 | from bokeh.plotting import figure, show, save as bokeh_save
25 | from bokeh.models import ColumnDataSource, Slider, HoverTool, ColorBar, \
26 | Patches, Legend, CustomJS, TextInput, LabelSet, Select
27 | from bokeh.models.ranges import Range1d
28 | from bokeh.models.mappers import CategoricalColorMapper, LinearColorMapper
29 | from bokeh.layouts import layout, column, row, GridSpec
30 | from bokeh.transform import linear_cmap, factor_mark, factor_cmap
31 | from bokeh.core.enums import MarkerType
32 | from bokeh.palettes import Set1, Set2, Set3, inferno, viridis
33 | from bokeh.models.widgets.buttons import Button
34 |
35 | _bokeh_version = tuple(map(int, bokeh.__version__.split('.')))
36 |
37 |
38 | _inter_hist_js_code="""
39 | // here is where original data is stored
40 | var x = orig.data['values'];
41 |
42 | x = x.sort((a, b) => a - b);
43 | var n_bins = parseInt(bins.value); // can be either string or int
44 | var bin_size = (x[x.length - 1] - x[0]) / n_bins;
45 |
46 | var hist = new Array(n_bins).fill().map((_, i) => { return 0; });
47 | var l_edges = new Array(n_bins).fill().map((_, i) => { return x[0] + bin_size * i; });
48 | var r_edges = new Array(n_bins).fill().map((_, i) => { return x[0] + bin_size * (i + 1); });
49 | var indices = new Array(n_bins).fill().map((_) => { return []; });
50 |
51 | // create the histogram
52 | for (var i = 0; i < x.length; i++) {
53 | for (var j = 0; j < r_edges.length; j++) {
54 | if (x[i] <= r_edges[j]) {
55 | hist[j] += 1;
56 | indices[j].push(i);
57 | break;
58 | }
59 | }
60 | }
61 |
62 | // make it a density
63 | var sum = hist.reduce((a, b) => a + b, 0);
64 | var deltas = r_edges.map((c, i) => { return c - l_edges[i]; });
65 | // just like in numpy
66 | hist = hist.map((c, i) => { return c / deltas[i] / sum; });
67 |
68 | source.data['hist'] = hist;
69 | source.data['l_edges'] = l_edges;
70 | source.data['r_edges'] = r_edges;
71 | source.data['indices'] = indices;
72 |
73 | source.change.emit();
74 | """
75 |
76 |
77 | def _inter_color_code(*colors):
78 | assert len(colors) > 0, 'Doesn\'t make sense using no colors.'
79 | color_code = '\n'.join((f'renderers[i].glyph.{c} = {{field: cb_obj.value, transform: transform}};'
80 | for c in colors))
81 | return f"""
82 | var transform = cmaps[cb_obj.value]['transform'];
83 | var low = Math.min.apply(Math, source.data[cb_obj.value]);
84 | var high = Math.max.apply(Math, source.data[cb_obj.value]);
85 |
86 | for (var i = 0; i < renderers.length; i++) {{
87 | {color_code}
88 | }}
89 |
90 | color_bar.color_mapper.low = low;
91 | color_bar.color_mapper.high = high;
92 | color_bar.color_mapper.palette = transform.palette;
93 |
94 | source.change.emit();
95 | """
96 |
97 |
98 | def _set_plot_wh(fig, w, h):
99 | if w is not None:
100 | fig.plot_width = w
101 | if h is not None:
102 | fig.plot_height = h
103 |
104 |
105 | def _create_mapper(adata, key):
106 | """
107 | Helper function to create CategoricalColorMapper from annotated data.
108 |
109 | Params
110 | --------
111 | adata: AnnData
112 | annotated data object
113 | key: str
114 | key in `adata.obs.obs_keys()` or `adata.var_names`, for which we want the colors; if no colors for given
115 | column are found in `adata.uns[key_colors]`, use `viridis` palette
116 |
117 | Returns
118 | --------
119 | mapper: bokeh.models.mappers.CategoricalColorMapper
120 | mapper which maps valuems from `adata.obs[key]` to colors
121 | """
122 | if not key in adata.obs_keys():
123 | assert key in adata.var_names, f'`{key}` not found in `adata.obs_keys()` or `adata.var_names`'
124 | ix = np.where(adata.var_names == key)[0][0]
125 | vals = list(adata.X[:, ix])
126 | palette = cm.get_cmap('viridis', adata.n_obs)
127 |
128 | mapper = dict(zip(vals, range(len(vals))))
129 | palette = to_hex_palette(palette([mapper[v] for v in vals]))
130 |
131 | return LinearColorMapper(palette=palette, low=np.min(vals), high=np.max(vals))
132 |
133 | is_categorical = adata.obs[key].dtype.name == 'category'
134 | default_palette = cm.get_cmap('viridis', adata.n_obs if not is_categorical else len(adata.obs[key].unique()))
135 | palette = adata.uns.get(f'{key}_colors', default_palette)
136 |
137 | if palette is default_palette:
138 | vals = adata.obs[key].unique() if is_categorical else adata.obs[key]
139 | mapper = dict(zip(vals, range(len(vals))))
140 | palette = palette([mapper[v] for v in vals])
141 |
142 | palette = to_hex_palette(palette)
143 |
144 | if is_categorical:
145 | return CategoricalColorMapper(palette=palette, factors=list(map(str, adata.obs[key].cat.categories)))
146 |
147 | return LinearColorMapper(palette=palette, low=np.min(adata.obs[key]), high=np.max(adata.obs[key]))
148 |
149 |
150 | def _smooth_expression(x, y, n_points=100, time_span=[None, None], mode='gp', kernel_params=dict(), kernel_default_params=dict(),
151 | kernel_expr=None, default=False, verbose=False, **opt_params):
152 | """Smooth out the expression of given values.
153 |
154 | Params
155 | --------
156 | x: list(number)
157 | list of features
158 | y: list(number)
159 | list of targets
160 | n_points: int, optional (default: `100`)
161 | number of points to extrapolate
162 | time_span: list(int), optional (default `[None, None]`)
163 | initial and final start values for range
164 | mode: str, optional (default: `'gp'`)
165 | which regressor to use, available (`'gp'`: Gaussian Process, `'krr'`: Kernel Ridge Regression)
166 | kernel_params: dict, optional (default: `dict()`)
167 | dictionary of kernels with their parameters, keys correspond to variable names
168 | which can be later combined using `kernel_expr`. Supported kernels: `ConstantKernel`, `WhiteKernel`,
169 | `RBF`, `Mattern`, `RationalQuadratic`, `ExpSineSquared`, `DotProduct`, `PairWiseKernel`.
170 | kernel_default_params: dict, optional (default: `dict()`)
171 | default parameters for a kernel, if not found in `kernel_params`
172 | kernel_expr: str, default (`None`)
173 | expression to combine kernel variables specified in `kernel_params`. Supported operators are `+`, `*`, `**`;
174 | example: kernel_expr=`'(a + b) ** 2'`, kernel_params=`{'a': ConstantKernel(1), 'b': DotProduct(2)}`
175 | default: bool, optional (default: `False`)
176 | whether to use default kernel (RBF), if none specified and/or to use default
177 | parameters for kernel variables in` kernel_expr`, not found in `kernel_params`
178 | if False, throws an Exception
179 | verbose: bool, optional (default: `False`)
180 | be more verbose
181 | **opt_params: kwargs
182 | keyword arguments for optimizer
183 |
184 | Returns
185 | --------
186 | x_test: np.array
187 | points for which we predict the values
188 | x_mean: np.array
189 | mean of the response
190 | cov: np.array (`None` for mode=`'krr'`)
191 | covariance matrix of the response
192 | """
193 |
194 | from sklearn.kernel_ridge import KernelRidge
195 | from sklearn.gaussian_process import GaussianProcessRegressor
196 | import operator as op
197 | import ast
198 |
199 | def _eval(node):
200 | if isinstance(node, ast.Num):
201 | return node.n
202 |
203 | if isinstance(node, ast.Name):
204 | if not default and node.id not in kernel_params:
205 | raise ValueError(f'Error while parsing `{kernel_expr}`: `{node.id}` is not a valid key in kernel_params. To use RBF kernel with default parameters, specify default=True.')
206 | params = kernel_params.get(node.id, kernel_default_params)
207 | kernel_type = params.pop('type', 'rbf')
208 | return kernels[kernel_type](**params)
209 |
210 | if isinstance(node, ast.BinOp):
211 | return operators[type(node.op)](_eval(node.left), _eval(node.right))
212 |
213 | if isinstance(node, ast.UnaryOp):
214 | return operators[type(node.op)](_eval(node.operand))
215 |
216 | raise TypeError(node)
217 |
218 | operators = {ast.Add : op.add,
219 | ast.Mult: op.mul,
220 | ast.Pow :op.pow}
221 | kernels = dict(const=ConstantKernel,
222 | white=WhiteKernel,
223 | rbf=RBF,
224 | mat=Matern,
225 | rq=RationalQuadratic,
226 | esn=ExpSineSquared,
227 | dp=DotProduct,
228 | pw=PairwiseKernel)
229 |
230 | minn, maxx = time_span
231 | x_test = np.linspace(np.min(x) if minn is None else minn, np.max(x) if maxx is None else maxx, n_points)[:, None]
232 |
233 | if mode == 'krr':
234 | gamma = opt_params.pop('gamma', None)
235 |
236 | if gamma is None:
237 | length_scale = kernel_default_params.get('length_scale', 0.2)
238 | gamma = 1 / (2 * length_scale ** 2)
239 | if verbose:
240 | print(f'Smoothing using KRR with length_scale: {length_scale}.')
241 |
242 | kernel = opt_params.pop('kernel', 'rbf')
243 | model = KernelRidge(gamma=gamma, kernel=kernel, **opt_params)
244 | model.fit(x, y)
245 |
246 | return x_test, model.predict(x_test), [None] * n_points
247 |
248 | if mode == 'gp':
249 |
250 | if kernel_expr is None:
251 | assert len(kernel_params) == 1
252 | kernel_expr, = kernel_params.keys()
253 |
254 | kernel = _eval(ast.parse(kernel_expr, mode='eval').body)
255 | alpha = opt_params.pop('alpha', None)
256 | if alpha is None:
257 | alpha = np.std(y)
258 |
259 | optimizer = opt_params.pop('optimizer', None)
260 | opt_params['kernel'] = kernel
261 |
262 | model = GaussianProcessRegressor(alpha=alpha, optimizer=optimizer, **opt_params)
263 | model.fit(x, y)
264 |
265 | mean, cov = model.predict(x_test, return_cov=True)
266 |
267 | return x_test, mean, cov
268 |
269 | raise ValueError(f'Uknown type: `{type}`.')
270 |
271 |
272 | def _create_gt_fig(adatas, dataframe, color_key, title, color_mapper, show_cont_annot=False,
273 | use_raw=True, genes=[], legend_loc='top_right',
274 | plot_width=None, plot_height=None):
275 | """
276 | Helper function which create a figure with smoothed velocities, including
277 | confidence intervals, if possible.
278 |
279 | Params:
280 | --------
281 | dataframe: pandas.DataFrame
282 | dataframe containing the velocity data
283 | color_key: str
284 | column in `dataframe` that is to be mapped to colors
285 | title: str
286 | title of the figure
287 | color_mapper: bokeh.models.mappers.CategoricalColorMapper
288 | transformation which assings a value from `dataframe[color_key]` to a color
289 | show_cont_annot: bool, optional (default: `False`)
290 | show continuous annotations in `adata.obs`, if `color_key` is
291 | itself a continuous variable
292 | use_raw: bool, optional (default: `True`)
293 | whether to use adata.raw to get the expression
294 | genes: list, optional (default: `[]`)
295 | list of possible genes to color in,
296 | only works if `color_key` is continuous variable
297 | legend_loc: str, default(`'top_right'`)
298 | position of the legend
299 | plot_width: int, optional (default: `None`)
300 | width of the plot
301 | plot_height: int, optional (default: `None`)
302 | height of the plot
303 |
304 | Returns:
305 | --------
306 | fig: bokeh.plotting.figure
307 | figure containing the plot
308 | """
309 |
310 | # these markers are nearly indistinguishble
311 | markers = [marker for marker in MarkerType if marker not in ['circle_cross', 'circle_x']]
312 | fig = figure(title=title)
313 | _set_plot_wh(fig, plot_width, plot_height)
314 |
315 | renderers, color_selects = [], []
316 | for i, (adata, marker, (path, df)) in enumerate(zip(adatas, markers, dataframe.iterrows())):
317 | ds = {'dpt': df['dpt'],
318 | 'expr': df['expr'],
319 | f'{color_key}': df[color_key]}
320 | is_categorical = color_key in adata.obs_keys() and adata.obs[color_key].dtype.name == 'category'
321 | if not is_categorical:
322 | ds, mappers = _get_mappers(adata, ds, genes, use_raw=use_raw)
323 |
324 | source = ColumnDataSource(ds)
325 | renderers.append(fig.scatter('dpt', 'expr', source=source,
326 | color={'field': color_key, 'transform': color_mapper if is_categorical else mappers[color_key]['transform']},
327 | fill_color={'field': color_key, 'transform': color_mapper if is_categorical else mappers[color_key]['transform']},
328 | line_color={'field': color_key, 'transform': color_mapper if is_categorical else mappers[color_key]['transform']},
329 | marker=marker, size=10, legend_label=path, muted_alpha=0))
330 |
331 | fig.xaxis.axis_label = 'dpt'
332 | fig.yaxis.axis_label = 'expression'
333 | if legend_loc is not None:
334 | fig.legend.location = legend_loc
335 |
336 | if not is_categorical and show_cont_annot:
337 | color_selects.append(_add_color_select(color_key, fig, [renderers[-1]], source, mappers, suffix=f' [{path}]'))
338 |
339 | ds = dict(df[['x_test', 'x_mean', 'x_cov']])
340 | if ds.get('x_test') is not None:
341 | if ds.get('x_mean') is not None:
342 | source = ColumnDataSource(ds)
343 | fig.line('x_test', 'x_mean', source=source, muted_alpha=0, legend_label=path)
344 | if all(map(lambda val: val is not None, ds.get('x_cov', [None]))):
345 | x_mean = ds['x_mean']
346 | x_cov = ds['x_cov']
347 | band_x = np.append(ds['x_test'][::-1], ds['x_test'])
348 | # black magic, known only to the most illustrious of wizards
349 | band_y = np.append((x_mean - np.sqrt(np.diag(x_cov)))[::-1], (x_mean + np.sqrt(np.diag(x_cov))))
350 | fig.patch(band_x, band_y, alpha=0.1, line_color='black', fill_color='black',
351 | legend_label=path, line_dash='dotdash', muted_alpha=0)
352 |
353 | if ds.get('x_grad') is not None:
354 | fig.line('x_test', 'x_grad', source=source, muted_alpha=0)
355 |
356 | fig.legend.click_policy = 'mute'
357 |
358 | return column(fig, *color_selects)
359 |
360 |
361 | def interactive_hist(adata, keys=['n_counts', 'n_genes'],
362 | bins='auto', max_bins=100,
363 | groups=None, fill_alpha=0.4,
364 | palette=None, display_all=True,
365 | tools='pan, reset, wheel_zoom, save',
366 | legend_loc='top_right',
367 | plot_width=None, plot_height=None, save=None,
368 | *args, **kwargs):
369 | """Utility function to plot distributions with variable number of bins.
370 |
371 | Params
372 | --------
373 | adata: AnnData object
374 | annotated data object
375 | keys: list(str), optional (default: `['n_counts', 'n_genes']`)
376 | keys in `adata.obs` or `adata.var` where the distibutions are stored
377 | bins: int; str, optional (default: `auto`)
378 | number of bins used for plotting or str from numpy.histogram
379 | max_bins: int, optional (default: `1000`)
380 | maximum number of bins possible
381 | groups: list(str), (default: `None`)
382 | keys in `adata.obs.obs_keys()`, groups by all possible combinations of values, e.g. for
383 | 3 plates and 2 time points, we would create total of 6 groups
384 | fill_alpha: float[0.0, 1.0], (default: `0.4`)
385 | alpha channel of the fill color
386 | palette: list(str), optional (default: `None`)
387 | palette to use
388 | display_all: bool, optional (default: `True`)
389 | display the statistics for all data
390 | tools: str, optional (default: `'pan,reset, wheel_zoom, save'`)
391 | palette of interactive tools for the user
392 | legend_loc: str, (default: `'top_right'`)
393 | position of the legend
394 | legend_loc: str, default(`'top_left'`)
395 | position of the legend
396 | plot_width: int, optional (default: `None`)
397 | width of the plot
398 | plot_height: int, optional (default: `None`)
399 | height of the plot
400 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
401 | path where to save the plot
402 | *args, **kwargs: arguments, keyword arguments
403 | addition argument to bokeh.models.figure
404 |
405 | Returns
406 | --------
407 | None
408 | """
409 |
410 | if max_bins < 1:
411 | raise ValueError(f'`max_bins` must >= 1')
412 |
413 | palette = Set1[9] + Set2[8] + Set3[12] if palette is None else palette
414 |
415 | # check the input
416 | for key in keys:
417 | if key not in adata.obs.keys() and \
418 | key not in adata.var.keys() and \
419 | key not in adata.var_names:
420 | raise ValueError(f'The key `{key}` does not exist in `adata.obs`, `adata.var` or `adata.var_names`.')
421 |
422 | def _create_adata_groups():
423 | if groups is None:
424 | return [adata], [('all',)]
425 |
426 | combs = list(product(*[set(adata.obs[g]) for g in groups]))
427 | adatas= [adata[reduce(lambda l, r: l & r,
428 | (adata.obs[k] == v for k, v in zip(groups, vals)), True)]
429 | for vals in combs] + [adata]
430 |
431 | if display_all:
432 | combs += [('all',)]
433 | adatas += [adata]
434 |
435 | return adatas, combs
436 |
437 | # group_v_combs contains the value combinations
438 | ad_gs = _create_adata_groups()
439 |
440 | cols = []
441 | for key in keys:
442 | callbacks = []
443 | fig = figure(*args, tools=tools, **kwargs)
444 | slider = Slider(start=1, end=max_bins, value=0, step=1,
445 | title='Bins')
446 |
447 | plots = []
448 | for j, (ad, group_vs) in enumerate(filter(lambda ad_g: ad_g[0].n_obs > 0, zip(*ad_gs))):
449 |
450 | if key in ad.obs.keys():
451 | orig = ad.obs[key]
452 | hist, edges = np.histogram(orig, density=True, bins=bins)
453 | elif key in ad.var.keys():
454 | orig = ad.var[key]
455 | hist, edges = np.histogram(orig, density=True, bins=bins)
456 | else:
457 | orig = ad[:, key].X
458 | hist, edges = np.histogram(orig, density=True, bins=bins)
459 |
460 | slider.value = len(hist)
461 | # case when automatic bins
462 | max_bins = max(max_bins, slider.value)
463 |
464 | # original data, used for recalculation of histogram in JS code
465 | orig = ColumnDataSource(data=dict(values=orig))
466 | # data that we update in JS code
467 | source = ColumnDataSource(data=dict(hist=hist, l_edges=edges[:-1], r_edges=edges[1:]))
468 |
469 | legend = ', '.join(': '.join(map(str, gv)) for gv in zip(groups, group_vs)) \
470 | if groups is not None else 'all'
471 | p = fig.quad(source=source, top='hist', bottom=0,
472 | left='l_edges', right='r_edges',
473 | fill_color=palette[j], legend_label=legend if legend_loc is not None else None,
474 | muted_alpha=0,
475 | line_color="#555555", fill_alpha=fill_alpha)
476 |
477 | # create callback and slider
478 | callback = CustomJS(args=dict(source=source, orig=orig), code=_inter_hist_js_code)
479 | callback.args['bins'] = slider
480 | callbacks.append(callback)
481 |
482 | # add the current plot so that we can set it
483 | # visible/invisible in JS code
484 | plots.append(p)
485 |
486 | slider.end = max_bins
487 |
488 | # slider now updates all values
489 | slider.js_on_change('value', *callbacks)
490 |
491 | button = Button(label='Toggle', button_type='primary')
492 | button.callback = CustomJS(
493 | args={'plots': plots},
494 | code='''
495 | for (var i = 0; i < plots.length; i++) {
496 | plots[i].muted = !plots[i].muted;
497 | }
498 | '''
499 | )
500 |
501 | if legend_loc is not None:
502 | fig.legend.location = legend_loc
503 | fig.legend.click_policy = 'mute'
504 |
505 | fig.xaxis.axis_label = key
506 | fig.yaxis.axis_label = 'normalized frequency'
507 | _set_plot_wh(fig, plot_width, plot_height)
508 |
509 | cols.append(column(slider, button, fig))
510 |
511 | if _bokeh_version > (1, 0, 4):
512 | from bokeh.layouts import grid
513 | plot = grid(children=cols, ncols=2)
514 | else:
515 | cols = list(map(list, np.array_split(cols, np.ceil(len(cols) / 2))))
516 | plot = layout(children=cols, sizing_mode='fixed', ncols=2)
517 |
518 | if save is not None:
519 | save = save if str(save).endswith('.html') else str(save) + '.html'
520 | bokeh_save(plot, save)
521 | else:
522 | show(plot)
523 |
524 |
525 | def thresholding_hist(adata, key, categories, basis=['umap'], components=[1, 2],
526 | bins='auto', palette=None, legend_loc='top_right',
527 | plot_width=None, plot_height=None, save=None):
528 | """Histogram with the option to highlight categories based on thresholding binned values.
529 |
530 | Params
531 | --------
532 | adata: AnnData object
533 | annotated data object
534 | key: str
535 | key in `adata.obs_keys()` where the data is stored
536 | categories: dict
537 | dictionary with keys corresponding to group names and values to starting boundaries `[min, max]`
538 | basis: list, optional (default: `['umap']`)
539 | basis in `adata.obsm_keys()` to visualize
540 | components: list(int); list(list(int)), optional (default: `[1, 2]`)
541 | components to use for each basis
542 | bins: int; str, optional (default: `auto`)
543 | number of bins used for initial binning or a string key used in from numpy.histogram
544 | palette: list(str), optional (default: `None`)
545 | palette to use for coloring categories
546 | legend_loc: str, default(`'top_right'`)
547 | position of the legend
548 | plot_width: int, optional (default: `None`)
549 | width of the plot
550 | plot_height: int, optional (default: `None`)
551 | height of the plot
552 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
553 | path where to save the plot
554 |
555 | Returns
556 | --------
557 | None
558 | """
559 |
560 | if not isinstance(components[0], list):
561 | components = [components]
562 |
563 | if len(components) != len(basis):
564 | assert len(basis) % len(components) == 0 and len(basis) >= len(components)
565 | components = components * (len(basis) // len(components))
566 |
567 | if not isinstance(components, np.ndarray):
568 | components = np.asarray(components)
569 |
570 | if not isinstance(basis, list):
571 | basis = [basis]
572 |
573 | palette = Set1[9] + Set2[8] + Set3[12] if palette is None else palette
574 |
575 | hist_fig = figure()
576 | _set_plot_wh(hist_fig, plot_width, plot_height)
577 |
578 | hist_fig.xaxis.axis_label = key
579 | hist_fig.yaxis.axis_label = 'normalized frequency'
580 | hist, edges = np.histogram(adata.obs[key], density=True, bins=bins)
581 |
582 | source = ColumnDataSource(data=dict(hist=hist, l_edges=edges[:-1], r_edges=edges[1:],
583 | category=['default'] * len(hist), indices=[[]] * len(hist)))
584 |
585 | df = pd.concat([pd.DataFrame(adata.obsm[f'X_{bs}'][:, comp - (bs != 'diffmap')], columns=[f'x_{bs}', f'y_{bs}'])
586 | for bs, comp in zip(basis, components)], axis=1)
587 | df['values'] = list(adata.obs[key])
588 | df['category'] = 'default'
589 | df['visible_category'] = 'default'
590 | df['cat_stack'] = [['default']] * len(df)
591 |
592 | orig = ColumnDataSource(df)
593 | color = dict(field='category', transform=CategoricalColorMapper(palette=palette, factors=list(categories.keys())))
594 | hist_fig.quad(source=source, top='hist', bottom=0,
595 | left='l_edges', right='r_edges', color=color,
596 | line_color="#555555", legend_group='category')
597 | if legend_loc is not None:
598 | hist_fig.legend.location = legend_loc
599 |
600 | emb_figs = []
601 | for bs, comp in zip(basis, components):
602 | fig = figure(title=bs)
603 |
604 | fig.xaxis.axis_label = f'{bs}_{comp[0]}'
605 | fig.yaxis.axis_label = f'{bs}_{comp[1]}'
606 | _set_plot_wh(fig, plot_width, plot_height)
607 |
608 | fig.scatter(f'x_{bs}', f'y_{bs}', source=orig, size=10, color=color, legend_group='category')
609 | if legend_loc is not None:
610 | fig.legend.location = legend_loc
611 |
612 | emb_figs.append(fig)
613 |
614 | inputs, category_cbs = [], []
615 | code_start, code_mid, code_thresh = [], [], []
616 | args = {'source': source, 'orig': orig}
617 |
618 | for col, cat_item in zip(palette, categories.items()):
619 | cat, (start, end) = cat_item
620 | inp_min = TextInput(name='test', value=f'{start}', title=f'{cat}/min')
621 | inp_max = TextInput(name='test', value=f'{end}', title=f'{cat}/max')
622 |
623 | code_start.append(f'''
624 | var min_{cat} = parseFloat(inp_min_{cat}.value);
625 | var max_{cat} = parseFloat(inp_max_{cat}.value);
626 | ''')
627 | code_mid.append(f'''
628 | var mid_{cat} = (source.data['r_edges'][i] - source.data['l_edges'][i]) / 2;
629 | ''')
630 | code_thresh.append(f'''
631 | if (source.data['l_edges'][i] + mid_{cat} >= min_{cat} && source.data['r_edges'][i] - mid_{cat} <= max_{cat}) {{
632 | source.data['category'][i] = '{cat}';
633 | for (var j = 0; j < source.data['indices'][i].length; j++) {{
634 | var ix = source.data['indices'][i][j];
635 | orig.data['category'][ix] = '{cat}';
636 | }}
637 | }}
638 | ''')
639 | args[f'inp_min_{cat}'] = inp_min
640 | args[f'inp_max_{cat}'] = inp_max
641 | min_ds = ColumnDataSource(dict(xs=[start] * 2))
642 | max_ds = ColumnDataSource(dict(xs=[end] * 2))
643 |
644 | inputs.extend([inp_min, inp_max])
645 |
646 | code_thresh.append(
647 | '''
648 | {
649 | source.data['category'][i] = 'default';
650 | for (var j = 0; j < source.data['indices'][i].length; j++) {
651 | var ix = source.data['indices'][i][j];
652 | orig.data['category'][ix] = 'default';
653 | }
654 | }
655 | ''')
656 | callback = CustomJS(args=args, code=f'''
657 | {';'.join(code_start)}
658 | for (var i = 0; i < source.data['hist'].length; i++) {{
659 | {';'.join(code_mid)}
660 | {' else '.join(code_thresh)}
661 | }}
662 | orig.change.emit();
663 | source.change.emit();
664 | ''')
665 |
666 | for input in inputs:
667 | input.js_on_change('value', callback)
668 |
669 | slider = Slider(start=1, end=100, value=len(hist), title='Bins')
670 | interactive_hist_cb = CustomJS(args={'source': source, 'orig': orig, 'bins': slider}, code=_inter_hist_js_code)
671 | slider.js_on_change('value', interactive_hist_cb, callback)
672 |
673 | plot = column(row(hist_fig, column(slider, *inputs)), *emb_figs)
674 |
675 | if save is not None:
676 | save = save if str(save).endswith('.html') else str(save) + '.html'
677 | bokeh_save(plot, save)
678 | else:
679 | show(plot)
680 |
681 |
682 | def gene_trend(adata, paths, genes=None, mode='gp', exp_key='X',
683 | separate_paths=False, show_cont_annot=False,
684 | extra_genes=[], n_points=100, show_zero_counts=True,
685 | time_span=[None, None], use_raw=True,
686 | n_velocity_genes=5, length_scale=0.2,
687 | path_key='louvain', color_key='louvain',
688 | share_y=True, legend_loc='top_right',
689 | plot_width=None, plot_height=None, save=None, **kwargs):
690 | """
691 | Function which shows expression levels as well as velocity per gene as a function of DPT.
692 |
693 | Params
694 | --------
695 | adata: AnnData
696 | annotated data object
697 | paths: list(list(str))
698 | different paths to visualize
699 | genes: list, optional (default: `None`)
700 | list of genes to show, if `None` take `n_velocity` genes
701 | from `adata.var['velocity_genes']`
702 | mode: str, optional (default: `'gp'`)
703 | whether to use Kernel Ridge Regression (`'krr'`) or a Gaussian Process (`'gp'`) for
704 | smoothing the expression values
705 | exp_key: str, optional (default: `'X'`)
706 | key from adata.layers or just `'X'` to get expression values
707 | separate_paths: bool, optional (default: `False`)
708 | whether to show each path for each gene in a separate plot
709 | show_cont_annot: bool, optional (default: `False`)
710 | show continuous annotations in `adata.obs`,
711 | only works if `color_key` is continuous variable
712 | extra_genes: list(str), optional (default: `[]`)
713 | list of possible genes to color in,
714 | only works if `color_key` is continuous variable
715 | n_points: int, optional (default: `100`)
716 | how many points to use for the smoothing
717 | time_span: list(int), optional (default `[None, None]`)
718 | initial and final start values for range, `None` corresponds to min/max
719 | use_raw: bool, optional (default: `True`)
720 | whether to use adata.raw to get the expression
721 | show_zero_counts: bool, optional (default: `True`)
722 | whether to show cells with zero counts
723 | n_velocity_genes: int, optional (default: `5`)
724 | number of genes to take from` adata.var['velocity_genes']`
725 | length_scale : float, optional (default `0.2`)
726 | length scale for RBF kernel
727 | path_key: str, optional (default: `'louvain'`)
728 | key in `adata.obs_keys()` where to look for groups specified in `paths` argument
729 | color_key: str, optional (default: `'louvain'`)
730 | key in `adata.obs_keys()` which is color in plot
731 | share_y: bool, optional (default: `True`)
732 | whether to share y-axis when plotting paths separately
733 | legend_loc: str, default(`'top_right'`)
734 | position of the legend
735 | plot_width: int, optional (default: `None`)
736 | width of the plot
737 | plot_height: int, optional (default: `None`)
738 | height of the plot
739 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
740 | path where to save the plot
741 | **kwargs: kwargs
742 | keyword arguments for KRR or GP
743 |
744 | Returns
745 | --------
746 | None
747 | """
748 |
749 | if mode == 'krr':
750 | warnings.warn('KRR is experimental; please consider using mode=`gp`')
751 |
752 | for path in paths:
753 | for p in path:
754 | assert p in adata.obs[path_key].cat.categories, f'`{p}` is not in `adata.obs[path_key]`. Possible values are: `{list(adata.obs[path_key].cat.categories)}`.'
755 |
756 | # check the input
757 | if 'dpt_pseudotime' not in adata.obs.keys():
758 | raise ValueError('`dpt_pseudotime` is not in `adata.obs.keys()`')
759 |
760 | # check the genes list
761 | if genes is None:
762 | genes = adata[:, adata.var['velocity_genes']].var_names[:n_velocity_genes]
763 |
764 | genes_indicator = np.in1d(genes, adata.var_names) #[gene in adata.var_names for gene in genes]
765 | if not all(genes_indicator):
766 | genes_missing = np.array(genes)[np.invert(genes_indicator)]
767 | print(f'Could not find the following genes: `{genes_missing}`.')
768 | genes = list(np.array(genes)[genes_indicator])
769 |
770 | mapper = _create_mapper(adata, color_key)
771 | figs, adatas = [], []
772 |
773 | for gene in genes:
774 | data = defaultdict(list)
775 | row_figs = []
776 | y_lim_min, y_lim_max = np.inf, -np.inf
777 | for path in paths:
778 | path_ix = np.in1d(adata.obs[path_key], path)
779 | ad = adata[path_ix].copy()
780 |
781 | minn, maxx = time_span
782 | ad.obs['dpt_pseudotime'] = ad.obs['dpt_pseudotime'].replace(np.inf, 1)
783 | minn = np.min(ad.obs['dpt_pseudotime']) if minn is None else minn
784 | maxx = np.max(ad.obs['dpt_pseudotime']) if maxx is None else maxx
785 |
786 | # wish I could get rid of this copy
787 | ad = ad[(ad.obs['dpt_pseudotime'] >= minn) & (ad.obs['dpt_pseudotime'] <= maxx)]
788 |
789 | gene_exp = ad[:, gene].layers[exp_key] if exp_key != 'X' else (ad.raw if use_raw else ad)[:, gene].X
790 |
791 | # exclude dropouts
792 | ix = (gene_exp > 0).squeeze()
793 | indexer = slice(None) if show_zero_counts else ix
794 | # just use for sanity check with asserts
795 | rev_indexer = ix if show_zero_counts else slice(None)
796 |
797 | dpt = ad.obs['dpt_pseudotime']
798 |
799 | if issparse(gene_exp):
800 | gene_exp = gene_exp.A
801 |
802 | gene_exp = np.squeeze(gene_exp[indexer, None])
803 | data['expr'].append(gene_exp)
804 | y_lim_min, y_lim_max = min(y_lim_min, np.min(gene_exp)), max(y_lim_max, np.max(gene_exp))
805 |
806 | # compute smoothed values from expression
807 | data['dpt'].append(np.squeeze(dpt[indexer, None]))
808 | data[color_key].append(np.array(ad[indexer].obs[color_key]))
809 |
810 | assert all(gene_exp[rev_indexer] > 0)
811 |
812 | if len(gene_exp[rev_indexer]) == 0:
813 | print(f'All counts are 0 for: `{gene}`.')
814 | continue
815 |
816 | x_test, exp_mean, exp_cov = _smooth_expression(np.expand_dims(dpt[ix], -1), gene_exp[ix if show_zero_counts else slice(None)], mode=mode,
817 | time_span=time_span, n_points=n_points, kernel_params=dict(k=dict(length_scale=length_scale)),
818 | **kwargs)
819 |
820 | data['x_test'].append(x_test)
821 | data['x_mean'].append(exp_mean)
822 | data['x_cov'].append(exp_cov)
823 |
824 | # we need this for the _create mapper
825 | adatas.append(ad[indexer])
826 |
827 | if separate_paths:
828 | dataframe = pd.DataFrame(data, index=list(map(lambda path: ', '.join(map(str, path)), [path])))
829 | row_figs.append(_create_gt_fig(adatas, dataframe, color_key, title=gene, color_mapper=mapper,
830 | show_cont_annot=show_cont_annot, legend_loc=legend_loc, genes=extra_genes,
831 | use_raw=use_raw, plot_width=plot_width, plot_height=plot_height))
832 | adatas = []
833 | data = defaultdict(list)
834 |
835 | if separate_paths:
836 | if share_y:
837 | # first child is the figure
838 | for fig in map(lambda c: c.children[0], row_figs):
839 | fig.y_range = Range1d(y_lim_min - 0.1, y_lim_max + 0.1)
840 |
841 | figs.append(row(row_figs))
842 | row_figs = []
843 | else:
844 | dataframe = pd.DataFrame(data, index=list(map(lambda path: ', '.join(map(str, path)), paths)))
845 | figs.append(_create_gt_fig(adatas, dataframe, color_key, title=gene, color_mapper=mapper,
846 | show_cont_annot=show_cont_annot, legend_loc=legend_loc, genes=extra_genes,
847 | use_raw=use_raw, plot_width=plot_width, plot_height=plot_height))
848 |
849 | plot = column(*figs)
850 |
851 | if save is not None:
852 | save = save if str(save).endswith('.html') else str(save) + '.html'
853 | bokeh_save(plot, save)
854 | else:
855 | show(plot)
856 |
857 |
858 | def highlight_de(adata, basis='umap', components=[1, 2], n_top_genes=10,
859 | de_keys='names, scores, pvals_adj, logfoldchanges',
860 | cell_keys='', n_neighbors=5, fill_alpha=0.1, show_hull=True,
861 | legend_loc='top_right', plot_width=None, plot_height=None, save=None):
862 | """
863 | Highlight differential expression by hovering over clusters.
864 |
865 | Params
866 | --------
867 | adata: AnnData
868 | annotated data object
869 | basis: str, optional (default: `'umap'`)
870 | basis used in visualization
871 | components: list(int), optional (default: `[1, 2]`)
872 | components of the basis
873 | n_top_genes: int, optional (default: `10`)
874 | number of differentially expressed genes to display
875 | de_keys: list(str); str, optional (default: `'names, scores, pvals_ads, logfoldchanges'`)
876 | list or comma-seperated values of keys in `adata.uns['rank_genes_groups'].keys()`
877 | to be displayed for each cluster
878 | cell_keys: list(str); str, optional (default: '')
879 | keys in `adata.obs_keys()` to be displayed
880 | n_neighbors: int, optional (default: `5`)
881 | number of neighbors for KNN classifier, which
882 | controls how the convex hull looks like
883 | fill_alpha: float, optional (default: `0.1`)
884 | alpha value of the cluster colors
885 | show_hull: bool, optional (default: `True`)
886 | show the convex hull along each cluster
887 | legend_loc: str, default(`'top_right'`)
888 | position of the legend
889 | plot_width: int, optional (default: `None`)
890 | width of the plot
891 | plot_height: int, optional (default: `None`)
892 | height of the plot
893 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
894 | path where to save the plot
895 |
896 | Returns
897 | --------
898 | None
899 | """
900 |
901 | if 'rank_genes_groups' not in adata.uns_keys():
902 | raise ValueError('Run differential expression first.')
903 |
904 |
905 | if isinstance(de_keys, str):
906 | de_keys = list(dict.fromkeys(map(str.strip, de_keys.split(','))))
907 | if de_keys != ['']:
908 | assert all(map(lambda k: k in adata.uns['rank_genes_groups'].keys(), de_keys)), 'Not all keys are in `adata.uns[\'rank_genes_groups\']`.'
909 | else:
910 | de_keys = []
911 |
912 | if isinstance(cell_keys, str):
913 | cell_keys = list(dict.fromkeys(map(str.strip, cell_keys.split(','))))
914 | if cell_keys != ['']:
915 | assert all(map(lambda k: k in adata.obs.keys(), cell_keys)), 'Not all keys are in `adata.obs.keys()`.'
916 | else:
917 | cell_keys = []
918 |
919 | if f'X_{basis}' not in adata.obsm.keys():
920 | raise ValueError(f'Key `X_{basis}` not found in adata.obsm.')
921 |
922 | if not isinstance(components, np.ndarray):
923 | components = np.asarray(components)
924 |
925 | key = adata.uns['rank_genes_groups']['params']['groupby']
926 | if key not in cell_keys:
927 | cell_keys.insert(0, key)
928 |
929 | df = pd.DataFrame(adata.obsm[f'X_{basis}'][:, components - (basis != 'diffmap')], columns=['x', 'y'])
930 | for k in cell_keys:
931 | df[k] = list(map(str, adata.obs[k]))
932 |
933 | knn = neighbors.KNeighborsClassifier(n_neighbors)
934 | knn.fit(df[['x', 'y']], adata.obs[key])
935 | df['prediction'] = knn.predict(df[['x', 'y']])
936 |
937 | conv_hulls = df[df[key] == df['prediction']].groupby(key).apply(lambda df: df.iloc[ConvexHull(np.vstack([df['x'], df['y']]).T).vertices])
938 |
939 | mapper = _create_mapper(adata, key)
940 | categories = adata.obs[key].cat.categories
941 | fig = figure(tools='pan, reset, wheel_zoom, lasso_select, save')
942 | _set_plot_wh(fig, plot_width, plot_height)
943 | legend_dict = defaultdict(list)
944 |
945 | for k in categories:
946 | d = df[df[key] == k]
947 | data_source = ColumnDataSource(d)
948 | legend_dict[k].append(fig.scatter('x', 'y', source=data_source, color={'field': key, 'transform': mapper}, size=5, muted_alpha=0))
949 |
950 | hover_cell = HoverTool(renderers=[r[0] for r in legend_dict.values()], tooltips=[(f'{key}', f'@{key}')] + [(f'{k}', f'@{k}') for k in cell_keys[1:]])
951 |
952 | c_hulls = conv_hulls.copy()
953 | de_possible = conv_hulls[key].isin(adata.uns['rank_genes_groups']['names'].dtype.names)
954 | ok_patches = []
955 | prev_cat = []
956 | for i, isin in enumerate((~de_possible, de_possible)):
957 | conv_hulls = c_hulls[isin]
958 |
959 | if len(conv_hulls) == 0:
960 | continue
961 |
962 | # must use 'group' instead of key since key is MultiIndex
963 | conv_hulls.rename(columns={'louvain': 'group'}, inplace=True)
964 | xs, ys, ks = zip(*conv_hulls.groupby('group').apply(lambda df: list(map(list, (df['x'], df['y'], df['group'])))))
965 | tmp_data = defaultdict(list)
966 | tmp_data['xs'] = xs
967 | tmp_data['ys'] = ys
968 | tmp_data[key] = list(map(lambda k: k[0], ks))
969 |
970 | if i == 1:
971 | ix = list(map(lambda k: adata.uns['rank_genes_groups']['names'].dtype.names.index(k), tmp_data[key]))
972 | for k in de_keys:
973 | tmp = np.array(list(zip(*adata.uns['rank_genes_groups'][k])))[ix, :n_top_genes]
974 | for j in range(n_top_genes):
975 | tmp_data[f'{k}_{j}'] = tmp[:, j]
976 |
977 | tmp_data = pd.DataFrame(tmp_data)
978 | for k in categories:
979 | d = tmp_data[tmp_data[key] == k]
980 | source = ColumnDataSource(d)
981 |
982 | patches = fig.patches('xs', 'ys', source=source, fill_alpha=fill_alpha, muted_alpha=0, hover_alpha=0.5,
983 | color={'field': key, 'transform': mapper} if (show_hull and i == 1) else None,
984 | hover_color={'field': key, 'transform': mapper} if (show_hull and i == 1) else None)
985 | legend_dict[k].append(patches)
986 | if i == 1:
987 | ok_patches.append(patches)
988 |
989 | hover_group = HoverTool(renderers=ok_patches, tooltips=[(f'{key}', f'@{key}'),
990 | ('groupby', adata.uns['rank_genes_groups']['params']['groupby']),
991 | ('reference', adata.uns['rank_genes_groups']['params']['reference']),
992 | ('rank', ' | '.join(de_keys))] + [(f'#{i + 1}', ' | '.join((f'@{k}_{i}' for k in de_keys))) for i in range(n_top_genes)]
993 | )
994 |
995 |
996 | fig.toolbar.active_inspect = [hover_group]
997 | if len(cell_keys) > 1:
998 | fig.add_tools(hover_group, hover_cell)
999 | else:
1000 | fig.add_tools(hover_group)
1001 |
1002 | if legend_loc is not None:
1003 | legend = Legend(items=list(legend_dict.items()), location=legend_loc)
1004 | fig.add_layout(legend)
1005 | fig.legend.click_policy = 'hide' # hide does disable hovering, whereas 'mute' does not
1006 |
1007 | fig.xaxis.axis_label = f'{basis}_{components[0]}'
1008 | fig.yaxis.axis_label = f'{basis}_{components[1]}'
1009 |
1010 | if save is not None:
1011 | save = save if str(save).endswith('.html') else str(save) + '.html'
1012 | bokeh_save(fig, save)
1013 | else:
1014 | show(fig)
1015 |
1016 |
1017 | def link_plot(adata, key, genes=None, basis=['umap', 'pca'], components=[1, 2],
1018 | subsample=None, steps=[40, 40], sample_size=500,
1019 | distance=2, cutoff=True, highlight_only=None, palette=None,
1020 | show_legend=False, legend_loc='top_right', plot_width=None, plot_height=None, save=None):
1021 | """
1022 | Display the distances of cells from currently highlighted cell.
1023 |
1024 | Params
1025 | --------
1026 | adata: AnnData
1027 | annotated data object
1028 | key: str
1029 | key in `adata.obs_keys()` to color the static plot
1030 | genes: list(str), optional (default: `None`)
1031 | list of genes in `adata.var_names`,
1032 | which are used to compute the distance;
1033 | if None, take all the genes
1034 | basis: list(str), optional (default:`['umap', 'pca']`)
1035 | list of basis to use when plotting;
1036 | only the first plot is hoverable
1037 | components: list(int); list(list(int)), optional (default: `[1, 2]`)
1038 | list of components for each basis
1039 | subsample: str, optional (default: `None`)
1040 | subsample strategy to use when there are too many cells
1041 | possible values are: `"density"`, `"uniform"`, `None`
1042 | steps: int; list(int), optional (default: `[40, 40]`)
1043 | number of steps in each direction when using `subsample="uniform"`
1044 | sample_size: int, optional (default: `500`)
1045 | number of cells to sample based on their density in the respective embedding
1046 | when using `subsample="density"`; should be < `1000`
1047 | distance: int; str, optional (default: `2`)
1048 | for integers, use p-norm,
1049 | for strings, only `'dpt'` is available
1050 | cutoff: bool, optional (default: `True`)
1051 | if `True`, do not color cells whose distance is further away
1052 | than the threshold specified by the slider
1053 | highlight_only: 'str', optional (default: `None`)
1054 | key in `adata.obs_keys()`, which makes highlighting
1055 | work only on clusters specified by this parameter
1056 | palette: matplotlib.colors.Colormap; list(str), optional (default: `None`)
1057 | colormap to use, if None, use plt.cm.RdYlBu
1058 | show_legend: bool, optional (default: `False`)
1059 | display the legend also in the linked plot
1060 | legend_loc: str, optional (default `'top_right'`)
1061 | location of the legend
1062 | seed: int, optional (default: `None`)
1063 | seed when `subsample='density'`
1064 | plot_width: int, optional (default: `None`)
1065 | width of the plot
1066 | plot_height: int, optional (default: `None`)
1067 | height of the plot
1068 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
1069 | path where to save the plot
1070 |
1071 | Returns
1072 | --------
1073 | None
1074 | """
1075 |
1076 | assert key in adata.obs.keys(), f'`{key}` not found in `adata.obs`.'
1077 |
1078 | if subsample == 'uniform':
1079 | adata, _ = sample_unif(adata, steps, basis[0])
1080 | elif subsample == 'density':
1081 | adata, _ = sample_density(adata, sample_size, basis[0], seed=seed)
1082 | elif subsample is not None:
1083 | raise ValueError(f'Unknown subsample strategy: `{subsample}`.')
1084 |
1085 | palette = cm.RdYlBu if palette is None else palette
1086 | if isinstance(palette, matplotlib.colors.Colormap):
1087 | palette = to_hex_palette(palette(range(palette.N), 1., bytes=True))
1088 |
1089 | if not isinstance(components[0], list):
1090 | components = [components]
1091 |
1092 | if len(components) != len(basis):
1093 | assert len(basis) % len(components) == 0 and len(basis) >= len(components)
1094 | components = components * (len(basis) // len(components))
1095 |
1096 | if not isinstance(components, np.ndarray):
1097 | components = np.asarray(components)
1098 |
1099 | if highlight_only is not None:
1100 | assert highlight_only in adata.obs_keys(), f'`{highlight_only}` is not in adata.obs_keys().'
1101 |
1102 | genes = adata.var_names if genes is None else genes
1103 | gene_subset = np.in1d(adata.var_names, genes)
1104 |
1105 | if distance != 'dpt':
1106 | d = adata.X[:, gene_subset]
1107 | if issparse(d):
1108 | d = d.A
1109 | dmat = distance_matrix(d, d, p=distance)
1110 | else:
1111 | if not all(gene_subset):
1112 | warnings.warn('`genes` is not None, are you sure this is what you want when using `dpt` distance?')
1113 |
1114 | dmat = []
1115 | ad_tmp = adata.copy()
1116 | ad_tmp = ad_tmp[:, gene_subset]
1117 | for i in range(ad_tmp.n_obs):
1118 | ad_tmp.uns['iroot'] = i
1119 | sc.tl.dpt(ad_tmp)
1120 | dmat.append(list(ad_tmp.obs['dpt_pseudotime'].replace([np.nan, np.inf], [0, 1])))
1121 |
1122 | dmat = pd.DataFrame(dmat, columns=list(map(str, range(adata.n_obs))))
1123 | df = pd.concat([pd.DataFrame(adata.obsm[f'X_{bs}'][:, comp - (bs != 'diffmap')], columns=[f'x{i}', f'y{i}'])
1124 | for i, (bs, comp) in enumerate(zip(basis, components))] + [dmat], axis=1)
1125 | df['hl_color'] = np.nan
1126 | df['index'] = range(len(df))
1127 | df['hl_key'] = list(adata.obs[highlight_only]) if highlight_only is not None else 0
1128 | df[key] = list(map(str, adata.obs[key]))
1129 |
1130 | start_ix = '0' # our root cell
1131 | ds = ColumnDataSource(df)
1132 | mapper = linear_cmap(field_name='hl_color', palette=palette,
1133 | low=df[start_ix].min(), high=df[start_ix].max())
1134 | static_fig_mapper = _create_mapper(adata, key)
1135 |
1136 | static_figs = []
1137 | figs, renderers = [], []
1138 | for i, bs in enumerate(basis):
1139 | # linked plots
1140 | fig = figure(tools='pan, reset, save, ' + ('zoom_in, zoom_out' if i == 0 else 'wheel_zoom'),
1141 | title=bs, plot_width=400, plot_height=400)
1142 | _set_plot_wh(fig, plot_width, plot_height)
1143 |
1144 | kwargs = {}
1145 | if show_legend and legend_loc is not None:
1146 | kwargs['legend_group'] = 'hl_key' if highlight_only is not None else key
1147 |
1148 | scatter = fig.scatter(f'x{i}', f'y{i}', source=ds, line_color=mapper, color=mapper,
1149 | hover_color='black', size=8, line_width=8, line_alpha=0, **kwargs)
1150 | if show_legend and legend_loc is not None:
1151 | fig.legend.location = legend_loc
1152 |
1153 | figs.append(fig)
1154 | renderers.append(scatter)
1155 |
1156 | # static plots
1157 | fig = figure(title=bs, plot_width=400, plot_height=400)
1158 |
1159 | fig.scatter(f'x{i}', f'y{i}', source=ds, size=8,
1160 | color={'field': key, 'transform': static_fig_mapper}, **kwargs)
1161 |
1162 | if legend_loc is not None:
1163 | fig.legend.location = legend_loc
1164 |
1165 | static_figs.append(fig)
1166 |
1167 | fig = figs[0]
1168 |
1169 | end = dmat[~np.isinf(dmat)].max().max() if distance != 'dpt' else 1.0
1170 | slider = Slider(start=0, end=end, value=end / 2, step=end / 1000,
1171 | title='Distance ' + '(dpt)' if distance == 'dpt' else f'({distance}-norm)')
1172 | col_ds = ColumnDataSource(dict(value=[start_ix]))
1173 | update_color_code = f'''
1174 | source.data['hl_color'] = source.data[first].map(
1175 | (x, i) => {{ return isNaN(x) ||
1176 | {'x > slider.value || ' if cutoff else ''}
1177 | source.data['hl_key'][first] != source.data['hl_key'][i] ? NaN : x; }}
1178 | );
1179 | '''
1180 | slider.callback = CustomJS(args={'slider': slider, 'mapper': mapper['transform'], 'source': ds, 'col': col_ds}, code=f'''
1181 | mapper.high = slider.value;
1182 | var first = col.data['value'];
1183 | {update_color_code}
1184 | source.change.emit();
1185 | ''')
1186 |
1187 | h_tool = HoverTool(renderers=renderers, tooltips=[], show_arrow=False)
1188 | h_tool.callback = CustomJS(args=dict(source=ds, slider=slider, col=col_ds), code=f'''
1189 | var indices = cb_data.index['1d'].indices;
1190 | if (indices.length == 0) {{
1191 | source.data['hl_color'] = source.data['hl_color'];
1192 | }} else {{
1193 | var first = indices[0];
1194 | source.data['hl_color'] = source.data[first];
1195 | {update_color_code}
1196 | col.data['value'] = first;
1197 | col.change.emit();
1198 | }}
1199 | source.change.emit();
1200 | ''')
1201 | fig.add_tools(h_tool)
1202 |
1203 | color_bar = ColorBar(color_mapper=mapper['transform'], width=12, location=(0,0))
1204 | fig.add_layout(color_bar, 'left')
1205 |
1206 | fig.add_tools(h_tool)
1207 | plot = column(slider, row(*static_figs), row(*figs))
1208 |
1209 | if save is not None:
1210 | save = save if str(save).endswith('.html') else str(save) + '.html'
1211 | bokeh_save(plot, save)
1212 | else:
1213 | show(plot)
1214 |
1215 |
1216 | def _get_mappers(adata, df, genes=[], use_raw=True, sort=True):
1217 | if sort:
1218 | genes = sorted(genes)
1219 |
1220 | mappers = {c:{'field': c, 'transform': _create_mapper(adata, c)}
1221 | for c in (sorted if sort else list)(filter(lambda c: adata.obs[c].dtype.name != 'category', adata.obs.columns)) + genes}
1222 |
1223 | # assume all columns in .obs are numbers
1224 | for k in filter(lambda k: k not in genes, mappers.keys()):
1225 | df[k] = list(adata.obs[k].astype(float))
1226 |
1227 | indices, = np.where(np.in1d(adata.var_names, genes))
1228 | for ix in indices:
1229 | df[adata.var_names[ix]] = (adata.raw if use_raw else adata).X[:, ix]
1230 |
1231 | return df, mappers
1232 |
1233 |
1234 | def _add_color_select(key, fig, renderers, source, mappers, colors=['color', 'fill_color', 'line_color'],
1235 | color_bar_pos='right', suffix=''):
1236 | color_bar = ColorBar(color_mapper=mappers[key]['transform'], width=10, location=(0, 0))
1237 | fig.add_layout(color_bar, color_bar_pos)
1238 |
1239 | code = _inter_color_code(*colors)
1240 | callback= CustomJS(args=dict(renderers=renderers, source=source, color_bar=color_bar, cmaps=mappers),
1241 | code=code)
1242 |
1243 | return Select(title=f'Select variable to color{suffix}:', value=key,
1244 | options=list(mappers.keys()), callback=callback)
1245 |
--------------------------------------------------------------------------------
/interactive_plotting/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | from interactive_plotting.experimental.plots import heatmap, scatter2 as scatter
2 | from interactive_plotting.experimental.scatter3d import scatter3d
3 |
--------------------------------------------------------------------------------
/interactive_plotting/experimental/plots.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from interactive_plotting.utils._utils import *
4 |
5 | from pandas.api.types import is_categorical
6 | from collections import OrderedDict as odict
7 | from datashader.colors import *
8 | from bokeh.palettes import Viridis256
9 | from holoviews.streams import Selection1D
10 | from holoviews.operation import decimate
11 | from holoviews.operation.datashader import datashade, dynspread, rasterize
12 | from bokeh.models import HoverTool
13 |
14 | import numpy as np
15 | import pandas as pd
16 | import datashader as ds
17 | import holoviews as hv
18 | import warnings
19 |
20 |
21 | def pad(minn, maxx, padding=0.1):
22 | if minn > maxx:
23 | maxx, minn = minn, maxx
24 | delta = maxx - minn
25 |
26 | return minn - (delta * padding), maxx + (delta * padding)
27 |
28 |
29 | def minmax(component, perc=None, is_sorted=False):
30 | if perc is not None:
31 | assert len(perc) == 2, 'Percentile must be of length 2.'
32 | component = np.clip(component, *np.percentile(component, sorted(perc)))
33 |
34 | return (np.nanmin(component), np.nanmax(component)) if not is_sorted else (component[0], component[-1])
35 |
36 |
37 | def scatter2(adata, x, y, color, order_key=None, indices=None, layer=None, subsample='datashade', use_raw=False,
38 | size=5, jitter=None, perc=None, cmap=None,
39 | hover_keys=None, hover_dims=(10, 10), kde=None, density=None, density_size=150,
40 | keep_frac=0.2, steps=40, seed=None, use_original_limits=False,
41 | legend_loc='top_right', show_legend=True, plot_height=600, plot_width=600, save=None):
42 | '''
43 | Plot a scatterplot.
44 |
45 | Params
46 | -------
47 | adata: anndata.AnnData
48 | adata object
49 | x: Union[Int, Str]
50 | values on the x-axis
51 | if of type `NoneType`, x-axis will be based on `order_key`
52 | if of type `Int`, it corresponds to index in `adata.var_names`
53 | if of type `Str`, it can be either:
54 | - key in `adata.var_names`
55 | - key in `adata.obs`
56 | - `'{basis_name}:{index}'` or `'{basis_name}'`, such as
57 | `'{basis}:{index}'` where `'{basis}'` is a basis from `adata.obsm` and
58 | `'{index}'` is the number of the component to choose
59 | y: Union[Int, Str]
60 | values on the y-axis
61 | if of type `Int`, it corresponds to index in `adata.var_names`
62 | if of type `Str`, it can be either:
63 | - key in `adata.var_names`
64 | - key in `adata.obs`
65 | - `'{basis_name}:{index}'` or `'{basis_name}'`, such as
66 | `'{basis}:{index}'` where `'{basis}'` is a basis from `adata.obsm` and
67 | `'{index}'` is the number of the component to choose
68 | color: Union[Str, NoneType], optional (default: `None`)
69 | key in `adata.obs` to color in
70 | order_key: Union[Str, NoneType], optional (default: `None`)
71 | key in `adata.obs` which defines cell-ordering,
72 | such as `'dpt_pseudotime'`, cells will
73 | be sorted in ascending order
74 | indices: Union[np.array[Int], np.array[Bool], NoneType], optional (default: `None`)
75 | subset of cells to plot,
76 | if `None`, plot all cells
77 | layer: Union[Str, NoneType], optinal (default: `None`)
78 | key in `adata.layers`, used when `x` or `y` is a gene
79 | subsample: Str, optional (default: `'datashade'`)
80 | subsampling strategy for large data
81 | possible values are `None, 'none', 'datashade', 'decimate'`
82 | using `subsample='datashade'` is preferred over other option since it does not subset
83 | when using `subsample='datashade'`, colorbar is not visible
84 | use_raw: Bool, optional (default: `False`)
85 | whether to use `.raw` attribute
86 | size: Int, optional (default: `5`)
87 | size of the glyphs
88 | jitter: Union[Tuple[Float, Float], NoneType], optional (default: `None`)
89 | variance of normal distribution to use
90 | as a jittering
91 | perc: Union[List[Float, Float], NoneType], optional (default: `None`)
92 | percentiles for color clipping
93 | hover_keys: Union[List[Str], NoneType], optional (default: `[]`)
94 | keys in `adata.obs` to display when hovering over cells,
95 | if `None`, display nothing,
96 | if `[]`, display cell index,
97 | note that if `subsample='datashade'`, only cell index can be display
98 | kde: Union[Float, NoneType], optional (default: `None`)
99 | whether to show a kernel density estimate
100 | if a `Float`, it corresponds to the bandwidth
101 | density: Union[Str, None], optional (default: `None`)
102 | whether to show marginal x,y-densities,
103 | can be one of `'group'`, `'all'`
104 | if `'group'` and `color` is categorical variable,
105 | the density estimate is shown per category
106 | if `'all'`, the density is estimated using all points
107 | density_size: Int, optional (default: `150`)
108 | height and width of density plots
109 | hover_dims: Tuple[Int, Int], optional (default: `(10, 10)`)
110 | number of rows and columns of hovering tiles,
111 | only used when `subsample='datashade'`
112 | keep_frac: Float, optional (default: `0.2`)
113 | fraction of cells to keep, used when `subsample='decimate'`
114 | steps: Union[Int, Tuple[Int, Int]], optional (default: `40`)
115 | step size when the embedding directions
116 | larger step size corresponds to higher density of points
117 | seed: Union[Float, NoneType], optional (default: `None`)
118 | random seed, used when `subsample='decimate'`
119 | use_original_limits: Bool, optional (default: `False`)
120 | internal use only
121 | legend_loc: Str, optional (default: `'top_right'`)
122 | position of the legend
123 | show_legend:, Bool, optional (default: `True`)
124 | whether to show legend
125 | plot_height: Int, optional (default: `600`)
126 | height of the plot
127 | plot_width: Int, optional (default: `600`)
128 | width of the plot
129 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
130 | path where to save the plot
131 |
132 | Returns
133 | -------
134 | plot: hv.ScatterPlot
135 | a scatterplot
136 | '''
137 |
138 | if hover_keys is not None:
139 | for h in hover_keys:
140 | assert h in adata.obs, f'Hover key `{h}` not found in `adata.obs`.'
141 |
142 | if perc is not None:
143 | assert len(perc) == 2, f'`perc` must be of length `2`, found: `{len(perc)}`.'
144 | assert all((p is not None for p in perc)), '`perc` cannot contain `None`.'
145 | perc = sorted(perc)
146 |
147 | assert len(hover_dims) == 2, f'Expected `hover_dims` to be of length `2`, found `{len(hover_dims)}`.'
148 | assert all((isinstance(d, int) for d in hover_dims)), 'All of `hover_dims` must be of type `int`.'
149 | assert all((d > 1 for d in hover_dims)), 'All of `hover_dims` must be `> 1`.'
150 |
151 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.'
152 | assert subsample in (None, 'none', 'datashade', 'decimate'), f'Invalid subsampling strategy `{subsample}`. ' \
153 | 'Possible values are `None, \'none\', \'datashade\', \'decimate\'`.'
154 |
155 | adata_mraw = get_mraw(adata, use_raw and layer is None)
156 | if adata_mraw is adata and use_raw:
157 | warnings.warn('Failed fetching the `.raw`. attribute of `adata`.')
158 |
159 | if indices is None:
160 | indices = np.arange(adata_mraw.n_obs)
161 |
162 | xlim, ylim = None, None
163 | if order_key is not None:
164 | assert order_key in adata.obs, f'`{order_key}` not found in `adata.obs`.'
165 | ixs = np.argsort(adata.obs[order_key][indices])
166 | else:
167 | ixs = np.arange(len(indices))
168 |
169 | if x is None:
170 | if order_key is not None:
171 | x, xlabel = adata.obs[order_key][indices][ixs], order_key
172 | else:
173 | x, xlabel = ixs, 'index'
174 | else:
175 | x, xlabel, xlim = get_xy_data(x, adata, adata_mraw, layer, indices, use_original_limits)
176 |
177 | y, ylabel, ylim = get_xy_data(y, adata, adata_mraw, layer, indices, use_original_limits, inc=1)
178 |
179 | # jitter
180 | if jitter is not None:
181 | msg = 'Using `jitter` != `None` and `use_original_limits=True` can negatively impact the limits.'
182 | if isinstance(jitter, (tuple, list)):
183 | assert len(jitter) == 2, f'`jitter` must be of length `2`, found `{len(jitter)}`.'
184 | if any((j is not None for j in jitter)) and use_original_limits:
185 | warnings.warn(msg)
186 | else:
187 | assert isinstance(jitter, float), 'Expected` jitter` to be of type `float`, found `{type(jitter).__name__}`.'
188 | warnings.warn(msg)
189 |
190 | x = x.astype(np.float64)
191 | y = y.astype(np.float64)
192 |
193 | adata_mraw = adata_mraw[indices, :]
194 | if color is not None:
195 | if color in adata.obs:
196 | condition = adata.obs[color][indices][ixs]
197 | else:
198 | if isinstance(color, int):
199 | color = adata_mraw.var_names[color]
200 | condition = adata_mraw.obs_vector(color)[ixs]
201 | else:
202 | condition = None
203 |
204 | hover = None
205 | if hover_keys is not None:
206 | hover = {'index': ixs}
207 | for key in hover_keys:
208 | hover[key] = adata.obs[key][indices][ixs]
209 |
210 | plot = _scatter(adata, x=x.copy(), y=y.copy(),
211 | condition=condition, by=color,
212 | xlabel=xlabel, ylabel=ylabel,
213 | title=color, hover=hover, jitter=jitter,
214 | perc=perc, xlim=xlim, ylim=ylim,
215 | hover_width=hover_dims[1], hover_height=hover_dims[0], kde=kde, density=density, density_size=density_size,
216 | subsample=subsample, steps=steps, keep_frac=keep_frac, seed=seed, legend_loc=legend_loc,
217 | size=size, cmap=cmap, show_legend=show_legend, plot_height=plot_height, plot_width=plot_width)
218 |
219 | if save is not None:
220 | hv.renderer('bokeh').save(plot, save)
221 |
222 | return plot
223 |
224 |
225 | def _scatter(adata, x, y, condition, by=None, subsample='datashade', steps=40, keep_frac=0.2,
226 | seed=None, legend_loc='top_right', size=4, xlabel=None, ylabel=None, title=None,
227 | use_raw=True, hover=None, hover_width=10, hover_height=10, kde=None,
228 | density=None, density_size=150, jitter=None, perc=None, xlim=None, ylim=None,
229 | cmap=None, show_legend=True, plot_height=400, plot_width=400):
230 |
231 | _sentinel = object()
232 |
233 | def create_density_plots(df, density, kdims, cmap):
234 | cm = {}
235 | if density == 'all':
236 | dfs = {_sentinel: df}
237 | elif density == 'group':
238 | if 'z' not in df.columns:
239 | warnings.warn(f'`density=\'groups\' was specified, but no group found. Did you specify `color=...`?')
240 | dfs = {_sentinel: df}
241 | elif not is_categorical(df['z']):
242 | warnings.warn(f'`density=\'groups\' was specified, but column `{condition}` is not categorical.')
243 | dfs = {_sentinel: df}
244 | else:
245 | dfs = {k:v for k, v in df.groupby('z')}
246 | cm = cmap
247 | else:
248 | raise ValueError(f'Invalid `density` type: \'`{density}`\'. Possible values are `\'all\'`, `\'group\'`.')
249 | # assumes x, y order in kdims
250 | return [hv.Overlay([hv.Distribution(df, kdims=dim).opts(color=cm.get(k, 'black'),
251 | framewise=True)
252 | for k, df in dfs.items()])
253 | for dim in kdims]
254 |
255 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.'
256 |
257 | adata_mraw = get_mraw(adata, use_raw)
258 |
259 | if subsample == 'uniform':
260 | cb_kwargs = {'steps': steps}
261 | elif subsample == 'density':
262 | cb_kwargs = {'size': int(keep_frac * adata.n_obs), 'seed': seed}
263 | else:
264 | cb_kwargs = {}
265 |
266 | categorical = False
267 | if condition is None:
268 | cmap = ['black'] * len(x) if subsample == 'datashade' else 'black'
269 | elif is_categorical(condition):
270 | categorical = True
271 | cmap = Sets1to3 if cmap is None else cmap
272 | cmap = odict(zip(condition.cat.categories, adata.uns.get(f'{by}_colors', cmap)))
273 | else:
274 | cmap = Viridis256 if cmap is None else cmap
275 |
276 | jitter_x, jitter_y = None, None
277 | if isinstance(jitter, (tuple, list)):
278 | assert len(jitter) == 2, f'`jitter` must be of length `2`, found `{len(jitter)}`.'
279 | jitter_x, jitter_y = jitter
280 | elif jitter is not None:
281 | jitter_x, jitter_y = jitter, jitter
282 |
283 | if jitter_x is not None:
284 | x += np.random.normal(0, jitter_x, size=x.shape)
285 | if jitter_y is not None:
286 | y += np.random.normal(0, jitter_y, size=y.shape)
287 |
288 | data = {'x': x, 'y': y, 'z': condition}
289 | vdims = ['z']
290 |
291 | hovertool = None
292 | if hover is not None:
293 | for k, dt in hover.items():
294 | vdims.append(k)
295 | data[k] = dt
296 | hovertool = HoverTool(tooltips=[(key.capitalize(), f'@{key}')
297 | for key in (['index'] if subsample == 'datashade' else hover.keys())])
298 |
299 | data = pd.DataFrame(data)
300 | if categorical:
301 | data['z'] = data['z'].astype('category')
302 |
303 | if not vdims:
304 | vdims = None
305 |
306 | if xlim is None:
307 | xlim = pad(*minmax(x))
308 | if ylim is None:
309 | ylim = pad(*minmax(y))
310 |
311 | kdims=[('x', 'x' if xlabel is None else xlabel),
312 | ('y', 'y' if ylabel is None else ylabel)]
313 |
314 | scatter = hv.Scatter(data, kdims=kdims, vdims=vdims).sort('z')
315 | scatter = scatter.opts(size=size, xlim=xlim, ylim=ylim)
316 |
317 | kde_plot= None if kde is None else \
318 | hv.Bivariate(scatter).opts(bandwidth=kde, show_legend=False, line_width=2)
319 | xdist, ydist = (None, None) if density is None else create_density_plots(data, density, kdims, cmap)
320 |
321 | if categorical:
322 | scatter = scatter.opts(cmap=cmap, color='z', show_legend=show_legend, legend_position=legend_loc)
323 | elif 'z' in data:
324 | scatter = scatter.opts(cmap=cmap, color='z',
325 | clim=tuple(map(float, minmax(data['z'], perc))),
326 | colorbar=True,
327 | colorbar_opts={'width': 20})
328 | else:
329 | scatter = scatter.opts(color='black')
330 |
331 | legend = None
332 | if subsample == 'datashade':
333 | subsampled = dynspread(datashade(scatter, aggregator=(ds.count_cat('z') if categorical else ds.mean('z')) if vdims is not None else None,
334 | color_key=cmap, cmap=cmap,
335 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize],
336 | min_alpha=255).opts(axiswise=True, framewise=True), threshold=0.8, max_px=5)
337 | if show_legend and categorical:
338 | legend = hv.NdOverlay({k: hv.Points([0, 0], label=str(k)).opts(size=0, color=v)
339 | for k, v in cmap.items()})
340 | if hover is not None:
341 | t = hv.util.Dynamic(rasterize(scatter, width=hover_width, height=hover_height, streams=[hv.streams.RangeXY],
342 | aggregator=ds.reductions.min('index')), operation=hv.QuadMesh)\
343 | .opts(tools=[hovertool], axiswise=True, framewise=True,
344 | alpha=0, hover_alpha=0.25,
345 | height=plot_height, width=plot_width)
346 | scatter = t * subsampled
347 | else:
348 | scatter = subsampled
349 |
350 | elif subsample == 'decimate':
351 | scatter = decimate(scatter, max_samples=int(adata.n_obs * keep_frac),
352 | streams=[hv.streams.RangeXY(transient=True)], random_seed=seed)
353 |
354 | if legend is not None:
355 | scatter = (scatter * legend).opts(legend_position=legend_loc)
356 |
357 | if kde_plot is not None:
358 | scatter *= kde_plot
359 |
360 | scatter = scatter.opts(height=plot_height, width=plot_width)
361 | scatter = scatter.opts(hv.opts.Scatter(tools=[hovertool])) if hovertool is not None else scatter
362 |
363 | if xdist is not None and ydist is not None:
364 | scatter = (scatter << ydist.opts(width=density_size)) << xdist.opts(height=density_size)
365 |
366 | return scatter.opts(title=title if title is not None else '')
367 |
368 |
369 | def _heatmap(adata, genes, group, sort_genes=True, use_raw=False,
370 | agg_fns=['mean'], hover=True,
371 | xrotation=90, yrotation=0, colorbar=True, cmap=None,
372 | plot_height=300, plot_width=600):
373 | '''
374 | Internal heatmap function.
375 |
376 | Params
377 | -------
378 | adata: anndata.AnnData
379 | adata object
380 | genes: List[Str]
381 | genes in `adata.var_names`
382 | group: Str
383 | key in `adata.obs`, must be categorical
384 | sort_genes: Bool, optional (default: `True`)
385 | whether to sort the genes
386 | use_raw: Bool, optional (default: `True`)
387 | whether to use `.raw` attribute
388 | agg_fns: List[Str], optional (default: `['mean']`)
389 | list of pandas' aggregation functions
390 | hover: Bool, optional (deault: `True`)
391 | whether to show hover information
392 | xrotation: Int, optional (default: `90`)
393 | rotation of labels on x-axis
394 | yrotation: Int, optional (default: `0`)
395 | rotation of labels on y-axis
396 | colorbar: Bool, optional (default: `True`)
397 | whether to show colorbar
398 | cmap: Union[List[Str], NoneType], optional (default, `None`)
399 | colormap of the heatmap,
400 | if `None`, use `Viridis256`
401 | plot_height: Int, optional (default: `600`)
402 | height of the heatmap
403 | plot_width: Int, optional (default: `200`)
404 | width of the heatmap
405 |
406 | Returns
407 | -------
408 | plot: hv.HeatMap
409 | a heatmap
410 | '''
411 |
412 | assert group in adata.obs
413 | assert is_categorical(adata.obs[group])
414 | assert len(agg_fns) > 0
415 |
416 | for g in genes:
417 | assert g in adata.var_names, f'Unable to find gene `{g}` in `adata.var_names`.'
418 |
419 | genes = sorted(genes) if sort_genes else genes
420 | groups = sorted(list(adata.obs[group].cat.categories))
421 |
422 | adata_mraw = get_mraw(adata, use_raw)
423 | common_subset = list(set(adata.obs_names) & set(adata_mraw.obs_names))
424 | adata, adata_mraw = adata[common_subset, :], adata_mraw[common_subset, :]
425 |
426 | ixs = np.in1d(adata.obs[group], groups)
427 | adata, adata_mraw = adata[ixs, :], adata_mraw[ixs, :]
428 |
429 | df = pd.DataFrame(adata_mraw[:, genes].X, columns=genes)
430 | df['group'] = list(map(str, adata.obs[group]))
431 | groupby = df.groupby('group')
432 |
433 | vals = {agg_fn: groupby.agg(agg_fn) for agg_fn in agg_fns}
434 | z_value = vals.pop(agg_fns[0])
435 |
436 | x = hv.Dimension('x', label='Gene')
437 | y = hv.Dimension('y', label='Group')
438 | z = hv.Dimension('z', label='Expression')
439 | vdims = [(k, k.capitalize()) for k in vals.keys()]
440 |
441 | heatmap = hv.HeatMap({'x': np.array(genes), 'y': np.array(groups), 'z': z_value, **vals},
442 | kdims=[('x', 'Gene'), ('y', 'Group')],
443 | vdims=[('z', 'Expression')] + vdims).opts(tools=['box_select'] + (['hover'] if hover else []),
444 | xrotation=xrotation, yrotation=yrotation)
445 |
446 | return heatmap.opts(frame_width=plot_width, frame_height=plot_height, colorbar=colorbar, cmap=cmap)
447 |
448 |
449 | @wrap_as_col
450 | def heatmap(adata, genes, groups=None, compare='genes', agg_fns=['mean', 'var'], use_raw=False,
451 | order_keys=[], hover=True, show_highlight=False, show_scatter=False,
452 | subsample=None, keep_frac=0.2, seed=None,
453 | xrotation=90, yrotation=0, colorbar=True, cont_cmap=None,
454 | height=200, width=600, save=None, **scatter_kwargs):
455 | '''
456 | Plot a heatmap with groups selected from a drop-down menu.
457 | If `show_highlight=True` and `show_scatterplot=True`, additional
458 | interaction occurrs when clicking on the highlighted heatmap.
459 |
460 | Params
461 | -------
462 | adata: anndata.AnnData
463 | adata object
464 | genes: List[Str]
465 | genes in `adata.var_names`
466 | groups: List[Str], optional (default: `None`)
467 | categorical observation in `adata.obs`,
468 | if `None`, get all groups from `adata.obs`
469 | compare: Str, optional (default: `'genes'`)
470 | only used when `show_scatterplot=True`,
471 | creates a drop-down:
472 | if `'genes'`:
473 | drow-down menu will contain values from `genes` and clicking
474 | a gene in highlighted heatmap will plot scatterplot of the 2
475 | genes with groups colored in
476 | if `'basis'`:
477 | drow-down menu will contain available bases and clicking
478 | a gene in highlighted heatmap will plot the gene in the selected
479 | embedding with its expression colored in
480 | if `'order'`:
481 | drop-down menu will contain values from `order_keys`,
482 | and clicking on a gene in highlighted heatmap will plot its expression
483 | in selected order
484 | agg_fns: List[Str], optional (default: `['mean', 'var']`
485 | names of pandas' aggregation functions, such `'min'`, ...
486 | the first function specified is mapped to colors
487 | use_raw: Bool, optional (default: `False`)
488 | whether to use `.raw` for gene expression
489 | order_keys: List[Str], optional (default: `None`)
490 | keys in `adata.obs`, used when `compare='order'`
491 | hover: Bool, optional (default: `True`)
492 | whether to display hover information over the heatmap
493 | show_highlight: Bool, optional (default: `False`)
494 | whether to show when using boxselect
495 | show_scatter: Bool, optional (default: `False`)
496 | whether to show a scatterplot,
497 | if `True`, overrides `show_highlight=False`
498 | subsample: Str, optional (default: `'decimate'`)
499 | subsampling strategy for large data
500 | possible values are `None, 'none', 'decimate'`
501 | keep_frac: Float, optional (default: `0.2`)
502 | fraction of cells to keep, used when `subsample='decimate'`
503 | seed: Union[Float, NoneType], optional (default: `None`)
504 | random seed, used when `subsample='decimate'`
505 | xrotation: Int, optional (default: `90`)
506 | rotation of labels on x-axis
507 | yrotation: Int, optional (default: `0`)
508 | rotation of labels on y-axis
509 | colorbar: Bool, optional (default: `True`)
510 | whether to show colorbar
511 | cont_cmap: Union[List[Str], NoneType], optional (default, `None`)
512 | colormap of the heatmap,
513 | if `None`, use `Viridis256`
514 | height: Int, optional (default: `200`)
515 | height of the heatmap
516 | width: Int, optional (default: `600`)
517 | width of the heatmap
518 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
519 | path where to save the plot
520 | **scatter_kwargs:
521 | additional argument for `ipl.experimental.scatter`,
522 | only used when `show_scatter=True`
523 |
524 | Returns
525 | -------
526 | holoviews plot
527 | '''
528 |
529 | def _highlight(group, index):
530 | original = hm[group]
531 | if not index:
532 | return original
533 |
534 | return original.iloc[sorted(index)]
535 |
536 | def _scatter(group, which, gwise, x, y):
537 | indices = adata.obs[group] == y if gwise else np.isin(adata.obs[group], highlight[group].data['y'])
538 | # this is necessary
539 | indices = np.where(indices)[0]
540 |
541 | if is_ordered:
542 | scatter_kwargs['order_key'] = which
543 | x, y = None, x
544 | elif f'X_{which}' in adata.obsm:
545 | group = x
546 | x, y = which, which
547 | else:
548 | x, y = x, which
549 |
550 | return scatter2(adata, x=x, y=y, color=group, indices=indices,
551 | **scatter_kwargs).opts(axiswise=True, framewise=True)
552 |
553 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.'
554 | assert subsample in (None, 'none','decimate'), f'Invalid subsampling strategy `{subsample}`. ' \
555 | 'Possible values are `None, \'none\', \'decimate\'`.'
556 |
557 | assert compare in ('genes', 'basis', 'order'), f'`compare` must be one of `\'genes\', \'basis\', \'order\'`.'
558 |
559 | if cont_cmap is None:
560 | cont_cmap = Viridis256
561 |
562 | is_ordered = False
563 | scatter_kwargs['use_original_limits'] = True
564 | scatter_kwargs['subsample'] = None
565 | if 'plot_width' not in scatter_kwargs:
566 | scatter_kwargs['plot_width'] = 300
567 | if 'plot_height' not in scatter_kwargs:
568 | scatter_kwargs['plot_height'] = 300
569 |
570 | if groups is not None:
571 | assert len(groups) > 0, f'Number of groups `> 1`.'
572 | else:
573 | groups = [k for k in adata.obs.keys() if is_categorical(adata.obs[k])]
574 |
575 | kdims=[hv.Dimension('Group',values=groups, default=groups[0])]
576 |
577 | hm = hv.DynamicMap(lambda g: _heatmap(adata, genes, agg_fns=agg_fns, group=g,
578 | hover=hover, use_raw=use_raw,
579 | cmap=cont_cmap,
580 | xrotation=xrotation, yrotation=yrotation,
581 | colorbar=colorbar), kdims=kdims).opts(frame_height=height,
582 | frame_width=width)
583 | if not show_highlight and not show_scatter:
584 | return hm
585 |
586 | highlight = hv.DynamicMap(_highlight, kdims=kdims, streams=[Selection1D(source=hm)])
587 | if not show_scatter:
588 | return (hm + highlight).cols(1)
589 |
590 | if compare == 'basis':
591 | basis = [b.lstrip('X_') for b in adata.obsm.keys()]
592 | kdims += [hv.Dimension('Components', values=basis, default=basis[0])]
593 | elif compare == 'genes':
594 | kdims += [hv.Dimension('Genes', values=genes, default=genes[0])]
595 | else:
596 | is_ordered = True
597 | k = scatter_kwargs.pop('order_key', None)
598 | assert k is not None or order_keys != [], f'No order keys specified.'
599 |
600 | if k is not None and k not in order_keys:
601 | order_keys.append(k)
602 |
603 | for k in order_keys:
604 | assert k in adata.obs, f'Order key `{k}` not found in `adata.obs`.'
605 |
606 | kdims += [hv.Dimension('Order', values=order_keys)]
607 |
608 | kdims += [hv.Dimension('Groupwise', type=bool, values=[True, False], default=True)]
609 |
610 | scatter_stream = hv.streams.Tap(source=highlight, x=genes[0], y=adata.obs[groups[0]].values[0])
611 | scatter = hv.DynamicMap(_scatter, kdims=kdims, streams=[scatter_stream])
612 |
613 | if subsample == 'decimate':
614 | scatter = decimate(scatter, max_samples=int(adata.n_obs * keep_frac),
615 | streams=[hv.streams.RangeXY(transient=True)], random_seed=seed)
616 |
617 | plot = (hm + highlight + scatter).cols(1)
618 |
619 | if save is not None:
620 | hv.rendered('bokeh').save(plot, save)
621 |
622 | return plot
623 |
624 |
--------------------------------------------------------------------------------
/interactive_plotting/experimental/scatter3d.py:
--------------------------------------------------------------------------------
1 | from interactive_plotting.utils import *
2 |
3 | from bokeh.core.properties import Any, Dict, Instance, String
4 | from bokeh.models import (
5 | ColumnDataSource,
6 | LayoutDOM,
7 | Legend,
8 | LegendItem,
9 | ColorBar,
10 | LinearColorMapper,
11 | FixedTicker
12 | )
13 | from bokeh.io import save
14 | from bokeh.resources import CDN
15 | from bokeh.layouts import row
16 | from bokeh.plotting import figure
17 | from bokeh.colors import RGB
18 | from pandas.api.types import is_categorical_dtype
19 | from anndata import AnnData
20 | from typing import Union, Optional, Sequence, Tuple
21 | from time import sleep
22 | from collections import defaultdict
23 |
24 | import matplotlib
25 | import matplotlib.cm as cm
26 | import numpy as np
27 | import webbrowser
28 | import tempfile
29 |
30 |
31 | _DEFAULT = {
32 | 'width': '600px',
33 | 'height': '600px',
34 | 'style': 'dot-color',
35 | 'showPerspective': False,
36 | 'showGrid': True,
37 | 'keepAspectRatio': True,
38 | 'verticalRatio': 1.0,
39 | 'cameraPosition': {
40 | 'horizontal': 1,
41 | 'vertical': 0.25,
42 | 'distance': 2,
43 | }
44 | }
45 |
46 |
47 | class Surface3d(LayoutDOM):
48 | __implementation__ = 'surface3d.ts'
49 | __javascript__ = 'https://unpkg.com/vis-graph3d@latest/dist/vis-graph3d.min.js'
50 |
51 | data_source = Instance(ColumnDataSource)
52 |
53 | x = String
54 | y = String
55 | z = String
56 | color = String
57 |
58 | options = Dict(String, Any, default=_DEFAULT)
59 |
60 |
61 | def _to_hex_colors(values, cmap, perc=None):
62 | minn, maxx = minmax(values, perc)
63 | norm = matplotlib.colors.Normalize(vmin=minn, vmax=maxx, clip=True)
64 |
65 | mapper = cm.ScalarMappable(norm=norm, cmap=cmap)
66 |
67 | return [matplotlib.colors.to_hex(mapper.to_rgba(v)) for v in values], minn, maxx
68 |
69 |
70 | def _mpl_to_hex_palette(cmap):
71 | if isinstance(cmap, matplotlib.colors.ListedColormap):
72 | rgb_cmap = (255 * cmap(range(256))).astype('int')
73 | return [RGB(*tuple(rgb)).to_hex() for rgb in rgb_cmap]
74 |
75 | assert all(map(lambda c: matplotlib.colors.is_color_like(c), cmap)), 'Not all colors are color-like.'
76 |
77 | return [matplotlib.colors.to_hex(c) for c in cmap]
78 |
79 |
80 |
81 | def scatter3d(adata: AnnData,
82 | key: str,
83 | basis: str = 'umap',
84 | components: Sequence[int] = (0, 1, 2),
85 | steps: Union[Tuple[int, int], int] = 100,
86 | perc: Optional[Tuple[int, int]] = None,
87 | n_ticks: int = 10,
88 | vertical_ratio: float = 1,
89 | show_axes: bool = False,
90 | keep_aspect_ratio: bool = True,
91 | perspective: bool = True,
92 | tooltips: Optional[Sequence[str]] = [],
93 | cmap: Optional[matplotlib.colors.ListedColormap] = None,
94 | dot_size_ratio: float = 0.01,
95 | show_legend: bool = True,
96 | show_cbar: bool = True,
97 | plot_height: Optional[int] = 1400,
98 | plot_width: Optional[int] = 1400):
99 | """
100 | Parameters
101 | ----------
102 | adata : :class:`anndata.AnnData`
103 | Annotated data object.
104 | key
105 | Key in `adata.obs` or `adata.var_names` to color in.
106 | basis
107 | Basis to use.
108 | components
109 | Components of the basis to plot.
110 | steps
111 | Step size when the subsampling the data.
112 | Larger step size corresponds to higher density of points.
113 | perc
114 | Percentile by which to clip colors.
115 | n_ticks
116 | Number of ticks for colorbar if `key` is not categorical.
117 | vertical_ratio
118 | Ratio by which to squish the z-axis.
119 | show_axes
120 | Whether to show axes.
121 | keep_aspect_ratio
122 | Whether to keep aspect ratio.
123 | perspective
124 | Whether to keep the perspective.
125 | tooltips
126 | Keys in `adata.obs` to visualize when hovering over cells.
127 | cmap
128 | Colormap to use.
129 | dot_size_ratio
130 | Ratio of the dots with respect to the plot size.
131 | show_legend
132 | Whether to show legend when annotation is categorical.
133 | show_cbar
134 | Whether to show colorbar when annotation is continuous.
135 | plot_height
136 | Height of the plot in pixels. If `None`, try getting the screen height.
137 | plot_width
138 | Width of the plot in pixels. If `None`, try getting the screen width.
139 |
140 | Returns
141 | -------
142 | None
143 | Nothing, just plots in a new tab.
144 | """
145 |
146 | def _wrap_as_div(row, sep=':'):
147 | res = []
148 | for kind, val in zip(tooltips, row):
149 | if isinstance(val, float):
150 | res.append(f'
{kind}{sep} {val:.04f}
')
151 | else:
152 | res.append(f'{kind}{sep} {val}
')
153 |
154 | return ''.join(res)
155 |
156 | basis_key = f'X_{basis}'
157 | assert basis_key in adata.obsm, f'Basis `{basis_key}` not found in `adata.obsm`.'
158 | if perc is not None:
159 | assert len(perc) == 2, f'Percentile must be of length `2`, found `{len(perc)}`.'
160 | assert len(components) == 3, f'Number of components must be `3`, found `{len(components)}`.'
161 | assert all(c >= 0 for c in components), f'All components must be non-negative, found `{min(components)}`.'
162 | assert max(components) < adata.obsm[basis_key].shape[-1], \
163 | f'Component `{max(components)}` is >= than number of components `{adata.obsm[basis_key].shape[-1]}`.'
164 | assert key in adata.obs or key in adata.var_names, f'Key `{key}` not found in `adata.obs` or `adata.var_names`.'
165 |
166 | colors = adata.uns.get(f'{key}_colors', None)
167 | if steps is not None:
168 | # this somehow destroys the colors
169 | adata, _ = sample_unif(adata, steps, bs=basis, components=components)
170 |
171 | data = dict(x=adata.obsm[basis_key][:, components[0]],
172 | y=adata.obsm[basis_key][:, components[1]],
173 | z=adata.obsm[basis_key][:, components[2]])
174 |
175 | fig = figure(tools=[], outline_line_width=0, toolbar_location='left', disabled=True)
176 | to_add = None
177 |
178 | if key in adata.obs and is_categorical_dtype(adata.obs[key]):
179 | if cmap is None:
180 | cmap = colors or cm.tab20b
181 |
182 | hex_palette = _mpl_to_hex_palette(cmap)
183 |
184 | mapper = defaultdict(lambda: '#AAAAAA', zip(adata.obs[key].cat.categories, hex_palette))
185 | colors = [str(mapper[c]) for c in adata.obs[key]]
186 |
187 | n_cls = len(adata.obs[key].cat.categories)
188 | _ = fig.circle([0] * n_cls, [0] * n_cls,
189 | color=list(mapper.values()),
190 | visible=False, radius=0)
191 | if show_legend:
192 | to_add = Legend(items=[
193 | LegendItem(label=str(c), index=i, renderers=[_])
194 | for i, c in enumerate(mapper.keys())
195 | ])
196 | else:
197 | vals = adata.obs_vector(key) if key in adata.var_names else adata.obs[key]
198 |
199 | cmap = cm.viridis if cmap is None else cmap
200 | colors, minn, maxx = _to_hex_colors(vals, cmap, perc=perc)
201 | hex_palette = _mpl_to_hex_palette(cmap)
202 |
203 | _ = fig.circle(0, 0, visible=False, radius=0)
204 |
205 | color_mapper = LinearColorMapper(palette=hex_palette, low=minn, high=maxx)
206 | if show_cbar:
207 | to_add = ColorBar(color_mapper=color_mapper, ticker=FixedTicker(ticks=np.linspace(minn, maxx, n_ticks)),
208 | label_standoff=12, border_line_color=None, location=(0, 0))
209 |
210 | data['color'] = colors
211 | if tooltips is None:
212 | tooltips = adata.obs_keys()
213 | if len(tooltips):
214 | data['tooltip'] = adata.obs[tooltips].apply(_wrap_as_div, axis=1)
215 |
216 | source = ColumnDataSource(data=data)
217 | if plot_width is None or plot_height is None:
218 | try:
219 | import screeninfo
220 | for monitor in screeninfo.get_monitors():
221 | break
222 | plot_width = max(monitor.width - 300, 0) if plot_width is None else plot_width
223 | plot_height = max(monitor.height, 300) if plot_height is None else plot_height
224 | except ImportError:
225 | print('Unable to get the screen size, please install package `screeninfo` as `pip install screeninfo`.')
226 | plot_width = 1200 if plot_width is None else plot_width
227 | plot_height = 1200 if plot_height is None else plot_height
228 | except:
229 | plot_width = 1200 if plot_width is None else plot_width
230 | plot_height = 1200 if plot_height is None else plot_height
231 |
232 | surface = Surface3d(x="x", y="y", z="z", color="color",
233 | data_source=source, options={**_DEFAULT,
234 | **dict(dotSizeRatio=dot_size_ratio,
235 | showXAxis=show_axes,
236 | showYAxis=show_axes,
237 | showZAxis=show_axes,
238 | xLabel=f'{basis}_{components[0]}',
239 | yLabel=f'{basis}_{components[1]}',
240 | zLabel=f'{basis}_{components[2]}',
241 | showPerspective=perspective,
242 | height=f'{plot_height}px',
243 | width=f'{plot_width}px',
244 | verticalRatio=vertical_ratio,
245 | keepAspectRatio=keep_aspect_ratio,
246 | showLegend=False,
247 | tooltip='tooltip' in data,
248 | xCenter='50%',
249 | yCenter='50%',
250 | showGrid=show_axes)})
251 | if to_add is not None:
252 | fig.add_layout(to_add, 'left')
253 |
254 | # dirty little trick, makes plot disappear
255 | # ideally, one would modify the DOM in the .ts file but I'm just lazy
256 | fig.xgrid.visible = False
257 | fig.ygrid.visible = False
258 | fig.xaxis.visible = False
259 | fig.yaxis.visible = False
260 |
261 | with tempfile.NamedTemporaryFile(suffix='.html') as fout:
262 | path = save(row(surface, fig), fout.name, resources=CDN, title=f'Scatter3D - {key}')
263 | fout.flush()
264 | webbrowser.open_new_tab(path)
265 | sleep(2) # better safe than sorry
266 |
--------------------------------------------------------------------------------
/interactive_plotting/experimental/surface3d.ts:
--------------------------------------------------------------------------------
1 | import {HTMLBox, HTMLBoxView} from "models/layouts/html_box"
2 | import {ColumnDataSource} from "models/sources/column_data_source"
3 | import * as p from "core/properties"
4 |
5 | declare namespace vis {
6 | class Graph3d {
7 | constructor(el: HTMLElement, data: object, OPTIONS: object)
8 | setData(data: vis.DataSet): void
9 | }
10 |
11 | class DataSet {
12 | add(data: unknown): void
13 | }
14 | }
15 |
16 | function _tooltip(obj: {x: number, y: number, z: number, data: {tooltip: string}}) {
17 | return obj.data["tooltip"];
18 | }
19 |
20 | const OPTIONS = {
21 | width: '1200px',
22 | height: '1200px',
23 | style: 'dot-color',
24 | showPerspective: true,
25 | tooltip: _tooltip,
26 | showGrid: false,
27 | showXAxis: false,
28 | showYAxis: false,
29 | showZAxis: false,
30 | keepAspectRatio: true,
31 | verticalRatio: 1.0,
32 | cameraPosition: {
33 | horizontal: 1,
34 | vertical: 0.25,
35 | distance: 2.0,
36 | },
37 | }
38 |
39 | export class Surface3dView extends HTMLBoxView {
40 | model: Surface3d
41 |
42 | private _graph: vis.Graph3d
43 |
44 | render(): void {
45 | super.render()
46 | if (this.model.options["tooltip"]) { // we want to show tooltips
47 | this.model.options["tooltip"] = _tooltip
48 | }
49 | this._graph = new vis.Graph3d(this.el, this.get_data(), this.model.options)
50 | }
51 |
52 | connect_signals(): void {
53 | super.connect_signals()
54 | this.connect(this.model.data_source.change, () => this._graph.setData(this.get_data()))
55 | }
56 |
57 | get_data(): vis.DataSet {
58 | const data = new vis.DataSet()
59 | const source = this.model.data_source
60 |
61 | if ("tooltip" in source.data) {
62 | for (let i = 0; i < source.get_length()!; i++) {
63 | data.add({
64 | x: source.data[this.model.x][i],
65 | y: source.data[this.model.y][i],
66 | z: source.data[this.model.z][i],
67 | style: source.data[this.model.color][i],
68 | tooltip: source.data["tooltip"][i]
69 | })
70 | }
71 | } else {
72 | for (let i = 0; i < source.get_length()!; i++) {
73 | data.add({
74 | x: source.data[this.model.x][i],
75 | y: source.data[this.model.y][i],
76 | z: source.data[this.model.z][i],
77 | style: source.data[this.model.color][i],
78 | })
79 | }
80 | }
81 |
82 | return data
83 | }
84 | }
85 |
86 | export namespace Surface3d {
87 | export type Attrs = p.AttrsOf
88 |
89 | export type Props = HTMLBox.Props & {
90 | x: p.Property
91 | y: p.Property
92 | z: p.Property
93 | color: p.Property
94 | data_source: p.Property
95 | options: p.Property<{[key: string]: unknown}>
96 | }
97 | }
98 |
99 | export interface Surface3d extends Surface3d.Attrs {}
100 |
101 | export class Surface3d extends HTMLBox {
102 | properties: Surface3d.Props
103 |
104 | constructor(attrs?: Partial) {
105 | super(attrs)
106 | }
107 |
108 | static __name__ = "Surface3d"
109 |
110 | static init_Surface3d(): void {
111 | this.prototype.default_view = Surface3dView
112 |
113 | this.define({
114 | x: [ p.String ],
115 | y: [ p.String ],
116 | z: [ p.String ],
117 | color: [ p.String ],
118 | data_source: [ p.Instance ],
119 | options: [ p.Any, OPTIONS ]
120 | })
121 | }
122 | }
123 |
--------------------------------------------------------------------------------
/interactive_plotting/holoviews_plots.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from interactive_plotting.utils._utils import *
4 |
5 | from collections import OrderedDict as odict
6 |
7 | from pandas.api.types import is_categorical_dtype, is_string_dtype, infer_dtype
8 | from scipy.sparse import issparse
9 | from functools import partial
10 | from bokeh.palettes import Viridis256
11 | from datashader.colors import Sets1to3
12 | from pandas.core.indexes.base import Index
13 | from holoviews.operation.datashader import datashade, bundle_graph, shade, dynspread, rasterize, spread
14 | from holoviews.operation import decimate
15 | from bokeh.models import HoverTool
16 |
17 | import scanpy as sc
18 | import numpy as np
19 | import pandas as pd
20 | import networkx as nx
21 | import holoviews as hv
22 | import datashader as ds
23 | import warnings
24 |
25 |
26 | try:
27 | assert callable(sc.tl.dpt)
28 | dpt_fn = sc.tl.dpt
29 | except AssertionError:
30 | from scanpy.api.tl import dpt as dpt_fn
31 |
32 | #TODO: DRY
33 |
34 | @wrap_as_panel
35 | def scatter(adata, genes=None, basis=None, components=(1, 2), obs_keys=None,
36 | obsm_keys=None, use_raw=False, subsample='datashade', steps=40, keep_frac=None, lazy_loading=True,
37 | default_obsm_ixs=[0], sort=True, skip=True, seed=None, cols=None, size=4,
38 | perc=None, show_perc=True, cmap=None, plot_height=400, plot_width=400, save=None):
39 | '''
40 | Scatter plot for continuous observations.
41 |
42 | Params
43 | --------
44 | adata: anndata.Anndata
45 | anndata object
46 | genes: List[Str], optional (default: `None`)
47 | list of genes to add for visualization
48 | if `None`, use `adata.var_names`
49 | basis: Union[Str, List[Str]], optional (default: `None`)
50 | basis in `adata.obsm`, if `None`, get all available
51 | components: Union[List[Int], List[List[Int]]], optional (default: `[1, 2]`)
52 | components of specified `basis`
53 | if it's of type `List[Int]`, all the basis have use the same components
54 | obs_keys: List[Str], optional (default: `None`)
55 | keys of categorical observations in `adata.obs`
56 | if `None`, get all available
57 | obsm_keys: List[Str], optional (default: `None`)
58 | keys of categorical observations in `adata.obsm`
59 | if `None`, get all available
60 | use_raw: Bool, optional (default: `False`)
61 | use `adata.raw` for gene expression levels
62 | subsample: Str, optional (default: `'datashade'`)
63 | subsampling strategy for large data
64 | possible values are `None, 'none', 'datashade', 'decimate', 'density', 'uniform'`
65 | using `subsample='datashade'` is preferred over other options since it does not subset
66 | when using `subsample='datashade'`, colorbar is not visible
67 | `'density'` and `'uniform'` use first element of `basis` for their computation
68 | steps: Union[Int, Tuple[Int, Int]], optional (default: `40`)
69 | step size when the embedding directions
70 | larger step size corresponds to higher density of points
71 | keep_frac: Float, optional (default: `adata.n_obs / 5`)
72 | number of observations to keep when `subsample='decimate'`
73 | lazy_loading: Bool, optional (default: `False`)
74 | only visualize when necessary
75 | for notebook sharing, consider using `lazy_loading=False`
76 | default_obsm_ixs: List[Int], optional (default: `[0]`)
77 | indices of 2-D elements in `adata.obsm` to add
78 | when `obsm_keys=None`
79 | by default adds only 1st column
80 | sort: Bool, optional (default: `True`)
81 | whether sort the `genes`, `obs_keys` and `obsm_keys`
82 | in ascending order
83 | skip: Bool, optional (default: `True`)
84 | skip all the keys not found in the corresponding collections
85 | seed: Int, optional (default: `None`)
86 | random seed, used when `subsample='decimate'``
87 | cols: Int, optional (default: `2`)
88 | number of columns when plotting basis
89 | if `None`, use togglebar
90 | size: Int, optional (default: `4`)
91 | size of the glyphs
92 | works only when `subsample != 'datashade'`
93 | perc: List[Float], optional (default: `None`)
94 | percentile for colors
95 | useful when `lazy_loading = False`
96 | works only when `subsample != 'datashade'`
97 | show_perc: Bool, optional (default: `True`)
98 | show percentile slider when `lazy_loading = True`
99 | works only when `subsample != 'datashade'`
100 | cmap: List[Str], optional (default: `bokeh.palettes.Viridis256`)
101 | continuous colormap in hex format
102 | plot_height: Int, optional (default: `400`)
103 | height of the plot in pixels
104 | plot_width: Int, optional (default: `400`)
105 | width of the plot in pixels
106 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
107 | path where to save the plot
108 |
109 | Returns
110 | --------
111 | plot: panel.panel
112 | holoviews plot wrapped in `panel.panel`
113 | '''
114 |
115 | def create_scatterplot(gene, perc_low, perc_high, *args, bs=None):
116 | ixs = np.where(basis == bs)[0][0]
117 | is_diffmap = bs == 'diffmap'
118 |
119 | if len(args) > 0:
120 | ixs = np.where(basis == bs)[0][0] * 2
121 | comp = (np.array([args[ixs], args[ixs + 1]]) - (not is_diffmap)) % adata.obsm[f'X_{bs}'].shape[-1]
122 | else:
123 | comp = np.array(components[ixs]) # need to make a copy
124 |
125 | if perc_low is not None and perc_high is not None:
126 | if perc_low > perc_high:
127 | perc_low, perc_high = perc_high, perc_low
128 | perc = [perc_low, perc_high]
129 | else:
130 | perc = None
131 |
132 | ad, _ = alazy[bs, tuple(comp)]
133 | ad_mraw = ad.raw if use_raw else ad
134 |
135 | # because diffmap has small range, it iterferes with
136 | # the legend created
137 | emb = ad.obsm[f'X_{bs}'][:, comp] * (1000 if is_diffmap else 1)
138 | comp += not is_diffmap # naming consistence
139 |
140 | bsu = bs.upper()
141 | x = hv.Dimension('x', label=f'{bsu}{comp[0]}')
142 | y = hv.Dimension('y', label=f'{bsu}{comp[1]}')
143 |
144 | #if ignore_after is not None and ignore_after in gene:
145 | if gene in ad.obsm.keys():
146 | data = ad.obsm[gene][:, 0]
147 | elif gene in ad.obs.keys():
148 | data = ad.obs[gene].values
149 | elif gene in ad_mraw.var_names:
150 | data = ad_mraw.obs_vector(gene)
151 | else:
152 | gene, ix = gene.split(ignore_after)
153 | ix = int(ix)
154 | data = ad.obsm[gene][:, ix]
155 |
156 | data = np.array(data, dtype=np.float64)
157 |
158 | # we need to clip the data as well
159 | scatter = hv.Scatter({'x': emb[:, 0], 'y': emb[:, 1], 'gene': data},
160 | kdims=[x, y], vdims='gene')
161 |
162 | return scatter.opts(cmap=cmap, color='gene',
163 | colorbar=True,
164 | colorbar_opts={'width': CBW},
165 | size=size,
166 | clim=minmax(data, perc=perc),
167 | xlim=minmax(emb[:, 0]),
168 | ylim=minmax(emb[:, 1]),
169 | xlabel=f'{bsu}{comp[0]}',
170 | ylabel=f'{bsu}{comp[1]}')
171 |
172 | def _create_scatterplot_nl(bs, gene, perc_low, perc_high, *args):
173 | # arg switching
174 | return create_scatterplot(gene, perc_low, perc_high, *args, bs=bs)
175 |
176 | if perc is None:
177 | perc = [None, None]
178 | assert len(perc) == 2, f'Percentile must be of length 2, found `{len(perc)}`.'
179 | if all(map(lambda p: p is not None, perc)):
180 | perc = sorted(perc)
181 |
182 | if keep_frac is None:
183 | keep_frac = 0.2
184 |
185 | if basis is None:
186 | basis = np.ravel(sorted(filter(len, map(BS_PAT.findall, adata.obsm.keys()))))
187 | elif isinstance(basis, str):
188 | basis = np.array([basis])
189 | elif not isinstance(basis, np.ndarray):
190 | basis = np.array(basis)
191 |
192 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.'
193 | assert subsample in ALL_SUBSAMPLING_STRATEGIES, f'Invalid subsampling strategy `{subsample}`. Possible values are `{ALL_SUBSAMPLING_STRATEGIES}`.'
194 |
195 | if subsample == 'uniform':
196 | cb_kwargs = {'steps': steps}
197 | elif subsample == 'density':
198 | cb_kwargs = {'size': int(keep_frac * adata.n_obs), 'seed': seed}
199 | else:
200 | cb_kwargs = {}
201 | alazy = SamplingLazyDict(adata, subsample, callback_kwargs=cb_kwargs)
202 | adata_mraw = adata.raw if use_raw else adata # maybe raw
203 |
204 | if obs_keys is None:
205 | obs_keys = skip_or_filter(adata, adata.obs.keys(), adata.obs.keys(), dtype=is_numeric,
206 | where='obs', skip=True, warn=False)
207 | else:
208 | if not iterable(obs_keys):
209 | obs_keys = [obs_keys]
210 | obs_keys = skip_or_filter(adata, obs_keys, adata.obs.keys(), dtype=is_numeric,
211 | where='obs', skip=skip)
212 |
213 | if obsm_keys is None:
214 | obsm_keys = get_all_obsm_keys(adata, default_obsm_ixs)
215 | ignore_after = OBSM_SEP
216 | obsm_keys = skip_or_filter(adata, obsm_keys, adata.obsm.keys(), where='obsm',
217 | dtype=is_numeric, skip=True, warn=False, ignore_after=ignore_after)
218 | else:
219 | if not iterable(obsm_keys):
220 | obsm_keys = [obsm_keys]
221 |
222 | ignore_after = OBSM_SEP if any((OBSM_SEP in obs_key for obs_key in obsm_keys)) else None
223 | obsm_keys = skip_or_filter(adata, obsm_keys, adata.obsm.keys(), where='obsm',
224 | dtype=is_numeric, skip=skip, ignore_after=ignore_after)
225 |
226 | if genes is None:
227 | genes = adata_mraw.var_names
228 | elif not iterable(genes):
229 | genes = [genes]
230 | genes = skip_or_filter(adata_mraw, genes, adata_mraw.var_names, where='adata.var_names', skip=skip)
231 |
232 | if isinstance(genes, Index):
233 | genes = list(genes)
234 |
235 | if sort:
236 | if any(genes[i] > genes[i + 1] for i in range(len(genes) - 1)):
237 | genes = sorted(genes)
238 | if any(obs_keys[i] > obs_keys[i + 1] for i in range(len(obs_keys) - 1)):
239 | obs_keys = sorted(obs_keys)
240 | if any(obsm_keys[i] > obsm_keys[i + 1] for i in range(len(obsm_keys) - 1)):
241 | obsm_keys = sorted(obsm_keys)
242 |
243 | conditions = obs_keys + obsm_keys + genes
244 | if len(conditions) == 0:
245 | warnings.warn(f'Nothing to plot, no conditions found.')
246 | return
247 |
248 | if not isinstance(components, np.ndarray):
249 | components = np.array(components)
250 | if components.ndim == 1:
251 | components = np.repeat(components[np.newaxis, :], len(basis), axis=0)
252 |
253 | assert components.ndim == 2, f'Only `2` dimensional components are supported, got `{components.ndim}`.'
254 | assert components.shape[-1] == 2, f'Components\' second dimension must be of size `2`, got `{components.shape[-1]}`.'
255 | if not isinstance(basis, np.ndarray):
256 | basis = np.array(basis)
257 |
258 | assert components.shape[0] == len(basis), f'Expected #components == `{len(basis)}`, got `{components.shape[0]}`.'
259 | assert np.all(components >= 0), f'Currently, only positive indices are supported, found `{list(map(list, components))}`.'
260 |
261 | diffmap_ix = np.where(basis != 'diffmap')[0]
262 | components[diffmap_ix, :] -= 1
263 |
264 | for bs, comp in zip(basis, components):
265 | shape = adata.obsm[f'X_{bs}'].shape
266 | assert f'X_{bs}' in adata.obsm.keys(), f'`X_{bs}` not found in `adata.obsm`'
267 | assert shape[-1] > np.max(comp), f'Requested invalid components `{list(comp)}` for basis `X_{bs}` with shape `{shape}`.'
268 |
269 | if adata.n_obs > SUBSAMPLE_THRESH and subsample in NO_SUBSAMPLE:
270 | warnings.warn(f'Number of cells `{adata.n_obs}` > `{SUBSAMPLE_THRESH}`. Consider specifying `subsample={SUBSAMPLING_STRATEGIES}`.')
271 |
272 | if len(conditions) > HOLOMAP_THRESH and not lazy_loading:
273 | warnings.warn(f'Number of conditions `{len(conditions)}` > `{HOLOMAP_THRESH}`. Consider specifying `lazy_loading=True`.')
274 |
275 | if cmap is None:
276 | cmap = Viridis256
277 |
278 | kdims = [hv.Dimension('Basis', values=basis),
279 | hv.Dimension('Condition', values=conditions),
280 | hv.Dimension('Percentile (lower)', range=(0, 100), step=0.1, type=float, default=0 if perc[0] is None else perc[0]),
281 | hv.Dimension('Percentile (upper)', range=(0, 100), step=0.1, type=float, default=100 if perc[1] is None else perc[1])]
282 |
283 | cs = create_scatterplot
284 | _cs = _create_scatterplot_nl
285 | if not show_perc or subsample == 'datashade' or not lazy_loading:
286 | kdims = kdims[:2]
287 | cs = lambda gene, *args, **kwargs: create_scatterplot(gene, perc[0], perc[1], *args, **kwargs)
288 | _cs = lambda bs, gene, *args, **kwargs: _create_scatterplot_nl(bs, gene, perc[0], perc[1], *args, **kwargs)
289 |
290 | if not lazy_loading:
291 | dynmaps = [hv.HoloMap({(g, b):cs(g, bs=b) for g in conditions for b in basis}, kdims=kdims[::-1])]
292 | else:
293 | for bs, comp in zip(basis, components):
294 | kdims.append(hv.Dimension(f'{bs.upper()}[X]',
295 | type=int, default=1, step=1,
296 | range=(1, adata.obsm[f'X_{bs}'].shape[-1])))
297 | kdims.append(hv.Dimension(f'{bs.upper()}[Y]',
298 | type=int, default=2, step=1,
299 | range=(1, adata.obsm[f'X_{bs}'].shape[-1])))
300 | if cols is None:
301 | dynmaps = [hv.DynamicMap(_cs, kdims=kdims)]
302 | else:
303 | dynmaps = [hv.DynamicMap(partial(cs, bs=bs), kdims=kdims[1:]) for bs in basis]
304 |
305 | if subsample == 'datashade':
306 | dynmaps = [dynspread(datashade(d, aggregator=ds.mean('gene'), color_key='gene',
307 | cmap=cmap, streams=[hv.streams.RangeXY(transient=True)]),
308 | threshold=0.8, max_px=5)
309 | for d in dynmaps]
310 | elif subsample == 'decimate':
311 | dynmaps = [decimate(d, max_samples=int(adata.n_obs * keep_frac),
312 | streams=[hv.streams.RangeXY(transient=True)], random_seed=seed) for d in dynmaps]
313 |
314 | dynmaps = [d.opts(framewise=True, axiswise=True, frame_height=plot_height, frame_width=plot_width) for d in dynmaps]
315 |
316 | if cols is None:
317 | plot = dynmaps[0].opts(title='', frame_height=plot_height, frame_width=plot_width)
318 | else:
319 | plot = hv.Layout(dynmaps).opts(title='', height=plot_height, width=plot_width).cols(cols)
320 |
321 | if save is not None:
322 | hv.renderer('bokeh').save(plot, save)
323 |
324 | return plot
325 |
326 |
327 | @wrap_as_panel
328 | def scatterc(adata, basis=None, components=[1, 2], obs_keys=None,
329 | obsm_keys=None, subsample='datashade', steps=40, keep_frac=None, hover=False, lazy_loading=True,
330 | default_obsm_ixs=[0], sort=True, skip=True, seed=None, legend_loc='top_right', cols=None, size=4,
331 | cmap=None, show_legend=True, plot_height=400, plot_width=400, save=None):
332 | '''
333 | Scatter plot for categorical observations.
334 |
335 | Params
336 | --------
337 | adata: anndata.Anndata
338 | anndata object
339 | basis: Union[Str, List[Str]], optional (default: `None`)
340 | basis in `adata.obsm`, if `None`, get all available
341 | components: Union[List[Int], List[List[Int]]], optional (default: `[1, 2]`)
342 | components of specified `basis`
343 | if it's of type `List[Int]`, all the basis have use the same components
344 | obs_keys: List[Str], optional (default: `None`)
345 | keys of categorical observations in `adata.obs`
346 | if `None`, get all available
347 | obsm_keys: List[Str], optional (default: `None`)
348 | keys of categorical observations in `adata.obsm`
349 | if `None`, get all available
350 | subsample: Str, optional (default: `'datashade'`)
351 | subsampling strategy for large data
352 | possible values are `None, 'none', 'datashade', 'decimate', 'density', 'uniform'`
353 | using `subsample='datashade'` is preferred over other options since it does not subset
354 | when using `subsample='datashade'`, colorbar is not visible
355 | `'density'` and `'uniform'` use first element of `basis` for their computation
356 | steps: Union[Int, Tuple[Int, Int]], optional (default: `40`)
357 | step size when the embedding directions
358 | larger step size corresponds to higher density of points
359 | keep_frac: Float, optional (default: `adata.n_obs / 5`)
360 | number of observations to keep when `subsample='decimate'`
361 | hover: Union[Bool, Int], optional (default: `False`)
362 | whether to display cell index when hovering over a block
363 | if integer, it specifies the number of rows/columns (defualt: `10`)
364 | lazy_loading: Bool, optional (default: `False`)
365 | only visualize when necessary
366 | for notebook sharing, consider using `lazy_loading=False`
367 | sort: Bool, optional (default: `True`)
368 | whether sort the `genes`, `obs_keys` and `obsm_keys`
369 | in ascending order
370 | skip: Bool, optional (default: `True`)
371 | skip all the keys not found in the corresponding collections
372 | seed: Int, optional (default: `None`)
373 | random seed, used when `subsample='decimate'``
374 | legend_loc: Str, optional (default: `top_right`)
375 | position of the legend
376 | cols: Int, optional (default: `None`)
377 | number of columns when plotting basis
378 | if `None`, use togglebar
379 | size: Int, optional (default: `4`)
380 | size of the glyphs
381 | works only when `subsample!='datashade'`
382 | cmap: List[Str], optional (default: `datashader.colors.Sets1to3`)
383 | categorical colormap in hex format
384 | plot_height: Int, optional (default: `400`)
385 | height of the plot in pixels
386 | plot_width: Int, optional (default: `400`)
387 | width of the plot in pixels
388 | save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
389 | path where to save the plot
390 |
391 | Returns
392 | --------
393 | plot: panel.panel
394 | holoviews plot wrapped in `panel.panel`
395 | '''
396 |
397 | def create_legend(condition, bs):
398 | # slightly hacky solution to get the correct initial limits
399 | xlim = lims['x'][bs]
400 | ylim = lims['y'][bs]
401 |
402 | return hv.NdOverlay({k: hv.Points([0, 0], label=str(k)).opts(size=0, color=v, xlim=xlim, ylim=ylim) # alpha affects legend
403 | for k, v in cmaps[condition].items()})
404 |
405 | def add_hover(subsampled, dynmaps=None, by_block=True):
406 | hovertool = HoverTool(tooltips=[('Cell Index', '@index')])
407 | hover_width, hover_height = (10, 10) if isinstance(hover, bool) else (hover, hover)
408 |
409 | if by_block:
410 | if dynmaps is None:
411 | dynmaps = subsampled
412 |
413 | return [s * hv.util.Dynamic(rasterize(d, width=hover_width, height=hover_height, streams=[hv.streams.RangeXY],
414 | aggregator=ds.reductions.min('index')), operation=hv.QuadMesh)\
415 | .opts(tools=[hovertool], axiswise=True, framewise=True, alpha=0, hover_alpha=0.25,
416 | height=plot_height, width=plot_width)
417 | for s, d in zip(subsampled, dynmaps)]
418 |
419 | return [s.opts(tools=[hovertool]) for s in subsampled]
420 |
421 | def create_scatterplot(cond, *args, bs=None):
422 | ixs = np.where(basis == bs)[0][0]
423 | is_diffmap = bs == 'diffmap'
424 |
425 | if len(args) > 0:
426 | ixs = np.where(basis == bs)[0][0] * 2
427 | comp = (np.array([args[ixs], args[ixs + 1]]) - (not is_diffmap)) % adata.obsm[f'X_{bs}'].shape[-1]
428 | else:
429 | comp = np.array(components[ixs]) # need to make a copy
430 |
431 | # subsample is uniform or density
432 | ad, ixs = alazy[bs, tuple(comp)]
433 | # because diffmap has small range, it interferes with the legend
434 | emb = ad.obsm[f'X_{bs}'][:, comp] * (1000 if is_diffmap else 1)
435 | comp += not is_diffmap # naming consistence
436 |
437 | bsu = bs.upper()
438 | x = hv.Dimension('x', label=f'{bsu}{comp[0]}')
439 | y = hv.Dimension('y', label=f'{bsu}{comp[1]}')
440 |
441 | #if ignore_after is not None and ignore_after in gene:
442 | if cond in ad.obsm.keys():
443 | data = ad.obsm[cond][:, 0]
444 | elif cond in ad.obs.keys():
445 | data = ad.obs[cond]
446 | else:
447 | cond, ix = cond.split(ignore_after)
448 | ix = int(ix)
449 | data = ad.obsm[cond][:, ix]
450 |
451 | data = pd.Categorical(data).as_ordered()
452 | scatter = hv.Scatter({'x': emb[:, 0], 'y': emb[:, 1], 'cond': data, 'index': ixs},
453 | kdims=[x, y], vdims=['cond', 'index']).sort('cond')
454 |
455 | return scatter.opts(color_index='cond', cmap=cmaps[cond],
456 | show_legend=show_legend,
457 | legend_position=legend_loc,
458 | size=size,
459 | xlim=minmax(emb[:, 0]),
460 | ylim=minmax(emb[:, 1]),
461 | xlabel=f'{bsu}{comp[0]}',
462 | ylabel=f'{bsu}{comp[1]}')
463 |
464 | def _cs(bs, cond, *args):
465 | return create_scatterplot(cond, *args, bs=bs)
466 |
467 | if keep_frac is None:
468 | keep_frac = 0.2
469 |
470 | if basis is None:
471 | basis = np.ravel(sorted(filter(len, map(BS_PAT.findall, adata.obsm.keys()))))
472 | elif isinstance(basis, str):
473 | basis = np.array([basis])
474 | elif not isinstance(basis, np.ndarray):
475 | basis = np.array(basis)
476 |
477 | if not isinstance(hover, bool):
478 | assert hover > 1, f'Expected `hover` to be `> 1` when being an integer, found: `{hover}`.'
479 |
480 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.'
481 | assert subsample in ALL_SUBSAMPLING_STRATEGIES, f'Invalid subsampling strategy `{subsample}`. Possible values are `{ALL_SUBSAMPLING_STRATEGIES}`.'
482 |
483 | if subsample == 'uniform':
484 | cb_kwargs = {'steps': steps}
485 | elif subsample == 'density':
486 | cb_kwargs = {'size': int(keep_frac * adata.n_obs), 'seed': seed}
487 | else:
488 | cb_kwargs = {}
489 | alazy = SamplingLazyDict(adata, subsample, callback_kwargs=cb_kwargs)
490 |
491 | if obs_keys is None:
492 | obs_keys = skip_or_filter(adata, adata.obs.keys(), adata.obs.keys(),
493 | dtype='category', where='obs', skip=True, warn=False)
494 | else:
495 | if not iterable(obs_keys):
496 | obs_keys = [obs_keys]
497 |
498 | obs_keys = skip_or_filter(adata, obs_keys, adata.obs.keys(),
499 | dtype='category', where='obs', skip=skip)
500 |
501 | if obsm_keys is None:
502 | obsm_keys = get_all_obsm_keys(adata, default_obsm_ixs)
503 | obsm_keys = skip_or_filter(adata, obsm_keys, adata.obsm.keys(), where='obsm',
504 | dtype='category', skip=True, warn=False, ignore_after=OBSM_SEP)
505 | else:
506 | if not iterable(obsm_keys):
507 | obsm_keys = [obsm_keys]
508 |
509 | ignore_after = OBSM_SEP if any((OBSM_SEP in obs_key for obs_key in obsm_keys)) else None
510 | obsm_keys = skip_or_filter(adata, obsm_keys, adata.obsm.keys(), where='obsm',
511 | dtype='category', skip=skip, ignore_after=ignore_after)
512 |
513 | if sort:
514 | if any(obs_keys[i] > obs_keys[i + 1] for i in range(len(obs_keys) - 1)):
515 | obs_keys = sorted(obs_keys)
516 | if any(obsm_keys[i] > obsm_keys[i + 1] for i in range(len(obsm_keys) - 1)):
517 | obsm_keys = sorted(obsm_keys)
518 |
519 | conditions = obs_keys + obsm_keys
520 |
521 | if len(conditions) == 0:
522 | warnings.warn('Nothing to plot, no conditions found.')
523 | return
524 |
525 | if not isinstance(components, np.ndarray):
526 | components = np.array(components)
527 | if components.ndim == 1:
528 | components = np.repeat(components[np.newaxis, :], len(basis), axis=0)
529 |
530 | assert components.ndim == 2, f'Only `2` dimensional components are supported, got `{components.ndim}`.'
531 | assert components.shape[-1] == 2, f'Components\' second dimension must be of size `2`, got `{components.shape[-1]}`.'
532 |
533 | assert components.shape[0] == len(basis), f'Expected #components == `{len(basis)}`, got `{components.shape[0]}`.'
534 | assert np.all(components >= 0), f'Currently, only positive indices are supported, found `{list(map(list, components))}`.'
535 |
536 | diffmap_ix = np.where(basis != 'diffmap')[0]
537 | components[diffmap_ix, :] -= 1
538 |
539 | for bs, comp in zip(basis, components):
540 | shape = adata.obsm[f'X_{bs}'].shape
541 | assert f'X_{bs}' in adata.obsm.keys(), f'`X_{bs}` not found in `adata.obsm`'
542 | assert shape[-1] > np.max(comp), f'Requested invalid components `{list(comp)}` for basis `X_{bs}` with shape `{shape}`.'
543 |
544 | if adata.n_obs > SUBSAMPLE_THRESH and subsample in NO_SUBSAMPLE:
545 | warnings.warn(f'Number of cells `{adata.n_obs}` > `{SUBSAMPLE_THRESH}`. Consider specifying `subsample={SUBSAMPLING_STRATEGIES}`.')
546 |
547 | if len(conditions) > HOLOMAP_THRESH and not lazy_loading:
548 | warnings.warn(f'Number of conditions `{len(conditions)}` > `{HOLOMAP_THRESH}`. Consider specifying `lazy_loading=True`.')
549 |
550 | if cmap is None:
551 | cmap = Sets1to3
552 |
553 | lims = dict(x=dict(), y=dict())
554 | for bs in basis:
555 | emb = adata.obsm[f'X_{bs}']
556 | is_diffmap = bs == 'diffmap'
557 | if is_diffmap:
558 | emb = (emb * 1000).copy()
559 | lims['x'][bs] = minmax(emb[:, 0 + is_diffmap])
560 | lims['y'][bs] = minmax(emb[:, 1 + is_diffmap])
561 |
562 | kdims = [hv.Dimension('Basis', values=basis),
563 | hv.Dimension('Condition', values=conditions)]
564 |
565 | cmaps = dict()
566 | for cond in conditions:
567 | color_key = f'{cond}_colors'
568 | # use the datashader default cmap since setting it doesn't work (for multiple conditions)
569 | cmaps[cond] = odict(zip(adata.obs[cond].cat.categories, # adata.uns.get(color_key, cmap)))
570 | cmap if subsample == 'datashade' else adata.uns.get(color_key, cmap)))
571 | # this approach (for datashader) does not really work - the legend gets mixed up
572 | # cmap = dict(ChainMap(*[c.copy() for c in cmaps.values()]))
573 | # if len(cmap.keys()) != len([k for c in conditions for k in cmaps[c].keys()]):
574 | # warnings.warn('Found same key across multiple conditions. The colormap/legend may not accurately display the colors.')
575 |
576 | if not lazy_loading:
577 | # have to wrap because of the *args
578 | dynmaps = [hv.HoloMap({(c, b):create_scatterplot(c, bs=b) for c in conditions for b in basis}, kdims=kdims[::-1])]
579 | else:
580 | for bs, comp in zip(basis, components):
581 | kdims.append(hv.Dimension(f'{bs.upper()}[X]',
582 | type=int, default=1, step=1,
583 | range=(1, adata.obsm[f'X_{bs}'].shape[-1])))
584 | kdims.append(hv.Dimension(f'{bs.upper()}[Y]',
585 | type=int, default=2, step=1,
586 | range=(1, adata.obsm[f'X_{bs}'].shape[-1])))
587 |
588 | if cols is None:
589 | dynmaps = [hv.DynamicMap(_cs, kdims=kdims)]
590 | else:
591 | dynmaps = [hv.DynamicMap(partial(create_scatterplot, bs=bs), kdims=kdims[1:]) for bs in basis]
592 |
593 | legend = None
594 | if subsample == 'datashade':
595 | subsampled = [dynspread(datashade(d, aggregator=ds.count_cat('cond'), color_key=cmap,
596 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize],
597 | min_alpha=255).opts(axiswise=True, framewise=True), threshold=0.8, max_px=5)
598 | for d in dynmaps]
599 | dynmaps = add_hover(subsampled, dynmaps) if hover else subsampled
600 |
601 | if show_legend:
602 | warnings.warn('Automatic adjustment of axes is currently not working when '
603 | '`show_legend=True` and `subsample=\'datashade\'`.')
604 | legend = hv.DynamicMap(create_legend, kdims=kdims[:2][::-1])
605 | elif hover:
606 | warnings.warn('Automatic adjustment of axes is currently not working when hovering is enabled.')
607 |
608 | elif subsample == 'decimate':
609 | subsampled = [decimate(d, max_samples=int(adata.n_obs * keep_frac),
610 | streams=[hv.streams.RangeXY(transient=True)], random_seed=seed) for d in dynmaps]
611 | dynmaps = add_hover(subsampled, by_block=False) if hover else subsampled
612 | elif hover:
613 | dynmaps = add_hover(dynmaps, by_block=False)
614 |
615 | if cols is None:
616 | dynmap = dynmaps[0].opts(title='', frame_height=plot_height, frame_width=plot_width, axiswise=True, framewise=True)
617 | if legend is not None:
618 | dynmap = (dynmap * legend).opts(legend_position=legend_loc)
619 | else:
620 | if legend is not None:
621 | dynmaps = [(d * l).opts(legend_position=legend_loc)
622 | for d, l in zip(dynmaps, legend.layout('bs') if lazy_loading and show_legend else [legend] * len(dynmaps))]
623 |
624 | dynmap = hv.Layout([d.opts(axiswise=True, framewise=True,
625 | frame_height=plot_height, frame_width=plot_width) for d in dynmaps]).cols(cols)
626 |
627 | plot = dynmap.cols(cols).opts(title='', height=plot_height, width=plot_width) if cols is not None else dynmap
628 | if save is not None:
629 | hv.renderer('bokeh').save(plot, save)
630 |
631 | return plot
632 |
633 |
634 | @wrap_as_col
635 | def dpt(adata, key, genes=None, basis=None, components=[1, 2],
636 | subsample='datashade', steps=40, use_raw=False, keep_frac=None,
637 | sort=True, skip=True, seed=None, show_legend=True, root_cell_all=False,
638 | root_cell_hl=True, root_cell_bbox=True, root_cell_size=None, root_cell_color='orange',
639 | legend_loc='top_right', size=4, perc=None, show_perc=True, cat_cmap=None, cont_cmap=None,
640 | plot_height=400, plot_width=400, *args, **kwargs):
641 | '''
642 | Scatter plot for categorical observations.
643 |
644 | Params
645 | --------
646 | adata: anndata.Anndata
647 | anndata object
648 | key: Str
649 | key in `adata.obs`, `adata.obsm` or `adata.var_names`
650 | to be visualized in top right plot
651 | can be categorical or continuous
652 | genes: List[Str], optional (default: `None`)
653 | list of genes to add for visualization
654 | if `None`, use `adata.var_names`
655 | basis: Union[Str, List[Str]], optional (default: `None`)
656 | basis in `adata.obsm`, if `None`, get all available
657 | components: Union[List[Int], List[List[Int]]], optional (default: `[1, 2]`)
658 | components of specified `basis`
659 | if it's of type `List[Int]`, all the basis have use the same components
660 | use_raw: Bool, optional (default: `False`)
661 | use `adata.raw` for gene expression levels
662 | subsample: Str, optional (default: `'datashade'`)
663 | subsampling strategy for large data
664 | possible values are `None, 'none', 'datashade', 'decimate', 'density', 'uniform'`
665 | using `subsample='datashade'` is preferred over other options since it does not subset
666 | when using `subsample='datashade'`, colorbar is not visible
667 | `'density'` and `'uniform'` use first element of `basis` for their computation
668 | steps: Union[Int, Tuple[Int, Int]], optional (default: `40`)
669 | step size when the embedding directions
670 | larger step size corresponds to higher density of points
671 | keep_frac: Float, optional (default: `adata.n_obs / 5`)
672 | number of observations to keep when `subsample='decimate'`
673 | sort: Bool, optional (default: `True`)
674 | whether sort the `genes`, `obs_keys` and `obsm_keys`
675 | in ascending order
676 | skip: Bool, optional (default: `True`)
677 | skip all the keys not found in the corresponding collections
678 | seed: Int, optional (default: `None`)
679 | random seed, used when `subsample='decimate'``
680 | show_legend: Bool, optional (default: `True`)
681 | whether to show legend
682 | legend_loc: Str, optional (default: `top_right`)
683 | position of the legend
684 | cols: Int, optional (default: `None`)
685 | number of columns when plotting basis
686 | if `None`, use togglebar
687 | size: Int, optional (default: `4`)
688 | size of the glyphs
689 | works only when `subsample!='datashade'`
690 | perc: List[Float], optional (default: `None`)
691 | percentile for colors when `key` refers to continous observation
692 | works only when `subsample != 'datashade'`
693 | show_perc: Bool, optional (default: `True`)
694 | show percentile slider when `key` refers to continous observation
695 | works only when `subsample != 'datashade'`
696 | cat_cmap: List[Str], optional (default: `datashader.colors.Sets1to3`)
697 | categorical colormap in hex format
698 | used when `key` is categorical variable
699 | cont_cmap: List[Str], optional (default: `bokeh.palettes.Viridis256`)
700 | continuous colormap in hex format
701 | used when `key` is continuous variable
702 | root_cell_all: Bool, optional (default: `False`)
703 | show all root cells, even though they might not be in the embedding
704 | (e.g. when subsample='uniform' or 'density')
705 | otherwise only show in the embedding (based on the data of the 1st `basis`)
706 | root_cell_hl: Bool, optional (default: `True`)
707 | highlight the root cell
708 | root_cell_bbox: Bool, optional (default: `True`)
709 | show bounding box around the root cell
710 | root_cell_size: Int, optional (default `None`)
711 | size of the root cell, if `None`, it's `size * 2`
712 | root_cell_color: Str, optional (default: `red`)
713 | color of the root cell, can be in hex format
714 | plot_height: Int, optional (default: `400`)
715 | height of the plot in pixels
716 | plot_width: Int, optional (default: `400`)
717 | width of the plot in pixels
718 | *args, **kwargs:
719 | additional arguments for `sc.tl.dpt`
720 |
721 | Returns
722 | --------
723 | plot: panel.Column
724 | holoviews plot wrapped in `panel.Column`
725 | '''
726 |
727 | def create_scatterplot(root_cell, gene, bs, perc_low, perc_high, *args, typp='expr', ret_hl=False):
728 | ixs = np.where(basis == bs)[0][0]
729 | is_diffmap = bs == 'diffmap'
730 |
731 | if len(args) > 0:
732 | ixs = np.where(basis == bs)[0][0] * 2
733 | comp = (np.array([args[ixs], args[ixs + 1]]) - (not is_diffmap)) % adata.obsm[f'X_{bs}'].shape[-1]
734 | else:
735 | comp = np.array(components[ixs]) # need to make a copy
736 |
737 | ad, _ = alazy[bs, tuple(comp)]
738 | ad_mraw = ad.raw if use_raw else ad
739 |
740 | if perc_low is not None and perc_high is not None:
741 | if perc_low > perc_high:
742 | perc_low, perc_high = perc_high, perc_low
743 | perc = [perc_low, perc_high]
744 | else:
745 | perc = None
746 |
747 | # because diffmap has small range, it iterferes with
748 | # the legend created
749 | emb = ad.obsm[f'X_{bs}'][:, comp] * (1000 if is_diffmap else 1)
750 | comp += not is_diffmap # naming consistence
751 |
752 | bsu = bs.upper()
753 | x = hv.Dimension('x', label=f'{bsu}{comp[0]}')
754 | y = hv.Dimension('y', label=f'{bsu}{comp[1]}')
755 | xmin, xmax = minmax(emb[:, 0])
756 | ymin, ymax = minmax(emb[:, 1])
757 |
758 | # adata is the original, ad may be subsampled
759 | mask = np.in1d(adata.obs_names, ad.obs_names)
760 |
761 | if typp == 'emb_discrete':
762 | scatter = hv.Scatter({'x': emb[:, 0], 'y': emb[:, 1], 'condition': data[mask]},
763 | kdims=[x, y], vdims='condition').sort('condition')
764 |
765 | scatter = scatter.opts(title=key,
766 | color='condition',
767 | xlim=(xmin, xmax),
768 | ylim=(ymin, ymax),
769 | size=size,
770 | xlabel=f'{bsu}{comp[0]}',
771 | ylabel=f'{bsu}{comp[1]}')
772 |
773 | if is_cat:
774 | # we're manually creating legend (for datashade)
775 | return scatter.opts(cmap=cat_cmap, show_legend=False)
776 |
777 | return scatter.opts(colorbar=True, colorbar_opts={'width': CBW},
778 | cmap=cont_cmap, clim=minmax(data, perc=perc))
779 |
780 |
781 | if typp == 'root_cell_hl':
782 | # find the index of the root cell in maybe subsampled data
783 | rid = np.where(ad.obs_names == root_cell)[0]
784 | if not len(rid):
785 | return hv.Scatter([]).opts(axiswise=True, framewise=True)
786 |
787 | rid = rid[0]
788 | dx, dy = (xmax - xmin) / 25, (ymax - ymin) / 25
789 | rx, ry = emb[rid, 0], emb[rid, 1]
790 |
791 | root_cell_scatter = hv.Scatter({'x': emb[rid, 0], 'y': emb[rid, 1]}).opts(color=root_cell_color, size=root_cell_size)
792 | if root_cell_bbox:
793 | root_cell_scatter *= hv.Bounds((rx - dx, ry - dy, rx + dx, ry + dy)).opts(line_width=4, color=root_cell_color).opts(axiswise=True, framewise=True)
794 |
795 | return root_cell_scatter
796 |
797 |
798 | adata.uns['iroot'] = np.where(adata.obs_names == root_cell)[0][0]
799 | dpt_fn(adata, *args, **kwargs)
800 |
801 | pseudotime = adata.obs['dpt_pseudotime'].values
802 | pseudotime = pseudotime[mask]
803 | pseudotime[pseudotime == np.inf] = 1
804 | pseudotime[pseudotime == -np.inf] = 0
805 |
806 | if typp == 'emb':
807 |
808 | scatter = hv.Scatter({'x': emb[:, 0], 'y': emb[:, 1], 'pseudotime': pseudotime},
809 | kdims=[x, y], vdims='pseudotime')
810 |
811 | return scatter.opts(title='Pseudotime',
812 | cmap=cont_cmap, color='pseudotime',
813 | colorbar=True,
814 | colorbar_opts={'width': CBW},
815 | size=size,
816 | clim=minmax(pseudotime, perc=perc),
817 | xlim=(xmin, xmax),
818 | ylim=(ymin, ymax),
819 | xlabel=f'{bsu}{comp[0]}',
820 | ylabel=f'{bsu}{comp[1]}')# if not ret_hl else root_cell_scatter
821 |
822 | if typp == 'expr':
823 | expr = ad_mraw.obs_vector(gene)
824 |
825 | x = hv.Dimension('x', label='pseudotime')
826 | y = hv.Dimension('y', label='expression')
827 | # data is in outer scope
828 | scatter_expr = hv.Scatter({'x': pseudotime, 'y': expr, 'condition': data[mask]},
829 | kdims=[x, y], vdims='condition')
830 |
831 | scatter_expr = scatter_expr.opts(title=key,
832 | color='condition',
833 | size=size,
834 | xlim=minmax(pseudotime),
835 | ylim=minmax(expr))
836 | if is_cat:
837 | # we're manually creating legend (for datashade)
838 | return scatter_expr.opts(cmap=cat_cmap, show_legend=False)
839 |
840 | return scatter_expr.opts(colorbar=True, colorbar_opts={'width': CBW},
841 | cmap=cont_cmap, clim=minmax(data, perc=perc))
842 |
843 | if typp == 'hist':
844 | return hv.Histogram(np.histogram(pseudotime, bins=20)).opts(xlabel='pseudotime',
845 | ylabel='frequency',
846 | color='#f2f2f2')
847 |
848 | raise RuntimeError(f'Unknown type `{typp}` for `create_scatterplot`.')
849 |
850 | # we copy beforehand
851 | if kwargs.pop('copy', False):
852 | adata = adata.copy()
853 |
854 | if keep_frac is None:
855 | keep_frac = 0.2
856 |
857 | if root_cell_size is None:
858 | root_cell_size = size * 2
859 |
860 | if basis is None:
861 | basis = np.ravel(sorted(filter(len, map(BS_PAT.findall, adata.obsm.keys()))))
862 | elif isinstance(basis, str):
863 | basis = np.array([basis])
864 | elif not isinstance(basis, np.ndarray):
865 | basis = np.array(basis)
866 |
867 | if perc is None:
868 | perc = [None, None]
869 | assert len(perc) == 2, f'Percentile must be of length 2, found `{len(perc)}`.'
870 | if all(map(lambda p: p is not None, perc)):
871 | perc = sorted(perc)
872 |
873 | assert keep_frac >= 0 and keep_frac <= 1, f'`keep_perc` must be in interval `[0, 1]`, got `{keep_frac}`.'
874 | assert subsample in ALL_SUBSAMPLING_STRATEGIES, f'Invalid subsampling strategy `{subsample}`. Possible values are `{ALL_SUBSAMPLING_STRATEGIES}`.'
875 |
876 | if subsample == 'uniform':
877 | cb_kwargs = {'steps': steps}
878 | elif subsample == 'density':
879 | cb_kwargs = {'size': int(keep_frac * adata.n_obs), 'seed': seed}
880 | else:
881 | cb_kwargs = {}
882 | alazy = SamplingLazyDict(adata, subsample, callback_kwargs=cb_kwargs)
883 | adata_mraw = adata.raw if use_raw else adata
884 |
885 | if genes is None:
886 | genes = adata_mraw.var_names
887 | elif not iterable(genes):
888 | genes = [genes]
889 | genes = skip_or_filter(adata_mraw, genes, adata_mraw.var_names, where='adata.var_names', skip=skip)
890 |
891 | if sort:
892 | if any(genes[i] > genes[i + 1] for i in range(len(genes) - 1)):
893 | genes = sorted(genes)
894 |
895 | if len(genes) == 0:
896 | warnings.warn(f'No genes found. Consider speciying `skip=False`.')
897 | return
898 |
899 | if not isinstance(components, np.ndarray):
900 | components = np.array(components)
901 | if components.ndim == 1:
902 | components = np.repeat(components[np.newaxis, :], len(basis), axis=0)
903 |
904 | assert components.ndim == 2, f'Only `2` dimensional components are supported, got `{components.ndim}`.'
905 | assert components.shape[-1] == 2, f'Components\' second dimension must be of size `2`, got `{components.shape[-1]}`.'
906 |
907 | if not isinstance(basis, np.ndarray):
908 | basis = np.array(basis)
909 |
910 | assert components.shape[0] == len(basis), f'Expected #components == `{len(basis)}`, got `{components.shape[0]}`.'
911 | assert np.all(components >= 0), f'Currently, only positive indices are supported, found `{list(map(list, components))}`.'
912 |
913 | diffmap_ix = np.where(basis != 'diffmap')[0]
914 | components[diffmap_ix, :] -= 1
915 |
916 | for bs, comp in zip(basis, components):
917 | shape = adata.obsm[f'X_{bs}'].shape
918 | assert f'X_{bs}' in adata.obsm.keys(), f'`X_{bs}` not found in `adata.obsm`'
919 | assert shape[-1] > np.max(comp), f'Requested invalid components `{list(comp)}` for basis `X_{bs}` with shape `{shape}`.'
920 |
921 | if adata.n_obs > SUBSAMPLE_THRESH and subsample in NO_SUBSAMPLE:
922 | warnings.warn(f'Number of cells `{adata.n_obs}` > `{SUBSAMPLE_THRESH}`. Consider specifying `subsample={SUBSAMPLING_STRATEGIES}`.')
923 |
924 | if cat_cmap is None:
925 | cat_cmap = Sets1to3
926 |
927 | if cont_cmap is None:
928 | cont_cmap = Viridis256
929 |
930 | kdims = [hv.Dimension('Root cell', values=(adata if root_cell_all else alazy[basis[0], tuple(components[0])][0]).obs_names),
931 | hv.Dimension('Gene', values=genes),
932 | hv.Dimension('Basis', values=basis)]
933 | cs = lambda cell, gene, bs, *args, **kwargs: create_scatterplot(cell, gene, bs, perc[0], perc[1], *args, **kwargs)
934 |
935 | data, is_cat = get_data(adata, key)
936 | if is_cat:
937 | data = pd.Categorical(data)
938 | aggregator = ds.count_cat
939 | cmap = cat_cmap
940 | legend = hv.NdOverlay({c: hv.Points([0, 0], label=str(c)).opts(size=0, color=color)
941 | for c, color in zip(data.categories, cat_cmap)})
942 | else:
943 | data = np.array(data, dtype=np.float64)
944 | aggregator = ds.mean
945 | cmap = cont_cmap
946 | legend = None
947 | if show_perc and subsample != 'datashade':
948 | kdims += [
949 | hv.Dimension('Percentile (lower)', range=(0, 100), step=0.1, type=float, default=0 if perc[0] is None else perc[0]),
950 | hv.Dimension('Percentile (upper)', range=(0, 100), step=0.1, type=float, default=100 if perc[1] is None else perc[1])
951 | ]
952 | cs = create_scatterplot
953 |
954 | emb = hv.DynamicMap(partial(cs, typp='emb'), kdims=kdims)
955 | if root_cell_hl:
956 | root_cell = hv.DynamicMap(partial(cs, typp='root_cell_hl'), kdims=kdims)
957 | emb_d = hv.DynamicMap(partial(cs, typp='emb_discrete'), kdims=kdims)
958 | expr = hv.DynamicMap(partial(cs, typp='expr'), kdims=kdims)
959 | hist = hv.DynamicMap(partial(cs, typp='hist'), kdims=kdims)
960 |
961 | if subsample == 'datashade':
962 | emb = dynspread(datashade(emb, aggregator=ds.mean('pseudotime'), cmap=cont_cmap,
963 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize],
964 | min_alpha=255),
965 | threshold=0.8, max_px=5)
966 | emb_d = dynspread(datashade(emb_d, aggregator=aggregator('condition'), cmap=cmap,
967 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize],
968 | min_alpha=255),
969 | threshold=0.8, max_px=5)
970 | expr = dynspread(datashade(expr, aggregator=aggregator('condition'), cmap=cmap,
971 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize],
972 | min_alpha=255),
973 | threshold=0.8, max_px=5)
974 | elif subsample == 'decimate':
975 | emb, emb_d, expr = (decimate(d, max_samples=int(adata.n_obs * keep_frac)) for d in (emb, emb_d, expr))
976 |
977 | if root_cell_hl:
978 | emb *= root_cell # emb * root_cell.opts(axiswise=True, framewise=True)
979 |
980 | emb = emb.opts(axiswise=False, framewise=True, frame_height=plot_height, frame_width=plot_width)
981 | expr = expr.opts(axiswise=True, framewise=True, frame_height=plot_height, frame_width=plot_width)
982 | emb_d = emb_d.opts(axiswise=True, framewise=True, frame_height=plot_height, frame_width=plot_width)
983 | hist = hist.opts(axiswise=True, framewise=True, frame_height=plot_height, frame_width=plot_width)
984 |
985 | if show_legend and legend is not None:
986 | emb_d = (emb_d * legend).opts(legend_position=legend_loc, show_legend=True)
987 |
988 | return ((emb + emb_d) + (hist + expr).opts(axiswise=True, framewise=True)).cols(2)
989 |
990 |
991 | @wrap_as_col
992 | def graph(adata, key, basis=None, components=[1, 2], obs_keys=[], color_key=None, color_key_reduction=np.sum,
993 | ixs=None, top_n_edges=None, filter_edges=None, directed=True, bundle=False, bundle_kwargs={},
994 | subsample=None, layouts=None, layout_kwargs={}, force_paga_indices=False,
995 | degree_by=None, legend_loc='top_right', node_size=12, edge_width=2, arrowhead_length=None,
996 | perc=None, color_edges_by='weight', hover_selection='nodes',
997 | node_cmap=None, edge_cmap=None, plot_height=600, plot_width=600):
998 | '''
999 | Params
1000 | --------
1001 |
1002 | adata: anndata.Anndata
1003 | anndata object
1004 | key: Str
1005 | key in `adata.uns`, `adata.uns[\'paga\'] or adata.uns[\'neighbors\']` which
1006 | represents the graph as an adjacency matrix (can be sparse)
1007 | use `'paga'` to access PAGA connectivies graph or (prefix `'p:...'`)
1008 | to access `adata.uns['paga'][...]`
1009 | for `adata.uns['neighbors'][...]`, use prefix `'n:...'`
1010 | basis: Union[Str, List[Str]], optional (default: `None`)
1011 | basis in `adata.obsm`, if `None`, get all of them
1012 | components: Union[List[Int], List[List[Int]]], optional (default: `[1, 2]`)
1013 | components of specified `basis`
1014 | if it's of type `List[Int]`, all the basis have use the same components
1015 | color_key: Str, optional (default: `None`)
1016 | variable in `adata.obs` with which to color in each node
1017 | or `'incoming'`, `'outgoing'` for coloring values based on weights
1018 | color_key_reduction: Callable, optional (default: `np.sum`)
1019 | a numpy function, such as `np.mean`, `np.max`, ... when
1020 | `color_key` is `'incoming'` or `'outgoing'`
1021 | obs_keys: List[Str], optional (default: `None`)
1022 | keys of categorical observations in `adata.obs`
1023 | if `None`, get all available, only visible when `hover_selection='nodes'`
1024 | ixs: List[Int], optional (default: `None`)
1025 | list of indices of nodes of graph to visualize
1026 | if `None`, visualize all
1027 | top_n_edges: Union[Int, Tuple[Int, Bool, Str]], optional (default: `None`)
1028 | only for directed graph
1029 | maximum number of outgoing edges per node to keep based on decreasing weight
1030 | if a tuple, the second element specifies whether it's ascending or not
1031 | the third one whether whether to consider outgoing ('out') or ('in') incoming edges
1032 | filter_edges: Tuple[Float, Float], optional (default: `None`)
1033 | min and max threshold values for edge visualization
1034 | nodes without edges will *NOT* be removed
1035 | directed: Bool, optional (default: `True`)
1036 | whether the graph is directed or not
1037 | subsample: Str, optional (default: `None`)
1038 | subsampling strategies for edges
1039 | possible values are `None, \'none\', \'datashade\'`
1040 | bundle: Bool, optional (default: `False`)
1041 | whether to bundle edges together (can be computationally expensive)
1042 | bundle_kwargs: Dict, optional (defaul: `{}`)
1043 | kwargs for bundler, e.g. `iterations=1` (default `4`)
1044 | for more options, see `hv.operation.datashader.bundle_graph`
1045 | layouts: List[Str], optional (default: `None`)
1046 | layout names to use when drawing graph, e.g. `'umap'` in `adata.obsm`
1047 | or `'kamada_kawai'` from `nx.layouts`
1048 | if `None`, use all available layouts
1049 | layout_kwargs: Dict[Str, Dict], optional (default: `{}`)
1050 | kwargs for a given layout
1051 | force_paga_indices: Bool, optional (default: `False`)
1052 | by default, when `key='paga'`, all indices are used
1053 | regardless of what was specified
1054 | degree_by: Str, optional (default: `None`)
1055 | if `'weights'`, use edge weights when calculating the degree
1056 | only visible when `hover_selection='nodes'`
1057 | legend_loc: Str, optional (default: `'top_right'`)
1058 | locations of the legend, if `None`, do not show legend
1059 | node_size: Float, optional (default: `12`)
1060 | size of the graph nodes
1061 | edge_width: Float, optional (default: `2`)
1062 | width of the graph edges
1063 | arrowhead_length: Float, optional (default: `None`)
1064 | length of the arrow when `directed=True`
1065 | perc: List[Float], optional (default: `None`)
1066 | percentile for edge colors
1067 | *WARNING* this can remove nodes and will be fixed in the future
1068 | color_edges_by: Str, optional (default: `weight`)
1069 | whether to color edges, if `None` do not color edges
1070 | hover_selection: Str, optional (default: `'nodes'`)
1071 | whether to define hover over `'nodes'` or `'edges'`
1072 | if `subsample == 'datashade'`, it is always `'nodes'`
1073 | node_cmap: List[Str], optional (default: `datashader.colors.Sets1to3`)
1074 | colormap in hex format for `color_key`
1075 | edge_cmap: List[Str], optional (default: `bokeh.palettes.Viridis256`)
1076 | continuous colormap in hex format for edges
1077 | plot_height: Int, optional (default: `600`)
1078 | height of the plot in pixels
1079 | plot_width: Int, optional (default: `600`)
1080 | width of the plot in pixels
1081 |
1082 | Returns
1083 | --------
1084 | plot: panel.Column
1085 | `hv.DynamicMap` wrapped in `panel.Column` that displays the graph in various layouts
1086 | '''
1087 |
1088 | def normalize(emb):
1089 | # TODO: to this once
1090 | # normalize because of arrows...
1091 | emb = emb.copy()
1092 | x_min, y_min = np.min(emb[:, 0]), np.min(emb[:, 1])
1093 | emb[:, 0] = (emb[:, 0] - x_min) / (np.max(emb[:, 0]) - x_min)
1094 | emb[:, 1] = (emb[:, 1] - y_min) / (np.max(emb[:, 1]) - y_min)
1095 | return emb
1096 |
1097 | def create_graph(adata, data):
1098 | if perc is not None:
1099 | data = percentile(data, perc)
1100 | create_using = nx.DiGraph if directed else nx.Graph
1101 | g = (nx.from_scipy_sparse_matrix if issparse(data) else nx.from_numpy_array)(data, create_using=create_using)
1102 |
1103 | if filter_edges is not None:
1104 | minn, maxx = filter_edges
1105 | minn = minn if minn is not None else -np.inf
1106 | maxx = maxx if maxx is not None else np.inf
1107 | for e, attr in list(g.edges.items()):
1108 | if attr['weight'] < minn or attr['weight'] > maxx:
1109 | g.remove_edge(*e)
1110 |
1111 | to_keep = None
1112 | if top_n_edges is not None:
1113 | if isinstance(top_n_edges, (tuple, list)):
1114 | to_keep, ascending, group_by = top_n_edges
1115 | else:
1116 | to_keep, ascending, group_by = top_n_edges, False, 'out'
1117 |
1118 | source, target = zip(*g.edges)
1119 | weights = [v['weight'] for v in g.edges.values()]
1120 | tmp = pd.DataFrame({'out': source, 'in': target, 'w': weights})
1121 |
1122 | to_keep = set(map(tuple, tmp.groupby(group_by).apply(lambda g: g.sort_values('w', ascending=ascending).take(range(min(to_keep, len(g)))))[['out', 'in']].values))
1123 |
1124 | for e in list(g.edges):
1125 | if e not in to_keep:
1126 | g.remove_edge(*e)
1127 |
1128 | if not len(g.nodes):
1129 | raise RuntimeError('Empty graph.')
1130 |
1131 | if not len(g.edges):
1132 | msg = 'No edges to visualize.'
1133 | if filter_edges is not None:
1134 | msg += f' Consider altering the edge filtering thresholds `{filter_edges}`.'
1135 | if top_n_edges is not None:
1136 | msg += f' Perhaps use more top edges than `{to_keep}`.'
1137 | raise RuntimeError(msg)
1138 |
1139 | if hover_selection == 'nodes':
1140 | if directed:
1141 | nx.set_node_attributes(g, values=dict(g.in_degree(weight=degree_by)),
1142 | name='indegree')
1143 | nx.set_node_attributes(g, values=dict(g.out_degree(weight=degree_by)),
1144 | name='outdegree')
1145 | nx.set_node_attributes(g, values=nx.in_degree_centrality(g),
1146 | name='indegree centrality')
1147 | nx.set_node_attributes(g, values=nx.out_degree_centrality(g),
1148 | name='outdegree centrality')
1149 | else:
1150 | nx.set_node_attributes(g, values=dict(g.degree(weight=degree_by)),
1151 | name='degree')
1152 | nx.set_node_attributes(g, values=nx.degree_centrality(g),
1153 | name='centrality')
1154 |
1155 | if not is_paga:
1156 | nx.set_node_attributes(g, values=dict(zip(g.nodes.keys(), adata.obs.index)),
1157 | name='name')
1158 | for key in list(obs_keys):
1159 | nx.set_node_attributes(g, values=dict(zip(g.nodes.keys(), adata.obs[key])),
1160 | name=key)
1161 | if color_key is not None:
1162 | # color_vals has been set beforehand
1163 | nx.set_node_attributes(g, values=dict(zip(g.nodes.keys(), adata.obs[color_key] if color_key in adata.obs.keys() else color_vals)),
1164 | name=color_key)
1165 |
1166 | else:
1167 | nx.set_node_attributes(g, values=dict(zip(g.nodes.keys(), adata.obs[color_key].cat.categories)),
1168 | name=color_key)
1169 |
1170 | return g
1171 |
1172 | def embed_graph(layout_key, graph):
1173 | bs_key = f'X_{layout_key}'
1174 | if bs_key in adata.obsm.keys():
1175 | emb = adata_ss.obsm[bs_key][:, get_component[layout_key]]
1176 | emb = normalize(emb)
1177 | layout = dict(zip(graph.nodes.keys(), emb))
1178 | l_kwargs = {}
1179 | elif layout_key == 'paga':
1180 | layout = dict(zip(graph.nodes.keys(), paga_pos))
1181 | l_kwargs = {}
1182 | elif layout_key in DEFAULT_LAYOUTS:
1183 | layout = DEFAULT_LAYOUTS[layout_key]
1184 | l_kwargs = layout_kwargs.get(layout_key, {})
1185 |
1186 | g = hv.Graph.from_networkx(graph, positions=layout, **l_kwargs)
1187 | g = g.opts(inspection_policy='nodes' if subsample == 'datashade' else hover_selection,
1188 | tools=['hover', 'box_select'],
1189 | edge_color=hv.dim(color_edges_by) if color_edges_by is not None else None,
1190 | edge_line_width=edge_width * (hv.dim('weight') if is_paga else 1),
1191 | edge_cmap=edge_cmap,
1192 | node_color=color_key,
1193 | node_cmap=node_cmap,
1194 | directed=directed,
1195 | colorbar=True,
1196 | show_legend=legend_loc is not None
1197 | )
1198 |
1199 | return g if arrowhead_length is None else g.opts(arrowhead_length=arrowhead_length)
1200 |
1201 | def get_nodes(layout_key): # DRY DRY DRY
1202 | nodes = bundled[layout_key].nodes
1203 | bs_key = f'X_{layout_key}'
1204 |
1205 | if bs_key in adata.obsm.keys():
1206 | emb = adata_ss.obsm[bs_key][:, get_component[layout_key]]
1207 | emb = normalize(emb)
1208 | xlim = minmax(emb[:, 0])
1209 | ylim = minmax(emb[:, 1])
1210 | elif layout_key == 'paga':
1211 | xlim = minmax(paga_pos[:, 0])
1212 | ylim = minmax(paga_pos[:, 1])
1213 | else:
1214 | xlim, ylim = bundled[layout_key].range('x'), bundled[layout_key].range('y')
1215 |
1216 | xlim, ylim = pad(*xlim), pad(*ylim) # for datashade
1217 |
1218 | # remove axes for datashade
1219 | return nodes.opts(xlim=xlim, ylim=ylim, xaxis=None, yaxis=None, show_legend=legend_loc is not None)
1220 |
1221 | assert subsample in (None, 'none', 'datashade'), \
1222 | f'Invalid subsampling strategy `{subsample}`. Possible values are None, \'none\', \'datashade\'.`'
1223 |
1224 | if top_n_edges is not None:
1225 | assert directed, f'`n_top_edges` works only on directed graphs.`'
1226 | if isinstance(top_n_edges, (tuple, list)):
1227 | assert len(top_n_edges) == 3, f'`top_n_edges` must be of length 3, found `{len(top_n_edges)}`.'
1228 | assert isinstance(top_n_edges[0], int), f'`top_n_edges[0]` must be an int, found `{type(top_n_edges[0])}`.'
1229 | assert isinstance(top_n_edges[1], bool), f'`top_n_edges[1]` must be a bool, found `{type(top_n_edges[1])}`.'
1230 | assert top_n_edges[2] in ('in', 'out') , '`top_n_edges[2]` must be either \'in\' or \'out\'.'
1231 | else:
1232 | assert isinstance(top_n_edges, int), f'`top_n_edges` must be an int, found `{type(top_n_edges)}`.'
1233 |
1234 | if edge_cmap is None:
1235 | edge_cmap = Viridis256
1236 | if node_cmap is None:
1237 | node_cmap = Sets1to3
1238 |
1239 | if color_key is not None:
1240 | assert color_key in adata.obs or color_key in ('incoming', 'outgoing'), f'Color key `{color_key}` not found in `adata.obs` and is not \'incoming\' or \'outgoing\'.'
1241 |
1242 | if obs_keys is None:
1243 | obs_keys = adata.obs.keys()
1244 | else:
1245 | for obs_key in obs_keys:
1246 | assert obs_key in adata.obs.keys(), f'Key `{obs_key}` not found in `adata.obs`.'
1247 |
1248 | if key.startswith('p:') or key.startswith('n:'):
1249 | which, key = key.split(':')
1250 | elif key == 'paga': # QOL
1251 | which, key = 'p', 'connectivities'
1252 | else:
1253 | which = None
1254 |
1255 | paga_pos = None
1256 | is_paga = False
1257 | if which is None:
1258 | if key in adata.uns.keys():
1259 | data = adata.uns[key]
1260 | elif hasattr(adata, 'obsp') and key in adata.obsp:
1261 | data = adata.obsp[key]
1262 | elif which == 'n' and key in adata.uns['neighbors'].keys():
1263 | data = adata.uns['neighbors'][key]
1264 | elif which == 'p' and key in adata.uns['paga'].keys():
1265 | data = adata.uns['paga'][key]
1266 | is_paga = True
1267 | directed = False
1268 | if 'pos' in adata.uns['paga'].keys():
1269 | paga_pos = adata.uns['paga']['pos']
1270 | else:
1271 | raise ValueError(f'Key `{key}` not found in `adata.uns` or '
1272 | '`adata.uns[\'neighbors\']` or `adata.uns[\'paga\']`. '
1273 | 'To visualize the graphs in `uns[\'neighbors\']` or uns[\'paga\'] '
1274 | 'prefix the key with `n:` or `p:`, respectively (e.g. `n:connectivities`).')
1275 | assert data.ndim == 2, f'Adjacency matrix must be dimension of `2`, found `{adata.ndim}`.'
1276 | assert data.shape[0] == data.shape[1], 'Adjacency matrix is not square, found shape `{data.shape}`.'
1277 |
1278 | if ixs is None or (is_paga and not force_paga_indices):
1279 | ixs = np.arange(data.shape[0])
1280 | else:
1281 | assert np.min(ixs) >= 0
1282 | assert np.max(ixs) < adata.shape[0]
1283 |
1284 | data = data[ixs, :][:, ixs]
1285 | adata_ss = adata[ixs, :] if not is_paga or (len(ixs) != data.shape[0] and force_paga_indices) else adata
1286 |
1287 | if layouts is None:
1288 | layouts = list(DEFAULT_LAYOUTS.keys())
1289 | if isinstance(layouts, str):
1290 | layouts = [layouts]
1291 | for l in layouts:
1292 | assert l in DEFAULT_LAYOUTS.keys(), f'Unknown layout `{l}`. Available layouts are `{list(DEFAULT_LAYOUTS.keys())}`.'
1293 |
1294 | if np.min(data) < 0 and 'kamada_kawai' in layouts:
1295 | warnings.warn('`kamada_kawai` layout required non-negative edges, removing it from the list of possible layouts.')
1296 | layouts.remove('kamada_kawai')
1297 |
1298 | if basis is None:
1299 | basis = np.ravel(sorted(filter(len, map(BS_PAT.findall, adata.obsm.keys()))))
1300 | elif basis is str:
1301 | basis = [basis]
1302 | if not isinstance(basis, np.ndarray):
1303 | basis = np.array(basis)
1304 |
1305 | if not isinstance(components, np.ndarray):
1306 | components = np.array(components)
1307 | if components.ndim == 1:
1308 | components = np.repeat(components[np.newaxis, :], len(basis), axis=0)
1309 | if len(basis):
1310 | components[np.where(basis != 'diffmap')] -= 1
1311 |
1312 | if is_paga:
1313 | g_name = adata.uns['paga']['groups']
1314 | if color_key is None or color_key != g_name:
1315 | warnings.warn(f'Color key `{color_key}` differs from PAGA\'s groups `{g_name}`, setting it to `{g_name}`.')
1316 | color_key = g_name
1317 | if len(basis):
1318 | warnings.warn(f'Cannot plot PAGA in the basis `{basis}`, removing them from layouts.')
1319 | basis, components = [], []
1320 |
1321 | for bs, comp in zip(basis, components):
1322 | shape = adata.obsm[f'X_{bs}'].shape
1323 | assert f'X_{bs}' in adata.obsm.keys(), f'`X_{bs}` not found in `adata.obsm`'
1324 | assert shape[-1] > np.max(comp), f'Requested invalid components `{list(comp)}` for basis `X_{bs}` with shape `{shape}`.'
1325 |
1326 | if paga_pos is not None:
1327 | basis = ['paga']
1328 | components = [0, 1]
1329 | get_component = dict(zip(basis, components))
1330 |
1331 | is_categorical = False
1332 | if color_key is not None:
1333 | node_cmap = adata_ss.uns[f'{color_key}_colors'] if f'{color_key}_colors' in adata_ss.uns else node_cmap
1334 | if color_key in adata.obs:
1335 | color_vals = adata_ss.obs[color_key]
1336 | if is_categorical_dtype(color_vals) or is_string_dtype(adata.obs[color_key]):
1337 | color_vals = color_vals.astype('category').cat.categories
1338 | is_categorical = True
1339 | node_cmap = odict(zip(color_vals, to_hex_palette(node_cmap)))
1340 | else:
1341 | color_vals = adata_ss.obs[color_key].values
1342 | else:
1343 | color_vals = np.array(color_key_reduction(data, axis=int(color_key == 'outgoing'))).flatten()
1344 |
1345 | if not is_categorical:
1346 | legend_loc = None
1347 |
1348 | layouts = np.append(basis, layouts)
1349 | if len(layouts) == 0:
1350 | warnings.warn('Nothing to plot, no layouts found.')
1351 | return
1352 |
1353 | # because of the categories
1354 | graph = create_graph(adata_ss, data=data)
1355 |
1356 | kdims = [hv.Dimension('Layout', values=layouts)]
1357 | g = hv.DynamicMap(partial(embed_graph, graph=graph), kdims=kdims).opts(axiswise=True, framewise=True) # necessary as well
1358 |
1359 | if subsample != 'datashade':
1360 | for layout_key in layouts:
1361 | bs_key = f'X_{layout_key}'
1362 | if bs_key in adata.obsm.keys():
1363 | emb = adata_ss.obsm[bs_key][:, get_component[layout_key]]
1364 | emb = normalize(emb)
1365 | xlim = minmax(emb[:, 0])
1366 | ylim = minmax(emb[:, 1])
1367 | elif layout_key == 'paga':
1368 | xlim = minmax(paga_pos[:, 0])
1369 | ylim = minmax(paga_pos[:, 1])
1370 | else:
1371 | xlim, ylim = g[layout_key].range('x'), g[layout_key].range('y')
1372 | xlim, ylim = pad(*xlim), pad(*ylim)
1373 | g[layout_key].opts(xlim=xlim, ylim=ylim) # other layouts are not normalized
1374 |
1375 | bundled = bundle_graph(g, **bundle_kwargs, weight=None) if bundle else g.clone()
1376 | nodes = hv.DynamicMap(get_nodes, kdims=kdims).opts(axiswise=True, framewise=True) # needed for datashade
1377 |
1378 | if subsample == 'datashade':
1379 | g = (datashade(bundled, normalization='linear', color_key=color_edges_by, min_alpha=128,
1380 | cmap='black' if color_edges_by is None else edge_cmap,
1381 | streams=[hv.streams.RangeXY(transient=True), hv.streams.PlotSize]))
1382 | res = (g * nodes).opts(height=plot_height, width=plot_width).opts(
1383 | hv.opts.Nodes(size=node_size, tools=['hover'], cmap=node_cmap,
1384 | fill_color='orange' if color_key is None else color_key)
1385 | )
1386 | else:
1387 | res = bundled.opts(height=plot_height, width=plot_width).opts(
1388 | hv.opts.Graph(
1389 | node_size=node_size,
1390 | node_fill_color='orange' if color_key is None else color_key,
1391 | node_nonselection_alpha=0.05,
1392 | edge_nonselection_alpha=0.05,
1393 | edge_cmap=edge_cmap,
1394 | node_cmap=node_cmap
1395 | )
1396 | )
1397 | if legend_loc is not None and color_key is not None:
1398 | res *= hv.NdOverlay({k: hv.Points([0,0], label=str(k)).opts(size=0, color=v)
1399 | for k, v in node_cmap.items()})
1400 |
1401 | if legend_loc is not None and color_key is not None:
1402 | res = res.opts(legend_position=legend_loc)
1403 |
1404 | return res.opts(hv.opts.Graph(xaxis=None, yaxis=None))
1405 |
--------------------------------------------------------------------------------
/interactive_plotting/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from interactive_plotting.utils._utils import *
2 |
--------------------------------------------------------------------------------
/interactive_plotting/utils/_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from functools import wraps
4 | from collections import Iterable
5 | from inspect import signature
6 | from sklearn.neighbors import NearestNeighbors
7 | from scipy.sparse import issparse
8 |
9 | import anndata
10 | import matplotlib.colors as colors
11 | import scanpy as sc
12 | import numpy as np
13 | import pandas as pd
14 | import networkx as nx
15 | import panel as pn
16 | import re
17 | import itertools
18 | import warnings
19 |
20 |
21 | NO_SUBSAMPLE = (None, 'none')
22 | SUBSAMPLING_STRATEGIES = ('datashade', 'decimate', 'density', 'uniform')
23 | ALL_SUBSAMPLING_STRATEGIES = NO_SUBSAMPLE + SUBSAMPLING_STRATEGIES
24 |
25 | SUBSAMPLE_THRESH = 30_000
26 | HOLOMAP_THRESH = 50
27 | OBSM_SEP = ':'
28 |
29 | CBW = 10 # colorbar width
30 | BS_PAT = re.compile('^X_(.+)')
31 |
32 | # for graph
33 | DEFAULT_LAYOUTS = {l.split('_layout')[0]:getattr(nx.layout, l)
34 | for l in dir(nx.layout) if l.endswith('_layout')}
35 | DEFAULT_LAYOUTS.pop('bipartite')
36 | DEFAULT_LAYOUTS.pop('rescale')
37 | DEFAULT_LAYOUTS.pop('spectral')
38 |
39 |
40 | class SamplingLazyDict(dict):
41 |
42 | def __init__(self, adata, subsample, *args, callback_kwargs={}, **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.adata = adata
45 | self.callback_kwargs = callback_kwargs
46 |
47 | if subsample == 'uniform':
48 | self.callback = sample_unif
49 | elif subsample == 'density':
50 | self.callback = sample_density
51 | else:
52 | ixs = list(range(adata.n_obs))
53 | self.callback = lambda *args, **kwargs: (adata, ixs)
54 |
55 | def __getitem__(self, key):
56 | if key not in self:
57 | bs, comps = key
58 | rev_comps = comps[::-1]
59 |
60 | if (bs, rev_comps) in self.keys():
61 | res, ixs = self[bs, rev_comps]
62 | else:
63 | res, ixs = self.callback(self.adata, bs=bs, components=comps, **self.callback_kwargs)
64 |
65 | self[key] = res, ixs
66 |
67 | return res, ixs
68 |
69 | return super().__getitem__(key)
70 |
71 |
72 | def to_hex_palette(palette, normalize=True):
73 | """
74 | Converts matplotlib color array to hex strings
75 | """
76 | if not isinstance(palette, np.ndarray):
77 | palette = np.array(palette)
78 |
79 | if isinstance(palette[0], str):
80 | assert all(map(colors.is_color_like, palette)), 'Not all strings are color like.'
81 | return palette
82 |
83 | if normalize:
84 | minn = np.min(palette)
85 | # normalize to [0, 1]
86 | palette = (palette - minn) / (np.max(palette) - minn)
87 |
88 | return [colors.to_hex(c) if colors.is_color_like(c) else c for c in palette]
89 |
90 |
91 | def pad(minn, maxx, padding=0.05):
92 | if minn > maxx:
93 | maxx, minn = minn, maxx
94 | return minn - padding, maxx + padding
95 |
96 |
97 | def iterable(obj):
98 | '''
99 | Checks whether the object is iterable non-string.
100 |
101 | Params
102 | --------
103 | obj: Object
104 | Python object
105 |
106 | Returns
107 | --------
108 | is_iterable: Bool
109 | whether the object is not `str` and is
110 | instance of class `Iterable`
111 | '''
112 |
113 | return not isinstance(obj, str) and isinstance(obj, Iterable)
114 |
115 |
116 | def istype(obj):
117 | '''
118 | Checks whether the object is of class `type`.
119 |
120 | Params
121 | --------
122 | obj: Union[Object, Tuple]
123 | Python object or a tuple
124 |
125 | Returns
126 | --------
127 | is_type: Bool
128 | `True` if the objects is instance of class `type` or
129 | all the element of the tuple is of class `type`
130 | '''
131 |
132 | return isinstance(obj, type) or (isinstance(obj, tuple) and all(map(lambda o: isinstance(o, type), obj)))
133 |
134 | def is_numeric(obj):
135 | '''
136 | Params
137 | obj: Object
138 | Python object
139 |
140 | --------
141 | Returns
142 | is_numeric: Bool
143 | `True` if the object is numeric, else `False`
144 | --------
145 | '''
146 | return all(hasattr(obj, attr)
147 | for attr in ('__add__', '__sub__', '__mul__', '__truediv__', '__pow__'))
148 |
149 |
150 | def is_categorical(obj):
151 | '''
152 | Is the object categorical?
153 |
154 | Params
155 | --------
156 | obj: Python object
157 | object that has attribute `'dtype'`,
158 |
159 | Returns
160 | --------
161 | is_categorical: Bool
162 | `True` if it's categorical else `False`
163 | '''
164 |
165 | return obj.dtype.name == 'category'
166 |
167 |
168 | def minmax(component, perc=None, is_sorted=False):
169 | '''
170 | Get the minimum and maximum value of an array.
171 |
172 | Params
173 | --------
174 | component: Union[np.ndarray, List, Tuple]
175 | 1-D array
176 | perc: Union[List[Float], Tuple[Float]]
177 | clip the values by the percentiles
178 | is_sorted: Bool, optional (default: `False`)
179 | whether the component is already sorted,
180 | if `True`, min and max are the first and last
181 | elements respectively
182 |
183 | Returns
184 | --------
185 | min_max: Tuple[Float, Float]
186 | minimum and maximum values that are not NaN
187 | '''
188 | if perc is not None:
189 | assert len(perc) == 2, 'Percentile must be of length 2.'
190 | component = np.clip(component, *np.percentile(component, sorted(perc)))
191 |
192 | return (np.nanmin(component), np.nanmax(component)) if not is_sorted else (component[0], component[-1])
193 |
194 |
195 | def skip_or_filter(adata, needles, haystack, where='', dtype=None,
196 | skip=False, warn=True, ignore_after=None):
197 | '''
198 | Find all the needles in a given haystack.
199 |
200 | Params
201 | --------
202 | adata: anndata.AnnData
203 | anndata object
204 | needles: List[Str]
205 | keys to search for
206 | haystack: Iterable
207 | collection to search, e.g. `adata.obs.keys()`
208 | where: Str, optional, default (`''`)
209 | attribute of `anndata.AnnData` where to look, e. g. `'obsm'`
210 | dtype: Union[Callable, Type]
211 | expected datatype of the needles
212 | skip: Bool, optional (default: `False`)
213 | whether to skip the needles which do not have
214 | the expected `dtype`
215 | warn: Bool
216 | whether to issue a warning if `skip=True`
217 | ignore_after: Str, optional (default: `None`)
218 | token used for extracting the actual key name
219 | from the needle of form `KeyTokenIndex`, neeeded when
220 | `where='obsm'`, useful e.g. for extracting specific components
221 |
222 | Returns
223 | --------
224 | found_needles: List[Str]
225 | list of all the needles in haystack
226 | if `skip=False`, will throw a `RuntimeError`
227 | if the needle's type differs from `dtype`
228 | '''
229 |
230 | needles_f = list(map(lambda n: n[:n.find(ignore_after)], needles)) if ignore_after is not None else needles
231 | res = []
232 |
233 | for n, nf in zip(needles, needles_f):
234 | if nf not in haystack:
235 | msg = f'`{nf}` not found in `adata.{where}.keys()`.'
236 | if not skip:
237 | assert False, msg
238 | if warn:
239 | warnings.warn(msg + ' Skipping.')
240 | continue
241 |
242 | col = getattr(adata, where)[nf]
243 | val = col[0] if isinstance(col, np.ndarray) else col.iloc[0] # np.ndarray of pd.DataFrame
244 | if n != nf:
245 | assert where == 'obsm', f'Indexing is only supported for `adata.obsm`, found {nf} in adata.`{where}`.'
246 | _, ix = n.split(ignore_after)
247 | assert nf == _, 'Unable to parse input.'
248 | val = val[int(ix)]
249 |
250 | msg = None
251 | is_tup = isinstance(dtype, tuple)
252 | if isinstance(dtype, type) or is_tup:
253 | if not isinstance(val, dtype):
254 | types = dtype.__name__ if not is_tup else f"Union[{', '.join(map(lambda d: d.__name__, dtype))}]"
255 | msg = f'Expected `{nf}` to be of type `{types}`, found `{type(val).__name__}`.'
256 | elif callable(dtype):
257 | if not dtype(val):
258 | msg = f'`{nf}` did not pass the type checking of `{callable.__name__}`.'
259 | else:
260 | assert isinstance(dtype, str)
261 | if not dtype == col.dtype.name:
262 | msg = f'Expected `{nf}` to be of type `{dtype}`, found `{col.dtype.name}`.'
263 |
264 | if msg is not None:
265 | if not skip:
266 | raise RuntimeError(msg)
267 | if warn:
268 | warnings.warn(msg + ' Skipping.')
269 | continue
270 |
271 | res.append(n)
272 |
273 | return res
274 |
275 |
276 | def has_attributes(*args, **kwargs):
277 | '''
278 | Params
279 | --------
280 | *args: variable length arguments
281 | key to check in for
282 | **kwargs: keyword arguments
283 | attributes to check, keys will be interpreted
284 | as arguments in function signature to check
285 | and values are lists annotated as follows:
286 | `[, 'a:', '', '', 'a:', ...]`
287 | using type `None` will result in no type checking
288 |
289 | Returns
290 | --------
291 | wrapped: Callable
292 | function, which does the checks at runtime
293 | --------
294 | '''
295 |
296 | def inner(fn):
297 | '''
298 | Binds the arguments of the function and checks the types.
299 |
300 | Params
301 | --------
302 | fn: Callable
303 | function to wrap
304 |
305 | Returns
306 | --------
307 | wrapped: Callable
308 | the wrapped function
309 | '''
310 |
311 | @wraps(fn)
312 | def inner2(*fargs, **fkwargs):
313 | bound = sig.bind(*fargs, **fkwargs)
314 | bound.apply_defaults()
315 |
316 | for k, v in kwargs.items():
317 | if isinstance(v, type) or istype(v):
318 | assert isinstance(bound.arguments[k], v), f'Argument: `{k}` must be of type: `{v}`.'
319 | elif iterable(v):
320 | if not iterable(v[0]):
321 | v = [v]
322 |
323 | for vals in v:
324 | typp = None
325 | if vals[0] is None or istype(vals[0]):
326 | typp, *vals = vals
327 | if not vals[0].startswith('a:'):
328 | raise ValueError('The first element must be an attribute '
329 | f'annotated with: `a:`, found: `{vals[0]}`. '
330 | f'Consider using: `a:{vals[0]}`.')
331 |
332 | obj = None
333 | for val in vals:
334 | if val.startswith('a:'):
335 | obj = getattr(obj if obj is not None else bound.arguments[k], val[2:])
336 | else:
337 | assert obj is not None
338 | obj = obj[val]
339 |
340 | if typp is not None:
341 | assert isinstance(obj, typp)
342 | else:
343 | raise RuntimeError(f'Unable to decode invariant: `{k}={v}`.')
344 |
345 | return fn(*fargs, **fkwargs)
346 |
347 | sig = signature(fn)
348 | for param in tuple(kwargs.keys()) + args:
349 | if not param in sig.parameters.keys():
350 | raise ValueError(f'Parameter `{param}` not found in the signature.')
351 |
352 | return inner2
353 |
354 | return inner
355 |
356 |
357 | def wrap_as_panel(fn):
358 | '''
359 | Wrap the widget inside a panel.
360 |
361 | Params
362 | --------
363 | fn: Callable
364 | funtion that returns a plot, such as `scatter`
365 |
366 | Returns
367 | --------
368 | wrapper: Callable
369 | function which return object of type `pn.panel`
370 | '''
371 |
372 | @wraps(fn)
373 | def inner(*args, **kwargs):
374 | reverse = kwargs.pop('reverse', True)
375 | res = fn(*args, **kwargs)
376 | if res is None:
377 | return None
378 |
379 | res = pn.panel(res)
380 | if reverse and hasattr(res, 'reverse'):
381 | res.reverse()
382 |
383 | return res
384 |
385 | return inner
386 |
387 |
388 | def wrap_as_col(fn):
389 | '''
390 | Wrap the widget in a column, having it's
391 | input in one row.
392 |
393 | Params
394 | --------
395 | fn: Callable
396 | funtion that returns a plot, such as `dpt`
397 |
398 | Returns
399 | --------
400 | wrapped: Callable
401 | function which return object of type `pn.Column`
402 | '''
403 |
404 | def chunkify(l, n):
405 | for i in range(0, len(l), n):
406 | yield l[i: i + n]
407 |
408 | @wraps(fn)
409 | def inner(*args, **kwargs):
410 | reverse = kwargs.pop('reverse', True)
411 | res = fn(*args, **kwargs)
412 | if res is None:
413 | return None
414 |
415 | res = pn.panel(res)
416 | if reverse and hasattr(res, 'reverse'):
417 | res.reverse()
418 |
419 | widgets = list(map(lambda w: pn.Row(*w), filter(len, chunkify(res[0], 3))))
420 | return pn.Column(*(widgets + [res[1]]))
421 |
422 | return inner
423 |
424 |
425 | def get_data(adata, needle, ignore_after=OBSM_SEP, haystacks=['obs', 'obsm', 'var_names']):
426 | f'''
427 | Search for a needle in multiple haystacks.
428 |
429 | Params
430 | --------
431 | adata: anndata.AnnData
432 | anndata object
433 | needle: Str
434 | needle to search for
435 | ignore_after: Str, optional (default: `'{OBSM_SEP}'`)
436 | token used for extracting the actual key name
437 | from the needle of form `KeyTokenIndex`, neeeded
438 | when `'obsm' in haystacks`, useful e.g. for extracting specific components
439 | haystack: List[Str], optional (default: `['obs', 'obsm', 'var_names']`)
440 | attributes of `anndata.AnnData`
441 |
442 | Returns
443 | --------
444 | (result, is_categorical): Tuple[Object, Bool]
445 | the found object and whether it's categorical
446 | '''
447 |
448 | for haystack in haystacks:
449 | obj = getattr(adata, haystack)
450 | if ignore_after in needle and haystack == 'obsm':
451 | k, ix = needle.split(ignore_after)
452 | ix = int(ix)
453 | else:
454 | k, ix = needle, None
455 |
456 | if k in obj:
457 | res = obj[k] if haystack != 'var_names' else adata.obs_vector(k)
458 | if ix is not None:
459 | assert res.ndim == 2, f'`adata.{haystack}[{k}]` must have a dimension of 2, found `{res.dim}`.'
460 | assert res.shape[-1] > ix, f'Index `{ix}` out of bounds for `adata.{haystack}[{k}]` of shape `{res.shape}`.'
461 | res = res[:, ix]
462 | if res.shape != (adata.n_obs, ):
463 | msg = f'`{needle}` in `adata.{haystack}` has wrong shape of `{res.shape}`.'
464 | if haystack == 'obsm':
465 | msg += f' Try using `{needle}{OBSM_SEP}ix`, 0 <= `ix` < {res.shape[-1]}.'
466 | raise ValueError(msg)
467 |
468 | return res, is_categorical(res)
469 |
470 | raise ValueError(f'Unable to find `{needle}` in `adata.{haystacks}`.')
471 |
472 |
473 | def get_all_obsm_keys(adata, ixs):
474 | if not isinstance(ixs, (tuple, list)):
475 | ixs = [ixs]
476 |
477 | assert all(map(lambda ix: ix >= 0, ixs)), f'All indices must be non-negative.'
478 |
479 | return list(itertools.chain.from_iterable((f'{key}{OBSM_SEP}{ix}'
480 | for key in adata.obsm.keys() if isinstance(adata.obsm[key], np.ndarray) and adata.obsm[key].ndim == 2 and adata.obsm[key].shape[-1] > ix)
481 | for ix in ixs))
482 | # based on:
483 | # https://github.com/velocyto-team/velocyto-notebooks/blob/master/python/DentateGyrus.ipynb
484 | def sample_unif(adata, steps, bs='umap', components=(0, 1)):
485 | if not isinstance(steps, (tuple, list)):
486 | steps = [steps] * len(components)
487 |
488 | assert len(components)
489 | assert min(components) >= 0
490 |
491 | embedding = adata.obsm[f'X_{bs}'][:, components]
492 | n_dim = len(components)
493 |
494 | grs = []
495 | for i in range(n_dim):
496 | m, M = np.min(embedding[:, i]), np.max(embedding[:, i])
497 | m = m - 0.025 * np.abs(M - m)
498 | M = M + 0.025 * np.abs(M - m)
499 | gr = np.linspace(m, M, num=steps[i])
500 | grs.append(gr)
501 |
502 | meshes_tuple = np.meshgrid(*grs)
503 | gridpoints_coordinates = np.vstack([i.flat for i in meshes_tuple]).T
504 |
505 | nn = NearestNeighbors()
506 | nn.fit(embedding)
507 | dist, ixs = nn.kneighbors(gridpoints_coordinates, 1)
508 |
509 | diag_step_dist = np.linalg.norm([grs[dim_i][1] - grs[dim_i][0] for dim_i in range(n_dim)])
510 | min_dist = diag_step_dist / 2
511 |
512 | ixs = ixs[dist < min_dist]
513 | ixs = np.unique(ixs)
514 |
515 | return adata[ixs, :].copy(), ixs
516 |
517 |
518 | def sample_density(adata, size, bs='umap', seed=None, components=[0, 1]):
519 | if size >= adata.n_obs:
520 | return adata
521 |
522 | if components[0] == components[1]:
523 | tmp = pd.DataFrame(np.ones(adata.n_obs) / adata.n_obs, columns=['prob_density'])
524 | else:
525 | # should be unique, using it only once since we cache the results
526 | # we don't need to add the components
527 | key_added = f'{bs}_density_ipl_tmp'
528 | remove_key = False # we may be operating on original object, keep it clean
529 | if key_added not in adata.obs.keys():
530 | sc.tl.embedding_density(adata, bs, key_added=key_added)
531 | remove_key = True
532 | tmp = pd.DataFrame(np.exp(adata.obs[key_added]) / np.sum(np.exp(adata.obs[key_added])))
533 | tmp.rename(columns={key_added: 'prob_density'}, inplace=True)
534 | if remove_key:
535 | del adata.obs[key_added]
536 |
537 | state = np.random.RandomState(seed)
538 | ixs = sorted(state.choice(range(adata.n_obs), size=size, p=tmp['prob_density'], replace=False))
539 |
540 | return adata[ixs].copy(), ixs
541 |
542 |
543 | def get_xy_data(x, adata, adata_mraw, layer, indices, use_original_limits=False, inc=0):
544 |
545 | def extract(data, ix):
546 | if isinstance(data, anndata._core.anndata.Raw) or \
547 | isinstance(data, anndata.AnnData):
548 | return data.obs_vector(ix)[indices]
549 | if issparse(data):
550 | return np.squeeze(data.getcol(ix).A)[indices]
551 |
552 | # assume np.array
553 | return data[indices, ix]
554 |
555 | xlim = None
556 | msg = f'Unable to decode key `{x}`.'
557 |
558 | if layer is None:
559 | take_from = adata_mraw
560 | else:
561 | # mraw is always adata in this case
562 | assert layer in adata.layers, f'Layer `\'{layer}\'` not found in `adata.layers`.'
563 | assert adata is adata_mraw, 'Sanity check failed.'
564 | take_from = adata.layers[layer]
565 |
566 | if not isinstance(x, int):
567 | assert isinstance(x, str)
568 | # can't use take from, since it can be an array
569 | if x in adata_mraw.var_names:
570 | ix = np.where(x == adata_mraw.var_names)[0][0]
571 | xlabel = adata_mraw.var_names[ix]
572 | x = extract(take_from, ix)
573 |
574 | return x, xlabel, xlim
575 |
576 | if x in adata_mraw.obs_keys():
577 | x, xlabel = adata_mraw.obs[x].values, x
578 | if use_original_limits:
579 | xlim = pad(*minmax(x))
580 |
581 | return x[indices], xlabel, xlim
582 |
583 | x = x.lstrip('X_')
584 | comp, *ix = x.split(':')
585 | ix = int(ix[0]) if len(ix) else (int(x == 'diffmap') + inc)
586 |
587 | if f'X_{comp}' in adata.obsm:
588 | xlabel = f'{comp}_{ix}'
589 | x = adata.obsm[f'X_{comp}'][:, ix]
590 | if use_original_limits:
591 | xlim = pad(*minmax(x))
592 |
593 | return x[indices], xlabel, xlim
594 |
595 | raise RuntimeError(msg)
596 |
597 | xlabel = adata_mraw.var_names[x]
598 | x = extract(take_from, x)
599 |
600 | return x, xlabel, xlim
601 |
602 |
603 | def get_mraw(adata, use_raw):
604 | if not use_raw:
605 | return adata
606 |
607 | return adata.raw if hasattr(adata, 'raw') and adata.raw is not None else adata
608 |
609 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scanpy>=1.4.3
2 | pandas>=0.23.4
3 | matplotlib>=3.0.2
4 | setuptools>=41.0.1
5 | scipy>=1.2.0
6 | numpy>=1.16.4
7 | scikit_learn>=0.21.3
8 | anndata>=0.7.3
9 | bokeh==1.4.0
10 | datashader==0.9.0
11 | panel==0.6.2
12 | holoviews==1.12.7
13 |
--------------------------------------------------------------------------------
/resources/images/dpt_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/dpt_plot.png
--------------------------------------------------------------------------------
/resources/images/gene_trend.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/gene_trend.png
--------------------------------------------------------------------------------
/resources/images/graph_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/graph_plot.png
--------------------------------------------------------------------------------
/resources/images/heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/heatmap.png
--------------------------------------------------------------------------------
/resources/images/highlight_de.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/highlight_de.png
--------------------------------------------------------------------------------
/resources/images/inter_hist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/inter_hist.png
--------------------------------------------------------------------------------
/resources/images/link_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/link_plot.png
--------------------------------------------------------------------------------
/resources/images/scatter3d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/scatter3d.png
--------------------------------------------------------------------------------
/resources/images/scatter_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/scatter_cat.png
--------------------------------------------------------------------------------
/resources/images/scatter_cont.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/scatter_cont.png
--------------------------------------------------------------------------------
/resources/images/scatter_general1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/scatter_general1.png
--------------------------------------------------------------------------------
/resources/images/scatter_general2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/scatter_general2.png
--------------------------------------------------------------------------------
/resources/images/thresh_hist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theislab/interactive_plotting/58b127e3e386496ef056cc3ce828499ee8f8ccc0/resources/images/thresh_hist.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | import os
4 |
5 | setup(
6 | name='Interactive Plotting',
7 | version='0.0.5',
8 | description='Interactive plotting functions for scanpy',
9 | url='https://github.com/theislab/interactive_plotting',
10 | license='MIT',
11 | packages=find_packages(),
12 | setup_requires=['setuptools_scm'],
13 | include_package_data=True,
14 | install_requires=list(map(str.strip,
15 | open(os.path.abspath('requirements.txt'), 'r').read().split())),
16 |
17 | zip_safe=False
18 | )
19 |
--------------------------------------------------------------------------------