├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── spateo-viewer.iml
└── vcs.xml
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── datasets_website
└── index.html
├── requirements.txt
├── stv_explorer.py
├── stv_reconstructor.py
├── stviewer
├── Explorer
│ ├── __init__.py
│ ├── pv_pipeline
│ │ ├── __init__.py
│ │ ├── init_parameters.py
│ │ ├── pv_actors.py
│ │ ├── pv_callback.py
│ │ ├── pv_custom.py
│ │ ├── pv_interpolation.py
│ │ ├── pv_morphogenesis.py
│ │ └── pv_plotter.py
│ └── ui
│ │ ├── __init__.py
│ │ ├── container.py
│ │ ├── drawer
│ │ ├── __init__.py
│ │ ├── adata_obj.py
│ │ ├── custom_card.py
│ │ ├── main.py
│ │ ├── model_mesh.py
│ │ ├── model_point.py
│ │ ├── morphogenesis.py
│ │ ├── output.py
│ │ └── pipeline.py
│ │ ├── layout.py
│ │ ├── toolbar.py
│ │ └── utils.py
├── Reconstructor
│ ├── __init__.py
│ ├── pv_pipeline
│ │ ├── __init__.py
│ │ ├── alignment_utils.py
│ │ ├── init_parameters.py
│ │ ├── pv_alignment.py
│ │ ├── pv_callback.py
│ │ ├── pv_custom.py
│ │ ├── pv_models.py
│ │ ├── pv_plotter.py
│ │ └── pv_tdr.py
│ └── ui
│ │ ├── __init__.py
│ │ ├── container.py
│ │ ├── drawer
│ │ ├── __init__.py
│ │ ├── alignment.py
│ │ ├── custom_card.py
│ │ ├── main.py
│ │ ├── model_point.py
│ │ └── reconstruction.py
│ │ ├── layout.py
│ │ ├── toolbar.py
│ │ └── utils.py
├── __init__.py
├── assets
│ ├── __init__.py
│ ├── anndata_preprocess.py
│ ├── dataset
│ │ ├── drosophila_S11
│ │ │ ├── h5ad
│ │ │ │ └── S11_cellbin_demo.h5ad
│ │ │ ├── mesh_models
│ │ │ │ ├── 0_Embryo_S11_aligned_mesh_model.vtk
│ │ │ │ ├── 1_CNS_S11_aligned_mesh_model.vtk
│ │ │ │ ├── 2_Midgut_S11_aligned_mesh_model.vtk
│ │ │ │ ├── 3_Hindgut_S11_aligned_mesh_model.vtk
│ │ │ │ ├── 4_Muscle_S11_aligned_mesh_model.vtk
│ │ │ │ ├── 5_SalivaryGland_S11_aligned_mesh_model.vtk
│ │ │ │ └── 6_Amnioserosa_S11_aligned_mesh_model.vtk
│ │ │ └── pc_models
│ │ │ │ ├── 0_Embryo_S11_aligned_pc_model.vtk
│ │ │ │ ├── 1_CNS_S11_aligned_pc_model.vtk
│ │ │ │ ├── 2_Midgut_S11_aligned_pc_model.vtk
│ │ │ │ ├── 3_Hindgut_S11_aligned_pc_model.vtk
│ │ │ │ ├── 4_Muscle_S11_aligned_pc_model.vtk
│ │ │ │ ├── 5_SalivaryGland_S11_aligned_pc_model.vtk
│ │ │ │ └── 6_Amnioserosa_S11_aligned_pc_model.vtk
│ │ └── mouse_E95
│ │ │ ├── h5ad
│ │ │ └── mouse_E95_demo.h5ad
│ │ │ ├── matrices
│ │ │ └── X_sparse_matrix.npz
│ │ │ ├── mesh_models
│ │ │ └── 0_Embryo_mouse_E95_mesh_model.vtk
│ │ │ └── pc_models
│ │ │ └── 0_Embryo_mouse_E95_pc_model.vtk
│ ├── dataset_acquisition.py
│ ├── dataset_manager.py
│ ├── image
│ │ ├── interactive_viewer.png
│ │ ├── spateo_logo.png
│ │ ├── spateoviewer.png
│ │ ├── static_viewer.png
│ │ ├── upload_file.png
│ │ └── upload_folder.png
│ └── image_manager.py
├── explorer_app.py
├── reconstructor_app.py
└── server.py
└── usage
├── ExplorerUsage.md
├── ReconstructorUsage.md
└── spateo-viewer.pdf
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/spateo-viewer.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/python/black
3 | rev: 23.3.0
4 | hooks:
5 | - id: black
6 |
7 | - repo: https://github.com/pycqa/isort
8 | rev: 5.12.0
9 | hooks:
10 | - id: isort
11 | args: ["--profile", "black", "--filter-files"]
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2023, Aristotle
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## Spateo-viewer: the "Google earth" browser of spatial transcriptomics
6 |
7 | [**Spateo-viewer**](https://github.com/aristoteleo/spateo-viewer) is the “Google earth” of spatial transcriptomics.
8 | Relying on a set of powerful libraries and tools in the Python ecosystem, such as Trame, PyVista, VTK, etc., it delivers
9 | a complete web application solution of creating convenient, vivid, and lightweight interface for 3D reconstruction and
10 | visualization of [**Spateo**](https://github.com/aristoteleo/spateo-release) downstream analysis results. Currently,
11 | Spateo-viewer includes two major applications, ***Explorer*** and ***Reconstructor***, which are respectively
12 | dedicated to the 3D reconstruction of spatial transcriptomics and the visualization of spatial transcriptomics analysis results.
13 |
14 | Please download and read the corresponding [**Slides**](https://github.com/aristoteleo/spateo-viewer/blob/main/usage/spateo-viewer.pdf) to learn more about Spateo-viewer.
15 |
16 | ## Highlights
17 |
18 | * In the ***Reconstructor***, 3D serial slices of spatial transcriptomics datasets can be aligned to create 3D models. The 3D model can be also cleaned up by freely clipping and editing.
19 | * In the ***Explorer***, users can not only visualize gene expression, but also easily switch between raw and different types of normalized data or data layers. Users can also visualize all cell annotation information such as cell size, cell type, tissue type, etc. All done in 3D space!
20 | * Static-viewer allows users to not only visualize the point cloud model and mesh model of the whole embryo, but also for individual organ or tissue type at the same time. It even visualizes morphogenesis vector field model to animate how cell move in the physical 3D space.
21 | * Spateo-viewer can not only run on the local computer, but also run freely on the remote server.
22 | * Users can upload custom files in the web application, or access to custom files in local folders when running Spateo-Viewer as a stand alone App.(See [**ExplorerUsage**](https://github.com/aristoteleo/spateo-viewer/blob/main/usage/ExplorerUsage.md) or [**ReconstructorUsage**](https://github.com/aristoteleo/spateo-viewer/blob/main/usage/ReconstructorUsage.md))
23 |
24 | ## Installation
25 |
26 | You can clone the [**Spateo-viewer**](https://github.com/aristoteleo/spateo-viewer) with ``git`` and install dependencies with ``pip``:
27 |
28 | git clone https://github.com/aristoteleo/spateo-viewer.git
29 | cd spateo-viewer
30 | pip install -r requirements.txt
31 |
32 | ## Usage
33 |
34 | #### Run the *Explorer* application:
35 |
36 | python stv_explorer.py --port 1234
37 |
38 | See the [**ExplorerUsage**](https://github.com/aristoteleo/spateo-viewer/blob/main/usage/ExplorerUsage.md) for more details.
39 |
40 | #### Run the *Reconstructor* application:
41 |
42 | python stv_reconstructor.py --port 1234
43 |
44 | See the [**ReconstructorUsage**](https://github.com/aristoteleo/spateo-viewer/blob/main/usage/ReconstructorUsage.md) for more details.
45 |
46 | ## Sample Datasets
47 |
48 | #### [**Mouse E9.5 dataset**](https://github.com/aristoteleo/spateo-viewer/tree/main/stviewer/assets/dataset/mouse_E95):
49 | - **h5ad/mouse_E95_demo.h5ad**:Single-cell resolution Stereo-seq data with alignment and cell annotation by the Spateo team. This data only contains 1000 highly variable genes. If you need the raw data, please check the [CNGB website](https://db.cngb.org/stomics/mosta/download/).
50 | - **matrices**: Contains various gene expression matrices for .h5ad data.
51 | - **mesh_models**:Contains mesh model of mouse embryo.
52 | - **pc_models**:Contains point cloud model of mouse embryo.
53 |
54 | #### [**Drosophila S11 dataset**](https://github.com/aristoteleo/spateo-viewer/tree/main/stviewer/assets/dataset/drosophila_S11):
55 | - **h5ad/S11_cellbin_demo.h5ad**:Single-cell resolution Stereo-seq data with alignment and cell annotation by the Spateo team. This data only contains 1000 highly variable genes.
56 | - **mesh_models**:Contains mesh models of drosophila embryos and various organs.
57 | - **pc_models**:Contains point cloud models of drosophila embryos and various organs.
58 |
59 | ## Citation
60 |
61 | [ Spatiotemporal modeling of molecular holograms ](https://www.cell.com/cell/fulltext/S0092-8674(24)01159-0)
62 |
63 | Xiaojie Qiu1, 7, 8\$\*, Daniel Y. Zhu3\$, Yifan Lu1, 7, 8, 9\$, Jiajun Yao2, 4, 10\$, Zehua Jing2, 4, 11\$, Kyung Hoi (Joseph) Min12\$, Mengnan Cheng2,6\$, Hailin Pan6, Lulu Zuo6, Samuel King13, Qi Fang2, 6, Huiwen Zheng2, 11, Mingyue Wang2, 14, Shuai Wang2, 11, Qingquan Zhang25, Sichao Yu5, Sha Liao6, 17, 18, Chao Liu15, Xinchao Wu2, 4, 16, Yiwei Lai6, Shijie Hao2, Zhewei Zhang2, 4, 16, Liang Wu18, Yong Zhang15, Mei Li17, Zhencheng Tu2, 11, Jinpei Lin2, 4, Zhuoxuan Yang2, 16, Yuxiang Li15, Ying Gu2, 6, 11, Ao Chen6, 17, 18, Longqi Liu2, 19, 20, Jonathan S. Weissman5, 22, 23, Jiayi Ma9*, Xun Xu2, 11, 21*, Shiping Liu2, 19, 20, 24*, Yinqi Bai4, 26*
64 |
65 | $Co-first authors; *:Corresponding authors
66 |
--------------------------------------------------------------------------------
/datasets_website/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Spateo Datasets
6 |
7 |
44 |
45 |
46 |
56 |
174 |
177 |
178 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | anndata>=0.8.0
2 | av>=11.0.0
3 | imageio>=2.33.1
4 | imageio-ffmpeg>=0.4.9
5 | gpytorch>=1.11
6 | numpy>=1.18.1
7 | matplotlib>=3.5.3
8 | POT>=0.8.1
9 | pyacvd>=0.2.9
10 | pyautogui>=0.9.54
11 | PyMCubes>=0.1.4
12 | pymeshfix>=0.16.2
13 | pyvista==0.40.0
14 | trame==2.5.2
15 | trame-server>=2.15.0
--------------------------------------------------------------------------------
/stv_explorer.py:
--------------------------------------------------------------------------------
1 | import getopt
2 | import sys
3 |
4 | from stviewer.explorer_app import state, static_server
5 |
6 | if __name__ == "__main__":
7 | # upload anndata
8 | state.selected_dir = None
9 |
10 | opts, args = getopt.getopt(sys.argv[1:], "p", ["port="])
11 | port = "1234" if len(opts) == 0 else opts[0][1]
12 | static_server.start(port=port)
13 |
--------------------------------------------------------------------------------
/stv_reconstructor.py:
--------------------------------------------------------------------------------
1 | import getopt
2 | import sys
3 |
4 | from stviewer.reconstructor_app import interactive_server, state
5 |
6 | if __name__ == "__main__":
7 | # upload anndata
8 | state.upload_anndata = None
9 | opts, args = getopt.getopt(sys.argv[1:], "p", ["port="])
10 | port = "8888" if len(opts) == 0 else opts[0][1]
11 | interactive_server.start(port=port)
12 |
--------------------------------------------------------------------------------
/stviewer/Explorer/__init__.py:
--------------------------------------------------------------------------------
1 | from .pv_pipeline import *
2 | from .ui import *
3 |
--------------------------------------------------------------------------------
/stviewer/Explorer/pv_pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | from .init_parameters import *
2 | from .pv_actors import generate_actors, generate_actors_tree, init_actors
3 | from .pv_callback import PVCB, SwitchModels, Viewer, vuwrap
4 | from .pv_plotter import add_single_model, create_plotter
5 |
--------------------------------------------------------------------------------
/stviewer/Explorer/pv_pipeline/init_parameters.py:
--------------------------------------------------------------------------------
1 | import matplotlib.colors as mcolors
2 |
3 | # Init parameters
4 | init_card_parameters = {
5 | "show_anndata_card": False,
6 | "show_model_card": True,
7 | "show_output_card": True,
8 | }
9 | init_adata_parameters = {
10 | "uploaded_anndata_path": None,
11 | }
12 | init_pc_parameters = {
13 | "pc_obs_value": None,
14 | "pc_gene_value": None,
15 | "pc_scalars_raw": {"None": "None"},
16 | "pc_matrix_value": "X",
17 | "pc_coords_value": "spatial",
18 | "pc_opacity_value": 1.0,
19 | "pc_ambient_value": 0.2,
20 | "pc_color_value": None,
21 | "pc_colormap_value": "Spectral",
22 | "pc_point_size_value": 4,
23 | "pc_add_legend": False,
24 | "pc_picking_group": None,
25 | "pc_overwrite": False,
26 | "pc_reload": False,
27 | "pc_colors_list": [c for c in mcolors.CSS4_COLORS.keys()],
28 | }
29 | init_mesh_parameters = {
30 | "mesh_opacity_value": 0.2,
31 | "mesh_ambient_value": 0.2,
32 | "mesh_color_value": "gainsboro",
33 | "mesh_style_value": "surface",
34 | "mesh_morphology": False,
35 | "mesh_colors_list": [c for c in mcolors.CSS4_COLORS.keys()],
36 | }
37 | init_morphogenesis_parameters = {
38 | "cal_morphogenesis": False,
39 | "morpho_target_anndata_path": None,
40 | "morpho_uploaded_target_anndata_path": None,
41 | "morpho_mapping_method": "GP",
42 | "morpho_mapping_device": "cpu",
43 | "morpho_mapping_factor": 0.2,
44 | "morphofield_factor": 3000,
45 | "morphopath_t_end": 10000,
46 | "morphopath_downsampling": 500,
47 | "morphofield_visibile": False,
48 | "morphopath_visibile": False,
49 | "morphopath_predicted_models": None,
50 | "morphopath_animation_path": None,
51 | }
52 | init_interpolation_parameters = {
53 | "cal_interpolation": False,
54 | "interpolation_device": "cpu",
55 | }
56 | init_output_parameters = {
57 | "screenshot_path": None,
58 | "animation_path": None,
59 | "animation_npoints": 50,
60 | "animation_framerate": 10,
61 | }
62 |
63 | # costum init parameters
64 | init_custom_parameters = {
65 | "custom_func": False,
66 | "custom_analysis": False,
67 | "custom_model": None,
68 | "custom_model_visible": False,
69 | "custom_parameter1": "X",
70 | "custom_parameter2": "recipe_monocle",
71 | "custom_parameter3": "pca",
72 | "custom_parameter4": "umap",
73 | "custom_parameter5": False,
74 | "custom_parameter6": "None",
75 | "custom_parameter7": 30,
76 | "custom_parameter8": 30,
77 | "custom_parameter9": "None",
78 | "custom_parameter10": 1,
79 | }
80 |
--------------------------------------------------------------------------------
/stviewer/Explorer/pv_pipeline/pv_actors.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 |
4 | warnings.filterwarnings("ignore")
5 |
6 | from .pv_plotter import add_single_model
7 |
8 | try:
9 | from typing import Literal
10 | except ImportError:
11 | from typing_extensions import Literal
12 |
13 | from typing import Optional
14 |
15 | from stviewer.assets import sample_dataset
16 |
17 |
18 | def generate_actors(
19 | plotter,
20 | pc_models: Optional[list] = None,
21 | mesh_models: Optional[list] = None,
22 | pc_model_names: Optional[list] = None,
23 | mesh_model_names: Optional[list] = None,
24 | ):
25 | # Generate actors for pc models
26 | pc_kwargs = dict(model_style="points", model_size=5)
27 | if not (pc_models is None):
28 | pc_actors = [
29 | add_single_model(
30 | plotter=plotter, model=model, model_name=model_name, **pc_kwargs
31 | )
32 | for model, model_name in zip(pc_models, pc_model_names)
33 | ]
34 | else:
35 | pc_actors = None
36 |
37 | # Generate actors for mesh models
38 | mesh_kwargs = dict(opacity=0.2, model_style="surface", color="gainsboro")
39 | if not (mesh_models is None):
40 | mesh_actors = [
41 | add_single_model(
42 | plotter=plotter, model=model, model_name=model_name, **mesh_kwargs
43 | )
44 | for model, model_name in zip(mesh_models, mesh_model_names)
45 | ]
46 | else:
47 | mesh_actors = None
48 | return pc_actors, mesh_actors
49 |
50 |
51 | def standard_tree(actors: list, base_id: int = 0):
52 | actor_tree, actor_names = [], []
53 | for i, actor in enumerate(actors):
54 | if i == 0:
55 | actor.SetVisibility(True)
56 | else:
57 | actor.SetVisibility(False)
58 | actor_names.append(str(actor.name))
59 | actor_tree.append(
60 | {
61 | "id": str(base_id + 1 + i),
62 | "parent": str(0) if i == 0 else str(base_id + 1),
63 | "visible": True if i == 0 else False,
64 | "name": str(actor.name),
65 | }
66 | )
67 |
68 | return actors, actor_names, actor_tree
69 |
70 |
71 | def generate_actors_tree(
72 | pc_actors: Optional[list] = None,
73 | mesh_actors: Optional[list] = None,
74 | ):
75 | if not (pc_actors is None):
76 | pc_actors, pc_actor_names, pc_tree = standard_tree(actors=pc_actors, base_id=0)
77 | else:
78 | pc_actors, pc_actor_names, pc_tree = [], [], []
79 |
80 | if not (mesh_actors is None):
81 | mesh_actors, mesh_actor_names, mesh_tree = standard_tree(
82 | actors=mesh_actors,
83 | base_id=0 if pc_actors is None else len(pc_actors),
84 | )
85 | else:
86 | mesh_actors, mesh_actor_names, mesh_tree = [], [], []
87 |
88 | actors = pc_actors + mesh_actors
89 | actor_names = pc_actor_names + mesh_actor_names
90 | actor_tree = pc_tree + mesh_tree
91 | return actors, actor_names, actor_tree
92 |
93 |
94 | def init_actors(plotter, path):
95 | (
96 | anndata_info,
97 | pc_models,
98 | pc_model_ids,
99 | mesh_models,
100 | mesh_model_ids,
101 | custom_colors,
102 | ) = sample_dataset(path=path)
103 |
104 | # Generate actors
105 | pc_actors, mesh_actors = generate_actors(
106 | plotter=plotter,
107 | pc_models=pc_models,
108 | pc_model_names=pc_model_ids,
109 | mesh_models=mesh_models,
110 | mesh_model_names=mesh_model_ids,
111 | )
112 |
113 | # Generate the relationship tree of actors
114 | actors, actor_names, actor_tree = generate_actors_tree(
115 | pc_actors=pc_actors,
116 | mesh_actors=mesh_actors,
117 | )
118 |
119 | return (
120 | anndata_info,
121 | actors,
122 | actor_names,
123 | actor_tree,
124 | custom_colors,
125 | )
126 |
--------------------------------------------------------------------------------
/stviewer/Explorer/pv_pipeline/pv_custom.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional, Union
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pyvista as pv
6 | from anndata import AnnData
7 | from pyvista import PolyData
8 |
9 |
10 | def RNAvelocity(
11 | adata: AnnData,
12 | pc_model: PolyData,
13 | layer: str = "X",
14 | basis_pca: str = "pca",
15 | basis_umap: str = "umap",
16 | data_preprocess: Literal[
17 | "False", "recipe_monocle", "pearson_residuals"
18 | ] = "recipe_monocle",
19 | harmony_debatch: bool = False,
20 | group_key: Optional[str] = None,
21 | n_neighbors: int = 30,
22 | n_pca_components: int = 30,
23 | n_vectors_downsampling: Optional[int] = None,
24 | vectors_size: Union[float, int] = 1,
25 | ):
26 | try:
27 | import dynamo as dyn
28 | from dynamo.preprocessing import Preprocessor
29 | from dynamo.tools.Markov import velocity_on_grid
30 | except ImportError:
31 | raise ImportError(
32 | "You need to install the package `dynamo`. "
33 | "\nInstall dynamo via `pip install dynamo-release`."
34 | )
35 | try:
36 | import harmonypy
37 | except ImportError:
38 | raise ImportError(
39 | "You need to install the package `harmonypy`. "
40 | "\nInstall harmonypy via `pip install harmonypy`."
41 | )
42 |
43 | # Preprocess
44 | _obs_index = pc_model.point_data["obs_index"]
45 | dyn_adata = adata[_obs_index, :].copy()
46 | dyn_adata.X = dyn_adata.X if layer == "X" else dyn_adata.layers[layer]
47 |
48 | # Data preprocess
49 | if basis_pca in dyn_adata.obsm.keys():
50 | if data_preprocess != "False":
51 | preprocessor = Preprocessor()
52 | preprocessor.preprocess_adata(adata, recipe="monocle")
53 | dyn.tl.reduceDimension(
54 | dyn_adata,
55 | basis=basis_pca,
56 | n_neighbors=n_neighbors,
57 | n_pca_components=n_pca_components,
58 | )
59 | else:
60 | if data_preprocess != "False":
61 | if data_preprocess == "recipe_monocle":
62 | preprocessor = Preprocessor()
63 | preprocessor.preprocess_adata(adata, recipe="monocle")
64 | elif data_preprocess == "pearson_residuals":
65 | preprocessor = Preprocessor()
66 | preprocessor.preprocess_adata(adata, recipe="pearson_residuals")
67 | dyn.tl.reduceDimension(
68 | dyn_adata,
69 | basis=basis_pca,
70 | n_neighbors=n_neighbors,
71 | n_pca_components=n_pca_components,
72 | )
73 | if harmony_debatch:
74 | harmony_out = harmonypy.run_harmony(
75 | dyn_adata.obsm[basis_pca], dyn_adata.obs, group_key, max_iter_harmony=20
76 | )
77 | dyn_adata.obsm[basis_pca] = harmony_out.Z_corr.T
78 | dyn.tl.reduceDimension(
79 | dyn_adata,
80 | X_data=dyn_adata.obsm[basis_pca],
81 | enforce=True,
82 | n_neighbors=n_neighbors,
83 | n_pca_components=n_pca_components,
84 | )
85 |
86 | # RNA velocity
87 | if basis_umap in dyn_adata.obsm.keys():
88 | if f"X_{basis_umap}" not in dyn_adata.obsm.keys():
89 | dyn_adata.obsm[f"X_{basis_umap}"] = dyn_adata.obsm[basis_umap]
90 |
91 | dyn.tl.dynamics(dyn_adata, model="stochastic", cores=3)
92 | dyn.tl.cell_velocities(
93 | dyn_adata,
94 | basis=basis_umap,
95 | method="pearson",
96 | other_kernels_dict={"transform": "sqrt"},
97 | )
98 | dyn.tl.cell_velocities(dyn_adata, basis=basis_pca)
99 |
100 | # Vectorfield
101 | dyn.vf.VectorField(dyn_adata, basis=basis_umap)
102 | dyn.vf.VectorField(dyn_adata, basis=basis_pca)
103 |
104 | # Pesudotime
105 | dyn.ext.ddhodge(dyn_adata, basis=basis_umap)
106 | dyn.ext.ddhodge(dyn_adata, basis=basis_pca)
107 |
108 | # Differnetial geometry
109 | dyn.vf.speed(dyn_adata, basis=basis_pca)
110 | dyn.vf.acceleration(dyn_adata, basis=basis_pca)
111 | dyn.vf.curvature(dyn_adata, basis=basis_pca)
112 | dyn.vf.curl(dyn_adata, basis=basis_pca)
113 | dyn.vf.divergence(dyn_adata, basis=basis_pca)
114 |
115 | # RNA velocity vectors model
116 | if n_vectors_downsampling in [None, "None", "none"]:
117 | ix_choice = np.arange(dyn_adata.shape[0])
118 | else:
119 | ix_choice = np.random.choice(
120 | np.arange(dyn_adata.shape[0]), size=n_vectors_downsampling, replace=False
121 | )
122 |
123 | X = dyn_adata.obsm[f"X_{basis_umap}"][:, [0, 1]]
124 | V = dyn_adata.obsm[f"velocity_{basis_umap}"][:, [0, 1]]
125 | X, V = X[ix_choice, :], V[ix_choice, :]
126 | if X.shape[1] == 2:
127 | df = pd.DataFrame(
128 | {
129 | "x": X[:, 0],
130 | "y": X[:, 1],
131 | "z": np.zeros(shape=(X.shape[0])),
132 | "u": V[:, 0],
133 | "v": V[:, 1],
134 | "w": np.zeros(shape=(V.shape[0])),
135 | }
136 | )
137 | else:
138 | df = pd.DataFrame(
139 | {
140 | "x": X[:, 0],
141 | "y": X[:, 1],
142 | "z": X[:, 2],
143 | "u": V[:, 0],
144 | "v": V[:, 1],
145 | "w": V[:, 2],
146 | }
147 | )
148 | df = df.iloc[ix_choice, :]
149 |
150 | x0, x1, x2 = df.iloc[:, 0], df.iloc[:, 1], df.iloc[:, 2]
151 | v0, v1, v2 = df.iloc[:, 3], df.iloc[:, 4], df.iloc[:, 5]
152 |
153 | point_cloud = pv.PolyData(np.column_stack((x0.values, x1.values, x2.values)))
154 | point_cloud["vectors"] = np.column_stack((v0.values, v1.values, v2.values))
155 | point_cloud.point_data["obs_index"] = dyn_adata.obs.index.tolist()
156 | vectors = point_cloud.glyph(orient="vectors", factor=vectors_size)
157 | vectors.point_data[f"{basis_umap}_ddhodge_potential"] = np.asarray(
158 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[
159 | f"{basis_umap}_ddhodge_potential"
160 | ]
161 | )
162 | vectors.point_data[f"{basis_pca}_ddhodge_potential"] = np.asarray(
163 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[
164 | f"{basis_pca}_ddhodge_potential"
165 | ]
166 | )
167 | vectors.point_data[f"speed_{basis_pca}"] = np.asarray(
168 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[f"speed_{basis_pca}"]
169 | )
170 | vectors.point_data[f"acceleration_{basis_pca}"] = np.asarray(
171 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[
172 | f"acceleration_{basis_pca}"
173 | ]
174 | )
175 | vectors.point_data[f"divergence_{basis_pca}"] = np.asarray(
176 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[
177 | f"divergence_{basis_pca}"
178 | ]
179 | )
180 | vectors.point_data[f"curvature_{basis_pca}"] = np.asarray(
181 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[
182 | f"curvature_{basis_pca}"
183 | ]
184 | )
185 | vectors.point_data[f"curl_{basis_pca}"] = np.asarray(
186 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[f"curl_{basis_pca}"]
187 | )
188 |
189 | pc_model.point_data[f"{basis_umap}_ddhodge_potential"] = np.asarray(
190 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[
191 | f"{basis_umap}_ddhodge_potential"
192 | ]
193 | )
194 | pc_model.point_data[f"{basis_pca}_ddhodge_potential"] = np.asarray(
195 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[
196 | f"{basis_pca}_ddhodge_potential"
197 | ]
198 | )
199 | pc_model.point_data[f"speed_{basis_pca}"] = np.asarray(
200 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[
201 | f"speed_{basis_pca}"
202 | ]
203 | )
204 | pc_model.point_data[f"acceleration_{basis_pca}"] = np.asarray(
205 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[
206 | f"acceleration_{basis_pca}"
207 | ]
208 | )
209 | pc_model.point_data[f"divergence_{basis_pca}"] = np.asarray(
210 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[
211 | f"divergence_{basis_pca}"
212 | ]
213 | )
214 | pc_model.point_data[f"curvature_{basis_pca}"] = np.asarray(
215 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[
216 | f"curvature_{basis_pca}"
217 | ]
218 | )
219 | pc_model.point_data[f"curl_{basis_pca}"] = np.asarray(
220 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[f"curl_{basis_pca}"]
221 | )
222 |
223 | return pc_model, vectors
224 |
--------------------------------------------------------------------------------
/stviewer/Explorer/pv_pipeline/pv_interpolation.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | from tqdm import tqdm
4 |
5 | try:
6 | from typing import Literal
7 | except ImportError:
8 | from typing_extensions import Literal
9 |
10 | # GP model
11 | import gpytorch
12 | import numpy as np
13 | import ot
14 | import pandas as pd
15 | import torch
16 | from anndata import AnnData
17 | from gpytorch.likelihoods import GaussianLikelihood
18 | from gpytorch.models import ApproximateGP, ExactGP
19 | from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
20 | from numpy import ndarray
21 | from scipy.sparse import issparse
22 |
23 |
24 | class Approx_GPModel(ApproximateGP):
25 | def __init__(self, inducing_points):
26 | variational_distribution = CholeskyVariationalDistribution(
27 | inducing_points.size(0)
28 | )
29 | variational_strategy = VariationalStrategy(
30 | self,
31 | inducing_points,
32 | variational_distribution,
33 | learn_inducing_locations=True,
34 | )
35 | super(Approx_GPModel, self).__init__(variational_strategy)
36 | self.mean_module = gpytorch.means.ConstantMean()
37 | self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
38 |
39 | def forward(self, x):
40 | mean_x = self.mean_module(x)
41 | covar_x = self.covar_module(x)
42 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
43 |
44 |
45 | class Exact_GPModel(ExactGP):
46 | def __init__(self, train_x, train_y, likelihood):
47 | super().__init__(train_x, train_y, likelihood)
48 | self.mean_module = gpytorch.means.ZeroMean()
49 | self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
50 |
51 | def forward(self, x):
52 | mean_x = self.mean_module(x)
53 | covar_x = self.covar_module(x)
54 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
55 |
56 |
57 | def gp_train(model, likelihood, train_loader, train_epochs, method, N, device):
58 | if torch.cuda.is_available() and device != "cpu":
59 | model = model.cuda()
60 | likelihood = likelihood.cuda()
61 |
62 | model.train()
63 | likelihood.train()
64 | # define the mll (loss)
65 | if method == "SVGP":
66 | mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=N)
67 | else:
68 | mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
69 |
70 | optimizer = torch.optim.Adam(
71 | [
72 | {"params": model.parameters()},
73 | {"params": likelihood.parameters()},
74 | ],
75 | lr=0.01,
76 | )
77 |
78 | epochs_iter = tqdm(range(train_epochs), desc="Epoch")
79 | for i in epochs_iter:
80 | if method == "SVGP":
81 | # Within each iteration, we will go over each minibatch of data
82 | minibatch_iter = tqdm(train_loader, desc="Minibatch", leave=True)
83 | for x_batch, y_batch in minibatch_iter:
84 | optimizer.zero_grad()
85 | output = model(x_batch)
86 | loss = -mll(output, y_batch)
87 | minibatch_iter.set_postfix(loss=loss.item())
88 | loss.backward()
89 | optimizer.step()
90 | else:
91 | # Zero gradients from previous iteration
92 | optimizer.zero_grad()
93 | # Output from model
94 | output = model(train_loader["train_x"])
95 | # Calc loss and backprop gradients
96 | loss = -mll(output, train_loader["train_y"])
97 | loss.backward()
98 | optimizer.step()
99 |
100 |
101 | nx_torch = lambda nx: True if isinstance(nx, ot.backend.TorchBackend) else False
102 | _chunk = (
103 | lambda nx, x, chunk_num, dim: torch.chunk(x, chunk_num, dim=dim)
104 | if nx_torch(nx)
105 | else np.array_split(x, chunk_num, axis=dim)
106 | )
107 | _unsqueeze = lambda nx: torch.unsqueeze if nx_torch(nx) else np.expand_dims
108 |
109 |
110 | class Imputation_GPR:
111 | def __init__(
112 | self,
113 | source_adata: AnnData,
114 | target_points: Optional[ndarray] = None,
115 | keys: Union[str, list] = None,
116 | spatial_key: str = "spatial",
117 | layer: str = "X",
118 | device: str = "cpu",
119 | method: Literal["SVGP", "ExactGP"] = "SVGP",
120 | batch_size: int = 1024,
121 | shuffle: bool = True,
122 | inducing_num: int = 512,
123 | normalize_spatial: bool = True,
124 | ):
125 | # Source data
126 | source_adata = source_adata.copy()
127 | source_adata.X = source_adata.X if layer == "X" else source_adata.layers[layer]
128 |
129 | source_spatial_data = source_adata.obsm[spatial_key]
130 |
131 | info_data = np.ones(shape=(source_spatial_data.shape[0], 1))
132 | assert keys != None, "`keys` cannot be None."
133 | keys = [keys] if isinstance(keys, str) else keys
134 | obs_keys = [key for key in keys if key in source_adata.obs.keys()]
135 | if len(obs_keys) != 0:
136 | obs_data = np.asarray(source_adata.obs[obs_keys].values)
137 | info_data = np.c_[info_data, obs_data]
138 | var_keys = [key for key in keys if key in source_adata.var_names.tolist()]
139 | if len(var_keys) != 0:
140 | var_data = source_adata[:, var_keys].X
141 | if issparse(var_data):
142 | var_data = var_data.A
143 | info_data = np.c_[info_data, var_data]
144 | info_data = info_data[:, 1:]
145 |
146 | self.device = (
147 | f"cuda:{device}" if torch.cuda.is_available() and device != "cpu" else "cpu"
148 | )
149 | torch.device(self.device)
150 |
151 | self.train_x = torch.from_numpy(source_spatial_data).float()
152 | self.train_y = torch.from_numpy(info_data).float()
153 | if self.device == "cpu":
154 | self.train_x = self.train_x.cpu()
155 | self.train_y = self.train_y.cpu()
156 | else:
157 | self.train_x = self.train_x.cuda()
158 | self.train_y = self.train_y.cuda()
159 | self.train_y = self.train_y.squeeze()
160 |
161 | self.nx = ot.backend.get_backend(self.train_x, self.train_y)
162 |
163 | self.normalize_spatial = normalize_spatial
164 | if self.normalize_spatial:
165 | self.train_x = self.normalize_coords(self.train_x)
166 |
167 | self.N = self.train_x.shape[0]
168 | # create training dataloader
169 | self.method = method
170 | from torch.utils.data import DataLoader, TensorDataset
171 |
172 | if method == "SVGP":
173 | train_dataset = TensorDataset(self.train_x, self.train_y)
174 | self.train_loader = DataLoader(
175 | train_dataset, batch_size=batch_size, shuffle=shuffle
176 | )
177 | inducing_idx = (
178 | np.random.choice(self.train_x.shape[0], inducing_num)
179 | if self.train_x.shape[0] > inducing_num
180 | else np.arange(self.train_x.shape[0])
181 | )
182 | self.inducing_points = self.train_x[inducing_idx, :].clone()
183 | else:
184 | train_loader = {"train_x": self.train_x, "train_y": self.train_y}
185 | # TO-DO: add a dict that contains all the train_x and train_y
186 | # pass
187 |
188 | self.PCA_reduction = False
189 | self.info_keys = {"obs_keys": obs_keys, "var_keys": var_keys}
190 |
191 | # Target data
192 | self.target_points = torch.from_numpy(target_points).float()
193 | self.target_points = (
194 | self.target_points.cpu()
195 | if self.device == "cpu"
196 | else self.target_points.cuda()
197 | )
198 |
199 | def normalize_coords(
200 | self, data: Union[np.ndarray, torch.Tensor], given_normalize: bool = False
201 | ):
202 | if not given_normalize:
203 | self.mean_data = _unsqueeze(self.nx)(self.nx.mean(data, axis=0), 0)
204 | data = data - self.mean_data
205 | if not given_normalize:
206 | self.variance = self.nx.sqrt(self.nx.sum(data**2) / data.shape[0])
207 | data = data / self.variance
208 | return data
209 |
210 | def inference(
211 | self,
212 | training_iter: int = 50,
213 | ):
214 | self.likelihood = GaussianLikelihood()
215 | if self.method == "SVGP":
216 | self.GPR_model = Approx_GPModel(inducing_points=self.inducing_points)
217 | elif self.method == "ExactGP":
218 | self.GPR_model = Exact_GPModel(self.train_x, self.train_y, self.likelihood)
219 | # if to convert to GPU
220 | if self.device != "cpu":
221 | self.GPR_model = self.GPR_model.cuda()
222 | self.likelihood = self.likelihood.cuda()
223 |
224 | # Start training to find optimal model hyperparameters
225 | gp_train(
226 | model=self.GPR_model,
227 | likelihood=self.likelihood,
228 | train_loader=self.train_loader,
229 | train_epochs=training_iter,
230 | method=self.method,
231 | N=self.N,
232 | device=self.device,
233 | )
234 |
235 | self.GPR_model.eval()
236 | self.likelihood.eval()
237 |
238 | def interpolate(
239 | self,
240 | use_chunk: bool = False,
241 | chunk_num: int = 20,
242 | ):
243 | # Get into evaluation (predictive posterior) mode
244 | self.GPR_model.eval()
245 | self.likelihood.eval()
246 |
247 | target_points = self.target_points
248 | if self.normalize_spatial:
249 | target_points = self.normalize_coords(target_points, given_normalize=True)
250 |
251 | if use_chunk:
252 | target_points_s = _chunk(self.nx, target_points, chunk_num, 0)
253 | arr = []
254 | with torch.no_grad(), gpytorch.settings.fast_pred_var():
255 | for target_points_ss in target_points_s:
256 | predictions = self.likelihood(self.GPR_model(target_points_ss)).mean
257 | arr.append(predictions)
258 | quary_target = self.nx.concatenate(arr, axis=0)
259 | else:
260 | with torch.no_grad(), gpytorch.settings.fast_pred_var():
261 | predictions = self.likelihood(self.GPR_model(target_points))
262 | quary_target = predictions.mean
263 |
264 | quary_target = (
265 | np.asarray(quary_target.cpu())
266 | if self.device != "cpu"
267 | else np.asarray(quary_target)
268 | )
269 | return quary_target
270 |
271 |
272 | def gp_interpolation(
273 | source_adata: AnnData,
274 | target_points: Optional[ndarray] = None,
275 | keys: Union[str, list] = None,
276 | spatial_key: str = "spatial",
277 | layer: str = "X",
278 | training_iter: int = 50,
279 | device: str = "cpu",
280 | method: Literal["SVGP", "ExactGP"] = "SVGP",
281 | batch_size: int = 1024,
282 | shuffle: bool = True,
283 | inducing_num: int = 512,
284 | ) -> AnnData:
285 | """
286 | Learn a continuous mapping from space to gene expression pattern with the Gaussian Process method.
287 |
288 | Args:
289 | source_adata: AnnData object that contains spatial (numpy.ndarray) in the `obsm` attribute.
290 | target_points: The spatial coordinates of new data point. If target_coords is None, generate new points based on grid_num.
291 | keys: Gene list or info list in the `obs` attribute whose interpolate expression across space needs to learned.
292 | spatial_key: The key in ``.obsm`` that corresponds to the spatial coordinate of each bucket.
293 | layer: If ``'X'``, uses ``.X``, otherwise uses the representation given by ``.layers[layer]``.
294 | training_iter: Max number of iterations for training.
295 | device: Equipment used to run the program. You can also set the specified GPU for running. ``E.g.: '0'``.
296 |
297 | Returns:
298 | interp_adata: an anndata object that has interpolated expression.
299 | """
300 |
301 | # Inference
302 | GPR = Imputation_GPR(
303 | source_adata=source_adata,
304 | target_points=target_points,
305 | keys=keys,
306 | spatial_key=spatial_key,
307 | layer=layer,
308 | device=device,
309 | method=method,
310 | batch_size=batch_size,
311 | shuffle=shuffle,
312 | inducing_num=inducing_num,
313 | )
314 | GPR.inference(training_iter=training_iter)
315 |
316 | # Interpolation
317 | target_info_data = GPR.interpolate(use_chunk=True)
318 | target_info_data = target_info_data[:, None]
319 |
320 | # Output interpolated anndata
321 | obs_keys = GPR.info_keys["obs_keys"]
322 | if len(obs_keys) != 0:
323 | obs_data = target_info_data[:, : len(obs_keys)]
324 | obs_data = pd.DataFrame(obs_data, columns=obs_keys)
325 |
326 | var_keys = GPR.info_keys["var_keys"]
327 | if len(var_keys) != 0:
328 | X = target_info_data[:, len(obs_keys) :]
329 | var_data = pd.DataFrame(index=var_keys)
330 |
331 | interp_adata = AnnData(
332 | X=X if len(var_keys) != 0 else None,
333 | obs=obs_data if len(obs_keys) != 0 else None,
334 | obsm={spatial_key: np.asarray(target_points)},
335 | var=var_data if len(var_keys) != 0 else None,
336 | )
337 | return interp_adata
338 |
--------------------------------------------------------------------------------
/stviewer/Explorer/pv_pipeline/pv_morphogenesis.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional
2 |
3 | import numpy as np
4 | from anndata import AnnData
5 | from pyvista import PolyData
6 | from scipy.integrate import odeint
7 |
8 |
9 | def morphogenesis(
10 | source_adata: AnnData,
11 | source_pc_model: PolyData,
12 | target_adata: Optional[AnnData] = None,
13 | mapping_method: Literal["GP", "OT"] = "GP",
14 | mapping_factor: float = 0.2,
15 | mapping_device: str = "cpu",
16 | morphofield_factor: int = 3000,
17 | morphopath_t_end: int = 10000,
18 | morphopath_sampling: int = 500,
19 | ):
20 | try:
21 | import spateo as st
22 | except ImportError:
23 | raise ImportError(
24 | "You need to install the package `spateo`. "
25 | "\nInstall spateo via `pip install spateo-release`."
26 | )
27 | try:
28 | from dynamo.vectorfield import SvcVectorField
29 | except ImportError:
30 | raise ImportError(
31 | "You need to install the package `dynamo`. "
32 | "\nInstall dynamo via `pip install dynamo-release`."
33 | )
34 |
35 | # Preprocess
36 | _obs_index = source_pc_model.point_data["obs_index"]
37 | source_adata = source_adata[_obs_index, :]
38 |
39 | # 3D mapping and morphofield
40 | if mapping_method == "OT":
41 | if not (target_adata is None):
42 | _ = st.tdr.cell_directions(
43 | adataA=source_adata,
44 | adataB=target_adata,
45 | numItermaxEmd=2000000,
46 | spatial_key="spatial",
47 | key_added="cells_mapping",
48 | alpha=mapping_factor,
49 | device=mapping_device,
50 | inplace=True,
51 | )
52 |
53 | if "V_cells_mapping" not in source_adata.obsm.keys():
54 | raise ValueError("You need to add the target anndata object. ")
55 |
56 | st.tdr.morphofield_sparsevfc(
57 | adata=source_adata,
58 | spatial_key="spatial",
59 | V_key="V_cells_mapping",
60 | key_added="VecFld_morpho",
61 | NX=None,
62 | inplace=True,
63 | )
64 | elif mapping_method == "GP":
65 | if not (target_adata is None):
66 | align_models, _, _ = st.align.morpho_align_sparse(
67 | models=[target_adata.copy(), source_adata.copy()],
68 | spatial_key="spatial",
69 | key_added="mapping_spatial",
70 | device=mapping_device,
71 | mode="SN-S",
72 | max_iter=200,
73 | partial_robust_level=1,
74 | beta=0.1, # nonrigid,
75 | beta2_end=mapping_factor, # low beta2_end, high expression similarity
76 | lambdaVF=1,
77 | K=200,
78 | SVI_mode=True,
79 | use_sparse=True,
80 | )
81 | source_adata = align_models[1].copy()
82 |
83 | if "VecFld_morpho" not in source_adata.uns.keys():
84 | raise ValueError("You need to add the target anndata object. ")
85 |
86 | st.tdr.morphofield_gp(
87 | adata=source_adata,
88 | spatial_key="spatial",
89 | vf_key="VecFld_morpho",
90 | NX=np.asarray(source_adata.obsm["spatial"]),
91 | inplace=True,
92 | )
93 |
94 | # construct morphofield model
95 | source_adata.obs["V_z"] = source_adata.uns["VecFld_morpho"]["V"][:, 2].flatten()
96 | source_pc_model.point_data["vectors"] = source_adata.uns["VecFld_morpho"]["V"]
97 | source_pc_model.point_data["V_Z"] = source_pc_model.point_data["vectors"][
98 | :, 2
99 | ].flatten()
100 |
101 | pc_vectors, _ = st.tdr.construct_field(
102 | model=source_pc_model,
103 | vf_key="vectors",
104 | arrows_scale_key="vectors",
105 | n_sampling=None,
106 | factor=morphofield_factor,
107 | key_added="obs_index",
108 | label=source_pc_model.point_data["obs_index"],
109 | )
110 |
111 | # construct morphopath model
112 | st.tdr.morphopath(
113 | adata=source_adata,
114 | vf_key="VecFld_morpho",
115 | key_added="fate_morpho",
116 | t_end=morphopath_t_end,
117 | interpolation_num=20,
118 | cores=10,
119 | )
120 | trajectory_model, _ = st.tdr.construct_trajectory(
121 | adata=source_adata,
122 | fate_key="fate_morpho",
123 | n_sampling=morphopath_sampling,
124 | sampling_method="random",
125 | key_added="obs_index",
126 | label=np.asarray(source_adata.obs.index),
127 | )
128 |
129 | # morphometric features
130 | st.tdr.morphofield_acceleration(
131 | adata=source_adata, vf_key="VecFld_morpho", key_added="acceleration"
132 | )
133 | st.tdr.morphofield_curvature(
134 | adata=source_adata, vf_key="VecFld_morpho", key_added="curvature"
135 | )
136 | st.tdr.morphofield_curl(
137 | adata=source_adata, vf_key="VecFld_morpho", key_added="curl"
138 | )
139 | st.tdr.morphofield_torsion(
140 | adata=source_adata, vf_key="VecFld_morpho", key_added="torsion"
141 | )
142 | st.tdr.morphofield_divergence(
143 | adata=source_adata, vf_key="VecFld_morpho", key_added="divergence"
144 | )
145 |
146 | source_pc_index = source_pc_model.point_data["obs_index"]
147 | source_pc_model.point_data["acceleration"] = np.asarray(
148 | source_adata[np.asarray(source_pc_index)].obs["acceleration"]
149 | )
150 | source_pc_model.point_data["curvature"] = np.asarray(
151 | source_adata[np.asarray(source_pc_index)].obs["curvature"]
152 | )
153 | source_pc_model.point_data["curl"] = np.asarray(
154 | source_adata[np.asarray(source_pc_index)].obs["curl"]
155 | )
156 | source_pc_model.point_data["torsion"] = np.asarray(
157 | source_adata[np.asarray(source_pc_index)].obs["torsion"]
158 | )
159 | source_pc_model.point_data["divergence"] = np.asarray(
160 | source_adata[np.asarray(source_pc_index)].obs["divergence"]
161 | )
162 |
163 | pc_vectors_index = pc_vectors.point_data["obs_index"]
164 | pc_vectors.point_data["V_Z"] = np.asarray(
165 | source_adata[np.asarray(pc_vectors_index)].obs["V_z"]
166 | )
167 | pc_vectors.point_data["acceleration"] = np.asarray(
168 | source_adata[np.asarray(pc_vectors_index)].obs["acceleration"]
169 | )
170 | pc_vectors.point_data["curvature"] = np.asarray(
171 | source_adata[np.asarray(pc_vectors_index)].obs["curvature"]
172 | )
173 | pc_vectors.point_data["curl"] = np.asarray(
174 | source_adata[np.asarray(pc_vectors_index)].obs["curl"]
175 | )
176 | pc_vectors.point_data["torsion"] = np.asarray(
177 | source_adata[np.asarray(pc_vectors_index)].obs["torsion"]
178 | )
179 | pc_vectors.point_data["divergence"] = np.asarray(
180 | source_adata[np.asarray(pc_vectors_index)].obs["divergence"]
181 | )
182 |
183 | trajectory_index = trajectory_model.point_data["obs_index"]
184 | trajectory_model.point_data["V_Z"] = np.asarray(
185 | source_adata[np.asarray(trajectory_index)].obs["V_z"]
186 | )
187 | trajectory_model.point_data["acceleration"] = np.asarray(
188 | source_adata[np.asarray(trajectory_index)].obs["acceleration"]
189 | )
190 | trajectory_model.point_data["curvature"] = np.asarray(
191 | source_adata[np.asarray(trajectory_index)].obs["curvature"]
192 | )
193 | trajectory_model.point_data["curl"] = np.asarray(
194 | source_adata[np.asarray(trajectory_index)].obs["curl"]
195 | )
196 | trajectory_model.point_data["torsion"] = np.asarray(
197 | source_adata[np.asarray(trajectory_index)].obs["torsion"]
198 | )
199 | trajectory_model.point_data["divergence"] = np.asarray(
200 | source_adata[np.asarray(trajectory_index)].obs["divergence"]
201 | )
202 |
203 | # cell stages of animation
204 | t_ind = np.asarray(list(source_adata.uns["fate_morpho"]["t"].keys()), dtype=int)
205 | t_sort_ind = np.argsort(t_ind)
206 | t = np.asarray(list(source_adata.uns["fate_morpho"]["t"].values()))[t_sort_ind]
207 | flats = np.unique([int(item) for sublist in t for item in sublist])
208 | flats = np.hstack((0, flats))
209 | flats = np.sort(flats) if 3000 is None else np.sort(flats[flats <= 3000])
210 | time_vec = np.logspace(0, np.log10(max(flats) + 1), 100) - 1
211 | vf = SvcVectorField()
212 | vf.from_adata(source_adata, basis="morpho")
213 | f = lambda x, _: vf.func(x)
214 | displace = lambda x, dt: odeint(f, x, [0, dt])
215 |
216 | init_states = source_adata.uns["fate_morpho"]["init_states"]
217 | pts = [i.tolist() for i in init_states]
218 | stages_X = [source_adata.obs.index.tolist()]
219 | for i in range(100):
220 | pts = [displace(cur_pts, time_vec[i])[1].tolist() for cur_pts in pts]
221 | stages_X.append(pts)
222 |
223 | return source_pc_model, pc_vectors, trajectory_model, stages_X
224 |
--------------------------------------------------------------------------------
/stviewer/Explorer/pv_pipeline/pv_plotter.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import pyvista as pv
4 | from pyvista import Plotter, PolyData, UnstructuredGrid
5 |
6 | try:
7 | from typing import Literal
8 | except ImportError:
9 | from typing_extensions import Literal
10 |
11 |
12 | def create_plotter(
13 | window_size: tuple = (1024, 1024), background: str = "black", **kwargs
14 | ) -> Plotter:
15 | """
16 | Create a plotting object to display pyvista/vtk model.
17 |
18 | Args:
19 | window_size: Window size in pixels. The default window_size is ``[1024, 768]``.
20 | background: The background color of the window.
21 |
22 | Returns:
23 | plotter: The plotting object to display pyvista/vtk model.
24 | """
25 |
26 | # Create an initial plotting object.
27 | plotter = pv.Plotter(
28 | window_size=window_size, off_screen=True, lighting="light_kit", **kwargs
29 | )
30 |
31 | # Set the background color of the active render window.
32 | plotter.background_color = background
33 | return plotter
34 |
35 |
36 | def add_single_model(
37 | plotter: Plotter,
38 | model: Union[PolyData, UnstructuredGrid],
39 | key: Optional[str] = None,
40 | cmap: Optional[str] = "rainbow",
41 | color: Optional[str] = "gainsboro",
42 | ambient: float = 0.2,
43 | opacity: float = 1.0,
44 | model_style: Literal["points", "surface", "wireframe"] = "surface",
45 | model_size: float = 3.0,
46 | model_name: Optional[str] = None,
47 | ):
48 | """
49 | Add model(s) to the plotter.
50 | Args:
51 | plotter: The plotting object to display pyvista/vtk model.
52 | model: A reconstructed model.
53 | key: The key under which are the labels.
54 | cmap: Name of the Matplotlib colormap to use when mapping the model.
55 | color: Name of the Matplotlib color to use when mapping the model.
56 | ambient: When lighting is enabled, this is the amount of light in the range of 0 to 1 (default 0.0) that reaches
57 | the actor when not directed at the light source emitted from the viewer.
58 | opacity: Opacity of the model.
59 | If a single float value is given, it will be the global opacity of the model and uniformly applied
60 | everywhere, elif a numpy.ndarray with single float values is given, it
61 | will be the opacity of each point. - should be between 0 and 1.
62 | A string can also be specified to map the scalars range to a predefined opacity transfer function
63 | (options include: 'linear', 'linear_r', 'geom', 'geom_r').
64 | model_style: Visualization style of the model. One of the following:
65 | * ``model_style = 'surface'``,
66 | * ``model_style = 'wireframe'``,
67 | * ``model_style = 'points'``.
68 | model_size: If ``model_style = 'points'``, point size of any nodes in the dataset plotted.
69 | If ``model_style = 'wireframe'``, thickness of lines.
70 | model_name: Name to assign to the model. Defaults to the memory address.
71 | """
72 |
73 | if model_style == "points":
74 | render_spheres, render_tubes, smooth_shading = True, False, True
75 | elif model_style == "wireframe":
76 | render_spheres, render_tubes, smooth_shading = False, True, False
77 | else:
78 | render_spheres, render_tubes, smooth_shading = False, False, True
79 | mesh_kwargs = dict(
80 | scalars=key if key in model.array_names else None,
81 | style=model_style,
82 | render_points_as_spheres=render_spheres,
83 | render_lines_as_tubes=render_tubes,
84 | point_size=model_size,
85 | line_width=model_size,
86 | ambient=ambient,
87 | opacity=opacity,
88 | smooth_shading=smooth_shading,
89 | show_scalar_bar=False,
90 | cmap=cmap,
91 | color=color,
92 | name=model_name,
93 | )
94 | actor = plotter.add_mesh(model, **mesh_kwargs)
95 | return actor
96 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/__init__.py:
--------------------------------------------------------------------------------
1 | from .container import ui_container
2 | from .drawer import ui_drawer
3 | from .layout import ui_layout
4 | from .toolbar import toolbar_switch_model, toolbar_widgets, ui_title, ui_toolbar
5 | from .utils import button, checkbox, switch
6 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/container.py:
--------------------------------------------------------------------------------
1 | try:
2 | from typing import Literal
3 | except ImportError:
4 | from typing_extensions import Literal
5 |
6 | from pyvista import BasePlotter
7 | from pyvista.trame import PyVistaLocalView, PyVistaRemoteLocalView, PyVistaRemoteView
8 | from trame.widgets import vuetify
9 |
10 | from .toolbar import Viewer
11 |
12 | # -----------------------------------------------------------------------------
13 | # GUI- standard Container
14 | # -----------------------------------------------------------------------------
15 |
16 |
17 | def ui_container(
18 | server,
19 | layout,
20 | plotter: BasePlotter,
21 | mode: Literal["trame", "server", "client"] = "trame",
22 | default_server_rendering: bool = True,
23 | **kwargs,
24 | ):
25 | """
26 | Generate standard VContainer for Spateo UI.
27 |
28 | Args:
29 | server: The trame server.
30 | layout: The layout object.
31 | plotter: The PyVista plotter to connect with the UI.
32 | mode: The UI view mode. Options are:
33 |
34 | * ``'trame'``: Uses a view that can switch between client and server rendering modes.
35 | * ``'server'``: Uses a view that is purely server rendering.
36 | * ``'client'``: Uses a view that is purely client rendering (generally safe without a virtual frame buffer)
37 | default_server_rendering: Whether to use server-side or client-side rendering on-start when using the ``'trame'`` mode.
38 | kwargs: Additional parameters that will be passed to ``pyvista.trame.app.PyVistaXXXXView`` function.
39 | """
40 | if mode != "trame":
41 | default_server_rendering = mode == "server"
42 |
43 | viewer = Viewer(plotter, server, suppress_rendering=mode == "client")
44 | ctrl = server.controller
45 |
46 | with layout.content as con:
47 | with vuetify.VContainer(
48 | fluid=True,
49 | classes="pa-0 fill-height",
50 | ):
51 | if mode == "trame":
52 | view = PyVistaRemoteLocalView(
53 | plotter,
54 | mode=(
55 | # Must use single-quote string for JS here
56 | f"{viewer.SERVER_RENDERING} ? 'remote' : 'local'",
57 | "remote" if default_server_rendering else "local",
58 | ),
59 | **kwargs,
60 | )
61 | ctrl.view_update_image = view.update_image
62 | elif mode == "server":
63 | view = PyVistaRemoteView(plotter, **kwargs)
64 | elif mode == "client":
65 | view = PyVistaLocalView(plotter, **kwargs)
66 |
67 | ctrl.view_update = view.update
68 | ctrl.view_reset_camera = view.reset_camera
69 | ctrl.view_push_camera = view.push_camera
70 | ctrl.on_server_ready.add(view.update)
71 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/drawer/__init__.py:
--------------------------------------------------------------------------------
1 | from .main import ui_drawer
2 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/drawer/adata_obj.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import vuetify
2 |
3 |
4 | def anndata_object_content():
5 | with vuetify.VRow(classes="pt-2", dense=True):
6 | with vuetify.VCol(cols="12"):
7 | vuetify.VTextarea(
8 | v_model=("anndata_info.anndata_structure",),
9 | label="AnnData Structure",
10 | hide_details=True,
11 | dense=True,
12 | outlined=True,
13 | classes="pt-1",
14 | )
15 | with vuetify.VRow(classes="pt-2", dense=True):
16 | with vuetify.VCol(cols="12"):
17 | vuetify.VFileInput(
18 | v_model=("uploaded_anndata_path", None),
19 | label="Upload AnnData (.h5ad)",
20 | show_size=True,
21 | small_chips=True,
22 | dense=True,
23 | outlined=True,
24 | hide_details=True,
25 | classes="pt-1",
26 | rounded=False,
27 | accept=".h5ad",
28 | __properties=["accept"],
29 | )
30 |
31 |
32 | def anndata_panel():
33 | with vuetify.VToolbar(
34 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;"
35 | ):
36 | vuetify.VIcon("mdi-apps")
37 | vuetify.VCardTitle(
38 | " AnnData Object",
39 | classes="pa-0 ma-0",
40 | style="flex: none;",
41 | hide_details=True,
42 | dense=True,
43 | )
44 |
45 | vuetify.VSpacer()
46 | with vuetify.VBtn(
47 | small=True,
48 | icon=True,
49 | click="show_anndata_card = !show_anndata_card",
50 | ):
51 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_anndata_card",))
52 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_anndata_card",))
53 |
54 | # Main content
55 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"):
56 | with vuetify.VCardText(classes="py-2", v_if=("show_anndata_card",)):
57 | anndata_object_content()
58 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/drawer/custom_card.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import vuetify
2 |
3 |
4 | def custom_card_content():
5 | with vuetify.VRow(classes="pt-2", dense=True):
6 | with vuetify.VCol(cols="6"):
7 | vuetify.VCheckbox(
8 | v_model=("custom_analysis", False),
9 | label="Custom analysis calculation",
10 | on_icon="mdi-transit-detour",
11 | off_icon="mdi-transit-skip",
12 | dense=True,
13 | hide_details=True,
14 | classes="pt-1",
15 | )
16 | with vuetify.VCol(cols="6"):
17 | vuetify.VCheckbox(
18 | v_model=("custom_model_visible", False),
19 | label="Custom model visibility",
20 | on_icon="mdi-pyramid",
21 | off_icon="mdi-pyramid-off",
22 | dense=True,
23 | hide_details=True,
24 | classes="pt-1",
25 | )
26 |
27 | with vuetify.VRow(classes="pt-2", dense=True):
28 | with vuetify.VCol(cols="6"):
29 | vuetify.VSelect(
30 | v_model=("custom_parameter1", "X"),
31 | items=("matrices_list",),
32 | label="Custom parameter 1",
33 | hide_details=True,
34 | dense=True,
35 | outlined=True,
36 | classes="pt-1",
37 | )
38 | with vuetify.VCol(cols="6"):
39 | vuetify.VSelect(
40 | v_model=("custom_parameter2", "recipe_monocle"),
41 | items=(["False", "recipe_monocle", "pearson_residuals"],),
42 | label="Custom parameter 2",
43 | hide_details=True,
44 | dense=True,
45 | outlined=True,
46 | classes="pt-1",
47 | )
48 |
49 | with vuetify.VRow(classes="pt-2", dense=True):
50 | with vuetify.VCol(cols="6"):
51 | vuetify.VTextField(
52 | v_model=("custom_parameter3", "pca"),
53 | label="Custom parameter 3",
54 | hide_details=True,
55 | dense=True,
56 | outlined=True,
57 | classes="pt-1",
58 | )
59 | with vuetify.VCol(cols="6"):
60 | vuetify.VTextField(
61 | v_model=("custom_parameter4", "umap"),
62 | label="Custom parameter 4",
63 | hide_details=True,
64 | dense=True,
65 | outlined=True,
66 | classes="pt-1",
67 | )
68 |
69 | with vuetify.VRow(classes="pt-2", dense=True):
70 | vuetify.VCheckbox(
71 | v_model=("custom_parameter5", False),
72 | label="Custom parameter 5",
73 | on_icon="mdi-plus-thick",
74 | off_icon="mdi-close-thick",
75 | dense=True,
76 | hide_details=True,
77 | classes="pt-1",
78 | )
79 | with vuetify.VCol(cols="6"):
80 | vuetify.VTextField(
81 | v_model=("custom_parameter6", "None"),
82 | label="Custom parameter 6",
83 | hide_details=True,
84 | dense=True,
85 | outlined=True,
86 | classes="pt-1",
87 | )
88 |
89 | with vuetify.VRow(classes="pt-2", dense=True):
90 | with vuetify.VCol(cols="6"):
91 | vuetify.VTextField(
92 | v_model=("custom_parameter7", 30),
93 | label="Custom parameter 7",
94 | hide_details=True,
95 | dense=True,
96 | outlined=True,
97 | classes="pt-1",
98 | )
99 | with vuetify.VCol(cols="6"):
100 | vuetify.VTextField(
101 | v_model=("custom_parameter8", 30),
102 | label="Custom parameter 8",
103 | hide_details=True,
104 | dense=True,
105 | outlined=True,
106 | classes="pt-1",
107 | )
108 |
109 | with vuetify.VRow(classes="pt-2", dense=True):
110 | with vuetify.VCol(cols="6"):
111 | vuetify.VTextField(
112 | v_model=("custom_parameter9", "None"),
113 | label="Custom parameter 9",
114 | hide_details=True,
115 | dense=True,
116 | outlined=True,
117 | classes="pt-1",
118 | )
119 | with vuetify.VCol(cols="6"):
120 | vuetify.VTextField(
121 | v_model=("custom_parameter10", 1),
122 | label="Custom parameter 10",
123 | hide_details=True,
124 | dense=True,
125 | outlined=True,
126 | classes="pt-1",
127 | )
128 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/drawer/main.py:
--------------------------------------------------------------------------------
1 | try:
2 | from typing import Literal
3 | except ImportError:
4 | from typing_extensions import Literal
5 |
6 | from pyvista import BasePlotter
7 | from trame.widgets import vuetify
8 |
9 |
10 | def _get_spateo_cmap():
11 | import matplotlib as mpl
12 | from matplotlib.colors import LinearSegmentedColormap
13 |
14 | if "spateo_cmap" not in mpl.colormaps():
15 | colors = ["#4B0082", "#800080", "#F97306", "#FFA500", "#FFD700", "#FFFFCB"]
16 | nodes = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
17 |
18 | mpl.colormaps.register(
19 | LinearSegmentedColormap.from_list("spateo_cmap", list(zip(nodes, colors)))
20 | )
21 | return "spateo_cmap"
22 |
23 |
24 | def ui_drawer(
25 | server,
26 | layout,
27 | plotter: BasePlotter,
28 | mode: Literal["trame", "server", "client"] = "trame",
29 | ):
30 | """
31 | Generate standard Drawer for Spateo UI.
32 |
33 | Args:
34 | server: The trame server.
35 | layout: The layout object.
36 | plotter: The PyVista plotter to connect with the UI.
37 | mode: The UI view mode. Options are:
38 |
39 | * ``'trame'``: Uses a view that can switch between client and server rendering modes.
40 | * ``'server'``: Uses a view that is purely server rendering.
41 | * ``'client'``: Uses a view that is purely client rendering (generally safe without a virtual frame buffer)
42 | """
43 |
44 | _get_spateo_cmap()
45 | from stviewer.Explorer.pv_pipeline import PVCB
46 |
47 | PVCB(server=server, plotter=plotter, suppress_rendering=mode == "client")
48 |
49 | with layout.drawer as dr:
50 | # Pipeline
51 | from .pipeline import pipeline_panel
52 |
53 | pipeline_panel(server=server, plotter=plotter)
54 |
55 | # AnnData object
56 | from .adata_obj import anndata_panel
57 |
58 | anndata_panel()
59 |
60 | # Active model
61 | with vuetify.VToolbar(
62 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;"
63 | ):
64 | vuetify.VIcon("mdi-format-paint")
65 | vuetify.VCardTitle(
66 | " Active Model",
67 | classes="pa-0 ma-0",
68 | style="flex: none;",
69 | hide_details=True,
70 | dense=True,
71 | )
72 |
73 | vuetify.VSpacer()
74 | with vuetify.VBtn(
75 | small=True,
76 | icon=True,
77 | click="show_model_card = !show_model_card",
78 | ):
79 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_model_card",))
80 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_model_card",))
81 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"):
82 | with vuetify.VCardText(
83 | classes="py-2 ",
84 | v_show=f"active_model_type === 'PC'",
85 | v_if=("show_model_card",),
86 | ):
87 | items = ["Model", "Morphogenesis"]
88 | items = (
89 | items + ["Custom"] if server.state.custom_func is True else items
90 | )
91 | with vuetify.VTabs(v_model=("pc_active_tab", 0), left=True):
92 | for item in items:
93 | vuetify.VTab(
94 | item,
95 | style="width: 50%;",
96 | )
97 | with vuetify.VTabsItems(
98 | value=("pc_active_tab",),
99 | style="width: 100%; height: 100%;",
100 | ):
101 | with vuetify.VTabItem(value=(0,)):
102 | from .model_point import pc_card_content
103 |
104 | pc_card_content()
105 | with vuetify.VTabItem(value=(1,)):
106 | from .morphogenesis import morphogenesis_card_content
107 |
108 | morphogenesis_card_content()
109 | # Custom
110 | if server.state.custom_func is True:
111 | with vuetify.VTabItem(value=(2,)):
112 | from .custom_card import custom_card_content
113 |
114 | custom_card_content()
115 | with vuetify.VCardText(
116 | classes="py-2",
117 | v_show=f"active_model_type === 'Mesh'",
118 | v_if=("show_model_card",),
119 | ):
120 | from .model_mesh import mesh_card_content
121 |
122 | mesh_card_content()
123 |
124 | # Output
125 | from .output import output_panel
126 |
127 | output_panel()
128 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/drawer/model_mesh.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import vuetify
2 |
3 |
4 | def mesh_card_content():
5 | with vuetify.VRow(classes="pt-2", dense=True):
6 | with vuetify.VCol(cols="6"):
7 | vuetify.VCombobox(
8 | label="Color",
9 | v_model=("mesh_color_value", "None"),
10 | items=("mesh_colors_list",),
11 | type="str",
12 | hide_details=True,
13 | dense=True,
14 | outlined=True,
15 | classes="pt-1",
16 | )
17 | with vuetify.VCol(cols="6"):
18 | vuetify.VCombobox(
19 | label="Style",
20 | v_model=("mesh_style_value", "surface"),
21 | items=(f"styles", ["surface", "points", "wireframe"]),
22 | hide_details=True,
23 | dense=True,
24 | outlined=True,
25 | classes="pt-1",
26 | )
27 | # Opacity
28 | vuetify.VSlider(
29 | v_model=("mesh_opacity_value", 0.2),
30 | min=0,
31 | max=1,
32 | step=0.05,
33 | label="Opacity",
34 | classes="mt-1",
35 | hide_details=True,
36 | dense=True,
37 | )
38 | # Ambient
39 | vuetify.VSlider(
40 | v_model=("mesh_ambient_value", 0.2),
41 | min=0,
42 | max=1,
43 | step=0.05,
44 | label="Ambient",
45 | classes="mt-1",
46 | hide_details=True,
47 | dense=True,
48 | )
49 | vuetify.VCheckbox(
50 | v_model=("mesh_morphology", False),
51 | label="Model Morphological Metrics",
52 | on_icon="mdi-pencil-ruler",
53 | off_icon="mdi-ruler",
54 | dense=True,
55 | hide_details=True,
56 | classes="mt-1",
57 | )
58 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/drawer/model_point.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from trame.widgets import vuetify
3 |
4 |
5 | def pc_card_content():
6 | with vuetify.VRow(classes="pt-2", dense=True):
7 | with vuetify.VCol(cols="6"):
8 | vuetify.VCombobox(
9 | label="Observation Annotation",
10 | v_model=("pc_obs_value", None),
11 | items=("available_obs",),
12 | type="str",
13 | show_size=True,
14 | dense=True,
15 | outlined=True,
16 | hide_details=True,
17 | classes="pt-1",
18 | )
19 | with vuetify.VCol(cols="6"):
20 | vuetify.VCheckbox(
21 | v_model=("pc_reload", False),
22 | label="Reload Model",
23 | on_icon="mdi-restore",
24 | off_icon="mdi-restore",
25 | dense=True,
26 | hide_details=True,
27 | classes="pt-1",
28 | )
29 |
30 | with vuetify.VRow(classes="pt-2", dense=True):
31 | with vuetify.VCol(cols="6"):
32 | vuetify.VCombobox(
33 | v_model=("pc_picking_group", None),
34 | items=("Object.keys(pc_scalars_raw)",),
35 | label="Picking Group",
36 | show_size=True,
37 | dense=True,
38 | outlined=True,
39 | hide_details=True,
40 | classes="pt-1",
41 | )
42 | with vuetify.VCol(cols="6"):
43 | vuetify.VCheckbox(
44 | v_model=("pc_overwrite", False),
45 | label="Add Picked Group",
46 | on_icon="mdi-plus-thick",
47 | off_icon="mdi-close-thick",
48 | dense=True,
49 | hide_details=True,
50 | classes="pt-1",
51 | )
52 |
53 | with vuetify.VRow(classes="pt-2", dense=True):
54 | with vuetify.VCol(cols="6"):
55 | vuetify.VCombobox(
56 | label="Available Genes",
57 | v_model=("pc_gene_value", None),
58 | items=("available_genes",),
59 | type="str",
60 | show_size=True,
61 | dense=True,
62 | outlined=True,
63 | hide_details=True,
64 | classes="pt-1",
65 | )
66 | with vuetify.VCol(cols="6"):
67 | vuetify.VCheckbox(
68 | v_model=("pc_add_legend", True),
69 | label="Add Legend",
70 | on_icon="mdi-view-grid-plus",
71 | off_icon="mdi-view-grid",
72 | dense=True,
73 | hide_details=True,
74 | classes="pt-1",
75 | )
76 | with vuetify.VRow(classes="pt-2", dense=True):
77 | with vuetify.VCol(cols="6"):
78 | vuetify.VTextField(
79 | v_model=("interpolation_device", "cpu"),
80 | label="Interpolation Device",
81 | hide_details=True,
82 | dense=True,
83 | outlined=True,
84 | classes="pt-1",
85 | )
86 | with vuetify.VCol(cols="6"):
87 | vuetify.VCheckbox(
88 | v_model=("cal_interpolation", True),
89 | label="GP Interpolation",
90 | on_icon="mdi-smoke-detector",
91 | off_icon="mdi-smoke-detector-off",
92 | dense=True,
93 | hide_details=True,
94 | classes="pt-1",
95 | )
96 |
97 | with vuetify.VRow(classes="pt-2", dense=True):
98 | with vuetify.VCol(cols="6"):
99 | vuetify.VCombobox(
100 | v_model=("pc_coords_value", "spatial"),
101 | items=("anndata_info.anndata_obsm_keys",),
102 | label="Coords",
103 | type="str",
104 | hide_details=True,
105 | dense=True,
106 | outlined=True,
107 | classes="pt-1",
108 | )
109 | with vuetify.VCol(cols="6"):
110 | vuetify.VCombobox(
111 | v_model=("pc_matrix_value", "X"),
112 | items=("anndata_info.anndata_matrices",),
113 | label="Matrices",
114 | type="str",
115 | hide_details=True,
116 | dense=True,
117 | outlined=True,
118 | classes="pt-1",
119 | )
120 |
121 | with vuetify.VRow(classes="pt-2", dense=True):
122 | with vuetify.VCol(cols="6"):
123 | vuetify.VCombobox(
124 | v_model=("pc_color_value", "None"),
125 | items=("pc_colors_list",),
126 | label="Color",
127 | type="str",
128 | hide_details=True,
129 | dense=True,
130 | outlined=True,
131 | classes="pt-1",
132 | )
133 | # Colormap
134 | with vuetify.VCol(cols="6"):
135 | vuetify.VCombobox(
136 | v_model=("pc_colormap_value", "Spectral"),
137 | items=("pc_colormaps_list",),
138 | label="Colormap",
139 | type="str",
140 | hide_details=True,
141 | dense=True,
142 | outlined=True,
143 | classes="pt-1",
144 | )
145 |
146 | # Opacity
147 | vuetify.VSlider(
148 | v_model=("pc_opacity_value", 1.0),
149 | min=0,
150 | max=1,
151 | step=0.05,
152 | label="Opacity",
153 | classes="mt-1",
154 | hide_details=True,
155 | dense=True,
156 | )
157 | # Ambient
158 | vuetify.VSlider(
159 | v_model=("pc_ambient_value", 0.2),
160 | min=0,
161 | max=1,
162 | step=0.05,
163 | label="Ambient",
164 | classes="mt-1",
165 | hide_details=True,
166 | dense=True,
167 | )
168 | # Point size
169 | vuetify.VSlider(
170 | v_model=("pc_point_size_value", 2),
171 | min=0,
172 | max=20,
173 | step=1,
174 | label="Point Size",
175 | classes="mt-1",
176 | hide_details=True,
177 | dense=True,
178 | )
179 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/drawer/morphogenesis.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import vuetify
2 |
3 |
4 | def morphogenesis_card_content():
5 | with vuetify.VRow(classes="pt-2", dense=True):
6 | with vuetify.VCol(cols="12"):
7 | vuetify.VCheckbox(
8 | v_model=("cal_morphogenesis", False),
9 | label="Calculate the Morphogenesis",
10 | on_icon="mdi-transit-detour",
11 | off_icon="mdi-transit-skip",
12 | dense=True,
13 | hide_details=True,
14 | classes="pt-1",
15 | )
16 |
17 | with vuetify.VRow(classes="pt-2", dense=True):
18 | with vuetify.VCol(cols="12"):
19 | vuetify.VFileInput(
20 | v_model=("morpho_uploaded_target_anndata_path", None),
21 | label="Upload Target AnnData (.h5ad)",
22 | dense=True,
23 | classes="pt-1",
24 | accept=".h5ad",
25 | hide_details=True,
26 | show_size=True,
27 | small_chips=True,
28 | truncate_length=0,
29 | outlined=True,
30 | __properties=["accept"],
31 | )
32 | with vuetify.VRow(classes="pt-2", dense=True):
33 | avaliable_samples = [
34 | "uploaded_target_anndata",
35 | ]
36 | with vuetify.VCol(cols="12"):
37 | vuetify.VSelect(
38 | label="Target AnnData",
39 | v_model=("morpho_target_anndata_path", None),
40 | items=(avaliable_samples,),
41 | dense=True,
42 | outlined=True,
43 | hide_details=True,
44 | classes="pt-1",
45 | )
46 |
47 | with vuetify.VRow(classes="pt-2", dense=True):
48 | with vuetify.VCol(cols="6"):
49 | vuetify.VSelect(
50 | v_model=("morpho_mapping_method", "GP"),
51 | items=(["OT", "GP"],),
52 | label="Mapping Method",
53 | hide_details=True,
54 | dense=True,
55 | outlined=True,
56 | classes="pt-1",
57 | )
58 | with vuetify.VCol(cols="6"):
59 | vuetify.VTextField(
60 | v_model=("morpho_mapping_device", "cpu"),
61 | label="Mapping Device",
62 | hide_details=True,
63 | dense=True,
64 | outlined=True,
65 | classes="pt-1",
66 | )
67 |
68 | with vuetify.VRow(classes="pt-2", dense=True):
69 | with vuetify.VCol(cols="6"):
70 | vuetify.VTextField(
71 | v_model=("morpho_mapping_factor", 0.2),
72 | label="Mapping Factor",
73 | hide_details=True,
74 | dense=True,
75 | outlined=True,
76 | classes="pt-1",
77 | )
78 | with vuetify.VCol(cols="6"):
79 | vuetify.VTextField(
80 | v_model=("morphofield_factor", 3000),
81 | label="Morphofield Factor",
82 | hide_details=True,
83 | dense=True,
84 | outlined=True,
85 | classes="pt-1",
86 | )
87 | with vuetify.VRow(classes="pt-2", dense=True):
88 | with vuetify.VCol(cols="6"):
89 | vuetify.VTextField(
90 | v_model=("morphopath_t_end", 10000),
91 | label="Morphopath Length",
92 | hide_details=True,
93 | dense=True,
94 | outlined=True,
95 | classes="pt-1",
96 | )
97 | with vuetify.VCol(cols="6"):
98 | vuetify.VTextField(
99 | v_model=("morphopath_downsampling", 500),
100 | label="Morphopath Sampling",
101 | hide_details=True,
102 | dense=True,
103 | outlined=True,
104 | classes="pt-1",
105 | )
106 | with vuetify.VRow(classes="pt-2", dense=True):
107 | with vuetify.VCol(cols="6"):
108 | vuetify.VCheckbox(
109 | v_model=("morphofield_visibile", False),
110 | label="Morphofield Visibility",
111 | on_icon="mdi-pyramid",
112 | off_icon="mdi-pyramid-off",
113 | dense=True,
114 | hide_details=True,
115 | classes="pt-1",
116 | )
117 | with vuetify.VCol(cols="6"):
118 | vuetify.VCheckbox(
119 | v_model=("morphopath_visibile", False),
120 | label="Morphopath Visibility",
121 | on_icon="mdi-octahedron",
122 | off_icon="mdi-octahedron-off",
123 | dense=True,
124 | hide_details=True,
125 | classes="pt-1",
126 | )
127 |
128 | with vuetify.VRow(classes="pt-2", dense=True):
129 | with vuetify.VCol(cols="12"):
130 | vuetify.VTextField(
131 | v_model=("morphopath_animation_path", None),
132 | label="Morphogenesis Animation Output (MP4)",
133 | hide_details=True,
134 | dense=True,
135 | outlined=True,
136 | classes="pt-1",
137 | )
138 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/drawer/output.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import vuetify
2 |
3 |
4 | def output_screenshot_content():
5 | with vuetify.VRow(classes="pt-2", dense=True):
6 | with vuetify.VCol(cols="12"):
7 | vuetify.VTextField(
8 | v_model=("screenshot_path", None),
9 | label="Screenshot Output (PNG or PDF)",
10 | hide_details=True,
11 | dense=True,
12 | outlined=True,
13 | classes="pt-1",
14 | )
15 |
16 |
17 | def output_animation_content():
18 | with vuetify.VRow(classes="pt-2", dense=True):
19 | with vuetify.VCol(cols="6"):
20 | vuetify.VTextField(
21 | v_model=("animation_npoints", 50),
22 | label="Animation N Points",
23 | hide_details=True,
24 | dense=True,
25 | outlined=True,
26 | classes="pt-1",
27 | )
28 | with vuetify.VCol(cols="6"):
29 | vuetify.VTextField(
30 | v_model=("animation_framerate", 10),
31 | label="Animation Framerate",
32 | hide_details=True,
33 | dense=True,
34 | outlined=True,
35 | classes="pt-1",
36 | )
37 |
38 | with vuetify.VRow(classes="pt-2", dense=True):
39 | with vuetify.VCol(cols="12"):
40 | vuetify.VTextField(
41 | v_model=("animation_path", None),
42 | label="Animation Output (MP4)",
43 | hide_details=True,
44 | dense=True,
45 | outlined=True,
46 | classes="pt-1",
47 | )
48 |
49 |
50 | def output_panel():
51 | with vuetify.VToolbar(
52 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;"
53 | ):
54 | vuetify.VIcon("mdi-download")
55 | vuetify.VCardTitle(
56 | " Output Widgets",
57 | classes="pa-0 ma-0",
58 | style="flex: none;",
59 | hide_details=True,
60 | dense=True,
61 | )
62 |
63 | vuetify.VSpacer()
64 | with vuetify.VBtn(
65 | small=True,
66 | icon=True,
67 | click="show_output_card = !show_output_card",
68 | ):
69 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_output_card",))
70 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_output_card",))
71 |
72 | # Main content
73 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"):
74 | with vuetify.VCardText(classes="py-2", v_if=("show_output_card",)):
75 | items = ["Screenshot", "Animation"]
76 | with vuetify.VTabs(v_model=("output_active_tab", 0), left=True):
77 | for item in items:
78 | vuetify.VTab(
79 | item,
80 | style="width: 50%;",
81 | )
82 | with vuetify.VTabsItems(
83 | value=("output_active_tab",),
84 | style="width: 100%; height: 100%;",
85 | ):
86 | with vuetify.VTabItem(value=(0,)):
87 | output_screenshot_content()
88 | with vuetify.VTabItem(value=(1,)):
89 | output_animation_content()
90 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/drawer/pipeline.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import html, trame, vuetify
2 |
3 | from stviewer.Explorer.pv_pipeline.init_parameters import (
4 | init_mesh_parameters,
5 | init_morphogenesis_parameters,
6 | init_pc_parameters,
7 | )
8 |
9 |
10 | def pipeline_content(server, plotter):
11 | # server
12 | state, ctrl = server.state, server.controller
13 |
14 | @ctrl.set("actives_change")
15 | def actives_change(ids):
16 | _id = ids[0]
17 | active_actor_id = state.actor_ids[int(_id) - 1]
18 | state.active_ui = active_actor_id
19 | state.active_model_type = str(state.active_ui).split("_")[0]
20 | state.active_id = int(_id)
21 |
22 | if state.active_ui.startswith("PC"):
23 | state.update(init_pc_parameters)
24 | state.update(init_morphogenesis_parameters)
25 | elif state.active_ui.startswith("Mesh"):
26 | state.update(init_mesh_parameters)
27 | ctrl.view_update()
28 |
29 | @ctrl.set("visibility_change")
30 | def visibility_change(event):
31 | _id = event["id"]
32 | _visibility = event["visible"]
33 | active_actor = [value for value in plotter.actors.values()][int(_id) - 1]
34 | active_actor.SetVisibility(_visibility)
35 | if _visibility is True:
36 | state.vis_ids.append(int(_id) - 1)
37 | else:
38 | state.vis_ids.remove(int(_id) - 1)
39 | state.vis_ids = list(set(state.vis_ids))
40 | ctrl.view_update()
41 |
42 | # main content
43 | trame.GitTree(
44 | sources=("pipeline",),
45 | actives_change=(ctrl.actives_change, "[$event]"),
46 | visibility_change=(ctrl.visibility_change, "[$event]"),
47 | )
48 |
49 |
50 | def pipeline_panel(server, plotter):
51 | # Logo and title
52 | with vuetify.VToolbar(
53 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;"
54 | ):
55 | # Logo and title
56 | vuetify.VIcon("mdi-source-branch", style="transform: scale(1, -1);")
57 | vuetify.VCardTitle(
58 | " Pipeline",
59 | classes="pa-0 ma-0",
60 | style="flex: none;",
61 | hide_details=True,
62 | dense=True,
63 | )
64 |
65 | # Main content
66 | pipeline_content(server=server, plotter=plotter)
67 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/layout.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | # -----------------------------------------------------------------------------
4 | # GUI layout
5 | # -----------------------------------------------------------------------------
6 |
7 |
8 | def ui_layout(
9 | server, template_name: str = "main", drawer_width: Optional[int] = None, **kwargs
10 | ):
11 | """
12 | Define the user interface (UI) layout.
13 | Reference: https://trame.readthedocs.io/en/latest/trame.ui.vuetify.html#trame.ui.vuetify.SinglePageWithDrawerLayout
14 |
15 | Args:
16 | server: Server to bound the layout to.
17 | template_name: Name of the template.
18 | drawer_width: Drawer width in pixel.
19 |
20 | Returns:
21 | The SinglePageWithDrawerLayout layout object.
22 | """
23 | from trame.ui.vuetify import SinglePageWithDrawerLayout
24 |
25 | if drawer_width is None:
26 | import pyautogui
27 |
28 | screen_width, screen_height = pyautogui.size()
29 | drawer_width = int(screen_width * 0.15)
30 |
31 | return SinglePageWithDrawerLayout(
32 | server, template_name=template_name, width=drawer_width, **kwargs
33 | )
34 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/toolbar.py:
--------------------------------------------------------------------------------
1 | try:
2 | from typing import Literal
3 | except ImportError:
4 | from typing_extensions import Literal
5 |
6 | from typing import Optional
7 |
8 | from pyvista import BasePlotter
9 | from trame.widgets import html, vuetify
10 |
11 | from stviewer.assets import icon_manager, local_dataset_manager
12 | from stviewer.Explorer.pv_pipeline import SwitchModels, Viewer
13 |
14 | from .utils import button, checkbox
15 |
16 | # -----------------------------------------------------------------------------
17 | # GUI- UI title
18 | # -----------------------------------------------------------------------------
19 |
20 |
21 | def ui_title(
22 | layout, title_name="SPATEO VIEWER", title_icon: Optional[str] = None, **kwargs
23 | ):
24 | """
25 | Define the title name and logo of the UI.
26 | Reference: https://trame.readthedocs.io/en/latest/trame.ui.vuetify.html#trame.ui.vuetify.SinglePageWithDrawerLayout
27 |
28 | Args:
29 | layout: The layout object.
30 | title_name: Title name of the GUI.
31 | title_icon: Title icon of the GUI.
32 | **kwargs: Additional parameters that will be passed to ``html.Img`` function.
33 |
34 | Returns:
35 | None.
36 | """
37 |
38 | # Update the toolbar's name
39 | layout.title.set_text(title_name)
40 | layout.title.style = (
41 | "font-family:arial; font-size:25px; font-weight: 550; color: gray;"
42 | )
43 |
44 | # Update the toolbar's icon
45 | if not (title_icon is None):
46 | with layout.icon as icon:
47 | icon.style = "margin-left: 10px;" # "width: 7vw; height: 7vh;"
48 | html.Img(src=title_icon, height=40, **kwargs)
49 |
50 |
51 | # -----------------------------------------------------------------------------
52 | # GUI- standard ToolBar
53 | # -----------------------------------------------------------------------------
54 |
55 |
56 | def toolbar_widgets(
57 | server,
58 | plotter: BasePlotter,
59 | mode: Literal["trame", "server", "client"] = "trame",
60 | default_server_rendering: bool = True,
61 | ):
62 | """
63 | Generate standard widgets for ToolBar.
64 |
65 | Args:
66 | server: The trame server.
67 | plotter: The PyVista plotter to connect with the UI.
68 | mode: The UI view mode. Options are:
69 |
70 | * ``'trame'``: Uses a view that can switch between client and server rendering modes.
71 | * ``'server'``: Uses a view that is purely server rendering.
72 | * ``'client'``: Uses a view that is purely client rendering (generally safe without a virtual frame buffer)
73 | default_server_rendering: Whether to use server-side or client-side rendering on-start when using the ``'trame'`` mode.
74 | """
75 | if mode != "trame":
76 | default_server_rendering = mode == "server"
77 |
78 | viewer = Viewer(plotter=plotter, server=server, suppress_rendering=mode == "client")
79 |
80 | vuetify.VSpacer()
81 | # Whether to toggle the theme between light and dark
82 | checkbox(
83 | model="$vuetify.theme.dark",
84 | icons=("mdi-lightbulb-off-outline", "mdi-lightbulb-outline"),
85 | tooltip=f"Toggle theme",
86 | )
87 | # Whether to toggle the background color between white and black
88 | checkbox(
89 | model=(viewer.BACKGROUND, False),
90 | icons=("mdi-palette-swatch-outline", "mdi-palette-swatch"),
91 | tooltip=f"Toggle background ({{{{ {viewer.BACKGROUND} ? 'white' : 'black' }}}})",
92 | )
93 | # Server rendering options
94 | if mode == "trame":
95 | checkbox(
96 | model=(viewer.SERVER_RENDERING, default_server_rendering),
97 | icons=("mdi-lan-connect", "mdi-lan-disconnect"),
98 | tooltip=f"Toggle rendering mode ({{{{ {viewer.SERVER_RENDERING} ? 'remote' : 'local' }}}})",
99 | )
100 | # Whether to visualize the memory usage
101 | checkbox(
102 | model=(viewer.MEMORY_USAGE, False),
103 | icons=("mdi-memory", "mdi-memory"),
104 | tooltip=f"Toggle memory usage ({{{{ {viewer.MEMORY_USAGE} ? 'on' : 'off' }}}})",
105 | )
106 |
107 | vuetify.VDivider(vertical=True, classes="mx-1")
108 | # Whether to show the main model
109 | checkbox(
110 | model=(viewer.SHOW_MAIN_MODEL, True),
111 | icons=("mdi-eye-outline", "mdi-eye-off-outline"),
112 | tooltip=f"Toggle main model visibility ({{{{ {viewer.SHOW_MAIN_MODEL} ? 'True' : 'False' }}}})",
113 | )
114 | # Whether to add outline
115 | checkbox(
116 | model=(viewer.OUTLINE, False),
117 | icons=("mdi-cube", "mdi-cube-off"),
118 | tooltip=f"Toggle bounding box ({{{{ {viewer.OUTLINE} ? 'on' : 'off' }}}})",
119 | )
120 | # Whether to add grid
121 | checkbox(
122 | model=(viewer.GRID, False),
123 | icons=("mdi-ruler-square", "mdi-ruler-square"),
124 | tooltip=f"Toggle ruler ({{{{ {viewer.GRID} ? 'on' : 'off' }}}})",
125 | )
126 | # Whether to add axis legend
127 | checkbox(
128 | model=(viewer.AXIS, False),
129 | icons=("mdi-axis-arrow-info", "mdi-axis-arrow-info"),
130 | tooltip=f"Toggle axis ({{{{ {viewer.AXIS} ? 'on' : 'off' }}}})",
131 | )
132 |
133 | # Reset camera
134 | vuetify.VDivider(vertical=True, classes="mx-1")
135 | button(
136 | click=viewer.view_isometric,
137 | icon="mdi-axis-arrow",
138 | tooltip="Perspective view",
139 | )
140 | button(
141 | click=viewer.view_yz,
142 | icon="mdi-axis-x-arrow",
143 | tooltip="Reset camera X",
144 | )
145 | button(
146 | click=viewer.view_xz,
147 | icon="mdi-axis-y-arrow",
148 | tooltip="Reset camera Y",
149 | )
150 | button(
151 | click=viewer.view_xy,
152 | icon="mdi-axis-z-arrow",
153 | tooltip="Reset camera Z",
154 | )
155 |
156 |
157 | def toolbar_switch_model(
158 | server,
159 | plotter: BasePlotter,
160 | ):
161 | avaliable_samples = [
162 | key
163 | for key in local_dataset_manager.get_assets().keys()
164 | if not str(key).endswith("anndata")
165 | ]
166 | avaliable_samples.append("uploaded_sample")
167 |
168 | vuetify.VSpacer()
169 | SM = SwitchModels(server=server, plotter=plotter)
170 | vuetify.VSelect(
171 | label="Select Samples",
172 | v_model=(SM.SELECT_SAMPLES, None),
173 | items=("samples", avaliable_samples),
174 | dense=True,
175 | outlined=True,
176 | hide_details=True,
177 | classes="ml-8",
178 | prepend_inner_icon="mdi-magnify",
179 | style="max-width: 300px;",
180 | rounded=True,
181 | )
182 | # Select local directory
183 | button(
184 | # Must use single-quote string for JS here
185 | click=server.controller.open_directory,
186 | icon="mdi-file-document-outline",
187 | tooltip="Select directory",
188 | )
189 |
190 |
191 | def ui_toolbar(
192 | server,
193 | layout,
194 | plotter: BasePlotter,
195 | mode: Literal["trame", "server", "client"] = "trame",
196 | default_server_rendering: bool = True,
197 | ui_name: str = "SPATEO VIEWER",
198 | ui_icon=icon_manager.spateo_logo,
199 | ):
200 | """
201 | Generate standard ToolBar for Spateo UI.
202 |
203 | Args:
204 | server: The trame server.
205 | layout: The layout object.
206 | plotter: The PyVista plotter to connect with the UI.
207 | mode: The UI view mode. Options are:
208 |
209 | * ``'trame'``: Uses a view that can switch between client and server rendering modes.
210 | * ``'server'``: Uses a view that is purely server rendering.
211 | * ``'client'``: Uses a view that is purely client rendering (generally safe without a virtual frame buffer)
212 | default_server_rendering: Whether to use server-side or client-side rendering on-start when using the ``'trame'`` mode.
213 | ui_name: Title name of the GUI.
214 | ui_icon: Title icon of the GUI.
215 | """
216 |
217 | # -----------------------------------------------------------------------------
218 | # Title
219 | # -----------------------------------------------------------------------------
220 | ui_title(layout=layout, title_name=ui_name, title_icon=ui_icon)
221 |
222 | # -----------------------------------------------------------------------------
223 | # ToolBar
224 | # -----------------------------------------------------------------------------
225 | with layout.toolbar as tb:
226 | tb.height = 55
227 | tb.dense = True
228 | tb.clipped_right = True
229 |
230 | toolbar_switch_model(server=server, plotter=plotter)
231 | toolbar_widgets(
232 | server=server,
233 | plotter=plotter,
234 | mode=mode,
235 | default_server_rendering=default_server_rendering,
236 | )
237 |
--------------------------------------------------------------------------------
/stviewer/Explorer/ui/utils.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import html, vuetify
2 |
3 | # -----------------------------------------------------------------------------
4 | # vuetify components
5 | # -----------------------------------------------------------------------------
6 |
7 |
8 | def button(click, icon, tooltip):
9 | """Create a vuetify button."""
10 | with vuetify.VTooltip(bottom=True):
11 | with vuetify.Template(v_slot_activator="{ on, attrs }"):
12 | with vuetify.VBtn(icon=True, v_bind="attrs", v_on="on", click=click):
13 | vuetify.VIcon(icon)
14 | html.Span(tooltip)
15 |
16 |
17 | def checkbox(model, icons, tooltip, **kwargs):
18 | """Create a vuetify checkbox."""
19 | with vuetify.VTooltip(bottom=True):
20 | with vuetify.Template(v_slot_activator="{ on, attrs }"):
21 | with html.Div(v_on="on", v_bind="attrs"):
22 | vuetify.VCheckbox(
23 | v_model=model,
24 | on_icon=icons[0],
25 | off_icon=icons[1],
26 | dense=True,
27 | hide_details=True,
28 | classes="my-0 py-0 ml-1",
29 | **kwargs
30 | )
31 | html.Span(tooltip)
32 |
33 |
34 | def switch(model, tooltip, **kwargs):
35 | """Create a vuetify switch."""
36 | with vuetify.VTooltip(bottom=True):
37 | with vuetify.Template(v_slot_activator="{ on, attrs }"):
38 | vuetify.VSwitch(v_model=model, hide_details=True, dense=True, **kwargs)
39 | html.Span(tooltip)
40 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/__init__.py:
--------------------------------------------------------------------------------
1 | from .pv_pipeline import *
2 | from .ui import ui_container, ui_drawer, ui_layout, ui_toolbar
3 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/pv_pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | from .init_parameters import *
2 | from .pv_callback import Viewer
3 | from .pv_models import init_models
4 | from .pv_plotter import add_single_model, create_plotter
5 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/pv_pipeline/init_parameters.py:
--------------------------------------------------------------------------------
1 | import pyautogui
2 |
3 | # Init parameters
4 | init_active_parameters = {
5 | "picking_group": None,
6 | "overwrite": False,
7 | "activeModelVisible": True,
8 | "activeModel_output": None,
9 | "anndata_output": None,
10 | }
11 | init_align_parameters = {
12 | "slices_alignment": False,
13 | "slices_key": "slices",
14 | "slices_align_device": "CPU",
15 | "slices_align_method": "Paste",
16 | "slices_align_factor": 0.1,
17 | "slices_align_max_iter": 200,
18 | }
19 | init_mesh_parameters = {
20 | "meshModel": None,
21 | "meshModelVisible": False,
22 | "reconstruct_mesh": False,
23 | "mc_factor": 1.0,
24 | "mesh_voronoi": 20000,
25 | "mesh_smooth_factor": 2000,
26 | "mesh_scale_factor": 1.0,
27 | "clip_pc_with_mesh": False,
28 | "mesh_output": None,
29 | }
30 | init_picking_parameters = {
31 | "modes": [
32 | {"value": "hover", "icon": "mdi-magnify"},
33 | {"value": "click", "icon": "mdi-cursor-default-click-outline"},
34 | {"value": "select", "icon": "mdi-select-drag"},
35 | ],
36 | "pickData": None,
37 | "selectData": None,
38 | "resetModel": False,
39 | "tooltip": "",
40 | }
41 | init_setting_parameters = {
42 | "show_active_card": True,
43 | "show_align_card": True,
44 | "show_mesh_card": True,
45 | "background_color": "[0, 0, 0]",
46 | "pixel_ratio": pyautogui.size()[0] / 500,
47 | }
48 |
49 | # costum init parameters
50 | init_custom_parameters = {
51 | "custom_func": False,
52 | "custom_analysis": False,
53 | "custom_model": None,
54 | "custom_model_visible": False,
55 | "custom_model_size": pyautogui.size()[0] / 200,
56 | "custom_model_output": None,
57 | "show_custom_card": True,
58 | "custom_parameter1": "ElPiGraph",
59 | "custom_parameter2": 50,
60 | }
61 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/pv_pipeline/pv_custom.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | try:
4 | from typing import Literal
5 | except ImportError:
6 | from typing_extensions import Literal
7 |
8 | from typing import Optional, Tuple, Union
9 |
10 | import numpy as np
11 | from pyvista import PolyData, UnstructuredGrid
12 |
13 | try:
14 | from typing import Literal
15 | except ImportError:
16 | from typing_extensions import Literal
17 |
18 | #####################################################################
19 | # Principal curves algorithm #
20 | # ================================================================= #
21 | # Original Code Repository Author: Matthew Artuso. #
22 | # Adapted to Spateo by: spateo authors #
23 | # Created Date: 6/11/2022 #
24 | # Description: A principal curve is a smooth n-dimensional curve #
25 | # that passes through the middle of a dataset. #
26 | # Reference: https://doi.org/10.1016/j.cam.2015.11.041 #
27 | # ================================================================= #
28 | #####################################################################
29 |
30 |
31 | class NLPCA(object):
32 | """This is a global solver for principal curves that uses neural networks.
33 | Attributes:
34 | None
35 | """
36 |
37 | def __init__(self):
38 | self.fit_points = None
39 | self.model = None
40 | self.intermediate_layer_model = None
41 |
42 | def fit(
43 | self,
44 | data: np.ndarray,
45 | epochs: int = 500,
46 | nodes: int = 25,
47 | lr: float = 0.01,
48 | verbose: int = 0,
49 | ):
50 | """
51 | This method creates a model and will fit it to the given m x n dimensional data.
52 |
53 | Args:
54 | data: A numpy array of shape (m,n), where m is the number of points and n is the number of dimensions.
55 | epochs: Number of epochs to train neural network, defaults to 500.
56 | nodes: Number of nodes for the construction layers. Defaults to 25. The more complex the curve, the higher
57 | this number should be.
58 | lr: Learning rate for backprop. Defaults to .01
59 | verbose: Verbose = 0 mutes the training text from Keras. Defaults to 0.
60 | """
61 | try:
62 | from keras.models import Model
63 | except ImportError:
64 | raise ImportError(
65 | "You need to install the package `tensorflow`."
66 | "\nInstall tensorflow via `pip install -U tensorflow`."
67 | )
68 | num_dim = data.shape[1] # get number of dimensions for pts
69 |
70 | # create models, base and intermediate
71 | model = self.create_model(num_dim, nodes=nodes, lr=lr)
72 | bname = model.layers[2].name # bottle-neck layer name
73 |
74 | # The itermediate model gets the output of the bottleneck layer,
75 | # which acts as the projection layer.
76 | self.intermediate_layer_model = Model(
77 | inputs=model.input, outputs=model.get_layer(bname).output
78 | )
79 |
80 | # Fit the model and set the instances self.model to model
81 | model.fit(data, data, epochs=epochs, verbose=verbose)
82 | self.model = model
83 |
84 | return
85 |
86 | def project(self, data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
87 | """
88 | The project function will project the points to the curve generated by the fit function. Given back is the
89 | projection index of the original data and a sorted version of the original data.
90 |
91 | Args:
92 | data: m x n array to project to the curve
93 |
94 | Returns:
95 | proj: A one-dimension array that contains the projection index for each point in data.
96 | all_sorted: A m x n+1 array that contains data sorted by its projection index, along with the index.
97 | """
98 | num_dim = data.shape[1] # get number of dimensions for pts
99 |
100 | pts = self.model.predict(data)
101 | proj = self.intermediate_layer_model.predict(data)
102 | self.fit_points = pts
103 |
104 | all = np.concatenate([pts, proj], axis=1)
105 | all_sorted = all[all[:, num_dim].argsort()]
106 |
107 | return proj, all_sorted
108 |
109 | def create_model(self, num_dim: int, nodes: int, lr: float):
110 | """
111 | Creates a tf model.
112 |
113 | Args:
114 | num_dim: How many dimensions the input space is
115 | nodes: How many nodes for the construction layers
116 | lr: Learning rate of backpropigation
117 |
118 | Returns:
119 | model (object): Keras Model
120 | """
121 | # Create layers:
122 | # Function G
123 | try:
124 | import tensorflow as tf
125 | from keras import optimizers
126 | from keras.layers import Dense, Input
127 | from keras.models import Model
128 | except ImportError:
129 | raise ImportError(
130 | "You need to install the package `tensorflow`."
131 | "\nInstall tensorflow via `pip install -U tensorflow`."
132 | )
133 |
134 | def orth_dist(y_true, y_pred):
135 | """
136 | Loss function for the NLPCA NN. Returns the sum of the orthogonal
137 | distance from the output tensor to the real tensor.
138 | """
139 | return tf.math.reduce_sum((y_true - y_pred) ** 2)
140 |
141 | input = Input(shape=(num_dim,)) # input layer
142 | mapping = Dense(nodes, activation="sigmoid")(input) # mapping layer
143 | bottle = Dense(1, activation="sigmoid")(mapping) # bottle-neck layer
144 |
145 | # Function H
146 | demapping = Dense(nodes, activation="sigmoid")(bottle) # mapping layer
147 | output = Dense(num_dim)(demapping) # output layer
148 |
149 | # Connect and compile model:
150 | model = Model(inputs=input, outputs=output)
151 | gradient_descent = optimizers.Adam(learning_rate=lr)
152 | model.compile(loss=orth_dist, optimizer=gradient_descent)
153 |
154 | return model
155 |
156 |
157 | ###########
158 | # Methods #
159 | ###########
160 |
161 |
162 | def _euclidean_distance(N1, N2):
163 | temp = np.asarray(N1) - np.asarray(N2)
164 | euclid_dist = np.sqrt(np.dot(temp.T, temp))
165 | return euclid_dist
166 |
167 |
168 | def sort_nodes_of_curve(nodes, started_node):
169 | current_node = tuple(started_node)
170 | remaining_nodes = [tuple(node) for node in nodes]
171 |
172 | sorted_nodes = []
173 | while remaining_nodes:
174 | closest_node = min(
175 | remaining_nodes, key=lambda x: _euclidean_distance(current_node, x)
176 | )
177 | sorted_nodes.append(closest_node)
178 | remaining_nodes.remove(closest_node)
179 | current_node = closest_node
180 | sorted_nodes = np.asarray([list(sn) for sn in sorted_nodes])
181 | return sorted_nodes
182 |
183 |
184 | def ElPiGraph_method(
185 | X: np.ndarray,
186 | NumNodes: int = 50,
187 | topology: Literal["tree", "circle", "curve"] = "curve",
188 | Lambda: float = 0.01,
189 | Mu: float = 0.1,
190 | alpha: float = 0.0,
191 | FinalEnergy: Literal["Base", "Penalized"] = "Penalized",
192 | **kwargs,
193 | ) -> Tuple[np.ndarray, np.ndarray]:
194 | """
195 | Generate a principal elastic tree.
196 | Reference: Albergante et al. (2020), Robust and Scalable Learning of Complex Intrinsic Dataset Geometry via ElPiGraph.
197 |
198 | Args:
199 | X: DxN, data matrix list.
200 | NumNodes: The number of nodes of the principal graph. Use a range of 10 to 100 for ElPiGraph approach.
201 | topology:The appropriate topology used to fit a principal graph for each dataset.
202 | Lambda: The attractive strength of edges between nodes (constrains edge lengths)
203 | Mu: The repulsive strength of a node’s neighboring nodes (constrains angles to be close to harmonic)
204 | alpha: Branching penalty (penalizes number of branches for the principal tree)
205 | FinalEnergy: Indicating the final elastic emergy associated with the configuration. Currently it can be “Base” or “Penalized”
206 | **kwargs: Other parameters used in elpigraph.computeElasticPrincipalTree. For details, please see:
207 | https://elpigraph-python.readthedocs.io/en/latest/basics.html
208 |
209 | Returns:
210 | nodes: The nodes in the principal tree.
211 | edges: The edges between nodes in the principal tree.
212 | """
213 | try:
214 | import elpigraph
215 | except ImportError:
216 | raise ImportError(
217 | "You need to install the package `elpigraph-python`."
218 | "\nInstall elpigraph-python via `pip install elpigraph-python`."
219 | )
220 |
221 | ElPiGraph_kwargs = {
222 | "NumNodes": NumNodes,
223 | "Lambda": Lambda,
224 | "Mu": Mu,
225 | "alpha": alpha,
226 | "FinalEnergy": FinalEnergy,
227 | }
228 | ElPiGraph_kwargs.update(kwargs)
229 |
230 | if str(topology).lower() == "tree":
231 | elpi_tree = elpigraph.computeElasticPrincipalTree(
232 | X=np.asarray(X), **ElPiGraph_kwargs
233 | )
234 | elif str(topology).lower() == "circle":
235 | elpi_tree = elpigraph.computeElasticPrincipalCircle(
236 | X=np.asarray(X), **ElPiGraph_kwargs
237 | )
238 | elif str(topology).lower() == "curve":
239 | elpi_tree = elpigraph.computeElasticPrincipalCurve(
240 | X=np.asarray(X), **ElPiGraph_kwargs
241 | )
242 | else:
243 | raise ValueError(
244 | "`topology` value is wrong."
245 | "\nAvailable `topology` are: `'tree'`, `'circle'`, `'curve'`."
246 | )
247 |
248 | nodes = elpi_tree[0]["NodePositions"]
249 | edges = np.asarray(elpi_tree[0]["Edges"][0])
250 |
251 | if str(topology).lower() in ["curve", "circle"]:
252 | unique_values, occurrence_count = np.unique(edges.flatten(), return_counts=True)
253 | started_node_indices = [
254 | v for c, v in zip(occurrence_count, unique_values) if c == 1
255 | ]
256 | started_node = (
257 | nodes[started_node_indices[0]]
258 | if len(started_node_indices) != 0
259 | else nodes[0]
260 | )
261 |
262 | nodes = sort_nodes_of_curve(nodes, started_node)
263 | if str(topology).lower() == "curve":
264 | edges = np.c_[
265 | np.arange(0, len(nodes) - 1, 1).reshape(-1, 1),
266 | np.arange(1, len(nodes), 1).reshape(-1, 1),
267 | ]
268 | else:
269 | edges = np.c_[
270 | np.arange(0, len(nodes), 1).reshape(-1, 1),
271 | np.asarray(list(range(1, len(nodes))) + [0]).reshape(-1, 1),
272 | ]
273 |
274 | return nodes, edges
275 |
276 |
277 | def SimplePPT_method(
278 | X: np.ndarray,
279 | NumNodes: int = 50,
280 | sigma: Optional[Union[float, int]] = 0.1,
281 | lam: Optional[Union[float, int]] = 1,
282 | metric: str = "euclidean",
283 | nsteps: int = 50,
284 | err_cut: float = 5e-3,
285 | seed: Optional[int] = 1,
286 | **kwargs,
287 | ) -> Tuple[np.ndarray, np.ndarray]:
288 | """
289 | Generate a simple principal tree.
290 | Reference: Mao et al. (2015), SimplePPT: A simple principal tree algorithm, SIAM International Conference on Data Mining.
291 |
292 | Args:
293 | X: DxN, data matrix list.
294 | NumNodes: The number of nodes of the principal graph. Use a range of 100 to 2000 for PPT approach.
295 | sigma: Regularization parameter.
296 | lam: Penalty for the tree length.
297 | metric: The metric to use to compute distances in high dimensional space. For compatible metrics, check the
298 | documentation of sklearn.metrics.pairwise_distances.
299 | nsteps: Number of steps for the optimisation process.
300 | err_cut: Stop algorithm if proximity of principal points between iterations less than defined value.
301 | seed: A numpy random seed.
302 | **kwargs: Other parameters used in simpleppt.ppt. For details, please see:
303 | https://github.com/LouisFaure/simpleppt/blob/main/simpleppt/ppt.py
304 |
305 | Returns:
306 | nodes: The nodes in the principal tree.
307 | edges: The edges between nodes in the principal tree.
308 | """
309 | try:
310 | import igraph
311 | import simpleppt
312 | except ImportError:
313 | raise ImportError(
314 | "You need to install the package `simpleppt` and `igraph`."
315 | "\nInstall simpleppt via `pip install -U simpleppt`."
316 | "\nInstall igraph via `pip install -U igraph`"
317 | )
318 |
319 | SimplePPT_kwargs = {
320 | "seed": seed,
321 | "lam": lam,
322 | "sigma": sigma,
323 | "metric": metric,
324 | "nsteps": nsteps,
325 | "err_cut": err_cut,
326 | }
327 | SimplePPT_kwargs.update(kwargs)
328 |
329 | X = np.asarray(X)
330 | ppt_tree = simpleppt.ppt(X=X, Nodes=NumNodes, **SimplePPT_kwargs)
331 |
332 | R = ppt_tree.R
333 | nodes = (np.dot(X.T, R) / R.sum(axis=0)).T
334 |
335 | B = ppt_tree.B
336 | edges = np.array(
337 | igraph.Graph.Adjacency((B > 0).tolist(), mode="undirected").get_edgelist()
338 | )
339 |
340 | return nodes, edges
341 |
342 |
343 | def PrinCurve_method(
344 | X: np.ndarray,
345 | NumNodes: int = 50,
346 | epochs: int = 500,
347 | lr: float = 0.01,
348 | scale_factor: Union[int, float] = 1,
349 | **kwargs,
350 | ) -> Tuple[np.ndarray, np.ndarray]:
351 | """
352 | This is the global module that contains principal curve and nonlinear principal component analysis algorithms that
353 | work to optimize a line over an entire dataset.
354 | Reference: Chen et al. (2016), Constraint local principal curve: Concept, algorithms and applications.
355 |
356 | Args:
357 | X: DxN, data matrix list.
358 | NumNodes: Number of nodes for the construction layers. Defaults to 50. The more complex the curve, the higher this number should be.
359 | epochs: Number of epochs to train neural network, defaults to 500.
360 | lr: Learning rate for backprop. Defaults to .01
361 | scale_factor:
362 | **kwargs: Other parameters used in global algorithms. For details, please see:
363 | https://github.com/artusoma/prinPy/blob/master/prinpy/glob.py
364 |
365 | Returns:
366 | nodes: The nodes in the principal tree.
367 | edges: The edges between nodes in the principal tree.
368 | """
369 | try:
370 | from dynamo.tools.sampling import sample
371 | except ImportError:
372 | raise ImportError(
373 | "You need to install the package `dynamo`."
374 | "\nInstall dynamo via `pip install -U dynamo-release`."
375 | )
376 | PrinCurve_kwargs = {
377 | "epochs": epochs,
378 | "lr": lr,
379 | "verbose": 0,
380 | }
381 | PrinCurve_kwargs.update(kwargs)
382 |
383 | raw_X = np.asarray(X)
384 | dims = raw_X.shape[1]
385 |
386 | new_X = raw_X.copy() / scale_factor
387 | trans = []
388 | for i in range(dims):
389 | sub_trans = new_X[:, i].min()
390 | new_X[:, i] = new_X[:, i] - sub_trans
391 | trans.append(sub_trans)
392 | # create solver
393 | pca_project = NLPCA()
394 | # transform data for better training with the neural net using built in preprocessor.
395 | # fit the data
396 | pca_project.fit(new_X, nodes=NumNodes, **PrinCurve_kwargs)
397 | # project the current data. This returns a projection index for each point and points to plot the curve.
398 | _, curve_pts = pca_project.project(new_X)
399 | curve_pts = np.unique(curve_pts, axis=0)
400 | curve_pts = np.einsum("ij->ij", curve_pts[curve_pts[:, -1].argsort(), :])
401 | for i in range(dims):
402 | curve_pts[:, i] = curve_pts[:, i] + trans[i]
403 |
404 | nodes = curve_pts[:, :3] * scale_factor
405 | sampling = sample(
406 | arr=np.asarray(range(nodes.shape[0])), n=NumNodes, method="trn", X=nodes
407 | )
408 | sampling.sort()
409 | nodes = nodes[sampling, :]
410 | n_nodes = nodes.shape[0]
411 | edges = np.asarray([np.arange(0, n_nodes, 1), np.arange(1, n_nodes + 1, 1)]).T
412 | edges[-1, 1] = n_nodes - 1
413 | """
414 | unique_values, occurrence_count = np.unique(edges.flatten(), return_counts=True)
415 | started_node_indices = [v for c, v in zip(occurrence_count, unique_values) if c == 1]
416 | started_node = nodes[started_node_indices[0]] if len(started_node_indices) != 0 else nodes[0]
417 |
418 | sorted_nodes = sort_nodes_of_curve(nodes, started_node)
419 | sorted_edges = np.c_[np.arange(0, len(nodes) - 1, 1).reshape(-1, 1), np.arange(1, len(nodes), 1).reshape(-1, 1)]
420 | """
421 | return nodes, edges
422 |
423 |
424 | def construct_backbone(
425 | model: Union[PolyData, UnstructuredGrid],
426 | spatial_key: Optional[str] = None,
427 | nodes_key: str = "nodes",
428 | rd_method: Literal["ElPiGraph", "SimplePPT", "PrinCurve"] = "ElPiGraph",
429 | num_nodes: int = 50,
430 | **kwargs,
431 | ) -> PolyData:
432 | """
433 | Organ's backbone construction based on 3D point cloud model.
434 |
435 | Args:
436 | model: A point cloud model.
437 | spatial_key: If spatial_key is None, the spatial coordinates are in model.points, otherwise in model[spatial_key].
438 | nodes_key: The key that corresponds to the coordinates of the nodes in the backbone.
439 | rd_method: The method of constructing a backbone model. Available ``rd_method`` are:
440 |
441 | * ``'ElPiGraph'``: Generate a principal elastic tree.
442 | * ``'SimplePPT'``: Generate a simple principal tree.
443 | * ``'PrinCurve'``: This is the global module that contains principal curve and nonlinear principal
444 | component analysis algorithms that work to optimize a line over an entire dataset.
445 | num_nodes: Number of nodes for the backbone model.
446 | **kwargs: Additional parameters that will be passed to ``ElPiGraph_method``, ``SimplePPT_method`` or ``PrinCurve_method`` function.
447 |
448 | Returns:
449 | backbone_model: A three-dims backbone model.
450 | """
451 | import pyvista as pv
452 |
453 | model = model.copy()
454 | X = model.points if spatial_key is None else model[spatial_key]
455 |
456 | if rd_method == "ElPiGraph":
457 | nodes, edges = ElPiGraph_method(X=X, NumNodes=num_nodes, **kwargs)
458 | elif rd_method == "SimplePPT":
459 | nodes, edges = SimplePPT_method(X=X, NumNodes=num_nodes, **kwargs)
460 | elif rd_method == "PrinCurve":
461 | nodes, edges = PrinCurve_method(X=X, NumNodes=num_nodes, **kwargs)
462 | else:
463 | raise ValueError(
464 | "`rd_method` value is wrong."
465 | "\nAvailable `rd_method` are: `'ElPiGraph'`, `'SimplePPT'`, `'PrinCurve'`."
466 | )
467 |
468 | # Construct the backbone model
469 | # padding = np.array([2] * edges.shape[0], int)
470 | # edges_w_padding = np.vstack((padding, edges.T)).T
471 | # backbone_model = pv.PolyData(nodes, edges_w_padding)
472 |
473 | backbone_model = pv.MultipleLines(points=nodes)
474 | backbone_model.point_data[nodes_key] = np.arange(0, len(nodes), 1)
475 | return backbone_model
476 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/pv_pipeline/pv_models.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings("ignore")
4 |
5 | import anndata as ad
6 | import numpy as np
7 |
8 | from .pv_plotter import add_single_model
9 |
10 |
11 | def check_model_data(model, point_data: bool = True, cell_data: bool = True):
12 | # obtain the data of points
13 | pdd = {}
14 | if point_data:
15 | for name, array in model.point_data.items():
16 | if name != "obs_index":
17 | array = np.asarray(array)
18 | if len(array.shape) == 1 and name not in [
19 | "vtkOriginalPointIds",
20 | "SelectedPoints",
21 | "vtkInsidedness",
22 | ]:
23 | od = {"None": "None"}
24 | if not np.issubdtype(array.dtype, np.number):
25 | od = {o: i for i, o in enumerate(np.unique(array).tolist())}
26 | model.point_data[name] = np.asarray(
27 | list(map(lambda x: od[x], array)), dtype=float
28 | )
29 | array = np.asarray(model.point_data[name])
30 | pdd[name] = {
31 | "name": name,
32 | "range": [array.min(), array.max()],
33 | "value": name,
34 | "text": name,
35 | "scalarMode": 3,
36 | "raw_labels": od,
37 | }
38 |
39 | # obtain the data of cells
40 | cdd = {}
41 | if cell_data:
42 | for name, array in model.cell_data.items():
43 | if name != "obs_index":
44 | array = np.asarray(array)
45 | if len(array.shape) == 1 and name not in [
46 | "vtkOriginalCellIds",
47 | "orig_extract_id",
48 | "vtkInsidedness",
49 | ]:
50 | od = {"None": "None"}
51 | if not np.issubdtype(array.dtype, np.number):
52 | od = {o: i for i, o in enumerate(np.unique(array).tolist())}
53 | model.cell_data[name] = np.asarray(
54 | list(map(lambda x: od[x], array)), dtype=float
55 | )
56 | array = np.asarray(model.cell_data[name])
57 | cdd[name] = {
58 | "name": name,
59 | "range": [array.min(), array.max()],
60 | "value": name,
61 | "text": name,
62 | "scalarMode": 3,
63 | "raw_labels": od,
64 | }
65 |
66 | return model, pdd, cdd
67 |
68 |
69 | def init_models(plotter, anndata_path):
70 | # Generate init anndata object
71 | init_adata = ad.read_h5ad(anndata_path)
72 | init_adata.obs["Default"] = np.ones(shape=(init_adata.shape[0], 1))
73 | init_adata.obsm["spatial"] = (
74 | np.c_[init_adata.obsm["spatial"], np.ones(shape=(init_adata.shape[0], 1))]
75 | if init_adata.obsm["spatial"].shape[1] == 2
76 | else init_adata.obsm["spatial"]
77 | )
78 | spatial_center = init_adata.obsm["spatial"].mean(axis=0)
79 | if tuple(spatial_center) != (0, 0, 0):
80 | init_adata.obsm["spatial"] = init_adata.obsm["spatial"] - spatial_center
81 | for key in init_adata.obs_keys():
82 | if init_adata.obs[key].dtype == "category":
83 | init_adata.obs[key] = np.asarray(init_adata.obs[key], dtype=str)
84 | if np.issubdtype(init_adata.obs[key].dtype, np.number):
85 | init_adata.obs[key] = np.asarray(init_adata.obs[key], dtype=float)
86 |
87 | # Construct init pc model
88 | from .pv_tdr import construct_pc
89 |
90 | main_model = construct_pc(adata=init_adata, spatial_key="spatial")
91 | _obs_index = main_model.point_data["obs_index"]
92 | for key in init_adata.obs_keys():
93 | main_model.point_data[key] = init_adata.obs.loc[_obs_index, key]
94 |
95 | main_model, pdd, cdd = check_model_data(
96 | model=main_model, point_data=True, cell_data=True
97 | )
98 | _ = add_single_model(plotter=plotter, model=main_model, model_name="mainModel")
99 |
100 | # Generate active model
101 | active_model = main_model.copy()
102 | _ = add_single_model(
103 | plotter=plotter,
104 | model=active_model,
105 | model_name="activeModel",
106 | )
107 |
108 | # Init parameters
109 | init_scalar = "Default"
110 | return main_model, active_model, init_scalar, pdd, cdd
111 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/pv_pipeline/pv_plotter.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import pyvista as pv
4 | from pyvista import Plotter, PolyData, UnstructuredGrid
5 |
6 | try:
7 | from typing import Literal
8 | except ImportError:
9 | from typing_extensions import Literal
10 |
11 |
12 | def create_plotter(
13 | window_size: tuple = (1024, 1024), background: str = "black", **kwargs
14 | ) -> Plotter:
15 | """
16 | Create a plotting object to display pyvista/vtk model.
17 |
18 | Args:
19 | window_size: Window size in pixels. The default window_size is ``[1024, 768]``.
20 | background: The background color of the window.
21 |
22 | Returns:
23 | plotter: The plotting object to display pyvista/vtk model.
24 | """
25 |
26 | # Create an initial plotting object.
27 | plotter = pv.Plotter(
28 | window_size=window_size, off_screen=True, lighting="light_kit", **kwargs
29 | )
30 |
31 | # Set the background color of the active render window.
32 | plotter.background_color = background
33 | return plotter
34 |
35 |
36 | def add_single_model(
37 | plotter: Plotter,
38 | model: Union[PolyData, UnstructuredGrid],
39 | key: Optional[str] = None,
40 | cmap: Optional[str] = "rainbow",
41 | color: Optional[str] = "gainsboro",
42 | ambient: float = 0.2,
43 | opacity: float = 1.0,
44 | model_style: Literal["points", "surface", "wireframe"] = "surface",
45 | model_size: float = 3.0,
46 | model_name: Optional[str] = None,
47 | ):
48 | """
49 | Add model(s) to the plotter.
50 | Args:
51 | plotter: The plotting object to display pyvista/vtk model.
52 | model: A reconstructed model.
53 | key: The key under which are the labels.
54 | cmap: Name of the Matplotlib colormap to use when mapping the model.
55 | color: Name of the Matplotlib color to use when mapping the model.
56 | ambient: When lighting is enabled, this is the amount of light in the range of 0 to 1 (default 0.0) that reaches
57 | the actor when not directed at the light source emitted from the viewer.
58 | opacity: Opacity of the model.
59 | If a single float value is given, it will be the global opacity of the model and uniformly applied
60 | everywhere, elif a numpy.ndarray with single float values is given, it
61 | will be the opacity of each point. - should be between 0 and 1.
62 | A string can also be specified to map the scalars range to a predefined opacity transfer function
63 | (options include: 'linear', 'linear_r', 'geom', 'geom_r').
64 | model_style: Visualization style of the model. One of the following:
65 | * ``model_style = 'surface'``,
66 | * ``model_style = 'wireframe'``,
67 | * ``model_style = 'points'``.
68 | model_size: If ``model_style = 'points'``, point size of any nodes in the dataset plotted.
69 | If ``model_style = 'wireframe'``, thickness of lines.
70 | model_name: Name to assign to the model. Defaults to the memory address.
71 | """
72 |
73 | if model_style == "points":
74 | render_spheres, render_tubes, smooth_shading = True, False, True
75 | elif model_style == "wireframe":
76 | render_spheres, render_tubes, smooth_shading = False, True, False
77 | else:
78 | render_spheres, render_tubes, smooth_shading = False, False, True
79 | mesh_kwargs = dict(
80 | scalars=key if key in model.array_names else None,
81 | style=model_style,
82 | render_points_as_spheres=render_spheres,
83 | render_lines_as_tubes=render_tubes,
84 | point_size=model_size,
85 | line_width=model_size,
86 | ambient=ambient,
87 | opacity=opacity,
88 | smooth_shading=smooth_shading,
89 | show_scalar_bar=False,
90 | cmap=cmap,
91 | color=color,
92 | name=model_name,
93 | )
94 | actor = plotter.add_mesh(model, **mesh_kwargs)
95 | return actor
96 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/pv_pipeline/pv_tdr.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | import numpy as np
4 | import pyvista as pv
5 | from anndata import AnnData
6 | from pandas.core.frame import DataFrame
7 | from pyvista import DataSet, PolyData, UnstructuredGrid
8 | from scipy.spatial.distance import cdist
9 |
10 | try:
11 | from typing import Literal
12 | except ImportError:
13 | from typing_extensions import Literal
14 |
15 |
16 | ####################################
17 | # point cloud model reconstruction #
18 | ####################################
19 |
20 |
21 | def construct_pc(
22 | adata: AnnData,
23 | spatial_key: str = "spatial",
24 | ) -> PolyData:
25 | """
26 | Construct a point cloud model based on 3D coordinate information.
27 |
28 | Args:
29 | adata: AnnData object.
30 | spatial_key: The key in ``.obsm`` that corresponds to the spatial coordinate of each bucket.
31 |
32 | Returns:
33 | pc: A point cloud, which contains the following properties:
34 | ``pc.point_data[key_added]``, the ``groupby`` information.
35 | ``pc.point_data[f'{key_added}_rgba']``, the rgba colors of the ``groupby`` information.
36 | ``pc.point_data['obs_index']``, the obs_index of each coordinate in the original adata.
37 | plot_cmap: Recommended colormap parameter values for plotting.
38 | """
39 |
40 | # create an initial pc.
41 | adata = adata.copy()
42 | bucket_xyz = adata.obsm[spatial_key].astype(np.float64)
43 | if isinstance(bucket_xyz, DataFrame):
44 | bucket_xyz = bucket_xyz.values
45 | pc = pv.PolyData(bucket_xyz)
46 | pc.point_data["obs_index"] = np.array(adata.obs_names.tolist())
47 | return pc
48 |
49 |
50 | #######################
51 | # Mesh reconstruction #
52 | #######################
53 |
54 |
55 | def merge_models(
56 | models: List[PolyData or UnstructuredGrid or DataSet],
57 | ) -> PolyData or UnstructuredGrid:
58 | """Merge all models in the `models` list. The format of all models must be the same."""
59 |
60 | merged_model = models[0]
61 | for model in models[1:]:
62 | merged_model = merged_model.merge(model)
63 |
64 | return merged_model
65 |
66 |
67 | def rigid_transform(
68 | coords: np.ndarray,
69 | coords_refA: np.ndarray,
70 | coords_refB: np.ndarray,
71 | ) -> np.ndarray:
72 | """
73 | Compute optimal transformation based on the two sets of points and apply the transformation to other points.
74 |
75 | Args:
76 | coords: Coordinate matrix needed to be transformed.
77 | coords_refA: Referential coordinate matrix before transformation.
78 | coords_refB: Referential coordinate matrix after transformation.
79 |
80 | Returns:
81 | The coordinate matrix after transformation
82 | """
83 | # Check the spatial coordinates
84 |
85 | coords, coords_refA, coords_refB = (
86 | coords.copy(),
87 | coords_refA.copy(),
88 | coords_refB.copy(),
89 | )
90 | assert (
91 | coords.shape[1] == coords_refA.shape[1] == coords_refA.shape[1]
92 | ), "The dimensions of the input coordinates must be uniform, 2D or 3D."
93 | coords_dim = coords.shape[1]
94 | if coords_dim == 2:
95 | coords = np.c_[coords, np.zeros(shape=(coords.shape[0], 1))]
96 | coords_refA = np.c_[coords_refA, np.zeros(shape=(coords_refA.shape[0], 1))]
97 | coords_refB = np.c_[coords_refB, np.zeros(shape=(coords_refB.shape[0], 1))]
98 |
99 | # Compute optimal transformation based on the two sets of points.
100 | coords_refA = coords_refA.T
101 | coords_refB = coords_refB.T
102 |
103 | centroid_A = np.mean(coords_refA, axis=1).reshape(-1, 1)
104 | centroid_B = np.mean(coords_refB, axis=1).reshape(-1, 1)
105 |
106 | Am = coords_refA - centroid_A
107 | Bm = coords_refB - centroid_B
108 | H = Am @ np.transpose(Bm)
109 |
110 | U, S, Vt = np.linalg.svd(H)
111 | R = Vt.T @ U.T
112 |
113 | if np.linalg.det(R) < 0:
114 | Vt[2, :] *= -1
115 | R = Vt.T @ U.T
116 |
117 | t = -R @ centroid_A + centroid_B
118 |
119 | # Apply the transformation to other points
120 | new_coords = (R @ coords.T) + t
121 | new_coords = np.asarray(new_coords.T)
122 | return new_coords[:, :2] if coords_dim == 2 else new_coords
123 |
124 |
125 | def _scale_model_by_scale_factor(
126 | model: DataSet,
127 | scale_factor: Union[int, float, list, tuple] = 1,
128 | scale_center: Union[list, tuple] = None,
129 | ) -> DataSet:
130 | # Check the scaling factor.
131 | scale_factor = (
132 | scale_factor if isinstance(scale_factor, (tuple, list)) else [scale_factor] * 3
133 | )
134 | if len(scale_factor) != 3:
135 | raise ValueError(
136 | "`scale_factor` value is wrong."
137 | "\nWhen `scale_factor` is a list or tuple, it can only contain three elements."
138 | )
139 |
140 | # Check the scaling center.
141 | scale_center = model.center if scale_center is None else scale_center
142 | if len(scale_center) != 3:
143 | raise ValueError(
144 | "`scale_center` value is wrong."
145 | "\n`scale_center` can only contain three elements."
146 | )
147 |
148 | # Scale the model based on the scale center.
149 | for i, (f, c) in enumerate(zip(scale_factor, scale_center)):
150 | model.points[:, i] = (model.points[:, i] - c) * f + c
151 |
152 | return model
153 |
154 |
155 | def scale_model(
156 | model: Union[PolyData, UnstructuredGrid],
157 | scale_factor: Union[float, int, list, tuple] = 1,
158 | scale_center: Union[list, tuple] = None,
159 | inplace: bool = False,
160 | ) -> Union[PolyData, UnstructuredGrid, None]:
161 | """
162 | Scale the model around the center of the model.
163 |
164 | Args:
165 | model: A 3D reconstructed model.
166 | scale_factor: The scale by which the model is scaled. If `scale factor` is float, the model is scaled along the
167 | xyz axis at the same scale; when the `scale factor` is list, the model is scaled along the xyz
168 | axis at different scales. If `scale_factor` is None, there will be no scaling based on scale factor.
169 | scale_center: Scaling center. If `scale factor` is None, the `scale_center` will default to the center of the model.
170 | inplace: Updates model in-place.
171 |
172 | Returns:
173 | model_s: The scaled model.
174 | """
175 |
176 | model_s = model.copy() if not inplace else model
177 | model_s = _scale_model_by_scale_factor(
178 | model=model_s, scale_factor=scale_factor, scale_center=scale_center
179 | )
180 | model_s = model_s.triangulate()
181 | return model_s if not inplace else None
182 |
183 |
184 | def marching_cube_mesh(
185 | pc: PolyData,
186 | levelset: Union[int, float] = 0,
187 | mc_scale_factor: Union[int, float] = 1.0,
188 | ):
189 | """
190 | Computes a triangle mesh from a point cloud based on the marching cube algorithm.
191 | Algorithm Overview:
192 | The algorithm proceeds through the scalar field, taking eight neighbor locations at a time (thus forming an
193 | imaginary cube), then determining the polygon(s) needed to represent the part of the iso-surface that passes
194 | through this cube. The individual polygons are then fused into the desired surface.
195 |
196 | Args:
197 | pc: A point cloud model.
198 | levelset: The levelset of iso-surface. It is recommended to set levelset to 0 or 0.5.
199 | mc_scale_factor: The scale of the model. The scaled model is used to construct the mesh model.
200 |
201 | Returns:
202 | A mesh model.
203 | """
204 | try:
205 | import mcubes
206 | except ImportError:
207 | raise ImportError(
208 | "You need to install the package `mcubes`."
209 | "\nInstall mcubes via `pip install --upgrade PyMCubes`"
210 | )
211 |
212 | pc = pc.copy()
213 |
214 | # Move the model so that the coordinate minimum is at (0, 0, 0).
215 | raw_points = np.asarray(pc.points)
216 | pc.points = new_points = raw_points - np.min(raw_points, axis=0)
217 |
218 | # Generate new models for calculatation.
219 | dist = cdist(XA=new_points, XB=new_points, metric="euclidean")
220 | row, col = np.diag_indices_from(dist)
221 | dist[row, col] = None
222 | max_dist = np.nanmin(dist, axis=1).max()
223 | mc_sf = max_dist * mc_scale_factor
224 |
225 | scale_pc = scale_model(model=pc, scale_factor=1 / mc_sf)
226 | scale_pc_points = scale_pc.points = np.ceil(np.asarray(scale_pc.points)).astype(
227 | np.int64
228 | )
229 |
230 | # Generate grid for calculatation based on new model.
231 | volume_array = np.zeros(
232 | shape=[
233 | scale_pc_points[:, 0].max() + 3,
234 | scale_pc_points[:, 1].max() + 3,
235 | scale_pc_points[:, 2].max() + 3,
236 | ]
237 | )
238 | volume_array[
239 | scale_pc_points[:, 0], scale_pc_points[:, 1], scale_pc_points[:, 2]
240 | ] = 1
241 |
242 | # Extract the iso-surface based on marching cubes algorithm.
243 | # volume_array = mcubes.smooth(volume_array)
244 | vertices, triangles = mcubes.marching_cubes(volume_array, levelset)
245 |
246 | if len(vertices) == 0:
247 | raise ValueError(
248 | f"The point cloud cannot generate a surface mesh with `marching_cube` method."
249 | )
250 |
251 | v = np.asarray(vertices).astype(np.float64)
252 | f = np.asarray(triangles).astype(np.int64)
253 | f = np.c_[np.full(len(f), 3), f]
254 |
255 | # Generate mesh model.
256 | mesh = pv.PolyData(v, f.ravel()).extract_surface().triangulate()
257 | mesh.clean(inplace=True)
258 | mesh = scale_model(model=mesh, scale_factor=mc_sf)
259 |
260 | # Transform.
261 | scale_pc = scale_model(model=scale_pc, scale_factor=mc_sf)
262 | mesh.points = rigid_transform(
263 | coords=np.asarray(mesh.points),
264 | coords_refA=np.asarray(scale_pc.points),
265 | coords_refB=raw_points,
266 | )
267 | return mesh
268 |
269 |
270 | def smooth_mesh(mesh: PolyData, n_iter: int = 100, **kwargs) -> PolyData:
271 | """
272 | Adjust point coordinates using Laplacian smoothing.
273 | https://docs.pyvista.org/api/core/_autosummary/pyvista.PolyData.smooth.html#pyvista.PolyData.smooth
274 |
275 | Args:
276 | mesh: A mesh model.
277 | n_iter: Number of iterations for Laplacian smoothing.
278 | **kwargs: The rest of the parameters in pyvista.PolyData.smooth.
279 |
280 | Returns:
281 | smoothed_mesh: A smoothed mesh model.
282 | """
283 |
284 | smoothed_mesh = mesh.smooth(n_iter=n_iter, **kwargs)
285 |
286 | return smoothed_mesh
287 |
288 |
289 | def fix_mesh(mesh: PolyData) -> PolyData:
290 | """Repair the mesh where it was extracted and subtle holes along complex parts of the mesh."""
291 |
292 | # Check pymeshfix package
293 | try:
294 | import pymeshfix as mf
295 | except ImportError:
296 | raise ImportError(
297 | "You need to install the package `pymeshfix`. \nInstall pymeshfix via `pip install pymeshfix`"
298 | )
299 |
300 | meshfix = mf.MeshFix(mesh)
301 | meshfix.repair(verbose=False)
302 | fixed_mesh = meshfix.mesh.triangulate().clean()
303 |
304 | if fixed_mesh.n_points == 0:
305 | raise ValueError(
306 | f"The surface cannot be Repaired. "
307 | f"\nPlease change the method or parameters of surface reconstruction."
308 | )
309 |
310 | return fixed_mesh
311 |
312 |
313 | def clean_mesh(mesh: PolyData) -> PolyData:
314 | """Removes unused points and degenerate cells."""
315 |
316 | sub_meshes = mesh.split_bodies()
317 | n_mesh = len(sub_meshes)
318 |
319 | if n_mesh == 1:
320 | return mesh
321 | else:
322 | inside_number = []
323 | for i, main_mesh in enumerate(sub_meshes[:-1]):
324 | main_mesh = pv.PolyData(main_mesh.points, main_mesh.cells)
325 | for j, check_mesh in enumerate(sub_meshes[i + 1 :]):
326 | check_mesh = pv.PolyData(check_mesh.points, check_mesh.cells)
327 | inside = check_mesh.select_enclosed_points(
328 | main_mesh, check_surface=False
329 | ).threshold(0.5)
330 | inside = pv.PolyData(inside.points, inside.cells)
331 | if check_mesh == inside:
332 | inside_number.append(i + 1 + j)
333 |
334 | cm_number = list(set([i for i in range(n_mesh)]).difference(set(inside_number)))
335 | if len(cm_number) == 1:
336 | cmesh = sub_meshes[cm_number[0]]
337 | else:
338 | cmesh = merge_models([sub_meshes[i] for i in cm_number])
339 |
340 | return pv.PolyData(cmesh.points, cmesh.cells)
341 |
342 |
343 | def uniform_mesh(
344 | mesh: PolyData, nsub: Optional[int] = 3, nclus: int = 20000
345 | ) -> PolyData:
346 | """
347 | Generate a uniformly meshed surface using voronoi clustering.
348 |
349 | Args:
350 | mesh: A mesh model.
351 | nsub: Number of subdivisions. Each subdivision creates 4 new triangles, so the number of resulting triangles is
352 | nface*4**nsub where nface is the current number of faces.
353 | nclus: Number of voronoi clustering.
354 |
355 | Returns:
356 | new_mesh: A uniform mesh model.
357 | """
358 | # Check pyacvd package
359 | try:
360 | import pyacvd
361 | except ImportError:
362 | raise ImportError(
363 | "You need to install the package `pyacvd`. \nInstall pyacvd via `pip install pyacvd`"
364 | )
365 |
366 | # if mesh is not dense enough for uniform remeshing, increase the number of triangles in a mesh.
367 | if not (nsub is None):
368 | mesh.subdivide(nsub=nsub, subfilter="butterfly", inplace=True)
369 |
370 | # Uniformly remeshing.
371 | clustered = pyacvd.Clustering(mesh)
372 | clustered.cluster(nclus)
373 |
374 | new_mesh = clustered.create_mesh().triangulate().clean()
375 | return new_mesh
376 |
377 |
378 | def construct_surface(
379 | pc: PolyData,
380 | levelset: Union[int, float] = 0,
381 | mc_scale_factor: Union[int, float] = 1.0,
382 | nsub: Optional[int] = 3,
383 | nclus: int = 20000,
384 | smooth: Optional[int] = 1000,
385 | scale_factor: Union[float, int, list, tuple] = None,
386 | ) -> Union[PolyData, UnstructuredGrid, None]:
387 | """
388 | Surface mesh reconstruction based on 3D point cloud model.
389 |
390 | Args:
391 | pc: A point cloud model.
392 | levelset: The levelset of iso-surface. It is recommended to set levelset to 0 or 0.5.
393 | mc_scale_factor: The scale of the model. The scaled model is used to construct the mesh model.
394 | nsub: Number of subdivisions. Each subdivision creates 4 new triangles, so the number of resulting triangles is
395 | nface*4**nsub where nface is the current number of faces.
396 | nclus: Number of voronoi clustering.
397 | smooth: Number of iterations for Laplacian smoothing.
398 | scale_factor: The scale by which the model is scaled. If ``scale factor`` is float, the model is scaled along the
399 | xyz axis at the same scale; when the ``scale factor`` is list, the model is scaled along the xyz
400 | axis at different scales. If ``scale_factor`` is None, there will be no scaling based on scale factor.
401 |
402 | Returns:
403 | uniform_surf: A reconstructed surface mesh, which contains the following properties:
404 | ``uniform_surf.cell_data[key_added]``, the ``label`` array;
405 | ``uniform_surf.cell_data[f'{key_added}_rgba']``, the rgba colors of the ``label`` array.
406 | """
407 | # Reconstruct surface mesh.
408 | surf = marching_cube_mesh(pc=pc, levelset=levelset, mc_scale_factor=mc_scale_factor)
409 |
410 | # Removes unused points and degenerate cells.
411 | csurf = clean_mesh(mesh=surf)
412 |
413 | uniform_surfs = []
414 | for sub_surf in csurf.split_bodies():
415 | # Repair the surface mesh where it was extracted and subtle holes along complex parts of the mesh
416 | sub_fix_surf = fix_mesh(mesh=sub_surf.extract_surface())
417 |
418 | # Get a uniformly meshed surface using voronoi clustering.
419 | sub_uniform_surf = uniform_mesh(mesh=sub_fix_surf, nsub=nsub, nclus=nclus)
420 | uniform_surfs.append(sub_uniform_surf)
421 | uniform_surf = merge_models(models=uniform_surfs)
422 | uniform_surf = uniform_surf.extract_surface().triangulate().clean()
423 |
424 | # Adjust point coordinates using Laplacian smoothing.
425 | if not (smooth is None):
426 | uniform_surf = smooth_mesh(mesh=uniform_surf, n_iter=smooth)
427 |
428 | # Scale the surface mesh.
429 | uniform_surf = scale_model(model=uniform_surf, scale_factor=scale_factor)
430 | return uniform_surf
431 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/__init__.py:
--------------------------------------------------------------------------------
1 | from .container import ui_container
2 | from .drawer import ui_drawer
3 | from .layout import ui_layout
4 | from .toolbar import ui_toolbar
5 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/container.py:
--------------------------------------------------------------------------------
1 | try:
2 | from typing import Literal
3 | except ImportError:
4 | from typing_extensions import Literal
5 |
6 | from trame.widgets import html
7 | from trame.widgets import vtk as vtk_widgets
8 | from trame.widgets import vuetify
9 |
10 | VIEW_INTERACT = [
11 | {"button": 1, "action": "Rotate"},
12 | {"button": 2, "action": "Pan"},
13 | {"button": 3, "action": "Zoom", "scrollEnabled": True},
14 | {"button": 1, "action": "Pan", "alt": True},
15 | {"button": 1, "action": "Zoom", "control": True},
16 | {"button": 1, "action": "Pan", "shift": True},
17 | {"button": 1, "action": "Roll", "alt": True, "shift": True},
18 | ]
19 |
20 | VIEW_SELECT = [{"button": 1, "action": "Select"}]
21 |
22 |
23 | # -----------------------------------------------------------------------------
24 | # GUI- standard Container
25 | # -----------------------------------------------------------------------------
26 |
27 |
28 | def ui_container(
29 | server,
30 | layout,
31 | ):
32 | """
33 | Generate standard VContainer for Spateo UI.
34 |
35 | Args:
36 | server: The trame server.
37 | layout: The layout object.
38 | """
39 |
40 | state, ctrl = server.state, server.controller
41 | with layout.content:
42 | with vuetify.VContainer(
43 | fluid=True, classes="pa-0 fill-height", style="position: relative;"
44 | ):
45 | with vuetify.VCard(
46 | style=("tooltipStyle", {"display": "none"}), elevation=2, outlined=True
47 | ):
48 | with vuetify.VCardText():
49 | html.Pre("{{ tooltip }}")
50 |
51 | with vtk_widgets.VtkView(
52 | ref="render",
53 | background=(state.background_color,),
54 | picking_modes=("[pickingMode]",),
55 | interactor_settings=("interactorSettings", VIEW_INTERACT),
56 | click="pickData = $event",
57 | hover="pickData = $event",
58 | select="selectData = $event",
59 | ) as view:
60 | ctrl.view_reset_camera = view.reset_camera
61 | with vtk_widgets.VtkGeometryRepresentation(
62 | id="activeModel",
63 | v_if="activeModel",
64 | actor=("{ visibility: activeModelVisible }",),
65 | color_map_preset=("colorMap",),
66 | color_data_range=("scalarParameters[scalar].range",),
67 | mapper=(
68 | "{ colorByArrayName: scalar, scalarMode: scalarParameters[scalar].scalarMode,"
69 | " interpolateScalarsBeforeMapping: true, scalarVisibility: scalar !== 'Default' }",
70 | ),
71 | property=(
72 | {
73 | "pointSize": state.pixel_ratio,
74 | "representation": 1,
75 | "opacity": 1,
76 | "ambient": 0.3,
77 | },
78 | ),
79 | ):
80 | vtk_widgets.VtkMesh("activeModel", state=("activeModel",))
81 | with vtk_widgets.VtkGeometryRepresentation(
82 | id="meshModel",
83 | v_if="meshModel",
84 | actor=("{ visibility: meshModelVisible }",),
85 | property=(
86 | {
87 | "representation": 1,
88 | "opacity": 0.6,
89 | "ambient": 0.1,
90 | },
91 | ),
92 | ):
93 | vtk_widgets.VtkMesh("meshModel", state=("meshModel",))
94 |
95 | # Custom model visibility
96 | if state.custom_func is True:
97 | with vtk_widgets.VtkGeometryRepresentation(
98 | id="custom_model",
99 | v_if="custom_model",
100 | actor=("{ visibility: custom_model_visible }",),
101 | property=(
102 | {
103 | "lineWidth": state.custom_model_size,
104 | "pointSize": state.custom_model_size,
105 | "representation": 1,
106 | "opacity": 1,
107 | "ambient": 0.1,
108 | },
109 | ),
110 | ):
111 | vtk_widgets.VtkMesh("custom_model", state=("custom_model",))
112 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/drawer/__init__.py:
--------------------------------------------------------------------------------
1 | from .main import ui_drawer
2 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/drawer/alignment.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import vuetify
2 |
3 |
4 | def align_card_content():
5 | with vuetify.VRow(classes="pt-2", dense=True):
6 | with vuetify.VCol(cols="6"):
7 | vuetify.VCheckbox(
8 | v_model=("slices_alignment", False),
9 | label="Align slices",
10 | on_icon="mdi-layers-outline",
11 | off_icon="mdi-layers-off-outline",
12 | dense=True,
13 | hide_details=True,
14 | classes="pt-1",
15 | )
16 | with vuetify.VCol(cols="6"):
17 | vuetify.VSelect(
18 | v_model=("slices_align_method", "Paste"),
19 | items=(["Paste", "Morpho"],),
20 | show_size=True,
21 | dense=True,
22 | outlined=True,
23 | hide_details=True,
24 | classes="pt-1",
25 | label="Method of Alignment",
26 | )
27 | with vuetify.VRow(classes="pt-2", dense=True):
28 | with vuetify.VCol(cols="6"):
29 | vuetify.VTextField(
30 | v_model=("slices_key", "slices"),
31 | label="Slices Key",
32 | hide_details=True,
33 | dense=True,
34 | outlined=True,
35 | classes="pt-1",
36 | )
37 | with vuetify.VCol(cols="6"):
38 | vuetify.VTextField(
39 | v_model=("slices_align_factor", 0.1),
40 | label="Align Factor",
41 | hide_details=True,
42 | dense=True,
43 | outlined=True,
44 | classes="pt-1",
45 | )
46 | with vuetify.VRow(classes="pt-2", dense=True):
47 | with vuetify.VCol(cols="6"):
48 | vuetify.VTextField(
49 | v_model=("slices_align_max_iter", 200),
50 | label="Max Iterations",
51 | hide_details=True,
52 | dense=True,
53 | outlined=True,
54 | classes="pt-1",
55 | )
56 | with vuetify.VCol(cols="6"):
57 | vuetify.VTextField(
58 | v_model=("slices_align_device", "CPU"),
59 | label="Device",
60 | show_size=True,
61 | dense=True,
62 | outlined=True,
63 | hide_details=True,
64 | classes="pt-1",
65 | )
66 |
67 |
68 | def align_card_panel():
69 | with vuetify.VToolbar(
70 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;"
71 | ):
72 | vuetify.VIcon("mdi-target")
73 | vuetify.VCardTitle(
74 | " Slices Alignment",
75 | classes="pa-0 ma-0",
76 | style="flex: none;",
77 | hide_details=True,
78 | dense=True,
79 | )
80 |
81 | vuetify.VSpacer()
82 | with vuetify.VBtn(
83 | small=True,
84 | icon=True,
85 | click="show_align_card = !show_align_card",
86 | ):
87 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_align_card",))
88 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_align_card",))
89 |
90 | # Main content
91 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"):
92 | with vuetify.VCardText(classes="py-2", v_if=("show_align_card",)):
93 | align_card_content()
94 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/drawer/custom_card.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import vuetify
2 |
3 |
4 | def custom_card_content():
5 | with vuetify.VRow(classes="pt-2", dense=True):
6 | with vuetify.VCol(cols="6"):
7 | vuetify.VCheckbox(
8 | v_model=("custom_analysis", False),
9 | label="Custom analysis calculation",
10 | on_icon="mdi-billiards-rack",
11 | off_icon="mdi-dots-triangle",
12 | dense=True,
13 | hide_details=True,
14 | classes="pt-1",
15 | )
16 | with vuetify.VCol(cols="6"):
17 | vuetify.VCheckbox(
18 | v_model=("custom_model_visible", False),
19 | label="Custom model visibility",
20 | on_icon="mdi-eye-outline",
21 | off_icon="mdi-eye-off-outline",
22 | dense=True,
23 | hide_details=True,
24 | classes="pt-1",
25 | )
26 |
27 | with vuetify.VRow(classes="pt-2", dense=True):
28 | with vuetify.VCol(cols="6"):
29 | vuetify.VSelect(
30 | label="Custom Parameter 1",
31 | v_model=("custom_parameter1", "ElPiGraph"),
32 | items=(["ElPiGraph", "SimplePPT", "PrinCurve"],),
33 | hide_details=True,
34 | dense=True,
35 | outlined=True,
36 | classes="pt-1",
37 | )
38 |
39 | with vuetify.VCol(cols="6"):
40 | vuetify.VTextField(
41 | label="Custom Parameter 2",
42 | v_model=("custom_parameter2", 50),
43 | hide_details=True,
44 | dense=True,
45 | outlined=True,
46 | classes="pt-1",
47 | )
48 |
49 | with vuetify.VRow(classes="pt-2", dense=True):
50 | with vuetify.VCol(cols="12"):
51 | vuetify.VTextField(
52 | v_model=("custom_model_output", None),
53 | label="Custom model Output",
54 | show_size=True,
55 | hide_details=True,
56 | dense=True,
57 | outlined=True,
58 | classes="pt-1",
59 | )
60 |
61 |
62 | def custom_card_panel():
63 | with vuetify.VToolbar(
64 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;"
65 | ):
66 | vuetify.VIcon("mdi-billiards-rack")
67 | vuetify.VCardTitle(
68 | " Custom Card",
69 | classes="pa-0 ma-0",
70 | style="flex: none;",
71 | hide_details=True,
72 | dense=True,
73 | )
74 |
75 | vuetify.VSpacer()
76 | with vuetify.VBtn(
77 | small=True,
78 | icon=True,
79 | click="show_custom_card = !show_custom_card",
80 | ):
81 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_custom_card",))
82 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_custom_card",))
83 |
84 | # Main content
85 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"):
86 | with vuetify.VCardText(classes="py-2", v_if=("show_custom_card",)):
87 | custom_card_content()
88 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/drawer/main.py:
--------------------------------------------------------------------------------
1 | def _get_spateo_cmap():
2 | import matplotlib as mpl
3 | from matplotlib.colors import LinearSegmentedColormap
4 |
5 | if "spateo_cmap" not in mpl.colormaps():
6 | colors = ["#4B0082", "#800080", "#F97306", "#FFA500", "#FFD700", "#FFFFCB"]
7 | nodes = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
8 |
9 | mpl.colormaps.register(
10 | LinearSegmentedColormap.from_list("spateo_cmap", list(zip(nodes, colors)))
11 | )
12 | return "spateo_cmap"
13 |
14 |
15 | def ui_drawer(server, layout):
16 | """
17 | Generate standard Drawer for Spateo UI.
18 |
19 | Args:
20 | server: The trame server.
21 | layout: The layout object.
22 | """
23 |
24 | _get_spateo_cmap()
25 | with layout.drawer as dr:
26 | # Active model
27 | from .model_point import pc_card_panel
28 |
29 | pc_card_panel()
30 | # Slices alignment
31 | from .alignment import align_card_panel
32 |
33 | align_card_panel()
34 | # Mesh reconstruction
35 | from .reconstruction import mesh_card_panel
36 |
37 | mesh_card_panel()
38 | # Custom
39 | if server.state.custom_func is True:
40 | from .custom_card import custom_card_panel
41 |
42 | custom_card_panel()
43 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/drawer/model_point.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import vuetify
2 |
3 |
4 | def pc_card_content():
5 | with vuetify.VRow(classes="pt-2", dense=True):
6 | with vuetify.VCol(cols="6"):
7 | vuetify.VSelect(
8 | v_model=("scalar", "Default"),
9 | items=("Object.values(scalarParameters)",),
10 | show_size=True,
11 | dense=True,
12 | outlined=True,
13 | hide_details=True,
14 | classes="pt-1",
15 | label="Scalars",
16 | )
17 | with vuetify.VCol(cols="6"):
18 | vuetify.VSelect(
19 | v_model=("colorMap", "erdc_rainbow_bright"),
20 | items=("trame.utils.vtk.vtkColorPresetItems('')",),
21 | show_size=True,
22 | # truncate_length=25,
23 | dense=True,
24 | outlined=True,
25 | hide_details=True,
26 | classes="pt-1",
27 | # style="max-width: 150px",
28 | label="Colormap",
29 | )
30 | with vuetify.VRow(classes="pt-2", dense=True):
31 | with vuetify.VCol(cols="6"):
32 | vuetify.VSelect(
33 | v_model=("picking_group", None),
34 | items=("Object.keys(scalarParameters[scalar].raw_labels)",),
35 | show_size=True,
36 | dense=True,
37 | outlined=True,
38 | hide_details=True,
39 | classes="pt-2",
40 | label="Picking Group",
41 | )
42 | with vuetify.VCol(cols="6"):
43 | vuetify.VCheckbox(
44 | v_model=("overwrite", False),
45 | label="Overwrite the Active Model",
46 | on_icon="mdi-plus-thick",
47 | off_icon="mdi-close-thick",
48 | dense=True,
49 | hide_details=True,
50 | classes="pt-1",
51 | )
52 |
53 | with vuetify.VRow(classes="pt-2", dense=True):
54 | with vuetify.VCol(cols="12"):
55 | vuetify.VTextField(
56 | v_model=("activeModel_output", None),
57 | label="Active Model Output",
58 | hide_details=True,
59 | dense=True,
60 | outlined=True,
61 | classes="pt-1",
62 | )
63 | with vuetify.VRow(classes="pt-2", dense=True):
64 | with vuetify.VCol(cols="12"):
65 | vuetify.VTextField(
66 | v_model=("anndata_output", None),
67 | label="Anndata Output",
68 | hide_details=True,
69 | dense=True,
70 | outlined=True,
71 | classes="pt-1",
72 | )
73 | with vuetify.VRow(classes="pt-2", dense=True):
74 | with vuetify.VCol(cols="12"):
75 | vuetify.VCheckbox(
76 | v_model=("activeModelVisible", True),
77 | label="Visibility of Active Model",
78 | on_icon="mdi-eye-outline",
79 | off_icon="mdi-eye-off-outline",
80 | dense=True,
81 | hide_details=True,
82 | classes="pt-1",
83 | )
84 |
85 |
86 | def pc_card_panel():
87 | with vuetify.VToolbar(
88 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;"
89 | ):
90 | vuetify.VIcon("mdi-format-paint")
91 | vuetify.VCardTitle(
92 | " Active Model",
93 | classes="pa-0 ma-0",
94 | style="flex: none;",
95 | hide_details=True,
96 | dense=True,
97 | )
98 |
99 | vuetify.VSpacer()
100 | with vuetify.VBtn(
101 | small=True,
102 | icon=True,
103 | click="show_active_card = !show_active_card",
104 | ):
105 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_active_card",))
106 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_active_card",))
107 |
108 | # Main content
109 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"):
110 | with vuetify.VCardText(classes="py-2", v_if=("show_active_card",)):
111 | pc_card_content()
112 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/drawer/reconstruction.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import vuetify
2 |
3 |
4 | def mesh_card_content():
5 | with vuetify.VRow(classes="pt-2", dense=True):
6 | with vuetify.VCol(cols="6"):
7 | vuetify.VCheckbox(
8 | v_model=("reconstruct_mesh", False),
9 | label="Reconstruct Mesh Model",
10 | on_icon="mdi-billiards-rack",
11 | off_icon="mdi-dots-triangle",
12 | dense=True,
13 | hide_details=True,
14 | classes="pt-1",
15 | )
16 | with vuetify.VCol(cols="6"):
17 | vuetify.VCheckbox(
18 | v_model=("clip_pc_with_mesh", False),
19 | label="Clip with Mesh Model",
20 | on_icon="mdi-box-cutter",
21 | off_icon="mdi-box-cutter-off",
22 | dense=True,
23 | hide_details=True,
24 | classes="pt-1",
25 | )
26 |
27 | with vuetify.VRow(classes="pt-2", dense=True):
28 | with vuetify.VCol(cols="6"):
29 | vuetify.VTextField(
30 | v_model=("mc_factor", 1.0),
31 | label="MC Factor",
32 | hide_details=True,
33 | dense=True,
34 | outlined=True,
35 | classes="pt-1",
36 | )
37 | with vuetify.VCol(cols="6"):
38 | vuetify.VTextField(
39 | v_model=("mesh_voronoi", 20000),
40 | label="Voronoi Clustering",
41 | hide_details=True,
42 | dense=True,
43 | outlined=True,
44 | classes="pt-1",
45 | )
46 |
47 | with vuetify.VRow(classes="pt-2", dense=True):
48 | with vuetify.VCol(cols="6"):
49 | vuetify.VTextField(
50 | v_model=("mesh_smooth_factor", 1000),
51 | label="Smooth Factor",
52 | hide_details=True,
53 | dense=True,
54 | outlined=True,
55 | classes="pt-1",
56 | )
57 | with vuetify.VCol(cols="6"):
58 | vuetify.VTextField(
59 | v_model=("mesh_scale_factor", 1.0),
60 | label="Scale Factor",
61 | hide_details=True,
62 | dense=True,
63 | outlined=True,
64 | classes="pt-1",
65 | )
66 |
67 | with vuetify.VRow(classes="pt-2", dense=True):
68 | with vuetify.VCol(cols="12"):
69 | vuetify.VTextField(
70 | v_model=("mesh_output", None),
71 | label="Reconstructed Mesh Output",
72 | show_size=True,
73 | hide_details=True,
74 | dense=True,
75 | outlined=True,
76 | classes="pt-1",
77 | )
78 | with vuetify.VRow(classes="pt-2", dense=True):
79 | with vuetify.VCol(cols="12"):
80 | vuetify.VCheckbox(
81 | v_model=("meshModelVisible", False),
82 | label="Visibility of Mesh Model",
83 | on_icon="mdi-eye-outline",
84 | off_icon="mdi-eye-off-outline",
85 | dense=True,
86 | hide_details=True,
87 | classes="pt-1",
88 | )
89 |
90 |
91 | def mesh_card_panel():
92 | with vuetify.VToolbar(
93 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;"
94 | ):
95 | vuetify.VIcon("mdi-billiards-rack")
96 | vuetify.VCardTitle(
97 | " Mesh Reconstruction",
98 | classes="pa-0 ma-0",
99 | style="flex: none;",
100 | hide_details=True,
101 | dense=True,
102 | )
103 |
104 | vuetify.VSpacer()
105 | with vuetify.VBtn(
106 | small=True,
107 | icon=True,
108 | click="show_mesh_card = !show_mesh_card",
109 | ):
110 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_mesh_card",))
111 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_mesh_card",))
112 |
113 | # Main content
114 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"):
115 | with vuetify.VCardText(classes="py-2", v_if=("show_mesh_card",)):
116 | mesh_card_content()
117 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/layout.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | # -----------------------------------------------------------------------------
4 | # GUI layout
5 | # -----------------------------------------------------------------------------
6 |
7 |
8 | def ui_layout(
9 | server, template_name: str = "main", drawer_width: Optional[int] = None, **kwargs
10 | ):
11 | """
12 | Define the user interface (UI) layout.
13 | Reference: https://trame.readthedocs.io/en/latest/trame.ui.vuetify.html#trame.ui.vuetify.SinglePageWithDrawerLayout
14 |
15 | Args:
16 | server: Server to bound the layout to.
17 | template_name: Name of the template.
18 | drawer_width: Drawer width in pixel.
19 |
20 | Returns:
21 | The SinglePageWithDrawerLayout layout object.
22 | """
23 | from trame.ui.vuetify import SinglePageWithDrawerLayout
24 |
25 | if drawer_width is None:
26 | import pyautogui
27 |
28 | screen_width, screen_height = pyautogui.size()
29 | drawer_width = int(screen_width * 0.15)
30 |
31 | return SinglePageWithDrawerLayout(
32 | server, template_name=template_name, width=drawer_width, **kwargs
33 | )
34 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/toolbar.py:
--------------------------------------------------------------------------------
1 | try:
2 | from typing import Literal
3 | except ImportError:
4 | from typing_extensions import Literal
5 |
6 | from typing import Optional
7 |
8 | from pyvista import BasePlotter
9 | from trame.widgets import html, vuetify
10 |
11 | from stviewer.assets import icon_manager
12 | from stviewer.Reconstructor.pv_pipeline import Viewer
13 |
14 | from .utils import button, checkbox
15 |
16 | # -----------------------------------------------------------------------------
17 | # GUI- UI title
18 | # -----------------------------------------------------------------------------
19 |
20 |
21 | def ui_title(
22 | layout, title_name="SPATEO VIEWER", title_icon: Optional[str] = None, **kwargs
23 | ):
24 | """
25 | Define the title name and logo of the UI.
26 | Reference: https://trame.readthedocs.io/en/latest/trame.ui.vuetify.html#trame.ui.vuetify.SinglePageWithDrawerLayout
27 |
28 | Args:
29 | layout: The layout object.
30 | title_name: Title name of the GUI.
31 | title_icon: Title icon of the GUI.
32 | **kwargs: Additional parameters that will be passed to ``html.Img`` function.
33 |
34 | Returns:
35 | None.
36 | """
37 |
38 | # Update the toolbar's name
39 | layout.title.set_text(title_name)
40 | layout.title.style = (
41 | "font-family:arial; font-size:25px; font-weight: 550; color: gray;"
42 | )
43 |
44 | # Update the toolbar's icon
45 | if not (title_icon is None):
46 | with layout.icon as icon:
47 | icon.style = "margin-left: 10px;" # "width: 7vw; height: 7vh;"
48 | html.Img(src=title_icon, height=40, **kwargs)
49 |
50 |
51 | # -----------------------------------------------------------------------------
52 | # GUI- standard ToolBar
53 | # -----------------------------------------------------------------------------
54 |
55 |
56 | def toolbar_widgets(server, plotter: BasePlotter):
57 | """
58 | Generate standard widgets for ToolBar.
59 |
60 | Args:
61 | server: The trame server.
62 | plotter: The PyVista plotter to connect with the UI.
63 | """
64 | viewer = Viewer(server=server, plotter=plotter)
65 |
66 | vuetify.VSpacer()
67 | # Upload file
68 | vuetify.VFileInput(
69 | v_model=(viewer.UPLOAD_ANNDATA, None),
70 | label="Select Sample",
71 | show_size=True,
72 | small_chips=True,
73 | truncate_length=25,
74 | dense=True,
75 | outlined=True,
76 | hide_details=True,
77 | classes="ml-8",
78 | style="max-width: 300px;",
79 | rounded=True,
80 | accept=".h5ad",
81 | __properties=["accept"],
82 | )
83 |
84 | vuetify.VSpacer()
85 | # Change the selection mode
86 | with vuetify.VBtnToggle(v_model=(viewer.PICKING_MODE, "hover"), dense=True):
87 | with vuetify.VBtn(value=("item.value",), v_for="item, idx in modes"):
88 | vuetify.VIcon("{{item.icon}}")
89 | # Whether to reload the main model
90 | button(
91 | click=viewer.on_reload_main_model,
92 | icon="mdi-restore",
93 | tooltip="Reload main model",
94 | )
95 |
96 | vuetify.VProgressLinear(
97 | indeterminate=True, absolute=True, bottom=True, active=("trame__busy",)
98 | )
99 |
100 |
101 | def ui_toolbar(
102 | server,
103 | layout,
104 | plotter: BasePlotter,
105 | ui_name: str = "SPATEO VIEWER",
106 | ui_icon=icon_manager.spateo_logo,
107 | ):
108 | """
109 | Generate standard ToolBar for Spateo UI.
110 |
111 | Args:
112 | server: The trame server.
113 | layout: The layout object.
114 | plotter: The PyVista plotter to connect with the UI.
115 | ui_name: Title name of the GUI.
116 | ui_icon: Title icon of the GUI.
117 | """
118 |
119 | # -----------------------------------------------------------------------------
120 | # Title
121 | # -----------------------------------------------------------------------------
122 | ui_title(layout=layout, title_name=ui_name, title_icon=ui_icon)
123 |
124 | # -----------------------------------------------------------------------------
125 | # ToolBar
126 | # -----------------------------------------------------------------------------
127 | with layout.toolbar as tb:
128 | tb.height = 55
129 | tb.dense = True
130 | tb.clipped_right = True
131 | toolbar_widgets(server=server, plotter=plotter)
132 |
--------------------------------------------------------------------------------
/stviewer/Reconstructor/ui/utils.py:
--------------------------------------------------------------------------------
1 | from trame.widgets import html, vuetify
2 |
3 | # -----------------------------------------------------------------------------
4 | # vuetify components
5 | # -----------------------------------------------------------------------------
6 |
7 |
8 | def button(click, icon, tooltip):
9 | """Create a vuetify button."""
10 | with vuetify.VTooltip(bottom=True):
11 | with vuetify.Template(v_slot_activator="{ on, attrs }"):
12 | with vuetify.VBtn(icon=True, v_bind="attrs", v_on="on", click=click):
13 | vuetify.VIcon(icon)
14 | html.Span(tooltip)
15 |
16 |
17 | def checkbox(model, icons, tooltip, **kwargs):
18 | """Create a vuetify checkbox."""
19 | with vuetify.VTooltip(bottom=True):
20 | with vuetify.Template(v_slot_activator="{ on, attrs }"):
21 | with html.Div(v_on="on", v_bind="attrs"):
22 | vuetify.VCheckbox(
23 | v_model=model,
24 | on_icon=icons[0],
25 | off_icon=icons[1],
26 | dense=True,
27 | hide_details=True,
28 | classes="my-0 py-0 ml-1",
29 | **kwargs
30 | )
31 | html.Span(tooltip)
32 |
33 |
34 | def switch(model, tooltip, **kwargs):
35 | """Create a vuetify switch."""
36 | with vuetify.VTooltip(bottom=True):
37 | with vuetify.Template(v_slot_activator="{ on, attrs }"):
38 | vuetify.VSwitch(v_model=model, hide_details=True, dense=True, **kwargs)
39 | html.Span(tooltip)
40 |
--------------------------------------------------------------------------------
/stviewer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/__init__.py
--------------------------------------------------------------------------------
/stviewer/assets/__init__.py:
--------------------------------------------------------------------------------
1 | from .anndata_preprocess import anndata_preprocess
2 | from .dataset_acquisition import abstract_anndata, abstract_models, sample_dataset
3 | from .dataset_manager import local_dataset_manager
4 | from .image_manager import icon_manager
5 |
--------------------------------------------------------------------------------
/stviewer/assets/anndata_preprocess.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import anndata as ad
4 | from scipy.sparse import csr_matrix, issparse
5 |
6 |
7 | def anndata_preprocess(
8 | path: str,
9 | output_path: str,
10 | X_counts: str = "X_counts",
11 | X_log1p: Optional[str] = "X_log1p",
12 | spatial_key: str = "3d_align_spatial",
13 | ):
14 | adata = ad.read_h5ad(filename=path)
15 |
16 | # matrices
17 | X_counts = (
18 | adata.layers[X_counts].copy()
19 | if issparse(adata.layers[X_counts])
20 | else csr_matrix(adata.layers[X_counts])
21 | )
22 | if not (X_log1p is None):
23 | X_log1p = (
24 | adata.layers[X_log1p].copy()
25 | if issparse(adata.layers[X_log1p])
26 | else csr_matrix(adata.layers[X_log1p])
27 | )
28 | else:
29 | import dynamo as dyn
30 |
31 | adata.X = X_counts.copy()
32 | dyn.pp.normalize_cell_expr_by_size_factors(
33 | adata=adata, layers="X", skip_log=False
34 | )
35 | X_log1p = csr_matrix(adata.X.copy())
36 |
37 | # spatial coordinates
38 | spatial_coords = adata.obsm[spatial_key]
39 |
40 | # preprocess
41 | del adata.uns, adata.layers, adata.obsm, adata.obsp, adata.varm
42 | adata.X = X_counts
43 | adata.layers["X_log1p"] = X_log1p
44 | adata.obsm["spatial"] = spatial_coords
45 |
46 | adata.write_h5ad(output_path, compression="gzip")
47 | return adata
48 |
49 |
50 | """
51 | import os
52 | os.chdir(f"spateo-viewer/stviewer/assets/dataset")
53 | adata = anndata_preprocess(
54 | path=r"drosophila_E8_9h/h5ad/E8_9h_cellbin_v3.h5ad",
55 | output_path=r"drosophila_E8_9h/h5ad/E8_9h_cellbin_v3_new.h5ad"
56 | )
57 | print(adata)
58 | """
59 |
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/h5ad/S11_cellbin_demo.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/h5ad/S11_cellbin_demo.h5ad
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/mesh_models/0_Embryo_S11_aligned_mesh_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/0_Embryo_S11_aligned_mesh_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/mesh_models/1_CNS_S11_aligned_mesh_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/1_CNS_S11_aligned_mesh_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/mesh_models/2_Midgut_S11_aligned_mesh_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/2_Midgut_S11_aligned_mesh_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/mesh_models/3_Hindgut_S11_aligned_mesh_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/3_Hindgut_S11_aligned_mesh_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/mesh_models/4_Muscle_S11_aligned_mesh_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/4_Muscle_S11_aligned_mesh_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/mesh_models/5_SalivaryGland_S11_aligned_mesh_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/5_SalivaryGland_S11_aligned_mesh_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/mesh_models/6_Amnioserosa_S11_aligned_mesh_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/6_Amnioserosa_S11_aligned_mesh_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/pc_models/0_Embryo_S11_aligned_pc_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/0_Embryo_S11_aligned_pc_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/pc_models/1_CNS_S11_aligned_pc_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/1_CNS_S11_aligned_pc_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/pc_models/2_Midgut_S11_aligned_pc_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/2_Midgut_S11_aligned_pc_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/pc_models/3_Hindgut_S11_aligned_pc_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/3_Hindgut_S11_aligned_pc_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/pc_models/4_Muscle_S11_aligned_pc_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/4_Muscle_S11_aligned_pc_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/pc_models/5_SalivaryGland_S11_aligned_pc_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/5_SalivaryGland_S11_aligned_pc_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/drosophila_S11/pc_models/6_Amnioserosa_S11_aligned_pc_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/6_Amnioserosa_S11_aligned_pc_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/mouse_E95/h5ad/mouse_E95_demo.h5ad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/mouse_E95/h5ad/mouse_E95_demo.h5ad
--------------------------------------------------------------------------------
/stviewer/assets/dataset/mouse_E95/matrices/X_sparse_matrix.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/mouse_E95/matrices/X_sparse_matrix.npz
--------------------------------------------------------------------------------
/stviewer/assets/dataset/mouse_E95/mesh_models/0_Embryo_mouse_E95_mesh_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/mouse_E95/mesh_models/0_Embryo_mouse_E95_mesh_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset/mouse_E95/pc_models/0_Embryo_mouse_E95_pc_model.vtk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/mouse_E95/pc_models/0_Embryo_mouse_E95_pc_model.vtk
--------------------------------------------------------------------------------
/stviewer/assets/dataset_acquisition.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | from pathlib import Path
4 | from typing import Optional, Tuple
5 |
6 | import anndata as ad
7 | import matplotlib as mpl
8 | import numpy as np
9 | import pyvista as pv
10 | from anndata import AnnData
11 | from matplotlib.colors import LinearSegmentedColormap
12 | from pandas import DataFrame
13 | from scipy import sparse
14 |
15 | try:
16 | from typing import Literal
17 | except ImportError:
18 | from typing_extensions import Literal
19 |
20 |
21 | def extract_anndata_structure(adata: AnnData):
22 | # Anndata basic info
23 | obs_str, var_str, uns_str, obsm_str, layers_str = (
24 | f" obs:",
25 | f" var:",
26 | f" uns:",
27 | f" obsm:",
28 | f" layers:",
29 | )
30 |
31 | if len(list(adata.obs.keys())) != 0:
32 | for key in list(adata.obs.keys()):
33 | obs_str = obs_str + f" '{key}',"
34 | if len(list(adata.var.keys())) != 0:
35 | for key in list(adata.var.keys()):
36 | var_str = var_str + f" '{key}',"
37 | if len(list(adata.uns.keys())) != 0:
38 | for key in list(adata.uns.keys()):
39 | uns_str = uns_str + f" '{key}',"
40 | if len(list(adata.obsm.keys())) != 0:
41 | for key in list(adata.obsm.keys()):
42 | obsm_str = obsm_str + f" '{key}',"
43 | if len(list(adata.layers.keys())) != 0:
44 | for key in list(adata.layers.keys()):
45 | layers_str = layers_str + f" '{key}',"
46 |
47 | anndata_structure = (
48 | f"AnnData object with n_obs × n_vars = {adata.shape[0]} × {adata.shape[1]}\n"
49 | )
50 | for ad_str in [obs_str, var_str, uns_str, obsm_str, layers_str]:
51 | if ad_str.endswith(","):
52 | anndata_structure = anndata_structure + f"{ad_str[:-1]}\n"
53 | return anndata_structure
54 |
55 |
56 | def abstract_anndata(path: str, X_layer: str = "X") -> Tuple[AnnData, str]:
57 | adata = ad.read_h5ad(filename=path)
58 | anndata_structure = extract_anndata_structure(adata=adata)
59 | if X_layer != "X":
60 | assert (
61 | X_layer in adata.layers.keys()
62 | ), f"``{X_layer}`` does not exist in `adata.layers`."
63 | adata.X = adata.layers[X_layer]
64 |
65 | return adata, anndata_structure
66 |
67 |
68 | def abstract_models(path: str, model_ids: Optional[list] = None):
69 | model_files = os.listdir(path=path)
70 | model_files.sort()
71 | assert len(model_files) != 0, "There is no file under this path."
72 |
73 | models = [pv.read(filename=os.path.join(path, f)) for f in model_files]
74 | if model_ids is None: # Cannot contain `-` and ` `.
75 | model_ids = [f"Model{i}" for i in range(len(models))]
76 | assert len(model_ids) == len(
77 | models
78 | ), "The number of model_ids does not equal to that of models."
79 |
80 | return models, model_ids
81 |
82 |
83 | def sample_dataset(
84 | path: str,
85 | X_layer: str = "X",
86 | pc_model_ids: Optional[list] = None,
87 | mesh_model_ids: Optional[list] = None,
88 | ):
89 | # Generate anndata object
90 | if os.path.isfile(path) and path.endswith(".h5ad"):
91 | anndata_path = path
92 | matrices_npz_path = f"./temp/matrices_{path.split('/')[-1]}"
93 |
94 | adata, anndata_structure = abstract_anndata(path=path, X_layer=X_layer)
95 | elif os.path.isdir(path):
96 | anndata_dir = os.path.join(path, "h5ad")
97 | anndata_list = [
98 | f for f in os.listdir(path=anndata_dir) if str(f).endswith(".h5ad")
99 | ]
100 | anndata_path = os.path.join(anndata_dir, anndata_list[0])
101 | matrices_npz_path = os.path.join(path, "matrices")
102 |
103 | adata, anndata_structure = abstract_anndata(
104 | path=os.path.join(anndata_dir, anndata_list[0]),
105 | X_layer=X_layer,
106 | )
107 | else:
108 | raise ValueError(f"`{path}` is not available for spateo-viewer.")
109 |
110 | ## Generate info-dict of anndata object
111 | anndata_info = {
112 | "anndata_path": anndata_path,
113 | "anndata_structure": anndata_structure,
114 | "anndata_obs_keys": list(adata.obs_keys()),
115 | "anndata_obs_index": list(adata.obs.index.to_list()),
116 | "anndata_var_index": list(adata.var.index.to_list()),
117 | "anndata_obsm_keys": [
118 | key for key in ["spatial", "X_umap"] if key in adata.obsm.keys()
119 | ],
120 | "anndata_matrices": ["X"] + [i for i in adata.layers.keys()],
121 | "matrices_npz_path": matrices_npz_path,
122 | }
123 |
124 | # Check matrices
125 | if not os.path.exists(anndata_info["matrices_npz_path"]):
126 | Path(anndata_info["matrices_npz_path"]).mkdir(parents=True, exist_ok=True)
127 | for matrix_id in anndata_info["anndata_matrices"]:
128 | matrix = adata.X if matrix_id == "X" else adata.layers[matrix_id]
129 | sparse.save_npz(
130 | f"{anndata_info['matrices_npz_path']}/{matrix_id}_sparse_matrix.npz",
131 | matrix,
132 | )
133 | else:
134 | pass
135 |
136 | # Generate point cloud models
137 | if os.path.isdir(path) and os.path.exists(os.path.join(path, "pc_models")):
138 | pc_model_files = [
139 | f
140 | for f in os.listdir(path=os.path.join(path, "pc_models"))
141 | if str(f).endswith(".vtk") or str(f).endswith(".vtm")
142 | ]
143 | pc_model_files.sort()
144 |
145 | if pc_model_ids is None:
146 | pc_model_ids = [f"PC_{str(i).split('_')[1]}" for i in pc_model_files]
147 | _pc_models, pc_model_ids = abstract_models(
148 | path=os.path.join(path, "pc_models"), model_ids=pc_model_ids
149 | )
150 | else:
151 | bucket_xyz = adata.obsm["spatial"].astype(np.float64)
152 | if isinstance(bucket_xyz, DataFrame):
153 | bucket_xyz = bucket_xyz.values
154 | pc_model = pv.PolyData(bucket_xyz)
155 | pc_model.point_data["obs_index"] = np.array(adata.obs_names.tolist())
156 | _pc_models, pc_model_ids = [pc_model], ["PC_Model"]
157 |
158 | pc_models = []
159 | for pc_model in _pc_models:
160 | _obs_index = pc_model.point_data["obs_index"]
161 | for obsm_key in anndata_info["anndata_obsm_keys"]:
162 | coords = np.asarray(adata[_obs_index, :].obsm[obsm_key])
163 | pc_model.point_data[f"{obsm_key}_X"] = coords[:, 0]
164 | pc_model.point_data[f"{obsm_key}_Y"] = coords[:, 1]
165 | pc_model.point_data[f"{obsm_key}_Z"] = (
166 | 0 if coords.shape[1] == 2 else coords[:, 2]
167 | )
168 |
169 | for obs_key in adata.obs_keys():
170 | array = np.asarray(adata[_obs_index, :].obs[obs_key])
171 | array = (
172 | np.asarray(array, dtype=float)
173 | if np.issubdtype(array.dtype, np.number)
174 | else np.asarray(array, dtype=str)
175 | )
176 | pc_model.point_data[obs_key] = array
177 | pc_models.append(pc_model)
178 |
179 | # Generate mesh models
180 | if os.path.isdir(path) and os.path.exists(os.path.join(path, "mesh_models")):
181 | mesh_model_files = [
182 | f
183 | for f in os.listdir(path=os.path.join(path, "mesh_models"))
184 | if str(f).endswith(".vtk") or str(f).endswith(".vtm")
185 | ]
186 | mesh_model_files.sort()
187 |
188 | if mesh_model_ids is None:
189 | mesh_model_ids = [f"Mesh_{str(i).split('_')[1]}" for i in mesh_model_files]
190 | mesh_models, mesh_model_ids = abstract_models(
191 | path=os.path.join(path, "mesh_models"), model_ids=mesh_model_ids
192 | )
193 | else:
194 | mesh_models, mesh_model_ids = None, None
195 |
196 | # Custom colors
197 | custom_colors = []
198 | for key in adata.uns.keys():
199 | if str(key).endswith("colors"):
200 | colors = adata.uns[key]
201 | if isinstance(colors, dict):
202 | colors = np.asarray([i for i in colors.values()])
203 | if isinstance(colors, (np.ndarray, list)):
204 | custom_colors.append(key)
205 | nodes = np.linspace(0, 1, num=len(colors))
206 | if key not in mpl.colormaps():
207 | mpl.colormaps.register(
208 | LinearSegmentedColormap.from_list(key, list(zip(nodes, colors)))
209 | )
210 |
211 | # Delete anndata object
212 | del adata
213 | gc.collect()
214 |
215 | return (
216 | anndata_info,
217 | pc_models,
218 | pc_model_ids,
219 | mesh_models,
220 | mesh_model_ids,
221 | custom_colors,
222 | )
223 |
--------------------------------------------------------------------------------
/stviewer/assets/dataset_manager.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import os
3 | from pathlib import Path
4 |
5 | from trame.app.mimetypes import to_mime
6 |
7 |
8 | def to_base64(file_path):
9 | """
10 | Return the base64 content of the file path.
11 |
12 | Args:
13 | file_path: Path to the file to read.
14 |
15 | Return:
16 | File content encoded in base64
17 |
18 | """
19 | with open(file_path, "rb") as bin_file:
20 | return base64.b64encode(bin_file.read()).decode("ascii")
21 |
22 |
23 | def to_url(file_path: str):
24 | """
25 | Return the base64 encoded URL of the file path.
26 |
27 | Args:
28 | file_path: Path to the file to read.
29 |
30 | Return:
31 | Inlined bas64 encoded url (data:{mime};base64,{content})
32 | """
33 | encoded = to_base64(file_path)
34 | mime = to_mime(file_path)
35 | return f"data:{mime};base64,{encoded}"
36 |
37 |
38 | class LocalFileManager:
39 | """LocalFileManager provide convenient methods for handling local files"""
40 |
41 | def __init__(self, base_path):
42 | """
43 | Provide the base path on which relative path should be based on.
44 | base_path: A file or directory path
45 | """
46 | _base = Path(base_path)
47 |
48 | # Ensure directory
49 | if _base.is_file():
50 | _base = _base.parent
51 |
52 | self._root = Path(str(_base.resolve().absolute()))
53 | self._assests = {}
54 |
55 | def __getitem__(self, name):
56 | return self._assests.get(name)
57 |
58 | def __getattr__(self, name):
59 | return self._assests.get(name)
60 |
61 | def _to_path(self, file_path):
62 | _input_file = Path(file_path)
63 | if _input_file.is_absolute():
64 | return str(_input_file.resolve().absolute())
65 |
66 | return str(self._root.joinpath(file_path).resolve().absolute())
67 |
68 | def file_url(self, key: str, file_path: str):
69 | """
70 | Store an url encoded file content under the provided key name.
71 |
72 | Args:
73 | key: The name for that content which can then be accessed by the [] or . notation
74 | file_path: A file path
75 | """
76 | # if file_path is None:
77 | # file_path, key = key, file_path
78 |
79 | # data = to_url(self._to_path(file_path))
80 |
81 | if key is not None:
82 | self._assests[key] = file_path
83 |
84 | return file_path
85 |
86 | def dir_url(self, key: str, dir_path: str):
87 | """
88 | Store a directory url under the provided key name.
89 |
90 | Args:
91 | key: The name for that content which can then be accessed by the [] or . notation
92 | dir_path: A directory path
93 | """
94 | if dir_path is None:
95 | dir_path, key = key, dir_path
96 |
97 | data = self._to_path(dir_path)
98 |
99 | if key is not None:
100 | self._assests[key] = data
101 |
102 | return data
103 |
104 | @property
105 | def assets(self):
106 | """Return the full set of assets as a dict"""
107 | return self._assests
108 |
109 | def get_assets(self, *keys):
110 | """Return a filtered out dict using the provided set of keys"""
111 | if len(keys) == 0:
112 | return self.assets
113 |
114 | _assets = {}
115 | for key in keys:
116 | _assets[key] = self._assests.get(key)
117 |
118 | return _assets
119 |
120 |
121 | local_dataset_manager = LocalFileManager(__file__)
122 | if os.path.exists(r"./stviewer/assets/dataset/mouse_E95"):
123 | local_dataset_manager.dir_url("mouse_E95", r"./dataset/mouse_E95")
124 | if os.path.exists(r"./stviewer/assets/dataset/mouse_E115"):
125 | local_dataset_manager.dir_url("mouse_E115", r"./dataset/mouse_E115")
126 | if os.path.exists(r"./stviewer/assets/dataset/drosophila_S11"):
127 | local_dataset_manager.dir_url("drosophila_S11", r"./dataset/drosophila_S11")
128 |
129 | if os.path.exists(r"./stviewer/assets/dataset/drosophila_S11/h5ad/S11_cellbin.h5ad"):
130 | local_dataset_manager.file_url(
131 | "drosophila_S11_anndata",
132 | r"./stviewer/assets/dataset/drosophila_S11/h5ad/S11_cellbin.h5ad",
133 | )
134 |
135 | """
136 | # -----------------------------------------------------------------------------
137 | # Data file information
138 | # -----------------------------------------------------------------------------
139 | from trame.assets.remote import HttpFile
140 | dataset_file = HttpFile(
141 | "./data/disk_out_ref.vtu",
142 | "https://github.com/Kitware/trame/raw/master/examples/data/disk_out_ref.vtu",
143 | __file__,
144 | )
145 | """
146 |
--------------------------------------------------------------------------------
/stviewer/assets/image/interactive_viewer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/interactive_viewer.png
--------------------------------------------------------------------------------
/stviewer/assets/image/spateo_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/spateo_logo.png
--------------------------------------------------------------------------------
/stviewer/assets/image/spateoviewer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/spateoviewer.png
--------------------------------------------------------------------------------
/stviewer/assets/image/static_viewer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/static_viewer.png
--------------------------------------------------------------------------------
/stviewer/assets/image/upload_file.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/upload_file.png
--------------------------------------------------------------------------------
/stviewer/assets/image/upload_folder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/upload_folder.png
--------------------------------------------------------------------------------
/stviewer/assets/image_manager.py:
--------------------------------------------------------------------------------
1 | from trame.assets.local import LocalFileManager
2 |
3 | icon_manager = LocalFileManager(__file__)
4 | icon_manager.url("spateo_logo", "./image/spateo_logo.png")
5 |
--------------------------------------------------------------------------------
/stviewer/explorer_app.py:
--------------------------------------------------------------------------------
1 | try:
2 | from typing import Literal
3 | except ImportError:
4 | from typing_extensions import Literal
5 |
6 | from tkinter import Tk, filedialog
7 |
8 | import matplotlib.pyplot as plt
9 | from trame.widgets import trame as trame_widgets
10 |
11 | from .assets import icon_manager, local_dataset_manager
12 | from .Explorer import (
13 | create_plotter,
14 | init_actors,
15 | init_adata_parameters,
16 | init_card_parameters,
17 | init_custom_parameters,
18 | init_interpolation_parameters,
19 | init_mesh_parameters,
20 | init_morphogenesis_parameters,
21 | init_output_parameters,
22 | init_pc_parameters,
23 | ui_container,
24 | ui_drawer,
25 | ui_layout,
26 | ui_toolbar,
27 | )
28 | from .server import get_trame_server
29 |
30 | # export WSLINK_MAX_MSG_SIZE=1000000000 # 1GB
31 |
32 | # Get a Server to work with
33 | static_server = get_trame_server(name="spateo_explorer")
34 | state, ctrl = static_server.state, static_server.controller
35 | state.trame__title = "SPATEO VIEWER"
36 | state.trame__favicon = icon_manager.spateo_logo
37 | state.setdefault("active_ui", None)
38 |
39 | # Generate a new plotter
40 | plotter = create_plotter()
41 | # Init model
42 | (
43 | anndata_info,
44 | actors,
45 | actor_names,
46 | actor_tree,
47 | custom_colors,
48 | ) = init_actors(
49 | plotter=plotter,
50 | path=local_dataset_manager.mouse_E95,
51 | )
52 |
53 | # Init parameters
54 | state.update(init_card_parameters)
55 | state.update(init_adata_parameters)
56 | state.update(init_pc_parameters)
57 | state.update(init_mesh_parameters)
58 | state.update(init_morphogenesis_parameters)
59 | state.update(init_interpolation_parameters)
60 | state.update(init_output_parameters)
61 | state.update(
62 | {
63 | "init_dataset": True,
64 | "anndata_info": anndata_info,
65 | "pc_point_size_value": 4,
66 | "pc_obs_value": "mapped_celltype",
67 | "available_obs": ["None"] + anndata_info["anndata_obs_keys"],
68 | "available_genes": ["None"] + anndata_info["anndata_var_index"],
69 | "pc_colormaps_list": ["spateo_cmap"] + custom_colors + plt.colormaps(),
70 | # setting
71 | "actor_ids": actor_names,
72 | "pipeline": actor_tree,
73 | "active_id": 1,
74 | "active_ui": actor_names[0],
75 | "active_model_type": str(actor_names[0]).split("_")[0],
76 | "vis_ids": [
77 | i for i, actor in enumerate(plotter.actors.values()) if actor.visibility
78 | ],
79 | }
80 | )
81 |
82 | # Custom init parameters
83 | if init_custom_parameters["custom_func"] is True:
84 | state.update(init_custom_parameters)
85 | else:
86 | state.update({"custom_func": False})
87 |
88 |
89 | # Upload directory
90 | def open_directory():
91 | dirpath = filedialog.askdirectory(title="Select Directory")
92 | if not dirpath:
93 | return
94 | state.selected_dir = dirpath
95 | ctrl.view_update()
96 |
97 |
98 | root = Tk()
99 | root.withdraw()
100 | root.wm_attributes("-topmost", 1)
101 | state.selected_dir = "None"
102 | ctrl.open_directory = open_directory
103 |
104 |
105 | # GUI
106 | ui_standard_layout = ui_layout(server=static_server, template_name="main")
107 | with ui_standard_layout as layout:
108 | # Let the server know the browser pixel ratio and the default theme
109 | trame_widgets.ClientTriggers(
110 | mounted="pixel_ratio = window.devicePixelRatio, $vuetify.theme.dark = true"
111 | )
112 |
113 | # -----------------------------------------------------------------------------
114 | # ToolBar
115 | # -----------------------------------------------------------------------------
116 | ui_toolbar(
117 | server=static_server,
118 | layout=layout,
119 | plotter=plotter,
120 | mode="trame",
121 | ui_name="SPATEO VIEWER (EXPLORER)",
122 | )
123 |
124 | # -----------------------------------------------------------------------------
125 | # Drawer
126 | # -----------------------------------------------------------------------------
127 | ui_drawer(server=static_server, layout=layout, plotter=plotter, mode="trame")
128 |
129 | # -----------------------------------------------------------------------------
130 | # Main Content
131 | # -----------------------------------------------------------------------------
132 | ui_container(server=static_server, layout=layout, plotter=plotter, mode="trame")
133 |
134 | # -----------------------------------------------------------------------------
135 | # Footer
136 | # -----------------------------------------------------------------------------
137 | layout.footer.hide()
138 | # layout.flush_content()
139 |
--------------------------------------------------------------------------------
/stviewer/reconstructor_app.py:
--------------------------------------------------------------------------------
1 | try:
2 | from typing import Literal
3 | except ImportError:
4 | from typing_extensions import Literal
5 |
6 | from trame.widgets import trame as trame_widgets
7 | from vtkmodules.web.utils import mesh as vtk_mesh
8 |
9 | from .assets import icon_manager, local_dataset_manager
10 | from .Reconstructor import (
11 | create_plotter,
12 | init_active_parameters,
13 | init_align_parameters,
14 | init_custom_parameters,
15 | init_mesh_parameters,
16 | init_models,
17 | init_picking_parameters,
18 | init_setting_parameters,
19 | ui_container,
20 | ui_drawer,
21 | ui_layout,
22 | ui_toolbar,
23 | )
24 | from .server import get_trame_server
25 |
26 | # export WSLINK_MAX_MSG_SIZE=1000000000 # 1GB
27 |
28 | # Get a Server to work with
29 | interactive_server = get_trame_server(name="spateo_reconstructor")
30 | state, ctrl = interactive_server.state, interactive_server.controller
31 | state.trame__title = "SPATEO VIEWER"
32 | state.trame__favicon = icon_manager.spateo_logo
33 | state.setdefault("active_ui", None)
34 |
35 | # Generate anndata object
36 | plotter = create_plotter()
37 | init_anndata_path = local_dataset_manager.drosophila_S11_anndata
38 | main_model, active_model, init_scalar, pdd, cdd = init_models(
39 | plotter=plotter, anndata_path=init_anndata_path
40 | )
41 |
42 | # Init parameters
43 | state.update(init_active_parameters)
44 | state.update(init_picking_parameters)
45 | state.update(init_align_parameters)
46 | state.update(init_mesh_parameters)
47 | state.update(init_setting_parameters)
48 | state.update(
49 | {
50 | "init_anndata": init_anndata_path,
51 | "upload_anndata": None,
52 | # main model
53 | "mainModel": vtk_mesh(
54 | main_model,
55 | point_arrays=[key for key in pdd.keys()],
56 | cell_arrays=[key for key in cdd.keys()],
57 | ),
58 | "activeModel": vtk_mesh(
59 | active_model,
60 | point_arrays=[key for key in pdd.keys()],
61 | cell_arrays=[key for key in cdd.keys()],
62 | ),
63 | "scalar": "anno_tissue",
64 | "scalarParameters": {**pdd, **cdd},
65 | }
66 | )
67 | # Custom init parameters
68 | if init_custom_parameters["custom_func"] is True:
69 | state.update(init_custom_parameters)
70 | else:
71 | state.update({"custom_func": False})
72 |
73 |
74 | # GUI
75 | ui_standard_layout = ui_layout(server=interactive_server, template_name="main")
76 | with ui_standard_layout as layout:
77 | # Let the server know the browser pixel ratio and the default theme
78 | trame_widgets.ClientTriggers(
79 | mounted="pixel_ratio = window.devicePixelRatio, $vuetify.theme.dark = true"
80 | )
81 |
82 | # -----------------------------------------------------------------------------
83 | # ToolBar
84 | # -----------------------------------------------------------------------------
85 | ui_toolbar(
86 | server=interactive_server,
87 | layout=layout,
88 | plotter=plotter,
89 | ui_name="SPATEO VIEWER (RECONSTRUCTOR)",
90 | )
91 | trame_widgets.ClientStateChange(name="activeModel", change=ctrl.view_reset_camera)
92 | # -----------------------------------------------------------------------------
93 | # Drawer
94 | # -----------------------------------------------------------------------------
95 | ui_drawer(server=interactive_server, layout=layout)
96 |
97 | # -----------------------------------------------------------------------------
98 | # Main Content
99 | # -----------------------------------------------------------------------------
100 | ui_container(server=interactive_server, layout=layout)
101 |
102 | # -----------------------------------------------------------------------------
103 | # Footer
104 | # -----------------------------------------------------------------------------
105 | layout.footer.hide()
106 | # layout.flush_content()
107 |
--------------------------------------------------------------------------------
/stviewer/server.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from trame.app import get_server
4 |
5 | # -----------------------------------------------------------------------------
6 | # Server
7 | # -----------------------------------------------------------------------------
8 |
9 |
10 | def get_trame_server(name: Optional[str] = None, **kwargs):
11 | """
12 | Return a server for serving trame applications. If a name is given and such server is not available yet, it will be
13 | created otherwise the previously created instance will be returned.
14 |
15 | Args:
16 | name: A server name identifier which can be useful when several servers are expected to be created.
17 | kwargs: Additional parameters that will be passed to ``trame.app.get_server`` function.
18 |
19 | Returns:
20 | Return a unique Server instance per given name.
21 | """
22 |
23 | server = get_server(name=name, **kwargs)
24 | return server
25 |
--------------------------------------------------------------------------------
/usage/ExplorerUsage.md:
--------------------------------------------------------------------------------
1 |
2 | ## 10 minutes to static-viewer
3 |
4 | Welcome to static-viewer!
5 |
6 | static-viewer is a web application for visualizing various models created from spatial transcriptomics data in 3D space,
7 | including point cloud models, mesh models, trajectory models, vector field models, etc.
8 |
9 |
10 |
11 |
12 |
13 | ## How to use
14 |
15 | ### Installation
16 | You can clone the [**Spateo-Viewer**](https://github.com/aristoteleo/spateo-viewer) with ``git`` and install dependencies with ``pip``:
17 |
18 | git clone https://github.com/aristoteleo/spateo-viewer.git
19 | cd spateo-viewer
20 | pip install -r requirements.txt
21 |
22 | ### Running (Users could change the port)
23 |
24 | python stv_static_app.py --port 1234
25 |
26 | ### Folder Structure of the upload data
27 |
28 | ```
29 | ├── drosophila_E7_8h # The folder name
30 | ├── h5ad # (Required) The folder includes an anndata object (.h5ad file)
31 | │ └── E7_8h_cellbin.h5ad # (Required) The only one anndata object (.h5ad file)
32 | └── mesh_models # (Optional) The folder includes mesh models (.vtk files)
33 | │ ├── 0_Embryo_E7_8h_aligned_mesh_model.vtk # (Optional) The filename start with an ordinal and end with "_mesh_model.vtk"
34 | │ └── 1_CNS_E7_8h_aligned_mesh_model.vtk # (Optional) The filename start with an ordinal and end with "_mesh_model.vtk"
35 | └── pc_models # (Optional) The folder includes point cloud models (.vtk files)
36 | ├── 0_Embryo_E7_8h_aligned_pc_model.vtk # (Optional) The filename start with an ordinal and end with "_pc_model.vtk"
37 | └── 1_CNS_E7_8h_aligned_pc_model.vtk # (Optional) The filename start with an ordinal and end with "_pc_model.vtk"
38 | ```
39 |
40 | You can refer to the folder structure of the data we include by default in the [**dataset**](https://github.com/aristoteleo/spateo-viewer/blob/main/stviewer/assets/dataset).
41 |
42 | ### How to upload data
43 |
44 | 1. Upload folder via the tools included in the toolbar in the web application:
45 |
46 | 
47 |
48 | 2. Upload folder via the ``stv_static_app.py``:
49 |
50 | ```
51 | from stviewer.static_app import static_server, state
52 |
53 | if __name__ == "__main__":
54 | state.selected_dir = None
55 | static_server.start()
56 | ```
57 |
58 | Change None in ``state.selected_dir = None`` to the absolute path of the folder you want to upload.(Please give priority to this method when used in remote servers)
59 |
60 |
61 | ### How to upload custom colors
62 | The colors to be uploaded should be saved in the uploaded anndata object, which should be stored in anndata.uns in the
63 | form of a list. Examples are as follows:
64 |
65 | Uploaded colors list (the key of colors ends with "colors"):
66 |
67 | ```
68 | adata.uns["anno_tissue_list_colors"] = ["#ef9b20", "#f46a9b", "#ece05a", "#ede15b", "#ea5545", "#9a82e0", "#87bc45", "#ec5646", "#bdcf32"]
69 | ```
70 |
--------------------------------------------------------------------------------
/usage/ReconstructorUsage.md:
--------------------------------------------------------------------------------
1 |
2 | ## 10 minutes to interactive-viewer
3 |
4 | Welcome to interactive-viewer!
5 |
6 | interactive-viewer is a web application for interactively 3D reconstruction of spatial transcriptomics in 3D space.
7 |
8 |
9 |
10 |
11 |
12 | ## How to use
13 |
14 | ### Installation
15 | You can clone the [**Spateo-Viewer**](https://github.com/aristoteleo/spateo-viewer) with ``git`` and install dependencies with ``pip``:
16 |
17 | git clone https://github.com/aristoteleo/spateo-viewer.git
18 | cd spateo-viewer
19 | pip install -r requirements.txt
20 |
21 | ### Running (Users could change the port)
22 |
23 | python stv_interactive_app.py --port 1234
24 |
25 | ### How to generate the anndata object to upload
26 |
27 | ```
28 | import spateo as st
29 |
30 | # Load the anndata object
31 | adata = st.read_h5ad("E7_8h_cellbin.h5ad")
32 |
33 | # Make sure adata.obsm contains 'spatial' to save the coordinates
34 | adata.obsm['spatial'] = adata.obsm['3d_align_spatial']
35 |
36 | # Interactive-viewer will read all info contained in anndata.obs, so please make sure the info you need has been saved in anndata.obs
37 |
38 | # Save the anndata object
39 | adata.write_h5ad("E7_8h_cellbin.h5ad", compression="gzip")
40 |
41 | ```
42 |
43 | You can refer to the data structure we include by default in the [**dataset**](https://github.com/aristoteleo/spateo-viewer/blob/main/stviewer/assets/dataset/drosophila_E7_8h/pc_models/0_Embryo_E7_8h_aligned_pc_model.vtk).
44 |
45 | ### How to upload data
46 |
47 | 1. Upload file via the tool included in the toolbar in the web application:
48 |
49 | 
50 |
51 | 2. Upload folder via the ``stv_interactive_app.py``:
52 |
53 | ```
54 | from stviewer.interactive_app import interactive_server, state
55 |
56 | if __name__ == "__main__":
57 | **state.upload_anndata = None**
58 | interactive_server.start()
59 | ```
60 |
61 | Change None in ``state.upload_anndata = None`` to the absolute path of the file you want to upload.(Please give priority to this method when used in remote servers)
62 |
63 |
--------------------------------------------------------------------------------
/usage/spateo-viewer.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/usage/spateo-viewer.pdf
--------------------------------------------------------------------------------