├── .gitignore ├── LICENSE ├── requirements.txt ├── CITATION.cff ├── environment.yml ├── relion_jobs ├── aligntilt_job.py ├── tomograms_job.py ├── class2d_job.py ├── postprocess_job.py ├── ctffind_job.py ├── picking_job.py ├── mask_job.py ├── localres_job.py ├── excludetilt_job.py ├── polish_job.py ├── import_job.py ├── class3d_job.py └── motioncorr_job.py ├── static └── frg.svg ├── README.md └── follow_relion_gracefully.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.egg-info/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dzyla/Follow_Relion_gracefully/HEAD/LICENSE -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<2.0 2 | pandas 3 | mrcfile 4 | scikit-image 5 | scipy 6 | matplotlib 7 | seaborn 8 | pillow 9 | plotly 10 | streamlit 11 | pyvis 12 | gemmi 13 | stqdm 14 | PyMCubes 15 | altair 16 | morphosamplers -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Zyla" 5 | given-names: "Dawid" 6 | orcid: "https://orcid.org/0000-0001-8471-469X" 7 | title: "Follow_Relion_gracefully" 8 | version: v5 9 | doi: 10.5281/zenodo.10465899 10 | date-released: 2024-01-06 11 | url: "https://github.com/dzyla/Follow_Relion_gracefully" 12 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: FollowRelion 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.12 7 | - pip 8 | - numpy<2.0 9 | - pandas 10 | - mrcfile 11 | - scikit-image 12 | - scipy 13 | - matplotlib 14 | - seaborn 15 | - pillow 16 | - plotly 17 | - streamlit 18 | - pyvis 19 | - gemmi 20 | - altair 21 | - pip: 22 | - stqdm 23 | - PyMCubes 24 | - morphosamplers 25 | -------------------------------------------------------------------------------- /relion_jobs/aligntilt_job.py: -------------------------------------------------------------------------------- 1 | # aligntilt_job.py 2 | 3 | import os 4 | import logging 5 | from typing import List 6 | 7 | import streamlit as st 8 | import pandas as pd 9 | 10 | from lib.utils import parse_star, get_values_from_first_key, report_error 11 | 12 | logger = logging.getLogger("main_app") 13 | 14 | 15 | def plot_align_tilt_series(rln_folder: str, node_file: str) -> None: 16 | """Plot alignment tilt series for a RELION node star file. 17 | 18 | Args: 19 | rln_folder (str): Path to the RELION folder containing star files. 20 | node_file (str): Name of the primary star file for the tilt series. 21 | """ 22 | try: 23 | logger.info("Plotting alignment tilt series for node: %s", node_file) 24 | star_path = os.path.join(rln_folder, node_file) 25 | 26 | if not os.path.isfile(star_path): 27 | logger.error("Star file not found: %s", star_path) 28 | st.warning(f"Star file not found: {star_path}") 29 | return 30 | 31 | logger.debug("Parsing star file at: %s", star_path) 32 | try: 33 | star = parse_star(star_path) 34 | except Exception as exc: 35 | report_error(exc, f"Failed to parse star file: {star_path}") 36 | st.warning("Failed to parse the main star file.") 37 | return 38 | 39 | if "global" not in star: 40 | logger.error("Missing 'global' section in star file: %s", star_path) 41 | st.warning("The star file does not contain a 'global' section.") 42 | return 43 | 44 | tomo_files = star["global"].get("_rlnTomoTiltSeriesStarFile", []) 45 | if not isinstance(tomo_files, list) or not tomo_files: 46 | logger.error("No tilt-series references found in 'global' section.") 47 | st.warning("No tilt-series data found in the star file.") 48 | return 49 | 50 | dfs_tilt: List[pd.DataFrame] = [] 51 | for tomo_file in tomo_files: 52 | path = os.path.join(rln_folder, tomo_file) 53 | try: 54 | star_data = parse_star(path) 55 | df = get_values_from_first_key(star_data) 56 | df["FileSource"] = os.path.basename(path) 57 | dfs_tilt.append(df) 58 | except Exception as exc: 59 | logger.error("Error parsing star file %s: %s", path, exc) 60 | report_error(exc, f"Error parsing tilt-series star file: {path}") 61 | 62 | if not dfs_tilt: 63 | logger.error("Parsed tilt-series data is empty.") 64 | st.warning("No tilt-series data found. Cannot plot tilt angles.") 65 | return 66 | 67 | # TODO: Implement plotting logic here using dfs_tilt 68 | except Exception as exc: 69 | report_error(exc, "Unexpected error in plot_align_tilt_series") 70 | st.warning("An unexpected error occurred while plotting tilt series.") 71 | -------------------------------------------------------------------------------- /relion_jobs/tomograms_job.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | import logging 4 | 5 | import streamlit as st 6 | import json 7 | 8 | from lib.image_utils import micrograph_viewer 9 | 10 | from lib.utils import ( 11 | parse_star, 12 | get_values_from_first_key, 13 | ) 14 | 15 | logger = logging.getLogger("main_app") 16 | 17 | def report_error(exc: Exception) -> None: 18 | """ 19 | Report an error using the global error handler if available. 20 | Otherwise, log the error with full traceback. 21 | """ 22 | error_info = traceback.format_exc() 23 | logger.error("An unexpected error occurred:\n%s", error_info) 24 | 25 | 26 | def plot_tomographs(rln_folder: str, node_files: str) -> None: 27 | logger.info(f"Plotting tomographs... rln_folder= {rln_folder}, node_files= {node_files}") 28 | 29 | star_path = os.path.join(rln_folder, node_files) 30 | logger.debug(f"star_path: {star_path}") 31 | 32 | if not os.path.exists(star_path): 33 | st.error(f"Star file not found: {star_path}") 34 | return 35 | 36 | try: 37 | star = parse_star(star_path) 38 | except Exception as exc: 39 | report_error(exc) 40 | st.error("Failed to parse the main star file.") 41 | return 42 | 43 | if "global" not in star: 44 | st.error("The star file does not contain a 'global' section.") 45 | return 46 | 47 | logger.debug(f"Parsed star file: {star}") 48 | 49 | # Gather paths to tilt-series star files from the exclude job 50 | tomo_star_files = star["global"]["_rlnTomoTiltSeriesStarFile"] 51 | tomo_mrc_files = star["global"]["_rlnTomoReconstructedTomogramHalf1"] 52 | tomo_star_files_paths = [os.path.join(rln_folder, f) for f in tomo_star_files] 53 | 54 | # Parse each tilt-series star file 55 | dfs_tilt = [] 56 | for path_tomo_star in tomo_star_files_paths: 57 | try: 58 | star_data = parse_star(path_tomo_star) 59 | if star_data: 60 | df_tilt = get_values_from_first_key(star_data) 61 | df_tilt["FileSource"] = os.path.basename(path_tomo_star) 62 | dfs_tilt.append(df_tilt) 63 | except Exception as exc: 64 | logger.error(f"Error parsing star file {path_tomo_star}: {exc}") 65 | report_error(exc) 66 | 67 | if not dfs_tilt: 68 | st.error("No tilt-series data found. Cannot plot tilt angles.") 69 | return 70 | 71 | if 'Denoise' in node_files: 72 | logger.info("Denoised tomograms detected.") 73 | 74 | if '_rlnTomoReconstructedTomogramDenoised' not in star["global"]: 75 | logger.info("Denoise training job detected") 76 | config_path = os.path.join(os.path.dirname(star_path), "external/training/train_data_config.json") 77 | if os.path.exists(config_path): 78 | config_json = json.load(open(config_path)) 79 | st.write("**Denoising training job config:**") 80 | st.json(config_json) 81 | 82 | 83 | else: 84 | tomo_mrc_files = star["global"]["_rlnTomoReconstructedTomogramDenoised"] 85 | micrograph_viewer(rln_folder=rln_folder, image_files=tomo_mrc_files) 86 | else: 87 | micrograph_viewer(rln_folder=rln_folder, image_files=tomo_mrc_files) -------------------------------------------------------------------------------- /relion_jobs/class2d_job.py: -------------------------------------------------------------------------------- 1 | # class2d_job.py 2 | """ 3 | Module handling 2D classification visualization and class selection. 4 | """ 5 | 6 | import os 7 | import glob 8 | import logging 9 | from datetime import datetime 10 | from typing import List 11 | 12 | import numpy as np 13 | import streamlit as st 14 | import matplotlib.pyplot as plt 15 | 16 | # ------------------------------------------------------------------ 17 | # Import your existing shared utilities. Adjust imports as needed. 18 | # ------------------------------------------------------------------ 19 | from lib.utils import ( 20 | interactive_scatter_plot, 21 | get_classes, 22 | ) 23 | from relion_jobs.select_job import display_classes # or wherever display_classes is defined 24 | 25 | 26 | logger = logging.getLogger("main_app") 27 | 28 | 29 | def plot_class_distribution(class_dist_: np.ndarray, PLOT_HEIGHT=500): 30 | """ 31 | Plots iteration-by-iteration distribution for 2D classes. 32 | Expects class_dist_ shape = (num_classes, num_iterations). 33 | """ 34 | import plotly.graph_objects as go 35 | 36 | fig = go.Figure() 37 | for n, class_ in enumerate(class_dist_): 38 | class_ = class_.astype(float) * 100 39 | x = np.arange(0, class_dist_.shape[1]) 40 | fig.add_trace( 41 | go.Scatter( 42 | x=x, 43 | y=class_, 44 | name=f"Class {n + 1}", 45 | showlegend=True, 46 | hovertemplate=( 47 | f"Class {n + 1}
Iteration: %{{x}}" 48 | "
Cls dist: %{y:.2f}%" 49 | "" 50 | ), 51 | mode="lines", 52 | stackgroup="one", 53 | ) 54 | ) 55 | 56 | fig.update_xaxes(title_text="Iteration") 57 | fig.update_yaxes(title_text="Class distribution (%)") 58 | fig.update_layout(title="Class distribution over iterations") 59 | fig.update_layout(hovermode="x unified", height=PLOT_HEIGHT) 60 | 61 | st.plotly_chart(fig, use_container_width=True, height=PLOT_HEIGHT) 62 | 63 | 64 | def plot_class2d(rln_folder: str, nodes: List[str]) -> None: 65 | """ 66 | Show 2D classes for a given job, let user select classes, generate new star file, 67 | and optionally plot iteration-by-iteration distribution or do interactive plotting. 68 | 69 | Parameters: 70 | rln_folder: Root folder for the job data. 71 | nodes: List of relevant star/mrc/etc. files produced by the job. 72 | """ 73 | logger.debug(f"{datetime.now()}: plot_class2d started with {nodes}") 74 | st.subheader("2D Classification Job") 75 | 76 | # 1) Derive the job base folder from the first node path 77 | job_basefolder = os.path.join(rln_folder, os.path.dirname(nodes[0])) 78 | data_star = os.path.join(job_basefolder, nodes[0]) 79 | logger.debug(f"Job base folder: {job_basefolder}") 80 | 81 | # 2) Gather model.star files 82 | model_files = sorted( 83 | glob.glob(os.path.join(job_basefolder, "*model.star")), 84 | key=os.path.getmtime, 85 | ) 86 | if not model_files: 87 | st.write("No model files found.") 88 | return 89 | 90 | # 3) Basic info about classes from the last model 91 | class_paths, n_classes, iter_, class_dist, class_res, _, _ = get_classes( 92 | job_basefolder, [model_files[-1]] 93 | ) 94 | class_dist = np.squeeze(class_dist) 95 | logger.debug(f"Class paths: {class_paths}") 96 | 97 | if not len(class_paths): 98 | st.write("No classes found in the last model.") 99 | return 100 | 101 | # We place the class display in an expander so it is collapsible 102 | with st.expander("Class Averages and Selections", expanded=True): 103 | # This controls the UI for sorting classes 104 | col_sort, _ = st.columns([1, 10]) 105 | sort_classes = col_sort.checkbox("Sort classes?", True) 106 | 107 | # Show each class path 108 | for class_path in class_paths: 109 | # `display_classes` is from relion_jobs.select_job 110 | # Keep the call arguments as is for backward compatibility 111 | logger.debug(f"2D Class path: {class_path}") 112 | display_classes( 113 | class_path=class_path, 114 | class_distribution=class_dist, 115 | sort_by_distribution=sort_classes, 116 | raw_data_star=nodes[0], # pass the first node star 117 | rln_folder=job_basefolder 118 | ) 119 | 120 | # Now gather all model files for iteration-based distribution 121 | with st.expander("Distribution Over Iterations", expanded=False): 122 | class_paths_all, n_classes_all, iter_all, class_dist_all, class_res_all, _, _ = get_classes( 123 | job_basefolder, model_files 124 | ) 125 | # Plot distribution if it is not empty 126 | if class_dist_all.size > 0: 127 | plot_class_distribution(class_dist_all, PLOT_HEIGHT=500) 128 | else: 129 | st.write("No iteration-based distribution data available.") 130 | 131 | if st.checkbox("Show detailed statistics?"): 132 | interactive_scatter_plot(os.path.join(rln_folder, nodes[0])) 133 | st.info("Displaying interactive scatter plot for the main STAR file.") 134 | 135 | plt.close() 136 | logger.info(f"{datetime.now()}: plot_class2d done.") 137 | -------------------------------------------------------------------------------- /static/frg.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | 7 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /relion_jobs/postprocess_job.py: -------------------------------------------------------------------------------- 1 | #postprocess_job.py 2 | 3 | import os 4 | import logging 5 | import math 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import streamlit as st 11 | import plotly.graph_objects as go 12 | import plotly.figure_factory as ff 13 | from skimage.transform import resize 14 | import mrcfile 15 | import seaborn as sns 16 | 17 | 18 | # Import your shared utilities. 19 | from lib.utils import parse_star, interactive_scatter_plot 20 | from lib.image_utils import normalize, display_volume_slices, plot_volume 21 | from relion_jobs.select_job import display_classes 22 | 23 | # Setup logger 24 | logger = logging.getLogger("main_app") 25 | 26 | 27 | def _load_postprocess_data(rln_folder: str, postprocess_star_path: str) -> dict: 28 | """ 29 | Load the postprocess.star file using parse_star and cache the result in session state. 30 | If the current postprocess file is the same as the cached one, return the cached data. 31 | """ 32 | cache_key = "postprocess_data" 33 | current_file_key = "current_postprocess_file" 34 | if ( 35 | current_file_key in st.session_state 36 | and st.session_state[current_file_key] == postprocess_star_path 37 | and cache_key in st.session_state 38 | ): 39 | logger.debug("Using cached postprocess star data.") 40 | return st.session_state[cache_key] 41 | 42 | try: 43 | logger.debug(f"Loading postprocess star file: {postprocess_star_path}") 44 | postprocess_data = parse_star(postprocess_star_path) 45 | st.session_state[cache_key] = postprocess_data 46 | st.session_state[current_file_key] = postprocess_star_path 47 | return postprocess_data 48 | except Exception as e: 49 | logger.error(f"Error loading postprocess star file: {e}") 50 | return {} 51 | 52 | 53 | def plot_postprocess(rln_folder, nodes): 54 | """ 55 | Main function to plot postprocess data in Streamlit. 56 | This function: 57 | 1) Loads (or retrieves from session state) the postprocess.star file. 58 | 2) Plots the FSC and Guinier curves side by side. 59 | 3) Displays masked volume slices. 60 | """ 61 | logger.debug(f"{datetime.now()}: plot_postprocess started for {rln_folder} with {nodes}") 62 | 63 | # Assume postprocess.star file is the 4th node. 64 | postprocess_star_path = os.path.join(rln_folder, nodes[3]) 65 | if not os.path.exists(postprocess_star_path): 66 | st.write("No postprocess.star file found") 67 | return 68 | 69 | # Load and cache postprocess data. 70 | postprocess_data = _load_postprocess_data(rln_folder, postprocess_star_path) 71 | if not postprocess_data: 72 | st.write("Failed to load postprocess.star data") 73 | return 74 | 75 | # Directly use Plotly for all plotting (remove Seaborn option). 76 | try: 77 | fsc_data = postprocess_data["fsc"].astype(float) 78 | guinier_data = postprocess_data["guinier"].astype(float) 79 | except Exception as e: 80 | st.write("Error processing postprocess.star data") 81 | logger.error(f"Error converting FSC/Guinier data to float: {e}") 82 | return 83 | 84 | st.write("## Postprocess Data Visualization") 85 | # Arrange FSC and Guinier plots into two columns. 86 | col_fsc, col_guinier = st.columns(2) 87 | with col_fsc: 88 | st.markdown("### Fourier Shell Correlation (FSC) Curve") 89 | plot_fsc_curve(fsc_data) 90 | with col_guinier: 91 | st.markdown("### Guinier Plot") 92 | plot_guinier_curve(guinier_data) 93 | 94 | # Display Masked Volume Slices. 95 | with st.expander("Volume preview"): 96 | display_volume_slices(rln_folder, nodes) 97 | 98 | logger.info(f"{datetime.now()}: plot_postprocess done") 99 | 100 | 101 | def plot_fsc_curve(fsc_data, dpi=150): 102 | """ 103 | Plot the Fourier Shell Correlation (FSC) curve using Plotly. 104 | FSC curves for different map types are colored using Seaborn's 'deep' palette. 105 | """ 106 | # Retrieve FSC resolution data and convert to float. 107 | fsc_x = fsc_data["_rlnAngstromResolution"].astype(float) 108 | fsc_x_min = np.min(fsc_x) 109 | 110 | # FSC fields to plot. 111 | fsc_to_plot = [ 112 | "_rlnFourierShellCorrelationCorrected", 113 | "_rlnFourierShellCorrelationUnmaskedMaps", 114 | "_rlnFourierShellCorrelationMaskedMaps", 115 | "_rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps", 116 | ] 117 | 118 | # Use Seaborn's default "deep" palette for colors. 119 | palette = sns.color_palette("deep").as_hex() 120 | 121 | fig = go.Figure() 122 | for i, meta in enumerate(fsc_to_plot): 123 | color = palette[i % len(palette)] 124 | reciprocal_x = 1 / fsc_x 125 | fig.add_trace( 126 | go.Scatter( 127 | x=reciprocal_x, 128 | y=fsc_data[meta].astype(float), 129 | mode="lines", 130 | line=dict(color=color, width=3), 131 | customdata=fsc_x, 132 | name=meta.replace("_rlnFourierShellCorrelation", "").replace("_rlnCorrectedFourierShellCorrelation", ""), 133 | hovertemplate="Resolution: %{customdata:.2f} Å
FSC: %{y:.2f}", 134 | ) 135 | ) 136 | 137 | # Add horizontal threshold lines. 138 | fig.add_hline(y=0.143, line_dash="dash", line_color="black", annotation_text="0.143") 139 | fig.add_hline(y=0.5, line_dash="dash", line_color="black", annotation_text="0.5") 140 | 141 | # Determine axis range and custom ticks. 142 | start_res = 1 / 50 143 | end_res = 1 / fsc_x_min 144 | custom_ticks = np.linspace(start_res, end_res, num=10) 145 | fig.update_layout( 146 | xaxis=dict( 147 | title="Resolution, Å", 148 | tickvals=custom_ticks, 149 | ticktext=[f"{round(1/res, 2)}" for res in custom_ticks], 150 | range=[start_res, end_res], 151 | ), 152 | yaxis=dict(title="FSC", range=[-0.05, 1.05]), 153 | title="Fourier Shell Correlation Curve", 154 | legend=dict(x=1, y=1, xanchor="left", yanchor="top", bgcolor="rgba(255,255,255,0)"), 155 | margin=dict(l=20, r=20, t=40, b=20), 156 | ) 157 | st.plotly_chart(fig, use_container_width=True) 158 | 159 | # Resolution annotations. 160 | fsc143 = 0.143 161 | idx_143 = np.argmin(np.abs(fsc_data["_rlnFourierShellCorrelationCorrected"].astype(float) - fsc143)) 162 | fsc05 = 0.5 163 | idx_05 = np.argmin(np.abs(fsc_data["_rlnFourierShellCorrelationCorrected"].astype(float) - fsc05)) 164 | st.markdown(f"Reported resolution @ FSC=0.143: **{round(fsc_x[idx_143], 2)} Å**") 165 | st.markdown(f"Reported resolution @ FSC=0.5: **{round(fsc_x[idx_05], 2)} Å**") 166 | 167 | def plot_guinier_curve(guinier_data): 168 | """ 169 | Plot the Guinier curve using Plotly. 170 | """ 171 | guiner_x = guinier_data["_rlnResolutionSquared"].astype(float) 172 | guinier_to_plot = [ 173 | "_rlnLogAmplitudesOriginal", 174 | "_rlnLogAmplitudesMTFCorrected", 175 | "_rlnLogAmplitudesWeighted", 176 | "_rlnLogAmplitudesSharpened", 177 | "_rlnLogAmplitudesIntercept", 178 | ] 179 | 180 | palette = sns.color_palette("deep").as_hex() 181 | 182 | fig = go.Figure() 183 | for i, meta in enumerate(guinier_to_plot): 184 | color = palette[i % len(palette)] 185 | try: 186 | y_data = guinier_data[meta].astype(float) 187 | # Replace invalid values (-99) with NaN. 188 | y_data[y_data == -99] = float("nan") 189 | fig.add_trace( 190 | go.Scatter( 191 | x=guiner_x[1:], 192 | y=y_data[1:], 193 | mode="lines", 194 | line=dict(color=color, width=3), 195 | name=meta.replace("_rlnLogAmplitudes", ""), 196 | ) 197 | ) 198 | except Exception: 199 | pass 200 | 201 | fig.update_layout( 202 | title="Guinier Plot", 203 | xaxis_title="Resolution Squared, [1/Ų]", 204 | yaxis_title="Ln(Amplitudes)", 205 | margin=dict(l=20, r=20, t=40, b=20), 206 | ) 207 | st.plotly_chart(fig, use_container_width=True) 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Follow Relion Gracefully :microscope::rocket::globe_with_meridians: 2 | --- 3 | **v6: A complete dashboard for easy interaction with your cryo-EM data in Relion, now with ~~partial~~ full :sparkles: `#teamtomo` :sparkles: support!** 4 | 5 | * **Data sourced from [Relion5 tutorial](https://relion.readthedocs.io/en/latest/SPA_tutorial/index.html), [Relion4 STA](https://relion.readthedocs.io/en/release-4.0/STA_tutorial/index.html), and [Relion5 STA](https://zenodo.org/records/11068319)** 6 | * **Licensed under Non-Profit Open Software License 3.0 (NPOSL-3.0)** 7 | 8 | **Micrograph viewer** 9 | ![Screenshot 2025-04-26 132051](https://github.com/user-attachments/assets/f0bef20d-cdd8-4862-9e79-6de9cef4ea56) 10 | 11 | **Picking statistics preview** 12 | ![Screenshot 2025-04-26 132130](https://github.com/user-attachments/assets/62a328ac-8a12-4e0b-84ea-681a33425d2f) 13 | 14 | **Mask with original volume preview** 15 | ![Screenshot 2025-04-26 132427](https://github.com/user-attachments/assets/d283829e-eb4c-4d42-a5a0-9f91c9c1b6c6) 16 | 17 | **Local resolution directly in the browser** 18 | ![Screenshot 2025-04-26 132533](https://github.com/user-attachments/assets/b34df0fd-2083-4ed0-8f8b-f7a3605bae0a) 19 | 20 | **Tomography specific job preview:** 21 | ![Screenshot 2025-04-26 131941](https://github.com/user-attachments/assets/36e69169-1753-496e-b914-a3c6efb9c45e) 22 | 23 | **Tomogram viewer** 24 | ![Screenshot 2025-04-26 131931](https://github.com/user-attachments/assets/ad6234ae-7e20-40e3-a2be-d7eaaadd3a8c) 25 | 26 | **3D picking preview with annotations** 27 | ![Screenshot 2025-04-26 131907](https://github.com/user-attachments/assets/4a0a9800-d2b4-43ec-8d62-8035c5a2ef59) 28 | 29 | **3D picking with particles** 30 | ![Screenshot 2025-04-26 131849](https://github.com/user-attachments/assets/8b10cd27-ef06-4af2-9809-ae332f9338d7) 31 | 32 | 33 | 34 | #### :sparkles: Found this helpful in your research? Cite my work! :sparkles: 35 | 36 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10465899.svg)](https://doi.org/10.5281/zenodo.10465899) 37 | 38 | 39 | #### Dawid Zyla. (2024). dzyla/Follow_Relion_gracefully: v5 (Version v5). Zenodo. https://doi.org/10.5281/zenodo.10465899 40 | Buy Me A Coffee 41 | 42 | 43 | --- 44 | 45 | ## Description :microscope: 46 | 47 | #### v6: :high_brightness: 48 | Version 6 introduces a complete dashboard for easy interaction with your cryo-EM data in Relion, now with full `#teamtomo` support! It allows users to visualize and analyze their data in real-time, providing a comprehensive overview of their projects. The new version also includes improved job previews, enhanced data visualization, and the ability to download volumes directly from the dashboard. 49 | 50 | #### v5: 51 | Version 5 improves the job preview by adopting a dynamic approach. Using [Streamlit](https://streamlit.io/), it allows users to interact directly with their data. The underlying Python framework facilitates real-time computation of statistics and data from most jobs, enabling users to engage with metadata and select preferred statistics for download and further analysis. 52 | 53 | #### v4: 54 | Version 4 introduced support for multiple projects and job visualization through an online interface using the Hugo framework. While this static job generator enabled job display with example data, it lacked interactive capabilities due to its static nature. 55 | 56 | ## v6 features :dizzy: 57 | * All the features of v5 but with improvements 58 | * Full support for `#teamtomo` jobs 59 | * Better job previews, including volume previews, statistics, and download options 60 | * Improved data visualization and publication-ready statistics 61 | * Ability to download volumes directly from the dashboard 62 | * Job flow chart overview, showing relationships between jobs 63 | * Improved speed (but increased RAM usage) 64 | * Included file browser for easier project switching 65 | * Plenty of QOL improvements 66 | * `#OpenSoftwareAcceleratesScience` 67 | 68 | 69 | 70 | 71 | ## Installation :rocket: 72 | 73 | Minor changes from `v5`, with a few new libraries added. Tested on `Windows 10/11`, `WSL2`, and `Ubuntu 22.04`. 74 | 75 | :heavy_exclamation_mark: 76 | ### If you are using a previously created environment, I strongly recommend creating a new environment with freshly installed dependencies, as some new libraries were added. Follow the steps below to set up a new environment. 77 | :heavy_exclamation_mark: 78 | 79 | ### Install Dependencies :snake: 80 | 81 | Install dependencies in a conda environment, as Python 3.12 is required and virtual environments are no longer supported (though they might still work). 82 | 83 | #### Conda Instructions 84 | 85 | 1. Install miniconda3 (*no root access required*, only if not installed already): 86 | 87 | ```bash 88 | wget -q -P . https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 89 | 90 | bash ./Miniconda3-latest-Linux-x86_64.sh -b -f 91 | ``` 92 | 93 | Activate conda for bash: 94 | 95 | ```bash 96 | conda init bash 97 | ``` 98 | 99 | Restart the shell or type `bash` to see the (base) prompt: 100 | 101 | ```bash 102 | (base) dzyla@GPU0 103 | ``` 104 | 105 | 2. Clone the GitHub repository and navigate to the folder: 106 | 107 | ```bash 108 | git clone https://github.com/dzyla/Follow_Relion_gracefully.git 109 | 110 | cd Follow_Relion_gracefully 111 | ``` 112 | 113 | 3. Create a conda environment and install dependencies using the `environment.yml` file: 114 | 115 | ```bash 116 | conda env create --file environment.yml 117 | 118 | conda activate FollowRelion 119 | ``` 120 | 121 | You should now see: 122 | 123 | ```bash 124 | (FollowRelion) dzyla@GPU0 125 | ``` 126 | 127 | #### UV Instructions 128 | UV is a package manager for Python that allows you to install and manage Python packages easily. It is similar to pip but is ultra-fast. To install UV, follow these steps: 129 | 130 | 1. Install UV using pip: 131 | 132 | ```bash 133 | pip install uv 134 | ``` 135 | 136 | or if you don't have python/pip/conda installed, use the following command: 137 | 138 | ```bash 139 | # with curl 140 | curl -LsSf https://astral.sh/uv/install.sh | sh 141 | 142 | # alternatively, with wget 143 | wget -qO- https://astral.sh/uv/install.sh | sh 144 | 145 | # or on Windows 146 | powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" 147 | ``` 148 | 149 | 2. Clone the GitHub repository and navigate to the folder: 150 | 151 | ```bash 152 | git clone https://github.com/dzyla/Follow_Relion_gracefully.git 153 | 154 | cd Follow_Relion_gracefully 155 | ``` 156 | 157 | 3. Create a new virtual environment using UV: 158 | 159 | ```bash 160 | uv venv FollowRelion --python 3.12 161 | ``` 162 | 4. Activate the virtual environment: 163 | 164 | ```bash 165 | source FollowRelion/bin/activate 166 | ``` 167 | 168 | 169 | 5. Install the required packages using the `requirements.txt` file: 170 | 171 | ```bash 172 | uv pip install -r requirements.txt 173 | ``` 174 | 5. After the installation is complete, you should see a message indicating that the packages have been installed successfully. 175 | 176 | 177 | 178 | :sparkles:**Ready to start!** :sparkles: 179 | 180 | ## Usage :computer: 181 | 182 | ```text 183 | streamlit run follow_relion_gracefully.py 184 | ``` 185 | 186 | Additional command line parameters for extra features: 187 | 188 | ``` 189 | -h, --help Show this help message and exit. 190 | -i I, --folder I Path to the default folder. 191 | -p P, --password P Password for securing your instance. 192 | ``` 193 | 194 | ##### Example for live server updates and setting up a new project: 195 | 196 | To use command line parameters with streamlit, add `--` before the parameters: 197 | 198 | ``` 199 | conda activate FollowRelion 200 | 201 | streamlit run follow_relion_gracefully.py -- -p MyPassword$221#& -i /mnt/staging/240105_NewProcessing 202 | ``` 203 | 204 | This sets a password and default processing folder. 205 | 206 | ## Accessing the Dashboard via Browser :chart_with_upwards_trend: 207 | 208 | The dashboard should open automatically in your browser. For remote workstations, access it using the provided network URL, ensuring the port is not firewall-blocked. 209 | 210 | Remote access example: 211 | 212 | ``` 213 | (FollowRelion) dzyla@PC-HOME:~/Follow_Relion_gracefully$ streamlit run follow_relion_gracefully.py --server.port 8501 -- --p 1234 --i /mnt/f/linux/Tutorial5.0/ 214 | 215 | Local URL: http://localhost:8501 216 | Network URL: http://172.21.222.176:8501 217 | ``` 218 | 219 | Open the network URL in your browser to access the dashboard. 220 | 221 | For firewall issues, create an SSH tunnel: 222 | 223 | ```bash 224 | ssh -f username@workstation -L 8501:localhost:8501 -N 225 | ``` 226 | 227 | This allows remote dashboard access on your local computer: http://localhost:8501. 228 | 229 | ## Troubleshooting :wrench: 230 | 231 | * As previous versions, the code is a hobby project and may not work perfectly. Please report any issues on GitHub. 232 | * Large volumes (500px+) load slowly, especially for multiple class-3D classifications. Plotly does an excellent job with plotting but too large data can slow down the browser. 233 | * Ensure the correct environment is activated (`FollowRelion`). Deactivate others with `conda deactivate`. 234 | * Jobs run manually may not be processed, as the script reads from `default_pipeline.star`. 235 | * There are some cases with warning about the session state. They can be ignored. 236 | * Rendering issues in the browser can often be resolved by refreshing (`F5`). 237 | * Mac support is untested, but it's assumed to work similarly to Linux. Please report any issues! 238 | * Please note that this code was developed by a Python enthusiast, not a professional developer. It has been tested under standard scenarios to ensure reliability. However, as the author, I cannot be held responsible for any issues or damages that may arise from its use. Users are encouraged to review and test the code thoroughly before implementation in their projects. 239 | 240 | 241 | ## To-do :memo: 242 | 243 | * Add support for `DynaMight` 244 | * Speed up processing for some jobs 245 | * Add cryoSPARC metadata support and export to Relion 246 | 247 | 248 | 249 | 250 | ## Questions/suggestions?:email: 251 | 252 | Dawid Zyla, La Jolla Institute for Immunology 253 | 254 | [Twitter](https://twitter.com/DawidZyla) 255 | 256 | [dzyla@lji.org](mailto:dzyla@lji.org) 257 | -------------------------------------------------------------------------------- /relion_jobs/ctffind_job.py: -------------------------------------------------------------------------------- 1 | #ctffind_job.py 2 | 3 | import os 4 | import logging 5 | from datetime import datetime 6 | 7 | import streamlit as st 8 | import pandas as pd 9 | import numpy as np 10 | import altair as alt 11 | import mrcfile 12 | 13 | from lib.utils import ( 14 | parse_star, 15 | get_values_from_first_key, 16 | report_error, 17 | interactive_scatter_plot, 18 | ) 19 | from lib.image_utils import clip, normalize 20 | 21 | logger = logging.getLogger("main_app") 22 | 23 | # ============================================================================= 24 | # Cached helpers 25 | # ============================================================================= 26 | @st.cache_data(show_spinner=False) 27 | def read_ctf_average_data(ctf_avrot_path: str) -> pd.DataFrame: 28 | if not os.path.exists(ctf_avrot_path): 29 | raise FileNotFoundError(ctf_avrot_path) 30 | df = pd.read_csv( 31 | ctf_avrot_path, 32 | skiprows=[0, 1, 2, 3, 4, 6, 10], 33 | sep=r"\s+", 34 | names=["Spatial_freq", "1D_Ave", "Fit", "Fit_CC"], 35 | engine="python", 36 | ) 37 | df["Resolution"] = 1.0 / df["Spatial_freq"] 38 | return df.melt( 39 | id_vars=["Spatial_freq", "Resolution"], 40 | value_vars=["1D_Ave", "Fit", "Fit_CC"], 41 | var_name="Type", 42 | value_name="CTF", 43 | ) 44 | 45 | 46 | @st.cache_data(show_spinner=False) 47 | def load_and_normalize_mrc(mrc_path: str) -> np.ndarray: 48 | if not os.path.exists(mrc_path): 49 | raise FileNotFoundError(mrc_path) 50 | with mrcfile.mmap(mrc_path, permissive=True) as mrc: 51 | data = np.squeeze(mrc.data) 52 | return normalize(clip(data, 1, 99)) 53 | 54 | 55 | @st.cache_data(show_spinner=True) 56 | def load_star_data(folder: str, star_file: str) -> pd.DataFrame: 57 | """ 58 | Return a dataframe with **one row per micrograph or tomogram**. 59 | 60 | For single-particle projects it takes the *micrographs* table. 61 | 62 | For tomography projects (STAR contains only the *global* table) it opens 63 | every `_rlnTomoTiltSeriesStarFile`, extracts the first table inside each 64 | of those STAR files, and concatenates them. 65 | 66 | The caller therefore always receives a dataframe with real, per-image 67 | statistics – never the empty *global* table. 68 | """ 69 | path = os.path.join(folder, star_file) 70 | star = parse_star(path) 71 | 72 | # ── single-particle ──────────────────────────────────────────────────── 73 | if "micrographs" in star: 74 | return star["micrographs"] 75 | 76 | # ── tomography: expand the global list ──────────────────────────────── 77 | if "global" in star and "_rlnTomoTiltSeriesStarFile" in star["global"].columns: 78 | series_files = star["global"]["_rlnTomoTiltSeriesStarFile"].astype(str) 79 | dfs: list[pd.DataFrame] = [] 80 | for rel_p in series_files: 81 | sub_path = os.path.join(folder, rel_p) 82 | try: 83 | df = get_values_from_first_key(parse_star(sub_path)) 84 | if isinstance(df, pd.DataFrame): 85 | dfs.append(df) 86 | except Exception as exc: 87 | report_error(exc, f"Parsing tilt-series STAR {sub_path}") 88 | if dfs: 89 | return pd.concat(dfs, ignore_index=True) 90 | 91 | raise ValueError("Could not find a usable table (micrographs / global).") 92 | 93 | # ============================================================================= 94 | # Plot functions (unchanged except for minor guards) 95 | # ============================================================================= 96 | def plot_CTF_average(FOLDER: str, ctf_file_path: str) -> None: 97 | """ 98 | Plot the 1D CTF fit per micrograph from a given text file using Altair. 99 | 100 | Parameters: 101 | FOLDER (str): Folder containing the CTF file. 102 | ctf_file_path (str): Path to the CTF file. 103 | 104 | Returns: 105 | None. 106 | """ 107 | col1, col2 = st.columns([3, 1]) 108 | try: 109 | power_spectrum_ave_rot_txt_paths = os.path.join(FOLDER, ctf_file_path.replace(".ctf:mrc", "_avrot.txt")) 110 | if not os.path.exists(power_spectrum_ave_rot_txt_paths): 111 | st.error("File not found: " + power_spectrum_ave_rot_txt_paths) 112 | return 113 | 114 | ave_rot = pd.read_csv( 115 | power_spectrum_ave_rot_txt_paths, 116 | skiprows=[0, 1, 2, 3, 4, 6, 10], 117 | header=None, 118 | sep=r'\s+' 119 | ).transpose() 120 | ave_rot.columns = ["Spatial_freq", "1D_Ave", "Fit", "Fit_CC"] 121 | 122 | ave_rot["Resolution"] = 1 / ave_rot["Spatial_freq"] 123 | ave_rot_melt = ave_rot.melt( 124 | id_vars=["Spatial_freq", "Resolution"], 125 | value_vars=["1D_Ave", "Fit", "Fit_CC"], 126 | var_name="Type", 127 | value_name="CTF" 128 | ) 129 | 130 | # Use three distinct colors for these three lines 131 | chart_ctf = alt.Chart(ave_rot_melt).mark_line().encode( 132 | x=alt.X("Spatial_freq:Q", title="Spatial Frequency (1/Å)"), 133 | y=alt.Y("CTF:Q", title="CTF"), 134 | color=alt.Color("Type:N", scale=alt.Scale(range=["#FA8072", "#6FC381", "#6495ED"])), 135 | tooltip=["Spatial_freq", "Resolution", "CTF", "Type"] 136 | ).properties( 137 | title="CTF Fit per Micrograph", 138 | width=600, 139 | height=400 140 | ) 141 | 142 | col1.altair_chart(chart_ctf, use_container_width=True) 143 | 144 | try: 145 | mrc_file = os.path.join(FOLDER, ctf_file_path.replace("_avrot.txt", ".ctf").replace(":mrc", "")) 146 | logger.debug(f"CTF file: {mrc_file}") 147 | 148 | mrc_data = np.squeeze(mrcfile.mmap(mrc_file).data) 149 | mrc_image = normalize(clip(mrc_data, 1, 99)) 150 | col2.image(mrc_image, caption=f"CTF Image: {mrc_file}") 151 | except Exception as exc: 152 | report_error(exc) 153 | col2.error("Error loading CTF image.") 154 | 155 | except Exception as exc: 156 | report_error(exc) 157 | logger.error("Error in plot_CTF_average function.") 158 | 159 | def plot_ctf_stats(folder: str, star_file: str) -> None: 160 | # ── data load (tomography-aware) ─────────────────────────────────────── 161 | try: 162 | star_df = load_star_data(folder, star_file) 163 | except Exception as exc: 164 | report_error(exc) 165 | st.error("Failed to load STAR data.") 166 | return 167 | 168 | star_df = star_df.apply(lambda c: pd.to_numeric(c, errors="ignore")) 169 | star_df["Index"] = np.arange(len(star_df)) 170 | 171 | cols = { 172 | "defocus_v": "_rlnDefocusV", 173 | "max_res": "_rlnCtfMaxResolution", 174 | } 175 | numeric_cols = { 176 | label: col 177 | for label, col in cols.items() 178 | if col in star_df.columns and pd.api.types.is_numeric_dtype(star_df[col]) 179 | } 180 | if not numeric_cols: 181 | st.error("No numeric CTF columns found.") 182 | return 183 | 184 | st.subheader("CTF Statistics per Image") 185 | 186 | # Defocus plot 187 | if "defocus_v" in numeric_cols: 188 | cv = numeric_cols["defocus_v"] 189 | c1, c2 = st.columns([4, 1]) 190 | line = ( 191 | alt.Chart(star_df) 192 | .mark_line(color="#6FC381") 193 | .encode(x="Index:Q", y=f"{cv}:Q", tooltip=["Index", f"{cv}:Q"]) 194 | .properties(title="Defocus V", height=300) 195 | 196 | ) 197 | kde = ( 198 | alt.Chart(star_df) 199 | .transform_density(density=cv, as_=[cv, "density"]) 200 | .mark_area(orient="horizontal") 201 | .encode(y=f"{cv}:Q", x="density:Q") 202 | .properties(height=300) 203 | ) 204 | c1.altair_chart(line, use_container_width=True) 205 | c2.altair_chart(kde, use_container_width=True) 206 | 207 | # Max-resolution plot 208 | if "max_res" in numeric_cols: 209 | mr = numeric_cols["max_res"] 210 | c1, c2 = st.columns([4, 1]) 211 | line = ( 212 | alt.Chart(star_df) 213 | .mark_line(color="#6495ED") 214 | .encode(x="Index:Q", y=f"{mr}:Q", tooltip=["Index", f"{mr}:Q"]) 215 | .properties(title="Max Resolution", height=300) 216 | 217 | ) 218 | kde = ( 219 | alt.Chart(star_df) 220 | .transform_density(density=mr, as_=[mr, "density"]) 221 | .mark_area(orient="horizontal") 222 | .encode(y=f"{mr}:Q", x="density:Q") 223 | .properties(height=300) 224 | ) 225 | c1.altair_chart(line, use_container_width=True) 226 | c2.altair_chart(kde, use_container_width=True) 227 | 228 | # 2-D histogram 229 | st.subheader("2-D Parameter Distribution") 230 | numeric_candidates = [ 231 | c for c in star_df.columns if pd.api.types.is_numeric_dtype(star_df[c]) and star_df[c].nunique() > 1 232 | ] 233 | if len(numeric_candidates) >= 2: 234 | x_col = st.selectbox("X-axis", numeric_candidates, key="2dh_x") 235 | y_col = st.selectbox( 236 | "Y-axis", numeric_candidates, key="2dh_y", index=1 if x_col == numeric_candidates[0] else 0 237 | ) 238 | if x_col != y_col: 239 | chart2d = ( 240 | alt.Chart(star_df) 241 | .mark_rect() 242 | .encode( 243 | x=alt.X(f"{x_col}:Q", bin=alt.Bin(maxbins=50)), 244 | y=alt.Y(f"{y_col}:Q", bin=alt.Bin(maxbins=50)), 245 | color=alt.Color("count()", scale=alt.Scale(scheme="viridis")), 246 | ) 247 | .properties(title=f"2-D Histogram: {x_col} vs {y_col}") 248 | 249 | ) 250 | st.altair_chart(chart2d, use_container_width=True) 251 | else: 252 | st.warning("Select two different columns.") 253 | else: 254 | st.info("Not enough numeric columns for a 2-D histogram.") 255 | 256 | # 1-D CTF fit per image (only if column exists) 257 | if "_rlnCtfImage" in star_df.columns and star_df["_rlnCtfImage"].notna().any(): 258 | st.subheader("1-D CTF Fit") 259 | img_series = star_df["_rlnCtfImage"].astype(str) 260 | idx = st.slider("Image index", 0, len(img_series) - 1, 0) 261 | plot_CTF_average(folder, img_series.iloc[idx]) 262 | 263 | # optional interactive scatter 264 | if st.checkbox("Detailed scatter plot?", key="scatter"): 265 | interactive_scatter_plot(os.path.join(folder, star_file)) 266 | 267 | logger.info(f"{datetime.now()}: plot_ctf_stats finished") 268 | -------------------------------------------------------------------------------- /relion_jobs/picking_job.py: -------------------------------------------------------------------------------- 1 | # picking_job.py 2 | 3 | import os 4 | import glob 5 | import tempfile 6 | import logging 7 | from datetime import datetime 8 | from typing import List, Tuple, Union 9 | 10 | import streamlit as st 11 | import pandas as pd 12 | import numpy as np 13 | import plotly.express as px 14 | import plotly.graph_objects as go 15 | 16 | from lib.utils import parse_star, star_from_df, interactive_scatter_plot, report_error 17 | from lib.image_utils import micrograph_viewer 18 | 19 | logger = logging.getLogger("main_app") 20 | 21 | def plot_histogram_particles_per_mic( 22 | particles_per_mic: Union[pd.Series, np.ndarray] 23 | ) -> go.Figure: 24 | """ 25 | Create a histogram of particles per micrograph. 26 | 27 | Parameters: 28 | particles_per_mic: Series or array of particle counts. 29 | 30 | Returns: 31 | A Plotly Figure. 32 | """ 33 | try: 34 | fig = px.histogram(particles_per_mic, nbins=30) 35 | fig.update_layout( 36 | title="Histogram of Particles Per Micrograph", 37 | xaxis_title="Particles", 38 | yaxis_title="Micrographs Count", 39 | ) 40 | return fig 41 | except Exception as exc: 42 | report_error(exc) 43 | logger.error("Error in plot_histogram_particles_per_mic.") 44 | return go.Figure() 45 | 46 | 47 | def plot_histogram_fom( 48 | figure_of_merit: Union[pd.Series, np.ndarray] 49 | ) -> go.Figure: 50 | """ 51 | Create a histogram of autopick figure-of-merit. 52 | 53 | Parameters: 54 | figure_of_merit: Series or array of FOM values. 55 | 56 | Returns: 57 | A Plotly Figure. 58 | """ 59 | try: 60 | fig = px.histogram(figure_of_merit, nbins=50) 61 | fig.update_layout( 62 | title="Histogram of Autopick Figure of Merit", 63 | xaxis_title="Autopick Figure of Merit", 64 | yaxis_title="Count", 65 | ) 66 | return fig 67 | except Exception as exc: 68 | report_error(exc) 69 | logger.error("Error in plot_histogram_fom.") 70 | return go.Figure() 71 | 72 | 73 | def get_coord_paths(job_folder: str, rln_folder: str) -> Tuple[List[str], List[str]]: 74 | """ 75 | Retrieve coordinate file paths and corresponding micrograph paths. 76 | 77 | Checks in order: 78 | - autopick.star 79 | - manualpick.star 80 | - files matching "coords_suffix_*" 81 | 82 | Parameters: 83 | job_folder (str): Path to the job folder. 84 | rln_folder (str): Base folder for micrographs. 85 | 86 | Returns: 87 | A tuple (coord_paths, mics_paths). 88 | """ 89 | try: 90 | # Case 1: autopick.star 91 | autopick_star_path = os.path.join(job_folder, "autopick.star") 92 | if os.path.exists(autopick_star_path) and os.path.getsize(autopick_star_path) > 0: 93 | autopick_star = parse_star(autopick_star_path)["coordinate_files"] 94 | mics_paths = autopick_star["_rlnMicrographName"].to_numpy().tolist() 95 | coord_paths = autopick_star["_rlnMicrographCoordinates"].to_numpy().tolist() 96 | return coord_paths, mics_paths 97 | 98 | # Case 2: manualpick.star 99 | manualpick_star_path = os.path.join(job_folder, "manualpick.star") 100 | if os.path.exists(manualpick_star_path) and os.path.getsize(manualpick_star_path) > 0: 101 | manpick_star = parse_star(manualpick_star_path)["coordinate_files"] 102 | mics_paths = manpick_star["_rlnMicrographName"].to_numpy().tolist() 103 | coord_paths = manpick_star["_rlnMicrographCoordinates"].to_numpy().tolist() 104 | return coord_paths, mics_paths 105 | 106 | # Case 3: coords_suffix_* pattern 107 | suffix_files = glob.glob(os.path.join(job_folder, "coords_suffix_*")) 108 | if suffix_files: 109 | suffix_file = suffix_files[0] 110 | suffix = os.path.basename(suffix_file).replace("coords_suffix_", "").replace(".star", "") 111 | with open(suffix_file, "r") as f: 112 | mics_data_path = f.readline().strip() 113 | all_mics_paths = parse_star(os.path.join(rln_folder, mics_data_path))["micrographs"]["_rlnMicrographName"] 114 | mics_paths = [os.path.join(rln_folder, name) for name in all_mics_paths] 115 | coord_paths = [ 116 | os.path.join( 117 | job_folder, 118 | f"coords_{suffix}", 119 | os.path.basename(mic_path).replace(".mrc", f"_{suffix}.star") 120 | ) 121 | for mic_path in mics_paths 122 | ] 123 | return coord_paths, mics_paths 124 | 125 | # Case 4: No matching pattern found. 126 | return [], [] 127 | except Exception as exc: 128 | report_error(exc) 129 | logger.error("Error in get_coord_paths.") 130 | return [], [] 131 | 132 | 133 | 134 | 135 | 136 | def plot_picks( 137 | rln_folder: str, job_name: str, img_resize_fac: float = 0.2 138 | ) -> None: 139 | """ 140 | Display picking statistics and overlay picks on the selected micrograph using micrograph_viewer. 141 | Determines the coordinate source (autopick.star, manualpick.star, or coords_suffix_*) 142 | and passes the computed picks to micrograph_viewer (which now supports optional picks overlay). 143 | Also supports Topaz training statistics if no coordinate files are found. 144 | 145 | Parameters: 146 | rln_folder (str): Base folder for job and micrograph data. 147 | job_name (str): Name of the job folder. 148 | img_resize_fac (float): Initial resize factor. 149 | 150 | Returns: 151 | None. 152 | """ 153 | 154 | logger.debug(f"{datetime.now()}: plot_picks started with job_name: {job_name}") 155 | try: 156 | path_data = os.path.join(rln_folder, job_name) 157 | coord_paths, mics_paths = get_coord_paths(path_data, job_name) 158 | 159 | logger.debug(f"coord_paths: {coord_paths}") 160 | 161 | # Fallback for Topaz training statistics. 162 | if not coord_paths: 163 | topaz_training_files = glob.glob(os.path.join(path_data, "model_training.txt")) 164 | if topaz_training_files: 165 | topaz_training_txt = topaz_training_files[0] 166 | data = pd.read_csv(topaz_training_txt, delimiter="\t") 167 | data_test = data[data["split"] == "test"] 168 | x = data_test["epoch"] 169 | data_test = data_test.drop(["iter", "split", "ge_penalty"], axis=1) 170 | fig = go.Figure() 171 | for column in data_test.columns: 172 | if column != "epoch": 173 | y = data_test[column] 174 | fig.add_scatter( 175 | x=x, 176 | y=y, 177 | name=column, 178 | hovertemplate=f"{column}
Epoch: %{{x}}
Y: %{{y:.2f}}", 179 | ) 180 | fig.update_xaxes(title_text="Epoch") 181 | fig.update_yaxes(title_text="Statistics") 182 | best_epoch = data_test[data_test["auprc"].astype(float) == np.max(data_test["auprc"].astype(float))]["epoch"].values 183 | fig.update_layout(title=f"Topaz training stats. Best model: {best_epoch}") 184 | fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)) 185 | fig.update_layout(hovermode="x unified") 186 | st.plotly_chart(fig, use_container_width=True) 187 | logger.info(f"{datetime.now()}: plot_picks_streamlit done (Topaz training)") 188 | return 189 | st.write("No coordinate files found.") 190 | logger.info(f"{datetime.now()}: plot_picks_streamlit done (No coordinate files)") 191 | return 192 | 193 | # Select a micrograph. 194 | col1, col2 = st.columns([1, 3]) 195 | 196 | # For autopick data (non-manual), get FOM slider and compute picks overlay. 197 | if "ManualPick" not in job_name: 198 | # problem with indexing and passing it to micrograph_viewer###### 199 | 200 | 201 | plot_all = col1.checkbox("Plot FOM statistics? (Might be slow for huge datasets)", value=False) 202 | fom_all_mics = [] 203 | tmp_file_path = None 204 | if plot_all: 205 | coords_df = pd.DataFrame() 206 | for idx in range(len(mics_paths)): 207 | try: 208 | star_data = parse_star(os.path.join(rln_folder, coord_paths[idx])) 209 | star_data = list(star_data.values())[0] 210 | star_data["_rlnMicrographName"] = mics_paths[idx] 211 | coords_df = pd.concat([coords_df, star_data], ignore_index=True) 212 | fom_stats = star_data["_rlnAutopickFigureOfMerit"].astype(float) 213 | fom_all_mics.extend(fom_stats) 214 | except Exception as exc: 215 | report_error(exc) 216 | logger.error(f"Error processing picking stats for micrograph index {idx}") 217 | modified_star = star_from_df({"particles": coords_df}) 218 | with tempfile.NamedTemporaryFile(delete=False, suffix=".star") as tmp_file: 219 | modified_star.write_file(tmp_file.name) 220 | tmp_file_path = tmp_file.name 221 | else: 222 | fom_slider = [-10000, 10000] 223 | picks_overlay = None 224 | plot_all = False 225 | 226 | # Use micrograph_viewer to display the micrograph with picks overlay. 227 | # If picks_overlay is None, micrograph_viewer behaves as before. 228 | micrograph_viewer( 229 | rln_folder=rln_folder, 230 | image_files=mics_paths, 231 | selected_filter="gaussian", 232 | default_gaussian=0.2, 233 | coord_paths=coord_paths, 234 | ) 235 | 236 | # For autopick, if "plot all" is enabled, show FOM histogram and detailed stats. 237 | if "ManualPick" not in job_name and plot_all and fom_all_mics: 238 | fom_all_mics = np.array(fom_all_mics) 239 | fom_histogram_fig = plot_histogram_fom(fom_all_mics) 240 | col2.plotly_chart(fom_histogram_fig, use_container_width=True) 241 | 242 | if st.checkbox("Show detailed statistics?") and tmp_file_path: 243 | interactive_scatter_plot(tmp_file_path) 244 | 245 | logger.info(f"{datetime.now()}: plot_picks_streamlit done") 246 | except Exception as exc: 247 | report_error(exc) 248 | logger.error("Error in plot_picks_streamlit function.") -------------------------------------------------------------------------------- /relion_jobs/mask_job.py: -------------------------------------------------------------------------------- 1 | #mask_job.py 2 | 3 | import os 4 | import logging 5 | import traceback 6 | from datetime import datetime 7 | from typing import List 8 | 9 | import numpy as np 10 | import mrcfile 11 | import streamlit as st 12 | import plotly.graph_objects as go 13 | import plotly.figure_factory as ff 14 | from skimage.transform import resize 15 | 16 | # Import shared utilities. 17 | from lib.utils import get_note, extract_source_job 18 | from lib.image_utils import normalize 19 | 20 | logger = logging.getLogger("mask_job") 21 | logger.setLevel(logging.DEBUG) 22 | 23 | 24 | def plot_volume_custom(volume: np.ndarray, threshold: float, max_size: int = 150, 25 | colormap_override: str = None, opacity_override: float = None, 26 | title: str = None) -> go.Figure: 27 | """ 28 | Generate an isosurface plot using marching cubes with Plotly. 29 | 30 | Args: 31 | volume (np.ndarray): 3D volume data. 32 | threshold (float): Fraction (0 to 1) used to compute the actual threshold intensity. 33 | max_size (int): Maximum dimension size to which the volume is resized. 34 | colormap_override (str): If provided, use a single color (e.g. "rgb(100,149,237)") 35 | for all facets instead of the default gray. 36 | opacity_override (float): If provided, set marker opacity (0 to 1). 37 | title (str): Title of the figure. 38 | 39 | Returns: 40 | A Plotly Figure containing the isosurface. 41 | """ 42 | try: 43 | # Create a writable copy. 44 | volume = np.array(volume, copy=True) 45 | volume.flags.writeable = True 46 | 47 | # Resize if needed. 48 | original_shape = volume.shape 49 | if np.any(np.array(original_shape) > max_size): 50 | resize_factor = max_size / np.max(volume.shape) 51 | new_shape = np.round(np.array(volume.shape) * resize_factor).astype(int) 52 | volume = resize(volume, new_shape, anti_aliasing=False) 53 | 54 | # Determine threshold intensity. 55 | min_val, max_val = np.min(volume), np.max(volume) 56 | actual_threshold = min_val + (max_val - min_val) * threshold 57 | 58 | # Compute mesh via marching cubes using mcubes. 59 | import mcubes 60 | verts, faces = mcubes.marching_cubes(volume, actual_threshold) 61 | 62 | # Use default gray color or override. 63 | if colormap_override is None: 64 | color_list = ["rgb(180,180,180)"] 65 | else: 66 | color_list = [colormap_override] 67 | 68 | fig_volume = ff.create_trisurf( 69 | x=verts[:, 2], 70 | y=verts[:, 1], 71 | z=verts[:, 0], 72 | simplices=faces, 73 | colormap=color_list, 74 | plot_edges=False, 75 | showbackground=True, 76 | show_colorbar=False, 77 | ) 78 | 79 | # Modify the Mesh3d trace. 80 | mesh = fig_volume.data[0] 81 | mesh.update( 82 | flatshading=False, 83 | lighting=dict( 84 | ambient=0.25, 85 | diffuse=0.9, 86 | specular=0.2, 87 | roughness=0.7, 88 | fresnel=4, 89 | ), 90 | lightposition=dict(x=0, y=500, z=500), 91 | ) 92 | if opacity_override is not None: 93 | mesh.opacity = opacity_override 94 | 95 | # Layout: black background. 96 | fig_volume.update_layout( 97 | title=title if title is not None else "3D Volume Isosurface", 98 | scene=dict( 99 | camera=dict(eye=dict(x=2, y=2, z=2)), 100 | xaxis=dict(visible=False), 101 | yaxis=dict(visible=False), 102 | zaxis=dict(visible=False), 103 | aspectmode="data", 104 | bgcolor="black", 105 | ), 106 | paper_bgcolor="black", 107 | height=600, 108 | margin=dict(l=0, r=0, t=0, b=0), 109 | hovermode=False, 110 | ) 111 | 112 | return fig_volume 113 | 114 | except RuntimeError: 115 | return None 116 | except Exception as exc: 117 | logger.error(f"plot_volume_custom() error: {exc}") 118 | return None 119 | 120 | 121 | def overlay_projections(source: np.ndarray, mask: np.ndarray, axis: int = 0, alpha: float = 0.7) -> np.ndarray: 122 | """ 123 | Compute a maximum intensity projection (MIP) of the source and mask volumes along the given axis, 124 | then overlay the mask (colored in cornflower blue) on top of the source (in grayscale). 125 | 126 | Args: 127 | source (np.ndarray): 3D source volume. 128 | mask (np.ndarray): 3D mask volume. 129 | axis (int): Axis along which to compute the projection. 130 | alpha (float): Transparency factor for the mask. 131 | 132 | Returns: 133 | An RGB image (as a NumPy array) combining the source and mask projections. 134 | """ 135 | # Compute MIPs. 136 | proj_source = source 137 | proj_mask = mask 138 | 139 | # Normalize projections. 140 | norm_source = normalize(proj_source) 141 | norm_mask = normalize(proj_mask) 142 | 143 | # Convert grayscale source to RGB. 144 | source_rgb = np.dstack([norm_source] * 3) 145 | # Define cornflower blue in normalized RGB. 146 | blue = np.array([100,149,237], dtype=np.float32)/255.0 147 | 148 | # Create a mask image with cornflower blue. 149 | mask_rgb = np.ones_like(source_rgb) 150 | mask_rgb[..., 0] = blue[0] 151 | mask_rgb[..., 1] = blue[1] 152 | mask_rgb[..., 2] = blue[2] 153 | 154 | # Blend images: raw map + mask overlay. 155 | mask_intensity = norm_mask[..., np.newaxis] 156 | overlay = source_rgb * (1 - alpha * mask_intensity) + (alpha * mask_rgb * mask_intensity) 157 | overlay = np.clip(overlay, 0, 1) 158 | overlay_uint8 = (overlay * 255).astype(np.uint8) 159 | return overlay_uint8 160 | 161 | 162 | def plot_mask(rln_folder: str, nodes: List[str]) -> None: 163 | """ 164 | Plot the mask and the source volume together. 165 | 166 | The function: 167 | 1) Retrieves the mask path and source job path. 168 | 2) Caches the volumes in session state (job-specific) so that repeated 169 | threshold changes do not cause re-reading from disk. 170 | 3) Provides two view modes: 171 | - 3D: Isosurface overlays using plot_volume_custom. 172 | - 2D: Maximum intensity projections overlaid (raw map in grayscale and mask in cornflower blue). 173 | """ 174 | logger.info("Plotting mask for nodes: %s", nodes) 175 | 176 | # Determine job folder from nodes. 177 | job_folder = os.path.join(rln_folder, os.path.dirname(nodes[0])) 178 | # Use a unique job key for caching, e.g., based on job folder. 179 | job_key = f"mask_job_{job_folder}" 180 | if job_key not in st.session_state: 181 | st.session_state[job_key] = {} 182 | 183 | note_path = os.path.join(job_folder, "note.txt") 184 | note = get_note(note_path) 185 | source_job = extract_source_job(note) 186 | 187 | mask_path = os.path.join(rln_folder, nodes[0]) 188 | source_job_path = os.path.join(rln_folder, source_job) 189 | 190 | # Cache volumes in session state (job-specific). 191 | cache = st.session_state[job_key] 192 | if "mask_volume" not in cache: 193 | try: 194 | mask_volume = mrcfile.mmap(mask_path, mode='r').data 195 | cache["mask_volume"] = mask_volume 196 | except Exception as exc: 197 | st.error(f"Error loading mask: {exc}") 198 | return 199 | else: 200 | mask_volume = cache["mask_volume"] 201 | 202 | if "source_volume" not in cache: 203 | try: 204 | source_volume = mrcfile.mmap(source_job_path, mode='r').data 205 | cache["source_volume"] = source_volume 206 | except Exception as exc: 207 | st.error(f"Error loading source volume: {exc}") 208 | return 209 | else: 210 | source_volume = cache["source_volume"] 211 | 212 | # Display mode selection: 3D or 2D. 213 | c1, c2 = st.columns([1, 3]) 214 | view_mode = c1.radio("Display mode", ["3D", "2D"], horizontal=True) 215 | 216 | if view_mode == "3D": 217 | with c1: 218 | st.markdown("### 3D Rendering Controls") 219 | threshold_source = st.slider("Original Map Threshold Fraction", 0.0, 1.0, 0.5, 0.01) 220 | threshold_mask = st.slider("Mask Threshold Fraction", 0.0, 1.0, 0.5, 0.01) 221 | # Generate isosurface figures. 222 | fig_source = plot_volume_custom(source_volume, threshold_source, max_size=150, 223 | colormap_override=None, opacity_override=None, 224 | title="Original Map") 225 | cornflower_blue = "rgb(100,149,237)" 226 | fig_mask = plot_volume_custom(mask_volume, threshold_mask, max_size=150, 227 | colormap_override=cornflower_blue, opacity_override=0.5, 228 | title="Mask") 229 | # Overlay the two isosurface traces. 230 | fig_overlay = go.Figure() 231 | if fig_source and len(fig_source.data) > 0: 232 | fig_overlay.add_trace(fig_source.data[0]) 233 | if fig_mask and len(fig_mask.data) > 0: 234 | fig_overlay.add_trace(fig_mask.data[0]) 235 | fig_overlay.update_layout( 236 | title="Overlay: Original Map and Mask (3D)", 237 | scene=dict( 238 | camera=dict(eye=dict(x=2, y=2, z=2)), 239 | xaxis=dict(visible=False), 240 | yaxis=dict(visible=False), 241 | zaxis=dict(visible=False), 242 | aspectmode="data", 243 | bgcolor="black", 244 | ), 245 | paper_bgcolor="black", 246 | height=600, 247 | margin=dict(l=0, r=0, t=0, b=0), 248 | hovermode=False, 249 | ) 250 | c2.plotly_chart(fig_overlay, use_container_width=True) 251 | else: 252 | c1.markdown("### 2D Projection Controls") 253 | with c1: 254 | projection_axis = st.radio("Projection axis", ['XY', 'XZ', 'YZ'], horizontal=True) 255 | # Map projection axis to axis index: use convention where 'XY' is projection along Z. 256 | projection_axis = {'XY': 0, 'XZ': 1, 'YZ': 2}[projection_axis] 257 | st.markdown("Adjust projection settings as needed.") 258 | # Compute maximum intensity projections for both volumes. 259 | proj_source = np.mean(source_volume, axis=projection_axis) 260 | proj_mask = np.max(mask_volume, axis=projection_axis) 261 | proj_source_norm = normalize(proj_source) 262 | proj_mask_norm = normalize(proj_mask) 263 | overlay_image = overlay_projections(proj_source_norm, proj_mask_norm, axis=projection_axis, alpha=0.7) 264 | c2.image(overlay_image, caption="Overlay of Original Map and Mask", width=400) 265 | 266 | logger.info(f"{datetime.now()}: plot_mask done.") 267 | -------------------------------------------------------------------------------- /relion_jobs/localres_job.py: -------------------------------------------------------------------------------- 1 | # localres_job.py 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | import re 7 | import traceback 8 | from typing import List, Tuple, Optional 9 | 10 | import mrcfile 11 | import numpy as np 12 | import plotly.graph_objects as go 13 | import streamlit as st 14 | from datetime import datetime 15 | from plotly.subplots import make_subplots 16 | from scipy.stats import gaussian_kde 17 | from skimage.transform import resize 18 | 19 | from lib.utils import report_error 20 | import logging 21 | 22 | logger = logging.getLogger("main_app") 23 | 24 | # ── pymcubes ──────────────────────────────────────────────────────────────── 25 | try: 26 | import mcubes 27 | except ImportError: 28 | mcubes = None 29 | 30 | # ───────────────────────────────────────────────────────────────────────────── 31 | # helper funcs 32 | # ───────────────────────────────────────────────────────────────────────────── 33 | def _load_mrc(path: str) -> np.ndarray: 34 | if not os.path.isfile(path): 35 | raise FileNotFoundError(path) 36 | with mrcfile.mmap(path, mode="r", permissive=True) as mrc: 37 | return mrc.data.copy() 38 | 39 | 40 | def _assign_maps(nodes: List[str], folder: str) -> Tuple[str, str]: 41 | raw, loc = "", "" 42 | for f in nodes: 43 | if not f.lower().endswith(".mrc"): 44 | continue 45 | base = os.path.basename(f).lower() 46 | full = os.path.join(folder, f) 47 | if "locres_filtered" in base: 48 | raw = full 49 | elif "locres" in base and "filtered" not in base: 50 | loc = full 51 | if raw and loc: 52 | break 53 | if not raw or not loc: 54 | raise ValueError("Missing relion_locres_filtered.mrc or relion_locres.mrc") 55 | return raw, loc 56 | 57 | 58 | def _orthogonal_slices(vol: np.ndarray, idx: int) -> tuple[np.ndarray, ...]: 59 | i = int(np.clip(idx, 0, min(vol.shape) - 1)) 60 | return vol[i], vol[:, i, :], vol[:, :, i] 61 | 62 | 63 | def _sample_vertex_values(vol: np.ndarray, verts: np.ndarray) -> np.ndarray: 64 | x = np.clip(np.round(verts[:, 0]).astype(int), 0, vol.shape[0] - 1) 65 | y = np.clip(np.round(verts[:, 1]).astype(int), 0, vol.shape[1] - 1) 66 | z = np.clip(np.round(verts[:, 2]).astype(int), 0, vol.shape[2] - 1) 67 | return vol[x, y, z] 68 | 69 | # ───────────────────────────── 3-D isosurface ─────────────────────────────── 70 | def _plot_isosurface( 71 | raw_vol: np.ndarray, 72 | loc_vol: np.ndarray, 73 | rel_thr: float, 74 | *, 75 | max_size: int, 76 | colourscale: str, 77 | ) -> Optional[go.Figure]: 78 | if mcubes is None: 79 | st.error("Install *pymcubes* for 3-D rendering.") 80 | return None 81 | if raw_vol.shape != loc_vol.shape: 82 | st.warning("Shape mismatch.") 83 | return None 84 | 85 | # down-sample for performance 86 | if max(raw_vol.shape) > max_size: 87 | fac = max_size / max(raw_vol.shape) 88 | new_shape = tuple(int(round(d * fac)) for d in raw_vol.shape) 89 | raw_vol = resize(raw_vol, new_shape, preserve_range=True, anti_aliasing=False) 90 | loc_vol = resize(loc_vol, new_shape, preserve_range=True, anti_aliasing=False) 91 | 92 | level = raw_vol.min() + rel_thr * (raw_vol.max() - raw_vol.min()) 93 | verts, faces = mcubes.marching_cubes(np.ascontiguousarray(raw_vol, np.float32), level) 94 | if verts.size == 0: 95 | return None 96 | 97 | intens = _sample_vertex_values(loc_vol, verts) 98 | cmin, cmax = float(loc_vol.min()), float(loc_vol.max()) 99 | if cmin == cmax: 100 | cmax = cmin + 1e-6 101 | 102 | fig = go.Figure( 103 | data=[ 104 | go.Mesh3d( 105 | x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], 106 | i=faces[:, 0], j=faces[:, 1], k=faces[:, 2], 107 | intensity=intens, 108 | colorscale=colourscale, 109 | cmin=cmin, cmax=cmax, 110 | showscale=True, 111 | opacity=1.0, 112 | lighting=dict(ambient=0.3, diffuse=0.5, specular=0.2), 113 | lightposition=dict(x=0, y=0, z=0), 114 | hoverinfo="skip", 115 | ) 116 | ] 117 | ) 118 | fig.update_layout( 119 | scene=dict( 120 | xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), 121 | aspectmode="data", bgcolor="black", dragmode="orbit" 122 | ), 123 | paper_bgcolor="black", 124 | margin=dict(l=10, r=10, t=30, b=10), 125 | height=600, 126 | title=f"Isosurface (thr ={rel_thr:.2f}) – clipped above ceiling", 127 | ) 128 | return fig 129 | 130 | # ───────────────────────────────────────────────────────────────────────────── 131 | # main entry 132 | # ───────────────────────────────────────────────────────────────────────────── 133 | def plot_locres(nodes: List[str], folder: str, job: str) -> None: 134 | """ 135 | Streamlit visualiser with: 136 | • view radio (3-D / 2-D) 137 | • threshold slider 138 | • max-dim slider 139 | • colour-scale select 140 | • *new* resolution-ceiling slider 141 | """ 142 | try: 143 | raw_p, loc_p = _assign_maps(nodes, folder) 144 | 145 | key_raw = f"raw_{job}_{os.path.getmtime(raw_p)}" 146 | key_loc = f"loc_{job}_{os.path.getmtime(loc_p)}" 147 | if key_raw not in st.session_state: 148 | with st.spinner("Loading raw map…"): 149 | st.session_state[key_raw] = _load_mrc(raw_p) 150 | if key_loc not in st.session_state: 151 | with st.spinner("Loading resolution map…"): 152 | st.session_state[key_loc] = _load_mrc(loc_p) 153 | 154 | raw = st.session_state[key_raw] 155 | loc = st.session_state[key_loc] 156 | 157 | # optional mask 158 | note = os.path.join(folder, job, "note.txt") 159 | if os.path.isfile(note): 160 | txt = open(note, encoding="utf-8").read() 161 | m = re.search(r"--mask\s+([\w\d/\\\.-]+\.mrc)", txt, re.I) 162 | if m: 163 | mpath = os.path.join(folder, m.group(1)) 164 | if os.path.isfile(mpath): 165 | mask = _load_mrc(mpath) > 0 166 | raw = np.where(mask, raw, 0) 167 | loc = np.where(mask, loc, 0) 168 | #st.info(f"Mask applied: {os.path.basename(mpath)}") 169 | 170 | # convert zeros to NaN 171 | loc_float = loc.astype(float) 172 | loc_float[loc == 0] = np.nan 173 | 174 | # ── UI controls ─────────────────────────────────────────────────── 175 | view = st.radio("View mode:", ("3-D isosurface", "2-D slices"), horizontal=True) 176 | 177 | # common resolution-ceiling slider 178 | finite_all = loc_float[np.isfinite(loc_float)] 179 | res_min, res_max = float(finite_all.min()), float(finite_all.max()) 180 | col1, col2, col3, col4 = st.columns([1, 2, 1, 1]) 181 | with col1: 182 | res_clip = st.slider( 183 | "Max resolution (Å)", 184 | min_value=round(res_min, 1), 185 | max_value=round(res_max, 1), 186 | value=round(res_max, 1), 187 | step=0.1, 188 | help="Values above this are clipped (set equal to the ceiling).", 189 | ) 190 | 191 | # apply clipping for visualisation 192 | loc_clip = np.minimum(loc_float, res_clip) 193 | 194 | if view == "2-D slices": 195 | idx = col2.slider( 196 | "Slice index", 0, min(loc.shape) - 1, (min(loc.shape) - 1) // 2, key=f"{job}_idx" 197 | ) 198 | slices = _orthogonal_slices(loc_clip, idx) 199 | vmin, vmax = float(loc_clip.min()), float(loc_clip.max()) 200 | if vmin == vmax: 201 | vmax = vmin + 1e-6 202 | 203 | fig_s = make_subplots(rows=1, cols=3, subplot_titles=("XY", "XZ", "YZ")) 204 | for c, slc in enumerate(slices, 1): 205 | fig_s.add_trace( 206 | go.Heatmap( 207 | z=slc.T, 208 | colorscale="RdYlBu_r", 209 | zmin=vmin, 210 | zmax=vmax, 211 | showscale=(c == 3), 212 | colorbar=dict(title="Å") if c == 3 else None, 213 | hoverinfo="skip", 214 | ), 215 | row=1, col=c, 216 | ) 217 | fig_s.update_xaxes(visible=False, row=1, col=c) 218 | fig_s.update_yaxes(visible=False, row=1, col=c) 219 | fig_s.update_layout(margin=dict(l=10, r=10, t=40, b=10), height=400) 220 | fig_s.update_layout( 221 | scene=dict( 222 | xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), 223 | aspectmode="cube" 224 | ), 225 | margin=dict(l=10, r=10, t=30, b=10), 226 | height=600, 227 | ) 228 | 229 | st.plotly_chart(fig_s, use_container_width=True) 230 | 231 | else: # 3-D 232 | with col2: 233 | thr = st.slider("Relative threshold", 0.05, 0.95, 0.5, 0.01, key=f"{job}_thr") 234 | with col3: 235 | dim = st.slider("Max dimension (px)", 80, 256, 200, 16, key=f"{job}_dim") 236 | with col4: 237 | colours = st.selectbox( 238 | "Colour scale", 239 | ("Turbo", "Viridis", "Cividis", "Plasma", "Inferno", "Haline"), 240 | index=0, 241 | key=f"{job}_cmap", 242 | ) 243 | 244 | with st.spinner("Rendering 3-D view…"): 245 | fig_iso = _plot_isosurface(raw, loc_clip, thr, max_size=dim, colourscale=colours) 246 | if fig_iso: 247 | st.plotly_chart(fig_iso, use_container_width=True) 248 | 249 | # ── histogram / KDE (use clipped values) ─────────────────────────── 250 | st.subheader("Local-resolution distribution") 251 | finite_vals = loc_clip[np.isfinite(loc_clip)] 252 | if finite_vals.size: 253 | kde = gaussian_kde(finite_vals) 254 | xs = np.linspace(finite_vals.min(), finite_vals.max(), 200) 255 | p10, p25, p50, p75, p90 = np.percentile(finite_vals, [10, 25, 50, 75, 90]) 256 | 257 | fig_h = go.Figure() 258 | fig_h.add_trace( 259 | go.Histogram( 260 | x=finite_vals, 261 | histnorm="probability density", 262 | marker_color="#9cc2cb", 263 | opacity=0.6, 264 | name="Histogram", 265 | ) 266 | ) 267 | fig_h.add_trace( 268 | go.Scatter(x=xs, y=kde(xs), mode="lines", name="KDE", line=dict(color="#417996", width=3)) 269 | ) 270 | for val, lbl, colr in [ 271 | (p10, "10 %", "grey"), 272 | (p25, "25 %", "orange"), 273 | (p50, "50 %", "red"), 274 | (p75, "75 %", "orange"), 275 | (p90, "90 %", "grey"), 276 | ]: 277 | fig_h.add_vline(x=val, line=dict(color=colr, dash="dash"), annotation_text=lbl, annotation_position="top") 278 | 279 | fig_h.update_layout( 280 | xaxis_title="Resolution (Å)", 281 | yaxis_title="Density", 282 | margin=dict(l=20, r=20, t=20, b=20), 283 | ) 284 | st.plotly_chart(fig_h, use_container_width=True) 285 | st.caption( 286 | f"10 %: {p10:.2f} Å 25 %: {p25:.2f} Å Median: {p50:.2f} Å " 287 | f"75 %: {p75:.2f} Å 90 %: {p90:.2f} Å (clipped at {res_clip:.2f} Å)" 288 | ) 289 | else: 290 | st.warning("No finite resolution values to plot.") 291 | 292 | except Exception as exc: 293 | st.error(f"plot_locres failed: {exc}") 294 | report_error(exc) 295 | logger.error(f"plot_locres error\n{traceback.format_exc()}") 296 | -------------------------------------------------------------------------------- /relion_jobs/excludetilt_job.py: -------------------------------------------------------------------------------- 1 | # excludetilt_job.py 2 | 3 | import os 4 | import traceback 5 | import logging 6 | from datetime import datetime 7 | 8 | import streamlit as st 9 | import numpy as np 10 | import pandas as pd 11 | import plotly.graph_objects as go 12 | 13 | from lib.utils import ( 14 | parse_star, 15 | get_values_from_first_key, 16 | get_note, 17 | extract_source_job, 18 | report_error, 19 | ) 20 | 21 | logger = logging.getLogger("main_app") 22 | 23 | def plot_exclude_tilt(rln_folder: str, node_files: str) -> None: 24 | """ 25 | Loads a star file referencing multiple tomography star files. Plots the tilt angles 26 | (_rlnTomoNominalStageTiltAngle) for a user-selected tilt series as: 27 | 1) Blue diameter lines from -r to +r. 28 | 2) Blue points on the circumference for those angles. 29 | 30 | If a source job star file also exists (extracted from note.txt) and references 31 | the same tilt series, we identify the angles that are present in the source 32 | but missing from the selected dataset. We plot those missing angles in red 33 | (both diameter lines and points). 34 | """ 35 | logger.info(f"Plotting exclude tilt data with Plotly... {rln_folder}, {node_files}") 36 | star_path = os.path.join(rln_folder, node_files) 37 | logger.debug(f"star_path: {star_path}") 38 | 39 | working_dir = os.path.dirname(star_path) 40 | note = get_note(os.path.join(working_dir, "note.txt")) 41 | logger.debug(f"Note: {note}") 42 | 43 | source_job = os.path.join(rln_folder, extract_source_job(note)) 44 | logger.debug(f"Source job: {source_job}") 45 | 46 | # Try to parse the source job file, if it exists. 47 | dfs_source_tilt = [] 48 | if os.path.exists(source_job): 49 | try: 50 | source_star = parse_star(source_job) 51 | if source_star is not None and "global" in source_star: 52 | tomo_star_files_source = source_star["global"]["_rlnTomoTiltSeriesStarFile"] 53 | tomo_star_files_source_paths = [ 54 | os.path.join(rln_folder, f) for f in tomo_star_files_source 55 | ] 56 | 57 | for path_tomo_source_star in tomo_star_files_source_paths: 58 | try: 59 | star_data_source = parse_star(path_tomo_source_star) 60 | if star_data_source: 61 | df_src = get_values_from_first_key(star_data_source) 62 | df_src["FileSource"] = os.path.basename(path_tomo_source_star) 63 | dfs_source_tilt.append(df_src) 64 | except Exception as exc: 65 | logger.error(f"Error parsing star file {path_tomo_source_star}: {exc}") 66 | report_error(exc) 67 | else: 68 | logger.debug(f"Source star is None or missing 'global' key: {source_job}") 69 | except Exception as exc: 70 | logger.error(f"Error parsing source star file {source_job}: {exc}") 71 | report_error(exc) 72 | else: 73 | logger.error(f"Source job file not found: {source_job}") 74 | # If missing, we will just plot the main star data. 75 | 76 | ############################## 77 | # Now parse the main star file (the exclude job) for the actual tilt data 78 | ############################## 79 | if not os.path.exists(star_path): 80 | st.error(f"Star file not found: {star_path}") 81 | return 82 | 83 | try: 84 | star = parse_star(star_path) 85 | except Exception as exc: 86 | report_error(exc) 87 | st.error("Failed to parse the main star file.") 88 | return 89 | 90 | if "global" not in star: 91 | st.error("The star file does not contain a 'global' section.") 92 | return 93 | 94 | # Gather paths to tilt-series star files from the exclude job 95 | tomo_star_files = star["global"]["_rlnTomoTiltSeriesStarFile"] 96 | tomo_star_files_paths = [os.path.join(rln_folder, f) for f in tomo_star_files] 97 | 98 | # Parse each tilt-series star file 99 | dfs_tilt = [] 100 | for path_tomo_star in tomo_star_files_paths: 101 | try: 102 | star_data = parse_star(path_tomo_star) 103 | if star_data: 104 | df_tilt = get_values_from_first_key(star_data) 105 | df_tilt["FileSource"] = os.path.basename(path_tomo_star) 106 | dfs_tilt.append(df_tilt) 107 | except Exception as exc: 108 | logger.error(f"Error parsing star file {path_tomo_star}: {exc}") 109 | report_error(exc) 110 | 111 | if not dfs_tilt: 112 | st.error("No tilt-series data found. Cannot plot tilt angles.") 113 | return 114 | 115 | ############################## 116 | # UI Layout: slider to pick which tilt series to display 117 | ############################## 118 | col1, col2, col3 = st.columns([1, 3, 3]) 119 | 120 | with col1: 121 | selected_idx = st.slider( 122 | label="Select Tilt Series", 123 | min_value=0, 124 | max_value=len(dfs_tilt), 125 | value=1, 126 | step=1 127 | ) - 1 128 | 129 | selected_df = dfs_tilt[selected_idx].copy() 130 | tilt_col = "_rlnTomoNominalStageTiltAngle" 131 | if tilt_col not in selected_df.columns: 132 | st.error(f"Column '{tilt_col}' not found in tilt-series data.") 133 | return 134 | 135 | # Convert from degrees to radians 136 | selected_df["TiltAngleDeg"] = selected_df[tilt_col].astype(float) 137 | selected_df["TiltAngleRad"] = np.deg2rad(selected_df["TiltAngleDeg"]) 138 | 139 | ############################## 140 | # Build lines/points for the main (blue) dataset 141 | ############################## 142 | lines_x_blue, lines_y_blue = [], [] 143 | points_x_blue, points_y_blue = [], [] 144 | text_labels_blue = [] 145 | 146 | for _, row in selected_df.iterrows(): 147 | angle_deg = row["TiltAngleDeg"] 148 | angle_rad = row["TiltAngleRad"] 149 | x_plus = np.cos(angle_rad) 150 | y_plus = np.sin(angle_rad) 151 | x_minus = -x_plus 152 | y_minus = -y_plus 153 | 154 | # lines 155 | lines_x_blue.extend([x_minus, x_plus, None]) 156 | lines_y_blue.extend([y_minus, y_plus, None]) 157 | 158 | # points at +r 159 | points_x_blue.append(x_plus) 160 | points_y_blue.append(y_plus) 161 | text_labels_blue.append(f"{row['FileSource']}
Tilt={angle_deg:.1f}°") 162 | 163 | ############################## 164 | # Check if we have a matching source dataset for this tilt series 165 | ############################## 166 | missing_df = pd.DataFrame() 167 | selected_file_source = selected_df["FileSource"].iloc[0] 168 | 169 | if len(dfs_source_tilt) > 0: 170 | # Find a source DF that has the same FileSource 171 | possible_matches = [ 172 | df_s for df_s in dfs_source_tilt 173 | if (df_s["FileSource"].unique()[0] == selected_file_source) 174 | ] 175 | if len(possible_matches) == 1: 176 | source_df = possible_matches[0].copy() 177 | if tilt_col in source_df.columns: 178 | source_df["TiltAngleDeg"] = source_df[tilt_col].astype(float) 179 | # Identify angles that are in source_df but missing from selected_df 180 | missing_df = source_df[~source_df["TiltAngleDeg"].isin(selected_df["TiltAngleDeg"])] 181 | missing_df["TiltAngleRad"] = np.deg2rad(missing_df["TiltAngleDeg"]) 182 | 183 | ############################## 184 | # Build lines/points for the missing (red) dataset 185 | ############################## 186 | lines_x_red, lines_y_red = [], [] 187 | points_x_red, points_y_red = [], [] 188 | text_labels_red = [] 189 | 190 | if not missing_df.empty: 191 | for _, row in missing_df.iterrows(): 192 | angle_deg = row["TiltAngleDeg"] 193 | angle_rad = row["TiltAngleRad"] 194 | x_plus = np.cos(angle_rad) 195 | y_plus = np.sin(angle_rad) 196 | x_minus = -x_plus 197 | y_minus = -y_plus 198 | 199 | lines_x_red.extend([x_minus, x_plus, None]) 200 | lines_y_red.extend([y_minus, y_plus, None]) 201 | 202 | points_x_red.append(x_plus) 203 | points_y_red.append(y_plus) 204 | text_labels_red.append(f"{row['FileSource']}
Tilt={angle_deg:.1f}°") 205 | 206 | ############################## 207 | # Plot with Plotly 208 | ############################## 209 | fig = go.Figure() 210 | 211 | # 1) Circle boundary 212 | fig.add_shape( 213 | type="circle", 214 | xref="x", yref="y", 215 | x0=-1, x1=1, 216 | y0=-1, y1=1, 217 | line=dict(color="lightgray") 218 | ) 219 | 220 | # 2) Blue lines 221 | fig.add_trace(go.Scatter( 222 | x=lines_x_blue, 223 | y=lines_y_blue, 224 | mode="lines", 225 | line=dict(color="#007acc"), # blue 226 | name="Kept tilt lines", # LEGEND: updated 227 | hoverinfo="none", 228 | showlegend=False # LEGEND: hidden 229 | )) 230 | 231 | # 3) Blue points 232 | fig.add_trace(go.Scatter( 233 | x=points_x_blue, 234 | y=points_y_blue, 235 | mode="markers", 236 | marker=dict(color="#007acc", size=8), 237 | text=text_labels_blue, 238 | hovertemplate="%{text}", 239 | name="Kept tilt angles" # LEGEND: updated 240 | )) 241 | 242 | # 4) Red lines for missing angles 243 | if not missing_df.empty: 244 | fig.add_trace(go.Scatter( 245 | x=lines_x_red, 246 | y=lines_y_red, 247 | mode="lines", 248 | line=dict(color="#cc3333", width=3), # red 249 | hoverinfo="none", 250 | showlegend=False # LEGEND: hidden 251 | )) 252 | # 5) Red points 253 | fig.add_trace(go.Scatter( 254 | x=points_x_red, 255 | y=points_y_red, 256 | mode="markers", 257 | marker=dict(color="#cc3333", size=15), 258 | text=text_labels_red, 259 | hovertemplate="%{text}", 260 | name="Excluded tilt angles" # LEGEND: updated 261 | )) 262 | 263 | # Make it square 264 | fig.update_xaxes( 265 | range=[-1.1, 1.1], 266 | scaleanchor="y", 267 | scaleratio=1, 268 | showgrid=False, 269 | zeroline=False, 270 | visible=False 271 | ) 272 | fig.update_yaxes( 273 | range=[-1.1, 1.1], 274 | showgrid=False, 275 | zeroline=False, 276 | visible=False 277 | ) 278 | 279 | fig.update_layout( 280 | title=f"Tilt Angles: {selected_file_source}", 281 | showlegend=True, # LEGEND: now shown 282 | width=600, 283 | height=600 284 | ) 285 | 286 | with col2: 287 | st.write("**Tilt Angles**") 288 | st.plotly_chart(fig, use_container_width=True) 289 | 290 | if "AlignTiltSeries" in node_files: 291 | logger.info("AlignTiltSeries job detected.") 292 | 293 | # plot the statistics from the tilt alignment: _rlnTomoXTilt, _rlnTomoYTilt _rlnTomoZRot _rlnTomoXShiftAngst _rlnTomoYShiftAngst using plotly 294 | fig2 = go.Figure() 295 | fig2.add_trace(go.Scatter( 296 | x=selected_df["TiltAngleDeg"], 297 | y=selected_df["_rlnTomoXTilt"], 298 | mode="markers", 299 | marker=dict(color="#007acc", size=10), 300 | name="_rlnTomoXTilt" 301 | )) 302 | fig2.add_trace(go.Scatter( 303 | x=selected_df["TiltAngleDeg"], 304 | y=selected_df["_rlnTomoYTilt"], 305 | mode="markers", 306 | marker=dict(color="#cc3333", size=10), 307 | name="_rlnTomoYTilt" 308 | )) 309 | fig2.add_trace(go.Scatter( 310 | x=selected_df["TiltAngleDeg"], 311 | y=selected_df["_rlnTomoZRot"], 312 | mode="markers", 313 | marker=dict(color="#ffd354", size=10), 314 | name="_rlnTomoZRot" 315 | )) 316 | fig2.add_trace(go.Scatter( 317 | x=selected_df["TiltAngleDeg"], 318 | y=selected_df["_rlnTomoXShiftAngst"], 319 | mode="markers", 320 | marker=dict(color="#b850c8", size=10), 321 | name="_rlnTomoXShiftAngst" 322 | )) 323 | fig2.add_trace(go.Scatter( 324 | x=selected_df["TiltAngleDeg"], 325 | y=selected_df["_rlnTomoYShiftAngst"], 326 | mode="markers", 327 | marker=dict(color="#45ab84", size=10), 328 | name="_rlnTomoYShiftAngst" 329 | )) 330 | fig2.update_layout( 331 | title=f"Tilt Angles: {selected_file_source}", 332 | xaxis_title="Tilt Angle (deg)", 333 | yaxis_title="Tilt Alignment Statistics", 334 | width=600, 335 | height=600 336 | ) 337 | col3.write("**Tilt Alignment Statistics**") 338 | col3.plotly_chart(fig2, use_container_width=True) 339 | 340 | logger.info(f"{datetime.now()}: plot_exclude_tilt completed successfully with Plotly.") 341 | -------------------------------------------------------------------------------- /relion_jobs/polish_job.py: -------------------------------------------------------------------------------- 1 | #polish_job.py 2 | 3 | import os 4 | import traceback 5 | from datetime import datetime 6 | from typing import List 7 | import logging 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import streamlit as st 12 | import plotly.graph_objects as go 13 | 14 | # Import shared utilities (ensure these functions are defined in your project) 15 | from lib.utils import parse_star 16 | from lib.image_utils import report_error 17 | from relion_jobs.extract_job import show_random_particles 18 | 19 | logger = logging.getLogger("main_app") 20 | 21 | 22 | def plot_polish(FOLDER: str, node_files: List[str]) -> None: 23 | """ 24 | Main function to plot polish job results in Streamlit. 25 | 26 | Depending on the contents of node_files, this function handles: 27 | 1) Optimal parameter reporting (if "opt_params_all_groups.txt" is found), 28 | 2) B-factor and Guinier analysis (if "shiny.star" is found), 29 | 3) Tomogram motion visualization (if "tomograms.star" is found). 30 | 31 | The module uses Plotly for plotting and caches loaded data in session state. 32 | """ 33 | train_job = False 34 | 35 | logger.debug(f"{datetime.now()}: plot_polish called with folder: {FOLDER} and node_files: {node_files}") 36 | for file in node_files: 37 | if 'star' in file: 38 | star_file_path = os.path.join(FOLDER, file) 39 | elif 'opt_params_all_groups.txt' in file: 40 | 41 | star_file_path = os.path.join(FOLDER, file) 42 | train_job = True 43 | 44 | job_path = os.path.dirname(star_file_path) 45 | logger.debug(f"Job path: {job_path}") 46 | 47 | # Create a job-specific session state key. 48 | polish_key = f"polish_data_{job_path}" 49 | if polish_key not in st.session_state: 50 | st.session_state[polish_key] = {} 51 | cache = st.session_state[polish_key] 52 | 53 | # Case 1: Optimal parameters available. 54 | if any("opt_params_all_groups.txt" in element for element in node_files): 55 | logger.debug("Optimal parameters found.") 56 | params_path = star_file_path 57 | try: 58 | with open(params_path, "r") as f: 59 | parameters = f.readline().strip().split() 60 | st.markdown("### Optimal Parameters") 61 | st.code(f"--s_vel {parameters[0]} --s_div {parameters[1]} --s_acc {parameters[2]}") 62 | except Exception as exc: 63 | st.error(f"Error reading optimal parameters: {exc}") 64 | logger.error(f"Optimal parameters error: {exc}\n{traceback.format_exc()}") 65 | return 66 | 67 | # Case 2: B-factor analysis (shiny.star found). 68 | elif any("shiny.star" in element for element in node_files): 69 | try: 70 | bfactors_star_path = os.path.join(job_path, "bfactors.star") 71 | logger.debug(f"Loading B-factors from: {bfactors_star_path}") 72 | # Cache parsed bfactors data. 73 | bfactors_cache_key = f"bfactors_data_{job_path}" 74 | logger.debug(f"Cache key for bfactors: {bfactors_cache_key}") 75 | if bfactors_cache_key in cache: 76 | bfactors_data = cache[bfactors_cache_key] 77 | logger.debug("Using cached bfactors data.") 78 | else: 79 | bfactors_data = parse_star(bfactors_star_path)["perframe_bfactors"] 80 | cache[bfactors_cache_key] = bfactors_data 81 | 82 | # Build a Plotly figure with two traces. 83 | fig = go.Figure() 84 | fig.add_trace( 85 | go.Scatter( 86 | x=bfactors_data["_rlnMovieFrameNumber"], 87 | y=bfactors_data["_rlnBfactorUsedForSharpening"], 88 | mode="lines", 89 | name="Bfactor Used For Sharpening", 90 | line=dict(color="darkgoldenrod", width=3), 91 | hovertemplate="Movie Frame: %{x}
Bfactor: %{y}" 92 | ) 93 | ) 94 | fig.add_trace( 95 | go.Scatter( 96 | x=bfactors_data["_rlnMovieFrameNumber"], 97 | y=bfactors_data["_rlnFittedInterceptGuinierPlot"], 98 | mode="lines", 99 | name="Fitted Intercept Guinier Plot", 100 | line=dict(color="lightseagreen", width=3), 101 | yaxis="y2", 102 | hovertemplate="Movie Frame: %{x}
Intercept: %{y}" 103 | ) 104 | ) 105 | fig.update_layout( 106 | title="Polish Job Statistics", 107 | xaxis=dict(title="Movie Frame Number"), 108 | yaxis=dict(title="Bfactor Used For Sharpening"), 109 | yaxis2=dict( 110 | title="Fitted Intercept Guinier Plot", overlaying="y", side="right" 111 | ), 112 | margin=dict(l=20, r=20, t=40, b=20), 113 | paper_bgcolor="white", 114 | ) 115 | st.plotly_chart(fig, use_container_width=True) 116 | 117 | if st.checkbox(":gem: Show random shiny particles?"): 118 | # Assume show_random_particles is defined. 119 | c1, c2 = st.columns([1,4]) 120 | show_random_particles(star_file_path, FOLDER, c1, c2) 121 | 122 | except Exception as exc: 123 | logger.error(f"{datetime.now()}: Error in bfactors branch:\n{report_error(exc)}") 124 | st.error("Error loading B-factors data") 125 | return 126 | 127 | # Case 3: Tomogram motion analysis. 128 | elif any("tomograms.star" in element for element in node_files): 129 | try: 130 | # Paths: use node_files[2] for particle star and node_files[3] for motion star. 131 | particle_star_path = os.path.join(FOLDER, node_files[2]) 132 | motion_star_path = os.path.join(FOLDER, node_files[3]) 133 | 134 | # Load particle star data and cache it. 135 | particle_cache_key = f"particle_star_{job_path}" 136 | if particle_cache_key in cache: 137 | particle_star = cache[particle_cache_key] 138 | logger.debug("Using cached particle star data.") 139 | else: 140 | particle_star = parse_star(particle_star_path)["particles"] 141 | # Keep only the relevant columns. 142 | particle_star = particle_star[ 143 | ["_rlnTomoName", "_rlnCenteredCoordinateXAngst", 144 | "_rlnCenteredCoordinateYAngst", "_rlnCenteredCoordinateZAngst"] 145 | ] 146 | for col in particle_star.columns: 147 | if col.startswith("_rln"): 148 | try: 149 | particle_star[col] = particle_star[col].astype(float) 150 | except ValueError: 151 | logger.warning(f"Column {col} cannot be converted to float.") 152 | 153 | #particle_star = convert_to_float(particle_star) 154 | cache[particle_cache_key] = particle_star 155 | 156 | unique_tomo_names = np.unique(particle_star["_rlnTomoName"]) 157 | idx = 0 158 | 159 | c1, c2 = st.columns([1, 5]) 160 | 161 | 162 | if unique_tomo_names.size > 1: 163 | idx = c1.slider("Tomogram index", 0, unique_tomo_names.size - 1, 0) 164 | tomo_name = unique_tomo_names[idx] 165 | particle_star_selected = particle_star[particle_star["_rlnTomoName"] == tomo_name] 166 | st.markdown( 167 | f"**Selected Tomogram:** {tomo_name}. **Number of particles:** {particle_star_selected.shape[0]}" 168 | ) 169 | 170 | # Load or cache motion data. 171 | temp_folder = os.path.join(os.path.dirname(particle_star_path), "temp") 172 | temp_motion_file_path = os.path.join(temp_folder, f"{tomo_name}_motion.star") 173 | if os.path.exists(temp_motion_file_path): 174 | motion_cache_key = f"motion_data_{tomo_name}" 175 | if motion_cache_key in cache: 176 | motion_file = cache[motion_cache_key] 177 | logger.debug("Using cached motion star data.") 178 | else: 179 | motion_file = parse_star(temp_motion_file_path) 180 | cache[motion_cache_key] = motion_file 181 | 182 | # Calculate total motion per tomogram block. 183 | all_motion = [] 184 | 185 | for key in motion_file.keys(): 186 | try: 187 | motion_df = motion_file[key].astype(float, errors='ignore') 188 | x_trace = np.mean(motion_df["_rlnOriginXAngst"]) 189 | y_trace = np.mean(motion_df["_rlnOriginYAngst"]) 190 | z_trace = np.mean(motion_df["_rlnOriginZAngst"]) 191 | total_motion = np.abs(x_trace + y_trace + z_trace) 192 | all_motion.append(total_motion) 193 | except Exception as exc: 194 | logger.error(f"Error processing motion data for key {key}: {exc}") 195 | all_motion.append(0) 196 | 197 | # Create a 3D scatter plot for particle positions, colored by total motion. 198 | scatter = go.Scatter3d( 199 | x=particle_star_selected["_rlnCenteredCoordinateXAngst"], 200 | y=particle_star_selected["_rlnCenteredCoordinateYAngst"], 201 | z=particle_star_selected["_rlnCenteredCoordinateZAngst"], 202 | mode="markers", 203 | marker=dict( 204 | size=5, 205 | color=all_motion, 206 | colorscale="Inferno", 207 | colorbar=dict(title="Total Motion"), 208 | ), 209 | name="Particles", 210 | hovertemplate="X: %{x}
Y: %{y}
Z: %{z}" 211 | ) 212 | fig_motion = go.Figure(data=[scatter]) 213 | fig_motion.update_layout( 214 | scene=dict( 215 | xaxis_title="X (Å)", 216 | yaxis_title="Y (Å)", 217 | zaxis_title="Z (Å)", 218 | aspectmode="data", 219 | ), 220 | height=600, 221 | title="Total Motion of Particles in Tomogram (summed XYZ motion)", 222 | showlegend=False, 223 | ) 224 | c2.plotly_chart(fig_motion, use_container_width=True) 225 | else: 226 | # Fallback: if the temp motion file does not exist, call process_tomo_motion. 227 | st.warning("No temporary motion file found; please run the motion correction step.") 228 | # You may call: process_tomo_motion(motion_star_path) 229 | 230 | except Exception as exc: 231 | logger.error(f"{datetime.now()}: Error in tomogram branch: {exc}\n{traceback.format_exc()}") 232 | st.error("Error processing tomogram motion data.") 233 | return 234 | 235 | else: 236 | st.write("No relevant data found for the Polish job.") 237 | 238 | logger.info(f"{datetime.now()}: plot_polish done") 239 | 240 | 241 | def process_tomo_motion(filename: str): 242 | """ 243 | Example function to process a tomogram motion star file. 244 | Splits the file into blocks and returns a list of (block_name, DataFrame). 245 | """ 246 | try: 247 | with open(filename, 'r') as file: 248 | lines = file.readlines() 249 | blocks = [] 250 | block = [] 251 | for line in lines: 252 | if line.startswith('data_'): 253 | if block: 254 | blocks.append(block) 255 | block = [line] 256 | else: 257 | block.append(line) 258 | if block: 259 | blocks.append(block) 260 | 261 | dataframes = [] 262 | for block in blocks: 263 | block_name = block[0].strip() 264 | data_section_started = False 265 | columns = [] 266 | data_rows = [] 267 | for line in block: 268 | line = line.strip() 269 | if line.startswith('_rln'): 270 | columns.append(line.split()[0]) 271 | elif data_section_started: 272 | if line.startswith('#') or 'None' in line: 273 | continue 274 | data_row = line.split() 275 | data_rows.append(data_row) 276 | elif line.startswith('loop_'): 277 | data_section_started = True 278 | df = pd.DataFrame(data_rows, columns=columns).dropna() 279 | df = df.astype(float, errors='ignore') 280 | dataframes.append((block_name, df)) 281 | return dataframes 282 | except Exception as exc: 283 | logger.error(f"Error processing tomogram motion file: {exc}") 284 | st.error("Error processing tomogram motion file.") 285 | return [] 286 | -------------------------------------------------------------------------------- /relion_jobs/import_job.py: -------------------------------------------------------------------------------- 1 | #import_job.py 2 | 3 | # Standard Library Imports 4 | import logging 5 | import os 6 | from concurrent.futures import ThreadPoolExecutor 7 | from datetime import datetime 8 | from typing import Dict, List, Optional, Tuple 9 | 10 | # Third-Party Imports 11 | import numpy as np 12 | import pandas as pd 13 | import plotly.graph_objects as go 14 | import streamlit as st 15 | 16 | # Local Imports 17 | from lib.image_utils import micrograph_viewer 18 | from lib.utils import ( 19 | get_first_key, 20 | get_modification_time, 21 | parse_star, 22 | report_error, 23 | ) 24 | 25 | # ============================================================================= 26 | # Logger Setup 27 | # ============================================================================= 28 | logger = logging.getLogger("main_app") # Use logger from main app 29 | 30 | 31 | def plot_tomogram_picks( 32 | n: int, # Tomogram index (for context/logging) 33 | coords_sel: Dict[str, List[float]], 34 | coords_rej: Dict[str, List[float]], 35 | scatter_size: Tuple[int, int] = (5, 2), 36 | ) -> go.Figure: 37 | """ 38 | Creates a 3D scatter plot of selected and rejected tomogram coordinates. 39 | 40 | Args: 41 | n: Tomogram index (used mainly for logging/titles). 42 | coords_sel: Dict with '_rlnCoordinateX/Y/Z' lists for selected points. 43 | coords_rej: Dict with '_rlnCoordinateX/Y/Z' lists for rejected points. 44 | scatter_size: Tuple of marker sizes (selected, rejected). 45 | 46 | Returns: 47 | A Plotly Figure object containing the 3D scatter plot. 48 | """ 49 | fig = go.Figure() 50 | required_keys = ["_rlnCoordinateX", "_rlnCoordinateY", "_rlnCoordinateZ"] 51 | 52 | try: 53 | # Plot selected points if data exists and is valid 54 | sel_x, sel_y, sel_z = [], [], [] 55 | if coords_sel and all(k in coords_sel for k in required_keys): 56 | sel_x, sel_y, sel_z = ( 57 | coords_sel["_rlnCoordinateX"], 58 | coords_sel["_rlnCoordinateY"], 59 | coords_sel["_rlnCoordinateZ"], 60 | ) 61 | if sel_x: # Check if list is not empty 62 | fig.add_trace( 63 | go.Scatter3d( 64 | x=sel_x, 65 | y=sel_y, 66 | z=sel_z, 67 | mode="markers", 68 | marker=dict(size=scatter_size[0], color="green", opacity=0.7), 69 | name=f"Selected ({len(sel_x)})", 70 | customdata=np.arange(len(sel_x)), 71 | hovertemplate="Selected Pick
X: %{x:.1f}
Y: %{y:.1f}
Z: %{z:.1f}", 72 | ) 73 | ) 74 | else: 75 | logger.info(f"No selected coordinates to plot for tomogram index {n}.") 76 | else: 77 | logger.info( 78 | f"Selected coordinates data structure invalid or missing keys for tomogram index {n}." 79 | ) 80 | 81 | # Plot rejected points if data exists and is valid 82 | rej_x, rej_y, rej_z = [], [], [] 83 | if coords_rej and all(k in coords_rej for k in required_keys): 84 | rej_x, rej_y, rej_z = ( 85 | coords_rej["_rlnCoordinateX"], 86 | coords_rej["_rlnCoordinateY"], 87 | coords_rej["_rlnCoordinateZ"], 88 | ) 89 | if rej_x: # Check if list is not empty 90 | fig.add_trace( 91 | go.Scatter3d( 92 | x=rej_x, 93 | y=rej_y, 94 | z=rej_z, 95 | mode="markers", 96 | marker=dict(size=scatter_size[1], color="red", opacity=0.6), 97 | name=f"Rejected ({len(rej_x)})", 98 | customdata=np.arange(len(rej_x)), 99 | hovertemplate="Rejected Pick
X: %{x:.1f}
Y: %{y:.1f}
Z: %{z:.1f}", 100 | ) 101 | ) 102 | else: 103 | logger.info(f"No rejected coordinates to plot for tomogram index {n}.") 104 | else: 105 | logger.info( 106 | f"Rejected coordinates data structure invalid or missing keys for tomogram index {n}." 107 | ) 108 | 109 | except KeyError as e: 110 | report_error( 111 | e, 112 | f"Missing coordinate key '{e}' while plotting tomogram picks for index {n}", 113 | ) 114 | st.warning(f"Coordinate data missing key: {e}. Cannot plot picks fully.") 115 | except Exception as e: 116 | report_error(e, f"Error plotting tomogram picks for index {n}") 117 | st.warning(f"Could not plot tomogram picks: {e}") 118 | 119 | fig.update_layout( 120 | title=f"Tomogram {n + 1}: Particle Picks 3D View", 121 | scene=dict( 122 | xaxis_title="X (px)", 123 | yaxis_title="Y (px)", 124 | zaxis_title="Z (px)", 125 | aspectmode="data", 126 | ), 127 | legend_title="Pick Status", 128 | hovermode="closest", 129 | height=600, 130 | margin=dict(l=10, r=10, t=50, b=10), 131 | ) 132 | return fig 133 | 134 | 135 | def plot_import(rln_folder: str, node_files: List[str]) -> None: 136 | """ 137 | Processes and displays information for a RELION Import job. 138 | 139 | Handles visualization for imported movies, tomograms (including Relion 5 140 | STAR-based imports), and particle coordinates. Logs errors instead of 141 | showing st.error messages. 142 | 143 | Args: 144 | rln_folder: Base directory of the RELION project. 145 | node_files: List of node file paths associated with this job. 146 | """ 147 | if not node_files: 148 | st.warning("No node files provided for Import job.") 149 | return 150 | 151 | star_file_path = os.path.join(rln_folder, node_files[0]) 152 | logger.info(f"Processing Import job using STAR file: {star_file_path}") 153 | 154 | if not os.path.exists(star_file_path): 155 | logger.error(f"Import job STAR file not found: {star_file_path}") 156 | st.warning(f"Import STAR file '{node_files[0]}' not found.") 157 | return 158 | 159 | try: 160 | star_data_all = parse_star(star_file_path) 161 | if not star_data_all: 162 | st.warning(f"Could not parse or data empty in: {node_files[0]}") 163 | return 164 | except Exception as e: 165 | report_error(e, f"Failed to parse STAR file {node_files[0]} for Import job") 166 | st.warning(f"Error parsing '{node_files[0]}'. Check logs.") 167 | return 168 | 169 | # Determine data type 170 | is_movie_import = any("movies" in node for node in node_files) 171 | is_tomo_import = any("tilt_series" in node for node in node_files) or any( 172 | "tomograms" in node for node in node_files 173 | ) 174 | is_particle_import = any("particles" in node for node in node_files) 175 | 176 | # --- Movie/Tomogram Data Processing --- 177 | if is_movie_import or is_tomo_import: 178 | logger.info("Processing movie/tomogram import.") 179 | file_names: Optional[List[str]] = None 180 | 181 | try: 182 | # --- Corrected Logic to find the data block --- 183 | data_block = None 184 | movies_block = star_data_all.get("movies") 185 | if isinstance(movies_block, pd.DataFrame) and not movies_block.empty: 186 | data_block = movies_block 187 | logger.debug("Using 'movies' data block.") 188 | else: 189 | global_block = star_data_all.get("global") 190 | if isinstance(global_block, pd.DataFrame) and not global_block.empty: 191 | data_block = global_block 192 | logger.debug("Using 'global' data block.") 193 | else: 194 | first_key = get_first_key(star_data_all) 195 | if first_key: 196 | first_block = star_data_all.get(first_key) 197 | if ( 198 | isinstance(first_block, pd.DataFrame) 199 | and not first_block.empty 200 | ): 201 | data_block = first_block 202 | logger.debug(f"Using first data block found: '{first_key}'") 203 | 204 | if data_block is None: 205 | block_names = list(star_data_all.keys()) 206 | msg = f"Could not find a suitable non-empty data block ('movies', 'global', or first) in STAR file. Blocks: {block_names}" 207 | logger.error(msg) 208 | raise KeyError(msg) 209 | # --- End Corrected Logic --- 210 | 211 | # Extract filenames 212 | if "_rlnMicrographMovieName" in data_block: 213 | file_names = data_block["_rlnMicrographMovieName"].tolist() 214 | logger.debug("Found movie names in _rlnMicrographMovieName.") 215 | elif "_rlnTomoTiltSeriesName" in data_block: 216 | file_names = data_block["_rlnTomoTiltSeriesName"].tolist() 217 | logger.debug("Found tomogram base names in _rlnTomoTiltSeriesName.") 218 | elif "_rlnTomoTiltSeriesStarFile" in data_block: # Relion 5+ Tomo 219 | tomo_star_files = data_block["_rlnTomoTiltSeriesStarFile"].tolist() 220 | logger.info( 221 | f"Relion 5 tomo import: {len(tomo_star_files)} tilt series STARs." 222 | ) 223 | all_tilt_movie_names = [] 224 | failed_parses = 0 225 | for ts_star_rel_path in tomo_star_files: 226 | ts_star_full_path = os.path.join(rln_folder, ts_star_rel_path) 227 | if not os.path.exists(ts_star_full_path): 228 | logger.warning( 229 | f"Tilt series STAR not found: {ts_star_full_path}" 230 | ) 231 | failed_parses += 1 232 | continue 233 | ts_data = parse_star(ts_star_full_path) 234 | ts_movies_block = ts_data.get(get_first_key(ts_data)) 235 | if ( 236 | ts_movies_block is not None 237 | and "_rlnMicrographMovieName" in ts_movies_block 238 | ): 239 | all_tilt_movie_names.extend( 240 | ts_movies_block["_rlnMicrographMovieName"].tolist() 241 | ) 242 | else: 243 | logger.warning( 244 | f"No movie names in tilt series STAR: {ts_star_rel_path}" 245 | ) 246 | failed_parses += 1 247 | if failed_parses > 0: 248 | st.warning( 249 | f"Failed to process {failed_parses} tilt series STAR file(s)." 250 | ) 251 | file_names = all_tilt_movie_names 252 | else: 253 | raise KeyError("No known movie/tomogram filename column found.") 254 | 255 | except KeyError as e: 256 | report_error(e, f"Missing key in {node_files[0]} for movie/tomo filenames") 257 | st.warning( 258 | f"Could not find movie/tomogram filenames in STAR file: {e}. See logs." 259 | ) 260 | return 261 | except Exception as e: 262 | report_error( 263 | e, f"Error processing movie/tomo import data from {node_files[0]}" 264 | ) 265 | st.warning(f"Error processing movie/tomo data: {e}. See logs.") 266 | return 267 | 268 | if not file_names: 269 | st.warning("No movie or tomogram file paths were found in the STAR file.") 270 | return 271 | 272 | # --- Modification Time Plot --- 273 | st.subheader(f"Imported {len(file_names)} Files") 274 | file_names_abs = [os.path.join(rln_folder, fn) for fn in file_names] 275 | limit_display = 100 276 | files_to_check = file_names_abs 277 | if len(file_names_abs) > limit_display: 278 | if st.checkbox( 279 | f"Limit timeline plot to {limit_display} random files?", 280 | value=True, 281 | key="limit_import_plot", 282 | ): 283 | indices = np.random.choice( 284 | len(file_names_abs), limit_display, replace=False 285 | ) 286 | files_to_check = [file_names_abs[i] for i in indices] 287 | st.caption( 288 | f"Showing timeline for {limit_display}/{len(file_names_abs)} files." 289 | ) 290 | 291 | with st.spinner( 292 | f"Loading modification times for {len(files_to_check)} files..." 293 | ): 294 | try: 295 | with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: 296 | mod_times_opt: List[Optional[datetime]] = list( 297 | executor.map(get_modification_time, files_to_check) 298 | ) 299 | file_mod_times: List[datetime] = sorted( 300 | [mt for mt in mod_times_opt if mt is not None] 301 | ) 302 | except Exception as e: 303 | report_error(e, "Error fetching modification times.") 304 | st.warning(f"Could not fetch file modification times: {e}. See logs.") 305 | file_mod_times = [] 306 | 307 | if file_mod_times: 308 | fig = go.Figure() 309 | fig.add_trace( 310 | go.Scatter( 311 | x=list(range(len(file_mod_times))), 312 | y=file_mod_times, 313 | mode="markers", 314 | name="Timestamp", 315 | marker=dict(size=4), 316 | hovertemplate="Index: %{x}
Time: %{y|%Y-%m-%d %H:%M:%S}", 317 | ) 318 | ) 319 | fig.update_layout( 320 | title="File Import Timeline (by modification date)", 321 | xaxis_title="File Index (sorted by time)", 322 | yaxis_title="Modification Timestamp", 323 | height=300, 324 | margin=dict(l=50, r=20, t=40, b=40), 325 | ) 326 | with st.expander("Show Import Timeline", expanded=False): 327 | st.plotly_chart(fig, use_container_width=True) 328 | else: 329 | st.caption("No valid modification times found.") 330 | 331 | # --- Micrograph Viewer --- 332 | st.subheader("Micrograph/Tomogram Viewer") 333 | try: 334 | micrograph_viewer(rln_folder, file_names) 335 | except Exception as e: 336 | report_error(e, "Error displaying micrograph viewer.") 337 | st.warning(f"Could not display image viewer: {e}. See logs.") 338 | 339 | # --- Particle Data Processing --- 340 | elif is_particle_import: 341 | logger.info("Processing particle import.") 342 | particles_block = star_data_all.get("particles") 343 | if particles_block is None or particles_block.empty: 344 | st.warning("No 'particles' data block found in the STAR file.") 345 | return 346 | 347 | if "_rlnCoordinateZ" in particles_block.columns: 348 | logger.info("Detected tomogram particle coordinates.") 349 | st.subheader("Tomogram Particle Picks") 350 | tomo_name_col = "_rlnTomoName" 351 | if tomo_name_col not in particles_block.columns: 352 | logger.error( 353 | f"Missing '{tomo_name_col}' column for grouping particles." 354 | ) 355 | st.warning( 356 | f"Missing '{tomo_name_col}' column. Cannot display picks by tomogram." 357 | ) 358 | return 359 | 360 | unique_tomos = particles_block[tomo_name_col].unique() 361 | if len(unique_tomos) == 0: 362 | st.warning("No tomogram names found in particle data.") 363 | return 364 | 365 | tomo_idx = st.slider( 366 | "Select Tomogram Index:", 367 | 0, 368 | len(unique_tomos) - 1, 369 | 0, 370 | key="import_particle_tomo_idx", 371 | ) 372 | selected_tomo_name = unique_tomos[tomo_idx] 373 | st.caption(f"Showing picks for: {selected_tomo_name}") 374 | 375 | tomo_coords_df = particles_block[ 376 | particles_block[tomo_name_col] == selected_tomo_name 377 | ] 378 | # Ensure required columns exist before creating dict 379 | coord_cols = ["_rlnCoordinateX", "_rlnCoordinateY", "_rlnCoordinateZ"] 380 | if not all(col in tomo_coords_df.columns for col in coord_cols): 381 | logger.error( 382 | "Missing coordinate columns in filtered particle data for selected tomogram." 383 | ) 384 | st.warning( 385 | "Missing coordinate columns for selected tomogram. Cannot plot picks." 386 | ) 387 | return 388 | 389 | coords_dict = {k: tomo_coords_df[k].tolist() for k in coord_cols} 390 | coords_rej_empty: Dict[str, List[float]] = {k: [] for k in coord_cols} 391 | 392 | fig_3d = plot_tomogram_picks(tomo_idx, coords_dict, coords_rej_empty) 393 | st.plotly_chart(fig_3d, use_container_width=True) 394 | 395 | else: # SPA particle import 396 | logger.info("Detected SPA particle coordinates.") 397 | st.info( 398 | f"SPA particle coordinates imported ({len(particles_block)} particles). No 3D plot generated." 399 | ) 400 | 401 | else: 402 | st.warning( 403 | "Could not determine data type (movies/tomograms/particles) from node files." 404 | ) 405 | logger.warning(f"Import job type unclear for node files: {node_files}") 406 | 407 | logger.info(f"Finished processing Import job: {node_files[0]}") 408 | -------------------------------------------------------------------------------- /relion_jobs/class3d_job.py: -------------------------------------------------------------------------------- 1 | # class3d_job.py 2 | import os 3 | import glob 4 | import logging 5 | import traceback 6 | import math 7 | from datetime import datetime 8 | from typing import List 9 | import re 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import mrcfile 14 | import streamlit as st 15 | from plotly.subplots import make_subplots 16 | 17 | # Shared utilities from your project. 18 | from lib.utils import ( 19 | parse_star, 20 | interactive_scatter_plot, 21 | report_error, 22 | get_angles, 23 | get_classes, 24 | get_note, 25 | ) 26 | from lib.image_utils import (downsample_volume, plot_angular_distribution_sphere, plot_fsc_stats, plot_class_resolution, 27 | plot_class_distribution, plot_projections, plot_volume, plot_angular_distribution_heatmap) 28 | 29 | 30 | logger = logging.getLogger("main_app") 31 | 32 | 33 | 34 | def _load_and_cache_job_data(rln_folder: str, job_name: str, nodes: List[str]) -> dict: 35 | """ 36 | Load all volumes, class distributions, resolutions, etc. for the given job_name. 37 | Downsample volumes and store everything in st.session_state[data_key]. 38 | """ 39 | data_key = f"class3d_data_{job_name}" 40 | if data_key not in st.session_state: 41 | st.session_state[data_key] = {} 42 | 43 | # If already loaded, just return 44 | if st.session_state[data_key].get("loaded", False): 45 | logger.debug(f"Data for job '{job_name}' found in session state.") 46 | return st.session_state[data_key] 47 | 48 | with st.spinner("Loading job data from disk..."): 49 | # 1) Construct path to job folder 50 | path_data = os.path.join(rln_folder, job_name) 51 | 52 | # 2) Gather model.star files 53 | model_files = glob.glob(os.path.join(path_data, "*model.star")) 54 | model_files.sort(key=os.path.getmtime) 55 | logger.debug(f"Found {len(model_files)} model.star files for job '{job_name}': {model_files}") 56 | 57 | # get symmetry information from the note 58 | symmetry = 'C1' 59 | note = get_note(os.path.join(path_data, 'note.txt')) 60 | logger.debug(f"Note file content: {note}") 61 | 62 | # find --sym XXX in the note using re 63 | if note: 64 | match = re.search(r'--sym\s+([A-Z]\d+)', note) 65 | if match: 66 | symmetry = match.group(1) 67 | logger.debug(f"Found symmetry: {symmetry}") 68 | 69 | # 3) If no model.star, fallback to MRC in the nodes 70 | if len(model_files) == 0: 71 | logger.info("No *_model.star files found. Will try to read MRC volumes directly.") 72 | mrc_list = [f for f in nodes if f.lower().endswith(".mrc")] 73 | logger.debug(f"original MRC files: {mrc_list}") 74 | 75 | if any("merged" in f for f in mrc_list): 76 | mrc_list = [f for f in mrc_list if "merged" in f] 77 | elif any("half1" in f for f in mrc_list): 78 | mrc_list = [f for f in mrc_list if "half1" in f] 79 | 80 | logger.debug(f"Filtered MRC files: {mrc_list}") 81 | 82 | volumes = [] 83 | for mrc_file in mrc_list: 84 | volume_path = os.path.join(rln_folder, mrc_file) 85 | try: 86 | logger.debug(f"Loading MRC volume: {volume_path}") 87 | with mrcfile.mmap(volume_path, permissive=True) as mrcf: 88 | volumes.append(mrcf.data) 89 | except Exception: 90 | error = traceback.format_exc() 91 | logger.error(f"Error displaying volume {mrc_file}:\n{error}") 92 | st.session_state[data_key] = { 93 | "volumes_downsampled": volumes, 94 | "class_dist": [1.0] * len(volumes), 95 | "class_res": np.array([]), 96 | "fsc_res": np.array([]), 97 | "fsc_vals": np.array([]), 98 | "angles": ([], [], []), 99 | "star_block": None, 100 | "n_classes": len(volumes), 101 | "class_paths": mrc_list, 102 | "loaded": True, 103 | "symmetry": symmetry, 104 | } 105 | return st.session_state[data_key] 106 | 107 | # 4) If we have model.star, parse them 108 | try: 109 | (class_paths, n_classes, _iter_count, 110 | class_dist, class_res, fsc_res, fsc_vals) = get_classes(path_data, model_files) 111 | except Exception as e: 112 | report_error(e) 113 | logger.error(f"Error in get_classes: {e}") 114 | st.session_state[data_key] = {} 115 | return {} 116 | 117 | # 5) Load volumes 118 | volumes_raw = [] 119 | for cls_path in class_paths: 120 | full_path = os.path.join(path_data, os.path.basename(cls_path)) 121 | logger.debug(f"Loading volume: {full_path}") 122 | with mrcfile.mmap(full_path, permissive=True) as mrcf: 123 | volumes_raw.append(mrcf.data) 124 | 125 | # 6) Angles 126 | try: 127 | rot, tilt, psi = get_angles(path_data) 128 | except Exception as e: 129 | logger.debug(f"Failed to load angles from {path_data}: {e}") 130 | rot, tilt, psi = ([], [], []) 131 | 132 | # 7) Attempt to parse star file for star_block 133 | star_block = None 134 | try: 135 | if nodes: 136 | star_main = parse_star(os.path.join(path_data, os.path.basename(nodes[0]))) 137 | star_block = star_main.get("particles", pd.DataFrame()) 138 | except Exception as e: 139 | logger.debug(f"Failed to parse star file from nodes[0]: {e}") 140 | 141 | # 8) Store in st.session_state 142 | st.session_state[data_key] = { 143 | "volumes_raw": volumes_raw, 144 | "volumes_downsampled": [], 145 | "class_dist": class_dist, 146 | "class_res": class_res, 147 | "fsc_res": fsc_res, 148 | "fsc_vals": fsc_vals, 149 | "angles": (rot, tilt, psi), 150 | "star_block": star_block, 151 | "n_classes": n_classes, 152 | "class_paths": class_paths, 153 | "loaded": True, 154 | "symmetry": symmetry, 155 | } 156 | 157 | return st.session_state[data_key] 158 | 159 | 160 | def _ensure_downsampled(job_data: dict, map_resize: int) -> List[np.ndarray]: 161 | """ 162 | Ensure the volumes in session_state are downsampled to 'map_resize' size. 163 | Only redo if the stored version does not match the requested size. 164 | """ 165 | if not job_data.get("volumes_raw", []): 166 | return job_data.get("volumes_downsampled", []) 167 | 168 | # Check if we already have downsampled volumes for this size 169 | # For clarity, we store them in job_data["volumes_downsampled_{size}"] or a single key with the size 170 | cached_size = job_data.get("cached_map_resize", None) 171 | if cached_size == map_resize and job_data.get("volumes_downsampled", []): 172 | logger.debug(f"Using cached volumes downsampled to {map_resize}.") 173 | return job_data["volumes_downsampled"] 174 | 175 | # Otherwise, downsample now 176 | new_downsampled = [] 177 | for vol in job_data["volumes_raw"]: 178 | dvol = downsample_volume(vol, map_resize) 179 | new_downsampled.append(dvol) 180 | 181 | job_data["volumes_downsampled"] = new_downsampled 182 | job_data["cached_map_resize"] = map_resize 183 | return new_downsampled 184 | 185 | 186 | def plot_combined_classes(volumes, class_dist, session_key="plot_combined_classes"): 187 | """ 188 | Display multiple 3D volumes as isosurfaces in a grid of Plotly subplots, 189 | with user controls for threshold, resizing, columns, and row height. 190 | Uses the `plot_volume` function to generate iso-surfaces, caching the 191 | results in st.session_state to avoid repeated computation. 192 | 193 | 'session_key' can incorporate a job-specific identifier to keep caches separate. 194 | The background for each subplot is set to black. 195 | """ 196 | logger.debug(f"{datetime.now()}: plot_combined_classes called with {len(volumes)} volumes.") 197 | 198 | # Initialize session state cache for this session_key if needed. 199 | if session_key not in st.session_state: 200 | st.session_state[session_key] = {} 201 | 202 | # Retrieve local iso-surface cache. 203 | iso_cache = st.session_state.get("iso_surface_cache", {}) 204 | #logger.debug(f"Current iso_surface_cache: {iso_cache}") 205 | 206 | volumes_n = len(volumes) 207 | if volumes_n < 5: 208 | columns_to_show = volumes_n 209 | if volumes_n == 1: 210 | plot_height = 700 211 | else: 212 | plot_height = 600 213 | else: 214 | columns_to_show = 5 215 | plot_height = 300 216 | 217 | 218 | 219 | # Basic UI controls. 220 | c1, c2 = st.columns(2) 221 | with c1: 222 | threshold = st.slider("Select Volume Threshold (Fraction)", 0.0, 1.0, 0.5, 0.01) 223 | map_resize = st.slider("Map size (px)", 64, 256, 150, 2) 224 | with c2: 225 | n_columns = st.slider("Number of columns", min_value=1, max_value=5, value=columns_to_show, step=1) 226 | row_height = st.slider("Plot height", min_value=100, max_value=1000, value=plot_height, step=100) 227 | 228 | show_class = st.checkbox("Show class volumes?", value=True) 229 | if not show_class: 230 | st.info("Class volumes are hidden. Check the box to show them.") 231 | return 232 | 233 | num_classes = len(volumes) 234 | cols = min(n_columns, num_classes) 235 | rows = math.ceil(num_classes / n_columns) 236 | 237 | # Create subplots with type 'scene' for 3D. 238 | fig = make_subplots(rows=rows, cols=cols, specs=[[{"type": "scene"} for _ in range(cols)] for _ in range(rows)]) 239 | 240 | annotations = [] 241 | for idx, volume in enumerate(volumes): 242 | row_idx, col_idx = divmod(idx, cols) 243 | # Create a cache key from session_key, volume index, threshold, and map_resize. 244 | cache_key = (session_key, idx, threshold, map_resize) 245 | if cache_key in iso_cache: 246 | fig_ = iso_cache[cache_key] 247 | logger.debug(f"Reusing cached iso-surface for volume {idx+1}.") 248 | else: 249 | with st.spinner(f"Creating iso-surface for class {idx+1}..."): 250 | fig_ = plot_volume(volume, threshold, max_size=map_resize) 251 | iso_cache[cache_key] = fig_ 252 | 253 | if fig_ and len(fig_.data) > 0: 254 | fig.add_trace(fig_.data[0], row=row_idx+1, col=col_idx+1) 255 | dist_percent = 0.0 256 | if idx < len(class_dist): 257 | dist_percent = round(float(class_dist[idx]) * 100, 2) 258 | ann_text = f"Class {idx+1}
Dist: {dist_percent}%" 259 | x_ = (col_idx + 0.5) / cols 260 | y_ = 1 - (row_idx / rows) - 0.05 261 | annotations.append( 262 | dict( 263 | x=x_, 264 | y=y_, 265 | xref="paper", 266 | yref="paper", 267 | text=ann_text, 268 | showarrow=False, 269 | xanchor="center", 270 | yanchor="bottom", 271 | font=dict(size=12), 272 | ) 273 | ) 274 | 275 | # Set overall layout with black background. 276 | fig.update_layout( 277 | hovermode=False, 278 | annotations=annotations, 279 | height=rows * row_height, 280 | margin=dict(l=0, r=0, t=0, b=0), 281 | paper_bgcolor="black", # overall background color 282 | ) 283 | fig.update_layout(scene_dragmode='orbit', title='3D Volume Isosurfaces') 284 | 285 | # Configure each subplot's scene with a black background. 286 | for n in range(num_classes): 287 | scene_id = f"scene{n+1}" if n > 0 else "scene" 288 | fig.update_layout( 289 | **{ 290 | scene_id: dict( 291 | xaxis=dict(visible=False), 292 | yaxis=dict(visible=False), 293 | zaxis=dict(visible=False), 294 | camera=dict(eye=dict(x=2.5, y=2.5, z=2.5)), 295 | bgcolor="black" # set each scene's background to black. 296 | ) 297 | } 298 | ) 299 | 300 | # Store the updated cache back into session state. 301 | st.session_state["iso_surface_cache"] = iso_cache 302 | 303 | st.plotly_chart(fig, use_container_width=True) 304 | 305 | 306 | def plot_class3d(rln_folder: str, nodes: List[str]) -> None: 307 | """ 308 | Main function to plot the results of a 3D classification or refinement job, 309 | caching data in session state to avoid repeated I/O and repeated downsampling. 310 | 311 | This function: 312 | 1) Clears/sets a job-specific session state if the job changed. 313 | 2) Loads data for the job, downsampling volumes as needed. 314 | 3) Displays combined classes, volume projections, class distributions, etc. 315 | """ 316 | logger.debug(f"{datetime.now()}: plot_class3d called with nodes={nodes}") 317 | if not nodes: 318 | st.warning("No files provided in 'nodes'.") 319 | return 320 | 321 | job_name = os.path.dirname(nodes[0]) # Basic extraction of job folder name 322 | if not job_name: 323 | st.warning("Could not determine job name from nodes.") 324 | return 325 | 326 | # Check if this is a new job 327 | if "current_class3d_job" not in st.session_state or st.session_state["current_class3d_job"] != job_name: 328 | # It's a different job => reset iso_surface_cache or use a job-specific key 329 | st.session_state["current_class3d_job"] = job_name 330 | st.session_state["iso_surface_cache"] = {} # Clear old caches for iso surfaces 331 | 332 | if 'Class3D' in job_name: 333 | job_type = 'Class3D' 334 | elif 'Refine3D' in job_name: 335 | job_type = 'Refine3D' 336 | elif 'Reconstruct' in job_name: 337 | job_type = 'Reconstruct' 338 | else: 339 | job_type = '3D Classification' 340 | 341 | st.subheader(f"{job_type} job: {job_name}") 342 | 343 | # 1) Load or retrieve data from session state 344 | job_data = _load_and_cache_job_data(rln_folder, job_name, nodes) 345 | if not job_data.get("loaded", False): 346 | st.warning("Could not load the job data properly.") 347 | return 348 | 349 | # 2) If user modifies map_resize, only re-downsample volumes if needed 350 | # Optional: you can store 'plot_combined_classes_map_resize' in st.session_state 351 | map_resize_request = 150 352 | if "plot_combined_classes_map_resize" in st.session_state: 353 | map_resize_request = st.session_state["plot_combined_classes_map_resize"] 354 | with st.spinner("Downsampling Volumes", show_time=True): 355 | volumes_downsampled = _ensure_downsampled(job_data, map_resize_request) 356 | n_classes = job_data["n_classes"] 357 | class_dist = job_data["class_dist"] 358 | class_res = job_data["class_res"] 359 | 360 | # Ensure FSC data is NumPy arrays 361 | fsc_res = np.array(job_data["fsc_res"], dtype=float) if isinstance(job_data["fsc_res"], (list, tuple)) else job_data["fsc_res"] 362 | fsc_vals = np.array(job_data["fsc_vals"], dtype=float) if isinstance(job_data["fsc_vals"], (list, tuple)) else job_data["fsc_vals"] 363 | 364 | rot, tilt, psi = job_data["angles"] 365 | # Ensure angles are NumPy arrays not pandas Series 366 | rot = np.array(rot, dtype=float) if isinstance(rot, (pd.Series)) else rot 367 | tilt = np.array(tilt, dtype=float) if isinstance(tilt, (pd.Series)) else tilt 368 | psi = np.array(psi, dtype=float) if isinstance(psi, (pd.Series)) else psi 369 | 370 | 371 | star_block = job_data["star_block"] 372 | 373 | logger.debug(f'rot: {rot}, tilt: {tilt}, psi: {psi}') 374 | 375 | # Determine final distribution if multiple iterations 376 | if isinstance(class_dist, np.ndarray) and class_dist.ndim == 2: 377 | class_dist_final = class_dist[:, -1] # last iteration 378 | else: 379 | class_dist_final = class_dist 380 | 381 | if n_classes == 0 and len(volumes_downsampled) == 0: 382 | st.warning("No 3D class references found.") 383 | return 384 | 385 | st.write(f"Found {n_classes} classes in final iteration.") 386 | if len(volumes_downsampled) < n_classes: 387 | pass 388 | 389 | # 3) Show 3D volumes 390 | with st.expander("Combined 3D Volumes", expanded=True): 391 | # use a job-specific session key => "plot_combined_classes_{job_name}" 392 | job_session_key = f"plot_combined_classes_{job_name}" 393 | with st.spinner("plotting combined classes..."): 394 | plot_combined_classes(volumes_downsampled, class_dist_final, session_key=job_session_key) 395 | 396 | # 4) Volume projections 397 | with st.expander("Volume Projections", expanded=True): 398 | plot_projections(volumes_downsampled, class_dist_final) 399 | 400 | # 5) Class distribution 401 | if n_classes > 1 and isinstance(class_dist, np.ndarray) and class_dist.ndim == 2: 402 | with st.expander("Class Distribution Over Iterations", expanded=True): 403 | plot_class_distribution(class_dist) 404 | 405 | # 6) FSC for "Refine3D" jobs 406 | if "Refine3D" in job_name and fsc_res.size > 0 and fsc_vals.size > 0: 407 | col1, col2 = st.columns(2) 408 | with col1.expander("Fourier Shell Correlation (FSC) Stats", expanded=True): 409 | plot_fsc_stats(fsc_res, fsc_vals) 410 | 411 | # 7) Class resolution 412 | if isinstance(class_res, np.ndarray) and class_res.size > 0: 413 | if "Refine3D" in job_name: 414 | with col2.expander("Class Resolution", expanded=True): 415 | plot_class_resolution(class_res) 416 | else: 417 | with st.expander("Class Resolution", expanded=True): 418 | plot_class_resolution(class_res) 419 | 420 | # 8) Angular distribution 421 | try: 422 | if n_classes > 1 and star_block is not None and "_rlnClassNumber" in star_block.columns: 423 | cls_idx = star_block["_rlnClassNumber"] 424 | else: 425 | cls_idx = None 426 | 427 | if not (rot.all() and tilt.all() and psi.all()): 428 | logger.info("No angles found for angular distribution.") 429 | return 430 | else: 431 | with st.expander("Angular Distribution", expanded=True): 432 | plot_type = st.radio("Select plot type:", ("3D", "2D"), index=0, horizontal=True) 433 | # Example: 3D sphere representation 434 | # or you can call plot_angular_distribution(...) for 2D histograms 435 | if plot_type == "3D": 436 | plot_angular_distribution_sphere(psi, rot, tilt, cls_idx, symmetry=job_data["symmetry"]) 437 | else: 438 | plot_angular_distribution_heatmap(psi, rot, tilt, cls_idx, symmetry=job_data["symmetry"]) 439 | 440 | # Show star file scatter 441 | if st.checkbox("Show star file interactive scatter?"): 442 | interactive_scatter_plot(os.path.join(rln_folder, nodes[0])) 443 | 444 | except Exception as e: 445 | st.warning("No angle or particle data found for angular distribution.") 446 | logger.debug(f"Angular distribution error: {report_error(e)}") 447 | 448 | logger.info(f"{datetime.now()}: plot_class3d done.") -------------------------------------------------------------------------------- /relion_jobs/motioncorr_job.py: -------------------------------------------------------------------------------- 1 | # motioncorr_job.py 2 | 3 | import logging 4 | import os 5 | from typing import List, Optional 6 | 7 | # Third-Party Imports 8 | import altair as alt # Restore Altair import 9 | import numpy as np 10 | import pandas as pd 11 | import plotly.graph_objects as go 12 | import streamlit as st 13 | 14 | # Local Imports 15 | from lib.image_utils import micrograph_viewer 16 | from lib.utils import ( 17 | get_first_key, 18 | get_values_from_first_key, 19 | interactive_scatter_plot, 20 | parse_star, 21 | report_error, 22 | ) 23 | 24 | # ============================================================================= 25 | # Logger Setup 26 | # ============================================================================= 27 | logger = logging.getLogger("main_app") # Use logger from main app 28 | 29 | 30 | def show_motion(rln_folder: str, mic_paths: List[str]) -> None: 31 | """ 32 | Plots global and local motion shifts for a selected micrograph using Plotly. 33 | 34 | Args: 35 | rln_folder: Base folder containing the project data. 36 | mic_paths: List of relative paths to micrograph metadata STAR files. 37 | """ 38 | if not mic_paths: 39 | st.info("No micrograph metadata files provided to display motion.") 40 | return 41 | 42 | try: 43 | # --- Micrograph Selection --- 44 | col1, col2 = st.columns([1, 3]) 45 | with col1: 46 | if len(mic_paths) == 1: 47 | idx = 0 48 | st.caption(f"Displaying motion for: {os.path.basename(mic_paths[0])}") 49 | else: 50 | idx = st.slider("Select Micrograph Index:", 0, len(mic_paths) - 1, 0, key="motioncorr_mic_slider") 51 | mic_metadata_rel_path = mic_paths[idx] 52 | st.caption(f"File: {os.path.basename(mic_metadata_rel_path)}") 53 | 54 | mic_metadata_abs_path = os.path.join(rln_folder, mic_metadata_rel_path) 55 | logger.debug(f"Processing motion from file: {mic_metadata_abs_path}") 56 | 57 | if not os.path.exists(mic_metadata_abs_path): 58 | st.warning(f"Metadata file not found: {mic_metadata_rel_path}") 59 | return 60 | 61 | # --- Parse STAR File --- 62 | motion_star = parse_star(mic_metadata_abs_path) 63 | if not motion_star: 64 | st.warning(f"Could not parse or empty motion metadata file: {mic_metadata_rel_path}") 65 | return 66 | 67 | # --- Extract Motion Data --- 68 | global_shift_df = motion_star.get("global_shift") 69 | local_shift_df = motion_star.get("local_shift") 70 | 71 | if global_shift_df is None and local_shift_df is None: 72 | st.info("No global or local motion data found in this metadata file.") 73 | return 74 | 75 | fig = go.Figure() 76 | traces_added = False 77 | 78 | # --- Process and Plot Local Motion --- 79 | if isinstance(local_shift_df, pd.DataFrame) and not local_shift_df.empty: 80 | try: 81 | local_shift = local_shift_df.astype(float) 82 | coord_x_col, coord_y_col = "_rlnCoordinateX", "_rlnCoordinateY" 83 | shift_x_col, shift_y_col = "_rlnMicrographShiftX", "_rlnMicrographShiftY" 84 | required_local_cols = [coord_x_col, coord_y_col, shift_x_col, shift_y_col] 85 | 86 | if all(col in local_shift.columns for col in required_local_cols): 87 | fold_local_motion = 200 #col1.slider( 88 | # "Local Motion Scale:", 1, 500, 100, 10, key="local_motion_scale", 89 | #help="Magnifies local shifts for visibility.") 90 | 91 | local_shift["X_start"] = local_shift[coord_x_col] 92 | local_shift["Y_start"] = local_shift[coord_y_col] 93 | local_shift["X_end"] = local_shift[coord_x_col] + local_shift[shift_x_col] * fold_local_motion 94 | local_shift["Y_end"] = local_shift[coord_y_col] + local_shift[shift_y_col] * fold_local_motion 95 | 96 | grouped = local_shift.groupby([coord_x_col, coord_y_col]) 97 | logger.debug(f"Plotting {len(grouped)} local motion tracks.") 98 | all_x_local, all_y_local = [], [] 99 | for _, group_df in grouped: 100 | all_x_local.extend([group_df['X_start'].iloc[0], group_df['X_end'].iloc[-1], None]) 101 | all_y_local.extend([group_df['Y_start'].iloc[0], group_df['Y_end'].iloc[-1], None]) 102 | 103 | if all_x_local: 104 | fig.add_trace(go.Scatter( 105 | x=all_x_local, y=all_y_local, mode="lines", name="Local Shifts", 106 | line=dict(color="#039d83", width=1), hoverinfo='none' 107 | )) 108 | traces_added = True 109 | else: logger.warning(f"Missing required local motion columns in {mic_metadata_rel_path}") 110 | except Exception as exc: 111 | report_error(exc, f"Error processing local shift data for {mic_metadata_rel_path}") 112 | st.warning(f"Could not process local motion data: {exc}") 113 | 114 | # --- Process and Plot Global Motion --- 115 | if isinstance(global_shift_df, pd.DataFrame) and not global_shift_df.empty: 116 | try: 117 | global_shift = global_shift_df.astype(float) 118 | shift_x_col, shift_y_col = "_rlnMicrographShiftX", "_rlnMicrographShiftY" 119 | if all(col in global_shift.columns for col in [shift_x_col, shift_y_col]): 120 | center_x, center_y = 0.0, 0.0 121 | if isinstance(local_shift_df, pd.DataFrame) and not local_shift_df.empty and \ 122 | all(col in local_shift_df.columns for col in ["_rlnCoordinateX", "_rlnCoordinateY"]): 123 | try: 124 | center_x = local_shift_df["_rlnCoordinateX"].astype(float).mean() 125 | center_y = local_shift_df["_rlnCoordinateY"].astype(float).mean() 126 | except Exception: pass # Ignore errors, default to 0,0 127 | 128 | if isinstance(local_shift_df, pd.DataFrame) and not local_shift_df.empty and \ 129 | all(col in local_shift_df.columns for col in ["_rlnCoordinateX", "_rlnCoordinateY", "_rlnMicrographShiftX", "_rlnMicrographShiftY"]): 130 | try: 131 | max_coord_x = local_shift_df["_rlnCoordinateX"].astype(float).max() 132 | max_coord_y = local_shift_df["_rlnCoordinateY"].astype(float).max() 133 | max_gshift_x = global_shift[shift_x_col].abs().max() 134 | max_gshift_y = global_shift[shift_y_col].abs().max() 135 | scale_x = (max_coord_x / (2 * max_gshift_x)) if max_gshift_x > 1e-6 else 100 136 | scale_y = (max_coord_y / (2 * max_gshift_y)) if max_gshift_y > 1e-6 else 100 137 | fold_global_motion = max(1, int(min(scale_x, scale_y, 500))) 138 | except Exception as e: 139 | logger.warning(f"Could not calculate global fold factor: {e}. Using default.") 140 | fold_global_motion = 100 # Default scale if local data missing/problematic 141 | else: 142 | fold_global_motion = st.slider( 143 | "Global Motion Scale:", 1, 1000, 100, 10, key="global_motion_scale", 144 | help="Magnifies global shifts for visibility when local data is absent.") 145 | 146 | global_x = global_shift[shift_x_col] * fold_global_motion + center_x 147 | global_y = global_shift[shift_y_col] * fold_global_motion + center_y 148 | 149 | fig.add_trace(go.Scatter( 150 | x=global_x, y=global_y, mode="lines+markers", name="Global Shift", 151 | line=dict(color="#007acc", width=2), marker=dict(size=5), 152 | hovertemplate="Frame: %{customdata}
X: %{x:.1f}
Y: %{y:.1f}", 153 | customdata=np.arange(len(global_x)) 154 | )) 155 | traces_added = True 156 | else: logger.warning(f"Missing required global motion columns in {mic_metadata_rel_path}") 157 | except Exception as exc: 158 | report_error(exc, f"Error processing global shift data for {mic_metadata_rel_path}") 159 | st.warning(f"Could not process global motion data: {exc}") 160 | 161 | # --- Display Combined Plot --- 162 | if traces_added: 163 | fig.update_layout( 164 | title=f"Motion Track: {os.path.basename(mic_metadata_rel_path)}", 165 | xaxis_title="X Coordinate (px)", yaxis_title="Y Coordinate (px)", 166 | width=500, height=500, showlegend=True, hovermode="closest", 167 | legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), 168 | xaxis=dict(scaleanchor="y", scaleratio=1), 169 | margin=dict(l=10, r=10, t=50, b=10) 170 | ) 171 | with col2: 172 | st.plotly_chart(fig, use_container_width=True) 173 | else: 174 | with col2: 175 | st.info("No plottable motion data found.") 176 | 177 | except FileNotFoundError: 178 | st.warning(f"Motion metadata file not found: {mic_metadata_rel_path}") 179 | except Exception as exc: 180 | report_error(exc, f"Error in show_motion for {mic_metadata_rel_path}") 181 | st.warning(f"An error occurred while displaying motion: {exc}") 182 | 183 | 184 | # ============================================================================= 185 | # Main Job Function 186 | # ============================================================================= 187 | 188 | def plot_motioncorr(rln_folder: str, node: str) -> None: 189 | """ 190 | Processes and plots MotionCorr job results, including Altair statistics plots. 191 | 192 | Args: 193 | rln_folder: Base directory of the RELION project. 194 | node: Specific node file for this job (e.g., "MotionCorr/.../corrected_micrographs.star"). 195 | """ 196 | star_path = os.path.join(rln_folder, node) 197 | logger.info(f"Processing MotionCorr job: {star_path}") 198 | 199 | if not os.path.exists(star_path): 200 | st.warning(f"MotionCorr STAR file not found: {node}") 201 | logger.error(f"File not found: {star_path}") 202 | return 203 | 204 | try: 205 | star = parse_star(star_path) 206 | if not star: 207 | st.warning(f"Could not parse or empty STAR file: {node}") 208 | return 209 | except Exception as exc: 210 | report_error(exc, f"Failed to parse MotionCorr STAR file: {node}") 211 | st.warning(f"Error parsing {node}. See logs.") 212 | return 213 | 214 | star_data: Optional[pd.DataFrame] = None 215 | # Handle different potential block names 216 | if 'micrographs' in star: 217 | star_data = star["micrographs"] 218 | logger.debug("Using 'micrographs' block.") 219 | elif 'global' in star: # Relion 5 Tomo case 220 | logger.debug("Found 'global' block (likely Relion 5 Tomo MotionCorr).") 221 | try: 222 | tomo_star_files = star["global"]['_rlnTomoTiltSeriesStarFile'] 223 | tomo_star_data_list = [] 224 | failed_loads = 0 225 | for tomo_star in tomo_star_files: 226 | tomo_star_path = os.path.join(rln_folder, tomo_star) 227 | if os.path.exists(tomo_star_path): 228 | parsed_tomo_star = parse_star(tomo_star_path) 229 | first_block_data = get_values_from_first_key(parsed_tomo_star) 230 | if isinstance(first_block_data, pd.DataFrame): 231 | tomo_star_data_list.append(first_block_data) 232 | else: failed_loads += 1; logger.warning(f"No DataFrame in {tomo_star}") 233 | else: failed_loads += 1; logger.warning(f"Tilt series STAR not found: {tomo_star_path}") 234 | if failed_loads > 0: st.warning(f"Failed to load data from {failed_loads} tilt series STAR(s).") 235 | if tomo_star_data_list: 236 | star_data = pd.concat(tomo_star_data_list, ignore_index=True) 237 | logger.info(f"Concatenated data from {len(tomo_star_data_list)} tilt series.") 238 | else: st.warning("Could not load motion data from linked tilt series STAR files."); return 239 | except KeyError as e: report_error(e, "Missing key in Relion 5 Tomo MotionCorr"); st.warning(f"Error: Missing key {e}."); return 240 | except Exception as e: report_error(e, "Error processing Relion 5 Tomo MotionCorr"); st.warning(f"Error processing Tomo data: {e}."); return 241 | else: 242 | first_key = get_first_key(star) 243 | if first_key and isinstance(star[first_key], pd.DataFrame): 244 | star_data = star[first_key] 245 | logger.warning(f"Using first block '{first_key}' as fallback data source.") 246 | else: st.warning("No 'micrographs' or 'global' block found."); return 247 | 248 | if star_data is None or star_data.empty: 249 | st.warning("No valid micrograph data found for MotionCorr.") 250 | return 251 | 252 | # --- Motion Statistics Plotting (Altair) --- 253 | st.subheader("Motion Statistics") 254 | meta_names = ["_rlnAccumMotionTotal", "_rlnAccumMotionEarly", "_rlnAccumMotionLate"] 255 | df_line = pd.DataFrame() 256 | df_hist = pd.DataFrame() 257 | percentile_99 = None # Initialize percentile 258 | 259 | try: 260 | valid_meta_found = False 261 | for meta in meta_names: 262 | if meta in star_data.columns: 263 | data_array = pd.to_numeric(star_data[meta], errors='coerce').dropna() 264 | if not data_array.empty: 265 | valid_meta_found = True 266 | # Calculate percentile for clipping (only once, based on Total if possible) 267 | if meta == "_rlnAccumMotionTotal" and percentile_99 is None: 268 | percentile_99 = np.percentile(data_array, 99.5) 269 | 270 | # Use the calculated percentile, or fallback if Total wasn't available first 271 | clip_value = percentile_99 if percentile_99 is not None else np.percentile(data_array, 99.5) 272 | data_clipped_line = np.clip(data_array, None, clip_value) 273 | # Use a fixed clip for histogram for better comparison range? e.g., 100 Angstroms 274 | data_clipped_hist = np.clip(data_array, None, 100) 275 | 276 | series_name = meta.replace("_rln", "").replace("AccumMotion", "") # Shorter name 277 | temp_df_line = pd.DataFrame({"Index": np.arange(len(data_clipped_line)), "Motion": data_clipped_line, "Series": series_name}) 278 | temp_df_hist = pd.DataFrame({"Motion": data_clipped_hist, "Series": series_name}) 279 | df_line = pd.concat([df_line, temp_df_line], ignore_index=True) 280 | df_hist = pd.concat([df_hist, temp_df_hist], ignore_index=True) 281 | else: logger.warning(f"Column '{meta}' is empty or has no numeric data.") 282 | else: logger.warning(f"Column '{meta}' not found in STAR data.") 283 | 284 | if not valid_meta_found: 285 | st.caption("No motion statistics columns found to plot.") 286 | else: 287 | # Create Altair charts 288 | clip_info = f" (line plot clipped at {percentile_99:.2f} Å)" if percentile_99 is not None else "" 289 | line_chart = alt.Chart(df_line).mark_line(point=False, opacity=0.8).encode( # point=False for large datasets 290 | x=alt.X("Index:Q", title="Micrograph Index"), 291 | y=alt.Y("Motion:Q", title="Accumulated Motion (Å)"), 292 | color=alt.Color("Series:N", title="Motion Type"), 293 | tooltip=["Index", "Motion", "Series"] 294 | ).properties( 295 | title=f"Motion Statistics{clip_info}" 296 | ).interactive() # Add interactivity 297 | 298 | hist_chart = alt.Chart(df_hist).mark_bar(opacity=0.6, binSpacing=0).encode( # Adjust binSpacing 299 | x=alt.X("Motion:Q", bin=alt.Bin(maxbins=50), title="Accumulated Motion (Å, clipped at 100)"), 300 | y=alt.Y("count()", title="Frequency", stack=None), # stack=None for overlaid bars 301 | color=alt.Color("Series:N", title="Motion Type"), 302 | tooltip=[alt.Tooltip("Motion:Q", bin=True), "count()", "Series:N"] 303 | ).properties( 304 | title="Motion Histograms" 305 | ).interactive() 306 | 307 | # Display side-by-side 308 | col1, col2 = st.columns(2) 309 | with col1: st.altair_chart(line_chart, use_container_width=True) 310 | with col2: st.altair_chart(hist_chart, use_container_width=True) 311 | 312 | except Exception as exc: 313 | report_error(exc, "Error generating MotionCorr statistics plots.") 314 | st.warning(f"Could not generate motion statistics plots: {exc}") 315 | 316 | st.divider() 317 | 318 | # --- Motion Tracks and Micrograph Display --- 319 | # Checkbox to show motion tracks plot 320 | metadata_col = '_rlnMicrographMetadata' 321 | if metadata_col in star_data.columns: 322 | # Use default value True to show motion by default if available 323 | if st.checkbox("Show Motion Tracks per Micrograph?", value=True, key="motioncorr_show_tracks"): 324 | st.subheader("Motion Tracks") 325 | motion_star_files = star_data[metadata_col].dropna().tolist() # Drop missing values 326 | valid_motion_files = [f for f in motion_star_files if isinstance(f, str) and f] 327 | if valid_motion_files: 328 | show_motion(rln_folder, valid_motion_files) 329 | else: st.info("No valid micrograph metadata files found.") 330 | else: st.caption(f"'{metadata_col}' column not found - cannot display individual motion tracks.") 331 | 332 | st.divider() 333 | 334 | # Display corrected micrographs 335 | st.subheader("Corrected Micrographs") 336 | mics_col = '_rlnMicrographName' 337 | if mics_col in star_data.columns: 338 | mics = star_data[mics_col].dropna().tolist() 339 | valid_mics = [m for m in mics if isinstance(m, str) and m] 340 | if valid_mics: 341 | try: 342 | micrograph_viewer(rln_folder, valid_mics) 343 | except Exception as exc: 344 | report_error(exc, "Error calling micrograph_viewer for corrected micrographs") 345 | st.warning(f"Error displaying corrected micrographs: {exc}") 346 | else: st.info("No valid corrected micrograph paths found.") 347 | else: st.caption(f"'{mics_col}' column not found - cannot display corrected micrographs.") 348 | 349 | st.divider() 350 | 351 | # Option to plot metadata interactively 352 | if st.checkbox("Plot Metadata Interactively?", key="motioncorr_plot_meta"): 353 | st.subheader("Metadata Distribution") 354 | try: 355 | interactive_scatter_plot(data_source=star_path, title_prefix="MotionCorr Metadata") 356 | except Exception as exc: 357 | report_error(exc, "Error calling interactive_scatter_plot for MotionCorr metadata") 358 | st.warning(f"Error plotting metadata: {exc}") 359 | 360 | logger.info(f"Finished processing MotionCorr job: {node}") -------------------------------------------------------------------------------- /follow_relion_gracefully.py: -------------------------------------------------------------------------------- 1 | # 2 | # ____ ___ ___ ____ ___ ____ ___ ___ ___ 3 | # /\ __\ /\_ \ /\_ \ /\ _`\ /\_ \ __ /\ _`\ /'___\ /\_ \ /\_ \ 4 | # \ \ \_ __\//\ \ \//\ \ ___ __ __ __ \ \ \L\ \ __\//\ \ /\_\ ___ ___ \ \ \L\_\ _ __ __ ___ __ /\ \__/ __ __\//\ \ \//\ \ __ __ 5 | # \ \ _\/ __`\\ \ \ \ \ \ / __`\/\ \/\ \/\ \ \ \ , / /'__`\\ \ \ \/\ \ / __`\ /' _ `\ \ \ \L_L /\`'__\/'__`\ /'___\ /'__`\ \ ,__\/\ \/\ \ \ \ \ \ \ \ /\ \/\ \ 6 | # \ \ \/\ \L\ \\_\ \_ \_\ \_/\ \L\ \ \ \_/ \_/ \ \ \ \\ \ /\ __/ \_\ \_\ \ \/\ \L\ \/\ \/\ \ \ \ \/, \ \ \//\ \L\.\_/\ \__//\ __/\ \ \_/\ \ \_\ \ \_\ \_ \_\ \_\ \ \_\ \ 7 | # \ \_\ \____//\____\/\____\ \____/\ \___x___/' \ \_\ \_\ \____\/\____\\ \_\ \____/\ \_\ \_\ \ \____/\ \_\\ \__/.\_\ \____\ \____\\ \_\ \ \____/ /\____\/\____\\/`____ \ 8 | # \/_/\/___/ \/____/\/____/\/___/ \/__//__/ \/_/\/ /\/____/\/____/ \/_/\/___/ \/_/\/_/ \/___/ \/_/ \/__/\/_/\/____/\/____/ \/_/ \/___/ \/____/\/____/ `/___/> \ 9 | # \\___/ 10 | # Follow Relion Gracefully (v6) 11 | # Developed by Dawid Zyla, La Jolla Institute for Immunology 12 | # Non-Profit Open Software License 3.0 13 | 14 | # update v6 (2025-04-26) 15 | 16 | # ## Main Changes 17 | # -> Better integration with Streamlit platform (https://streamlit.io/) 18 | # -> Added support for all (most) Relion jobs covering all cryo-ET and SPA jobs (except for DynaMight) 19 | # -> Temporary removed live and in-browser job execution 20 | # -> General QOL fixes and improvements 21 | # 22 | # ## New Features 23 | # -> Support for all cryo-ET jobs with job previews 24 | # -> Optimized visualizations for most of the Relion jobs 25 | # -> Divided the code into smaller modules for better readability and maintainability 26 | # -> Overhauled visualization of Local Resolution, picking, micrograph previews, and other jobs 27 | # -> Increased performance and reduced loading times (in most cases) 28 | # -> Added better logging and error handling 29 | # 30 | # 31 | # ## To Do 32 | # -> Add own DynaMight job preview (currently not supported) 33 | # -> Add support for cryoSPARC cs files and export to Relion (most likely via pyem) 34 | # -> Further speed optimization and code cleanup 35 | 36 | # Standard Library Imports 37 | import argparse 38 | import logging 39 | import os 40 | import re 41 | from typing import Callable, Dict, List, Optional, Tuple 42 | 43 | # Third-Party Imports 44 | import pandas as pd 45 | import streamlit as st 46 | 47 | # Local Imports 48 | from lib.jobs_utils import ( # Assuming these are correctly defined in jobs_utils 49 | create_network, 50 | display_job_info, 51 | format_display_name, 52 | ) 53 | from lib.utils import ( # Assuming these are correctly defined in utils 54 | check_password, 55 | custom_css, 56 | dynamic_folder_explorer, 57 | get_footer, 58 | # get_newest_change, # Not directly used here, used within display_job_info 59 | # get_note, # Not directly used here, used within display_job_info 60 | # get_relationships_df, # Not directly used here, used within display_job_info 61 | interactive_scatter_plot, 62 | parse_star, 63 | render_svg, 64 | report_error, # Use the central report_error 65 | ) 66 | 67 | # ============================================================================= 68 | # Constants 69 | # ============================================================================= 70 | # Column names in RELION STAR files 71 | RLN_PROCESS_TYPE_LABEL = "_rlnPipeLineProcessTypeLabel" 72 | RLN_PROCESS_ALIAS = "_rlnPipeLineProcessAlias" 73 | RLN_PROCESS_NAME = "_rlnPipeLineProcessName" 74 | RLN_STATUS_LABEL = "_rlnPipeLineProcessStatusLabel" 75 | 76 | # Keys for data blocks in parsed pipeline STAR dictionary 77 | PIPELINE_PROCESSES_KEY = "pipeline_processes" 78 | PIPELINE_NODES_KEY = "pipeline_nodes" 79 | PIPELINE_EDGES_KEY = "pipeline_input_edges" # Make sure this key matches parser output 80 | 81 | # Special process names used internally 82 | FLOWCHART_PROCESS = "relion.flowchart" 83 | INTERACTIVE_PLOT_PROCESS = "relion.InteractivePlot" 84 | 85 | # Session State Keys 86 | STATE_SELECTED_PROCESS = "selected_process" 87 | STATE_CURRENT_JOB = "current_job" 88 | STATE_JOB_PARAMS = "job_params" 89 | STATE_PROCESS_RADIO_KEY = "process_radio_key" 90 | STATE_JOB_RADIO_KEY = "job_radio_key" 91 | STATE_PASSWORD_ARGS = "password_args" 92 | STATE_DEFAULT_FOLDER = "default_job_folder" # Initial folder path from args/explorer 93 | STATE_CURRENT_FOLDER = "current_folder_path" # Path currently being viewed 94 | STATE_DISPLAY_TO_ORIGINAL_PROCESS = "display_to_original_process_map" 95 | STATE_JOBS_DICT = "jobs_dict" 96 | TEMP_DIR_PATH = "temp" 97 | 98 | # Logger configuration 99 | logging_level = logging.DEBUG 100 | 101 | # ============================================================================= 102 | # Logger and Global Error Handling 103 | # ============================================================================= 104 | logger = logging.getLogger("main_app") 105 | if not logger.handlers: 106 | logger.setLevel(logging_level) 107 | console_handler = logging.StreamHandler() 108 | console_handler.setLevel(logging_level) 109 | formatter = logging.Formatter( 110 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s", 111 | datefmt="%Y-%m-%d %H:%M:%S", 112 | ) 113 | console_handler.setFormatter(formatter) 114 | logger.addHandler(console_handler) 115 | logger.propagate = False 116 | 117 | ERROR_HANDLER: Optional[Callable[[Exception, str], None]] = None 118 | 119 | 120 | def local_report_error_handler(exc: Exception, error_info: str): 121 | """Local fallback error handler that logs the error.""" 122 | logger.error("An unexpected error occurred:\n%s", error_info) 123 | 124 | 125 | def set_error_handler(handler: Callable[[Exception, str], None]) -> None: 126 | """Sets the global error handler for this module.""" 127 | global ERROR_HANDLER 128 | ERROR_HANDLER = handler 129 | 130 | 131 | # ============================================================================= 132 | # UI Setup 133 | # ============================================================================= 134 | def set_style() -> None: 135 | """Configures Streamlit page settings and applies custom CSS.""" 136 | try: 137 | st.set_page_config( 138 | page_title="Follow Relion Gracefully", 139 | page_icon=":microscope:", 140 | initial_sidebar_state="expanded", 141 | layout="wide", 142 | ) 143 | st.markdown(custom_css(), unsafe_allow_html=True) 144 | st.markdown( 145 | """ 146 | """, 150 | unsafe_allow_html=True, 151 | ) 152 | except Exception as e: 153 | logger.error(f"Failed to set page style: {e}") 154 | 155 | 156 | # ============================================================================= 157 | # Utility Functions 158 | # ============================================================================= 159 | def parse_help_text(help_text: str) -> Dict[str, str]: 160 | """Parses --parameter (description) format from help text.""" 161 | pattern = ( 162 | r"--(\w+)(?: \(([^)]+)\))?" # Non-capturing group for optional description 163 | ) 164 | matches = re.findall(pattern, help_text) 165 | return {f"--{match[0]}": match[1].strip() if match[1] else "" for match in matches} 166 | 167 | 168 | def create_temp_directory(temp_dir: str = TEMP_DIR_PATH) -> bool: 169 | """Creates a temporary directory if it doesn't exist.""" 170 | if not os.path.exists(temp_dir): 171 | try: 172 | os.makedirs(temp_dir) 173 | logger.info("Created temporary directory: %s", temp_dir) 174 | return True 175 | except OSError as e: 176 | # Use imported report_error 177 | report_error(e, f"Failed to create temporary directory {temp_dir}") 178 | st.error(f"Failed to create required directory '{temp_dir}': {e}") 179 | return False 180 | return True 181 | 182 | 183 | # ============================================================================= 184 | # Caching Heavy Computations 185 | # ============================================================================= 186 | @st.cache_data(ttl=3600) # Cache STAR file parsing for 1 hour 187 | def load_pipeline_star(folder: str) -> Optional[Dict[str, pd.DataFrame]]: 188 | """Loads and parses the default_pipeline.star file from the given folder.""" 189 | default_pipeline_path = os.path.join(folder, "default_pipeline.star") 190 | if not os.path.exists(default_pipeline_path): 191 | logger.warning("File not found: %s", default_pipeline_path) 192 | return None 193 | try: 194 | logger.info("Loading and parsing STAR file: %s", default_pipeline_path) 195 | pipeline_star = parse_star(default_pipeline_path) 196 | if not pipeline_star: 197 | logger.warning( 198 | "Parsing STAR file returned empty: %s", default_pipeline_path 199 | ) 200 | return None 201 | logger.info("Successfully parsed STAR file: %s", default_pipeline_path) 202 | return pipeline_star 203 | except Exception as exc: 204 | report_error(exc, f"Failed to parse STAR file: {default_pipeline_path}") 205 | st.error( 206 | f"Failed to parse STAR file: {os.path.basename(default_pipeline_path)}. Error: {exc}" 207 | ) 208 | return None 209 | 210 | 211 | # ============================================================================= 212 | # Data Preparation Functions 213 | # ============================================================================= 214 | def get_pipeline_df(pipeline_star: Optional[Dict[str, pd.DataFrame]]) -> pd.DataFrame: 215 | """Extracts and prepares the 'pipeline_processes' DataFrame.""" 216 | if not pipeline_star: 217 | return pd.DataFrame() 218 | pipeline_processes = pipeline_star.get(PIPELINE_PROCESSES_KEY) 219 | if not isinstance(pipeline_processes, pd.DataFrame) or pipeline_processes.empty: 220 | logger.warning("'%s' missing or empty in STAR data.", PIPELINE_PROCESSES_KEY) 221 | return pd.DataFrame() 222 | 223 | df = pipeline_processes.copy() 224 | required_cols = { 225 | RLN_PROCESS_TYPE_LABEL: "", 226 | RLN_PROCESS_NAME: "", 227 | RLN_STATUS_LABEL: "Unknown", 228 | RLN_PROCESS_ALIAS: "None", 229 | } 230 | for col, default in required_cols.items(): 231 | if col not in df.columns: 232 | logger.warning( 233 | "Column '%s' missing, adding with default '%s'.", col, default 234 | ) 235 | df[col] = default 236 | df[RLN_PROCESS_ALIAS] = df[RLN_PROCESS_ALIAS].fillna("None").astype(str) 237 | df[RLN_STATUS_LABEL] = df[RLN_STATUS_LABEL].fillna("Unknown").astype(str) 238 | return df 239 | 240 | 241 | @st.cache_data 242 | def get_jobs_for_process( 243 | selected_process: str, df: pd.DataFrame 244 | ) -> Tuple[List[str], Dict[str, str]]: 245 | """ 246 | Filters jobs for a selected process type and creates a display key mapping. 247 | Jobs are returned in the order they appear in the DataFrame. 248 | 249 | Args: 250 | selected_process: Original process name (e.g., "relion.Class2D"). 251 | df: Prepared pipeline_processes DataFrame. 252 | 253 | Returns: 254 | Tuple: (list of display keys in original order, mapping display_key -> job_name). 255 | """ 256 | if df.empty or RLN_PROCESS_TYPE_LABEL not in df.columns: 257 | logger.debug( 258 | "Cannot get jobs: DataFrame empty or missing '%s'.", RLN_PROCESS_TYPE_LABEL 259 | ) 260 | return [], {} 261 | 262 | jobs_df = df.loc[df[RLN_PROCESS_TYPE_LABEL] == selected_process] 263 | if jobs_df.empty: 264 | logger.debug("No jobs found for process type '%s'.", selected_process) 265 | return [], {} 266 | 267 | jobs_dict: Dict[str, str] = {} 268 | jobs_display_keys: List[str] = [] # Maintain order 269 | display_key_counts: Dict[str, int] = {} 270 | 271 | for _, row in jobs_df.iterrows(): # Iterate in DataFrame order 272 | alias = row.get(RLN_PROCESS_ALIAS, "None") 273 | job_name = row.get(RLN_PROCESS_NAME, "") 274 | if not job_name: 275 | continue 276 | 277 | base_display_key = ( 278 | alias if alias != "None" and alias.strip() != "" else job_name 279 | ) 280 | current_count = display_key_counts.get(base_display_key, 0) 281 | display_key = base_display_key 282 | if current_count > 0: 283 | display_key = ( 284 | f"{base_display_key} ({current_count})" # Append count for uniqueness 285 | ) 286 | display_key_counts[base_display_key] = current_count + 1 287 | 288 | if display_key not in jobs_dict: 289 | jobs_dict[display_key] = job_name 290 | jobs_display_keys.append(display_key) # Add to ordered list 291 | # Do not sort jobs_display_keys 292 | 293 | logger.debug( 294 | "Found %d jobs for process '%s'.", len(jobs_display_keys), selected_process 295 | ) 296 | return jobs_display_keys, jobs_dict 297 | 298 | 299 | # ============================================================================= 300 | # State Management Callbacks 301 | # ============================================================================= 302 | def handle_folder_change(): 303 | """Resets relevant session state variables when the folder changes.""" 304 | logger.info("Resetting process/job state due to folder change.") 305 | keys_to_reset = [ 306 | STATE_SELECTED_PROCESS, 307 | STATE_CURRENT_JOB, 308 | STATE_JOB_PARAMS, 309 | STATE_PROCESS_RADIO_KEY, 310 | STATE_JOB_RADIO_KEY, 311 | STATE_JOBS_DICT, 312 | STATE_DISPLAY_TO_ORIGINAL_PROCESS, 313 | ] 314 | for key in keys_to_reset: 315 | if key in st.session_state: 316 | st.session_state[key] = None # Reset to None or appropriate default 317 | st.session_state[STATE_JOBS_DICT] = {} # Ensure these are empty dicts 318 | st.session_state[STATE_JOB_PARAMS] = {} 319 | st.session_state[STATE_DISPLAY_TO_ORIGINAL_PROCESS] = {} 320 | 321 | 322 | def handle_process_change(): 323 | """Resets job state when the selected process type changes via the radio button.""" 324 | selected_display_name = st.session_state.get(STATE_PROCESS_RADIO_KEY) 325 | display_to_original = st.session_state.get(STATE_DISPLAY_TO_ORIGINAL_PROCESS, {}) 326 | newly_selected_process = display_to_original.get(selected_display_name) 327 | 328 | # Check if the *actual underlying process name* has changed 329 | if newly_selected_process != st.session_state.get(STATE_SELECTED_PROCESS): 330 | logger.info("Process selection changed to: '%s'", newly_selected_process) 331 | st.session_state[STATE_SELECTED_PROCESS] = newly_selected_process 332 | # Reset only job-related state 333 | st.session_state[STATE_CURRENT_JOB] = None 334 | st.session_state[STATE_JOB_PARAMS] = {} 335 | st.session_state[STATE_JOB_RADIO_KEY] = None 336 | st.session_state[STATE_JOBS_DICT] = {} 337 | logger.debug("Job state reset due to process change.") 338 | 339 | 340 | def handle_job_change(): 341 | """Updates the current job based on the job radio button selection.""" 342 | selected_display_key = st.session_state.get(STATE_JOB_RADIO_KEY) 343 | jobs_dict = st.session_state.get(STATE_JOBS_DICT, {}) 344 | newly_selected_job_name = jobs_dict.get(selected_display_key) 345 | 346 | # Check if the *actual underlying job name* has changed 347 | if newly_selected_job_name != st.session_state.get(STATE_CURRENT_JOB): 348 | logger.info("Job selection changed to: '%s'", newly_selected_job_name) 349 | st.session_state[STATE_CURRENT_JOB] = newly_selected_job_name 350 | st.session_state[STATE_JOB_PARAMS] = { 351 | "folder": st.session_state.get(STATE_CURRENT_FOLDER), 352 | "process": st.session_state.get(STATE_SELECTED_PROCESS), 353 | } 354 | 355 | 356 | # ============================================================================= 357 | # Main Application Logic 358 | # ============================================================================= 359 | def main() -> None: 360 | """Main function to run the Streamlit application.""" 361 | try: 362 | # --- Argument Parsing --- 363 | parser = argparse.ArgumentParser(description="Follow Relion Gracefully Viewer") 364 | parser.add_argument( 365 | "-i", 366 | "--folder", 367 | type=str, 368 | help="Initial Relion project folder", 369 | default="~", 370 | ) 371 | # parser.add_argument("--relion-path", type=str, help="Path to RELION installation (bin)", default="/") 372 | parser.add_argument( 373 | "-p", 374 | "--password", 375 | type=str, 376 | help="Password to protect the instance", 377 | default="", 378 | ) 379 | args, _ = parser.parse_known_args() 380 | 381 | # --- Session State Initialization --- 382 | # Initialize folder state using command line argument ONLY IF state is empty 383 | if STATE_DEFAULT_FOLDER not in st.session_state: 384 | initial_folder = os.path.abspath(os.path.expanduser(args.folder)) 385 | st.session_state[STATE_DEFAULT_FOLDER] = initial_folder 386 | st.session_state[STATE_CURRENT_FOLDER] = ( 387 | initial_folder # Also set current initially 388 | ) 389 | logger.info( 390 | "Initialized state: Default folder set from arg to %s", initial_folder 391 | ) 392 | # Ensure other states have default values if not set previously 393 | st.session_state.setdefault( 394 | STATE_CURRENT_FOLDER, st.session_state[STATE_DEFAULT_FOLDER] 395 | ) 396 | st.session_state.setdefault(STATE_PASSWORD_ARGS, args.password) 397 | st.session_state.setdefault(STATE_SELECTED_PROCESS, None) 398 | st.session_state.setdefault(STATE_CURRENT_JOB, None) 399 | st.session_state.setdefault(STATE_JOB_PARAMS, {}) 400 | st.session_state.setdefault(STATE_PROCESS_RADIO_KEY, None) 401 | st.session_state.setdefault(STATE_JOB_RADIO_KEY, None) 402 | st.session_state.setdefault(STATE_DISPLAY_TO_ORIGINAL_PROCESS, {}) 403 | st.session_state.setdefault(STATE_JOBS_DICT, {}) 404 | 405 | # --- Password Check --- 406 | if st.session_state[STATE_PASSWORD_ARGS]: 407 | if not check_password( 408 | st.session_state[STATE_PASSWORD_ARGS] 409 | ): # Pass correct arg 410 | st.warning("Password required.") 411 | st.stop() 412 | 413 | # --- UI Rendering & Folder Selection --- 414 | render_svg("./static/frg.svg") 415 | footer = get_footer() 416 | create_temp_directory() 417 | 418 | # dynamic_folder_explorer updates STATE_DEFAULT_FOLDER internally now 419 | # It uses STATE_DEFAULT_FOLDER as its starting point if uninitialized. 420 | # It returns the confirmed path, which we use to update STATE_CURRENT_FOLDER if needed. 421 | selected_folder = dynamic_folder_explorer( 422 | st.session_state[STATE_DEFAULT_FOLDER] 423 | ) 424 | 425 | # Detect if the folder confirmed by the explorer differs from the current viewing folder 426 | if selected_folder != st.session_state.get(STATE_CURRENT_FOLDER): 427 | logger.info( 428 | "Folder changed via explorer: '%s' -> '%s'", 429 | st.session_state.get(STATE_CURRENT_FOLDER), 430 | selected_folder, 431 | ) 432 | st.session_state[STATE_CURRENT_FOLDER] = selected_folder 433 | # Crucially, update the DEFAULT folder as well if the user explicitly selected it 434 | st.session_state[STATE_DEFAULT_FOLDER] = selected_folder 435 | handle_folder_change() # Reset dependent states 436 | st.cache_data.clear() # Clear data cache on folder change 437 | st.rerun() 438 | 439 | current_folder = st.session_state[STATE_CURRENT_FOLDER] 440 | 441 | # --- Load Data for Current Folder --- 442 | pipeline_star = load_pipeline_star(current_folder) # Uses cache 443 | df = get_pipeline_df(pipeline_star) # Process potentially cached data 444 | 445 | # --- Sidebar: Process Selection --- 446 | st.sidebar.title("Process Types") 447 | process_types = [] 448 | if df.empty: 449 | st.sidebar.warning("No pipeline data found.") 450 | elif RLN_PROCESS_TYPE_LABEL not in df.columns: 451 | st.sidebar.error(f"Missing column: {RLN_PROCESS_TYPE_LABEL}") 452 | else: 453 | # Get unique process types IN THE ORDER THEY APPEAR in the DataFrame 454 | process_types = list(df[RLN_PROCESS_TYPE_LABEL].unique()) # No sorting here 455 | 456 | special_processes = [FLOWCHART_PROCESS, INTERACTIVE_PLOT_PROCESS] 457 | all_processes = ( 458 | process_types + special_processes 459 | ) # Add special views at the end 460 | 461 | if not all_processes: 462 | st.sidebar.info("No processes found.") 463 | st.info("Select a valid RELION project folder.") 464 | else: 465 | # Map original names to display names 466 | original_to_display = {p: format_display_name(p) for p in all_processes} 467 | display_to_original = {d: p for p, d in original_to_display.items()} 468 | st.session_state[STATE_DISPLAY_TO_ORIGINAL_PROCESS] = display_to_original 469 | display_options = list(original_to_display.values()) 470 | 471 | # Determine current selection index for the radio button 472 | current_display = st.session_state.get(STATE_PROCESS_RADIO_KEY) 473 | if current_display not in display_options: 474 | current_display = display_options[0] if display_options else None 475 | # If defaulting, update the underlying selected process state 476 | if current_display: 477 | st.session_state[STATE_PROCESS_RADIO_KEY] = current_display 478 | st.session_state[STATE_SELECTED_PROCESS] = display_to_original.get( 479 | current_display 480 | ) 481 | handle_process_change() # Reset job state implicitly 482 | 483 | current_index = ( 484 | display_options.index(current_display) 485 | if current_display in display_options 486 | else 0 487 | ) 488 | 489 | # Process Radio Button 490 | st.sidebar.radio( 491 | "Select View:", 492 | options=display_options, 493 | index=current_index, 494 | key=STATE_PROCESS_RADIO_KEY, 495 | on_change=handle_process_change, 496 | ) 497 | 498 | # --- Sidebar: Job Selection --- 499 | selected_process = st.session_state.get(STATE_SELECTED_PROCESS) 500 | if selected_process and selected_process not in special_processes: 501 | # Get jobs IN ORIGINAL ORDER 502 | jobs_display_keys, jobs_dict = get_jobs_for_process( 503 | selected_process, df 504 | ) # Uses cache 505 | st.session_state[STATE_JOBS_DICT] = jobs_dict 506 | 507 | if jobs_display_keys: 508 | process_display_name = format_display_name(selected_process) 509 | st.sidebar.title(f"{process_display_name} Jobs") 510 | 511 | current_job_key = st.session_state.get(STATE_JOB_RADIO_KEY) 512 | # Check if current selection is valid, default to LAST job if not 513 | if current_job_key not in jobs_display_keys: 514 | current_job_key = jobs_display_keys[ 515 | -1 516 | ] # Default to last job in the list 517 | st.session_state[STATE_JOB_RADIO_KEY] = current_job_key 518 | st.session_state[STATE_CURRENT_JOB] = jobs_dict.get( 519 | current_job_key 520 | ) 521 | st.session_state[STATE_JOB_PARAMS] = { 522 | "folder": current_folder, 523 | "process": selected_process, 524 | } 525 | 526 | job_index = jobs_display_keys.index(current_job_key) 527 | 528 | # Job Radio Button 529 | st.sidebar.radio( 530 | "Select Job:", 531 | options=jobs_display_keys, 532 | index=job_index, 533 | key=STATE_JOB_RADIO_KEY, 534 | on_change=handle_job_change, 535 | ) 536 | else: 537 | st.sidebar.caption("No jobs of this type found.") 538 | if ( 539 | st.session_state.get(STATE_CURRENT_JOB) is not None 540 | ): # Clear state if no jobs 541 | st.session_state[STATE_CURRENT_JOB] = None 542 | st.session_state[STATE_JOB_PARAMS] = {} 543 | st.session_state[STATE_JOB_RADIO_KEY] = None 544 | 545 | # --- Main Area Display --- 546 | st.markdown("---") 547 | if selected_process == FLOWCHART_PROCESS: 548 | st.title("Pipeline Flowchart") 549 | orientation = st.radio( 550 | "Layout:", 551 | ["top-bottom", "left-right"], 552 | index=0, 553 | key="flowchart_orientation", 554 | horizontal=True, 555 | ) 556 | if pipeline_star: 557 | dot_string = create_network(pipeline_star, orientation=orientation) 558 | if dot_string: 559 | st.graphviz_chart(dot_string, use_container_width=True) 560 | else: 561 | st.warning("Could not generate flowchart.") 562 | else: 563 | st.warning("Load a project first.") 564 | 565 | elif selected_process == INTERACTIVE_PLOT_PROCESS: 566 | st.title("Interactive STAR File Plot") 567 | col1, col2 = st.columns([1, 2]) 568 | star_file_data = None 569 | plot_file_name = "Data" 570 | viewer_prefix = "iplot_" 571 | with col1.expander("Select Data", expanded=True): 572 | star_path = st.text_input( 573 | "STAR file path:", key=f"{viewer_prefix}path" 574 | ) 575 | up_file = st.file_uploader( 576 | "Or upload STAR:", type=["star"], key=f"{viewer_prefix}up" 577 | ) 578 | if up_file: 579 | tmp_path = os.path.join(TEMP_DIR_PATH, up_file.name) 580 | try: 581 | with open(tmp_path, "wb") as f: 582 | f.write(up_file.getbuffer()) 583 | star_file_data = parse_star(tmp_path) 584 | plot_file_name = up_file.name 585 | os.remove(tmp_path) 586 | except Exception as e: 587 | st.error(f"Error processing upload: {e}") 588 | elif star_path and os.path.exists(star_path): 589 | star_file_data = parse_star(star_path) 590 | plot_file_name = os.path.basename(star_path) 591 | elif star_path: 592 | st.warning("Path does not exist.") 593 | 594 | if star_file_data: 595 | blocks = list(star_file_data.keys()) 596 | if blocks: 597 | sel_block = col1.selectbox( 598 | "Select Block:", 599 | blocks, 600 | index=len(blocks) - 1, 601 | key=f"{viewer_prefix}block", 602 | ) 603 | if sel_block: 604 | with col2: 605 | interactive_scatter_plot( 606 | data_source=star_file_data, # Pass data dict 607 | block_selector_options=blocks, 608 | default_block=sel_block, 609 | title_prefix=plot_file_name, 610 | ) 611 | else: 612 | col1.warning("No data blocks found.") 613 | else: 614 | col1.info("Upload or provide path to a STAR file.") 615 | 616 | elif selected_process: # Regular RELION job type 617 | selected_job = st.session_state.get(STATE_CURRENT_JOB) 618 | if selected_job: 619 | display_job_info( 620 | selected_job, current_folder, df, pipeline_star or {} 621 | ) 622 | else: 623 | if selected_process in process_types: 624 | st.info( 625 | f"Select a job for '{format_display_name(selected_process)}' from the sidebar." 626 | ) 627 | 628 | else: # No process selected 629 | st.info("Select a process type or view from the sidebar.") 630 | 631 | st.sidebar.button( 632 | "Refresh", 633 | key="refresh_button", 634 | help="Refresh the page to clear any temporary data.", 635 | on_click=handle_folder_change, # Reset state on refresh 636 | ) 637 | # --- Footer --- 638 | st.sidebar.markdown("---") 639 | st.sidebar.markdown(footer, unsafe_allow_html=True) 640 | 641 | except Exception as exc: 642 | report_error(exc, "An unexpected error occurred in the main application.") 643 | st.error("An unexpected error occurred. Check logs for details.") 644 | 645 | 646 | if __name__ == "__main__": 647 | set_style() 648 | # Set local fallback error handler for logging 649 | set_error_handler(local_report_error_handler) 650 | main() 651 | --------------------------------------------------------------------------------