├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── spateo-viewer.iml └── vcs.xml ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── datasets_website └── index.html ├── requirements.txt ├── stv_explorer.py ├── stv_reconstructor.py ├── stviewer ├── Explorer │ ├── __init__.py │ ├── pv_pipeline │ │ ├── __init__.py │ │ ├── init_parameters.py │ │ ├── pv_actors.py │ │ ├── pv_callback.py │ │ ├── pv_custom.py │ │ ├── pv_interpolation.py │ │ ├── pv_morphogenesis.py │ │ └── pv_plotter.py │ └── ui │ │ ├── __init__.py │ │ ├── container.py │ │ ├── drawer │ │ ├── __init__.py │ │ ├── adata_obj.py │ │ ├── custom_card.py │ │ ├── main.py │ │ ├── model_mesh.py │ │ ├── model_point.py │ │ ├── morphogenesis.py │ │ ├── output.py │ │ └── pipeline.py │ │ ├── layout.py │ │ ├── toolbar.py │ │ └── utils.py ├── Reconstructor │ ├── __init__.py │ ├── pv_pipeline │ │ ├── __init__.py │ │ ├── alignment_utils.py │ │ ├── init_parameters.py │ │ ├── pv_alignment.py │ │ ├── pv_callback.py │ │ ├── pv_custom.py │ │ ├── pv_models.py │ │ ├── pv_plotter.py │ │ └── pv_tdr.py │ └── ui │ │ ├── __init__.py │ │ ├── container.py │ │ ├── drawer │ │ ├── __init__.py │ │ ├── alignment.py │ │ ├── custom_card.py │ │ ├── main.py │ │ ├── model_point.py │ │ └── reconstruction.py │ │ ├── layout.py │ │ ├── toolbar.py │ │ └── utils.py ├── __init__.py ├── assets │ ├── __init__.py │ ├── anndata_preprocess.py │ ├── dataset │ │ ├── drosophila_S11 │ │ │ ├── h5ad │ │ │ │ └── S11_cellbin_demo.h5ad │ │ │ ├── mesh_models │ │ │ │ ├── 0_Embryo_S11_aligned_mesh_model.vtk │ │ │ │ ├── 1_CNS_S11_aligned_mesh_model.vtk │ │ │ │ ├── 2_Midgut_S11_aligned_mesh_model.vtk │ │ │ │ ├── 3_Hindgut_S11_aligned_mesh_model.vtk │ │ │ │ ├── 4_Muscle_S11_aligned_mesh_model.vtk │ │ │ │ ├── 5_SalivaryGland_S11_aligned_mesh_model.vtk │ │ │ │ └── 6_Amnioserosa_S11_aligned_mesh_model.vtk │ │ │ └── pc_models │ │ │ │ ├── 0_Embryo_S11_aligned_pc_model.vtk │ │ │ │ ├── 1_CNS_S11_aligned_pc_model.vtk │ │ │ │ ├── 2_Midgut_S11_aligned_pc_model.vtk │ │ │ │ ├── 3_Hindgut_S11_aligned_pc_model.vtk │ │ │ │ ├── 4_Muscle_S11_aligned_pc_model.vtk │ │ │ │ ├── 5_SalivaryGland_S11_aligned_pc_model.vtk │ │ │ │ └── 6_Amnioserosa_S11_aligned_pc_model.vtk │ │ └── mouse_E95 │ │ │ ├── h5ad │ │ │ └── mouse_E95_demo.h5ad │ │ │ ├── matrices │ │ │ └── X_sparse_matrix.npz │ │ │ ├── mesh_models │ │ │ └── 0_Embryo_mouse_E95_mesh_model.vtk │ │ │ └── pc_models │ │ │ └── 0_Embryo_mouse_E95_pc_model.vtk │ ├── dataset_acquisition.py │ ├── dataset_manager.py │ ├── image │ │ ├── interactive_viewer.png │ │ ├── spateo_logo.png │ │ ├── spateoviewer.png │ │ ├── static_viewer.png │ │ ├── upload_file.png │ │ └── upload_folder.png │ └── image_manager.py ├── explorer_app.py ├── reconstructor_app.py └── server.py └── usage ├── ExplorerUsage.md ├── ReconstructorUsage.md └── spateo-viewer.pdf /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/spateo-viewer.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/python/black 3 | rev: 23.3.0 4 | hooks: 5 | - id: black 6 | 7 | - repo: https://github.com/pycqa/isort 8 | rev: 5.12.0 9 | hooks: 10 | - id: isort 11 | args: ["--profile", "black", "--filter-files"] 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2023, Aristotle 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | ## Spateo-viewer: the "Google earth" browser of spatial transcriptomics 6 | 7 | [**Spateo-viewer**](https://github.com/aristoteleo/spateo-viewer) is the “Google earth” of spatial transcriptomics. 8 | Relying on a set of powerful libraries and tools in the Python ecosystem, such as Trame, PyVista, VTK, etc., it delivers 9 | a complete web application solution of creating convenient, vivid, and lightweight interface for 3D reconstruction and 10 | visualization of [**Spateo**](https://github.com/aristoteleo/spateo-release) downstream analysis results. Currently, 11 | Spateo-viewer includes two major applications, ***Explorer*** and ***Reconstructor***, which are respectively 12 | dedicated to the 3D reconstruction of spatial transcriptomics and the visualization of spatial transcriptomics analysis results. 13 | 14 | Please download and read the corresponding [**Slides**](https://github.com/aristoteleo/spateo-viewer/blob/main/usage/spateo-viewer.pdf) to learn more about Spateo-viewer. 15 | 16 | ## Highlights 17 | 18 | * In the ***Reconstructor***, 3D serial slices of spatial transcriptomics datasets can be aligned to create 3D models. The 3D model can be also cleaned up by freely clipping and editing. 19 | * In the ***Explorer***, users can not only visualize gene expression, but also easily switch between raw and different types of normalized data or data layers. Users can also visualize all cell annotation information such as cell size, cell type, tissue type, etc. All done in 3D space! 20 | * Static-viewer allows users to not only visualize the point cloud model and mesh model of the whole embryo, but also for individual organ or tissue type at the same time. It even visualizes morphogenesis vector field model to animate how cell move in the physical 3D space. 21 | * Spateo-viewer can not only run on the local computer, but also run freely on the remote server. 22 | * Users can upload custom files in the web application, or access to custom files in local folders when running Spateo-Viewer as a stand alone App.(See [**ExplorerUsage**](https://github.com/aristoteleo/spateo-viewer/blob/main/usage/ExplorerUsage.md) or [**ReconstructorUsage**](https://github.com/aristoteleo/spateo-viewer/blob/main/usage/ReconstructorUsage.md)) 23 | 24 | ## Installation 25 | 26 | You can clone the [**Spateo-viewer**](https://github.com/aristoteleo/spateo-viewer) with ``git`` and install dependencies with ``pip``: 27 | 28 | git clone https://github.com/aristoteleo/spateo-viewer.git 29 | cd spateo-viewer 30 | pip install -r requirements.txt 31 | 32 | ## Usage 33 | 34 | #### Run the *Explorer* application: 35 | 36 | python stv_explorer.py --port 1234 37 | 38 | See the [**ExplorerUsage**](https://github.com/aristoteleo/spateo-viewer/blob/main/usage/ExplorerUsage.md) for more details. 39 | 40 | #### Run the *Reconstructor* application: 41 | 42 | python stv_reconstructor.py --port 1234 43 | 44 | See the [**ReconstructorUsage**](https://github.com/aristoteleo/spateo-viewer/blob/main/usage/ReconstructorUsage.md) for more details. 45 | 46 | ## Sample Datasets 47 | 48 | #### [**Mouse E9.5 dataset**](https://github.com/aristoteleo/spateo-viewer/tree/main/stviewer/assets/dataset/mouse_E95): 49 | - **h5ad/mouse_E95_demo.h5ad**:Single-cell resolution Stereo-seq data with alignment and cell annotation by the Spateo team. This data only contains 1000 highly variable genes. If you need the raw data, please check the [CNGB website](https://db.cngb.org/stomics/mosta/download/). 50 | - **matrices**: Contains various gene expression matrices for .h5ad data. 51 | - **mesh_models**:Contains mesh model of mouse embryo. 52 | - **pc_models**:Contains point cloud model of mouse embryo. 53 | 54 | #### [**Drosophila S11 dataset**](https://github.com/aristoteleo/spateo-viewer/tree/main/stviewer/assets/dataset/drosophila_S11): 55 | - **h5ad/S11_cellbin_demo.h5ad**:Single-cell resolution Stereo-seq data with alignment and cell annotation by the Spateo team. This data only contains 1000 highly variable genes. 56 | - **mesh_models**:Contains mesh models of drosophila embryos and various organs. 57 | - **pc_models**:Contains point cloud models of drosophila embryos and various organs. 58 | 59 | ## Citation 60 | 61 | [ Spatiotemporal modeling of molecular holograms ](https://www.cell.com/cell/fulltext/S0092-8674(24)01159-0) 62 | 63 | Xiaojie Qiu1, 7, 8\$\*, Daniel Y. Zhu3\$, Yifan Lu1, 7, 8, 9\$, Jiajun Yao2, 4, 10\$, Zehua Jing2, 4, 11\$, Kyung Hoi (Joseph) Min12\$, Mengnan Cheng2,6\$, Hailin Pan6, Lulu Zuo6, Samuel King13, Qi Fang2, 6, Huiwen Zheng2, 11, Mingyue Wang2, 14, Shuai Wang2, 11, Qingquan Zhang25, Sichao Yu5, Sha Liao6, 17, 18, Chao Liu15, Xinchao Wu2, 4, 16, Yiwei Lai6, Shijie Hao2, Zhewei Zhang2, 4, 16, Liang Wu18, Yong Zhang15, Mei Li17, Zhencheng Tu2, 11, Jinpei Lin2, 4, Zhuoxuan Yang2, 16, Yuxiang Li15, Ying Gu2, 6, 11, Ao Chen6, 17, 18, Longqi Liu2, 19, 20, Jonathan S. Weissman5, 22, 23, Jiayi Ma9*, Xun Xu2, 11, 21*, Shiping Liu2, 19, 20, 24*, Yinqi Bai4, 26* 64 | 65 | $Co-first authors; *:Corresponding authors 66 | -------------------------------------------------------------------------------- /datasets_website/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Spateo Datasets 6 | 7 | 44 | 45 | 46 | 56 |
57 | logo 58 |
59 |
60 |
61 | 74 |
75 |

Stereo-seq E9.5

76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 90 | 91 | 92 | 93 | 98 | 99 | 100 | 101 | 105 | 106 | 107 | 108 | 111 | 112 | 113 | 114 |
NameSize
86 | 87 | E9.5_metadata.csv 88 | 89 | 63.9 MB
94 | 95 | Shendure_E9.5_ref.h5ad 96 | 97 | 15.77 GB
102 | E9.5_full_final.h5ad 103 | 104 | 3.01 GB
109 | E9.5_cell_types_to_colors_final.json 110 | 2 KB
115 |

Stereo-seq E11.5

116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 130 | 131 | 132 | 133 | 138 | 139 | 140 | 141 | 146 | 147 | 148 | 149 | 152 | 153 | 154 | 155 | 158 | 159 | 160 | 161 | 164 | 165 | 166 | 167 |
NameSize
126 | 127 | E11.5_metadata_full.csv 128 | 129 | 971.2 MB
134 | 135 | E11.5_ZLI_final.h5ad 136 | 137 | 3.93 GB
142 | 143 | E11.5_spinal_cord_final.h5ad 144 | 145 | 5.58 GB
150 | E11.5_full_final.h5ad 151 | 64.46 GB
156 | E11.5_diencephalic_ring_final.h5ad 157 | 1.53 GB
162 | E11.5_cell_types_to_colors_final.json 163 | 3 KB
168 |
169 |
170 | 171 | 172 | 173 |
174 | 177 | 178 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anndata>=0.8.0 2 | av>=11.0.0 3 | imageio>=2.33.1 4 | imageio-ffmpeg>=0.4.9 5 | gpytorch>=1.11 6 | numpy>=1.18.1 7 | matplotlib>=3.5.3 8 | POT>=0.8.1 9 | pyacvd>=0.2.9 10 | pyautogui>=0.9.54 11 | PyMCubes>=0.1.4 12 | pymeshfix>=0.16.2 13 | pyvista==0.40.0 14 | trame==2.5.2 15 | trame-server>=2.15.0 -------------------------------------------------------------------------------- /stv_explorer.py: -------------------------------------------------------------------------------- 1 | import getopt 2 | import sys 3 | 4 | from stviewer.explorer_app import state, static_server 5 | 6 | if __name__ == "__main__": 7 | # upload anndata 8 | state.selected_dir = None 9 | 10 | opts, args = getopt.getopt(sys.argv[1:], "p", ["port="]) 11 | port = "1234" if len(opts) == 0 else opts[0][1] 12 | static_server.start(port=port) 13 | -------------------------------------------------------------------------------- /stv_reconstructor.py: -------------------------------------------------------------------------------- 1 | import getopt 2 | import sys 3 | 4 | from stviewer.reconstructor_app import interactive_server, state 5 | 6 | if __name__ == "__main__": 7 | # upload anndata 8 | state.upload_anndata = None 9 | opts, args = getopt.getopt(sys.argv[1:], "p", ["port="]) 10 | port = "8888" if len(opts) == 0 else opts[0][1] 11 | interactive_server.start(port=port) 12 | -------------------------------------------------------------------------------- /stviewer/Explorer/__init__.py: -------------------------------------------------------------------------------- 1 | from .pv_pipeline import * 2 | from .ui import * 3 | -------------------------------------------------------------------------------- /stviewer/Explorer/pv_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .init_parameters import * 2 | from .pv_actors import generate_actors, generate_actors_tree, init_actors 3 | from .pv_callback import PVCB, SwitchModels, Viewer, vuwrap 4 | from .pv_plotter import add_single_model, create_plotter 5 | -------------------------------------------------------------------------------- /stviewer/Explorer/pv_pipeline/init_parameters.py: -------------------------------------------------------------------------------- 1 | import matplotlib.colors as mcolors 2 | 3 | # Init parameters 4 | init_card_parameters = { 5 | "show_anndata_card": False, 6 | "show_model_card": True, 7 | "show_output_card": True, 8 | } 9 | init_adata_parameters = { 10 | "uploaded_anndata_path": None, 11 | } 12 | init_pc_parameters = { 13 | "pc_obs_value": None, 14 | "pc_gene_value": None, 15 | "pc_scalars_raw": {"None": "None"}, 16 | "pc_matrix_value": "X", 17 | "pc_coords_value": "spatial", 18 | "pc_opacity_value": 1.0, 19 | "pc_ambient_value": 0.2, 20 | "pc_color_value": None, 21 | "pc_colormap_value": "Spectral", 22 | "pc_point_size_value": 4, 23 | "pc_add_legend": False, 24 | "pc_picking_group": None, 25 | "pc_overwrite": False, 26 | "pc_reload": False, 27 | "pc_colors_list": [c for c in mcolors.CSS4_COLORS.keys()], 28 | } 29 | init_mesh_parameters = { 30 | "mesh_opacity_value": 0.2, 31 | "mesh_ambient_value": 0.2, 32 | "mesh_color_value": "gainsboro", 33 | "mesh_style_value": "surface", 34 | "mesh_morphology": False, 35 | "mesh_colors_list": [c for c in mcolors.CSS4_COLORS.keys()], 36 | } 37 | init_morphogenesis_parameters = { 38 | "cal_morphogenesis": False, 39 | "morpho_target_anndata_path": None, 40 | "morpho_uploaded_target_anndata_path": None, 41 | "morpho_mapping_method": "GP", 42 | "morpho_mapping_device": "cpu", 43 | "morpho_mapping_factor": 0.2, 44 | "morphofield_factor": 3000, 45 | "morphopath_t_end": 10000, 46 | "morphopath_downsampling": 500, 47 | "morphofield_visibile": False, 48 | "morphopath_visibile": False, 49 | "morphopath_predicted_models": None, 50 | "morphopath_animation_path": None, 51 | } 52 | init_interpolation_parameters = { 53 | "cal_interpolation": False, 54 | "interpolation_device": "cpu", 55 | } 56 | init_output_parameters = { 57 | "screenshot_path": None, 58 | "animation_path": None, 59 | "animation_npoints": 50, 60 | "animation_framerate": 10, 61 | } 62 | 63 | # costum init parameters 64 | init_custom_parameters = { 65 | "custom_func": False, 66 | "custom_analysis": False, 67 | "custom_model": None, 68 | "custom_model_visible": False, 69 | "custom_parameter1": "X", 70 | "custom_parameter2": "recipe_monocle", 71 | "custom_parameter3": "pca", 72 | "custom_parameter4": "umap", 73 | "custom_parameter5": False, 74 | "custom_parameter6": "None", 75 | "custom_parameter7": 30, 76 | "custom_parameter8": 30, 77 | "custom_parameter9": "None", 78 | "custom_parameter10": 1, 79 | } 80 | -------------------------------------------------------------------------------- /stviewer/Explorer/pv_pipeline/pv_actors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | warnings.filterwarnings("ignore") 5 | 6 | from .pv_plotter import add_single_model 7 | 8 | try: 9 | from typing import Literal 10 | except ImportError: 11 | from typing_extensions import Literal 12 | 13 | from typing import Optional 14 | 15 | from stviewer.assets import sample_dataset 16 | 17 | 18 | def generate_actors( 19 | plotter, 20 | pc_models: Optional[list] = None, 21 | mesh_models: Optional[list] = None, 22 | pc_model_names: Optional[list] = None, 23 | mesh_model_names: Optional[list] = None, 24 | ): 25 | # Generate actors for pc models 26 | pc_kwargs = dict(model_style="points", model_size=5) 27 | if not (pc_models is None): 28 | pc_actors = [ 29 | add_single_model( 30 | plotter=plotter, model=model, model_name=model_name, **pc_kwargs 31 | ) 32 | for model, model_name in zip(pc_models, pc_model_names) 33 | ] 34 | else: 35 | pc_actors = None 36 | 37 | # Generate actors for mesh models 38 | mesh_kwargs = dict(opacity=0.2, model_style="surface", color="gainsboro") 39 | if not (mesh_models is None): 40 | mesh_actors = [ 41 | add_single_model( 42 | plotter=plotter, model=model, model_name=model_name, **mesh_kwargs 43 | ) 44 | for model, model_name in zip(mesh_models, mesh_model_names) 45 | ] 46 | else: 47 | mesh_actors = None 48 | return pc_actors, mesh_actors 49 | 50 | 51 | def standard_tree(actors: list, base_id: int = 0): 52 | actor_tree, actor_names = [], [] 53 | for i, actor in enumerate(actors): 54 | if i == 0: 55 | actor.SetVisibility(True) 56 | else: 57 | actor.SetVisibility(False) 58 | actor_names.append(str(actor.name)) 59 | actor_tree.append( 60 | { 61 | "id": str(base_id + 1 + i), 62 | "parent": str(0) if i == 0 else str(base_id + 1), 63 | "visible": True if i == 0 else False, 64 | "name": str(actor.name), 65 | } 66 | ) 67 | 68 | return actors, actor_names, actor_tree 69 | 70 | 71 | def generate_actors_tree( 72 | pc_actors: Optional[list] = None, 73 | mesh_actors: Optional[list] = None, 74 | ): 75 | if not (pc_actors is None): 76 | pc_actors, pc_actor_names, pc_tree = standard_tree(actors=pc_actors, base_id=0) 77 | else: 78 | pc_actors, pc_actor_names, pc_tree = [], [], [] 79 | 80 | if not (mesh_actors is None): 81 | mesh_actors, mesh_actor_names, mesh_tree = standard_tree( 82 | actors=mesh_actors, 83 | base_id=0 if pc_actors is None else len(pc_actors), 84 | ) 85 | else: 86 | mesh_actors, mesh_actor_names, mesh_tree = [], [], [] 87 | 88 | actors = pc_actors + mesh_actors 89 | actor_names = pc_actor_names + mesh_actor_names 90 | actor_tree = pc_tree + mesh_tree 91 | return actors, actor_names, actor_tree 92 | 93 | 94 | def init_actors(plotter, path): 95 | ( 96 | anndata_info, 97 | pc_models, 98 | pc_model_ids, 99 | mesh_models, 100 | mesh_model_ids, 101 | custom_colors, 102 | ) = sample_dataset(path=path) 103 | 104 | # Generate actors 105 | pc_actors, mesh_actors = generate_actors( 106 | plotter=plotter, 107 | pc_models=pc_models, 108 | pc_model_names=pc_model_ids, 109 | mesh_models=mesh_models, 110 | mesh_model_names=mesh_model_ids, 111 | ) 112 | 113 | # Generate the relationship tree of actors 114 | actors, actor_names, actor_tree = generate_actors_tree( 115 | pc_actors=pc_actors, 116 | mesh_actors=mesh_actors, 117 | ) 118 | 119 | return ( 120 | anndata_info, 121 | actors, 122 | actor_names, 123 | actor_tree, 124 | custom_colors, 125 | ) 126 | -------------------------------------------------------------------------------- /stviewer/Explorer/pv_pipeline/pv_custom.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, Union 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pyvista as pv 6 | from anndata import AnnData 7 | from pyvista import PolyData 8 | 9 | 10 | def RNAvelocity( 11 | adata: AnnData, 12 | pc_model: PolyData, 13 | layer: str = "X", 14 | basis_pca: str = "pca", 15 | basis_umap: str = "umap", 16 | data_preprocess: Literal[ 17 | "False", "recipe_monocle", "pearson_residuals" 18 | ] = "recipe_monocle", 19 | harmony_debatch: bool = False, 20 | group_key: Optional[str] = None, 21 | n_neighbors: int = 30, 22 | n_pca_components: int = 30, 23 | n_vectors_downsampling: Optional[int] = None, 24 | vectors_size: Union[float, int] = 1, 25 | ): 26 | try: 27 | import dynamo as dyn 28 | from dynamo.preprocessing import Preprocessor 29 | from dynamo.tools.Markov import velocity_on_grid 30 | except ImportError: 31 | raise ImportError( 32 | "You need to install the package `dynamo`. " 33 | "\nInstall dynamo via `pip install dynamo-release`." 34 | ) 35 | try: 36 | import harmonypy 37 | except ImportError: 38 | raise ImportError( 39 | "You need to install the package `harmonypy`. " 40 | "\nInstall harmonypy via `pip install harmonypy`." 41 | ) 42 | 43 | # Preprocess 44 | _obs_index = pc_model.point_data["obs_index"] 45 | dyn_adata = adata[_obs_index, :].copy() 46 | dyn_adata.X = dyn_adata.X if layer == "X" else dyn_adata.layers[layer] 47 | 48 | # Data preprocess 49 | if basis_pca in dyn_adata.obsm.keys(): 50 | if data_preprocess != "False": 51 | preprocessor = Preprocessor() 52 | preprocessor.preprocess_adata(adata, recipe="monocle") 53 | dyn.tl.reduceDimension( 54 | dyn_adata, 55 | basis=basis_pca, 56 | n_neighbors=n_neighbors, 57 | n_pca_components=n_pca_components, 58 | ) 59 | else: 60 | if data_preprocess != "False": 61 | if data_preprocess == "recipe_monocle": 62 | preprocessor = Preprocessor() 63 | preprocessor.preprocess_adata(adata, recipe="monocle") 64 | elif data_preprocess == "pearson_residuals": 65 | preprocessor = Preprocessor() 66 | preprocessor.preprocess_adata(adata, recipe="pearson_residuals") 67 | dyn.tl.reduceDimension( 68 | dyn_adata, 69 | basis=basis_pca, 70 | n_neighbors=n_neighbors, 71 | n_pca_components=n_pca_components, 72 | ) 73 | if harmony_debatch: 74 | harmony_out = harmonypy.run_harmony( 75 | dyn_adata.obsm[basis_pca], dyn_adata.obs, group_key, max_iter_harmony=20 76 | ) 77 | dyn_adata.obsm[basis_pca] = harmony_out.Z_corr.T 78 | dyn.tl.reduceDimension( 79 | dyn_adata, 80 | X_data=dyn_adata.obsm[basis_pca], 81 | enforce=True, 82 | n_neighbors=n_neighbors, 83 | n_pca_components=n_pca_components, 84 | ) 85 | 86 | # RNA velocity 87 | if basis_umap in dyn_adata.obsm.keys(): 88 | if f"X_{basis_umap}" not in dyn_adata.obsm.keys(): 89 | dyn_adata.obsm[f"X_{basis_umap}"] = dyn_adata.obsm[basis_umap] 90 | 91 | dyn.tl.dynamics(dyn_adata, model="stochastic", cores=3) 92 | dyn.tl.cell_velocities( 93 | dyn_adata, 94 | basis=basis_umap, 95 | method="pearson", 96 | other_kernels_dict={"transform": "sqrt"}, 97 | ) 98 | dyn.tl.cell_velocities(dyn_adata, basis=basis_pca) 99 | 100 | # Vectorfield 101 | dyn.vf.VectorField(dyn_adata, basis=basis_umap) 102 | dyn.vf.VectorField(dyn_adata, basis=basis_pca) 103 | 104 | # Pesudotime 105 | dyn.ext.ddhodge(dyn_adata, basis=basis_umap) 106 | dyn.ext.ddhodge(dyn_adata, basis=basis_pca) 107 | 108 | # Differnetial geometry 109 | dyn.vf.speed(dyn_adata, basis=basis_pca) 110 | dyn.vf.acceleration(dyn_adata, basis=basis_pca) 111 | dyn.vf.curvature(dyn_adata, basis=basis_pca) 112 | dyn.vf.curl(dyn_adata, basis=basis_pca) 113 | dyn.vf.divergence(dyn_adata, basis=basis_pca) 114 | 115 | # RNA velocity vectors model 116 | if n_vectors_downsampling in [None, "None", "none"]: 117 | ix_choice = np.arange(dyn_adata.shape[0]) 118 | else: 119 | ix_choice = np.random.choice( 120 | np.arange(dyn_adata.shape[0]), size=n_vectors_downsampling, replace=False 121 | ) 122 | 123 | X = dyn_adata.obsm[f"X_{basis_umap}"][:, [0, 1]] 124 | V = dyn_adata.obsm[f"velocity_{basis_umap}"][:, [0, 1]] 125 | X, V = X[ix_choice, :], V[ix_choice, :] 126 | if X.shape[1] == 2: 127 | df = pd.DataFrame( 128 | { 129 | "x": X[:, 0], 130 | "y": X[:, 1], 131 | "z": np.zeros(shape=(X.shape[0])), 132 | "u": V[:, 0], 133 | "v": V[:, 1], 134 | "w": np.zeros(shape=(V.shape[0])), 135 | } 136 | ) 137 | else: 138 | df = pd.DataFrame( 139 | { 140 | "x": X[:, 0], 141 | "y": X[:, 1], 142 | "z": X[:, 2], 143 | "u": V[:, 0], 144 | "v": V[:, 1], 145 | "w": V[:, 2], 146 | } 147 | ) 148 | df = df.iloc[ix_choice, :] 149 | 150 | x0, x1, x2 = df.iloc[:, 0], df.iloc[:, 1], df.iloc[:, 2] 151 | v0, v1, v2 = df.iloc[:, 3], df.iloc[:, 4], df.iloc[:, 5] 152 | 153 | point_cloud = pv.PolyData(np.column_stack((x0.values, x1.values, x2.values))) 154 | point_cloud["vectors"] = np.column_stack((v0.values, v1.values, v2.values)) 155 | point_cloud.point_data["obs_index"] = dyn_adata.obs.index.tolist() 156 | vectors = point_cloud.glyph(orient="vectors", factor=vectors_size) 157 | vectors.point_data[f"{basis_umap}_ddhodge_potential"] = np.asarray( 158 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[ 159 | f"{basis_umap}_ddhodge_potential" 160 | ] 161 | ) 162 | vectors.point_data[f"{basis_pca}_ddhodge_potential"] = np.asarray( 163 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[ 164 | f"{basis_pca}_ddhodge_potential" 165 | ] 166 | ) 167 | vectors.point_data[f"speed_{basis_pca}"] = np.asarray( 168 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[f"speed_{basis_pca}"] 169 | ) 170 | vectors.point_data[f"acceleration_{basis_pca}"] = np.asarray( 171 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[ 172 | f"acceleration_{basis_pca}" 173 | ] 174 | ) 175 | vectors.point_data[f"divergence_{basis_pca}"] = np.asarray( 176 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[ 177 | f"divergence_{basis_pca}" 178 | ] 179 | ) 180 | vectors.point_data[f"curvature_{basis_pca}"] = np.asarray( 181 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[ 182 | f"curvature_{basis_pca}" 183 | ] 184 | ) 185 | vectors.point_data[f"curl_{basis_pca}"] = np.asarray( 186 | dyn_adata[np.asarray(vectors.point_data["obs_index"])].obs[f"curl_{basis_pca}"] 187 | ) 188 | 189 | pc_model.point_data[f"{basis_umap}_ddhodge_potential"] = np.asarray( 190 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[ 191 | f"{basis_umap}_ddhodge_potential" 192 | ] 193 | ) 194 | pc_model.point_data[f"{basis_pca}_ddhodge_potential"] = np.asarray( 195 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[ 196 | f"{basis_pca}_ddhodge_potential" 197 | ] 198 | ) 199 | pc_model.point_data[f"speed_{basis_pca}"] = np.asarray( 200 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[ 201 | f"speed_{basis_pca}" 202 | ] 203 | ) 204 | pc_model.point_data[f"acceleration_{basis_pca}"] = np.asarray( 205 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[ 206 | f"acceleration_{basis_pca}" 207 | ] 208 | ) 209 | pc_model.point_data[f"divergence_{basis_pca}"] = np.asarray( 210 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[ 211 | f"divergence_{basis_pca}" 212 | ] 213 | ) 214 | pc_model.point_data[f"curvature_{basis_pca}"] = np.asarray( 215 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[ 216 | f"curvature_{basis_pca}" 217 | ] 218 | ) 219 | pc_model.point_data[f"curl_{basis_pca}"] = np.asarray( 220 | dyn_adata[np.asarray(pc_model.point_data["obs_index"])].obs[f"curl_{basis_pca}"] 221 | ) 222 | 223 | return pc_model, vectors 224 | -------------------------------------------------------------------------------- /stviewer/Explorer/pv_pipeline/pv_interpolation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from tqdm import tqdm 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | # GP model 11 | import gpytorch 12 | import numpy as np 13 | import ot 14 | import pandas as pd 15 | import torch 16 | from anndata import AnnData 17 | from gpytorch.likelihoods import GaussianLikelihood 18 | from gpytorch.models import ApproximateGP, ExactGP 19 | from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy 20 | from numpy import ndarray 21 | from scipy.sparse import issparse 22 | 23 | 24 | class Approx_GPModel(ApproximateGP): 25 | def __init__(self, inducing_points): 26 | variational_distribution = CholeskyVariationalDistribution( 27 | inducing_points.size(0) 28 | ) 29 | variational_strategy = VariationalStrategy( 30 | self, 31 | inducing_points, 32 | variational_distribution, 33 | learn_inducing_locations=True, 34 | ) 35 | super(Approx_GPModel, self).__init__(variational_strategy) 36 | self.mean_module = gpytorch.means.ConstantMean() 37 | self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) 38 | 39 | def forward(self, x): 40 | mean_x = self.mean_module(x) 41 | covar_x = self.covar_module(x) 42 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) 43 | 44 | 45 | class Exact_GPModel(ExactGP): 46 | def __init__(self, train_x, train_y, likelihood): 47 | super().__init__(train_x, train_y, likelihood) 48 | self.mean_module = gpytorch.means.ZeroMean() 49 | self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) 50 | 51 | def forward(self, x): 52 | mean_x = self.mean_module(x) 53 | covar_x = self.covar_module(x) 54 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) 55 | 56 | 57 | def gp_train(model, likelihood, train_loader, train_epochs, method, N, device): 58 | if torch.cuda.is_available() and device != "cpu": 59 | model = model.cuda() 60 | likelihood = likelihood.cuda() 61 | 62 | model.train() 63 | likelihood.train() 64 | # define the mll (loss) 65 | if method == "SVGP": 66 | mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=N) 67 | else: 68 | mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) 69 | 70 | optimizer = torch.optim.Adam( 71 | [ 72 | {"params": model.parameters()}, 73 | {"params": likelihood.parameters()}, 74 | ], 75 | lr=0.01, 76 | ) 77 | 78 | epochs_iter = tqdm(range(train_epochs), desc="Epoch") 79 | for i in epochs_iter: 80 | if method == "SVGP": 81 | # Within each iteration, we will go over each minibatch of data 82 | minibatch_iter = tqdm(train_loader, desc="Minibatch", leave=True) 83 | for x_batch, y_batch in minibatch_iter: 84 | optimizer.zero_grad() 85 | output = model(x_batch) 86 | loss = -mll(output, y_batch) 87 | minibatch_iter.set_postfix(loss=loss.item()) 88 | loss.backward() 89 | optimizer.step() 90 | else: 91 | # Zero gradients from previous iteration 92 | optimizer.zero_grad() 93 | # Output from model 94 | output = model(train_loader["train_x"]) 95 | # Calc loss and backprop gradients 96 | loss = -mll(output, train_loader["train_y"]) 97 | loss.backward() 98 | optimizer.step() 99 | 100 | 101 | nx_torch = lambda nx: True if isinstance(nx, ot.backend.TorchBackend) else False 102 | _chunk = ( 103 | lambda nx, x, chunk_num, dim: torch.chunk(x, chunk_num, dim=dim) 104 | if nx_torch(nx) 105 | else np.array_split(x, chunk_num, axis=dim) 106 | ) 107 | _unsqueeze = lambda nx: torch.unsqueeze if nx_torch(nx) else np.expand_dims 108 | 109 | 110 | class Imputation_GPR: 111 | def __init__( 112 | self, 113 | source_adata: AnnData, 114 | target_points: Optional[ndarray] = None, 115 | keys: Union[str, list] = None, 116 | spatial_key: str = "spatial", 117 | layer: str = "X", 118 | device: str = "cpu", 119 | method: Literal["SVGP", "ExactGP"] = "SVGP", 120 | batch_size: int = 1024, 121 | shuffle: bool = True, 122 | inducing_num: int = 512, 123 | normalize_spatial: bool = True, 124 | ): 125 | # Source data 126 | source_adata = source_adata.copy() 127 | source_adata.X = source_adata.X if layer == "X" else source_adata.layers[layer] 128 | 129 | source_spatial_data = source_adata.obsm[spatial_key] 130 | 131 | info_data = np.ones(shape=(source_spatial_data.shape[0], 1)) 132 | assert keys != None, "`keys` cannot be None." 133 | keys = [keys] if isinstance(keys, str) else keys 134 | obs_keys = [key for key in keys if key in source_adata.obs.keys()] 135 | if len(obs_keys) != 0: 136 | obs_data = np.asarray(source_adata.obs[obs_keys].values) 137 | info_data = np.c_[info_data, obs_data] 138 | var_keys = [key for key in keys if key in source_adata.var_names.tolist()] 139 | if len(var_keys) != 0: 140 | var_data = source_adata[:, var_keys].X 141 | if issparse(var_data): 142 | var_data = var_data.A 143 | info_data = np.c_[info_data, var_data] 144 | info_data = info_data[:, 1:] 145 | 146 | self.device = ( 147 | f"cuda:{device}" if torch.cuda.is_available() and device != "cpu" else "cpu" 148 | ) 149 | torch.device(self.device) 150 | 151 | self.train_x = torch.from_numpy(source_spatial_data).float() 152 | self.train_y = torch.from_numpy(info_data).float() 153 | if self.device == "cpu": 154 | self.train_x = self.train_x.cpu() 155 | self.train_y = self.train_y.cpu() 156 | else: 157 | self.train_x = self.train_x.cuda() 158 | self.train_y = self.train_y.cuda() 159 | self.train_y = self.train_y.squeeze() 160 | 161 | self.nx = ot.backend.get_backend(self.train_x, self.train_y) 162 | 163 | self.normalize_spatial = normalize_spatial 164 | if self.normalize_spatial: 165 | self.train_x = self.normalize_coords(self.train_x) 166 | 167 | self.N = self.train_x.shape[0] 168 | # create training dataloader 169 | self.method = method 170 | from torch.utils.data import DataLoader, TensorDataset 171 | 172 | if method == "SVGP": 173 | train_dataset = TensorDataset(self.train_x, self.train_y) 174 | self.train_loader = DataLoader( 175 | train_dataset, batch_size=batch_size, shuffle=shuffle 176 | ) 177 | inducing_idx = ( 178 | np.random.choice(self.train_x.shape[0], inducing_num) 179 | if self.train_x.shape[0] > inducing_num 180 | else np.arange(self.train_x.shape[0]) 181 | ) 182 | self.inducing_points = self.train_x[inducing_idx, :].clone() 183 | else: 184 | train_loader = {"train_x": self.train_x, "train_y": self.train_y} 185 | # TO-DO: add a dict that contains all the train_x and train_y 186 | # pass 187 | 188 | self.PCA_reduction = False 189 | self.info_keys = {"obs_keys": obs_keys, "var_keys": var_keys} 190 | 191 | # Target data 192 | self.target_points = torch.from_numpy(target_points).float() 193 | self.target_points = ( 194 | self.target_points.cpu() 195 | if self.device == "cpu" 196 | else self.target_points.cuda() 197 | ) 198 | 199 | def normalize_coords( 200 | self, data: Union[np.ndarray, torch.Tensor], given_normalize: bool = False 201 | ): 202 | if not given_normalize: 203 | self.mean_data = _unsqueeze(self.nx)(self.nx.mean(data, axis=0), 0) 204 | data = data - self.mean_data 205 | if not given_normalize: 206 | self.variance = self.nx.sqrt(self.nx.sum(data**2) / data.shape[0]) 207 | data = data / self.variance 208 | return data 209 | 210 | def inference( 211 | self, 212 | training_iter: int = 50, 213 | ): 214 | self.likelihood = GaussianLikelihood() 215 | if self.method == "SVGP": 216 | self.GPR_model = Approx_GPModel(inducing_points=self.inducing_points) 217 | elif self.method == "ExactGP": 218 | self.GPR_model = Exact_GPModel(self.train_x, self.train_y, self.likelihood) 219 | # if to convert to GPU 220 | if self.device != "cpu": 221 | self.GPR_model = self.GPR_model.cuda() 222 | self.likelihood = self.likelihood.cuda() 223 | 224 | # Start training to find optimal model hyperparameters 225 | gp_train( 226 | model=self.GPR_model, 227 | likelihood=self.likelihood, 228 | train_loader=self.train_loader, 229 | train_epochs=training_iter, 230 | method=self.method, 231 | N=self.N, 232 | device=self.device, 233 | ) 234 | 235 | self.GPR_model.eval() 236 | self.likelihood.eval() 237 | 238 | def interpolate( 239 | self, 240 | use_chunk: bool = False, 241 | chunk_num: int = 20, 242 | ): 243 | # Get into evaluation (predictive posterior) mode 244 | self.GPR_model.eval() 245 | self.likelihood.eval() 246 | 247 | target_points = self.target_points 248 | if self.normalize_spatial: 249 | target_points = self.normalize_coords(target_points, given_normalize=True) 250 | 251 | if use_chunk: 252 | target_points_s = _chunk(self.nx, target_points, chunk_num, 0) 253 | arr = [] 254 | with torch.no_grad(), gpytorch.settings.fast_pred_var(): 255 | for target_points_ss in target_points_s: 256 | predictions = self.likelihood(self.GPR_model(target_points_ss)).mean 257 | arr.append(predictions) 258 | quary_target = self.nx.concatenate(arr, axis=0) 259 | else: 260 | with torch.no_grad(), gpytorch.settings.fast_pred_var(): 261 | predictions = self.likelihood(self.GPR_model(target_points)) 262 | quary_target = predictions.mean 263 | 264 | quary_target = ( 265 | np.asarray(quary_target.cpu()) 266 | if self.device != "cpu" 267 | else np.asarray(quary_target) 268 | ) 269 | return quary_target 270 | 271 | 272 | def gp_interpolation( 273 | source_adata: AnnData, 274 | target_points: Optional[ndarray] = None, 275 | keys: Union[str, list] = None, 276 | spatial_key: str = "spatial", 277 | layer: str = "X", 278 | training_iter: int = 50, 279 | device: str = "cpu", 280 | method: Literal["SVGP", "ExactGP"] = "SVGP", 281 | batch_size: int = 1024, 282 | shuffle: bool = True, 283 | inducing_num: int = 512, 284 | ) -> AnnData: 285 | """ 286 | Learn a continuous mapping from space to gene expression pattern with the Gaussian Process method. 287 | 288 | Args: 289 | source_adata: AnnData object that contains spatial (numpy.ndarray) in the `obsm` attribute. 290 | target_points: The spatial coordinates of new data point. If target_coords is None, generate new points based on grid_num. 291 | keys: Gene list or info list in the `obs` attribute whose interpolate expression across space needs to learned. 292 | spatial_key: The key in ``.obsm`` that corresponds to the spatial coordinate of each bucket. 293 | layer: If ``'X'``, uses ``.X``, otherwise uses the representation given by ``.layers[layer]``. 294 | training_iter: Max number of iterations for training. 295 | device: Equipment used to run the program. You can also set the specified GPU for running. ``E.g.: '0'``. 296 | 297 | Returns: 298 | interp_adata: an anndata object that has interpolated expression. 299 | """ 300 | 301 | # Inference 302 | GPR = Imputation_GPR( 303 | source_adata=source_adata, 304 | target_points=target_points, 305 | keys=keys, 306 | spatial_key=spatial_key, 307 | layer=layer, 308 | device=device, 309 | method=method, 310 | batch_size=batch_size, 311 | shuffle=shuffle, 312 | inducing_num=inducing_num, 313 | ) 314 | GPR.inference(training_iter=training_iter) 315 | 316 | # Interpolation 317 | target_info_data = GPR.interpolate(use_chunk=True) 318 | target_info_data = target_info_data[:, None] 319 | 320 | # Output interpolated anndata 321 | obs_keys = GPR.info_keys["obs_keys"] 322 | if len(obs_keys) != 0: 323 | obs_data = target_info_data[:, : len(obs_keys)] 324 | obs_data = pd.DataFrame(obs_data, columns=obs_keys) 325 | 326 | var_keys = GPR.info_keys["var_keys"] 327 | if len(var_keys) != 0: 328 | X = target_info_data[:, len(obs_keys) :] 329 | var_data = pd.DataFrame(index=var_keys) 330 | 331 | interp_adata = AnnData( 332 | X=X if len(var_keys) != 0 else None, 333 | obs=obs_data if len(obs_keys) != 0 else None, 334 | obsm={spatial_key: np.asarray(target_points)}, 335 | var=var_data if len(var_keys) != 0 else None, 336 | ) 337 | return interp_adata 338 | -------------------------------------------------------------------------------- /stviewer/Explorer/pv_pipeline/pv_morphogenesis.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import numpy as np 4 | from anndata import AnnData 5 | from pyvista import PolyData 6 | from scipy.integrate import odeint 7 | 8 | 9 | def morphogenesis( 10 | source_adata: AnnData, 11 | source_pc_model: PolyData, 12 | target_adata: Optional[AnnData] = None, 13 | mapping_method: Literal["GP", "OT"] = "GP", 14 | mapping_factor: float = 0.2, 15 | mapping_device: str = "cpu", 16 | morphofield_factor: int = 3000, 17 | morphopath_t_end: int = 10000, 18 | morphopath_sampling: int = 500, 19 | ): 20 | try: 21 | import spateo as st 22 | except ImportError: 23 | raise ImportError( 24 | "You need to install the package `spateo`. " 25 | "\nInstall spateo via `pip install spateo-release`." 26 | ) 27 | try: 28 | from dynamo.vectorfield import SvcVectorField 29 | except ImportError: 30 | raise ImportError( 31 | "You need to install the package `dynamo`. " 32 | "\nInstall dynamo via `pip install dynamo-release`." 33 | ) 34 | 35 | # Preprocess 36 | _obs_index = source_pc_model.point_data["obs_index"] 37 | source_adata = source_adata[_obs_index, :] 38 | 39 | # 3D mapping and morphofield 40 | if mapping_method == "OT": 41 | if not (target_adata is None): 42 | _ = st.tdr.cell_directions( 43 | adataA=source_adata, 44 | adataB=target_adata, 45 | numItermaxEmd=2000000, 46 | spatial_key="spatial", 47 | key_added="cells_mapping", 48 | alpha=mapping_factor, 49 | device=mapping_device, 50 | inplace=True, 51 | ) 52 | 53 | if "V_cells_mapping" not in source_adata.obsm.keys(): 54 | raise ValueError("You need to add the target anndata object. ") 55 | 56 | st.tdr.morphofield_sparsevfc( 57 | adata=source_adata, 58 | spatial_key="spatial", 59 | V_key="V_cells_mapping", 60 | key_added="VecFld_morpho", 61 | NX=None, 62 | inplace=True, 63 | ) 64 | elif mapping_method == "GP": 65 | if not (target_adata is None): 66 | align_models, _, _ = st.align.morpho_align_sparse( 67 | models=[target_adata.copy(), source_adata.copy()], 68 | spatial_key="spatial", 69 | key_added="mapping_spatial", 70 | device=mapping_device, 71 | mode="SN-S", 72 | max_iter=200, 73 | partial_robust_level=1, 74 | beta=0.1, # nonrigid, 75 | beta2_end=mapping_factor, # low beta2_end, high expression similarity 76 | lambdaVF=1, 77 | K=200, 78 | SVI_mode=True, 79 | use_sparse=True, 80 | ) 81 | source_adata = align_models[1].copy() 82 | 83 | if "VecFld_morpho" not in source_adata.uns.keys(): 84 | raise ValueError("You need to add the target anndata object. ") 85 | 86 | st.tdr.morphofield_gp( 87 | adata=source_adata, 88 | spatial_key="spatial", 89 | vf_key="VecFld_morpho", 90 | NX=np.asarray(source_adata.obsm["spatial"]), 91 | inplace=True, 92 | ) 93 | 94 | # construct morphofield model 95 | source_adata.obs["V_z"] = source_adata.uns["VecFld_morpho"]["V"][:, 2].flatten() 96 | source_pc_model.point_data["vectors"] = source_adata.uns["VecFld_morpho"]["V"] 97 | source_pc_model.point_data["V_Z"] = source_pc_model.point_data["vectors"][ 98 | :, 2 99 | ].flatten() 100 | 101 | pc_vectors, _ = st.tdr.construct_field( 102 | model=source_pc_model, 103 | vf_key="vectors", 104 | arrows_scale_key="vectors", 105 | n_sampling=None, 106 | factor=morphofield_factor, 107 | key_added="obs_index", 108 | label=source_pc_model.point_data["obs_index"], 109 | ) 110 | 111 | # construct morphopath model 112 | st.tdr.morphopath( 113 | adata=source_adata, 114 | vf_key="VecFld_morpho", 115 | key_added="fate_morpho", 116 | t_end=morphopath_t_end, 117 | interpolation_num=20, 118 | cores=10, 119 | ) 120 | trajectory_model, _ = st.tdr.construct_trajectory( 121 | adata=source_adata, 122 | fate_key="fate_morpho", 123 | n_sampling=morphopath_sampling, 124 | sampling_method="random", 125 | key_added="obs_index", 126 | label=np.asarray(source_adata.obs.index), 127 | ) 128 | 129 | # morphometric features 130 | st.tdr.morphofield_acceleration( 131 | adata=source_adata, vf_key="VecFld_morpho", key_added="acceleration" 132 | ) 133 | st.tdr.morphofield_curvature( 134 | adata=source_adata, vf_key="VecFld_morpho", key_added="curvature" 135 | ) 136 | st.tdr.morphofield_curl( 137 | adata=source_adata, vf_key="VecFld_morpho", key_added="curl" 138 | ) 139 | st.tdr.morphofield_torsion( 140 | adata=source_adata, vf_key="VecFld_morpho", key_added="torsion" 141 | ) 142 | st.tdr.morphofield_divergence( 143 | adata=source_adata, vf_key="VecFld_morpho", key_added="divergence" 144 | ) 145 | 146 | source_pc_index = source_pc_model.point_data["obs_index"] 147 | source_pc_model.point_data["acceleration"] = np.asarray( 148 | source_adata[np.asarray(source_pc_index)].obs["acceleration"] 149 | ) 150 | source_pc_model.point_data["curvature"] = np.asarray( 151 | source_adata[np.asarray(source_pc_index)].obs["curvature"] 152 | ) 153 | source_pc_model.point_data["curl"] = np.asarray( 154 | source_adata[np.asarray(source_pc_index)].obs["curl"] 155 | ) 156 | source_pc_model.point_data["torsion"] = np.asarray( 157 | source_adata[np.asarray(source_pc_index)].obs["torsion"] 158 | ) 159 | source_pc_model.point_data["divergence"] = np.asarray( 160 | source_adata[np.asarray(source_pc_index)].obs["divergence"] 161 | ) 162 | 163 | pc_vectors_index = pc_vectors.point_data["obs_index"] 164 | pc_vectors.point_data["V_Z"] = np.asarray( 165 | source_adata[np.asarray(pc_vectors_index)].obs["V_z"] 166 | ) 167 | pc_vectors.point_data["acceleration"] = np.asarray( 168 | source_adata[np.asarray(pc_vectors_index)].obs["acceleration"] 169 | ) 170 | pc_vectors.point_data["curvature"] = np.asarray( 171 | source_adata[np.asarray(pc_vectors_index)].obs["curvature"] 172 | ) 173 | pc_vectors.point_data["curl"] = np.asarray( 174 | source_adata[np.asarray(pc_vectors_index)].obs["curl"] 175 | ) 176 | pc_vectors.point_data["torsion"] = np.asarray( 177 | source_adata[np.asarray(pc_vectors_index)].obs["torsion"] 178 | ) 179 | pc_vectors.point_data["divergence"] = np.asarray( 180 | source_adata[np.asarray(pc_vectors_index)].obs["divergence"] 181 | ) 182 | 183 | trajectory_index = trajectory_model.point_data["obs_index"] 184 | trajectory_model.point_data["V_Z"] = np.asarray( 185 | source_adata[np.asarray(trajectory_index)].obs["V_z"] 186 | ) 187 | trajectory_model.point_data["acceleration"] = np.asarray( 188 | source_adata[np.asarray(trajectory_index)].obs["acceleration"] 189 | ) 190 | trajectory_model.point_data["curvature"] = np.asarray( 191 | source_adata[np.asarray(trajectory_index)].obs["curvature"] 192 | ) 193 | trajectory_model.point_data["curl"] = np.asarray( 194 | source_adata[np.asarray(trajectory_index)].obs["curl"] 195 | ) 196 | trajectory_model.point_data["torsion"] = np.asarray( 197 | source_adata[np.asarray(trajectory_index)].obs["torsion"] 198 | ) 199 | trajectory_model.point_data["divergence"] = np.asarray( 200 | source_adata[np.asarray(trajectory_index)].obs["divergence"] 201 | ) 202 | 203 | # cell stages of animation 204 | t_ind = np.asarray(list(source_adata.uns["fate_morpho"]["t"].keys()), dtype=int) 205 | t_sort_ind = np.argsort(t_ind) 206 | t = np.asarray(list(source_adata.uns["fate_morpho"]["t"].values()))[t_sort_ind] 207 | flats = np.unique([int(item) for sublist in t for item in sublist]) 208 | flats = np.hstack((0, flats)) 209 | flats = np.sort(flats) if 3000 is None else np.sort(flats[flats <= 3000]) 210 | time_vec = np.logspace(0, np.log10(max(flats) + 1), 100) - 1 211 | vf = SvcVectorField() 212 | vf.from_adata(source_adata, basis="morpho") 213 | f = lambda x, _: vf.func(x) 214 | displace = lambda x, dt: odeint(f, x, [0, dt]) 215 | 216 | init_states = source_adata.uns["fate_morpho"]["init_states"] 217 | pts = [i.tolist() for i in init_states] 218 | stages_X = [source_adata.obs.index.tolist()] 219 | for i in range(100): 220 | pts = [displace(cur_pts, time_vec[i])[1].tolist() for cur_pts in pts] 221 | stages_X.append(pts) 222 | 223 | return source_pc_model, pc_vectors, trajectory_model, stages_X 224 | -------------------------------------------------------------------------------- /stviewer/Explorer/pv_pipeline/pv_plotter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import pyvista as pv 4 | from pyvista import Plotter, PolyData, UnstructuredGrid 5 | 6 | try: 7 | from typing import Literal 8 | except ImportError: 9 | from typing_extensions import Literal 10 | 11 | 12 | def create_plotter( 13 | window_size: tuple = (1024, 1024), background: str = "black", **kwargs 14 | ) -> Plotter: 15 | """ 16 | Create a plotting object to display pyvista/vtk model. 17 | 18 | Args: 19 | window_size: Window size in pixels. The default window_size is ``[1024, 768]``. 20 | background: The background color of the window. 21 | 22 | Returns: 23 | plotter: The plotting object to display pyvista/vtk model. 24 | """ 25 | 26 | # Create an initial plotting object. 27 | plotter = pv.Plotter( 28 | window_size=window_size, off_screen=True, lighting="light_kit", **kwargs 29 | ) 30 | 31 | # Set the background color of the active render window. 32 | plotter.background_color = background 33 | return plotter 34 | 35 | 36 | def add_single_model( 37 | plotter: Plotter, 38 | model: Union[PolyData, UnstructuredGrid], 39 | key: Optional[str] = None, 40 | cmap: Optional[str] = "rainbow", 41 | color: Optional[str] = "gainsboro", 42 | ambient: float = 0.2, 43 | opacity: float = 1.0, 44 | model_style: Literal["points", "surface", "wireframe"] = "surface", 45 | model_size: float = 3.0, 46 | model_name: Optional[str] = None, 47 | ): 48 | """ 49 | Add model(s) to the plotter. 50 | Args: 51 | plotter: The plotting object to display pyvista/vtk model. 52 | model: A reconstructed model. 53 | key: The key under which are the labels. 54 | cmap: Name of the Matplotlib colormap to use when mapping the model. 55 | color: Name of the Matplotlib color to use when mapping the model. 56 | ambient: When lighting is enabled, this is the amount of light in the range of 0 to 1 (default 0.0) that reaches 57 | the actor when not directed at the light source emitted from the viewer. 58 | opacity: Opacity of the model. 59 | If a single float value is given, it will be the global opacity of the model and uniformly applied 60 | everywhere, elif a numpy.ndarray with single float values is given, it 61 | will be the opacity of each point. - should be between 0 and 1. 62 | A string can also be specified to map the scalars range to a predefined opacity transfer function 63 | (options include: 'linear', 'linear_r', 'geom', 'geom_r'). 64 | model_style: Visualization style of the model. One of the following: 65 | * ``model_style = 'surface'``, 66 | * ``model_style = 'wireframe'``, 67 | * ``model_style = 'points'``. 68 | model_size: If ``model_style = 'points'``, point size of any nodes in the dataset plotted. 69 | If ``model_style = 'wireframe'``, thickness of lines. 70 | model_name: Name to assign to the model. Defaults to the memory address. 71 | """ 72 | 73 | if model_style == "points": 74 | render_spheres, render_tubes, smooth_shading = True, False, True 75 | elif model_style == "wireframe": 76 | render_spheres, render_tubes, smooth_shading = False, True, False 77 | else: 78 | render_spheres, render_tubes, smooth_shading = False, False, True 79 | mesh_kwargs = dict( 80 | scalars=key if key in model.array_names else None, 81 | style=model_style, 82 | render_points_as_spheres=render_spheres, 83 | render_lines_as_tubes=render_tubes, 84 | point_size=model_size, 85 | line_width=model_size, 86 | ambient=ambient, 87 | opacity=opacity, 88 | smooth_shading=smooth_shading, 89 | show_scalar_bar=False, 90 | cmap=cmap, 91 | color=color, 92 | name=model_name, 93 | ) 94 | actor = plotter.add_mesh(model, **mesh_kwargs) 95 | return actor 96 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/__init__.py: -------------------------------------------------------------------------------- 1 | from .container import ui_container 2 | from .drawer import ui_drawer 3 | from .layout import ui_layout 4 | from .toolbar import toolbar_switch_model, toolbar_widgets, ui_title, ui_toolbar 5 | from .utils import button, checkbox, switch 6 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/container.py: -------------------------------------------------------------------------------- 1 | try: 2 | from typing import Literal 3 | except ImportError: 4 | from typing_extensions import Literal 5 | 6 | from pyvista import BasePlotter 7 | from pyvista.trame import PyVistaLocalView, PyVistaRemoteLocalView, PyVistaRemoteView 8 | from trame.widgets import vuetify 9 | 10 | from .toolbar import Viewer 11 | 12 | # ----------------------------------------------------------------------------- 13 | # GUI- standard Container 14 | # ----------------------------------------------------------------------------- 15 | 16 | 17 | def ui_container( 18 | server, 19 | layout, 20 | plotter: BasePlotter, 21 | mode: Literal["trame", "server", "client"] = "trame", 22 | default_server_rendering: bool = True, 23 | **kwargs, 24 | ): 25 | """ 26 | Generate standard VContainer for Spateo UI. 27 | 28 | Args: 29 | server: The trame server. 30 | layout: The layout object. 31 | plotter: The PyVista plotter to connect with the UI. 32 | mode: The UI view mode. Options are: 33 | 34 | * ``'trame'``: Uses a view that can switch between client and server rendering modes. 35 | * ``'server'``: Uses a view that is purely server rendering. 36 | * ``'client'``: Uses a view that is purely client rendering (generally safe without a virtual frame buffer) 37 | default_server_rendering: Whether to use server-side or client-side rendering on-start when using the ``'trame'`` mode. 38 | kwargs: Additional parameters that will be passed to ``pyvista.trame.app.PyVistaXXXXView`` function. 39 | """ 40 | if mode != "trame": 41 | default_server_rendering = mode == "server" 42 | 43 | viewer = Viewer(plotter, server, suppress_rendering=mode == "client") 44 | ctrl = server.controller 45 | 46 | with layout.content as con: 47 | with vuetify.VContainer( 48 | fluid=True, 49 | classes="pa-0 fill-height", 50 | ): 51 | if mode == "trame": 52 | view = PyVistaRemoteLocalView( 53 | plotter, 54 | mode=( 55 | # Must use single-quote string for JS here 56 | f"{viewer.SERVER_RENDERING} ? 'remote' : 'local'", 57 | "remote" if default_server_rendering else "local", 58 | ), 59 | **kwargs, 60 | ) 61 | ctrl.view_update_image = view.update_image 62 | elif mode == "server": 63 | view = PyVistaRemoteView(plotter, **kwargs) 64 | elif mode == "client": 65 | view = PyVistaLocalView(plotter, **kwargs) 66 | 67 | ctrl.view_update = view.update 68 | ctrl.view_reset_camera = view.reset_camera 69 | ctrl.view_push_camera = view.push_camera 70 | ctrl.on_server_ready.add(view.update) 71 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/drawer/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import ui_drawer 2 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/drawer/adata_obj.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import vuetify 2 | 3 | 4 | def anndata_object_content(): 5 | with vuetify.VRow(classes="pt-2", dense=True): 6 | with vuetify.VCol(cols="12"): 7 | vuetify.VTextarea( 8 | v_model=("anndata_info.anndata_structure",), 9 | label="AnnData Structure", 10 | hide_details=True, 11 | dense=True, 12 | outlined=True, 13 | classes="pt-1", 14 | ) 15 | with vuetify.VRow(classes="pt-2", dense=True): 16 | with vuetify.VCol(cols="12"): 17 | vuetify.VFileInput( 18 | v_model=("uploaded_anndata_path", None), 19 | label="Upload AnnData (.h5ad)", 20 | show_size=True, 21 | small_chips=True, 22 | dense=True, 23 | outlined=True, 24 | hide_details=True, 25 | classes="pt-1", 26 | rounded=False, 27 | accept=".h5ad", 28 | __properties=["accept"], 29 | ) 30 | 31 | 32 | def anndata_panel(): 33 | with vuetify.VToolbar( 34 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;" 35 | ): 36 | vuetify.VIcon("mdi-apps") 37 | vuetify.VCardTitle( 38 | " AnnData Object", 39 | classes="pa-0 ma-0", 40 | style="flex: none;", 41 | hide_details=True, 42 | dense=True, 43 | ) 44 | 45 | vuetify.VSpacer() 46 | with vuetify.VBtn( 47 | small=True, 48 | icon=True, 49 | click="show_anndata_card = !show_anndata_card", 50 | ): 51 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_anndata_card",)) 52 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_anndata_card",)) 53 | 54 | # Main content 55 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"): 56 | with vuetify.VCardText(classes="py-2", v_if=("show_anndata_card",)): 57 | anndata_object_content() 58 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/drawer/custom_card.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import vuetify 2 | 3 | 4 | def custom_card_content(): 5 | with vuetify.VRow(classes="pt-2", dense=True): 6 | with vuetify.VCol(cols="6"): 7 | vuetify.VCheckbox( 8 | v_model=("custom_analysis", False), 9 | label="Custom analysis calculation", 10 | on_icon="mdi-transit-detour", 11 | off_icon="mdi-transit-skip", 12 | dense=True, 13 | hide_details=True, 14 | classes="pt-1", 15 | ) 16 | with vuetify.VCol(cols="6"): 17 | vuetify.VCheckbox( 18 | v_model=("custom_model_visible", False), 19 | label="Custom model visibility", 20 | on_icon="mdi-pyramid", 21 | off_icon="mdi-pyramid-off", 22 | dense=True, 23 | hide_details=True, 24 | classes="pt-1", 25 | ) 26 | 27 | with vuetify.VRow(classes="pt-2", dense=True): 28 | with vuetify.VCol(cols="6"): 29 | vuetify.VSelect( 30 | v_model=("custom_parameter1", "X"), 31 | items=("matrices_list",), 32 | label="Custom parameter 1", 33 | hide_details=True, 34 | dense=True, 35 | outlined=True, 36 | classes="pt-1", 37 | ) 38 | with vuetify.VCol(cols="6"): 39 | vuetify.VSelect( 40 | v_model=("custom_parameter2", "recipe_monocle"), 41 | items=(["False", "recipe_monocle", "pearson_residuals"],), 42 | label="Custom parameter 2", 43 | hide_details=True, 44 | dense=True, 45 | outlined=True, 46 | classes="pt-1", 47 | ) 48 | 49 | with vuetify.VRow(classes="pt-2", dense=True): 50 | with vuetify.VCol(cols="6"): 51 | vuetify.VTextField( 52 | v_model=("custom_parameter3", "pca"), 53 | label="Custom parameter 3", 54 | hide_details=True, 55 | dense=True, 56 | outlined=True, 57 | classes="pt-1", 58 | ) 59 | with vuetify.VCol(cols="6"): 60 | vuetify.VTextField( 61 | v_model=("custom_parameter4", "umap"), 62 | label="Custom parameter 4", 63 | hide_details=True, 64 | dense=True, 65 | outlined=True, 66 | classes="pt-1", 67 | ) 68 | 69 | with vuetify.VRow(classes="pt-2", dense=True): 70 | vuetify.VCheckbox( 71 | v_model=("custom_parameter5", False), 72 | label="Custom parameter 5", 73 | on_icon="mdi-plus-thick", 74 | off_icon="mdi-close-thick", 75 | dense=True, 76 | hide_details=True, 77 | classes="pt-1", 78 | ) 79 | with vuetify.VCol(cols="6"): 80 | vuetify.VTextField( 81 | v_model=("custom_parameter6", "None"), 82 | label="Custom parameter 6", 83 | hide_details=True, 84 | dense=True, 85 | outlined=True, 86 | classes="pt-1", 87 | ) 88 | 89 | with vuetify.VRow(classes="pt-2", dense=True): 90 | with vuetify.VCol(cols="6"): 91 | vuetify.VTextField( 92 | v_model=("custom_parameter7", 30), 93 | label="Custom parameter 7", 94 | hide_details=True, 95 | dense=True, 96 | outlined=True, 97 | classes="pt-1", 98 | ) 99 | with vuetify.VCol(cols="6"): 100 | vuetify.VTextField( 101 | v_model=("custom_parameter8", 30), 102 | label="Custom parameter 8", 103 | hide_details=True, 104 | dense=True, 105 | outlined=True, 106 | classes="pt-1", 107 | ) 108 | 109 | with vuetify.VRow(classes="pt-2", dense=True): 110 | with vuetify.VCol(cols="6"): 111 | vuetify.VTextField( 112 | v_model=("custom_parameter9", "None"), 113 | label="Custom parameter 9", 114 | hide_details=True, 115 | dense=True, 116 | outlined=True, 117 | classes="pt-1", 118 | ) 119 | with vuetify.VCol(cols="6"): 120 | vuetify.VTextField( 121 | v_model=("custom_parameter10", 1), 122 | label="Custom parameter 10", 123 | hide_details=True, 124 | dense=True, 125 | outlined=True, 126 | classes="pt-1", 127 | ) 128 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/drawer/main.py: -------------------------------------------------------------------------------- 1 | try: 2 | from typing import Literal 3 | except ImportError: 4 | from typing_extensions import Literal 5 | 6 | from pyvista import BasePlotter 7 | from trame.widgets import vuetify 8 | 9 | 10 | def _get_spateo_cmap(): 11 | import matplotlib as mpl 12 | from matplotlib.colors import LinearSegmentedColormap 13 | 14 | if "spateo_cmap" not in mpl.colormaps(): 15 | colors = ["#4B0082", "#800080", "#F97306", "#FFA500", "#FFD700", "#FFFFCB"] 16 | nodes = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] 17 | 18 | mpl.colormaps.register( 19 | LinearSegmentedColormap.from_list("spateo_cmap", list(zip(nodes, colors))) 20 | ) 21 | return "spateo_cmap" 22 | 23 | 24 | def ui_drawer( 25 | server, 26 | layout, 27 | plotter: BasePlotter, 28 | mode: Literal["trame", "server", "client"] = "trame", 29 | ): 30 | """ 31 | Generate standard Drawer for Spateo UI. 32 | 33 | Args: 34 | server: The trame server. 35 | layout: The layout object. 36 | plotter: The PyVista plotter to connect with the UI. 37 | mode: The UI view mode. Options are: 38 | 39 | * ``'trame'``: Uses a view that can switch between client and server rendering modes. 40 | * ``'server'``: Uses a view that is purely server rendering. 41 | * ``'client'``: Uses a view that is purely client rendering (generally safe without a virtual frame buffer) 42 | """ 43 | 44 | _get_spateo_cmap() 45 | from stviewer.Explorer.pv_pipeline import PVCB 46 | 47 | PVCB(server=server, plotter=plotter, suppress_rendering=mode == "client") 48 | 49 | with layout.drawer as dr: 50 | # Pipeline 51 | from .pipeline import pipeline_panel 52 | 53 | pipeline_panel(server=server, plotter=plotter) 54 | 55 | # AnnData object 56 | from .adata_obj import anndata_panel 57 | 58 | anndata_panel() 59 | 60 | # Active model 61 | with vuetify.VToolbar( 62 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;" 63 | ): 64 | vuetify.VIcon("mdi-format-paint") 65 | vuetify.VCardTitle( 66 | " Active Model", 67 | classes="pa-0 ma-0", 68 | style="flex: none;", 69 | hide_details=True, 70 | dense=True, 71 | ) 72 | 73 | vuetify.VSpacer() 74 | with vuetify.VBtn( 75 | small=True, 76 | icon=True, 77 | click="show_model_card = !show_model_card", 78 | ): 79 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_model_card",)) 80 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_model_card",)) 81 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"): 82 | with vuetify.VCardText( 83 | classes="py-2 ", 84 | v_show=f"active_model_type === 'PC'", 85 | v_if=("show_model_card",), 86 | ): 87 | items = ["Model", "Morphogenesis"] 88 | items = ( 89 | items + ["Custom"] if server.state.custom_func is True else items 90 | ) 91 | with vuetify.VTabs(v_model=("pc_active_tab", 0), left=True): 92 | for item in items: 93 | vuetify.VTab( 94 | item, 95 | style="width: 50%;", 96 | ) 97 | with vuetify.VTabsItems( 98 | value=("pc_active_tab",), 99 | style="width: 100%; height: 100%;", 100 | ): 101 | with vuetify.VTabItem(value=(0,)): 102 | from .model_point import pc_card_content 103 | 104 | pc_card_content() 105 | with vuetify.VTabItem(value=(1,)): 106 | from .morphogenesis import morphogenesis_card_content 107 | 108 | morphogenesis_card_content() 109 | # Custom 110 | if server.state.custom_func is True: 111 | with vuetify.VTabItem(value=(2,)): 112 | from .custom_card import custom_card_content 113 | 114 | custom_card_content() 115 | with vuetify.VCardText( 116 | classes="py-2", 117 | v_show=f"active_model_type === 'Mesh'", 118 | v_if=("show_model_card",), 119 | ): 120 | from .model_mesh import mesh_card_content 121 | 122 | mesh_card_content() 123 | 124 | # Output 125 | from .output import output_panel 126 | 127 | output_panel() 128 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/drawer/model_mesh.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import vuetify 2 | 3 | 4 | def mesh_card_content(): 5 | with vuetify.VRow(classes="pt-2", dense=True): 6 | with vuetify.VCol(cols="6"): 7 | vuetify.VCombobox( 8 | label="Color", 9 | v_model=("mesh_color_value", "None"), 10 | items=("mesh_colors_list",), 11 | type="str", 12 | hide_details=True, 13 | dense=True, 14 | outlined=True, 15 | classes="pt-1", 16 | ) 17 | with vuetify.VCol(cols="6"): 18 | vuetify.VCombobox( 19 | label="Style", 20 | v_model=("mesh_style_value", "surface"), 21 | items=(f"styles", ["surface", "points", "wireframe"]), 22 | hide_details=True, 23 | dense=True, 24 | outlined=True, 25 | classes="pt-1", 26 | ) 27 | # Opacity 28 | vuetify.VSlider( 29 | v_model=("mesh_opacity_value", 0.2), 30 | min=0, 31 | max=1, 32 | step=0.05, 33 | label="Opacity", 34 | classes="mt-1", 35 | hide_details=True, 36 | dense=True, 37 | ) 38 | # Ambient 39 | vuetify.VSlider( 40 | v_model=("mesh_ambient_value", 0.2), 41 | min=0, 42 | max=1, 43 | step=0.05, 44 | label="Ambient", 45 | classes="mt-1", 46 | hide_details=True, 47 | dense=True, 48 | ) 49 | vuetify.VCheckbox( 50 | v_model=("mesh_morphology", False), 51 | label="Model Morphological Metrics", 52 | on_icon="mdi-pencil-ruler", 53 | off_icon="mdi-ruler", 54 | dense=True, 55 | hide_details=True, 56 | classes="mt-1", 57 | ) 58 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/drawer/model_point.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from trame.widgets import vuetify 3 | 4 | 5 | def pc_card_content(): 6 | with vuetify.VRow(classes="pt-2", dense=True): 7 | with vuetify.VCol(cols="6"): 8 | vuetify.VCombobox( 9 | label="Observation Annotation", 10 | v_model=("pc_obs_value", None), 11 | items=("available_obs",), 12 | type="str", 13 | show_size=True, 14 | dense=True, 15 | outlined=True, 16 | hide_details=True, 17 | classes="pt-1", 18 | ) 19 | with vuetify.VCol(cols="6"): 20 | vuetify.VCheckbox( 21 | v_model=("pc_reload", False), 22 | label="Reload Model", 23 | on_icon="mdi-restore", 24 | off_icon="mdi-restore", 25 | dense=True, 26 | hide_details=True, 27 | classes="pt-1", 28 | ) 29 | 30 | with vuetify.VRow(classes="pt-2", dense=True): 31 | with vuetify.VCol(cols="6"): 32 | vuetify.VCombobox( 33 | v_model=("pc_picking_group", None), 34 | items=("Object.keys(pc_scalars_raw)",), 35 | label="Picking Group", 36 | show_size=True, 37 | dense=True, 38 | outlined=True, 39 | hide_details=True, 40 | classes="pt-1", 41 | ) 42 | with vuetify.VCol(cols="6"): 43 | vuetify.VCheckbox( 44 | v_model=("pc_overwrite", False), 45 | label="Add Picked Group", 46 | on_icon="mdi-plus-thick", 47 | off_icon="mdi-close-thick", 48 | dense=True, 49 | hide_details=True, 50 | classes="pt-1", 51 | ) 52 | 53 | with vuetify.VRow(classes="pt-2", dense=True): 54 | with vuetify.VCol(cols="6"): 55 | vuetify.VCombobox( 56 | label="Available Genes", 57 | v_model=("pc_gene_value", None), 58 | items=("available_genes",), 59 | type="str", 60 | show_size=True, 61 | dense=True, 62 | outlined=True, 63 | hide_details=True, 64 | classes="pt-1", 65 | ) 66 | with vuetify.VCol(cols="6"): 67 | vuetify.VCheckbox( 68 | v_model=("pc_add_legend", True), 69 | label="Add Legend", 70 | on_icon="mdi-view-grid-plus", 71 | off_icon="mdi-view-grid", 72 | dense=True, 73 | hide_details=True, 74 | classes="pt-1", 75 | ) 76 | with vuetify.VRow(classes="pt-2", dense=True): 77 | with vuetify.VCol(cols="6"): 78 | vuetify.VTextField( 79 | v_model=("interpolation_device", "cpu"), 80 | label="Interpolation Device", 81 | hide_details=True, 82 | dense=True, 83 | outlined=True, 84 | classes="pt-1", 85 | ) 86 | with vuetify.VCol(cols="6"): 87 | vuetify.VCheckbox( 88 | v_model=("cal_interpolation", True), 89 | label="GP Interpolation", 90 | on_icon="mdi-smoke-detector", 91 | off_icon="mdi-smoke-detector-off", 92 | dense=True, 93 | hide_details=True, 94 | classes="pt-1", 95 | ) 96 | 97 | with vuetify.VRow(classes="pt-2", dense=True): 98 | with vuetify.VCol(cols="6"): 99 | vuetify.VCombobox( 100 | v_model=("pc_coords_value", "spatial"), 101 | items=("anndata_info.anndata_obsm_keys",), 102 | label="Coords", 103 | type="str", 104 | hide_details=True, 105 | dense=True, 106 | outlined=True, 107 | classes="pt-1", 108 | ) 109 | with vuetify.VCol(cols="6"): 110 | vuetify.VCombobox( 111 | v_model=("pc_matrix_value", "X"), 112 | items=("anndata_info.anndata_matrices",), 113 | label="Matrices", 114 | type="str", 115 | hide_details=True, 116 | dense=True, 117 | outlined=True, 118 | classes="pt-1", 119 | ) 120 | 121 | with vuetify.VRow(classes="pt-2", dense=True): 122 | with vuetify.VCol(cols="6"): 123 | vuetify.VCombobox( 124 | v_model=("pc_color_value", "None"), 125 | items=("pc_colors_list",), 126 | label="Color", 127 | type="str", 128 | hide_details=True, 129 | dense=True, 130 | outlined=True, 131 | classes="pt-1", 132 | ) 133 | # Colormap 134 | with vuetify.VCol(cols="6"): 135 | vuetify.VCombobox( 136 | v_model=("pc_colormap_value", "Spectral"), 137 | items=("pc_colormaps_list",), 138 | label="Colormap", 139 | type="str", 140 | hide_details=True, 141 | dense=True, 142 | outlined=True, 143 | classes="pt-1", 144 | ) 145 | 146 | # Opacity 147 | vuetify.VSlider( 148 | v_model=("pc_opacity_value", 1.0), 149 | min=0, 150 | max=1, 151 | step=0.05, 152 | label="Opacity", 153 | classes="mt-1", 154 | hide_details=True, 155 | dense=True, 156 | ) 157 | # Ambient 158 | vuetify.VSlider( 159 | v_model=("pc_ambient_value", 0.2), 160 | min=0, 161 | max=1, 162 | step=0.05, 163 | label="Ambient", 164 | classes="mt-1", 165 | hide_details=True, 166 | dense=True, 167 | ) 168 | # Point size 169 | vuetify.VSlider( 170 | v_model=("pc_point_size_value", 2), 171 | min=0, 172 | max=20, 173 | step=1, 174 | label="Point Size", 175 | classes="mt-1", 176 | hide_details=True, 177 | dense=True, 178 | ) 179 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/drawer/morphogenesis.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import vuetify 2 | 3 | 4 | def morphogenesis_card_content(): 5 | with vuetify.VRow(classes="pt-2", dense=True): 6 | with vuetify.VCol(cols="12"): 7 | vuetify.VCheckbox( 8 | v_model=("cal_morphogenesis", False), 9 | label="Calculate the Morphogenesis", 10 | on_icon="mdi-transit-detour", 11 | off_icon="mdi-transit-skip", 12 | dense=True, 13 | hide_details=True, 14 | classes="pt-1", 15 | ) 16 | 17 | with vuetify.VRow(classes="pt-2", dense=True): 18 | with vuetify.VCol(cols="12"): 19 | vuetify.VFileInput( 20 | v_model=("morpho_uploaded_target_anndata_path", None), 21 | label="Upload Target AnnData (.h5ad)", 22 | dense=True, 23 | classes="pt-1", 24 | accept=".h5ad", 25 | hide_details=True, 26 | show_size=True, 27 | small_chips=True, 28 | truncate_length=0, 29 | outlined=True, 30 | __properties=["accept"], 31 | ) 32 | with vuetify.VRow(classes="pt-2", dense=True): 33 | avaliable_samples = [ 34 | "uploaded_target_anndata", 35 | ] 36 | with vuetify.VCol(cols="12"): 37 | vuetify.VSelect( 38 | label="Target AnnData", 39 | v_model=("morpho_target_anndata_path", None), 40 | items=(avaliable_samples,), 41 | dense=True, 42 | outlined=True, 43 | hide_details=True, 44 | classes="pt-1", 45 | ) 46 | 47 | with vuetify.VRow(classes="pt-2", dense=True): 48 | with vuetify.VCol(cols="6"): 49 | vuetify.VSelect( 50 | v_model=("morpho_mapping_method", "GP"), 51 | items=(["OT", "GP"],), 52 | label="Mapping Method", 53 | hide_details=True, 54 | dense=True, 55 | outlined=True, 56 | classes="pt-1", 57 | ) 58 | with vuetify.VCol(cols="6"): 59 | vuetify.VTextField( 60 | v_model=("morpho_mapping_device", "cpu"), 61 | label="Mapping Device", 62 | hide_details=True, 63 | dense=True, 64 | outlined=True, 65 | classes="pt-1", 66 | ) 67 | 68 | with vuetify.VRow(classes="pt-2", dense=True): 69 | with vuetify.VCol(cols="6"): 70 | vuetify.VTextField( 71 | v_model=("morpho_mapping_factor", 0.2), 72 | label="Mapping Factor", 73 | hide_details=True, 74 | dense=True, 75 | outlined=True, 76 | classes="pt-1", 77 | ) 78 | with vuetify.VCol(cols="6"): 79 | vuetify.VTextField( 80 | v_model=("morphofield_factor", 3000), 81 | label="Morphofield Factor", 82 | hide_details=True, 83 | dense=True, 84 | outlined=True, 85 | classes="pt-1", 86 | ) 87 | with vuetify.VRow(classes="pt-2", dense=True): 88 | with vuetify.VCol(cols="6"): 89 | vuetify.VTextField( 90 | v_model=("morphopath_t_end", 10000), 91 | label="Morphopath Length", 92 | hide_details=True, 93 | dense=True, 94 | outlined=True, 95 | classes="pt-1", 96 | ) 97 | with vuetify.VCol(cols="6"): 98 | vuetify.VTextField( 99 | v_model=("morphopath_downsampling", 500), 100 | label="Morphopath Sampling", 101 | hide_details=True, 102 | dense=True, 103 | outlined=True, 104 | classes="pt-1", 105 | ) 106 | with vuetify.VRow(classes="pt-2", dense=True): 107 | with vuetify.VCol(cols="6"): 108 | vuetify.VCheckbox( 109 | v_model=("morphofield_visibile", False), 110 | label="Morphofield Visibility", 111 | on_icon="mdi-pyramid", 112 | off_icon="mdi-pyramid-off", 113 | dense=True, 114 | hide_details=True, 115 | classes="pt-1", 116 | ) 117 | with vuetify.VCol(cols="6"): 118 | vuetify.VCheckbox( 119 | v_model=("morphopath_visibile", False), 120 | label="Morphopath Visibility", 121 | on_icon="mdi-octahedron", 122 | off_icon="mdi-octahedron-off", 123 | dense=True, 124 | hide_details=True, 125 | classes="pt-1", 126 | ) 127 | 128 | with vuetify.VRow(classes="pt-2", dense=True): 129 | with vuetify.VCol(cols="12"): 130 | vuetify.VTextField( 131 | v_model=("morphopath_animation_path", None), 132 | label="Morphogenesis Animation Output (MP4)", 133 | hide_details=True, 134 | dense=True, 135 | outlined=True, 136 | classes="pt-1", 137 | ) 138 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/drawer/output.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import vuetify 2 | 3 | 4 | def output_screenshot_content(): 5 | with vuetify.VRow(classes="pt-2", dense=True): 6 | with vuetify.VCol(cols="12"): 7 | vuetify.VTextField( 8 | v_model=("screenshot_path", None), 9 | label="Screenshot Output (PNG or PDF)", 10 | hide_details=True, 11 | dense=True, 12 | outlined=True, 13 | classes="pt-1", 14 | ) 15 | 16 | 17 | def output_animation_content(): 18 | with vuetify.VRow(classes="pt-2", dense=True): 19 | with vuetify.VCol(cols="6"): 20 | vuetify.VTextField( 21 | v_model=("animation_npoints", 50), 22 | label="Animation N Points", 23 | hide_details=True, 24 | dense=True, 25 | outlined=True, 26 | classes="pt-1", 27 | ) 28 | with vuetify.VCol(cols="6"): 29 | vuetify.VTextField( 30 | v_model=("animation_framerate", 10), 31 | label="Animation Framerate", 32 | hide_details=True, 33 | dense=True, 34 | outlined=True, 35 | classes="pt-1", 36 | ) 37 | 38 | with vuetify.VRow(classes="pt-2", dense=True): 39 | with vuetify.VCol(cols="12"): 40 | vuetify.VTextField( 41 | v_model=("animation_path", None), 42 | label="Animation Output (MP4)", 43 | hide_details=True, 44 | dense=True, 45 | outlined=True, 46 | classes="pt-1", 47 | ) 48 | 49 | 50 | def output_panel(): 51 | with vuetify.VToolbar( 52 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;" 53 | ): 54 | vuetify.VIcon("mdi-download") 55 | vuetify.VCardTitle( 56 | " Output Widgets", 57 | classes="pa-0 ma-0", 58 | style="flex: none;", 59 | hide_details=True, 60 | dense=True, 61 | ) 62 | 63 | vuetify.VSpacer() 64 | with vuetify.VBtn( 65 | small=True, 66 | icon=True, 67 | click="show_output_card = !show_output_card", 68 | ): 69 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_output_card",)) 70 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_output_card",)) 71 | 72 | # Main content 73 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"): 74 | with vuetify.VCardText(classes="py-2", v_if=("show_output_card",)): 75 | items = ["Screenshot", "Animation"] 76 | with vuetify.VTabs(v_model=("output_active_tab", 0), left=True): 77 | for item in items: 78 | vuetify.VTab( 79 | item, 80 | style="width: 50%;", 81 | ) 82 | with vuetify.VTabsItems( 83 | value=("output_active_tab",), 84 | style="width: 100%; height: 100%;", 85 | ): 86 | with vuetify.VTabItem(value=(0,)): 87 | output_screenshot_content() 88 | with vuetify.VTabItem(value=(1,)): 89 | output_animation_content() 90 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/drawer/pipeline.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import html, trame, vuetify 2 | 3 | from stviewer.Explorer.pv_pipeline.init_parameters import ( 4 | init_mesh_parameters, 5 | init_morphogenesis_parameters, 6 | init_pc_parameters, 7 | ) 8 | 9 | 10 | def pipeline_content(server, plotter): 11 | # server 12 | state, ctrl = server.state, server.controller 13 | 14 | @ctrl.set("actives_change") 15 | def actives_change(ids): 16 | _id = ids[0] 17 | active_actor_id = state.actor_ids[int(_id) - 1] 18 | state.active_ui = active_actor_id 19 | state.active_model_type = str(state.active_ui).split("_")[0] 20 | state.active_id = int(_id) 21 | 22 | if state.active_ui.startswith("PC"): 23 | state.update(init_pc_parameters) 24 | state.update(init_morphogenesis_parameters) 25 | elif state.active_ui.startswith("Mesh"): 26 | state.update(init_mesh_parameters) 27 | ctrl.view_update() 28 | 29 | @ctrl.set("visibility_change") 30 | def visibility_change(event): 31 | _id = event["id"] 32 | _visibility = event["visible"] 33 | active_actor = [value for value in plotter.actors.values()][int(_id) - 1] 34 | active_actor.SetVisibility(_visibility) 35 | if _visibility is True: 36 | state.vis_ids.append(int(_id) - 1) 37 | else: 38 | state.vis_ids.remove(int(_id) - 1) 39 | state.vis_ids = list(set(state.vis_ids)) 40 | ctrl.view_update() 41 | 42 | # main content 43 | trame.GitTree( 44 | sources=("pipeline",), 45 | actives_change=(ctrl.actives_change, "[$event]"), 46 | visibility_change=(ctrl.visibility_change, "[$event]"), 47 | ) 48 | 49 | 50 | def pipeline_panel(server, plotter): 51 | # Logo and title 52 | with vuetify.VToolbar( 53 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;" 54 | ): 55 | # Logo and title 56 | vuetify.VIcon("mdi-source-branch", style="transform: scale(1, -1);") 57 | vuetify.VCardTitle( 58 | " Pipeline", 59 | classes="pa-0 ma-0", 60 | style="flex: none;", 61 | hide_details=True, 62 | dense=True, 63 | ) 64 | 65 | # Main content 66 | pipeline_content(server=server, plotter=plotter) 67 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/layout.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | # ----------------------------------------------------------------------------- 4 | # GUI layout 5 | # ----------------------------------------------------------------------------- 6 | 7 | 8 | def ui_layout( 9 | server, template_name: str = "main", drawer_width: Optional[int] = None, **kwargs 10 | ): 11 | """ 12 | Define the user interface (UI) layout. 13 | Reference: https://trame.readthedocs.io/en/latest/trame.ui.vuetify.html#trame.ui.vuetify.SinglePageWithDrawerLayout 14 | 15 | Args: 16 | server: Server to bound the layout to. 17 | template_name: Name of the template. 18 | drawer_width: Drawer width in pixel. 19 | 20 | Returns: 21 | The SinglePageWithDrawerLayout layout object. 22 | """ 23 | from trame.ui.vuetify import SinglePageWithDrawerLayout 24 | 25 | if drawer_width is None: 26 | import pyautogui 27 | 28 | screen_width, screen_height = pyautogui.size() 29 | drawer_width = int(screen_width * 0.15) 30 | 31 | return SinglePageWithDrawerLayout( 32 | server, template_name=template_name, width=drawer_width, **kwargs 33 | ) 34 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/toolbar.py: -------------------------------------------------------------------------------- 1 | try: 2 | from typing import Literal 3 | except ImportError: 4 | from typing_extensions import Literal 5 | 6 | from typing import Optional 7 | 8 | from pyvista import BasePlotter 9 | from trame.widgets import html, vuetify 10 | 11 | from stviewer.assets import icon_manager, local_dataset_manager 12 | from stviewer.Explorer.pv_pipeline import SwitchModels, Viewer 13 | 14 | from .utils import button, checkbox 15 | 16 | # ----------------------------------------------------------------------------- 17 | # GUI- UI title 18 | # ----------------------------------------------------------------------------- 19 | 20 | 21 | def ui_title( 22 | layout, title_name="SPATEO VIEWER", title_icon: Optional[str] = None, **kwargs 23 | ): 24 | """ 25 | Define the title name and logo of the UI. 26 | Reference: https://trame.readthedocs.io/en/latest/trame.ui.vuetify.html#trame.ui.vuetify.SinglePageWithDrawerLayout 27 | 28 | Args: 29 | layout: The layout object. 30 | title_name: Title name of the GUI. 31 | title_icon: Title icon of the GUI. 32 | **kwargs: Additional parameters that will be passed to ``html.Img`` function. 33 | 34 | Returns: 35 | None. 36 | """ 37 | 38 | # Update the toolbar's name 39 | layout.title.set_text(title_name) 40 | layout.title.style = ( 41 | "font-family:arial; font-size:25px; font-weight: 550; color: gray;" 42 | ) 43 | 44 | # Update the toolbar's icon 45 | if not (title_icon is None): 46 | with layout.icon as icon: 47 | icon.style = "margin-left: 10px;" # "width: 7vw; height: 7vh;" 48 | html.Img(src=title_icon, height=40, **kwargs) 49 | 50 | 51 | # ----------------------------------------------------------------------------- 52 | # GUI- standard ToolBar 53 | # ----------------------------------------------------------------------------- 54 | 55 | 56 | def toolbar_widgets( 57 | server, 58 | plotter: BasePlotter, 59 | mode: Literal["trame", "server", "client"] = "trame", 60 | default_server_rendering: bool = True, 61 | ): 62 | """ 63 | Generate standard widgets for ToolBar. 64 | 65 | Args: 66 | server: The trame server. 67 | plotter: The PyVista plotter to connect with the UI. 68 | mode: The UI view mode. Options are: 69 | 70 | * ``'trame'``: Uses a view that can switch between client and server rendering modes. 71 | * ``'server'``: Uses a view that is purely server rendering. 72 | * ``'client'``: Uses a view that is purely client rendering (generally safe without a virtual frame buffer) 73 | default_server_rendering: Whether to use server-side or client-side rendering on-start when using the ``'trame'`` mode. 74 | """ 75 | if mode != "trame": 76 | default_server_rendering = mode == "server" 77 | 78 | viewer = Viewer(plotter=plotter, server=server, suppress_rendering=mode == "client") 79 | 80 | vuetify.VSpacer() 81 | # Whether to toggle the theme between light and dark 82 | checkbox( 83 | model="$vuetify.theme.dark", 84 | icons=("mdi-lightbulb-off-outline", "mdi-lightbulb-outline"), 85 | tooltip=f"Toggle theme", 86 | ) 87 | # Whether to toggle the background color between white and black 88 | checkbox( 89 | model=(viewer.BACKGROUND, False), 90 | icons=("mdi-palette-swatch-outline", "mdi-palette-swatch"), 91 | tooltip=f"Toggle background ({{{{ {viewer.BACKGROUND} ? 'white' : 'black' }}}})", 92 | ) 93 | # Server rendering options 94 | if mode == "trame": 95 | checkbox( 96 | model=(viewer.SERVER_RENDERING, default_server_rendering), 97 | icons=("mdi-lan-connect", "mdi-lan-disconnect"), 98 | tooltip=f"Toggle rendering mode ({{{{ {viewer.SERVER_RENDERING} ? 'remote' : 'local' }}}})", 99 | ) 100 | # Whether to visualize the memory usage 101 | checkbox( 102 | model=(viewer.MEMORY_USAGE, False), 103 | icons=("mdi-memory", "mdi-memory"), 104 | tooltip=f"Toggle memory usage ({{{{ {viewer.MEMORY_USAGE} ? 'on' : 'off' }}}})", 105 | ) 106 | 107 | vuetify.VDivider(vertical=True, classes="mx-1") 108 | # Whether to show the main model 109 | checkbox( 110 | model=(viewer.SHOW_MAIN_MODEL, True), 111 | icons=("mdi-eye-outline", "mdi-eye-off-outline"), 112 | tooltip=f"Toggle main model visibility ({{{{ {viewer.SHOW_MAIN_MODEL} ? 'True' : 'False' }}}})", 113 | ) 114 | # Whether to add outline 115 | checkbox( 116 | model=(viewer.OUTLINE, False), 117 | icons=("mdi-cube", "mdi-cube-off"), 118 | tooltip=f"Toggle bounding box ({{{{ {viewer.OUTLINE} ? 'on' : 'off' }}}})", 119 | ) 120 | # Whether to add grid 121 | checkbox( 122 | model=(viewer.GRID, False), 123 | icons=("mdi-ruler-square", "mdi-ruler-square"), 124 | tooltip=f"Toggle ruler ({{{{ {viewer.GRID} ? 'on' : 'off' }}}})", 125 | ) 126 | # Whether to add axis legend 127 | checkbox( 128 | model=(viewer.AXIS, False), 129 | icons=("mdi-axis-arrow-info", "mdi-axis-arrow-info"), 130 | tooltip=f"Toggle axis ({{{{ {viewer.AXIS} ? 'on' : 'off' }}}})", 131 | ) 132 | 133 | # Reset camera 134 | vuetify.VDivider(vertical=True, classes="mx-1") 135 | button( 136 | click=viewer.view_isometric, 137 | icon="mdi-axis-arrow", 138 | tooltip="Perspective view", 139 | ) 140 | button( 141 | click=viewer.view_yz, 142 | icon="mdi-axis-x-arrow", 143 | tooltip="Reset camera X", 144 | ) 145 | button( 146 | click=viewer.view_xz, 147 | icon="mdi-axis-y-arrow", 148 | tooltip="Reset camera Y", 149 | ) 150 | button( 151 | click=viewer.view_xy, 152 | icon="mdi-axis-z-arrow", 153 | tooltip="Reset camera Z", 154 | ) 155 | 156 | 157 | def toolbar_switch_model( 158 | server, 159 | plotter: BasePlotter, 160 | ): 161 | avaliable_samples = [ 162 | key 163 | for key in local_dataset_manager.get_assets().keys() 164 | if not str(key).endswith("anndata") 165 | ] 166 | avaliable_samples.append("uploaded_sample") 167 | 168 | vuetify.VSpacer() 169 | SM = SwitchModels(server=server, plotter=plotter) 170 | vuetify.VSelect( 171 | label="Select Samples", 172 | v_model=(SM.SELECT_SAMPLES, None), 173 | items=("samples", avaliable_samples), 174 | dense=True, 175 | outlined=True, 176 | hide_details=True, 177 | classes="ml-8", 178 | prepend_inner_icon="mdi-magnify", 179 | style="max-width: 300px;", 180 | rounded=True, 181 | ) 182 | # Select local directory 183 | button( 184 | # Must use single-quote string for JS here 185 | click=server.controller.open_directory, 186 | icon="mdi-file-document-outline", 187 | tooltip="Select directory", 188 | ) 189 | 190 | 191 | def ui_toolbar( 192 | server, 193 | layout, 194 | plotter: BasePlotter, 195 | mode: Literal["trame", "server", "client"] = "trame", 196 | default_server_rendering: bool = True, 197 | ui_name: str = "SPATEO VIEWER", 198 | ui_icon=icon_manager.spateo_logo, 199 | ): 200 | """ 201 | Generate standard ToolBar for Spateo UI. 202 | 203 | Args: 204 | server: The trame server. 205 | layout: The layout object. 206 | plotter: The PyVista plotter to connect with the UI. 207 | mode: The UI view mode. Options are: 208 | 209 | * ``'trame'``: Uses a view that can switch between client and server rendering modes. 210 | * ``'server'``: Uses a view that is purely server rendering. 211 | * ``'client'``: Uses a view that is purely client rendering (generally safe without a virtual frame buffer) 212 | default_server_rendering: Whether to use server-side or client-side rendering on-start when using the ``'trame'`` mode. 213 | ui_name: Title name of the GUI. 214 | ui_icon: Title icon of the GUI. 215 | """ 216 | 217 | # ----------------------------------------------------------------------------- 218 | # Title 219 | # ----------------------------------------------------------------------------- 220 | ui_title(layout=layout, title_name=ui_name, title_icon=ui_icon) 221 | 222 | # ----------------------------------------------------------------------------- 223 | # ToolBar 224 | # ----------------------------------------------------------------------------- 225 | with layout.toolbar as tb: 226 | tb.height = 55 227 | tb.dense = True 228 | tb.clipped_right = True 229 | 230 | toolbar_switch_model(server=server, plotter=plotter) 231 | toolbar_widgets( 232 | server=server, 233 | plotter=plotter, 234 | mode=mode, 235 | default_server_rendering=default_server_rendering, 236 | ) 237 | -------------------------------------------------------------------------------- /stviewer/Explorer/ui/utils.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import html, vuetify 2 | 3 | # ----------------------------------------------------------------------------- 4 | # vuetify components 5 | # ----------------------------------------------------------------------------- 6 | 7 | 8 | def button(click, icon, tooltip): 9 | """Create a vuetify button.""" 10 | with vuetify.VTooltip(bottom=True): 11 | with vuetify.Template(v_slot_activator="{ on, attrs }"): 12 | with vuetify.VBtn(icon=True, v_bind="attrs", v_on="on", click=click): 13 | vuetify.VIcon(icon) 14 | html.Span(tooltip) 15 | 16 | 17 | def checkbox(model, icons, tooltip, **kwargs): 18 | """Create a vuetify checkbox.""" 19 | with vuetify.VTooltip(bottom=True): 20 | with vuetify.Template(v_slot_activator="{ on, attrs }"): 21 | with html.Div(v_on="on", v_bind="attrs"): 22 | vuetify.VCheckbox( 23 | v_model=model, 24 | on_icon=icons[0], 25 | off_icon=icons[1], 26 | dense=True, 27 | hide_details=True, 28 | classes="my-0 py-0 ml-1", 29 | **kwargs 30 | ) 31 | html.Span(tooltip) 32 | 33 | 34 | def switch(model, tooltip, **kwargs): 35 | """Create a vuetify switch.""" 36 | with vuetify.VTooltip(bottom=True): 37 | with vuetify.Template(v_slot_activator="{ on, attrs }"): 38 | vuetify.VSwitch(v_model=model, hide_details=True, dense=True, **kwargs) 39 | html.Span(tooltip) 40 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/__init__.py: -------------------------------------------------------------------------------- 1 | from .pv_pipeline import * 2 | from .ui import ui_container, ui_drawer, ui_layout, ui_toolbar 3 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/pv_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .init_parameters import * 2 | from .pv_callback import Viewer 3 | from .pv_models import init_models 4 | from .pv_plotter import add_single_model, create_plotter 5 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/pv_pipeline/init_parameters.py: -------------------------------------------------------------------------------- 1 | import pyautogui 2 | 3 | # Init parameters 4 | init_active_parameters = { 5 | "picking_group": None, 6 | "overwrite": False, 7 | "activeModelVisible": True, 8 | "activeModel_output": None, 9 | "anndata_output": None, 10 | } 11 | init_align_parameters = { 12 | "slices_alignment": False, 13 | "slices_key": "slices", 14 | "slices_align_device": "CPU", 15 | "slices_align_method": "Paste", 16 | "slices_align_factor": 0.1, 17 | "slices_align_max_iter": 200, 18 | } 19 | init_mesh_parameters = { 20 | "meshModel": None, 21 | "meshModelVisible": False, 22 | "reconstruct_mesh": False, 23 | "mc_factor": 1.0, 24 | "mesh_voronoi": 20000, 25 | "mesh_smooth_factor": 2000, 26 | "mesh_scale_factor": 1.0, 27 | "clip_pc_with_mesh": False, 28 | "mesh_output": None, 29 | } 30 | init_picking_parameters = { 31 | "modes": [ 32 | {"value": "hover", "icon": "mdi-magnify"}, 33 | {"value": "click", "icon": "mdi-cursor-default-click-outline"}, 34 | {"value": "select", "icon": "mdi-select-drag"}, 35 | ], 36 | "pickData": None, 37 | "selectData": None, 38 | "resetModel": False, 39 | "tooltip": "", 40 | } 41 | init_setting_parameters = { 42 | "show_active_card": True, 43 | "show_align_card": True, 44 | "show_mesh_card": True, 45 | "background_color": "[0, 0, 0]", 46 | "pixel_ratio": pyautogui.size()[0] / 500, 47 | } 48 | 49 | # costum init parameters 50 | init_custom_parameters = { 51 | "custom_func": False, 52 | "custom_analysis": False, 53 | "custom_model": None, 54 | "custom_model_visible": False, 55 | "custom_model_size": pyautogui.size()[0] / 200, 56 | "custom_model_output": None, 57 | "show_custom_card": True, 58 | "custom_parameter1": "ElPiGraph", 59 | "custom_parameter2": 50, 60 | } 61 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/pv_pipeline/pv_custom.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | try: 4 | from typing import Literal 5 | except ImportError: 6 | from typing_extensions import Literal 7 | 8 | from typing import Optional, Tuple, Union 9 | 10 | import numpy as np 11 | from pyvista import PolyData, UnstructuredGrid 12 | 13 | try: 14 | from typing import Literal 15 | except ImportError: 16 | from typing_extensions import Literal 17 | 18 | ##################################################################### 19 | # Principal curves algorithm # 20 | # ================================================================= # 21 | # Original Code Repository Author: Matthew Artuso. # 22 | # Adapted to Spateo by: spateo authors # 23 | # Created Date: 6/11/2022 # 24 | # Description: A principal curve is a smooth n-dimensional curve # 25 | # that passes through the middle of a dataset. # 26 | # Reference: https://doi.org/10.1016/j.cam.2015.11.041 # 27 | # ================================================================= # 28 | ##################################################################### 29 | 30 | 31 | class NLPCA(object): 32 | """This is a global solver for principal curves that uses neural networks. 33 | Attributes: 34 | None 35 | """ 36 | 37 | def __init__(self): 38 | self.fit_points = None 39 | self.model = None 40 | self.intermediate_layer_model = None 41 | 42 | def fit( 43 | self, 44 | data: np.ndarray, 45 | epochs: int = 500, 46 | nodes: int = 25, 47 | lr: float = 0.01, 48 | verbose: int = 0, 49 | ): 50 | """ 51 | This method creates a model and will fit it to the given m x n dimensional data. 52 | 53 | Args: 54 | data: A numpy array of shape (m,n), where m is the number of points and n is the number of dimensions. 55 | epochs: Number of epochs to train neural network, defaults to 500. 56 | nodes: Number of nodes for the construction layers. Defaults to 25. The more complex the curve, the higher 57 | this number should be. 58 | lr: Learning rate for backprop. Defaults to .01 59 | verbose: Verbose = 0 mutes the training text from Keras. Defaults to 0. 60 | """ 61 | try: 62 | from keras.models import Model 63 | except ImportError: 64 | raise ImportError( 65 | "You need to install the package `tensorflow`." 66 | "\nInstall tensorflow via `pip install -U tensorflow`." 67 | ) 68 | num_dim = data.shape[1] # get number of dimensions for pts 69 | 70 | # create models, base and intermediate 71 | model = self.create_model(num_dim, nodes=nodes, lr=lr) 72 | bname = model.layers[2].name # bottle-neck layer name 73 | 74 | # The itermediate model gets the output of the bottleneck layer, 75 | # which acts as the projection layer. 76 | self.intermediate_layer_model = Model( 77 | inputs=model.input, outputs=model.get_layer(bname).output 78 | ) 79 | 80 | # Fit the model and set the instances self.model to model 81 | model.fit(data, data, epochs=epochs, verbose=verbose) 82 | self.model = model 83 | 84 | return 85 | 86 | def project(self, data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 87 | """ 88 | The project function will project the points to the curve generated by the fit function. Given back is the 89 | projection index of the original data and a sorted version of the original data. 90 | 91 | Args: 92 | data: m x n array to project to the curve 93 | 94 | Returns: 95 | proj: A one-dimension array that contains the projection index for each point in data. 96 | all_sorted: A m x n+1 array that contains data sorted by its projection index, along with the index. 97 | """ 98 | num_dim = data.shape[1] # get number of dimensions for pts 99 | 100 | pts = self.model.predict(data) 101 | proj = self.intermediate_layer_model.predict(data) 102 | self.fit_points = pts 103 | 104 | all = np.concatenate([pts, proj], axis=1) 105 | all_sorted = all[all[:, num_dim].argsort()] 106 | 107 | return proj, all_sorted 108 | 109 | def create_model(self, num_dim: int, nodes: int, lr: float): 110 | """ 111 | Creates a tf model. 112 | 113 | Args: 114 | num_dim: How many dimensions the input space is 115 | nodes: How many nodes for the construction layers 116 | lr: Learning rate of backpropigation 117 | 118 | Returns: 119 | model (object): Keras Model 120 | """ 121 | # Create layers: 122 | # Function G 123 | try: 124 | import tensorflow as tf 125 | from keras import optimizers 126 | from keras.layers import Dense, Input 127 | from keras.models import Model 128 | except ImportError: 129 | raise ImportError( 130 | "You need to install the package `tensorflow`." 131 | "\nInstall tensorflow via `pip install -U tensorflow`." 132 | ) 133 | 134 | def orth_dist(y_true, y_pred): 135 | """ 136 | Loss function for the NLPCA NN. Returns the sum of the orthogonal 137 | distance from the output tensor to the real tensor. 138 | """ 139 | return tf.math.reduce_sum((y_true - y_pred) ** 2) 140 | 141 | input = Input(shape=(num_dim,)) # input layer 142 | mapping = Dense(nodes, activation="sigmoid")(input) # mapping layer 143 | bottle = Dense(1, activation="sigmoid")(mapping) # bottle-neck layer 144 | 145 | # Function H 146 | demapping = Dense(nodes, activation="sigmoid")(bottle) # mapping layer 147 | output = Dense(num_dim)(demapping) # output layer 148 | 149 | # Connect and compile model: 150 | model = Model(inputs=input, outputs=output) 151 | gradient_descent = optimizers.Adam(learning_rate=lr) 152 | model.compile(loss=orth_dist, optimizer=gradient_descent) 153 | 154 | return model 155 | 156 | 157 | ########### 158 | # Methods # 159 | ########### 160 | 161 | 162 | def _euclidean_distance(N1, N2): 163 | temp = np.asarray(N1) - np.asarray(N2) 164 | euclid_dist = np.sqrt(np.dot(temp.T, temp)) 165 | return euclid_dist 166 | 167 | 168 | def sort_nodes_of_curve(nodes, started_node): 169 | current_node = tuple(started_node) 170 | remaining_nodes = [tuple(node) for node in nodes] 171 | 172 | sorted_nodes = [] 173 | while remaining_nodes: 174 | closest_node = min( 175 | remaining_nodes, key=lambda x: _euclidean_distance(current_node, x) 176 | ) 177 | sorted_nodes.append(closest_node) 178 | remaining_nodes.remove(closest_node) 179 | current_node = closest_node 180 | sorted_nodes = np.asarray([list(sn) for sn in sorted_nodes]) 181 | return sorted_nodes 182 | 183 | 184 | def ElPiGraph_method( 185 | X: np.ndarray, 186 | NumNodes: int = 50, 187 | topology: Literal["tree", "circle", "curve"] = "curve", 188 | Lambda: float = 0.01, 189 | Mu: float = 0.1, 190 | alpha: float = 0.0, 191 | FinalEnergy: Literal["Base", "Penalized"] = "Penalized", 192 | **kwargs, 193 | ) -> Tuple[np.ndarray, np.ndarray]: 194 | """ 195 | Generate a principal elastic tree. 196 | Reference: Albergante et al. (2020), Robust and Scalable Learning of Complex Intrinsic Dataset Geometry via ElPiGraph. 197 | 198 | Args: 199 | X: DxN, data matrix list. 200 | NumNodes: The number of nodes of the principal graph. Use a range of 10 to 100 for ElPiGraph approach. 201 | topology:The appropriate topology used to fit a principal graph for each dataset. 202 | Lambda: The attractive strength of edges between nodes (constrains edge lengths) 203 | Mu: The repulsive strength of a node’s neighboring nodes (constrains angles to be close to harmonic) 204 | alpha: Branching penalty (penalizes number of branches for the principal tree) 205 | FinalEnergy: Indicating the final elastic emergy associated with the configuration. Currently it can be “Base” or “Penalized” 206 | **kwargs: Other parameters used in elpigraph.computeElasticPrincipalTree. For details, please see: 207 | https://elpigraph-python.readthedocs.io/en/latest/basics.html 208 | 209 | Returns: 210 | nodes: The nodes in the principal tree. 211 | edges: The edges between nodes in the principal tree. 212 | """ 213 | try: 214 | import elpigraph 215 | except ImportError: 216 | raise ImportError( 217 | "You need to install the package `elpigraph-python`." 218 | "\nInstall elpigraph-python via `pip install elpigraph-python`." 219 | ) 220 | 221 | ElPiGraph_kwargs = { 222 | "NumNodes": NumNodes, 223 | "Lambda": Lambda, 224 | "Mu": Mu, 225 | "alpha": alpha, 226 | "FinalEnergy": FinalEnergy, 227 | } 228 | ElPiGraph_kwargs.update(kwargs) 229 | 230 | if str(topology).lower() == "tree": 231 | elpi_tree = elpigraph.computeElasticPrincipalTree( 232 | X=np.asarray(X), **ElPiGraph_kwargs 233 | ) 234 | elif str(topology).lower() == "circle": 235 | elpi_tree = elpigraph.computeElasticPrincipalCircle( 236 | X=np.asarray(X), **ElPiGraph_kwargs 237 | ) 238 | elif str(topology).lower() == "curve": 239 | elpi_tree = elpigraph.computeElasticPrincipalCurve( 240 | X=np.asarray(X), **ElPiGraph_kwargs 241 | ) 242 | else: 243 | raise ValueError( 244 | "`topology` value is wrong." 245 | "\nAvailable `topology` are: `'tree'`, `'circle'`, `'curve'`." 246 | ) 247 | 248 | nodes = elpi_tree[0]["NodePositions"] 249 | edges = np.asarray(elpi_tree[0]["Edges"][0]) 250 | 251 | if str(topology).lower() in ["curve", "circle"]: 252 | unique_values, occurrence_count = np.unique(edges.flatten(), return_counts=True) 253 | started_node_indices = [ 254 | v for c, v in zip(occurrence_count, unique_values) if c == 1 255 | ] 256 | started_node = ( 257 | nodes[started_node_indices[0]] 258 | if len(started_node_indices) != 0 259 | else nodes[0] 260 | ) 261 | 262 | nodes = sort_nodes_of_curve(nodes, started_node) 263 | if str(topology).lower() == "curve": 264 | edges = np.c_[ 265 | np.arange(0, len(nodes) - 1, 1).reshape(-1, 1), 266 | np.arange(1, len(nodes), 1).reshape(-1, 1), 267 | ] 268 | else: 269 | edges = np.c_[ 270 | np.arange(0, len(nodes), 1).reshape(-1, 1), 271 | np.asarray(list(range(1, len(nodes))) + [0]).reshape(-1, 1), 272 | ] 273 | 274 | return nodes, edges 275 | 276 | 277 | def SimplePPT_method( 278 | X: np.ndarray, 279 | NumNodes: int = 50, 280 | sigma: Optional[Union[float, int]] = 0.1, 281 | lam: Optional[Union[float, int]] = 1, 282 | metric: str = "euclidean", 283 | nsteps: int = 50, 284 | err_cut: float = 5e-3, 285 | seed: Optional[int] = 1, 286 | **kwargs, 287 | ) -> Tuple[np.ndarray, np.ndarray]: 288 | """ 289 | Generate a simple principal tree. 290 | Reference: Mao et al. (2015), SimplePPT: A simple principal tree algorithm, SIAM International Conference on Data Mining. 291 | 292 | Args: 293 | X: DxN, data matrix list. 294 | NumNodes: The number of nodes of the principal graph. Use a range of 100 to 2000 for PPT approach. 295 | sigma: Regularization parameter. 296 | lam: Penalty for the tree length. 297 | metric: The metric to use to compute distances in high dimensional space. For compatible metrics, check the 298 | documentation of sklearn.metrics.pairwise_distances. 299 | nsteps: Number of steps for the optimisation process. 300 | err_cut: Stop algorithm if proximity of principal points between iterations less than defined value. 301 | seed: A numpy random seed. 302 | **kwargs: Other parameters used in simpleppt.ppt. For details, please see: 303 | https://github.com/LouisFaure/simpleppt/blob/main/simpleppt/ppt.py 304 | 305 | Returns: 306 | nodes: The nodes in the principal tree. 307 | edges: The edges between nodes in the principal tree. 308 | """ 309 | try: 310 | import igraph 311 | import simpleppt 312 | except ImportError: 313 | raise ImportError( 314 | "You need to install the package `simpleppt` and `igraph`." 315 | "\nInstall simpleppt via `pip install -U simpleppt`." 316 | "\nInstall igraph via `pip install -U igraph`" 317 | ) 318 | 319 | SimplePPT_kwargs = { 320 | "seed": seed, 321 | "lam": lam, 322 | "sigma": sigma, 323 | "metric": metric, 324 | "nsteps": nsteps, 325 | "err_cut": err_cut, 326 | } 327 | SimplePPT_kwargs.update(kwargs) 328 | 329 | X = np.asarray(X) 330 | ppt_tree = simpleppt.ppt(X=X, Nodes=NumNodes, **SimplePPT_kwargs) 331 | 332 | R = ppt_tree.R 333 | nodes = (np.dot(X.T, R) / R.sum(axis=0)).T 334 | 335 | B = ppt_tree.B 336 | edges = np.array( 337 | igraph.Graph.Adjacency((B > 0).tolist(), mode="undirected").get_edgelist() 338 | ) 339 | 340 | return nodes, edges 341 | 342 | 343 | def PrinCurve_method( 344 | X: np.ndarray, 345 | NumNodes: int = 50, 346 | epochs: int = 500, 347 | lr: float = 0.01, 348 | scale_factor: Union[int, float] = 1, 349 | **kwargs, 350 | ) -> Tuple[np.ndarray, np.ndarray]: 351 | """ 352 | This is the global module that contains principal curve and nonlinear principal component analysis algorithms that 353 | work to optimize a line over an entire dataset. 354 | Reference: Chen et al. (2016), Constraint local principal curve: Concept, algorithms and applications. 355 | 356 | Args: 357 | X: DxN, data matrix list. 358 | NumNodes: Number of nodes for the construction layers. Defaults to 50. The more complex the curve, the higher this number should be. 359 | epochs: Number of epochs to train neural network, defaults to 500. 360 | lr: Learning rate for backprop. Defaults to .01 361 | scale_factor: 362 | **kwargs: Other parameters used in global algorithms. For details, please see: 363 | https://github.com/artusoma/prinPy/blob/master/prinpy/glob.py 364 | 365 | Returns: 366 | nodes: The nodes in the principal tree. 367 | edges: The edges between nodes in the principal tree. 368 | """ 369 | try: 370 | from dynamo.tools.sampling import sample 371 | except ImportError: 372 | raise ImportError( 373 | "You need to install the package `dynamo`." 374 | "\nInstall dynamo via `pip install -U dynamo-release`." 375 | ) 376 | PrinCurve_kwargs = { 377 | "epochs": epochs, 378 | "lr": lr, 379 | "verbose": 0, 380 | } 381 | PrinCurve_kwargs.update(kwargs) 382 | 383 | raw_X = np.asarray(X) 384 | dims = raw_X.shape[1] 385 | 386 | new_X = raw_X.copy() / scale_factor 387 | trans = [] 388 | for i in range(dims): 389 | sub_trans = new_X[:, i].min() 390 | new_X[:, i] = new_X[:, i] - sub_trans 391 | trans.append(sub_trans) 392 | # create solver 393 | pca_project = NLPCA() 394 | # transform data for better training with the neural net using built in preprocessor. 395 | # fit the data 396 | pca_project.fit(new_X, nodes=NumNodes, **PrinCurve_kwargs) 397 | # project the current data. This returns a projection index for each point and points to plot the curve. 398 | _, curve_pts = pca_project.project(new_X) 399 | curve_pts = np.unique(curve_pts, axis=0) 400 | curve_pts = np.einsum("ij->ij", curve_pts[curve_pts[:, -1].argsort(), :]) 401 | for i in range(dims): 402 | curve_pts[:, i] = curve_pts[:, i] + trans[i] 403 | 404 | nodes = curve_pts[:, :3] * scale_factor 405 | sampling = sample( 406 | arr=np.asarray(range(nodes.shape[0])), n=NumNodes, method="trn", X=nodes 407 | ) 408 | sampling.sort() 409 | nodes = nodes[sampling, :] 410 | n_nodes = nodes.shape[0] 411 | edges = np.asarray([np.arange(0, n_nodes, 1), np.arange(1, n_nodes + 1, 1)]).T 412 | edges[-1, 1] = n_nodes - 1 413 | """ 414 | unique_values, occurrence_count = np.unique(edges.flatten(), return_counts=True) 415 | started_node_indices = [v for c, v in zip(occurrence_count, unique_values) if c == 1] 416 | started_node = nodes[started_node_indices[0]] if len(started_node_indices) != 0 else nodes[0] 417 | 418 | sorted_nodes = sort_nodes_of_curve(nodes, started_node) 419 | sorted_edges = np.c_[np.arange(0, len(nodes) - 1, 1).reshape(-1, 1), np.arange(1, len(nodes), 1).reshape(-1, 1)] 420 | """ 421 | return nodes, edges 422 | 423 | 424 | def construct_backbone( 425 | model: Union[PolyData, UnstructuredGrid], 426 | spatial_key: Optional[str] = None, 427 | nodes_key: str = "nodes", 428 | rd_method: Literal["ElPiGraph", "SimplePPT", "PrinCurve"] = "ElPiGraph", 429 | num_nodes: int = 50, 430 | **kwargs, 431 | ) -> PolyData: 432 | """ 433 | Organ's backbone construction based on 3D point cloud model. 434 | 435 | Args: 436 | model: A point cloud model. 437 | spatial_key: If spatial_key is None, the spatial coordinates are in model.points, otherwise in model[spatial_key]. 438 | nodes_key: The key that corresponds to the coordinates of the nodes in the backbone. 439 | rd_method: The method of constructing a backbone model. Available ``rd_method`` are: 440 | 441 | * ``'ElPiGraph'``: Generate a principal elastic tree. 442 | * ``'SimplePPT'``: Generate a simple principal tree. 443 | * ``'PrinCurve'``: This is the global module that contains principal curve and nonlinear principal 444 | component analysis algorithms that work to optimize a line over an entire dataset. 445 | num_nodes: Number of nodes for the backbone model. 446 | **kwargs: Additional parameters that will be passed to ``ElPiGraph_method``, ``SimplePPT_method`` or ``PrinCurve_method`` function. 447 | 448 | Returns: 449 | backbone_model: A three-dims backbone model. 450 | """ 451 | import pyvista as pv 452 | 453 | model = model.copy() 454 | X = model.points if spatial_key is None else model[spatial_key] 455 | 456 | if rd_method == "ElPiGraph": 457 | nodes, edges = ElPiGraph_method(X=X, NumNodes=num_nodes, **kwargs) 458 | elif rd_method == "SimplePPT": 459 | nodes, edges = SimplePPT_method(X=X, NumNodes=num_nodes, **kwargs) 460 | elif rd_method == "PrinCurve": 461 | nodes, edges = PrinCurve_method(X=X, NumNodes=num_nodes, **kwargs) 462 | else: 463 | raise ValueError( 464 | "`rd_method` value is wrong." 465 | "\nAvailable `rd_method` are: `'ElPiGraph'`, `'SimplePPT'`, `'PrinCurve'`." 466 | ) 467 | 468 | # Construct the backbone model 469 | # padding = np.array([2] * edges.shape[0], int) 470 | # edges_w_padding = np.vstack((padding, edges.T)).T 471 | # backbone_model = pv.PolyData(nodes, edges_w_padding) 472 | 473 | backbone_model = pv.MultipleLines(points=nodes) 474 | backbone_model.point_data[nodes_key] = np.arange(0, len(nodes), 1) 475 | return backbone_model 476 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/pv_pipeline/pv_models.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | import anndata as ad 6 | import numpy as np 7 | 8 | from .pv_plotter import add_single_model 9 | 10 | 11 | def check_model_data(model, point_data: bool = True, cell_data: bool = True): 12 | # obtain the data of points 13 | pdd = {} 14 | if point_data: 15 | for name, array in model.point_data.items(): 16 | if name != "obs_index": 17 | array = np.asarray(array) 18 | if len(array.shape) == 1 and name not in [ 19 | "vtkOriginalPointIds", 20 | "SelectedPoints", 21 | "vtkInsidedness", 22 | ]: 23 | od = {"None": "None"} 24 | if not np.issubdtype(array.dtype, np.number): 25 | od = {o: i for i, o in enumerate(np.unique(array).tolist())} 26 | model.point_data[name] = np.asarray( 27 | list(map(lambda x: od[x], array)), dtype=float 28 | ) 29 | array = np.asarray(model.point_data[name]) 30 | pdd[name] = { 31 | "name": name, 32 | "range": [array.min(), array.max()], 33 | "value": name, 34 | "text": name, 35 | "scalarMode": 3, 36 | "raw_labels": od, 37 | } 38 | 39 | # obtain the data of cells 40 | cdd = {} 41 | if cell_data: 42 | for name, array in model.cell_data.items(): 43 | if name != "obs_index": 44 | array = np.asarray(array) 45 | if len(array.shape) == 1 and name not in [ 46 | "vtkOriginalCellIds", 47 | "orig_extract_id", 48 | "vtkInsidedness", 49 | ]: 50 | od = {"None": "None"} 51 | if not np.issubdtype(array.dtype, np.number): 52 | od = {o: i for i, o in enumerate(np.unique(array).tolist())} 53 | model.cell_data[name] = np.asarray( 54 | list(map(lambda x: od[x], array)), dtype=float 55 | ) 56 | array = np.asarray(model.cell_data[name]) 57 | cdd[name] = { 58 | "name": name, 59 | "range": [array.min(), array.max()], 60 | "value": name, 61 | "text": name, 62 | "scalarMode": 3, 63 | "raw_labels": od, 64 | } 65 | 66 | return model, pdd, cdd 67 | 68 | 69 | def init_models(plotter, anndata_path): 70 | # Generate init anndata object 71 | init_adata = ad.read_h5ad(anndata_path) 72 | init_adata.obs["Default"] = np.ones(shape=(init_adata.shape[0], 1)) 73 | init_adata.obsm["spatial"] = ( 74 | np.c_[init_adata.obsm["spatial"], np.ones(shape=(init_adata.shape[0], 1))] 75 | if init_adata.obsm["spatial"].shape[1] == 2 76 | else init_adata.obsm["spatial"] 77 | ) 78 | spatial_center = init_adata.obsm["spatial"].mean(axis=0) 79 | if tuple(spatial_center) != (0, 0, 0): 80 | init_adata.obsm["spatial"] = init_adata.obsm["spatial"] - spatial_center 81 | for key in init_adata.obs_keys(): 82 | if init_adata.obs[key].dtype == "category": 83 | init_adata.obs[key] = np.asarray(init_adata.obs[key], dtype=str) 84 | if np.issubdtype(init_adata.obs[key].dtype, np.number): 85 | init_adata.obs[key] = np.asarray(init_adata.obs[key], dtype=float) 86 | 87 | # Construct init pc model 88 | from .pv_tdr import construct_pc 89 | 90 | main_model = construct_pc(adata=init_adata, spatial_key="spatial") 91 | _obs_index = main_model.point_data["obs_index"] 92 | for key in init_adata.obs_keys(): 93 | main_model.point_data[key] = init_adata.obs.loc[_obs_index, key] 94 | 95 | main_model, pdd, cdd = check_model_data( 96 | model=main_model, point_data=True, cell_data=True 97 | ) 98 | _ = add_single_model(plotter=plotter, model=main_model, model_name="mainModel") 99 | 100 | # Generate active model 101 | active_model = main_model.copy() 102 | _ = add_single_model( 103 | plotter=plotter, 104 | model=active_model, 105 | model_name="activeModel", 106 | ) 107 | 108 | # Init parameters 109 | init_scalar = "Default" 110 | return main_model, active_model, init_scalar, pdd, cdd 111 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/pv_pipeline/pv_plotter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import pyvista as pv 4 | from pyvista import Plotter, PolyData, UnstructuredGrid 5 | 6 | try: 7 | from typing import Literal 8 | except ImportError: 9 | from typing_extensions import Literal 10 | 11 | 12 | def create_plotter( 13 | window_size: tuple = (1024, 1024), background: str = "black", **kwargs 14 | ) -> Plotter: 15 | """ 16 | Create a plotting object to display pyvista/vtk model. 17 | 18 | Args: 19 | window_size: Window size in pixels. The default window_size is ``[1024, 768]``. 20 | background: The background color of the window. 21 | 22 | Returns: 23 | plotter: The plotting object to display pyvista/vtk model. 24 | """ 25 | 26 | # Create an initial plotting object. 27 | plotter = pv.Plotter( 28 | window_size=window_size, off_screen=True, lighting="light_kit", **kwargs 29 | ) 30 | 31 | # Set the background color of the active render window. 32 | plotter.background_color = background 33 | return plotter 34 | 35 | 36 | def add_single_model( 37 | plotter: Plotter, 38 | model: Union[PolyData, UnstructuredGrid], 39 | key: Optional[str] = None, 40 | cmap: Optional[str] = "rainbow", 41 | color: Optional[str] = "gainsboro", 42 | ambient: float = 0.2, 43 | opacity: float = 1.0, 44 | model_style: Literal["points", "surface", "wireframe"] = "surface", 45 | model_size: float = 3.0, 46 | model_name: Optional[str] = None, 47 | ): 48 | """ 49 | Add model(s) to the plotter. 50 | Args: 51 | plotter: The plotting object to display pyvista/vtk model. 52 | model: A reconstructed model. 53 | key: The key under which are the labels. 54 | cmap: Name of the Matplotlib colormap to use when mapping the model. 55 | color: Name of the Matplotlib color to use when mapping the model. 56 | ambient: When lighting is enabled, this is the amount of light in the range of 0 to 1 (default 0.0) that reaches 57 | the actor when not directed at the light source emitted from the viewer. 58 | opacity: Opacity of the model. 59 | If a single float value is given, it will be the global opacity of the model and uniformly applied 60 | everywhere, elif a numpy.ndarray with single float values is given, it 61 | will be the opacity of each point. - should be between 0 and 1. 62 | A string can also be specified to map the scalars range to a predefined opacity transfer function 63 | (options include: 'linear', 'linear_r', 'geom', 'geom_r'). 64 | model_style: Visualization style of the model. One of the following: 65 | * ``model_style = 'surface'``, 66 | * ``model_style = 'wireframe'``, 67 | * ``model_style = 'points'``. 68 | model_size: If ``model_style = 'points'``, point size of any nodes in the dataset plotted. 69 | If ``model_style = 'wireframe'``, thickness of lines. 70 | model_name: Name to assign to the model. Defaults to the memory address. 71 | """ 72 | 73 | if model_style == "points": 74 | render_spheres, render_tubes, smooth_shading = True, False, True 75 | elif model_style == "wireframe": 76 | render_spheres, render_tubes, smooth_shading = False, True, False 77 | else: 78 | render_spheres, render_tubes, smooth_shading = False, False, True 79 | mesh_kwargs = dict( 80 | scalars=key if key in model.array_names else None, 81 | style=model_style, 82 | render_points_as_spheres=render_spheres, 83 | render_lines_as_tubes=render_tubes, 84 | point_size=model_size, 85 | line_width=model_size, 86 | ambient=ambient, 87 | opacity=opacity, 88 | smooth_shading=smooth_shading, 89 | show_scalar_bar=False, 90 | cmap=cmap, 91 | color=color, 92 | name=model_name, 93 | ) 94 | actor = plotter.add_mesh(model, **mesh_kwargs) 95 | return actor 96 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/pv_pipeline/pv_tdr.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import numpy as np 4 | import pyvista as pv 5 | from anndata import AnnData 6 | from pandas.core.frame import DataFrame 7 | from pyvista import DataSet, PolyData, UnstructuredGrid 8 | from scipy.spatial.distance import cdist 9 | 10 | try: 11 | from typing import Literal 12 | except ImportError: 13 | from typing_extensions import Literal 14 | 15 | 16 | #################################### 17 | # point cloud model reconstruction # 18 | #################################### 19 | 20 | 21 | def construct_pc( 22 | adata: AnnData, 23 | spatial_key: str = "spatial", 24 | ) -> PolyData: 25 | """ 26 | Construct a point cloud model based on 3D coordinate information. 27 | 28 | Args: 29 | adata: AnnData object. 30 | spatial_key: The key in ``.obsm`` that corresponds to the spatial coordinate of each bucket. 31 | 32 | Returns: 33 | pc: A point cloud, which contains the following properties: 34 | ``pc.point_data[key_added]``, the ``groupby`` information. 35 | ``pc.point_data[f'{key_added}_rgba']``, the rgba colors of the ``groupby`` information. 36 | ``pc.point_data['obs_index']``, the obs_index of each coordinate in the original adata. 37 | plot_cmap: Recommended colormap parameter values for plotting. 38 | """ 39 | 40 | # create an initial pc. 41 | adata = adata.copy() 42 | bucket_xyz = adata.obsm[spatial_key].astype(np.float64) 43 | if isinstance(bucket_xyz, DataFrame): 44 | bucket_xyz = bucket_xyz.values 45 | pc = pv.PolyData(bucket_xyz) 46 | pc.point_data["obs_index"] = np.array(adata.obs_names.tolist()) 47 | return pc 48 | 49 | 50 | ####################### 51 | # Mesh reconstruction # 52 | ####################### 53 | 54 | 55 | def merge_models( 56 | models: List[PolyData or UnstructuredGrid or DataSet], 57 | ) -> PolyData or UnstructuredGrid: 58 | """Merge all models in the `models` list. The format of all models must be the same.""" 59 | 60 | merged_model = models[0] 61 | for model in models[1:]: 62 | merged_model = merged_model.merge(model) 63 | 64 | return merged_model 65 | 66 | 67 | def rigid_transform( 68 | coords: np.ndarray, 69 | coords_refA: np.ndarray, 70 | coords_refB: np.ndarray, 71 | ) -> np.ndarray: 72 | """ 73 | Compute optimal transformation based on the two sets of points and apply the transformation to other points. 74 | 75 | Args: 76 | coords: Coordinate matrix needed to be transformed. 77 | coords_refA: Referential coordinate matrix before transformation. 78 | coords_refB: Referential coordinate matrix after transformation. 79 | 80 | Returns: 81 | The coordinate matrix after transformation 82 | """ 83 | # Check the spatial coordinates 84 | 85 | coords, coords_refA, coords_refB = ( 86 | coords.copy(), 87 | coords_refA.copy(), 88 | coords_refB.copy(), 89 | ) 90 | assert ( 91 | coords.shape[1] == coords_refA.shape[1] == coords_refA.shape[1] 92 | ), "The dimensions of the input coordinates must be uniform, 2D or 3D." 93 | coords_dim = coords.shape[1] 94 | if coords_dim == 2: 95 | coords = np.c_[coords, np.zeros(shape=(coords.shape[0], 1))] 96 | coords_refA = np.c_[coords_refA, np.zeros(shape=(coords_refA.shape[0], 1))] 97 | coords_refB = np.c_[coords_refB, np.zeros(shape=(coords_refB.shape[0], 1))] 98 | 99 | # Compute optimal transformation based on the two sets of points. 100 | coords_refA = coords_refA.T 101 | coords_refB = coords_refB.T 102 | 103 | centroid_A = np.mean(coords_refA, axis=1).reshape(-1, 1) 104 | centroid_B = np.mean(coords_refB, axis=1).reshape(-1, 1) 105 | 106 | Am = coords_refA - centroid_A 107 | Bm = coords_refB - centroid_B 108 | H = Am @ np.transpose(Bm) 109 | 110 | U, S, Vt = np.linalg.svd(H) 111 | R = Vt.T @ U.T 112 | 113 | if np.linalg.det(R) < 0: 114 | Vt[2, :] *= -1 115 | R = Vt.T @ U.T 116 | 117 | t = -R @ centroid_A + centroid_B 118 | 119 | # Apply the transformation to other points 120 | new_coords = (R @ coords.T) + t 121 | new_coords = np.asarray(new_coords.T) 122 | return new_coords[:, :2] if coords_dim == 2 else new_coords 123 | 124 | 125 | def _scale_model_by_scale_factor( 126 | model: DataSet, 127 | scale_factor: Union[int, float, list, tuple] = 1, 128 | scale_center: Union[list, tuple] = None, 129 | ) -> DataSet: 130 | # Check the scaling factor. 131 | scale_factor = ( 132 | scale_factor if isinstance(scale_factor, (tuple, list)) else [scale_factor] * 3 133 | ) 134 | if len(scale_factor) != 3: 135 | raise ValueError( 136 | "`scale_factor` value is wrong." 137 | "\nWhen `scale_factor` is a list or tuple, it can only contain three elements." 138 | ) 139 | 140 | # Check the scaling center. 141 | scale_center = model.center if scale_center is None else scale_center 142 | if len(scale_center) != 3: 143 | raise ValueError( 144 | "`scale_center` value is wrong." 145 | "\n`scale_center` can only contain three elements." 146 | ) 147 | 148 | # Scale the model based on the scale center. 149 | for i, (f, c) in enumerate(zip(scale_factor, scale_center)): 150 | model.points[:, i] = (model.points[:, i] - c) * f + c 151 | 152 | return model 153 | 154 | 155 | def scale_model( 156 | model: Union[PolyData, UnstructuredGrid], 157 | scale_factor: Union[float, int, list, tuple] = 1, 158 | scale_center: Union[list, tuple] = None, 159 | inplace: bool = False, 160 | ) -> Union[PolyData, UnstructuredGrid, None]: 161 | """ 162 | Scale the model around the center of the model. 163 | 164 | Args: 165 | model: A 3D reconstructed model. 166 | scale_factor: The scale by which the model is scaled. If `scale factor` is float, the model is scaled along the 167 | xyz axis at the same scale; when the `scale factor` is list, the model is scaled along the xyz 168 | axis at different scales. If `scale_factor` is None, there will be no scaling based on scale factor. 169 | scale_center: Scaling center. If `scale factor` is None, the `scale_center` will default to the center of the model. 170 | inplace: Updates model in-place. 171 | 172 | Returns: 173 | model_s: The scaled model. 174 | """ 175 | 176 | model_s = model.copy() if not inplace else model 177 | model_s = _scale_model_by_scale_factor( 178 | model=model_s, scale_factor=scale_factor, scale_center=scale_center 179 | ) 180 | model_s = model_s.triangulate() 181 | return model_s if not inplace else None 182 | 183 | 184 | def marching_cube_mesh( 185 | pc: PolyData, 186 | levelset: Union[int, float] = 0, 187 | mc_scale_factor: Union[int, float] = 1.0, 188 | ): 189 | """ 190 | Computes a triangle mesh from a point cloud based on the marching cube algorithm. 191 | Algorithm Overview: 192 | The algorithm proceeds through the scalar field, taking eight neighbor locations at a time (thus forming an 193 | imaginary cube), then determining the polygon(s) needed to represent the part of the iso-surface that passes 194 | through this cube. The individual polygons are then fused into the desired surface. 195 | 196 | Args: 197 | pc: A point cloud model. 198 | levelset: The levelset of iso-surface. It is recommended to set levelset to 0 or 0.5. 199 | mc_scale_factor: The scale of the model. The scaled model is used to construct the mesh model. 200 | 201 | Returns: 202 | A mesh model. 203 | """ 204 | try: 205 | import mcubes 206 | except ImportError: 207 | raise ImportError( 208 | "You need to install the package `mcubes`." 209 | "\nInstall mcubes via `pip install --upgrade PyMCubes`" 210 | ) 211 | 212 | pc = pc.copy() 213 | 214 | # Move the model so that the coordinate minimum is at (0, 0, 0). 215 | raw_points = np.asarray(pc.points) 216 | pc.points = new_points = raw_points - np.min(raw_points, axis=0) 217 | 218 | # Generate new models for calculatation. 219 | dist = cdist(XA=new_points, XB=new_points, metric="euclidean") 220 | row, col = np.diag_indices_from(dist) 221 | dist[row, col] = None 222 | max_dist = np.nanmin(dist, axis=1).max() 223 | mc_sf = max_dist * mc_scale_factor 224 | 225 | scale_pc = scale_model(model=pc, scale_factor=1 / mc_sf) 226 | scale_pc_points = scale_pc.points = np.ceil(np.asarray(scale_pc.points)).astype( 227 | np.int64 228 | ) 229 | 230 | # Generate grid for calculatation based on new model. 231 | volume_array = np.zeros( 232 | shape=[ 233 | scale_pc_points[:, 0].max() + 3, 234 | scale_pc_points[:, 1].max() + 3, 235 | scale_pc_points[:, 2].max() + 3, 236 | ] 237 | ) 238 | volume_array[ 239 | scale_pc_points[:, 0], scale_pc_points[:, 1], scale_pc_points[:, 2] 240 | ] = 1 241 | 242 | # Extract the iso-surface based on marching cubes algorithm. 243 | # volume_array = mcubes.smooth(volume_array) 244 | vertices, triangles = mcubes.marching_cubes(volume_array, levelset) 245 | 246 | if len(vertices) == 0: 247 | raise ValueError( 248 | f"The point cloud cannot generate a surface mesh with `marching_cube` method." 249 | ) 250 | 251 | v = np.asarray(vertices).astype(np.float64) 252 | f = np.asarray(triangles).astype(np.int64) 253 | f = np.c_[np.full(len(f), 3), f] 254 | 255 | # Generate mesh model. 256 | mesh = pv.PolyData(v, f.ravel()).extract_surface().triangulate() 257 | mesh.clean(inplace=True) 258 | mesh = scale_model(model=mesh, scale_factor=mc_sf) 259 | 260 | # Transform. 261 | scale_pc = scale_model(model=scale_pc, scale_factor=mc_sf) 262 | mesh.points = rigid_transform( 263 | coords=np.asarray(mesh.points), 264 | coords_refA=np.asarray(scale_pc.points), 265 | coords_refB=raw_points, 266 | ) 267 | return mesh 268 | 269 | 270 | def smooth_mesh(mesh: PolyData, n_iter: int = 100, **kwargs) -> PolyData: 271 | """ 272 | Adjust point coordinates using Laplacian smoothing. 273 | https://docs.pyvista.org/api/core/_autosummary/pyvista.PolyData.smooth.html#pyvista.PolyData.smooth 274 | 275 | Args: 276 | mesh: A mesh model. 277 | n_iter: Number of iterations for Laplacian smoothing. 278 | **kwargs: The rest of the parameters in pyvista.PolyData.smooth. 279 | 280 | Returns: 281 | smoothed_mesh: A smoothed mesh model. 282 | """ 283 | 284 | smoothed_mesh = mesh.smooth(n_iter=n_iter, **kwargs) 285 | 286 | return smoothed_mesh 287 | 288 | 289 | def fix_mesh(mesh: PolyData) -> PolyData: 290 | """Repair the mesh where it was extracted and subtle holes along complex parts of the mesh.""" 291 | 292 | # Check pymeshfix package 293 | try: 294 | import pymeshfix as mf 295 | except ImportError: 296 | raise ImportError( 297 | "You need to install the package `pymeshfix`. \nInstall pymeshfix via `pip install pymeshfix`" 298 | ) 299 | 300 | meshfix = mf.MeshFix(mesh) 301 | meshfix.repair(verbose=False) 302 | fixed_mesh = meshfix.mesh.triangulate().clean() 303 | 304 | if fixed_mesh.n_points == 0: 305 | raise ValueError( 306 | f"The surface cannot be Repaired. " 307 | f"\nPlease change the method or parameters of surface reconstruction." 308 | ) 309 | 310 | return fixed_mesh 311 | 312 | 313 | def clean_mesh(mesh: PolyData) -> PolyData: 314 | """Removes unused points and degenerate cells.""" 315 | 316 | sub_meshes = mesh.split_bodies() 317 | n_mesh = len(sub_meshes) 318 | 319 | if n_mesh == 1: 320 | return mesh 321 | else: 322 | inside_number = [] 323 | for i, main_mesh in enumerate(sub_meshes[:-1]): 324 | main_mesh = pv.PolyData(main_mesh.points, main_mesh.cells) 325 | for j, check_mesh in enumerate(sub_meshes[i + 1 :]): 326 | check_mesh = pv.PolyData(check_mesh.points, check_mesh.cells) 327 | inside = check_mesh.select_enclosed_points( 328 | main_mesh, check_surface=False 329 | ).threshold(0.5) 330 | inside = pv.PolyData(inside.points, inside.cells) 331 | if check_mesh == inside: 332 | inside_number.append(i + 1 + j) 333 | 334 | cm_number = list(set([i for i in range(n_mesh)]).difference(set(inside_number))) 335 | if len(cm_number) == 1: 336 | cmesh = sub_meshes[cm_number[0]] 337 | else: 338 | cmesh = merge_models([sub_meshes[i] for i in cm_number]) 339 | 340 | return pv.PolyData(cmesh.points, cmesh.cells) 341 | 342 | 343 | def uniform_mesh( 344 | mesh: PolyData, nsub: Optional[int] = 3, nclus: int = 20000 345 | ) -> PolyData: 346 | """ 347 | Generate a uniformly meshed surface using voronoi clustering. 348 | 349 | Args: 350 | mesh: A mesh model. 351 | nsub: Number of subdivisions. Each subdivision creates 4 new triangles, so the number of resulting triangles is 352 | nface*4**nsub where nface is the current number of faces. 353 | nclus: Number of voronoi clustering. 354 | 355 | Returns: 356 | new_mesh: A uniform mesh model. 357 | """ 358 | # Check pyacvd package 359 | try: 360 | import pyacvd 361 | except ImportError: 362 | raise ImportError( 363 | "You need to install the package `pyacvd`. \nInstall pyacvd via `pip install pyacvd`" 364 | ) 365 | 366 | # if mesh is not dense enough for uniform remeshing, increase the number of triangles in a mesh. 367 | if not (nsub is None): 368 | mesh.subdivide(nsub=nsub, subfilter="butterfly", inplace=True) 369 | 370 | # Uniformly remeshing. 371 | clustered = pyacvd.Clustering(mesh) 372 | clustered.cluster(nclus) 373 | 374 | new_mesh = clustered.create_mesh().triangulate().clean() 375 | return new_mesh 376 | 377 | 378 | def construct_surface( 379 | pc: PolyData, 380 | levelset: Union[int, float] = 0, 381 | mc_scale_factor: Union[int, float] = 1.0, 382 | nsub: Optional[int] = 3, 383 | nclus: int = 20000, 384 | smooth: Optional[int] = 1000, 385 | scale_factor: Union[float, int, list, tuple] = None, 386 | ) -> Union[PolyData, UnstructuredGrid, None]: 387 | """ 388 | Surface mesh reconstruction based on 3D point cloud model. 389 | 390 | Args: 391 | pc: A point cloud model. 392 | levelset: The levelset of iso-surface. It is recommended to set levelset to 0 or 0.5. 393 | mc_scale_factor: The scale of the model. The scaled model is used to construct the mesh model. 394 | nsub: Number of subdivisions. Each subdivision creates 4 new triangles, so the number of resulting triangles is 395 | nface*4**nsub where nface is the current number of faces. 396 | nclus: Number of voronoi clustering. 397 | smooth: Number of iterations for Laplacian smoothing. 398 | scale_factor: The scale by which the model is scaled. If ``scale factor`` is float, the model is scaled along the 399 | xyz axis at the same scale; when the ``scale factor`` is list, the model is scaled along the xyz 400 | axis at different scales. If ``scale_factor`` is None, there will be no scaling based on scale factor. 401 | 402 | Returns: 403 | uniform_surf: A reconstructed surface mesh, which contains the following properties: 404 | ``uniform_surf.cell_data[key_added]``, the ``label`` array; 405 | ``uniform_surf.cell_data[f'{key_added}_rgba']``, the rgba colors of the ``label`` array. 406 | """ 407 | # Reconstruct surface mesh. 408 | surf = marching_cube_mesh(pc=pc, levelset=levelset, mc_scale_factor=mc_scale_factor) 409 | 410 | # Removes unused points and degenerate cells. 411 | csurf = clean_mesh(mesh=surf) 412 | 413 | uniform_surfs = [] 414 | for sub_surf in csurf.split_bodies(): 415 | # Repair the surface mesh where it was extracted and subtle holes along complex parts of the mesh 416 | sub_fix_surf = fix_mesh(mesh=sub_surf.extract_surface()) 417 | 418 | # Get a uniformly meshed surface using voronoi clustering. 419 | sub_uniform_surf = uniform_mesh(mesh=sub_fix_surf, nsub=nsub, nclus=nclus) 420 | uniform_surfs.append(sub_uniform_surf) 421 | uniform_surf = merge_models(models=uniform_surfs) 422 | uniform_surf = uniform_surf.extract_surface().triangulate().clean() 423 | 424 | # Adjust point coordinates using Laplacian smoothing. 425 | if not (smooth is None): 426 | uniform_surf = smooth_mesh(mesh=uniform_surf, n_iter=smooth) 427 | 428 | # Scale the surface mesh. 429 | uniform_surf = scale_model(model=uniform_surf, scale_factor=scale_factor) 430 | return uniform_surf 431 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/__init__.py: -------------------------------------------------------------------------------- 1 | from .container import ui_container 2 | from .drawer import ui_drawer 3 | from .layout import ui_layout 4 | from .toolbar import ui_toolbar 5 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/container.py: -------------------------------------------------------------------------------- 1 | try: 2 | from typing import Literal 3 | except ImportError: 4 | from typing_extensions import Literal 5 | 6 | from trame.widgets import html 7 | from trame.widgets import vtk as vtk_widgets 8 | from trame.widgets import vuetify 9 | 10 | VIEW_INTERACT = [ 11 | {"button": 1, "action": "Rotate"}, 12 | {"button": 2, "action": "Pan"}, 13 | {"button": 3, "action": "Zoom", "scrollEnabled": True}, 14 | {"button": 1, "action": "Pan", "alt": True}, 15 | {"button": 1, "action": "Zoom", "control": True}, 16 | {"button": 1, "action": "Pan", "shift": True}, 17 | {"button": 1, "action": "Roll", "alt": True, "shift": True}, 18 | ] 19 | 20 | VIEW_SELECT = [{"button": 1, "action": "Select"}] 21 | 22 | 23 | # ----------------------------------------------------------------------------- 24 | # GUI- standard Container 25 | # ----------------------------------------------------------------------------- 26 | 27 | 28 | def ui_container( 29 | server, 30 | layout, 31 | ): 32 | """ 33 | Generate standard VContainer for Spateo UI. 34 | 35 | Args: 36 | server: The trame server. 37 | layout: The layout object. 38 | """ 39 | 40 | state, ctrl = server.state, server.controller 41 | with layout.content: 42 | with vuetify.VContainer( 43 | fluid=True, classes="pa-0 fill-height", style="position: relative;" 44 | ): 45 | with vuetify.VCard( 46 | style=("tooltipStyle", {"display": "none"}), elevation=2, outlined=True 47 | ): 48 | with vuetify.VCardText(): 49 | html.Pre("{{ tooltip }}") 50 | 51 | with vtk_widgets.VtkView( 52 | ref="render", 53 | background=(state.background_color,), 54 | picking_modes=("[pickingMode]",), 55 | interactor_settings=("interactorSettings", VIEW_INTERACT), 56 | click="pickData = $event", 57 | hover="pickData = $event", 58 | select="selectData = $event", 59 | ) as view: 60 | ctrl.view_reset_camera = view.reset_camera 61 | with vtk_widgets.VtkGeometryRepresentation( 62 | id="activeModel", 63 | v_if="activeModel", 64 | actor=("{ visibility: activeModelVisible }",), 65 | color_map_preset=("colorMap",), 66 | color_data_range=("scalarParameters[scalar].range",), 67 | mapper=( 68 | "{ colorByArrayName: scalar, scalarMode: scalarParameters[scalar].scalarMode," 69 | " interpolateScalarsBeforeMapping: true, scalarVisibility: scalar !== 'Default' }", 70 | ), 71 | property=( 72 | { 73 | "pointSize": state.pixel_ratio, 74 | "representation": 1, 75 | "opacity": 1, 76 | "ambient": 0.3, 77 | }, 78 | ), 79 | ): 80 | vtk_widgets.VtkMesh("activeModel", state=("activeModel",)) 81 | with vtk_widgets.VtkGeometryRepresentation( 82 | id="meshModel", 83 | v_if="meshModel", 84 | actor=("{ visibility: meshModelVisible }",), 85 | property=( 86 | { 87 | "representation": 1, 88 | "opacity": 0.6, 89 | "ambient": 0.1, 90 | }, 91 | ), 92 | ): 93 | vtk_widgets.VtkMesh("meshModel", state=("meshModel",)) 94 | 95 | # Custom model visibility 96 | if state.custom_func is True: 97 | with vtk_widgets.VtkGeometryRepresentation( 98 | id="custom_model", 99 | v_if="custom_model", 100 | actor=("{ visibility: custom_model_visible }",), 101 | property=( 102 | { 103 | "lineWidth": state.custom_model_size, 104 | "pointSize": state.custom_model_size, 105 | "representation": 1, 106 | "opacity": 1, 107 | "ambient": 0.1, 108 | }, 109 | ), 110 | ): 111 | vtk_widgets.VtkMesh("custom_model", state=("custom_model",)) 112 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/drawer/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import ui_drawer 2 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/drawer/alignment.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import vuetify 2 | 3 | 4 | def align_card_content(): 5 | with vuetify.VRow(classes="pt-2", dense=True): 6 | with vuetify.VCol(cols="6"): 7 | vuetify.VCheckbox( 8 | v_model=("slices_alignment", False), 9 | label="Align slices", 10 | on_icon="mdi-layers-outline", 11 | off_icon="mdi-layers-off-outline", 12 | dense=True, 13 | hide_details=True, 14 | classes="pt-1", 15 | ) 16 | with vuetify.VCol(cols="6"): 17 | vuetify.VSelect( 18 | v_model=("slices_align_method", "Paste"), 19 | items=(["Paste", "Morpho"],), 20 | show_size=True, 21 | dense=True, 22 | outlined=True, 23 | hide_details=True, 24 | classes="pt-1", 25 | label="Method of Alignment", 26 | ) 27 | with vuetify.VRow(classes="pt-2", dense=True): 28 | with vuetify.VCol(cols="6"): 29 | vuetify.VTextField( 30 | v_model=("slices_key", "slices"), 31 | label="Slices Key", 32 | hide_details=True, 33 | dense=True, 34 | outlined=True, 35 | classes="pt-1", 36 | ) 37 | with vuetify.VCol(cols="6"): 38 | vuetify.VTextField( 39 | v_model=("slices_align_factor", 0.1), 40 | label="Align Factor", 41 | hide_details=True, 42 | dense=True, 43 | outlined=True, 44 | classes="pt-1", 45 | ) 46 | with vuetify.VRow(classes="pt-2", dense=True): 47 | with vuetify.VCol(cols="6"): 48 | vuetify.VTextField( 49 | v_model=("slices_align_max_iter", 200), 50 | label="Max Iterations", 51 | hide_details=True, 52 | dense=True, 53 | outlined=True, 54 | classes="pt-1", 55 | ) 56 | with vuetify.VCol(cols="6"): 57 | vuetify.VTextField( 58 | v_model=("slices_align_device", "CPU"), 59 | label="Device", 60 | show_size=True, 61 | dense=True, 62 | outlined=True, 63 | hide_details=True, 64 | classes="pt-1", 65 | ) 66 | 67 | 68 | def align_card_panel(): 69 | with vuetify.VToolbar( 70 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;" 71 | ): 72 | vuetify.VIcon("mdi-target") 73 | vuetify.VCardTitle( 74 | " Slices Alignment", 75 | classes="pa-0 ma-0", 76 | style="flex: none;", 77 | hide_details=True, 78 | dense=True, 79 | ) 80 | 81 | vuetify.VSpacer() 82 | with vuetify.VBtn( 83 | small=True, 84 | icon=True, 85 | click="show_align_card = !show_align_card", 86 | ): 87 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_align_card",)) 88 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_align_card",)) 89 | 90 | # Main content 91 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"): 92 | with vuetify.VCardText(classes="py-2", v_if=("show_align_card",)): 93 | align_card_content() 94 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/drawer/custom_card.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import vuetify 2 | 3 | 4 | def custom_card_content(): 5 | with vuetify.VRow(classes="pt-2", dense=True): 6 | with vuetify.VCol(cols="6"): 7 | vuetify.VCheckbox( 8 | v_model=("custom_analysis", False), 9 | label="Custom analysis calculation", 10 | on_icon="mdi-billiards-rack", 11 | off_icon="mdi-dots-triangle", 12 | dense=True, 13 | hide_details=True, 14 | classes="pt-1", 15 | ) 16 | with vuetify.VCol(cols="6"): 17 | vuetify.VCheckbox( 18 | v_model=("custom_model_visible", False), 19 | label="Custom model visibility", 20 | on_icon="mdi-eye-outline", 21 | off_icon="mdi-eye-off-outline", 22 | dense=True, 23 | hide_details=True, 24 | classes="pt-1", 25 | ) 26 | 27 | with vuetify.VRow(classes="pt-2", dense=True): 28 | with vuetify.VCol(cols="6"): 29 | vuetify.VSelect( 30 | label="Custom Parameter 1", 31 | v_model=("custom_parameter1", "ElPiGraph"), 32 | items=(["ElPiGraph", "SimplePPT", "PrinCurve"],), 33 | hide_details=True, 34 | dense=True, 35 | outlined=True, 36 | classes="pt-1", 37 | ) 38 | 39 | with vuetify.VCol(cols="6"): 40 | vuetify.VTextField( 41 | label="Custom Parameter 2", 42 | v_model=("custom_parameter2", 50), 43 | hide_details=True, 44 | dense=True, 45 | outlined=True, 46 | classes="pt-1", 47 | ) 48 | 49 | with vuetify.VRow(classes="pt-2", dense=True): 50 | with vuetify.VCol(cols="12"): 51 | vuetify.VTextField( 52 | v_model=("custom_model_output", None), 53 | label="Custom model Output", 54 | show_size=True, 55 | hide_details=True, 56 | dense=True, 57 | outlined=True, 58 | classes="pt-1", 59 | ) 60 | 61 | 62 | def custom_card_panel(): 63 | with vuetify.VToolbar( 64 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;" 65 | ): 66 | vuetify.VIcon("mdi-billiards-rack") 67 | vuetify.VCardTitle( 68 | " Custom Card", 69 | classes="pa-0 ma-0", 70 | style="flex: none;", 71 | hide_details=True, 72 | dense=True, 73 | ) 74 | 75 | vuetify.VSpacer() 76 | with vuetify.VBtn( 77 | small=True, 78 | icon=True, 79 | click="show_custom_card = !show_custom_card", 80 | ): 81 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_custom_card",)) 82 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_custom_card",)) 83 | 84 | # Main content 85 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"): 86 | with vuetify.VCardText(classes="py-2", v_if=("show_custom_card",)): 87 | custom_card_content() 88 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/drawer/main.py: -------------------------------------------------------------------------------- 1 | def _get_spateo_cmap(): 2 | import matplotlib as mpl 3 | from matplotlib.colors import LinearSegmentedColormap 4 | 5 | if "spateo_cmap" not in mpl.colormaps(): 6 | colors = ["#4B0082", "#800080", "#F97306", "#FFA500", "#FFD700", "#FFFFCB"] 7 | nodes = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] 8 | 9 | mpl.colormaps.register( 10 | LinearSegmentedColormap.from_list("spateo_cmap", list(zip(nodes, colors))) 11 | ) 12 | return "spateo_cmap" 13 | 14 | 15 | def ui_drawer(server, layout): 16 | """ 17 | Generate standard Drawer for Spateo UI. 18 | 19 | Args: 20 | server: The trame server. 21 | layout: The layout object. 22 | """ 23 | 24 | _get_spateo_cmap() 25 | with layout.drawer as dr: 26 | # Active model 27 | from .model_point import pc_card_panel 28 | 29 | pc_card_panel() 30 | # Slices alignment 31 | from .alignment import align_card_panel 32 | 33 | align_card_panel() 34 | # Mesh reconstruction 35 | from .reconstruction import mesh_card_panel 36 | 37 | mesh_card_panel() 38 | # Custom 39 | if server.state.custom_func is True: 40 | from .custom_card import custom_card_panel 41 | 42 | custom_card_panel() 43 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/drawer/model_point.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import vuetify 2 | 3 | 4 | def pc_card_content(): 5 | with vuetify.VRow(classes="pt-2", dense=True): 6 | with vuetify.VCol(cols="6"): 7 | vuetify.VSelect( 8 | v_model=("scalar", "Default"), 9 | items=("Object.values(scalarParameters)",), 10 | show_size=True, 11 | dense=True, 12 | outlined=True, 13 | hide_details=True, 14 | classes="pt-1", 15 | label="Scalars", 16 | ) 17 | with vuetify.VCol(cols="6"): 18 | vuetify.VSelect( 19 | v_model=("colorMap", "erdc_rainbow_bright"), 20 | items=("trame.utils.vtk.vtkColorPresetItems('')",), 21 | show_size=True, 22 | # truncate_length=25, 23 | dense=True, 24 | outlined=True, 25 | hide_details=True, 26 | classes="pt-1", 27 | # style="max-width: 150px", 28 | label="Colormap", 29 | ) 30 | with vuetify.VRow(classes="pt-2", dense=True): 31 | with vuetify.VCol(cols="6"): 32 | vuetify.VSelect( 33 | v_model=("picking_group", None), 34 | items=("Object.keys(scalarParameters[scalar].raw_labels)",), 35 | show_size=True, 36 | dense=True, 37 | outlined=True, 38 | hide_details=True, 39 | classes="pt-2", 40 | label="Picking Group", 41 | ) 42 | with vuetify.VCol(cols="6"): 43 | vuetify.VCheckbox( 44 | v_model=("overwrite", False), 45 | label="Overwrite the Active Model", 46 | on_icon="mdi-plus-thick", 47 | off_icon="mdi-close-thick", 48 | dense=True, 49 | hide_details=True, 50 | classes="pt-1", 51 | ) 52 | 53 | with vuetify.VRow(classes="pt-2", dense=True): 54 | with vuetify.VCol(cols="12"): 55 | vuetify.VTextField( 56 | v_model=("activeModel_output", None), 57 | label="Active Model Output", 58 | hide_details=True, 59 | dense=True, 60 | outlined=True, 61 | classes="pt-1", 62 | ) 63 | with vuetify.VRow(classes="pt-2", dense=True): 64 | with vuetify.VCol(cols="12"): 65 | vuetify.VTextField( 66 | v_model=("anndata_output", None), 67 | label="Anndata Output", 68 | hide_details=True, 69 | dense=True, 70 | outlined=True, 71 | classes="pt-1", 72 | ) 73 | with vuetify.VRow(classes="pt-2", dense=True): 74 | with vuetify.VCol(cols="12"): 75 | vuetify.VCheckbox( 76 | v_model=("activeModelVisible", True), 77 | label="Visibility of Active Model", 78 | on_icon="mdi-eye-outline", 79 | off_icon="mdi-eye-off-outline", 80 | dense=True, 81 | hide_details=True, 82 | classes="pt-1", 83 | ) 84 | 85 | 86 | def pc_card_panel(): 87 | with vuetify.VToolbar( 88 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;" 89 | ): 90 | vuetify.VIcon("mdi-format-paint") 91 | vuetify.VCardTitle( 92 | " Active Model", 93 | classes="pa-0 ma-0", 94 | style="flex: none;", 95 | hide_details=True, 96 | dense=True, 97 | ) 98 | 99 | vuetify.VSpacer() 100 | with vuetify.VBtn( 101 | small=True, 102 | icon=True, 103 | click="show_active_card = !show_active_card", 104 | ): 105 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_active_card",)) 106 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_active_card",)) 107 | 108 | # Main content 109 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"): 110 | with vuetify.VCardText(classes="py-2", v_if=("show_active_card",)): 111 | pc_card_content() 112 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/drawer/reconstruction.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import vuetify 2 | 3 | 4 | def mesh_card_content(): 5 | with vuetify.VRow(classes="pt-2", dense=True): 6 | with vuetify.VCol(cols="6"): 7 | vuetify.VCheckbox( 8 | v_model=("reconstruct_mesh", False), 9 | label="Reconstruct Mesh Model", 10 | on_icon="mdi-billiards-rack", 11 | off_icon="mdi-dots-triangle", 12 | dense=True, 13 | hide_details=True, 14 | classes="pt-1", 15 | ) 16 | with vuetify.VCol(cols="6"): 17 | vuetify.VCheckbox( 18 | v_model=("clip_pc_with_mesh", False), 19 | label="Clip with Mesh Model", 20 | on_icon="mdi-box-cutter", 21 | off_icon="mdi-box-cutter-off", 22 | dense=True, 23 | hide_details=True, 24 | classes="pt-1", 25 | ) 26 | 27 | with vuetify.VRow(classes="pt-2", dense=True): 28 | with vuetify.VCol(cols="6"): 29 | vuetify.VTextField( 30 | v_model=("mc_factor", 1.0), 31 | label="MC Factor", 32 | hide_details=True, 33 | dense=True, 34 | outlined=True, 35 | classes="pt-1", 36 | ) 37 | with vuetify.VCol(cols="6"): 38 | vuetify.VTextField( 39 | v_model=("mesh_voronoi", 20000), 40 | label="Voronoi Clustering", 41 | hide_details=True, 42 | dense=True, 43 | outlined=True, 44 | classes="pt-1", 45 | ) 46 | 47 | with vuetify.VRow(classes="pt-2", dense=True): 48 | with vuetify.VCol(cols="6"): 49 | vuetify.VTextField( 50 | v_model=("mesh_smooth_factor", 1000), 51 | label="Smooth Factor", 52 | hide_details=True, 53 | dense=True, 54 | outlined=True, 55 | classes="pt-1", 56 | ) 57 | with vuetify.VCol(cols="6"): 58 | vuetify.VTextField( 59 | v_model=("mesh_scale_factor", 1.0), 60 | label="Scale Factor", 61 | hide_details=True, 62 | dense=True, 63 | outlined=True, 64 | classes="pt-1", 65 | ) 66 | 67 | with vuetify.VRow(classes="pt-2", dense=True): 68 | with vuetify.VCol(cols="12"): 69 | vuetify.VTextField( 70 | v_model=("mesh_output", None), 71 | label="Reconstructed Mesh Output", 72 | show_size=True, 73 | hide_details=True, 74 | dense=True, 75 | outlined=True, 76 | classes="pt-1", 77 | ) 78 | with vuetify.VRow(classes="pt-2", dense=True): 79 | with vuetify.VCol(cols="12"): 80 | vuetify.VCheckbox( 81 | v_model=("meshModelVisible", False), 82 | label="Visibility of Mesh Model", 83 | on_icon="mdi-eye-outline", 84 | off_icon="mdi-eye-off-outline", 85 | dense=True, 86 | hide_details=True, 87 | classes="pt-1", 88 | ) 89 | 90 | 91 | def mesh_card_panel(): 92 | with vuetify.VToolbar( 93 | dense=True, outlined=True, classes="pa-0 ma-0", style="flex: none;" 94 | ): 95 | vuetify.VIcon("mdi-billiards-rack") 96 | vuetify.VCardTitle( 97 | " Mesh Reconstruction", 98 | classes="pa-0 ma-0", 99 | style="flex: none;", 100 | hide_details=True, 101 | dense=True, 102 | ) 103 | 104 | vuetify.VSpacer() 105 | with vuetify.VBtn( 106 | small=True, 107 | icon=True, 108 | click="show_mesh_card = !show_mesh_card", 109 | ): 110 | vuetify.VIcon("mdi-unfold-less-horizontal", v_if=("show_mesh_card",)) 111 | vuetify.VIcon("mdi-unfold-more-horizontal", v_if=("!show_mesh_card",)) 112 | 113 | # Main content 114 | with vuetify.VCard(style="flex: none;", classes="pa-0 ma-0"): 115 | with vuetify.VCardText(classes="py-2", v_if=("show_mesh_card",)): 116 | mesh_card_content() 117 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/layout.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | # ----------------------------------------------------------------------------- 4 | # GUI layout 5 | # ----------------------------------------------------------------------------- 6 | 7 | 8 | def ui_layout( 9 | server, template_name: str = "main", drawer_width: Optional[int] = None, **kwargs 10 | ): 11 | """ 12 | Define the user interface (UI) layout. 13 | Reference: https://trame.readthedocs.io/en/latest/trame.ui.vuetify.html#trame.ui.vuetify.SinglePageWithDrawerLayout 14 | 15 | Args: 16 | server: Server to bound the layout to. 17 | template_name: Name of the template. 18 | drawer_width: Drawer width in pixel. 19 | 20 | Returns: 21 | The SinglePageWithDrawerLayout layout object. 22 | """ 23 | from trame.ui.vuetify import SinglePageWithDrawerLayout 24 | 25 | if drawer_width is None: 26 | import pyautogui 27 | 28 | screen_width, screen_height = pyautogui.size() 29 | drawer_width = int(screen_width * 0.15) 30 | 31 | return SinglePageWithDrawerLayout( 32 | server, template_name=template_name, width=drawer_width, **kwargs 33 | ) 34 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/toolbar.py: -------------------------------------------------------------------------------- 1 | try: 2 | from typing import Literal 3 | except ImportError: 4 | from typing_extensions import Literal 5 | 6 | from typing import Optional 7 | 8 | from pyvista import BasePlotter 9 | from trame.widgets import html, vuetify 10 | 11 | from stviewer.assets import icon_manager 12 | from stviewer.Reconstructor.pv_pipeline import Viewer 13 | 14 | from .utils import button, checkbox 15 | 16 | # ----------------------------------------------------------------------------- 17 | # GUI- UI title 18 | # ----------------------------------------------------------------------------- 19 | 20 | 21 | def ui_title( 22 | layout, title_name="SPATEO VIEWER", title_icon: Optional[str] = None, **kwargs 23 | ): 24 | """ 25 | Define the title name and logo of the UI. 26 | Reference: https://trame.readthedocs.io/en/latest/trame.ui.vuetify.html#trame.ui.vuetify.SinglePageWithDrawerLayout 27 | 28 | Args: 29 | layout: The layout object. 30 | title_name: Title name of the GUI. 31 | title_icon: Title icon of the GUI. 32 | **kwargs: Additional parameters that will be passed to ``html.Img`` function. 33 | 34 | Returns: 35 | None. 36 | """ 37 | 38 | # Update the toolbar's name 39 | layout.title.set_text(title_name) 40 | layout.title.style = ( 41 | "font-family:arial; font-size:25px; font-weight: 550; color: gray;" 42 | ) 43 | 44 | # Update the toolbar's icon 45 | if not (title_icon is None): 46 | with layout.icon as icon: 47 | icon.style = "margin-left: 10px;" # "width: 7vw; height: 7vh;" 48 | html.Img(src=title_icon, height=40, **kwargs) 49 | 50 | 51 | # ----------------------------------------------------------------------------- 52 | # GUI- standard ToolBar 53 | # ----------------------------------------------------------------------------- 54 | 55 | 56 | def toolbar_widgets(server, plotter: BasePlotter): 57 | """ 58 | Generate standard widgets for ToolBar. 59 | 60 | Args: 61 | server: The trame server. 62 | plotter: The PyVista plotter to connect with the UI. 63 | """ 64 | viewer = Viewer(server=server, plotter=plotter) 65 | 66 | vuetify.VSpacer() 67 | # Upload file 68 | vuetify.VFileInput( 69 | v_model=(viewer.UPLOAD_ANNDATA, None), 70 | label="Select Sample", 71 | show_size=True, 72 | small_chips=True, 73 | truncate_length=25, 74 | dense=True, 75 | outlined=True, 76 | hide_details=True, 77 | classes="ml-8", 78 | style="max-width: 300px;", 79 | rounded=True, 80 | accept=".h5ad", 81 | __properties=["accept"], 82 | ) 83 | 84 | vuetify.VSpacer() 85 | # Change the selection mode 86 | with vuetify.VBtnToggle(v_model=(viewer.PICKING_MODE, "hover"), dense=True): 87 | with vuetify.VBtn(value=("item.value",), v_for="item, idx in modes"): 88 | vuetify.VIcon("{{item.icon}}") 89 | # Whether to reload the main model 90 | button( 91 | click=viewer.on_reload_main_model, 92 | icon="mdi-restore", 93 | tooltip="Reload main model", 94 | ) 95 | 96 | vuetify.VProgressLinear( 97 | indeterminate=True, absolute=True, bottom=True, active=("trame__busy",) 98 | ) 99 | 100 | 101 | def ui_toolbar( 102 | server, 103 | layout, 104 | plotter: BasePlotter, 105 | ui_name: str = "SPATEO VIEWER", 106 | ui_icon=icon_manager.spateo_logo, 107 | ): 108 | """ 109 | Generate standard ToolBar for Spateo UI. 110 | 111 | Args: 112 | server: The trame server. 113 | layout: The layout object. 114 | plotter: The PyVista plotter to connect with the UI. 115 | ui_name: Title name of the GUI. 116 | ui_icon: Title icon of the GUI. 117 | """ 118 | 119 | # ----------------------------------------------------------------------------- 120 | # Title 121 | # ----------------------------------------------------------------------------- 122 | ui_title(layout=layout, title_name=ui_name, title_icon=ui_icon) 123 | 124 | # ----------------------------------------------------------------------------- 125 | # ToolBar 126 | # ----------------------------------------------------------------------------- 127 | with layout.toolbar as tb: 128 | tb.height = 55 129 | tb.dense = True 130 | tb.clipped_right = True 131 | toolbar_widgets(server=server, plotter=plotter) 132 | -------------------------------------------------------------------------------- /stviewer/Reconstructor/ui/utils.py: -------------------------------------------------------------------------------- 1 | from trame.widgets import html, vuetify 2 | 3 | # ----------------------------------------------------------------------------- 4 | # vuetify components 5 | # ----------------------------------------------------------------------------- 6 | 7 | 8 | def button(click, icon, tooltip): 9 | """Create a vuetify button.""" 10 | with vuetify.VTooltip(bottom=True): 11 | with vuetify.Template(v_slot_activator="{ on, attrs }"): 12 | with vuetify.VBtn(icon=True, v_bind="attrs", v_on="on", click=click): 13 | vuetify.VIcon(icon) 14 | html.Span(tooltip) 15 | 16 | 17 | def checkbox(model, icons, tooltip, **kwargs): 18 | """Create a vuetify checkbox.""" 19 | with vuetify.VTooltip(bottom=True): 20 | with vuetify.Template(v_slot_activator="{ on, attrs }"): 21 | with html.Div(v_on="on", v_bind="attrs"): 22 | vuetify.VCheckbox( 23 | v_model=model, 24 | on_icon=icons[0], 25 | off_icon=icons[1], 26 | dense=True, 27 | hide_details=True, 28 | classes="my-0 py-0 ml-1", 29 | **kwargs 30 | ) 31 | html.Span(tooltip) 32 | 33 | 34 | def switch(model, tooltip, **kwargs): 35 | """Create a vuetify switch.""" 36 | with vuetify.VTooltip(bottom=True): 37 | with vuetify.Template(v_slot_activator="{ on, attrs }"): 38 | vuetify.VSwitch(v_model=model, hide_details=True, dense=True, **kwargs) 39 | html.Span(tooltip) 40 | -------------------------------------------------------------------------------- /stviewer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/__init__.py -------------------------------------------------------------------------------- /stviewer/assets/__init__.py: -------------------------------------------------------------------------------- 1 | from .anndata_preprocess import anndata_preprocess 2 | from .dataset_acquisition import abstract_anndata, abstract_models, sample_dataset 3 | from .dataset_manager import local_dataset_manager 4 | from .image_manager import icon_manager 5 | -------------------------------------------------------------------------------- /stviewer/assets/anndata_preprocess.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import anndata as ad 4 | from scipy.sparse import csr_matrix, issparse 5 | 6 | 7 | def anndata_preprocess( 8 | path: str, 9 | output_path: str, 10 | X_counts: str = "X_counts", 11 | X_log1p: Optional[str] = "X_log1p", 12 | spatial_key: str = "3d_align_spatial", 13 | ): 14 | adata = ad.read_h5ad(filename=path) 15 | 16 | # matrices 17 | X_counts = ( 18 | adata.layers[X_counts].copy() 19 | if issparse(adata.layers[X_counts]) 20 | else csr_matrix(adata.layers[X_counts]) 21 | ) 22 | if not (X_log1p is None): 23 | X_log1p = ( 24 | adata.layers[X_log1p].copy() 25 | if issparse(adata.layers[X_log1p]) 26 | else csr_matrix(adata.layers[X_log1p]) 27 | ) 28 | else: 29 | import dynamo as dyn 30 | 31 | adata.X = X_counts.copy() 32 | dyn.pp.normalize_cell_expr_by_size_factors( 33 | adata=adata, layers="X", skip_log=False 34 | ) 35 | X_log1p = csr_matrix(adata.X.copy()) 36 | 37 | # spatial coordinates 38 | spatial_coords = adata.obsm[spatial_key] 39 | 40 | # preprocess 41 | del adata.uns, adata.layers, adata.obsm, adata.obsp, adata.varm 42 | adata.X = X_counts 43 | adata.layers["X_log1p"] = X_log1p 44 | adata.obsm["spatial"] = spatial_coords 45 | 46 | adata.write_h5ad(output_path, compression="gzip") 47 | return adata 48 | 49 | 50 | """ 51 | import os 52 | os.chdir(f"spateo-viewer/stviewer/assets/dataset") 53 | adata = anndata_preprocess( 54 | path=r"drosophila_E8_9h/h5ad/E8_9h_cellbin_v3.h5ad", 55 | output_path=r"drosophila_E8_9h/h5ad/E8_9h_cellbin_v3_new.h5ad" 56 | ) 57 | print(adata) 58 | """ 59 | -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/h5ad/S11_cellbin_demo.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/h5ad/S11_cellbin_demo.h5ad -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/mesh_models/0_Embryo_S11_aligned_mesh_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/0_Embryo_S11_aligned_mesh_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/mesh_models/1_CNS_S11_aligned_mesh_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/1_CNS_S11_aligned_mesh_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/mesh_models/2_Midgut_S11_aligned_mesh_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/2_Midgut_S11_aligned_mesh_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/mesh_models/3_Hindgut_S11_aligned_mesh_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/3_Hindgut_S11_aligned_mesh_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/mesh_models/4_Muscle_S11_aligned_mesh_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/4_Muscle_S11_aligned_mesh_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/mesh_models/5_SalivaryGland_S11_aligned_mesh_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/5_SalivaryGland_S11_aligned_mesh_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/mesh_models/6_Amnioserosa_S11_aligned_mesh_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/mesh_models/6_Amnioserosa_S11_aligned_mesh_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/pc_models/0_Embryo_S11_aligned_pc_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/0_Embryo_S11_aligned_pc_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/pc_models/1_CNS_S11_aligned_pc_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/1_CNS_S11_aligned_pc_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/pc_models/2_Midgut_S11_aligned_pc_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/2_Midgut_S11_aligned_pc_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/pc_models/3_Hindgut_S11_aligned_pc_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/3_Hindgut_S11_aligned_pc_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/pc_models/4_Muscle_S11_aligned_pc_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/4_Muscle_S11_aligned_pc_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/pc_models/5_SalivaryGland_S11_aligned_pc_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/5_SalivaryGland_S11_aligned_pc_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/drosophila_S11/pc_models/6_Amnioserosa_S11_aligned_pc_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/drosophila_S11/pc_models/6_Amnioserosa_S11_aligned_pc_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/mouse_E95/h5ad/mouse_E95_demo.h5ad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/mouse_E95/h5ad/mouse_E95_demo.h5ad -------------------------------------------------------------------------------- /stviewer/assets/dataset/mouse_E95/matrices/X_sparse_matrix.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/mouse_E95/matrices/X_sparse_matrix.npz -------------------------------------------------------------------------------- /stviewer/assets/dataset/mouse_E95/mesh_models/0_Embryo_mouse_E95_mesh_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/mouse_E95/mesh_models/0_Embryo_mouse_E95_mesh_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset/mouse_E95/pc_models/0_Embryo_mouse_E95_pc_model.vtk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/dataset/mouse_E95/pc_models/0_Embryo_mouse_E95_pc_model.vtk -------------------------------------------------------------------------------- /stviewer/assets/dataset_acquisition.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from pathlib import Path 4 | from typing import Optional, Tuple 5 | 6 | import anndata as ad 7 | import matplotlib as mpl 8 | import numpy as np 9 | import pyvista as pv 10 | from anndata import AnnData 11 | from matplotlib.colors import LinearSegmentedColormap 12 | from pandas import DataFrame 13 | from scipy import sparse 14 | 15 | try: 16 | from typing import Literal 17 | except ImportError: 18 | from typing_extensions import Literal 19 | 20 | 21 | def extract_anndata_structure(adata: AnnData): 22 | # Anndata basic info 23 | obs_str, var_str, uns_str, obsm_str, layers_str = ( 24 | f" obs:", 25 | f" var:", 26 | f" uns:", 27 | f" obsm:", 28 | f" layers:", 29 | ) 30 | 31 | if len(list(adata.obs.keys())) != 0: 32 | for key in list(adata.obs.keys()): 33 | obs_str = obs_str + f" '{key}'," 34 | if len(list(adata.var.keys())) != 0: 35 | for key in list(adata.var.keys()): 36 | var_str = var_str + f" '{key}'," 37 | if len(list(adata.uns.keys())) != 0: 38 | for key in list(adata.uns.keys()): 39 | uns_str = uns_str + f" '{key}'," 40 | if len(list(adata.obsm.keys())) != 0: 41 | for key in list(adata.obsm.keys()): 42 | obsm_str = obsm_str + f" '{key}'," 43 | if len(list(adata.layers.keys())) != 0: 44 | for key in list(adata.layers.keys()): 45 | layers_str = layers_str + f" '{key}'," 46 | 47 | anndata_structure = ( 48 | f"AnnData object with n_obs × n_vars = {adata.shape[0]} × {adata.shape[1]}\n" 49 | ) 50 | for ad_str in [obs_str, var_str, uns_str, obsm_str, layers_str]: 51 | if ad_str.endswith(","): 52 | anndata_structure = anndata_structure + f"{ad_str[:-1]}\n" 53 | return anndata_structure 54 | 55 | 56 | def abstract_anndata(path: str, X_layer: str = "X") -> Tuple[AnnData, str]: 57 | adata = ad.read_h5ad(filename=path) 58 | anndata_structure = extract_anndata_structure(adata=adata) 59 | if X_layer != "X": 60 | assert ( 61 | X_layer in adata.layers.keys() 62 | ), f"``{X_layer}`` does not exist in `adata.layers`." 63 | adata.X = adata.layers[X_layer] 64 | 65 | return adata, anndata_structure 66 | 67 | 68 | def abstract_models(path: str, model_ids: Optional[list] = None): 69 | model_files = os.listdir(path=path) 70 | model_files.sort() 71 | assert len(model_files) != 0, "There is no file under this path." 72 | 73 | models = [pv.read(filename=os.path.join(path, f)) for f in model_files] 74 | if model_ids is None: # Cannot contain `-` and ` `. 75 | model_ids = [f"Model{i}" for i in range(len(models))] 76 | assert len(model_ids) == len( 77 | models 78 | ), "The number of model_ids does not equal to that of models." 79 | 80 | return models, model_ids 81 | 82 | 83 | def sample_dataset( 84 | path: str, 85 | X_layer: str = "X", 86 | pc_model_ids: Optional[list] = None, 87 | mesh_model_ids: Optional[list] = None, 88 | ): 89 | # Generate anndata object 90 | if os.path.isfile(path) and path.endswith(".h5ad"): 91 | anndata_path = path 92 | matrices_npz_path = f"./temp/matrices_{path.split('/')[-1]}" 93 | 94 | adata, anndata_structure = abstract_anndata(path=path, X_layer=X_layer) 95 | elif os.path.isdir(path): 96 | anndata_dir = os.path.join(path, "h5ad") 97 | anndata_list = [ 98 | f for f in os.listdir(path=anndata_dir) if str(f).endswith(".h5ad") 99 | ] 100 | anndata_path = os.path.join(anndata_dir, anndata_list[0]) 101 | matrices_npz_path = os.path.join(path, "matrices") 102 | 103 | adata, anndata_structure = abstract_anndata( 104 | path=os.path.join(anndata_dir, anndata_list[0]), 105 | X_layer=X_layer, 106 | ) 107 | else: 108 | raise ValueError(f"`{path}` is not available for spateo-viewer.") 109 | 110 | ## Generate info-dict of anndata object 111 | anndata_info = { 112 | "anndata_path": anndata_path, 113 | "anndata_structure": anndata_structure, 114 | "anndata_obs_keys": list(adata.obs_keys()), 115 | "anndata_obs_index": list(adata.obs.index.to_list()), 116 | "anndata_var_index": list(adata.var.index.to_list()), 117 | "anndata_obsm_keys": [ 118 | key for key in ["spatial", "X_umap"] if key in adata.obsm.keys() 119 | ], 120 | "anndata_matrices": ["X"] + [i for i in adata.layers.keys()], 121 | "matrices_npz_path": matrices_npz_path, 122 | } 123 | 124 | # Check matrices 125 | if not os.path.exists(anndata_info["matrices_npz_path"]): 126 | Path(anndata_info["matrices_npz_path"]).mkdir(parents=True, exist_ok=True) 127 | for matrix_id in anndata_info["anndata_matrices"]: 128 | matrix = adata.X if matrix_id == "X" else adata.layers[matrix_id] 129 | sparse.save_npz( 130 | f"{anndata_info['matrices_npz_path']}/{matrix_id}_sparse_matrix.npz", 131 | matrix, 132 | ) 133 | else: 134 | pass 135 | 136 | # Generate point cloud models 137 | if os.path.isdir(path) and os.path.exists(os.path.join(path, "pc_models")): 138 | pc_model_files = [ 139 | f 140 | for f in os.listdir(path=os.path.join(path, "pc_models")) 141 | if str(f).endswith(".vtk") or str(f).endswith(".vtm") 142 | ] 143 | pc_model_files.sort() 144 | 145 | if pc_model_ids is None: 146 | pc_model_ids = [f"PC_{str(i).split('_')[1]}" for i in pc_model_files] 147 | _pc_models, pc_model_ids = abstract_models( 148 | path=os.path.join(path, "pc_models"), model_ids=pc_model_ids 149 | ) 150 | else: 151 | bucket_xyz = adata.obsm["spatial"].astype(np.float64) 152 | if isinstance(bucket_xyz, DataFrame): 153 | bucket_xyz = bucket_xyz.values 154 | pc_model = pv.PolyData(bucket_xyz) 155 | pc_model.point_data["obs_index"] = np.array(adata.obs_names.tolist()) 156 | _pc_models, pc_model_ids = [pc_model], ["PC_Model"] 157 | 158 | pc_models = [] 159 | for pc_model in _pc_models: 160 | _obs_index = pc_model.point_data["obs_index"] 161 | for obsm_key in anndata_info["anndata_obsm_keys"]: 162 | coords = np.asarray(adata[_obs_index, :].obsm[obsm_key]) 163 | pc_model.point_data[f"{obsm_key}_X"] = coords[:, 0] 164 | pc_model.point_data[f"{obsm_key}_Y"] = coords[:, 1] 165 | pc_model.point_data[f"{obsm_key}_Z"] = ( 166 | 0 if coords.shape[1] == 2 else coords[:, 2] 167 | ) 168 | 169 | for obs_key in adata.obs_keys(): 170 | array = np.asarray(adata[_obs_index, :].obs[obs_key]) 171 | array = ( 172 | np.asarray(array, dtype=float) 173 | if np.issubdtype(array.dtype, np.number) 174 | else np.asarray(array, dtype=str) 175 | ) 176 | pc_model.point_data[obs_key] = array 177 | pc_models.append(pc_model) 178 | 179 | # Generate mesh models 180 | if os.path.isdir(path) and os.path.exists(os.path.join(path, "mesh_models")): 181 | mesh_model_files = [ 182 | f 183 | for f in os.listdir(path=os.path.join(path, "mesh_models")) 184 | if str(f).endswith(".vtk") or str(f).endswith(".vtm") 185 | ] 186 | mesh_model_files.sort() 187 | 188 | if mesh_model_ids is None: 189 | mesh_model_ids = [f"Mesh_{str(i).split('_')[1]}" for i in mesh_model_files] 190 | mesh_models, mesh_model_ids = abstract_models( 191 | path=os.path.join(path, "mesh_models"), model_ids=mesh_model_ids 192 | ) 193 | else: 194 | mesh_models, mesh_model_ids = None, None 195 | 196 | # Custom colors 197 | custom_colors = [] 198 | for key in adata.uns.keys(): 199 | if str(key).endswith("colors"): 200 | colors = adata.uns[key] 201 | if isinstance(colors, dict): 202 | colors = np.asarray([i for i in colors.values()]) 203 | if isinstance(colors, (np.ndarray, list)): 204 | custom_colors.append(key) 205 | nodes = np.linspace(0, 1, num=len(colors)) 206 | if key not in mpl.colormaps(): 207 | mpl.colormaps.register( 208 | LinearSegmentedColormap.from_list(key, list(zip(nodes, colors))) 209 | ) 210 | 211 | # Delete anndata object 212 | del adata 213 | gc.collect() 214 | 215 | return ( 216 | anndata_info, 217 | pc_models, 218 | pc_model_ids, 219 | mesh_models, 220 | mesh_model_ids, 221 | custom_colors, 222 | ) 223 | -------------------------------------------------------------------------------- /stviewer/assets/dataset_manager.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | from pathlib import Path 4 | 5 | from trame.app.mimetypes import to_mime 6 | 7 | 8 | def to_base64(file_path): 9 | """ 10 | Return the base64 content of the file path. 11 | 12 | Args: 13 | file_path: Path to the file to read. 14 | 15 | Return: 16 | File content encoded in base64 17 | 18 | """ 19 | with open(file_path, "rb") as bin_file: 20 | return base64.b64encode(bin_file.read()).decode("ascii") 21 | 22 | 23 | def to_url(file_path: str): 24 | """ 25 | Return the base64 encoded URL of the file path. 26 | 27 | Args: 28 | file_path: Path to the file to read. 29 | 30 | Return: 31 | Inlined bas64 encoded url (data:{mime};base64,{content}) 32 | """ 33 | encoded = to_base64(file_path) 34 | mime = to_mime(file_path) 35 | return f"data:{mime};base64,{encoded}" 36 | 37 | 38 | class LocalFileManager: 39 | """LocalFileManager provide convenient methods for handling local files""" 40 | 41 | def __init__(self, base_path): 42 | """ 43 | Provide the base path on which relative path should be based on. 44 | base_path: A file or directory path 45 | """ 46 | _base = Path(base_path) 47 | 48 | # Ensure directory 49 | if _base.is_file(): 50 | _base = _base.parent 51 | 52 | self._root = Path(str(_base.resolve().absolute())) 53 | self._assests = {} 54 | 55 | def __getitem__(self, name): 56 | return self._assests.get(name) 57 | 58 | def __getattr__(self, name): 59 | return self._assests.get(name) 60 | 61 | def _to_path(self, file_path): 62 | _input_file = Path(file_path) 63 | if _input_file.is_absolute(): 64 | return str(_input_file.resolve().absolute()) 65 | 66 | return str(self._root.joinpath(file_path).resolve().absolute()) 67 | 68 | def file_url(self, key: str, file_path: str): 69 | """ 70 | Store an url encoded file content under the provided key name. 71 | 72 | Args: 73 | key: The name for that content which can then be accessed by the [] or . notation 74 | file_path: A file path 75 | """ 76 | # if file_path is None: 77 | # file_path, key = key, file_path 78 | 79 | # data = to_url(self._to_path(file_path)) 80 | 81 | if key is not None: 82 | self._assests[key] = file_path 83 | 84 | return file_path 85 | 86 | def dir_url(self, key: str, dir_path: str): 87 | """ 88 | Store a directory url under the provided key name. 89 | 90 | Args: 91 | key: The name for that content which can then be accessed by the [] or . notation 92 | dir_path: A directory path 93 | """ 94 | if dir_path is None: 95 | dir_path, key = key, dir_path 96 | 97 | data = self._to_path(dir_path) 98 | 99 | if key is not None: 100 | self._assests[key] = data 101 | 102 | return data 103 | 104 | @property 105 | def assets(self): 106 | """Return the full set of assets as a dict""" 107 | return self._assests 108 | 109 | def get_assets(self, *keys): 110 | """Return a filtered out dict using the provided set of keys""" 111 | if len(keys) == 0: 112 | return self.assets 113 | 114 | _assets = {} 115 | for key in keys: 116 | _assets[key] = self._assests.get(key) 117 | 118 | return _assets 119 | 120 | 121 | local_dataset_manager = LocalFileManager(__file__) 122 | if os.path.exists(r"./stviewer/assets/dataset/mouse_E95"): 123 | local_dataset_manager.dir_url("mouse_E95", r"./dataset/mouse_E95") 124 | if os.path.exists(r"./stviewer/assets/dataset/mouse_E115"): 125 | local_dataset_manager.dir_url("mouse_E115", r"./dataset/mouse_E115") 126 | if os.path.exists(r"./stviewer/assets/dataset/drosophila_S11"): 127 | local_dataset_manager.dir_url("drosophila_S11", r"./dataset/drosophila_S11") 128 | 129 | if os.path.exists(r"./stviewer/assets/dataset/drosophila_S11/h5ad/S11_cellbin.h5ad"): 130 | local_dataset_manager.file_url( 131 | "drosophila_S11_anndata", 132 | r"./stviewer/assets/dataset/drosophila_S11/h5ad/S11_cellbin.h5ad", 133 | ) 134 | 135 | """ 136 | # ----------------------------------------------------------------------------- 137 | # Data file information 138 | # ----------------------------------------------------------------------------- 139 | from trame.assets.remote import HttpFile 140 | dataset_file = HttpFile( 141 | "./data/disk_out_ref.vtu", 142 | "https://github.com/Kitware/trame/raw/master/examples/data/disk_out_ref.vtu", 143 | __file__, 144 | ) 145 | """ 146 | -------------------------------------------------------------------------------- /stviewer/assets/image/interactive_viewer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/interactive_viewer.png -------------------------------------------------------------------------------- /stviewer/assets/image/spateo_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/spateo_logo.png -------------------------------------------------------------------------------- /stviewer/assets/image/spateoviewer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/spateoviewer.png -------------------------------------------------------------------------------- /stviewer/assets/image/static_viewer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/static_viewer.png -------------------------------------------------------------------------------- /stviewer/assets/image/upload_file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/upload_file.png -------------------------------------------------------------------------------- /stviewer/assets/image/upload_folder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/stviewer/assets/image/upload_folder.png -------------------------------------------------------------------------------- /stviewer/assets/image_manager.py: -------------------------------------------------------------------------------- 1 | from trame.assets.local import LocalFileManager 2 | 3 | icon_manager = LocalFileManager(__file__) 4 | icon_manager.url("spateo_logo", "./image/spateo_logo.png") 5 | -------------------------------------------------------------------------------- /stviewer/explorer_app.py: -------------------------------------------------------------------------------- 1 | try: 2 | from typing import Literal 3 | except ImportError: 4 | from typing_extensions import Literal 5 | 6 | from tkinter import Tk, filedialog 7 | 8 | import matplotlib.pyplot as plt 9 | from trame.widgets import trame as trame_widgets 10 | 11 | from .assets import icon_manager, local_dataset_manager 12 | from .Explorer import ( 13 | create_plotter, 14 | init_actors, 15 | init_adata_parameters, 16 | init_card_parameters, 17 | init_custom_parameters, 18 | init_interpolation_parameters, 19 | init_mesh_parameters, 20 | init_morphogenesis_parameters, 21 | init_output_parameters, 22 | init_pc_parameters, 23 | ui_container, 24 | ui_drawer, 25 | ui_layout, 26 | ui_toolbar, 27 | ) 28 | from .server import get_trame_server 29 | 30 | # export WSLINK_MAX_MSG_SIZE=1000000000 # 1GB 31 | 32 | # Get a Server to work with 33 | static_server = get_trame_server(name="spateo_explorer") 34 | state, ctrl = static_server.state, static_server.controller 35 | state.trame__title = "SPATEO VIEWER" 36 | state.trame__favicon = icon_manager.spateo_logo 37 | state.setdefault("active_ui", None) 38 | 39 | # Generate a new plotter 40 | plotter = create_plotter() 41 | # Init model 42 | ( 43 | anndata_info, 44 | actors, 45 | actor_names, 46 | actor_tree, 47 | custom_colors, 48 | ) = init_actors( 49 | plotter=plotter, 50 | path=local_dataset_manager.mouse_E95, 51 | ) 52 | 53 | # Init parameters 54 | state.update(init_card_parameters) 55 | state.update(init_adata_parameters) 56 | state.update(init_pc_parameters) 57 | state.update(init_mesh_parameters) 58 | state.update(init_morphogenesis_parameters) 59 | state.update(init_interpolation_parameters) 60 | state.update(init_output_parameters) 61 | state.update( 62 | { 63 | "init_dataset": True, 64 | "anndata_info": anndata_info, 65 | "pc_point_size_value": 4, 66 | "pc_obs_value": "mapped_celltype", 67 | "available_obs": ["None"] + anndata_info["anndata_obs_keys"], 68 | "available_genes": ["None"] + anndata_info["anndata_var_index"], 69 | "pc_colormaps_list": ["spateo_cmap"] + custom_colors + plt.colormaps(), 70 | # setting 71 | "actor_ids": actor_names, 72 | "pipeline": actor_tree, 73 | "active_id": 1, 74 | "active_ui": actor_names[0], 75 | "active_model_type": str(actor_names[0]).split("_")[0], 76 | "vis_ids": [ 77 | i for i, actor in enumerate(plotter.actors.values()) if actor.visibility 78 | ], 79 | } 80 | ) 81 | 82 | # Custom init parameters 83 | if init_custom_parameters["custom_func"] is True: 84 | state.update(init_custom_parameters) 85 | else: 86 | state.update({"custom_func": False}) 87 | 88 | 89 | # Upload directory 90 | def open_directory(): 91 | dirpath = filedialog.askdirectory(title="Select Directory") 92 | if not dirpath: 93 | return 94 | state.selected_dir = dirpath 95 | ctrl.view_update() 96 | 97 | 98 | root = Tk() 99 | root.withdraw() 100 | root.wm_attributes("-topmost", 1) 101 | state.selected_dir = "None" 102 | ctrl.open_directory = open_directory 103 | 104 | 105 | # GUI 106 | ui_standard_layout = ui_layout(server=static_server, template_name="main") 107 | with ui_standard_layout as layout: 108 | # Let the server know the browser pixel ratio and the default theme 109 | trame_widgets.ClientTriggers( 110 | mounted="pixel_ratio = window.devicePixelRatio, $vuetify.theme.dark = true" 111 | ) 112 | 113 | # ----------------------------------------------------------------------------- 114 | # ToolBar 115 | # ----------------------------------------------------------------------------- 116 | ui_toolbar( 117 | server=static_server, 118 | layout=layout, 119 | plotter=plotter, 120 | mode="trame", 121 | ui_name="SPATEO VIEWER (EXPLORER)", 122 | ) 123 | 124 | # ----------------------------------------------------------------------------- 125 | # Drawer 126 | # ----------------------------------------------------------------------------- 127 | ui_drawer(server=static_server, layout=layout, plotter=plotter, mode="trame") 128 | 129 | # ----------------------------------------------------------------------------- 130 | # Main Content 131 | # ----------------------------------------------------------------------------- 132 | ui_container(server=static_server, layout=layout, plotter=plotter, mode="trame") 133 | 134 | # ----------------------------------------------------------------------------- 135 | # Footer 136 | # ----------------------------------------------------------------------------- 137 | layout.footer.hide() 138 | # layout.flush_content() 139 | -------------------------------------------------------------------------------- /stviewer/reconstructor_app.py: -------------------------------------------------------------------------------- 1 | try: 2 | from typing import Literal 3 | except ImportError: 4 | from typing_extensions import Literal 5 | 6 | from trame.widgets import trame as trame_widgets 7 | from vtkmodules.web.utils import mesh as vtk_mesh 8 | 9 | from .assets import icon_manager, local_dataset_manager 10 | from .Reconstructor import ( 11 | create_plotter, 12 | init_active_parameters, 13 | init_align_parameters, 14 | init_custom_parameters, 15 | init_mesh_parameters, 16 | init_models, 17 | init_picking_parameters, 18 | init_setting_parameters, 19 | ui_container, 20 | ui_drawer, 21 | ui_layout, 22 | ui_toolbar, 23 | ) 24 | from .server import get_trame_server 25 | 26 | # export WSLINK_MAX_MSG_SIZE=1000000000 # 1GB 27 | 28 | # Get a Server to work with 29 | interactive_server = get_trame_server(name="spateo_reconstructor") 30 | state, ctrl = interactive_server.state, interactive_server.controller 31 | state.trame__title = "SPATEO VIEWER" 32 | state.trame__favicon = icon_manager.spateo_logo 33 | state.setdefault("active_ui", None) 34 | 35 | # Generate anndata object 36 | plotter = create_plotter() 37 | init_anndata_path = local_dataset_manager.drosophila_S11_anndata 38 | main_model, active_model, init_scalar, pdd, cdd = init_models( 39 | plotter=plotter, anndata_path=init_anndata_path 40 | ) 41 | 42 | # Init parameters 43 | state.update(init_active_parameters) 44 | state.update(init_picking_parameters) 45 | state.update(init_align_parameters) 46 | state.update(init_mesh_parameters) 47 | state.update(init_setting_parameters) 48 | state.update( 49 | { 50 | "init_anndata": init_anndata_path, 51 | "upload_anndata": None, 52 | # main model 53 | "mainModel": vtk_mesh( 54 | main_model, 55 | point_arrays=[key for key in pdd.keys()], 56 | cell_arrays=[key for key in cdd.keys()], 57 | ), 58 | "activeModel": vtk_mesh( 59 | active_model, 60 | point_arrays=[key for key in pdd.keys()], 61 | cell_arrays=[key for key in cdd.keys()], 62 | ), 63 | "scalar": "anno_tissue", 64 | "scalarParameters": {**pdd, **cdd}, 65 | } 66 | ) 67 | # Custom init parameters 68 | if init_custom_parameters["custom_func"] is True: 69 | state.update(init_custom_parameters) 70 | else: 71 | state.update({"custom_func": False}) 72 | 73 | 74 | # GUI 75 | ui_standard_layout = ui_layout(server=interactive_server, template_name="main") 76 | with ui_standard_layout as layout: 77 | # Let the server know the browser pixel ratio and the default theme 78 | trame_widgets.ClientTriggers( 79 | mounted="pixel_ratio = window.devicePixelRatio, $vuetify.theme.dark = true" 80 | ) 81 | 82 | # ----------------------------------------------------------------------------- 83 | # ToolBar 84 | # ----------------------------------------------------------------------------- 85 | ui_toolbar( 86 | server=interactive_server, 87 | layout=layout, 88 | plotter=plotter, 89 | ui_name="SPATEO VIEWER (RECONSTRUCTOR)", 90 | ) 91 | trame_widgets.ClientStateChange(name="activeModel", change=ctrl.view_reset_camera) 92 | # ----------------------------------------------------------------------------- 93 | # Drawer 94 | # ----------------------------------------------------------------------------- 95 | ui_drawer(server=interactive_server, layout=layout) 96 | 97 | # ----------------------------------------------------------------------------- 98 | # Main Content 99 | # ----------------------------------------------------------------------------- 100 | ui_container(server=interactive_server, layout=layout) 101 | 102 | # ----------------------------------------------------------------------------- 103 | # Footer 104 | # ----------------------------------------------------------------------------- 105 | layout.footer.hide() 106 | # layout.flush_content() 107 | -------------------------------------------------------------------------------- /stviewer/server.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from trame.app import get_server 4 | 5 | # ----------------------------------------------------------------------------- 6 | # Server 7 | # ----------------------------------------------------------------------------- 8 | 9 | 10 | def get_trame_server(name: Optional[str] = None, **kwargs): 11 | """ 12 | Return a server for serving trame applications. If a name is given and such server is not available yet, it will be 13 | created otherwise the previously created instance will be returned. 14 | 15 | Args: 16 | name: A server name identifier which can be useful when several servers are expected to be created. 17 | kwargs: Additional parameters that will be passed to ``trame.app.get_server`` function. 18 | 19 | Returns: 20 | Return a unique Server instance per given name. 21 | """ 22 | 23 | server = get_server(name=name, **kwargs) 24 | return server 25 | -------------------------------------------------------------------------------- /usage/ExplorerUsage.md: -------------------------------------------------------------------------------- 1 | 2 | ## 10 minutes to static-viewer 3 | 4 | Welcome to static-viewer! 5 | 6 | static-viewer is a web application for visualizing various models created from spatial transcriptomics data in 3D space, 7 | including point cloud models, mesh models, trajectory models, vector field models, etc. 8 | 9 |

10 | 11 |

12 | 13 | ## How to use 14 | 15 | ### Installation 16 | You can clone the [**Spateo-Viewer**](https://github.com/aristoteleo/spateo-viewer) with ``git`` and install dependencies with ``pip``: 17 | 18 | git clone https://github.com/aristoteleo/spateo-viewer.git 19 | cd spateo-viewer 20 | pip install -r requirements.txt 21 | 22 | ### Running (Users could change the port) 23 | 24 | python stv_static_app.py --port 1234 25 | 26 | ### Folder Structure of the upload data 27 | 28 | ``` 29 | ├── drosophila_E7_8h # The folder name 30 |    ├── h5ad # (Required) The folder includes an anndata object (.h5ad file) 31 |    │   └── E7_8h_cellbin.h5ad # (Required) The only one anndata object (.h5ad file) 32 |    └── mesh_models # (Optional) The folder includes mesh models (.vtk files) 33 |    │   ├── 0_Embryo_E7_8h_aligned_mesh_model.vtk # (Optional) The filename start with an ordinal and end with "_mesh_model.vtk" 34 |    │   └── 1_CNS_E7_8h_aligned_mesh_model.vtk # (Optional) The filename start with an ordinal and end with "_mesh_model.vtk" 35 |    └── pc_models # (Optional) The folder includes point cloud models (.vtk files) 36 |       ├── 0_Embryo_E7_8h_aligned_pc_model.vtk # (Optional) The filename start with an ordinal and end with "_pc_model.vtk" 37 |       └── 1_CNS_E7_8h_aligned_pc_model.vtk # (Optional) The filename start with an ordinal and end with "_pc_model.vtk" 38 | ``` 39 | 40 | You can refer to the folder structure of the data we include by default in the [**dataset**](https://github.com/aristoteleo/spateo-viewer/blob/main/stviewer/assets/dataset). 41 | 42 | ### How to upload data 43 | 44 | 1. Upload folder via the tools included in the toolbar in the web application: 45 | 46 | ![UploadFolder](https://github.com/aristoteleo/spateo-viewer/blob/main/stviewer/assets/image/upload_folder.png) 47 | 48 | 2. Upload folder via the ``stv_static_app.py``: 49 | 50 | ``` 51 | from stviewer.static_app import static_server, state 52 | 53 | if __name__ == "__main__": 54 | state.selected_dir = None 55 | static_server.start() 56 | ``` 57 | 58 | Change None in ``state.selected_dir = None`` to the absolute path of the folder you want to upload.(Please give priority to this method when used in remote servers) 59 | 60 | 61 | ### How to upload custom colors 62 | The colors to be uploaded should be saved in the uploaded anndata object, which should be stored in anndata.uns in the 63 | form of a list. Examples are as follows: 64 | 65 | Uploaded colors list (the key of colors ends with "colors"): 66 | 67 | ``` 68 | adata.uns["anno_tissue_list_colors"] = ["#ef9b20", "#f46a9b", "#ece05a", "#ede15b", "#ea5545", "#9a82e0", "#87bc45", "#ec5646", "#bdcf32"] 69 | ``` 70 | -------------------------------------------------------------------------------- /usage/ReconstructorUsage.md: -------------------------------------------------------------------------------- 1 | 2 | ## 10 minutes to interactive-viewer 3 | 4 | Welcome to interactive-viewer! 5 | 6 | interactive-viewer is a web application for interactively 3D reconstruction of spatial transcriptomics in 3D space. 7 | 8 |

9 | 10 |

11 | 12 | ## How to use 13 | 14 | ### Installation 15 | You can clone the [**Spateo-Viewer**](https://github.com/aristoteleo/spateo-viewer) with ``git`` and install dependencies with ``pip``: 16 | 17 | git clone https://github.com/aristoteleo/spateo-viewer.git 18 | cd spateo-viewer 19 | pip install -r requirements.txt 20 | 21 | ### Running (Users could change the port) 22 | 23 | python stv_interactive_app.py --port 1234 24 | 25 | ### How to generate the anndata object to upload 26 | 27 | ``` 28 | import spateo as st 29 | 30 | # Load the anndata object 31 | adata = st.read_h5ad("E7_8h_cellbin.h5ad") 32 | 33 | # Make sure adata.obsm contains 'spatial' to save the coordinates 34 | adata.obsm['spatial'] = adata.obsm['3d_align_spatial'] 35 | 36 | # Interactive-viewer will read all info contained in anndata.obs, so please make sure the info you need has been saved in anndata.obs 37 | 38 | # Save the anndata object 39 | adata.write_h5ad("E7_8h_cellbin.h5ad", compression="gzip") 40 | 41 | ``` 42 | 43 | You can refer to the data structure we include by default in the [**dataset**](https://github.com/aristoteleo/spateo-viewer/blob/main/stviewer/assets/dataset/drosophila_E7_8h/pc_models/0_Embryo_E7_8h_aligned_pc_model.vtk). 44 | 45 | ### How to upload data 46 | 47 | 1. Upload file via the tool included in the toolbar in the web application: 48 | 49 | ![UploadFile](https://github.com/aristoteleo/spateo-viewer/blob/main/stviewer/assets/image/upload_file.png) 50 | 51 | 2. Upload folder via the ``stv_interactive_app.py``: 52 | 53 | ``` 54 | from stviewer.interactive_app import interactive_server, state 55 | 56 | if __name__ == "__main__": 57 | **state.upload_anndata = None** 58 | interactive_server.start() 59 | ``` 60 | 61 | Change None in ``state.upload_anndata = None`` to the absolute path of the file you want to upload.(Please give priority to this method when used in remote servers) 62 | 63 | -------------------------------------------------------------------------------- /usage/spateo-viewer.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aristoteleo/spateo-viewer/4c3b5404009686418a680bd56d0dd0260420c512/usage/spateo-viewer.pdf --------------------------------------------------------------------------------