├── .gitignore
├── LICENSE
├── requirements.txt
├── CITATION.cff
├── environment.yml
├── relion_jobs
├── aligntilt_job.py
├── tomograms_job.py
├── class2d_job.py
├── postprocess_job.py
├── ctffind_job.py
├── picking_job.py
├── mask_job.py
├── localres_job.py
├── excludetilt_job.py
├── polish_job.py
├── import_job.py
├── class3d_job.py
└── motioncorr_job.py
├── static
└── frg.svg
├── README.md
└── follow_relion_gracefully.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *.egg-info/
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzyla/Follow_Relion_gracefully/HEAD/LICENSE
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy<2.0
2 | pandas
3 | mrcfile
4 | scikit-image
5 | scipy
6 | matplotlib
7 | seaborn
8 | pillow
9 | plotly
10 | streamlit
11 | pyvis
12 | gemmi
13 | stqdm
14 | PyMCubes
15 | altair
16 | morphosamplers
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Zyla"
5 | given-names: "Dawid"
6 | orcid: "https://orcid.org/0000-0001-8471-469X"
7 | title: "Follow_Relion_gracefully"
8 | version: v5
9 | doi: 10.5281/zenodo.10465899
10 | date-released: 2024-01-06
11 | url: "https://github.com/dzyla/Follow_Relion_gracefully"
12 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: FollowRelion
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - python=3.12
7 | - pip
8 | - numpy<2.0
9 | - pandas
10 | - mrcfile
11 | - scikit-image
12 | - scipy
13 | - matplotlib
14 | - seaborn
15 | - pillow
16 | - plotly
17 | - streamlit
18 | - pyvis
19 | - gemmi
20 | - altair
21 | - pip:
22 | - stqdm
23 | - PyMCubes
24 | - morphosamplers
25 |
--------------------------------------------------------------------------------
/relion_jobs/aligntilt_job.py:
--------------------------------------------------------------------------------
1 | # aligntilt_job.py
2 |
3 | import os
4 | import logging
5 | from typing import List
6 |
7 | import streamlit as st
8 | import pandas as pd
9 |
10 | from lib.utils import parse_star, get_values_from_first_key, report_error
11 |
12 | logger = logging.getLogger("main_app")
13 |
14 |
15 | def plot_align_tilt_series(rln_folder: str, node_file: str) -> None:
16 | """Plot alignment tilt series for a RELION node star file.
17 |
18 | Args:
19 | rln_folder (str): Path to the RELION folder containing star files.
20 | node_file (str): Name of the primary star file for the tilt series.
21 | """
22 | try:
23 | logger.info("Plotting alignment tilt series for node: %s", node_file)
24 | star_path = os.path.join(rln_folder, node_file)
25 |
26 | if not os.path.isfile(star_path):
27 | logger.error("Star file not found: %s", star_path)
28 | st.warning(f"Star file not found: {star_path}")
29 | return
30 |
31 | logger.debug("Parsing star file at: %s", star_path)
32 | try:
33 | star = parse_star(star_path)
34 | except Exception as exc:
35 | report_error(exc, f"Failed to parse star file: {star_path}")
36 | st.warning("Failed to parse the main star file.")
37 | return
38 |
39 | if "global" not in star:
40 | logger.error("Missing 'global' section in star file: %s", star_path)
41 | st.warning("The star file does not contain a 'global' section.")
42 | return
43 |
44 | tomo_files = star["global"].get("_rlnTomoTiltSeriesStarFile", [])
45 | if not isinstance(tomo_files, list) or not tomo_files:
46 | logger.error("No tilt-series references found in 'global' section.")
47 | st.warning("No tilt-series data found in the star file.")
48 | return
49 |
50 | dfs_tilt: List[pd.DataFrame] = []
51 | for tomo_file in tomo_files:
52 | path = os.path.join(rln_folder, tomo_file)
53 | try:
54 | star_data = parse_star(path)
55 | df = get_values_from_first_key(star_data)
56 | df["FileSource"] = os.path.basename(path)
57 | dfs_tilt.append(df)
58 | except Exception as exc:
59 | logger.error("Error parsing star file %s: %s", path, exc)
60 | report_error(exc, f"Error parsing tilt-series star file: {path}")
61 |
62 | if not dfs_tilt:
63 | logger.error("Parsed tilt-series data is empty.")
64 | st.warning("No tilt-series data found. Cannot plot tilt angles.")
65 | return
66 |
67 | # TODO: Implement plotting logic here using dfs_tilt
68 | except Exception as exc:
69 | report_error(exc, "Unexpected error in plot_align_tilt_series")
70 | st.warning("An unexpected error occurred while plotting tilt series.")
71 |
--------------------------------------------------------------------------------
/relion_jobs/tomograms_job.py:
--------------------------------------------------------------------------------
1 | import os
2 | import traceback
3 | import logging
4 |
5 | import streamlit as st
6 | import json
7 |
8 | from lib.image_utils import micrograph_viewer
9 |
10 | from lib.utils import (
11 | parse_star,
12 | get_values_from_first_key,
13 | )
14 |
15 | logger = logging.getLogger("main_app")
16 |
17 | def report_error(exc: Exception) -> None:
18 | """
19 | Report an error using the global error handler if available.
20 | Otherwise, log the error with full traceback.
21 | """
22 | error_info = traceback.format_exc()
23 | logger.error("An unexpected error occurred:\n%s", error_info)
24 |
25 |
26 | def plot_tomographs(rln_folder: str, node_files: str) -> None:
27 | logger.info(f"Plotting tomographs... rln_folder= {rln_folder}, node_files= {node_files}")
28 |
29 | star_path = os.path.join(rln_folder, node_files)
30 | logger.debug(f"star_path: {star_path}")
31 |
32 | if not os.path.exists(star_path):
33 | st.error(f"Star file not found: {star_path}")
34 | return
35 |
36 | try:
37 | star = parse_star(star_path)
38 | except Exception as exc:
39 | report_error(exc)
40 | st.error("Failed to parse the main star file.")
41 | return
42 |
43 | if "global" not in star:
44 | st.error("The star file does not contain a 'global' section.")
45 | return
46 |
47 | logger.debug(f"Parsed star file: {star}")
48 |
49 | # Gather paths to tilt-series star files from the exclude job
50 | tomo_star_files = star["global"]["_rlnTomoTiltSeriesStarFile"]
51 | tomo_mrc_files = star["global"]["_rlnTomoReconstructedTomogramHalf1"]
52 | tomo_star_files_paths = [os.path.join(rln_folder, f) for f in tomo_star_files]
53 |
54 | # Parse each tilt-series star file
55 | dfs_tilt = []
56 | for path_tomo_star in tomo_star_files_paths:
57 | try:
58 | star_data = parse_star(path_tomo_star)
59 | if star_data:
60 | df_tilt = get_values_from_first_key(star_data)
61 | df_tilt["FileSource"] = os.path.basename(path_tomo_star)
62 | dfs_tilt.append(df_tilt)
63 | except Exception as exc:
64 | logger.error(f"Error parsing star file {path_tomo_star}: {exc}")
65 | report_error(exc)
66 |
67 | if not dfs_tilt:
68 | st.error("No tilt-series data found. Cannot plot tilt angles.")
69 | return
70 |
71 | if 'Denoise' in node_files:
72 | logger.info("Denoised tomograms detected.")
73 |
74 | if '_rlnTomoReconstructedTomogramDenoised' not in star["global"]:
75 | logger.info("Denoise training job detected")
76 | config_path = os.path.join(os.path.dirname(star_path), "external/training/train_data_config.json")
77 | if os.path.exists(config_path):
78 | config_json = json.load(open(config_path))
79 | st.write("**Denoising training job config:**")
80 | st.json(config_json)
81 |
82 |
83 | else:
84 | tomo_mrc_files = star["global"]["_rlnTomoReconstructedTomogramDenoised"]
85 | micrograph_viewer(rln_folder=rln_folder, image_files=tomo_mrc_files)
86 | else:
87 | micrograph_viewer(rln_folder=rln_folder, image_files=tomo_mrc_files)
--------------------------------------------------------------------------------
/relion_jobs/class2d_job.py:
--------------------------------------------------------------------------------
1 | # class2d_job.py
2 | """
3 | Module handling 2D classification visualization and class selection.
4 | """
5 |
6 | import os
7 | import glob
8 | import logging
9 | from datetime import datetime
10 | from typing import List
11 |
12 | import numpy as np
13 | import streamlit as st
14 | import matplotlib.pyplot as plt
15 |
16 | # ------------------------------------------------------------------
17 | # Import your existing shared utilities. Adjust imports as needed.
18 | # ------------------------------------------------------------------
19 | from lib.utils import (
20 | interactive_scatter_plot,
21 | get_classes,
22 | )
23 | from relion_jobs.select_job import display_classes # or wherever display_classes is defined
24 |
25 |
26 | logger = logging.getLogger("main_app")
27 |
28 |
29 | def plot_class_distribution(class_dist_: np.ndarray, PLOT_HEIGHT=500):
30 | """
31 | Plots iteration-by-iteration distribution for 2D classes.
32 | Expects class_dist_ shape = (num_classes, num_iterations).
33 | """
34 | import plotly.graph_objects as go
35 |
36 | fig = go.Figure()
37 | for n, class_ in enumerate(class_dist_):
38 | class_ = class_.astype(float) * 100
39 | x = np.arange(0, class_dist_.shape[1])
40 | fig.add_trace(
41 | go.Scatter(
42 | x=x,
43 | y=class_,
44 | name=f"Class {n + 1}",
45 | showlegend=True,
46 | hovertemplate=(
47 | f"Class {n + 1}
Iteration: %{{x}}"
48 | "
Cls dist: %{y:.2f}%"
49 | ""
50 | ),
51 | mode="lines",
52 | stackgroup="one",
53 | )
54 | )
55 |
56 | fig.update_xaxes(title_text="Iteration")
57 | fig.update_yaxes(title_text="Class distribution (%)")
58 | fig.update_layout(title="Class distribution over iterations")
59 | fig.update_layout(hovermode="x unified", height=PLOT_HEIGHT)
60 |
61 | st.plotly_chart(fig, use_container_width=True, height=PLOT_HEIGHT)
62 |
63 |
64 | def plot_class2d(rln_folder: str, nodes: List[str]) -> None:
65 | """
66 | Show 2D classes for a given job, let user select classes, generate new star file,
67 | and optionally plot iteration-by-iteration distribution or do interactive plotting.
68 |
69 | Parameters:
70 | rln_folder: Root folder for the job data.
71 | nodes: List of relevant star/mrc/etc. files produced by the job.
72 | """
73 | logger.debug(f"{datetime.now()}: plot_class2d started with {nodes}")
74 | st.subheader("2D Classification Job")
75 |
76 | # 1) Derive the job base folder from the first node path
77 | job_basefolder = os.path.join(rln_folder, os.path.dirname(nodes[0]))
78 | data_star = os.path.join(job_basefolder, nodes[0])
79 | logger.debug(f"Job base folder: {job_basefolder}")
80 |
81 | # 2) Gather model.star files
82 | model_files = sorted(
83 | glob.glob(os.path.join(job_basefolder, "*model.star")),
84 | key=os.path.getmtime,
85 | )
86 | if not model_files:
87 | st.write("No model files found.")
88 | return
89 |
90 | # 3) Basic info about classes from the last model
91 | class_paths, n_classes, iter_, class_dist, class_res, _, _ = get_classes(
92 | job_basefolder, [model_files[-1]]
93 | )
94 | class_dist = np.squeeze(class_dist)
95 | logger.debug(f"Class paths: {class_paths}")
96 |
97 | if not len(class_paths):
98 | st.write("No classes found in the last model.")
99 | return
100 |
101 | # We place the class display in an expander so it is collapsible
102 | with st.expander("Class Averages and Selections", expanded=True):
103 | # This controls the UI for sorting classes
104 | col_sort, _ = st.columns([1, 10])
105 | sort_classes = col_sort.checkbox("Sort classes?", True)
106 |
107 | # Show each class path
108 | for class_path in class_paths:
109 | # `display_classes` is from relion_jobs.select_job
110 | # Keep the call arguments as is for backward compatibility
111 | logger.debug(f"2D Class path: {class_path}")
112 | display_classes(
113 | class_path=class_path,
114 | class_distribution=class_dist,
115 | sort_by_distribution=sort_classes,
116 | raw_data_star=nodes[0], # pass the first node star
117 | rln_folder=job_basefolder
118 | )
119 |
120 | # Now gather all model files for iteration-based distribution
121 | with st.expander("Distribution Over Iterations", expanded=False):
122 | class_paths_all, n_classes_all, iter_all, class_dist_all, class_res_all, _, _ = get_classes(
123 | job_basefolder, model_files
124 | )
125 | # Plot distribution if it is not empty
126 | if class_dist_all.size > 0:
127 | plot_class_distribution(class_dist_all, PLOT_HEIGHT=500)
128 | else:
129 | st.write("No iteration-based distribution data available.")
130 |
131 | if st.checkbox("Show detailed statistics?"):
132 | interactive_scatter_plot(os.path.join(rln_folder, nodes[0]))
133 | st.info("Displaying interactive scatter plot for the main STAR file.")
134 |
135 | plt.close()
136 | logger.info(f"{datetime.now()}: plot_class2d done.")
137 |
--------------------------------------------------------------------------------
/static/frg.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/relion_jobs/postprocess_job.py:
--------------------------------------------------------------------------------
1 | #postprocess_job.py
2 |
3 | import os
4 | import logging
5 | import math
6 | from datetime import datetime
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import streamlit as st
11 | import plotly.graph_objects as go
12 | import plotly.figure_factory as ff
13 | from skimage.transform import resize
14 | import mrcfile
15 | import seaborn as sns
16 |
17 |
18 | # Import your shared utilities.
19 | from lib.utils import parse_star, interactive_scatter_plot
20 | from lib.image_utils import normalize, display_volume_slices, plot_volume
21 | from relion_jobs.select_job import display_classes
22 |
23 | # Setup logger
24 | logger = logging.getLogger("main_app")
25 |
26 |
27 | def _load_postprocess_data(rln_folder: str, postprocess_star_path: str) -> dict:
28 | """
29 | Load the postprocess.star file using parse_star and cache the result in session state.
30 | If the current postprocess file is the same as the cached one, return the cached data.
31 | """
32 | cache_key = "postprocess_data"
33 | current_file_key = "current_postprocess_file"
34 | if (
35 | current_file_key in st.session_state
36 | and st.session_state[current_file_key] == postprocess_star_path
37 | and cache_key in st.session_state
38 | ):
39 | logger.debug("Using cached postprocess star data.")
40 | return st.session_state[cache_key]
41 |
42 | try:
43 | logger.debug(f"Loading postprocess star file: {postprocess_star_path}")
44 | postprocess_data = parse_star(postprocess_star_path)
45 | st.session_state[cache_key] = postprocess_data
46 | st.session_state[current_file_key] = postprocess_star_path
47 | return postprocess_data
48 | except Exception as e:
49 | logger.error(f"Error loading postprocess star file: {e}")
50 | return {}
51 |
52 |
53 | def plot_postprocess(rln_folder, nodes):
54 | """
55 | Main function to plot postprocess data in Streamlit.
56 | This function:
57 | 1) Loads (or retrieves from session state) the postprocess.star file.
58 | 2) Plots the FSC and Guinier curves side by side.
59 | 3) Displays masked volume slices.
60 | """
61 | logger.debug(f"{datetime.now()}: plot_postprocess started for {rln_folder} with {nodes}")
62 |
63 | # Assume postprocess.star file is the 4th node.
64 | postprocess_star_path = os.path.join(rln_folder, nodes[3])
65 | if not os.path.exists(postprocess_star_path):
66 | st.write("No postprocess.star file found")
67 | return
68 |
69 | # Load and cache postprocess data.
70 | postprocess_data = _load_postprocess_data(rln_folder, postprocess_star_path)
71 | if not postprocess_data:
72 | st.write("Failed to load postprocess.star data")
73 | return
74 |
75 | # Directly use Plotly for all plotting (remove Seaborn option).
76 | try:
77 | fsc_data = postprocess_data["fsc"].astype(float)
78 | guinier_data = postprocess_data["guinier"].astype(float)
79 | except Exception as e:
80 | st.write("Error processing postprocess.star data")
81 | logger.error(f"Error converting FSC/Guinier data to float: {e}")
82 | return
83 |
84 | st.write("## Postprocess Data Visualization")
85 | # Arrange FSC and Guinier plots into two columns.
86 | col_fsc, col_guinier = st.columns(2)
87 | with col_fsc:
88 | st.markdown("### Fourier Shell Correlation (FSC) Curve")
89 | plot_fsc_curve(fsc_data)
90 | with col_guinier:
91 | st.markdown("### Guinier Plot")
92 | plot_guinier_curve(guinier_data)
93 |
94 | # Display Masked Volume Slices.
95 | with st.expander("Volume preview"):
96 | display_volume_slices(rln_folder, nodes)
97 |
98 | logger.info(f"{datetime.now()}: plot_postprocess done")
99 |
100 |
101 | def plot_fsc_curve(fsc_data, dpi=150):
102 | """
103 | Plot the Fourier Shell Correlation (FSC) curve using Plotly.
104 | FSC curves for different map types are colored using Seaborn's 'deep' palette.
105 | """
106 | # Retrieve FSC resolution data and convert to float.
107 | fsc_x = fsc_data["_rlnAngstromResolution"].astype(float)
108 | fsc_x_min = np.min(fsc_x)
109 |
110 | # FSC fields to plot.
111 | fsc_to_plot = [
112 | "_rlnFourierShellCorrelationCorrected",
113 | "_rlnFourierShellCorrelationUnmaskedMaps",
114 | "_rlnFourierShellCorrelationMaskedMaps",
115 | "_rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps",
116 | ]
117 |
118 | # Use Seaborn's default "deep" palette for colors.
119 | palette = sns.color_palette("deep").as_hex()
120 |
121 | fig = go.Figure()
122 | for i, meta in enumerate(fsc_to_plot):
123 | color = palette[i % len(palette)]
124 | reciprocal_x = 1 / fsc_x
125 | fig.add_trace(
126 | go.Scatter(
127 | x=reciprocal_x,
128 | y=fsc_data[meta].astype(float),
129 | mode="lines",
130 | line=dict(color=color, width=3),
131 | customdata=fsc_x,
132 | name=meta.replace("_rlnFourierShellCorrelation", "").replace("_rlnCorrectedFourierShellCorrelation", ""),
133 | hovertemplate="Resolution: %{customdata:.2f} Å
FSC: %{y:.2f}",
134 | )
135 | )
136 |
137 | # Add horizontal threshold lines.
138 | fig.add_hline(y=0.143, line_dash="dash", line_color="black", annotation_text="0.143")
139 | fig.add_hline(y=0.5, line_dash="dash", line_color="black", annotation_text="0.5")
140 |
141 | # Determine axis range and custom ticks.
142 | start_res = 1 / 50
143 | end_res = 1 / fsc_x_min
144 | custom_ticks = np.linspace(start_res, end_res, num=10)
145 | fig.update_layout(
146 | xaxis=dict(
147 | title="Resolution, Å",
148 | tickvals=custom_ticks,
149 | ticktext=[f"{round(1/res, 2)}" for res in custom_ticks],
150 | range=[start_res, end_res],
151 | ),
152 | yaxis=dict(title="FSC", range=[-0.05, 1.05]),
153 | title="Fourier Shell Correlation Curve",
154 | legend=dict(x=1, y=1, xanchor="left", yanchor="top", bgcolor="rgba(255,255,255,0)"),
155 | margin=dict(l=20, r=20, t=40, b=20),
156 | )
157 | st.plotly_chart(fig, use_container_width=True)
158 |
159 | # Resolution annotations.
160 | fsc143 = 0.143
161 | idx_143 = np.argmin(np.abs(fsc_data["_rlnFourierShellCorrelationCorrected"].astype(float) - fsc143))
162 | fsc05 = 0.5
163 | idx_05 = np.argmin(np.abs(fsc_data["_rlnFourierShellCorrelationCorrected"].astype(float) - fsc05))
164 | st.markdown(f"Reported resolution @ FSC=0.143: **{round(fsc_x[idx_143], 2)} Å**")
165 | st.markdown(f"Reported resolution @ FSC=0.5: **{round(fsc_x[idx_05], 2)} Å**")
166 |
167 | def plot_guinier_curve(guinier_data):
168 | """
169 | Plot the Guinier curve using Plotly.
170 | """
171 | guiner_x = guinier_data["_rlnResolutionSquared"].astype(float)
172 | guinier_to_plot = [
173 | "_rlnLogAmplitudesOriginal",
174 | "_rlnLogAmplitudesMTFCorrected",
175 | "_rlnLogAmplitudesWeighted",
176 | "_rlnLogAmplitudesSharpened",
177 | "_rlnLogAmplitudesIntercept",
178 | ]
179 |
180 | palette = sns.color_palette("deep").as_hex()
181 |
182 | fig = go.Figure()
183 | for i, meta in enumerate(guinier_to_plot):
184 | color = palette[i % len(palette)]
185 | try:
186 | y_data = guinier_data[meta].astype(float)
187 | # Replace invalid values (-99) with NaN.
188 | y_data[y_data == -99] = float("nan")
189 | fig.add_trace(
190 | go.Scatter(
191 | x=guiner_x[1:],
192 | y=y_data[1:],
193 | mode="lines",
194 | line=dict(color=color, width=3),
195 | name=meta.replace("_rlnLogAmplitudes", ""),
196 | )
197 | )
198 | except Exception:
199 | pass
200 |
201 | fig.update_layout(
202 | title="Guinier Plot",
203 | xaxis_title="Resolution Squared, [1/Ų]",
204 | yaxis_title="Ln(Amplitudes)",
205 | margin=dict(l=20, r=20, t=40, b=20),
206 | )
207 | st.plotly_chart(fig, use_container_width=True)
208 |
209 |
210 |
211 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Follow Relion Gracefully :microscope::rocket::globe_with_meridians:
2 | ---
3 | **v6: A complete dashboard for easy interaction with your cryo-EM data in Relion, now with ~~partial~~ full :sparkles: `#teamtomo` :sparkles: support!**
4 |
5 | * **Data sourced from [Relion5 tutorial](https://relion.readthedocs.io/en/latest/SPA_tutorial/index.html), [Relion4 STA](https://relion.readthedocs.io/en/release-4.0/STA_tutorial/index.html), and [Relion5 STA](https://zenodo.org/records/11068319)**
6 | * **Licensed under Non-Profit Open Software License 3.0 (NPOSL-3.0)**
7 |
8 | **Micrograph viewer**
9 | 
10 |
11 | **Picking statistics preview**
12 | 
13 |
14 | **Mask with original volume preview**
15 | 
16 |
17 | **Local resolution directly in the browser**
18 | 
19 |
20 | **Tomography specific job preview:**
21 | 
22 |
23 | **Tomogram viewer**
24 | 
25 |
26 | **3D picking preview with annotations**
27 | 
28 |
29 | **3D picking with particles**
30 | 
31 |
32 |
33 |
34 | #### :sparkles: Found this helpful in your research? Cite my work! :sparkles:
35 |
36 | [](https://doi.org/10.5281/zenodo.10465899)
37 |
38 |
39 | #### Dawid Zyla. (2024). dzyla/Follow_Relion_gracefully: v5 (Version v5). Zenodo. https://doi.org/10.5281/zenodo.10465899
40 |
41 |
42 |
43 | ---
44 |
45 | ## Description :microscope:
46 |
47 | #### v6: :high_brightness:
48 | Version 6 introduces a complete dashboard for easy interaction with your cryo-EM data in Relion, now with full `#teamtomo` support! It allows users to visualize and analyze their data in real-time, providing a comprehensive overview of their projects. The new version also includes improved job previews, enhanced data visualization, and the ability to download volumes directly from the dashboard.
49 |
50 | #### v5:
51 | Version 5 improves the job preview by adopting a dynamic approach. Using [Streamlit](https://streamlit.io/), it allows users to interact directly with their data. The underlying Python framework facilitates real-time computation of statistics and data from most jobs, enabling users to engage with metadata and select preferred statistics for download and further analysis.
52 |
53 | #### v4:
54 | Version 4 introduced support for multiple projects and job visualization through an online interface using the Hugo framework. While this static job generator enabled job display with example data, it lacked interactive capabilities due to its static nature.
55 |
56 | ## v6 features :dizzy:
57 | * All the features of v5 but with improvements
58 | * Full support for `#teamtomo` jobs
59 | * Better job previews, including volume previews, statistics, and download options
60 | * Improved data visualization and publication-ready statistics
61 | * Ability to download volumes directly from the dashboard
62 | * Job flow chart overview, showing relationships between jobs
63 | * Improved speed (but increased RAM usage)
64 | * Included file browser for easier project switching
65 | * Plenty of QOL improvements
66 | * `#OpenSoftwareAcceleratesScience`
67 |
68 |
69 |
70 |
71 | ## Installation :rocket:
72 |
73 | Minor changes from `v5`, with a few new libraries added. Tested on `Windows 10/11`, `WSL2`, and `Ubuntu 22.04`.
74 |
75 | :heavy_exclamation_mark:
76 | ### If you are using a previously created environment, I strongly recommend creating a new environment with freshly installed dependencies, as some new libraries were added. Follow the steps below to set up a new environment.
77 | :heavy_exclamation_mark:
78 |
79 | ### Install Dependencies :snake:
80 |
81 | Install dependencies in a conda environment, as Python 3.12 is required and virtual environments are no longer supported (though they might still work).
82 |
83 | #### Conda Instructions
84 |
85 | 1. Install miniconda3 (*no root access required*, only if not installed already):
86 |
87 | ```bash
88 | wget -q -P . https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
89 |
90 | bash ./Miniconda3-latest-Linux-x86_64.sh -b -f
91 | ```
92 |
93 | Activate conda for bash:
94 |
95 | ```bash
96 | conda init bash
97 | ```
98 |
99 | Restart the shell or type `bash` to see the (base) prompt:
100 |
101 | ```bash
102 | (base) dzyla@GPU0
103 | ```
104 |
105 | 2. Clone the GitHub repository and navigate to the folder:
106 |
107 | ```bash
108 | git clone https://github.com/dzyla/Follow_Relion_gracefully.git
109 |
110 | cd Follow_Relion_gracefully
111 | ```
112 |
113 | 3. Create a conda environment and install dependencies using the `environment.yml` file:
114 |
115 | ```bash
116 | conda env create --file environment.yml
117 |
118 | conda activate FollowRelion
119 | ```
120 |
121 | You should now see:
122 |
123 | ```bash
124 | (FollowRelion) dzyla@GPU0
125 | ```
126 |
127 | #### UV Instructions
128 | UV is a package manager for Python that allows you to install and manage Python packages easily. It is similar to pip but is ultra-fast. To install UV, follow these steps:
129 |
130 | 1. Install UV using pip:
131 |
132 | ```bash
133 | pip install uv
134 | ```
135 |
136 | or if you don't have python/pip/conda installed, use the following command:
137 |
138 | ```bash
139 | # with curl
140 | curl -LsSf https://astral.sh/uv/install.sh | sh
141 |
142 | # alternatively, with wget
143 | wget -qO- https://astral.sh/uv/install.sh | sh
144 |
145 | # or on Windows
146 | powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
147 | ```
148 |
149 | 2. Clone the GitHub repository and navigate to the folder:
150 |
151 | ```bash
152 | git clone https://github.com/dzyla/Follow_Relion_gracefully.git
153 |
154 | cd Follow_Relion_gracefully
155 | ```
156 |
157 | 3. Create a new virtual environment using UV:
158 |
159 | ```bash
160 | uv venv FollowRelion --python 3.12
161 | ```
162 | 4. Activate the virtual environment:
163 |
164 | ```bash
165 | source FollowRelion/bin/activate
166 | ```
167 |
168 |
169 | 5. Install the required packages using the `requirements.txt` file:
170 |
171 | ```bash
172 | uv pip install -r requirements.txt
173 | ```
174 | 5. After the installation is complete, you should see a message indicating that the packages have been installed successfully.
175 |
176 |
177 |
178 | :sparkles:**Ready to start!** :sparkles:
179 |
180 | ## Usage :computer:
181 |
182 | ```text
183 | streamlit run follow_relion_gracefully.py
184 | ```
185 |
186 | Additional command line parameters for extra features:
187 |
188 | ```
189 | -h, --help Show this help message and exit.
190 | -i I, --folder I Path to the default folder.
191 | -p P, --password P Password for securing your instance.
192 | ```
193 |
194 | ##### Example for live server updates and setting up a new project:
195 |
196 | To use command line parameters with streamlit, add `--` before the parameters:
197 |
198 | ```
199 | conda activate FollowRelion
200 |
201 | streamlit run follow_relion_gracefully.py -- -p MyPassword$221#& -i /mnt/staging/240105_NewProcessing
202 | ```
203 |
204 | This sets a password and default processing folder.
205 |
206 | ## Accessing the Dashboard via Browser :chart_with_upwards_trend:
207 |
208 | The dashboard should open automatically in your browser. For remote workstations, access it using the provided network URL, ensuring the port is not firewall-blocked.
209 |
210 | Remote access example:
211 |
212 | ```
213 | (FollowRelion) dzyla@PC-HOME:~/Follow_Relion_gracefully$ streamlit run follow_relion_gracefully.py --server.port 8501 -- --p 1234 --i /mnt/f/linux/Tutorial5.0/
214 |
215 | Local URL: http://localhost:8501
216 | Network URL: http://172.21.222.176:8501
217 | ```
218 |
219 | Open the network URL in your browser to access the dashboard.
220 |
221 | For firewall issues, create an SSH tunnel:
222 |
223 | ```bash
224 | ssh -f username@workstation -L 8501:localhost:8501 -N
225 | ```
226 |
227 | This allows remote dashboard access on your local computer: http://localhost:8501.
228 |
229 | ## Troubleshooting :wrench:
230 |
231 | * As previous versions, the code is a hobby project and may not work perfectly. Please report any issues on GitHub.
232 | * Large volumes (500px+) load slowly, especially for multiple class-3D classifications. Plotly does an excellent job with plotting but too large data can slow down the browser.
233 | * Ensure the correct environment is activated (`FollowRelion`). Deactivate others with `conda deactivate`.
234 | * Jobs run manually may not be processed, as the script reads from `default_pipeline.star`.
235 | * There are some cases with warning about the session state. They can be ignored.
236 | * Rendering issues in the browser can often be resolved by refreshing (`F5`).
237 | * Mac support is untested, but it's assumed to work similarly to Linux. Please report any issues!
238 | * Please note that this code was developed by a Python enthusiast, not a professional developer. It has been tested under standard scenarios to ensure reliability. However, as the author, I cannot be held responsible for any issues or damages that may arise from its use. Users are encouraged to review and test the code thoroughly before implementation in their projects.
239 |
240 |
241 | ## To-do :memo:
242 |
243 | * Add support for `DynaMight`
244 | * Speed up processing for some jobs
245 | * Add cryoSPARC metadata support and export to Relion
246 |
247 |
248 |
249 |
250 | ## Questions/suggestions?:email:
251 |
252 | Dawid Zyla, La Jolla Institute for Immunology
253 |
254 | [Twitter](https://twitter.com/DawidZyla)
255 |
256 | [dzyla@lji.org](mailto:dzyla@lji.org)
257 |
--------------------------------------------------------------------------------
/relion_jobs/ctffind_job.py:
--------------------------------------------------------------------------------
1 | #ctffind_job.py
2 |
3 | import os
4 | import logging
5 | from datetime import datetime
6 |
7 | import streamlit as st
8 | import pandas as pd
9 | import numpy as np
10 | import altair as alt
11 | import mrcfile
12 |
13 | from lib.utils import (
14 | parse_star,
15 | get_values_from_first_key,
16 | report_error,
17 | interactive_scatter_plot,
18 | )
19 | from lib.image_utils import clip, normalize
20 |
21 | logger = logging.getLogger("main_app")
22 |
23 | # =============================================================================
24 | # Cached helpers
25 | # =============================================================================
26 | @st.cache_data(show_spinner=False)
27 | def read_ctf_average_data(ctf_avrot_path: str) -> pd.DataFrame:
28 | if not os.path.exists(ctf_avrot_path):
29 | raise FileNotFoundError(ctf_avrot_path)
30 | df = pd.read_csv(
31 | ctf_avrot_path,
32 | skiprows=[0, 1, 2, 3, 4, 6, 10],
33 | sep=r"\s+",
34 | names=["Spatial_freq", "1D_Ave", "Fit", "Fit_CC"],
35 | engine="python",
36 | )
37 | df["Resolution"] = 1.0 / df["Spatial_freq"]
38 | return df.melt(
39 | id_vars=["Spatial_freq", "Resolution"],
40 | value_vars=["1D_Ave", "Fit", "Fit_CC"],
41 | var_name="Type",
42 | value_name="CTF",
43 | )
44 |
45 |
46 | @st.cache_data(show_spinner=False)
47 | def load_and_normalize_mrc(mrc_path: str) -> np.ndarray:
48 | if not os.path.exists(mrc_path):
49 | raise FileNotFoundError(mrc_path)
50 | with mrcfile.mmap(mrc_path, permissive=True) as mrc:
51 | data = np.squeeze(mrc.data)
52 | return normalize(clip(data, 1, 99))
53 |
54 |
55 | @st.cache_data(show_spinner=True)
56 | def load_star_data(folder: str, star_file: str) -> pd.DataFrame:
57 | """
58 | Return a dataframe with **one row per micrograph or tomogram**.
59 |
60 | For single-particle projects it takes the *micrographs* table.
61 |
62 | For tomography projects (STAR contains only the *global* table) it opens
63 | every `_rlnTomoTiltSeriesStarFile`, extracts the first table inside each
64 | of those STAR files, and concatenates them.
65 |
66 | The caller therefore always receives a dataframe with real, per-image
67 | statistics – never the empty *global* table.
68 | """
69 | path = os.path.join(folder, star_file)
70 | star = parse_star(path)
71 |
72 | # ── single-particle ────────────────────────────────────────────────────
73 | if "micrographs" in star:
74 | return star["micrographs"]
75 |
76 | # ── tomography: expand the global list ────────────────────────────────
77 | if "global" in star and "_rlnTomoTiltSeriesStarFile" in star["global"].columns:
78 | series_files = star["global"]["_rlnTomoTiltSeriesStarFile"].astype(str)
79 | dfs: list[pd.DataFrame] = []
80 | for rel_p in series_files:
81 | sub_path = os.path.join(folder, rel_p)
82 | try:
83 | df = get_values_from_first_key(parse_star(sub_path))
84 | if isinstance(df, pd.DataFrame):
85 | dfs.append(df)
86 | except Exception as exc:
87 | report_error(exc, f"Parsing tilt-series STAR {sub_path}")
88 | if dfs:
89 | return pd.concat(dfs, ignore_index=True)
90 |
91 | raise ValueError("Could not find a usable table (micrographs / global).")
92 |
93 | # =============================================================================
94 | # Plot functions (unchanged except for minor guards)
95 | # =============================================================================
96 | def plot_CTF_average(FOLDER: str, ctf_file_path: str) -> None:
97 | """
98 | Plot the 1D CTF fit per micrograph from a given text file using Altair.
99 |
100 | Parameters:
101 | FOLDER (str): Folder containing the CTF file.
102 | ctf_file_path (str): Path to the CTF file.
103 |
104 | Returns:
105 | None.
106 | """
107 | col1, col2 = st.columns([3, 1])
108 | try:
109 | power_spectrum_ave_rot_txt_paths = os.path.join(FOLDER, ctf_file_path.replace(".ctf:mrc", "_avrot.txt"))
110 | if not os.path.exists(power_spectrum_ave_rot_txt_paths):
111 | st.error("File not found: " + power_spectrum_ave_rot_txt_paths)
112 | return
113 |
114 | ave_rot = pd.read_csv(
115 | power_spectrum_ave_rot_txt_paths,
116 | skiprows=[0, 1, 2, 3, 4, 6, 10],
117 | header=None,
118 | sep=r'\s+'
119 | ).transpose()
120 | ave_rot.columns = ["Spatial_freq", "1D_Ave", "Fit", "Fit_CC"]
121 |
122 | ave_rot["Resolution"] = 1 / ave_rot["Spatial_freq"]
123 | ave_rot_melt = ave_rot.melt(
124 | id_vars=["Spatial_freq", "Resolution"],
125 | value_vars=["1D_Ave", "Fit", "Fit_CC"],
126 | var_name="Type",
127 | value_name="CTF"
128 | )
129 |
130 | # Use three distinct colors for these three lines
131 | chart_ctf = alt.Chart(ave_rot_melt).mark_line().encode(
132 | x=alt.X("Spatial_freq:Q", title="Spatial Frequency (1/Å)"),
133 | y=alt.Y("CTF:Q", title="CTF"),
134 | color=alt.Color("Type:N", scale=alt.Scale(range=["#FA8072", "#6FC381", "#6495ED"])),
135 | tooltip=["Spatial_freq", "Resolution", "CTF", "Type"]
136 | ).properties(
137 | title="CTF Fit per Micrograph",
138 | width=600,
139 | height=400
140 | )
141 |
142 | col1.altair_chart(chart_ctf, use_container_width=True)
143 |
144 | try:
145 | mrc_file = os.path.join(FOLDER, ctf_file_path.replace("_avrot.txt", ".ctf").replace(":mrc", ""))
146 | logger.debug(f"CTF file: {mrc_file}")
147 |
148 | mrc_data = np.squeeze(mrcfile.mmap(mrc_file).data)
149 | mrc_image = normalize(clip(mrc_data, 1, 99))
150 | col2.image(mrc_image, caption=f"CTF Image: {mrc_file}")
151 | except Exception as exc:
152 | report_error(exc)
153 | col2.error("Error loading CTF image.")
154 |
155 | except Exception as exc:
156 | report_error(exc)
157 | logger.error("Error in plot_CTF_average function.")
158 |
159 | def plot_ctf_stats(folder: str, star_file: str) -> None:
160 | # ── data load (tomography-aware) ───────────────────────────────────────
161 | try:
162 | star_df = load_star_data(folder, star_file)
163 | except Exception as exc:
164 | report_error(exc)
165 | st.error("Failed to load STAR data.")
166 | return
167 |
168 | star_df = star_df.apply(lambda c: pd.to_numeric(c, errors="ignore"))
169 | star_df["Index"] = np.arange(len(star_df))
170 |
171 | cols = {
172 | "defocus_v": "_rlnDefocusV",
173 | "max_res": "_rlnCtfMaxResolution",
174 | }
175 | numeric_cols = {
176 | label: col
177 | for label, col in cols.items()
178 | if col in star_df.columns and pd.api.types.is_numeric_dtype(star_df[col])
179 | }
180 | if not numeric_cols:
181 | st.error("No numeric CTF columns found.")
182 | return
183 |
184 | st.subheader("CTF Statistics per Image")
185 |
186 | # Defocus plot
187 | if "defocus_v" in numeric_cols:
188 | cv = numeric_cols["defocus_v"]
189 | c1, c2 = st.columns([4, 1])
190 | line = (
191 | alt.Chart(star_df)
192 | .mark_line(color="#6FC381")
193 | .encode(x="Index:Q", y=f"{cv}:Q", tooltip=["Index", f"{cv}:Q"])
194 | .properties(title="Defocus V", height=300)
195 |
196 | )
197 | kde = (
198 | alt.Chart(star_df)
199 | .transform_density(density=cv, as_=[cv, "density"])
200 | .mark_area(orient="horizontal")
201 | .encode(y=f"{cv}:Q", x="density:Q")
202 | .properties(height=300)
203 | )
204 | c1.altair_chart(line, use_container_width=True)
205 | c2.altair_chart(kde, use_container_width=True)
206 |
207 | # Max-resolution plot
208 | if "max_res" in numeric_cols:
209 | mr = numeric_cols["max_res"]
210 | c1, c2 = st.columns([4, 1])
211 | line = (
212 | alt.Chart(star_df)
213 | .mark_line(color="#6495ED")
214 | .encode(x="Index:Q", y=f"{mr}:Q", tooltip=["Index", f"{mr}:Q"])
215 | .properties(title="Max Resolution", height=300)
216 |
217 | )
218 | kde = (
219 | alt.Chart(star_df)
220 | .transform_density(density=mr, as_=[mr, "density"])
221 | .mark_area(orient="horizontal")
222 | .encode(y=f"{mr}:Q", x="density:Q")
223 | .properties(height=300)
224 | )
225 | c1.altair_chart(line, use_container_width=True)
226 | c2.altair_chart(kde, use_container_width=True)
227 |
228 | # 2-D histogram
229 | st.subheader("2-D Parameter Distribution")
230 | numeric_candidates = [
231 | c for c in star_df.columns if pd.api.types.is_numeric_dtype(star_df[c]) and star_df[c].nunique() > 1
232 | ]
233 | if len(numeric_candidates) >= 2:
234 | x_col = st.selectbox("X-axis", numeric_candidates, key="2dh_x")
235 | y_col = st.selectbox(
236 | "Y-axis", numeric_candidates, key="2dh_y", index=1 if x_col == numeric_candidates[0] else 0
237 | )
238 | if x_col != y_col:
239 | chart2d = (
240 | alt.Chart(star_df)
241 | .mark_rect()
242 | .encode(
243 | x=alt.X(f"{x_col}:Q", bin=alt.Bin(maxbins=50)),
244 | y=alt.Y(f"{y_col}:Q", bin=alt.Bin(maxbins=50)),
245 | color=alt.Color("count()", scale=alt.Scale(scheme="viridis")),
246 | )
247 | .properties(title=f"2-D Histogram: {x_col} vs {y_col}")
248 |
249 | )
250 | st.altair_chart(chart2d, use_container_width=True)
251 | else:
252 | st.warning("Select two different columns.")
253 | else:
254 | st.info("Not enough numeric columns for a 2-D histogram.")
255 |
256 | # 1-D CTF fit per image (only if column exists)
257 | if "_rlnCtfImage" in star_df.columns and star_df["_rlnCtfImage"].notna().any():
258 | st.subheader("1-D CTF Fit")
259 | img_series = star_df["_rlnCtfImage"].astype(str)
260 | idx = st.slider("Image index", 0, len(img_series) - 1, 0)
261 | plot_CTF_average(folder, img_series.iloc[idx])
262 |
263 | # optional interactive scatter
264 | if st.checkbox("Detailed scatter plot?", key="scatter"):
265 | interactive_scatter_plot(os.path.join(folder, star_file))
266 |
267 | logger.info(f"{datetime.now()}: plot_ctf_stats finished")
268 |
--------------------------------------------------------------------------------
/relion_jobs/picking_job.py:
--------------------------------------------------------------------------------
1 | # picking_job.py
2 |
3 | import os
4 | import glob
5 | import tempfile
6 | import logging
7 | from datetime import datetime
8 | from typing import List, Tuple, Union
9 |
10 | import streamlit as st
11 | import pandas as pd
12 | import numpy as np
13 | import plotly.express as px
14 | import plotly.graph_objects as go
15 |
16 | from lib.utils import parse_star, star_from_df, interactive_scatter_plot, report_error
17 | from lib.image_utils import micrograph_viewer
18 |
19 | logger = logging.getLogger("main_app")
20 |
21 | def plot_histogram_particles_per_mic(
22 | particles_per_mic: Union[pd.Series, np.ndarray]
23 | ) -> go.Figure:
24 | """
25 | Create a histogram of particles per micrograph.
26 |
27 | Parameters:
28 | particles_per_mic: Series or array of particle counts.
29 |
30 | Returns:
31 | A Plotly Figure.
32 | """
33 | try:
34 | fig = px.histogram(particles_per_mic, nbins=30)
35 | fig.update_layout(
36 | title="Histogram of Particles Per Micrograph",
37 | xaxis_title="Particles",
38 | yaxis_title="Micrographs Count",
39 | )
40 | return fig
41 | except Exception as exc:
42 | report_error(exc)
43 | logger.error("Error in plot_histogram_particles_per_mic.")
44 | return go.Figure()
45 |
46 |
47 | def plot_histogram_fom(
48 | figure_of_merit: Union[pd.Series, np.ndarray]
49 | ) -> go.Figure:
50 | """
51 | Create a histogram of autopick figure-of-merit.
52 |
53 | Parameters:
54 | figure_of_merit: Series or array of FOM values.
55 |
56 | Returns:
57 | A Plotly Figure.
58 | """
59 | try:
60 | fig = px.histogram(figure_of_merit, nbins=50)
61 | fig.update_layout(
62 | title="Histogram of Autopick Figure of Merit",
63 | xaxis_title="Autopick Figure of Merit",
64 | yaxis_title="Count",
65 | )
66 | return fig
67 | except Exception as exc:
68 | report_error(exc)
69 | logger.error("Error in plot_histogram_fom.")
70 | return go.Figure()
71 |
72 |
73 | def get_coord_paths(job_folder: str, rln_folder: str) -> Tuple[List[str], List[str]]:
74 | """
75 | Retrieve coordinate file paths and corresponding micrograph paths.
76 |
77 | Checks in order:
78 | - autopick.star
79 | - manualpick.star
80 | - files matching "coords_suffix_*"
81 |
82 | Parameters:
83 | job_folder (str): Path to the job folder.
84 | rln_folder (str): Base folder for micrographs.
85 |
86 | Returns:
87 | A tuple (coord_paths, mics_paths).
88 | """
89 | try:
90 | # Case 1: autopick.star
91 | autopick_star_path = os.path.join(job_folder, "autopick.star")
92 | if os.path.exists(autopick_star_path) and os.path.getsize(autopick_star_path) > 0:
93 | autopick_star = parse_star(autopick_star_path)["coordinate_files"]
94 | mics_paths = autopick_star["_rlnMicrographName"].to_numpy().tolist()
95 | coord_paths = autopick_star["_rlnMicrographCoordinates"].to_numpy().tolist()
96 | return coord_paths, mics_paths
97 |
98 | # Case 2: manualpick.star
99 | manualpick_star_path = os.path.join(job_folder, "manualpick.star")
100 | if os.path.exists(manualpick_star_path) and os.path.getsize(manualpick_star_path) > 0:
101 | manpick_star = parse_star(manualpick_star_path)["coordinate_files"]
102 | mics_paths = manpick_star["_rlnMicrographName"].to_numpy().tolist()
103 | coord_paths = manpick_star["_rlnMicrographCoordinates"].to_numpy().tolist()
104 | return coord_paths, mics_paths
105 |
106 | # Case 3: coords_suffix_* pattern
107 | suffix_files = glob.glob(os.path.join(job_folder, "coords_suffix_*"))
108 | if suffix_files:
109 | suffix_file = suffix_files[0]
110 | suffix = os.path.basename(suffix_file).replace("coords_suffix_", "").replace(".star", "")
111 | with open(suffix_file, "r") as f:
112 | mics_data_path = f.readline().strip()
113 | all_mics_paths = parse_star(os.path.join(rln_folder, mics_data_path))["micrographs"]["_rlnMicrographName"]
114 | mics_paths = [os.path.join(rln_folder, name) for name in all_mics_paths]
115 | coord_paths = [
116 | os.path.join(
117 | job_folder,
118 | f"coords_{suffix}",
119 | os.path.basename(mic_path).replace(".mrc", f"_{suffix}.star")
120 | )
121 | for mic_path in mics_paths
122 | ]
123 | return coord_paths, mics_paths
124 |
125 | # Case 4: No matching pattern found.
126 | return [], []
127 | except Exception as exc:
128 | report_error(exc)
129 | logger.error("Error in get_coord_paths.")
130 | return [], []
131 |
132 |
133 |
134 |
135 |
136 | def plot_picks(
137 | rln_folder: str, job_name: str, img_resize_fac: float = 0.2
138 | ) -> None:
139 | """
140 | Display picking statistics and overlay picks on the selected micrograph using micrograph_viewer.
141 | Determines the coordinate source (autopick.star, manualpick.star, or coords_suffix_*)
142 | and passes the computed picks to micrograph_viewer (which now supports optional picks overlay).
143 | Also supports Topaz training statistics if no coordinate files are found.
144 |
145 | Parameters:
146 | rln_folder (str): Base folder for job and micrograph data.
147 | job_name (str): Name of the job folder.
148 | img_resize_fac (float): Initial resize factor.
149 |
150 | Returns:
151 | None.
152 | """
153 |
154 | logger.debug(f"{datetime.now()}: plot_picks started with job_name: {job_name}")
155 | try:
156 | path_data = os.path.join(rln_folder, job_name)
157 | coord_paths, mics_paths = get_coord_paths(path_data, job_name)
158 |
159 | logger.debug(f"coord_paths: {coord_paths}")
160 |
161 | # Fallback for Topaz training statistics.
162 | if not coord_paths:
163 | topaz_training_files = glob.glob(os.path.join(path_data, "model_training.txt"))
164 | if topaz_training_files:
165 | topaz_training_txt = topaz_training_files[0]
166 | data = pd.read_csv(topaz_training_txt, delimiter="\t")
167 | data_test = data[data["split"] == "test"]
168 | x = data_test["epoch"]
169 | data_test = data_test.drop(["iter", "split", "ge_penalty"], axis=1)
170 | fig = go.Figure()
171 | for column in data_test.columns:
172 | if column != "epoch":
173 | y = data_test[column]
174 | fig.add_scatter(
175 | x=x,
176 | y=y,
177 | name=column,
178 | hovertemplate=f"{column}
Epoch: %{{x}}
Y: %{{y:.2f}}",
179 | )
180 | fig.update_xaxes(title_text="Epoch")
181 | fig.update_yaxes(title_text="Statistics")
182 | best_epoch = data_test[data_test["auprc"].astype(float) == np.max(data_test["auprc"].astype(float))]["epoch"].values
183 | fig.update_layout(title=f"Topaz training stats. Best model: {best_epoch}")
184 | fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))
185 | fig.update_layout(hovermode="x unified")
186 | st.plotly_chart(fig, use_container_width=True)
187 | logger.info(f"{datetime.now()}: plot_picks_streamlit done (Topaz training)")
188 | return
189 | st.write("No coordinate files found.")
190 | logger.info(f"{datetime.now()}: plot_picks_streamlit done (No coordinate files)")
191 | return
192 |
193 | # Select a micrograph.
194 | col1, col2 = st.columns([1, 3])
195 |
196 | # For autopick data (non-manual), get FOM slider and compute picks overlay.
197 | if "ManualPick" not in job_name:
198 | # problem with indexing and passing it to micrograph_viewer######
199 |
200 |
201 | plot_all = col1.checkbox("Plot FOM statistics? (Might be slow for huge datasets)", value=False)
202 | fom_all_mics = []
203 | tmp_file_path = None
204 | if plot_all:
205 | coords_df = pd.DataFrame()
206 | for idx in range(len(mics_paths)):
207 | try:
208 | star_data = parse_star(os.path.join(rln_folder, coord_paths[idx]))
209 | star_data = list(star_data.values())[0]
210 | star_data["_rlnMicrographName"] = mics_paths[idx]
211 | coords_df = pd.concat([coords_df, star_data], ignore_index=True)
212 | fom_stats = star_data["_rlnAutopickFigureOfMerit"].astype(float)
213 | fom_all_mics.extend(fom_stats)
214 | except Exception as exc:
215 | report_error(exc)
216 | logger.error(f"Error processing picking stats for micrograph index {idx}")
217 | modified_star = star_from_df({"particles": coords_df})
218 | with tempfile.NamedTemporaryFile(delete=False, suffix=".star") as tmp_file:
219 | modified_star.write_file(tmp_file.name)
220 | tmp_file_path = tmp_file.name
221 | else:
222 | fom_slider = [-10000, 10000]
223 | picks_overlay = None
224 | plot_all = False
225 |
226 | # Use micrograph_viewer to display the micrograph with picks overlay.
227 | # If picks_overlay is None, micrograph_viewer behaves as before.
228 | micrograph_viewer(
229 | rln_folder=rln_folder,
230 | image_files=mics_paths,
231 | selected_filter="gaussian",
232 | default_gaussian=0.2,
233 | coord_paths=coord_paths,
234 | )
235 |
236 | # For autopick, if "plot all" is enabled, show FOM histogram and detailed stats.
237 | if "ManualPick" not in job_name and plot_all and fom_all_mics:
238 | fom_all_mics = np.array(fom_all_mics)
239 | fom_histogram_fig = plot_histogram_fom(fom_all_mics)
240 | col2.plotly_chart(fom_histogram_fig, use_container_width=True)
241 |
242 | if st.checkbox("Show detailed statistics?") and tmp_file_path:
243 | interactive_scatter_plot(tmp_file_path)
244 |
245 | logger.info(f"{datetime.now()}: plot_picks_streamlit done")
246 | except Exception as exc:
247 | report_error(exc)
248 | logger.error("Error in plot_picks_streamlit function.")
--------------------------------------------------------------------------------
/relion_jobs/mask_job.py:
--------------------------------------------------------------------------------
1 | #mask_job.py
2 |
3 | import os
4 | import logging
5 | import traceback
6 | from datetime import datetime
7 | from typing import List
8 |
9 | import numpy as np
10 | import mrcfile
11 | import streamlit as st
12 | import plotly.graph_objects as go
13 | import plotly.figure_factory as ff
14 | from skimage.transform import resize
15 |
16 | # Import shared utilities.
17 | from lib.utils import get_note, extract_source_job
18 | from lib.image_utils import normalize
19 |
20 | logger = logging.getLogger("mask_job")
21 | logger.setLevel(logging.DEBUG)
22 |
23 |
24 | def plot_volume_custom(volume: np.ndarray, threshold: float, max_size: int = 150,
25 | colormap_override: str = None, opacity_override: float = None,
26 | title: str = None) -> go.Figure:
27 | """
28 | Generate an isosurface plot using marching cubes with Plotly.
29 |
30 | Args:
31 | volume (np.ndarray): 3D volume data.
32 | threshold (float): Fraction (0 to 1) used to compute the actual threshold intensity.
33 | max_size (int): Maximum dimension size to which the volume is resized.
34 | colormap_override (str): If provided, use a single color (e.g. "rgb(100,149,237)")
35 | for all facets instead of the default gray.
36 | opacity_override (float): If provided, set marker opacity (0 to 1).
37 | title (str): Title of the figure.
38 |
39 | Returns:
40 | A Plotly Figure containing the isosurface.
41 | """
42 | try:
43 | # Create a writable copy.
44 | volume = np.array(volume, copy=True)
45 | volume.flags.writeable = True
46 |
47 | # Resize if needed.
48 | original_shape = volume.shape
49 | if np.any(np.array(original_shape) > max_size):
50 | resize_factor = max_size / np.max(volume.shape)
51 | new_shape = np.round(np.array(volume.shape) * resize_factor).astype(int)
52 | volume = resize(volume, new_shape, anti_aliasing=False)
53 |
54 | # Determine threshold intensity.
55 | min_val, max_val = np.min(volume), np.max(volume)
56 | actual_threshold = min_val + (max_val - min_val) * threshold
57 |
58 | # Compute mesh via marching cubes using mcubes.
59 | import mcubes
60 | verts, faces = mcubes.marching_cubes(volume, actual_threshold)
61 |
62 | # Use default gray color or override.
63 | if colormap_override is None:
64 | color_list = ["rgb(180,180,180)"]
65 | else:
66 | color_list = [colormap_override]
67 |
68 | fig_volume = ff.create_trisurf(
69 | x=verts[:, 2],
70 | y=verts[:, 1],
71 | z=verts[:, 0],
72 | simplices=faces,
73 | colormap=color_list,
74 | plot_edges=False,
75 | showbackground=True,
76 | show_colorbar=False,
77 | )
78 |
79 | # Modify the Mesh3d trace.
80 | mesh = fig_volume.data[0]
81 | mesh.update(
82 | flatshading=False,
83 | lighting=dict(
84 | ambient=0.25,
85 | diffuse=0.9,
86 | specular=0.2,
87 | roughness=0.7,
88 | fresnel=4,
89 | ),
90 | lightposition=dict(x=0, y=500, z=500),
91 | )
92 | if opacity_override is not None:
93 | mesh.opacity = opacity_override
94 |
95 | # Layout: black background.
96 | fig_volume.update_layout(
97 | title=title if title is not None else "3D Volume Isosurface",
98 | scene=dict(
99 | camera=dict(eye=dict(x=2, y=2, z=2)),
100 | xaxis=dict(visible=False),
101 | yaxis=dict(visible=False),
102 | zaxis=dict(visible=False),
103 | aspectmode="data",
104 | bgcolor="black",
105 | ),
106 | paper_bgcolor="black",
107 | height=600,
108 | margin=dict(l=0, r=0, t=0, b=0),
109 | hovermode=False,
110 | )
111 |
112 | return fig_volume
113 |
114 | except RuntimeError:
115 | return None
116 | except Exception as exc:
117 | logger.error(f"plot_volume_custom() error: {exc}")
118 | return None
119 |
120 |
121 | def overlay_projections(source: np.ndarray, mask: np.ndarray, axis: int = 0, alpha: float = 0.7) -> np.ndarray:
122 | """
123 | Compute a maximum intensity projection (MIP) of the source and mask volumes along the given axis,
124 | then overlay the mask (colored in cornflower blue) on top of the source (in grayscale).
125 |
126 | Args:
127 | source (np.ndarray): 3D source volume.
128 | mask (np.ndarray): 3D mask volume.
129 | axis (int): Axis along which to compute the projection.
130 | alpha (float): Transparency factor for the mask.
131 |
132 | Returns:
133 | An RGB image (as a NumPy array) combining the source and mask projections.
134 | """
135 | # Compute MIPs.
136 | proj_source = source
137 | proj_mask = mask
138 |
139 | # Normalize projections.
140 | norm_source = normalize(proj_source)
141 | norm_mask = normalize(proj_mask)
142 |
143 | # Convert grayscale source to RGB.
144 | source_rgb = np.dstack([norm_source] * 3)
145 | # Define cornflower blue in normalized RGB.
146 | blue = np.array([100,149,237], dtype=np.float32)/255.0
147 |
148 | # Create a mask image with cornflower blue.
149 | mask_rgb = np.ones_like(source_rgb)
150 | mask_rgb[..., 0] = blue[0]
151 | mask_rgb[..., 1] = blue[1]
152 | mask_rgb[..., 2] = blue[2]
153 |
154 | # Blend images: raw map + mask overlay.
155 | mask_intensity = norm_mask[..., np.newaxis]
156 | overlay = source_rgb * (1 - alpha * mask_intensity) + (alpha * mask_rgb * mask_intensity)
157 | overlay = np.clip(overlay, 0, 1)
158 | overlay_uint8 = (overlay * 255).astype(np.uint8)
159 | return overlay_uint8
160 |
161 |
162 | def plot_mask(rln_folder: str, nodes: List[str]) -> None:
163 | """
164 | Plot the mask and the source volume together.
165 |
166 | The function:
167 | 1) Retrieves the mask path and source job path.
168 | 2) Caches the volumes in session state (job-specific) so that repeated
169 | threshold changes do not cause re-reading from disk.
170 | 3) Provides two view modes:
171 | - 3D: Isosurface overlays using plot_volume_custom.
172 | - 2D: Maximum intensity projections overlaid (raw map in grayscale and mask in cornflower blue).
173 | """
174 | logger.info("Plotting mask for nodes: %s", nodes)
175 |
176 | # Determine job folder from nodes.
177 | job_folder = os.path.join(rln_folder, os.path.dirname(nodes[0]))
178 | # Use a unique job key for caching, e.g., based on job folder.
179 | job_key = f"mask_job_{job_folder}"
180 | if job_key not in st.session_state:
181 | st.session_state[job_key] = {}
182 |
183 | note_path = os.path.join(job_folder, "note.txt")
184 | note = get_note(note_path)
185 | source_job = extract_source_job(note)
186 |
187 | mask_path = os.path.join(rln_folder, nodes[0])
188 | source_job_path = os.path.join(rln_folder, source_job)
189 |
190 | # Cache volumes in session state (job-specific).
191 | cache = st.session_state[job_key]
192 | if "mask_volume" not in cache:
193 | try:
194 | mask_volume = mrcfile.mmap(mask_path, mode='r').data
195 | cache["mask_volume"] = mask_volume
196 | except Exception as exc:
197 | st.error(f"Error loading mask: {exc}")
198 | return
199 | else:
200 | mask_volume = cache["mask_volume"]
201 |
202 | if "source_volume" not in cache:
203 | try:
204 | source_volume = mrcfile.mmap(source_job_path, mode='r').data
205 | cache["source_volume"] = source_volume
206 | except Exception as exc:
207 | st.error(f"Error loading source volume: {exc}")
208 | return
209 | else:
210 | source_volume = cache["source_volume"]
211 |
212 | # Display mode selection: 3D or 2D.
213 | c1, c2 = st.columns([1, 3])
214 | view_mode = c1.radio("Display mode", ["3D", "2D"], horizontal=True)
215 |
216 | if view_mode == "3D":
217 | with c1:
218 | st.markdown("### 3D Rendering Controls")
219 | threshold_source = st.slider("Original Map Threshold Fraction", 0.0, 1.0, 0.5, 0.01)
220 | threshold_mask = st.slider("Mask Threshold Fraction", 0.0, 1.0, 0.5, 0.01)
221 | # Generate isosurface figures.
222 | fig_source = plot_volume_custom(source_volume, threshold_source, max_size=150,
223 | colormap_override=None, opacity_override=None,
224 | title="Original Map")
225 | cornflower_blue = "rgb(100,149,237)"
226 | fig_mask = plot_volume_custom(mask_volume, threshold_mask, max_size=150,
227 | colormap_override=cornflower_blue, opacity_override=0.5,
228 | title="Mask")
229 | # Overlay the two isosurface traces.
230 | fig_overlay = go.Figure()
231 | if fig_source and len(fig_source.data) > 0:
232 | fig_overlay.add_trace(fig_source.data[0])
233 | if fig_mask and len(fig_mask.data) > 0:
234 | fig_overlay.add_trace(fig_mask.data[0])
235 | fig_overlay.update_layout(
236 | title="Overlay: Original Map and Mask (3D)",
237 | scene=dict(
238 | camera=dict(eye=dict(x=2, y=2, z=2)),
239 | xaxis=dict(visible=False),
240 | yaxis=dict(visible=False),
241 | zaxis=dict(visible=False),
242 | aspectmode="data",
243 | bgcolor="black",
244 | ),
245 | paper_bgcolor="black",
246 | height=600,
247 | margin=dict(l=0, r=0, t=0, b=0),
248 | hovermode=False,
249 | )
250 | c2.plotly_chart(fig_overlay, use_container_width=True)
251 | else:
252 | c1.markdown("### 2D Projection Controls")
253 | with c1:
254 | projection_axis = st.radio("Projection axis", ['XY', 'XZ', 'YZ'], horizontal=True)
255 | # Map projection axis to axis index: use convention where 'XY' is projection along Z.
256 | projection_axis = {'XY': 0, 'XZ': 1, 'YZ': 2}[projection_axis]
257 | st.markdown("Adjust projection settings as needed.")
258 | # Compute maximum intensity projections for both volumes.
259 | proj_source = np.mean(source_volume, axis=projection_axis)
260 | proj_mask = np.max(mask_volume, axis=projection_axis)
261 | proj_source_norm = normalize(proj_source)
262 | proj_mask_norm = normalize(proj_mask)
263 | overlay_image = overlay_projections(proj_source_norm, proj_mask_norm, axis=projection_axis, alpha=0.7)
264 | c2.image(overlay_image, caption="Overlay of Original Map and Mask", width=400)
265 |
266 | logger.info(f"{datetime.now()}: plot_mask done.")
267 |
--------------------------------------------------------------------------------
/relion_jobs/localres_job.py:
--------------------------------------------------------------------------------
1 | # localres_job.py
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 | import re
7 | import traceback
8 | from typing import List, Tuple, Optional
9 |
10 | import mrcfile
11 | import numpy as np
12 | import plotly.graph_objects as go
13 | import streamlit as st
14 | from datetime import datetime
15 | from plotly.subplots import make_subplots
16 | from scipy.stats import gaussian_kde
17 | from skimage.transform import resize
18 |
19 | from lib.utils import report_error
20 | import logging
21 |
22 | logger = logging.getLogger("main_app")
23 |
24 | # ── pymcubes ────────────────────────────────────────────────────────────────
25 | try:
26 | import mcubes
27 | except ImportError:
28 | mcubes = None
29 |
30 | # ─────────────────────────────────────────────────────────────────────────────
31 | # helper funcs
32 | # ─────────────────────────────────────────────────────────────────────────────
33 | def _load_mrc(path: str) -> np.ndarray:
34 | if not os.path.isfile(path):
35 | raise FileNotFoundError(path)
36 | with mrcfile.mmap(path, mode="r", permissive=True) as mrc:
37 | return mrc.data.copy()
38 |
39 |
40 | def _assign_maps(nodes: List[str], folder: str) -> Tuple[str, str]:
41 | raw, loc = "", ""
42 | for f in nodes:
43 | if not f.lower().endswith(".mrc"):
44 | continue
45 | base = os.path.basename(f).lower()
46 | full = os.path.join(folder, f)
47 | if "locres_filtered" in base:
48 | raw = full
49 | elif "locres" in base and "filtered" not in base:
50 | loc = full
51 | if raw and loc:
52 | break
53 | if not raw or not loc:
54 | raise ValueError("Missing relion_locres_filtered.mrc or relion_locres.mrc")
55 | return raw, loc
56 |
57 |
58 | def _orthogonal_slices(vol: np.ndarray, idx: int) -> tuple[np.ndarray, ...]:
59 | i = int(np.clip(idx, 0, min(vol.shape) - 1))
60 | return vol[i], vol[:, i, :], vol[:, :, i]
61 |
62 |
63 | def _sample_vertex_values(vol: np.ndarray, verts: np.ndarray) -> np.ndarray:
64 | x = np.clip(np.round(verts[:, 0]).astype(int), 0, vol.shape[0] - 1)
65 | y = np.clip(np.round(verts[:, 1]).astype(int), 0, vol.shape[1] - 1)
66 | z = np.clip(np.round(verts[:, 2]).astype(int), 0, vol.shape[2] - 1)
67 | return vol[x, y, z]
68 |
69 | # ───────────────────────────── 3-D isosurface ───────────────────────────────
70 | def _plot_isosurface(
71 | raw_vol: np.ndarray,
72 | loc_vol: np.ndarray,
73 | rel_thr: float,
74 | *,
75 | max_size: int,
76 | colourscale: str,
77 | ) -> Optional[go.Figure]:
78 | if mcubes is None:
79 | st.error("Install *pymcubes* for 3-D rendering.")
80 | return None
81 | if raw_vol.shape != loc_vol.shape:
82 | st.warning("Shape mismatch.")
83 | return None
84 |
85 | # down-sample for performance
86 | if max(raw_vol.shape) > max_size:
87 | fac = max_size / max(raw_vol.shape)
88 | new_shape = tuple(int(round(d * fac)) for d in raw_vol.shape)
89 | raw_vol = resize(raw_vol, new_shape, preserve_range=True, anti_aliasing=False)
90 | loc_vol = resize(loc_vol, new_shape, preserve_range=True, anti_aliasing=False)
91 |
92 | level = raw_vol.min() + rel_thr * (raw_vol.max() - raw_vol.min())
93 | verts, faces = mcubes.marching_cubes(np.ascontiguousarray(raw_vol, np.float32), level)
94 | if verts.size == 0:
95 | return None
96 |
97 | intens = _sample_vertex_values(loc_vol, verts)
98 | cmin, cmax = float(loc_vol.min()), float(loc_vol.max())
99 | if cmin == cmax:
100 | cmax = cmin + 1e-6
101 |
102 | fig = go.Figure(
103 | data=[
104 | go.Mesh3d(
105 | x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
106 | i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
107 | intensity=intens,
108 | colorscale=colourscale,
109 | cmin=cmin, cmax=cmax,
110 | showscale=True,
111 | opacity=1.0,
112 | lighting=dict(ambient=0.3, diffuse=0.5, specular=0.2),
113 | lightposition=dict(x=0, y=0, z=0),
114 | hoverinfo="skip",
115 | )
116 | ]
117 | )
118 | fig.update_layout(
119 | scene=dict(
120 | xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False),
121 | aspectmode="data", bgcolor="black", dragmode="orbit"
122 | ),
123 | paper_bgcolor="black",
124 | margin=dict(l=10, r=10, t=30, b=10),
125 | height=600,
126 | title=f"Isosurface (thr ={rel_thr:.2f}) – clipped above ceiling",
127 | )
128 | return fig
129 |
130 | # ─────────────────────────────────────────────────────────────────────────────
131 | # main entry
132 | # ─────────────────────────────────────────────────────────────────────────────
133 | def plot_locres(nodes: List[str], folder: str, job: str) -> None:
134 | """
135 | Streamlit visualiser with:
136 | • view radio (3-D / 2-D)
137 | • threshold slider
138 | • max-dim slider
139 | • colour-scale select
140 | • *new* resolution-ceiling slider
141 | """
142 | try:
143 | raw_p, loc_p = _assign_maps(nodes, folder)
144 |
145 | key_raw = f"raw_{job}_{os.path.getmtime(raw_p)}"
146 | key_loc = f"loc_{job}_{os.path.getmtime(loc_p)}"
147 | if key_raw not in st.session_state:
148 | with st.spinner("Loading raw map…"):
149 | st.session_state[key_raw] = _load_mrc(raw_p)
150 | if key_loc not in st.session_state:
151 | with st.spinner("Loading resolution map…"):
152 | st.session_state[key_loc] = _load_mrc(loc_p)
153 |
154 | raw = st.session_state[key_raw]
155 | loc = st.session_state[key_loc]
156 |
157 | # optional mask
158 | note = os.path.join(folder, job, "note.txt")
159 | if os.path.isfile(note):
160 | txt = open(note, encoding="utf-8").read()
161 | m = re.search(r"--mask\s+([\w\d/\\\.-]+\.mrc)", txt, re.I)
162 | if m:
163 | mpath = os.path.join(folder, m.group(1))
164 | if os.path.isfile(mpath):
165 | mask = _load_mrc(mpath) > 0
166 | raw = np.where(mask, raw, 0)
167 | loc = np.where(mask, loc, 0)
168 | #st.info(f"Mask applied: {os.path.basename(mpath)}")
169 |
170 | # convert zeros to NaN
171 | loc_float = loc.astype(float)
172 | loc_float[loc == 0] = np.nan
173 |
174 | # ── UI controls ───────────────────────────────────────────────────
175 | view = st.radio("View mode:", ("3-D isosurface", "2-D slices"), horizontal=True)
176 |
177 | # common resolution-ceiling slider
178 | finite_all = loc_float[np.isfinite(loc_float)]
179 | res_min, res_max = float(finite_all.min()), float(finite_all.max())
180 | col1, col2, col3, col4 = st.columns([1, 2, 1, 1])
181 | with col1:
182 | res_clip = st.slider(
183 | "Max resolution (Å)",
184 | min_value=round(res_min, 1),
185 | max_value=round(res_max, 1),
186 | value=round(res_max, 1),
187 | step=0.1,
188 | help="Values above this are clipped (set equal to the ceiling).",
189 | )
190 |
191 | # apply clipping for visualisation
192 | loc_clip = np.minimum(loc_float, res_clip)
193 |
194 | if view == "2-D slices":
195 | idx = col2.slider(
196 | "Slice index", 0, min(loc.shape) - 1, (min(loc.shape) - 1) // 2, key=f"{job}_idx"
197 | )
198 | slices = _orthogonal_slices(loc_clip, idx)
199 | vmin, vmax = float(loc_clip.min()), float(loc_clip.max())
200 | if vmin == vmax:
201 | vmax = vmin + 1e-6
202 |
203 | fig_s = make_subplots(rows=1, cols=3, subplot_titles=("XY", "XZ", "YZ"))
204 | for c, slc in enumerate(slices, 1):
205 | fig_s.add_trace(
206 | go.Heatmap(
207 | z=slc.T,
208 | colorscale="RdYlBu_r",
209 | zmin=vmin,
210 | zmax=vmax,
211 | showscale=(c == 3),
212 | colorbar=dict(title="Å") if c == 3 else None,
213 | hoverinfo="skip",
214 | ),
215 | row=1, col=c,
216 | )
217 | fig_s.update_xaxes(visible=False, row=1, col=c)
218 | fig_s.update_yaxes(visible=False, row=1, col=c)
219 | fig_s.update_layout(margin=dict(l=10, r=10, t=40, b=10), height=400)
220 | fig_s.update_layout(
221 | scene=dict(
222 | xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False),
223 | aspectmode="cube"
224 | ),
225 | margin=dict(l=10, r=10, t=30, b=10),
226 | height=600,
227 | )
228 |
229 | st.plotly_chart(fig_s, use_container_width=True)
230 |
231 | else: # 3-D
232 | with col2:
233 | thr = st.slider("Relative threshold", 0.05, 0.95, 0.5, 0.01, key=f"{job}_thr")
234 | with col3:
235 | dim = st.slider("Max dimension (px)", 80, 256, 200, 16, key=f"{job}_dim")
236 | with col4:
237 | colours = st.selectbox(
238 | "Colour scale",
239 | ("Turbo", "Viridis", "Cividis", "Plasma", "Inferno", "Haline"),
240 | index=0,
241 | key=f"{job}_cmap",
242 | )
243 |
244 | with st.spinner("Rendering 3-D view…"):
245 | fig_iso = _plot_isosurface(raw, loc_clip, thr, max_size=dim, colourscale=colours)
246 | if fig_iso:
247 | st.plotly_chart(fig_iso, use_container_width=True)
248 |
249 | # ── histogram / KDE (use clipped values) ───────────────────────────
250 | st.subheader("Local-resolution distribution")
251 | finite_vals = loc_clip[np.isfinite(loc_clip)]
252 | if finite_vals.size:
253 | kde = gaussian_kde(finite_vals)
254 | xs = np.linspace(finite_vals.min(), finite_vals.max(), 200)
255 | p10, p25, p50, p75, p90 = np.percentile(finite_vals, [10, 25, 50, 75, 90])
256 |
257 | fig_h = go.Figure()
258 | fig_h.add_trace(
259 | go.Histogram(
260 | x=finite_vals,
261 | histnorm="probability density",
262 | marker_color="#9cc2cb",
263 | opacity=0.6,
264 | name="Histogram",
265 | )
266 | )
267 | fig_h.add_trace(
268 | go.Scatter(x=xs, y=kde(xs), mode="lines", name="KDE", line=dict(color="#417996", width=3))
269 | )
270 | for val, lbl, colr in [
271 | (p10, "10 %", "grey"),
272 | (p25, "25 %", "orange"),
273 | (p50, "50 %", "red"),
274 | (p75, "75 %", "orange"),
275 | (p90, "90 %", "grey"),
276 | ]:
277 | fig_h.add_vline(x=val, line=dict(color=colr, dash="dash"), annotation_text=lbl, annotation_position="top")
278 |
279 | fig_h.update_layout(
280 | xaxis_title="Resolution (Å)",
281 | yaxis_title="Density",
282 | margin=dict(l=20, r=20, t=20, b=20),
283 | )
284 | st.plotly_chart(fig_h, use_container_width=True)
285 | st.caption(
286 | f"10 %: {p10:.2f} Å 25 %: {p25:.2f} Å Median: {p50:.2f} Å "
287 | f"75 %: {p75:.2f} Å 90 %: {p90:.2f} Å (clipped at {res_clip:.2f} Å)"
288 | )
289 | else:
290 | st.warning("No finite resolution values to plot.")
291 |
292 | except Exception as exc:
293 | st.error(f"plot_locres failed: {exc}")
294 | report_error(exc)
295 | logger.error(f"plot_locres error\n{traceback.format_exc()}")
296 |
--------------------------------------------------------------------------------
/relion_jobs/excludetilt_job.py:
--------------------------------------------------------------------------------
1 | # excludetilt_job.py
2 |
3 | import os
4 | import traceback
5 | import logging
6 | from datetime import datetime
7 |
8 | import streamlit as st
9 | import numpy as np
10 | import pandas as pd
11 | import plotly.graph_objects as go
12 |
13 | from lib.utils import (
14 | parse_star,
15 | get_values_from_first_key,
16 | get_note,
17 | extract_source_job,
18 | report_error,
19 | )
20 |
21 | logger = logging.getLogger("main_app")
22 |
23 | def plot_exclude_tilt(rln_folder: str, node_files: str) -> None:
24 | """
25 | Loads a star file referencing multiple tomography star files. Plots the tilt angles
26 | (_rlnTomoNominalStageTiltAngle) for a user-selected tilt series as:
27 | 1) Blue diameter lines from -r to +r.
28 | 2) Blue points on the circumference for those angles.
29 |
30 | If a source job star file also exists (extracted from note.txt) and references
31 | the same tilt series, we identify the angles that are present in the source
32 | but missing from the selected dataset. We plot those missing angles in red
33 | (both diameter lines and points).
34 | """
35 | logger.info(f"Plotting exclude tilt data with Plotly... {rln_folder}, {node_files}")
36 | star_path = os.path.join(rln_folder, node_files)
37 | logger.debug(f"star_path: {star_path}")
38 |
39 | working_dir = os.path.dirname(star_path)
40 | note = get_note(os.path.join(working_dir, "note.txt"))
41 | logger.debug(f"Note: {note}")
42 |
43 | source_job = os.path.join(rln_folder, extract_source_job(note))
44 | logger.debug(f"Source job: {source_job}")
45 |
46 | # Try to parse the source job file, if it exists.
47 | dfs_source_tilt = []
48 | if os.path.exists(source_job):
49 | try:
50 | source_star = parse_star(source_job)
51 | if source_star is not None and "global" in source_star:
52 | tomo_star_files_source = source_star["global"]["_rlnTomoTiltSeriesStarFile"]
53 | tomo_star_files_source_paths = [
54 | os.path.join(rln_folder, f) for f in tomo_star_files_source
55 | ]
56 |
57 | for path_tomo_source_star in tomo_star_files_source_paths:
58 | try:
59 | star_data_source = parse_star(path_tomo_source_star)
60 | if star_data_source:
61 | df_src = get_values_from_first_key(star_data_source)
62 | df_src["FileSource"] = os.path.basename(path_tomo_source_star)
63 | dfs_source_tilt.append(df_src)
64 | except Exception as exc:
65 | logger.error(f"Error parsing star file {path_tomo_source_star}: {exc}")
66 | report_error(exc)
67 | else:
68 | logger.debug(f"Source star is None or missing 'global' key: {source_job}")
69 | except Exception as exc:
70 | logger.error(f"Error parsing source star file {source_job}: {exc}")
71 | report_error(exc)
72 | else:
73 | logger.error(f"Source job file not found: {source_job}")
74 | # If missing, we will just plot the main star data.
75 |
76 | ##############################
77 | # Now parse the main star file (the exclude job) for the actual tilt data
78 | ##############################
79 | if not os.path.exists(star_path):
80 | st.error(f"Star file not found: {star_path}")
81 | return
82 |
83 | try:
84 | star = parse_star(star_path)
85 | except Exception as exc:
86 | report_error(exc)
87 | st.error("Failed to parse the main star file.")
88 | return
89 |
90 | if "global" not in star:
91 | st.error("The star file does not contain a 'global' section.")
92 | return
93 |
94 | # Gather paths to tilt-series star files from the exclude job
95 | tomo_star_files = star["global"]["_rlnTomoTiltSeriesStarFile"]
96 | tomo_star_files_paths = [os.path.join(rln_folder, f) for f in tomo_star_files]
97 |
98 | # Parse each tilt-series star file
99 | dfs_tilt = []
100 | for path_tomo_star in tomo_star_files_paths:
101 | try:
102 | star_data = parse_star(path_tomo_star)
103 | if star_data:
104 | df_tilt = get_values_from_first_key(star_data)
105 | df_tilt["FileSource"] = os.path.basename(path_tomo_star)
106 | dfs_tilt.append(df_tilt)
107 | except Exception as exc:
108 | logger.error(f"Error parsing star file {path_tomo_star}: {exc}")
109 | report_error(exc)
110 |
111 | if not dfs_tilt:
112 | st.error("No tilt-series data found. Cannot plot tilt angles.")
113 | return
114 |
115 | ##############################
116 | # UI Layout: slider to pick which tilt series to display
117 | ##############################
118 | col1, col2, col3 = st.columns([1, 3, 3])
119 |
120 | with col1:
121 | selected_idx = st.slider(
122 | label="Select Tilt Series",
123 | min_value=0,
124 | max_value=len(dfs_tilt),
125 | value=1,
126 | step=1
127 | ) - 1
128 |
129 | selected_df = dfs_tilt[selected_idx].copy()
130 | tilt_col = "_rlnTomoNominalStageTiltAngle"
131 | if tilt_col not in selected_df.columns:
132 | st.error(f"Column '{tilt_col}' not found in tilt-series data.")
133 | return
134 |
135 | # Convert from degrees to radians
136 | selected_df["TiltAngleDeg"] = selected_df[tilt_col].astype(float)
137 | selected_df["TiltAngleRad"] = np.deg2rad(selected_df["TiltAngleDeg"])
138 |
139 | ##############################
140 | # Build lines/points for the main (blue) dataset
141 | ##############################
142 | lines_x_blue, lines_y_blue = [], []
143 | points_x_blue, points_y_blue = [], []
144 | text_labels_blue = []
145 |
146 | for _, row in selected_df.iterrows():
147 | angle_deg = row["TiltAngleDeg"]
148 | angle_rad = row["TiltAngleRad"]
149 | x_plus = np.cos(angle_rad)
150 | y_plus = np.sin(angle_rad)
151 | x_minus = -x_plus
152 | y_minus = -y_plus
153 |
154 | # lines
155 | lines_x_blue.extend([x_minus, x_plus, None])
156 | lines_y_blue.extend([y_minus, y_plus, None])
157 |
158 | # points at +r
159 | points_x_blue.append(x_plus)
160 | points_y_blue.append(y_plus)
161 | text_labels_blue.append(f"{row['FileSource']}
Tilt={angle_deg:.1f}°")
162 |
163 | ##############################
164 | # Check if we have a matching source dataset for this tilt series
165 | ##############################
166 | missing_df = pd.DataFrame()
167 | selected_file_source = selected_df["FileSource"].iloc[0]
168 |
169 | if len(dfs_source_tilt) > 0:
170 | # Find a source DF that has the same FileSource
171 | possible_matches = [
172 | df_s for df_s in dfs_source_tilt
173 | if (df_s["FileSource"].unique()[0] == selected_file_source)
174 | ]
175 | if len(possible_matches) == 1:
176 | source_df = possible_matches[0].copy()
177 | if tilt_col in source_df.columns:
178 | source_df["TiltAngleDeg"] = source_df[tilt_col].astype(float)
179 | # Identify angles that are in source_df but missing from selected_df
180 | missing_df = source_df[~source_df["TiltAngleDeg"].isin(selected_df["TiltAngleDeg"])]
181 | missing_df["TiltAngleRad"] = np.deg2rad(missing_df["TiltAngleDeg"])
182 |
183 | ##############################
184 | # Build lines/points for the missing (red) dataset
185 | ##############################
186 | lines_x_red, lines_y_red = [], []
187 | points_x_red, points_y_red = [], []
188 | text_labels_red = []
189 |
190 | if not missing_df.empty:
191 | for _, row in missing_df.iterrows():
192 | angle_deg = row["TiltAngleDeg"]
193 | angle_rad = row["TiltAngleRad"]
194 | x_plus = np.cos(angle_rad)
195 | y_plus = np.sin(angle_rad)
196 | x_minus = -x_plus
197 | y_minus = -y_plus
198 |
199 | lines_x_red.extend([x_minus, x_plus, None])
200 | lines_y_red.extend([y_minus, y_plus, None])
201 |
202 | points_x_red.append(x_plus)
203 | points_y_red.append(y_plus)
204 | text_labels_red.append(f"{row['FileSource']}
Tilt={angle_deg:.1f}°")
205 |
206 | ##############################
207 | # Plot with Plotly
208 | ##############################
209 | fig = go.Figure()
210 |
211 | # 1) Circle boundary
212 | fig.add_shape(
213 | type="circle",
214 | xref="x", yref="y",
215 | x0=-1, x1=1,
216 | y0=-1, y1=1,
217 | line=dict(color="lightgray")
218 | )
219 |
220 | # 2) Blue lines
221 | fig.add_trace(go.Scatter(
222 | x=lines_x_blue,
223 | y=lines_y_blue,
224 | mode="lines",
225 | line=dict(color="#007acc"), # blue
226 | name="Kept tilt lines", # LEGEND: updated
227 | hoverinfo="none",
228 | showlegend=False # LEGEND: hidden
229 | ))
230 |
231 | # 3) Blue points
232 | fig.add_trace(go.Scatter(
233 | x=points_x_blue,
234 | y=points_y_blue,
235 | mode="markers",
236 | marker=dict(color="#007acc", size=8),
237 | text=text_labels_blue,
238 | hovertemplate="%{text}",
239 | name="Kept tilt angles" # LEGEND: updated
240 | ))
241 |
242 | # 4) Red lines for missing angles
243 | if not missing_df.empty:
244 | fig.add_trace(go.Scatter(
245 | x=lines_x_red,
246 | y=lines_y_red,
247 | mode="lines",
248 | line=dict(color="#cc3333", width=3), # red
249 | hoverinfo="none",
250 | showlegend=False # LEGEND: hidden
251 | ))
252 | # 5) Red points
253 | fig.add_trace(go.Scatter(
254 | x=points_x_red,
255 | y=points_y_red,
256 | mode="markers",
257 | marker=dict(color="#cc3333", size=15),
258 | text=text_labels_red,
259 | hovertemplate="%{text}",
260 | name="Excluded tilt angles" # LEGEND: updated
261 | ))
262 |
263 | # Make it square
264 | fig.update_xaxes(
265 | range=[-1.1, 1.1],
266 | scaleanchor="y",
267 | scaleratio=1,
268 | showgrid=False,
269 | zeroline=False,
270 | visible=False
271 | )
272 | fig.update_yaxes(
273 | range=[-1.1, 1.1],
274 | showgrid=False,
275 | zeroline=False,
276 | visible=False
277 | )
278 |
279 | fig.update_layout(
280 | title=f"Tilt Angles: {selected_file_source}",
281 | showlegend=True, # LEGEND: now shown
282 | width=600,
283 | height=600
284 | )
285 |
286 | with col2:
287 | st.write("**Tilt Angles**")
288 | st.plotly_chart(fig, use_container_width=True)
289 |
290 | if "AlignTiltSeries" in node_files:
291 | logger.info("AlignTiltSeries job detected.")
292 |
293 | # plot the statistics from the tilt alignment: _rlnTomoXTilt, _rlnTomoYTilt _rlnTomoZRot _rlnTomoXShiftAngst _rlnTomoYShiftAngst using plotly
294 | fig2 = go.Figure()
295 | fig2.add_trace(go.Scatter(
296 | x=selected_df["TiltAngleDeg"],
297 | y=selected_df["_rlnTomoXTilt"],
298 | mode="markers",
299 | marker=dict(color="#007acc", size=10),
300 | name="_rlnTomoXTilt"
301 | ))
302 | fig2.add_trace(go.Scatter(
303 | x=selected_df["TiltAngleDeg"],
304 | y=selected_df["_rlnTomoYTilt"],
305 | mode="markers",
306 | marker=dict(color="#cc3333", size=10),
307 | name="_rlnTomoYTilt"
308 | ))
309 | fig2.add_trace(go.Scatter(
310 | x=selected_df["TiltAngleDeg"],
311 | y=selected_df["_rlnTomoZRot"],
312 | mode="markers",
313 | marker=dict(color="#ffd354", size=10),
314 | name="_rlnTomoZRot"
315 | ))
316 | fig2.add_trace(go.Scatter(
317 | x=selected_df["TiltAngleDeg"],
318 | y=selected_df["_rlnTomoXShiftAngst"],
319 | mode="markers",
320 | marker=dict(color="#b850c8", size=10),
321 | name="_rlnTomoXShiftAngst"
322 | ))
323 | fig2.add_trace(go.Scatter(
324 | x=selected_df["TiltAngleDeg"],
325 | y=selected_df["_rlnTomoYShiftAngst"],
326 | mode="markers",
327 | marker=dict(color="#45ab84", size=10),
328 | name="_rlnTomoYShiftAngst"
329 | ))
330 | fig2.update_layout(
331 | title=f"Tilt Angles: {selected_file_source}",
332 | xaxis_title="Tilt Angle (deg)",
333 | yaxis_title="Tilt Alignment Statistics",
334 | width=600,
335 | height=600
336 | )
337 | col3.write("**Tilt Alignment Statistics**")
338 | col3.plotly_chart(fig2, use_container_width=True)
339 |
340 | logger.info(f"{datetime.now()}: plot_exclude_tilt completed successfully with Plotly.")
341 |
--------------------------------------------------------------------------------
/relion_jobs/polish_job.py:
--------------------------------------------------------------------------------
1 | #polish_job.py
2 |
3 | import os
4 | import traceback
5 | from datetime import datetime
6 | from typing import List
7 | import logging
8 |
9 | import numpy as np
10 | import pandas as pd
11 | import streamlit as st
12 | import plotly.graph_objects as go
13 |
14 | # Import shared utilities (ensure these functions are defined in your project)
15 | from lib.utils import parse_star
16 | from lib.image_utils import report_error
17 | from relion_jobs.extract_job import show_random_particles
18 |
19 | logger = logging.getLogger("main_app")
20 |
21 |
22 | def plot_polish(FOLDER: str, node_files: List[str]) -> None:
23 | """
24 | Main function to plot polish job results in Streamlit.
25 |
26 | Depending on the contents of node_files, this function handles:
27 | 1) Optimal parameter reporting (if "opt_params_all_groups.txt" is found),
28 | 2) B-factor and Guinier analysis (if "shiny.star" is found),
29 | 3) Tomogram motion visualization (if "tomograms.star" is found).
30 |
31 | The module uses Plotly for plotting and caches loaded data in session state.
32 | """
33 | train_job = False
34 |
35 | logger.debug(f"{datetime.now()}: plot_polish called with folder: {FOLDER} and node_files: {node_files}")
36 | for file in node_files:
37 | if 'star' in file:
38 | star_file_path = os.path.join(FOLDER, file)
39 | elif 'opt_params_all_groups.txt' in file:
40 |
41 | star_file_path = os.path.join(FOLDER, file)
42 | train_job = True
43 |
44 | job_path = os.path.dirname(star_file_path)
45 | logger.debug(f"Job path: {job_path}")
46 |
47 | # Create a job-specific session state key.
48 | polish_key = f"polish_data_{job_path}"
49 | if polish_key not in st.session_state:
50 | st.session_state[polish_key] = {}
51 | cache = st.session_state[polish_key]
52 |
53 | # Case 1: Optimal parameters available.
54 | if any("opt_params_all_groups.txt" in element for element in node_files):
55 | logger.debug("Optimal parameters found.")
56 | params_path = star_file_path
57 | try:
58 | with open(params_path, "r") as f:
59 | parameters = f.readline().strip().split()
60 | st.markdown("### Optimal Parameters")
61 | st.code(f"--s_vel {parameters[0]} --s_div {parameters[1]} --s_acc {parameters[2]}")
62 | except Exception as exc:
63 | st.error(f"Error reading optimal parameters: {exc}")
64 | logger.error(f"Optimal parameters error: {exc}\n{traceback.format_exc()}")
65 | return
66 |
67 | # Case 2: B-factor analysis (shiny.star found).
68 | elif any("shiny.star" in element for element in node_files):
69 | try:
70 | bfactors_star_path = os.path.join(job_path, "bfactors.star")
71 | logger.debug(f"Loading B-factors from: {bfactors_star_path}")
72 | # Cache parsed bfactors data.
73 | bfactors_cache_key = f"bfactors_data_{job_path}"
74 | logger.debug(f"Cache key for bfactors: {bfactors_cache_key}")
75 | if bfactors_cache_key in cache:
76 | bfactors_data = cache[bfactors_cache_key]
77 | logger.debug("Using cached bfactors data.")
78 | else:
79 | bfactors_data = parse_star(bfactors_star_path)["perframe_bfactors"]
80 | cache[bfactors_cache_key] = bfactors_data
81 |
82 | # Build a Plotly figure with two traces.
83 | fig = go.Figure()
84 | fig.add_trace(
85 | go.Scatter(
86 | x=bfactors_data["_rlnMovieFrameNumber"],
87 | y=bfactors_data["_rlnBfactorUsedForSharpening"],
88 | mode="lines",
89 | name="Bfactor Used For Sharpening",
90 | line=dict(color="darkgoldenrod", width=3),
91 | hovertemplate="Movie Frame: %{x}
Bfactor: %{y}"
92 | )
93 | )
94 | fig.add_trace(
95 | go.Scatter(
96 | x=bfactors_data["_rlnMovieFrameNumber"],
97 | y=bfactors_data["_rlnFittedInterceptGuinierPlot"],
98 | mode="lines",
99 | name="Fitted Intercept Guinier Plot",
100 | line=dict(color="lightseagreen", width=3),
101 | yaxis="y2",
102 | hovertemplate="Movie Frame: %{x}
Intercept: %{y}"
103 | )
104 | )
105 | fig.update_layout(
106 | title="Polish Job Statistics",
107 | xaxis=dict(title="Movie Frame Number"),
108 | yaxis=dict(title="Bfactor Used For Sharpening"),
109 | yaxis2=dict(
110 | title="Fitted Intercept Guinier Plot", overlaying="y", side="right"
111 | ),
112 | margin=dict(l=20, r=20, t=40, b=20),
113 | paper_bgcolor="white",
114 | )
115 | st.plotly_chart(fig, use_container_width=True)
116 |
117 | if st.checkbox(":gem: Show random shiny particles?"):
118 | # Assume show_random_particles is defined.
119 | c1, c2 = st.columns([1,4])
120 | show_random_particles(star_file_path, FOLDER, c1, c2)
121 |
122 | except Exception as exc:
123 | logger.error(f"{datetime.now()}: Error in bfactors branch:\n{report_error(exc)}")
124 | st.error("Error loading B-factors data")
125 | return
126 |
127 | # Case 3: Tomogram motion analysis.
128 | elif any("tomograms.star" in element for element in node_files):
129 | try:
130 | # Paths: use node_files[2] for particle star and node_files[3] for motion star.
131 | particle_star_path = os.path.join(FOLDER, node_files[2])
132 | motion_star_path = os.path.join(FOLDER, node_files[3])
133 |
134 | # Load particle star data and cache it.
135 | particle_cache_key = f"particle_star_{job_path}"
136 | if particle_cache_key in cache:
137 | particle_star = cache[particle_cache_key]
138 | logger.debug("Using cached particle star data.")
139 | else:
140 | particle_star = parse_star(particle_star_path)["particles"]
141 | # Keep only the relevant columns.
142 | particle_star = particle_star[
143 | ["_rlnTomoName", "_rlnCenteredCoordinateXAngst",
144 | "_rlnCenteredCoordinateYAngst", "_rlnCenteredCoordinateZAngst"]
145 | ]
146 | for col in particle_star.columns:
147 | if col.startswith("_rln"):
148 | try:
149 | particle_star[col] = particle_star[col].astype(float)
150 | except ValueError:
151 | logger.warning(f"Column {col} cannot be converted to float.")
152 |
153 | #particle_star = convert_to_float(particle_star)
154 | cache[particle_cache_key] = particle_star
155 |
156 | unique_tomo_names = np.unique(particle_star["_rlnTomoName"])
157 | idx = 0
158 |
159 | c1, c2 = st.columns([1, 5])
160 |
161 |
162 | if unique_tomo_names.size > 1:
163 | idx = c1.slider("Tomogram index", 0, unique_tomo_names.size - 1, 0)
164 | tomo_name = unique_tomo_names[idx]
165 | particle_star_selected = particle_star[particle_star["_rlnTomoName"] == tomo_name]
166 | st.markdown(
167 | f"**Selected Tomogram:** {tomo_name}. **Number of particles:** {particle_star_selected.shape[0]}"
168 | )
169 |
170 | # Load or cache motion data.
171 | temp_folder = os.path.join(os.path.dirname(particle_star_path), "temp")
172 | temp_motion_file_path = os.path.join(temp_folder, f"{tomo_name}_motion.star")
173 | if os.path.exists(temp_motion_file_path):
174 | motion_cache_key = f"motion_data_{tomo_name}"
175 | if motion_cache_key in cache:
176 | motion_file = cache[motion_cache_key]
177 | logger.debug("Using cached motion star data.")
178 | else:
179 | motion_file = parse_star(temp_motion_file_path)
180 | cache[motion_cache_key] = motion_file
181 |
182 | # Calculate total motion per tomogram block.
183 | all_motion = []
184 |
185 | for key in motion_file.keys():
186 | try:
187 | motion_df = motion_file[key].astype(float, errors='ignore')
188 | x_trace = np.mean(motion_df["_rlnOriginXAngst"])
189 | y_trace = np.mean(motion_df["_rlnOriginYAngst"])
190 | z_trace = np.mean(motion_df["_rlnOriginZAngst"])
191 | total_motion = np.abs(x_trace + y_trace + z_trace)
192 | all_motion.append(total_motion)
193 | except Exception as exc:
194 | logger.error(f"Error processing motion data for key {key}: {exc}")
195 | all_motion.append(0)
196 |
197 | # Create a 3D scatter plot for particle positions, colored by total motion.
198 | scatter = go.Scatter3d(
199 | x=particle_star_selected["_rlnCenteredCoordinateXAngst"],
200 | y=particle_star_selected["_rlnCenteredCoordinateYAngst"],
201 | z=particle_star_selected["_rlnCenteredCoordinateZAngst"],
202 | mode="markers",
203 | marker=dict(
204 | size=5,
205 | color=all_motion,
206 | colorscale="Inferno",
207 | colorbar=dict(title="Total Motion"),
208 | ),
209 | name="Particles",
210 | hovertemplate="X: %{x}
Y: %{y}
Z: %{z}"
211 | )
212 | fig_motion = go.Figure(data=[scatter])
213 | fig_motion.update_layout(
214 | scene=dict(
215 | xaxis_title="X (Å)",
216 | yaxis_title="Y (Å)",
217 | zaxis_title="Z (Å)",
218 | aspectmode="data",
219 | ),
220 | height=600,
221 | title="Total Motion of Particles in Tomogram (summed XYZ motion)",
222 | showlegend=False,
223 | )
224 | c2.plotly_chart(fig_motion, use_container_width=True)
225 | else:
226 | # Fallback: if the temp motion file does not exist, call process_tomo_motion.
227 | st.warning("No temporary motion file found; please run the motion correction step.")
228 | # You may call: process_tomo_motion(motion_star_path)
229 |
230 | except Exception as exc:
231 | logger.error(f"{datetime.now()}: Error in tomogram branch: {exc}\n{traceback.format_exc()}")
232 | st.error("Error processing tomogram motion data.")
233 | return
234 |
235 | else:
236 | st.write("No relevant data found for the Polish job.")
237 |
238 | logger.info(f"{datetime.now()}: plot_polish done")
239 |
240 |
241 | def process_tomo_motion(filename: str):
242 | """
243 | Example function to process a tomogram motion star file.
244 | Splits the file into blocks and returns a list of (block_name, DataFrame).
245 | """
246 | try:
247 | with open(filename, 'r') as file:
248 | lines = file.readlines()
249 | blocks = []
250 | block = []
251 | for line in lines:
252 | if line.startswith('data_'):
253 | if block:
254 | blocks.append(block)
255 | block = [line]
256 | else:
257 | block.append(line)
258 | if block:
259 | blocks.append(block)
260 |
261 | dataframes = []
262 | for block in blocks:
263 | block_name = block[0].strip()
264 | data_section_started = False
265 | columns = []
266 | data_rows = []
267 | for line in block:
268 | line = line.strip()
269 | if line.startswith('_rln'):
270 | columns.append(line.split()[0])
271 | elif data_section_started:
272 | if line.startswith('#') or 'None' in line:
273 | continue
274 | data_row = line.split()
275 | data_rows.append(data_row)
276 | elif line.startswith('loop_'):
277 | data_section_started = True
278 | df = pd.DataFrame(data_rows, columns=columns).dropna()
279 | df = df.astype(float, errors='ignore')
280 | dataframes.append((block_name, df))
281 | return dataframes
282 | except Exception as exc:
283 | logger.error(f"Error processing tomogram motion file: {exc}")
284 | st.error("Error processing tomogram motion file.")
285 | return []
286 |
--------------------------------------------------------------------------------
/relion_jobs/import_job.py:
--------------------------------------------------------------------------------
1 | #import_job.py
2 |
3 | # Standard Library Imports
4 | import logging
5 | import os
6 | from concurrent.futures import ThreadPoolExecutor
7 | from datetime import datetime
8 | from typing import Dict, List, Optional, Tuple
9 |
10 | # Third-Party Imports
11 | import numpy as np
12 | import pandas as pd
13 | import plotly.graph_objects as go
14 | import streamlit as st
15 |
16 | # Local Imports
17 | from lib.image_utils import micrograph_viewer
18 | from lib.utils import (
19 | get_first_key,
20 | get_modification_time,
21 | parse_star,
22 | report_error,
23 | )
24 |
25 | # =============================================================================
26 | # Logger Setup
27 | # =============================================================================
28 | logger = logging.getLogger("main_app") # Use logger from main app
29 |
30 |
31 | def plot_tomogram_picks(
32 | n: int, # Tomogram index (for context/logging)
33 | coords_sel: Dict[str, List[float]],
34 | coords_rej: Dict[str, List[float]],
35 | scatter_size: Tuple[int, int] = (5, 2),
36 | ) -> go.Figure:
37 | """
38 | Creates a 3D scatter plot of selected and rejected tomogram coordinates.
39 |
40 | Args:
41 | n: Tomogram index (used mainly for logging/titles).
42 | coords_sel: Dict with '_rlnCoordinateX/Y/Z' lists for selected points.
43 | coords_rej: Dict with '_rlnCoordinateX/Y/Z' lists for rejected points.
44 | scatter_size: Tuple of marker sizes (selected, rejected).
45 |
46 | Returns:
47 | A Plotly Figure object containing the 3D scatter plot.
48 | """
49 | fig = go.Figure()
50 | required_keys = ["_rlnCoordinateX", "_rlnCoordinateY", "_rlnCoordinateZ"]
51 |
52 | try:
53 | # Plot selected points if data exists and is valid
54 | sel_x, sel_y, sel_z = [], [], []
55 | if coords_sel and all(k in coords_sel for k in required_keys):
56 | sel_x, sel_y, sel_z = (
57 | coords_sel["_rlnCoordinateX"],
58 | coords_sel["_rlnCoordinateY"],
59 | coords_sel["_rlnCoordinateZ"],
60 | )
61 | if sel_x: # Check if list is not empty
62 | fig.add_trace(
63 | go.Scatter3d(
64 | x=sel_x,
65 | y=sel_y,
66 | z=sel_z,
67 | mode="markers",
68 | marker=dict(size=scatter_size[0], color="green", opacity=0.7),
69 | name=f"Selected ({len(sel_x)})",
70 | customdata=np.arange(len(sel_x)),
71 | hovertemplate="Selected Pick
X: %{x:.1f}
Y: %{y:.1f}
Z: %{z:.1f}",
72 | )
73 | )
74 | else:
75 | logger.info(f"No selected coordinates to plot for tomogram index {n}.")
76 | else:
77 | logger.info(
78 | f"Selected coordinates data structure invalid or missing keys for tomogram index {n}."
79 | )
80 |
81 | # Plot rejected points if data exists and is valid
82 | rej_x, rej_y, rej_z = [], [], []
83 | if coords_rej and all(k in coords_rej for k in required_keys):
84 | rej_x, rej_y, rej_z = (
85 | coords_rej["_rlnCoordinateX"],
86 | coords_rej["_rlnCoordinateY"],
87 | coords_rej["_rlnCoordinateZ"],
88 | )
89 | if rej_x: # Check if list is not empty
90 | fig.add_trace(
91 | go.Scatter3d(
92 | x=rej_x,
93 | y=rej_y,
94 | z=rej_z,
95 | mode="markers",
96 | marker=dict(size=scatter_size[1], color="red", opacity=0.6),
97 | name=f"Rejected ({len(rej_x)})",
98 | customdata=np.arange(len(rej_x)),
99 | hovertemplate="Rejected Pick
X: %{x:.1f}
Y: %{y:.1f}
Z: %{z:.1f}",
100 | )
101 | )
102 | else:
103 | logger.info(f"No rejected coordinates to plot for tomogram index {n}.")
104 | else:
105 | logger.info(
106 | f"Rejected coordinates data structure invalid or missing keys for tomogram index {n}."
107 | )
108 |
109 | except KeyError as e:
110 | report_error(
111 | e,
112 | f"Missing coordinate key '{e}' while plotting tomogram picks for index {n}",
113 | )
114 | st.warning(f"Coordinate data missing key: {e}. Cannot plot picks fully.")
115 | except Exception as e:
116 | report_error(e, f"Error plotting tomogram picks for index {n}")
117 | st.warning(f"Could not plot tomogram picks: {e}")
118 |
119 | fig.update_layout(
120 | title=f"Tomogram {n + 1}: Particle Picks 3D View",
121 | scene=dict(
122 | xaxis_title="X (px)",
123 | yaxis_title="Y (px)",
124 | zaxis_title="Z (px)",
125 | aspectmode="data",
126 | ),
127 | legend_title="Pick Status",
128 | hovermode="closest",
129 | height=600,
130 | margin=dict(l=10, r=10, t=50, b=10),
131 | )
132 | return fig
133 |
134 |
135 | def plot_import(rln_folder: str, node_files: List[str]) -> None:
136 | """
137 | Processes and displays information for a RELION Import job.
138 |
139 | Handles visualization for imported movies, tomograms (including Relion 5
140 | STAR-based imports), and particle coordinates. Logs errors instead of
141 | showing st.error messages.
142 |
143 | Args:
144 | rln_folder: Base directory of the RELION project.
145 | node_files: List of node file paths associated with this job.
146 | """
147 | if not node_files:
148 | st.warning("No node files provided for Import job.")
149 | return
150 |
151 | star_file_path = os.path.join(rln_folder, node_files[0])
152 | logger.info(f"Processing Import job using STAR file: {star_file_path}")
153 |
154 | if not os.path.exists(star_file_path):
155 | logger.error(f"Import job STAR file not found: {star_file_path}")
156 | st.warning(f"Import STAR file '{node_files[0]}' not found.")
157 | return
158 |
159 | try:
160 | star_data_all = parse_star(star_file_path)
161 | if not star_data_all:
162 | st.warning(f"Could not parse or data empty in: {node_files[0]}")
163 | return
164 | except Exception as e:
165 | report_error(e, f"Failed to parse STAR file {node_files[0]} for Import job")
166 | st.warning(f"Error parsing '{node_files[0]}'. Check logs.")
167 | return
168 |
169 | # Determine data type
170 | is_movie_import = any("movies" in node for node in node_files)
171 | is_tomo_import = any("tilt_series" in node for node in node_files) or any(
172 | "tomograms" in node for node in node_files
173 | )
174 | is_particle_import = any("particles" in node for node in node_files)
175 |
176 | # --- Movie/Tomogram Data Processing ---
177 | if is_movie_import or is_tomo_import:
178 | logger.info("Processing movie/tomogram import.")
179 | file_names: Optional[List[str]] = None
180 |
181 | try:
182 | # --- Corrected Logic to find the data block ---
183 | data_block = None
184 | movies_block = star_data_all.get("movies")
185 | if isinstance(movies_block, pd.DataFrame) and not movies_block.empty:
186 | data_block = movies_block
187 | logger.debug("Using 'movies' data block.")
188 | else:
189 | global_block = star_data_all.get("global")
190 | if isinstance(global_block, pd.DataFrame) and not global_block.empty:
191 | data_block = global_block
192 | logger.debug("Using 'global' data block.")
193 | else:
194 | first_key = get_first_key(star_data_all)
195 | if first_key:
196 | first_block = star_data_all.get(first_key)
197 | if (
198 | isinstance(first_block, pd.DataFrame)
199 | and not first_block.empty
200 | ):
201 | data_block = first_block
202 | logger.debug(f"Using first data block found: '{first_key}'")
203 |
204 | if data_block is None:
205 | block_names = list(star_data_all.keys())
206 | msg = f"Could not find a suitable non-empty data block ('movies', 'global', or first) in STAR file. Blocks: {block_names}"
207 | logger.error(msg)
208 | raise KeyError(msg)
209 | # --- End Corrected Logic ---
210 |
211 | # Extract filenames
212 | if "_rlnMicrographMovieName" in data_block:
213 | file_names = data_block["_rlnMicrographMovieName"].tolist()
214 | logger.debug("Found movie names in _rlnMicrographMovieName.")
215 | elif "_rlnTomoTiltSeriesName" in data_block:
216 | file_names = data_block["_rlnTomoTiltSeriesName"].tolist()
217 | logger.debug("Found tomogram base names in _rlnTomoTiltSeriesName.")
218 | elif "_rlnTomoTiltSeriesStarFile" in data_block: # Relion 5+ Tomo
219 | tomo_star_files = data_block["_rlnTomoTiltSeriesStarFile"].tolist()
220 | logger.info(
221 | f"Relion 5 tomo import: {len(tomo_star_files)} tilt series STARs."
222 | )
223 | all_tilt_movie_names = []
224 | failed_parses = 0
225 | for ts_star_rel_path in tomo_star_files:
226 | ts_star_full_path = os.path.join(rln_folder, ts_star_rel_path)
227 | if not os.path.exists(ts_star_full_path):
228 | logger.warning(
229 | f"Tilt series STAR not found: {ts_star_full_path}"
230 | )
231 | failed_parses += 1
232 | continue
233 | ts_data = parse_star(ts_star_full_path)
234 | ts_movies_block = ts_data.get(get_first_key(ts_data))
235 | if (
236 | ts_movies_block is not None
237 | and "_rlnMicrographMovieName" in ts_movies_block
238 | ):
239 | all_tilt_movie_names.extend(
240 | ts_movies_block["_rlnMicrographMovieName"].tolist()
241 | )
242 | else:
243 | logger.warning(
244 | f"No movie names in tilt series STAR: {ts_star_rel_path}"
245 | )
246 | failed_parses += 1
247 | if failed_parses > 0:
248 | st.warning(
249 | f"Failed to process {failed_parses} tilt series STAR file(s)."
250 | )
251 | file_names = all_tilt_movie_names
252 | else:
253 | raise KeyError("No known movie/tomogram filename column found.")
254 |
255 | except KeyError as e:
256 | report_error(e, f"Missing key in {node_files[0]} for movie/tomo filenames")
257 | st.warning(
258 | f"Could not find movie/tomogram filenames in STAR file: {e}. See logs."
259 | )
260 | return
261 | except Exception as e:
262 | report_error(
263 | e, f"Error processing movie/tomo import data from {node_files[0]}"
264 | )
265 | st.warning(f"Error processing movie/tomo data: {e}. See logs.")
266 | return
267 |
268 | if not file_names:
269 | st.warning("No movie or tomogram file paths were found in the STAR file.")
270 | return
271 |
272 | # --- Modification Time Plot ---
273 | st.subheader(f"Imported {len(file_names)} Files")
274 | file_names_abs = [os.path.join(rln_folder, fn) for fn in file_names]
275 | limit_display = 100
276 | files_to_check = file_names_abs
277 | if len(file_names_abs) > limit_display:
278 | if st.checkbox(
279 | f"Limit timeline plot to {limit_display} random files?",
280 | value=True,
281 | key="limit_import_plot",
282 | ):
283 | indices = np.random.choice(
284 | len(file_names_abs), limit_display, replace=False
285 | )
286 | files_to_check = [file_names_abs[i] for i in indices]
287 | st.caption(
288 | f"Showing timeline for {limit_display}/{len(file_names_abs)} files."
289 | )
290 |
291 | with st.spinner(
292 | f"Loading modification times for {len(files_to_check)} files..."
293 | ):
294 | try:
295 | with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
296 | mod_times_opt: List[Optional[datetime]] = list(
297 | executor.map(get_modification_time, files_to_check)
298 | )
299 | file_mod_times: List[datetime] = sorted(
300 | [mt for mt in mod_times_opt if mt is not None]
301 | )
302 | except Exception as e:
303 | report_error(e, "Error fetching modification times.")
304 | st.warning(f"Could not fetch file modification times: {e}. See logs.")
305 | file_mod_times = []
306 |
307 | if file_mod_times:
308 | fig = go.Figure()
309 | fig.add_trace(
310 | go.Scatter(
311 | x=list(range(len(file_mod_times))),
312 | y=file_mod_times,
313 | mode="markers",
314 | name="Timestamp",
315 | marker=dict(size=4),
316 | hovertemplate="Index: %{x}
Time: %{y|%Y-%m-%d %H:%M:%S}",
317 | )
318 | )
319 | fig.update_layout(
320 | title="File Import Timeline (by modification date)",
321 | xaxis_title="File Index (sorted by time)",
322 | yaxis_title="Modification Timestamp",
323 | height=300,
324 | margin=dict(l=50, r=20, t=40, b=40),
325 | )
326 | with st.expander("Show Import Timeline", expanded=False):
327 | st.plotly_chart(fig, use_container_width=True)
328 | else:
329 | st.caption("No valid modification times found.")
330 |
331 | # --- Micrograph Viewer ---
332 | st.subheader("Micrograph/Tomogram Viewer")
333 | try:
334 | micrograph_viewer(rln_folder, file_names)
335 | except Exception as e:
336 | report_error(e, "Error displaying micrograph viewer.")
337 | st.warning(f"Could not display image viewer: {e}. See logs.")
338 |
339 | # --- Particle Data Processing ---
340 | elif is_particle_import:
341 | logger.info("Processing particle import.")
342 | particles_block = star_data_all.get("particles")
343 | if particles_block is None or particles_block.empty:
344 | st.warning("No 'particles' data block found in the STAR file.")
345 | return
346 |
347 | if "_rlnCoordinateZ" in particles_block.columns:
348 | logger.info("Detected tomogram particle coordinates.")
349 | st.subheader("Tomogram Particle Picks")
350 | tomo_name_col = "_rlnTomoName"
351 | if tomo_name_col not in particles_block.columns:
352 | logger.error(
353 | f"Missing '{tomo_name_col}' column for grouping particles."
354 | )
355 | st.warning(
356 | f"Missing '{tomo_name_col}' column. Cannot display picks by tomogram."
357 | )
358 | return
359 |
360 | unique_tomos = particles_block[tomo_name_col].unique()
361 | if len(unique_tomos) == 0:
362 | st.warning("No tomogram names found in particle data.")
363 | return
364 |
365 | tomo_idx = st.slider(
366 | "Select Tomogram Index:",
367 | 0,
368 | len(unique_tomos) - 1,
369 | 0,
370 | key="import_particle_tomo_idx",
371 | )
372 | selected_tomo_name = unique_tomos[tomo_idx]
373 | st.caption(f"Showing picks for: {selected_tomo_name}")
374 |
375 | tomo_coords_df = particles_block[
376 | particles_block[tomo_name_col] == selected_tomo_name
377 | ]
378 | # Ensure required columns exist before creating dict
379 | coord_cols = ["_rlnCoordinateX", "_rlnCoordinateY", "_rlnCoordinateZ"]
380 | if not all(col in tomo_coords_df.columns for col in coord_cols):
381 | logger.error(
382 | "Missing coordinate columns in filtered particle data for selected tomogram."
383 | )
384 | st.warning(
385 | "Missing coordinate columns for selected tomogram. Cannot plot picks."
386 | )
387 | return
388 |
389 | coords_dict = {k: tomo_coords_df[k].tolist() for k in coord_cols}
390 | coords_rej_empty: Dict[str, List[float]] = {k: [] for k in coord_cols}
391 |
392 | fig_3d = plot_tomogram_picks(tomo_idx, coords_dict, coords_rej_empty)
393 | st.plotly_chart(fig_3d, use_container_width=True)
394 |
395 | else: # SPA particle import
396 | logger.info("Detected SPA particle coordinates.")
397 | st.info(
398 | f"SPA particle coordinates imported ({len(particles_block)} particles). No 3D plot generated."
399 | )
400 |
401 | else:
402 | st.warning(
403 | "Could not determine data type (movies/tomograms/particles) from node files."
404 | )
405 | logger.warning(f"Import job type unclear for node files: {node_files}")
406 |
407 | logger.info(f"Finished processing Import job: {node_files[0]}")
408 |
--------------------------------------------------------------------------------
/relion_jobs/class3d_job.py:
--------------------------------------------------------------------------------
1 | # class3d_job.py
2 | import os
3 | import glob
4 | import logging
5 | import traceback
6 | import math
7 | from datetime import datetime
8 | from typing import List
9 | import re
10 |
11 | import numpy as np
12 | import pandas as pd
13 | import mrcfile
14 | import streamlit as st
15 | from plotly.subplots import make_subplots
16 |
17 | # Shared utilities from your project.
18 | from lib.utils import (
19 | parse_star,
20 | interactive_scatter_plot,
21 | report_error,
22 | get_angles,
23 | get_classes,
24 | get_note,
25 | )
26 | from lib.image_utils import (downsample_volume, plot_angular_distribution_sphere, plot_fsc_stats, plot_class_resolution,
27 | plot_class_distribution, plot_projections, plot_volume, plot_angular_distribution_heatmap)
28 |
29 |
30 | logger = logging.getLogger("main_app")
31 |
32 |
33 |
34 | def _load_and_cache_job_data(rln_folder: str, job_name: str, nodes: List[str]) -> dict:
35 | """
36 | Load all volumes, class distributions, resolutions, etc. for the given job_name.
37 | Downsample volumes and store everything in st.session_state[data_key].
38 | """
39 | data_key = f"class3d_data_{job_name}"
40 | if data_key not in st.session_state:
41 | st.session_state[data_key] = {}
42 |
43 | # If already loaded, just return
44 | if st.session_state[data_key].get("loaded", False):
45 | logger.debug(f"Data for job '{job_name}' found in session state.")
46 | return st.session_state[data_key]
47 |
48 | with st.spinner("Loading job data from disk..."):
49 | # 1) Construct path to job folder
50 | path_data = os.path.join(rln_folder, job_name)
51 |
52 | # 2) Gather model.star files
53 | model_files = glob.glob(os.path.join(path_data, "*model.star"))
54 | model_files.sort(key=os.path.getmtime)
55 | logger.debug(f"Found {len(model_files)} model.star files for job '{job_name}': {model_files}")
56 |
57 | # get symmetry information from the note
58 | symmetry = 'C1'
59 | note = get_note(os.path.join(path_data, 'note.txt'))
60 | logger.debug(f"Note file content: {note}")
61 |
62 | # find --sym XXX in the note using re
63 | if note:
64 | match = re.search(r'--sym\s+([A-Z]\d+)', note)
65 | if match:
66 | symmetry = match.group(1)
67 | logger.debug(f"Found symmetry: {symmetry}")
68 |
69 | # 3) If no model.star, fallback to MRC in the nodes
70 | if len(model_files) == 0:
71 | logger.info("No *_model.star files found. Will try to read MRC volumes directly.")
72 | mrc_list = [f for f in nodes if f.lower().endswith(".mrc")]
73 | logger.debug(f"original MRC files: {mrc_list}")
74 |
75 | if any("merged" in f for f in mrc_list):
76 | mrc_list = [f for f in mrc_list if "merged" in f]
77 | elif any("half1" in f for f in mrc_list):
78 | mrc_list = [f for f in mrc_list if "half1" in f]
79 |
80 | logger.debug(f"Filtered MRC files: {mrc_list}")
81 |
82 | volumes = []
83 | for mrc_file in mrc_list:
84 | volume_path = os.path.join(rln_folder, mrc_file)
85 | try:
86 | logger.debug(f"Loading MRC volume: {volume_path}")
87 | with mrcfile.mmap(volume_path, permissive=True) as mrcf:
88 | volumes.append(mrcf.data)
89 | except Exception:
90 | error = traceback.format_exc()
91 | logger.error(f"Error displaying volume {mrc_file}:\n{error}")
92 | st.session_state[data_key] = {
93 | "volumes_downsampled": volumes,
94 | "class_dist": [1.0] * len(volumes),
95 | "class_res": np.array([]),
96 | "fsc_res": np.array([]),
97 | "fsc_vals": np.array([]),
98 | "angles": ([], [], []),
99 | "star_block": None,
100 | "n_classes": len(volumes),
101 | "class_paths": mrc_list,
102 | "loaded": True,
103 | "symmetry": symmetry,
104 | }
105 | return st.session_state[data_key]
106 |
107 | # 4) If we have model.star, parse them
108 | try:
109 | (class_paths, n_classes, _iter_count,
110 | class_dist, class_res, fsc_res, fsc_vals) = get_classes(path_data, model_files)
111 | except Exception as e:
112 | report_error(e)
113 | logger.error(f"Error in get_classes: {e}")
114 | st.session_state[data_key] = {}
115 | return {}
116 |
117 | # 5) Load volumes
118 | volumes_raw = []
119 | for cls_path in class_paths:
120 | full_path = os.path.join(path_data, os.path.basename(cls_path))
121 | logger.debug(f"Loading volume: {full_path}")
122 | with mrcfile.mmap(full_path, permissive=True) as mrcf:
123 | volumes_raw.append(mrcf.data)
124 |
125 | # 6) Angles
126 | try:
127 | rot, tilt, psi = get_angles(path_data)
128 | except Exception as e:
129 | logger.debug(f"Failed to load angles from {path_data}: {e}")
130 | rot, tilt, psi = ([], [], [])
131 |
132 | # 7) Attempt to parse star file for star_block
133 | star_block = None
134 | try:
135 | if nodes:
136 | star_main = parse_star(os.path.join(path_data, os.path.basename(nodes[0])))
137 | star_block = star_main.get("particles", pd.DataFrame())
138 | except Exception as e:
139 | logger.debug(f"Failed to parse star file from nodes[0]: {e}")
140 |
141 | # 8) Store in st.session_state
142 | st.session_state[data_key] = {
143 | "volumes_raw": volumes_raw,
144 | "volumes_downsampled": [],
145 | "class_dist": class_dist,
146 | "class_res": class_res,
147 | "fsc_res": fsc_res,
148 | "fsc_vals": fsc_vals,
149 | "angles": (rot, tilt, psi),
150 | "star_block": star_block,
151 | "n_classes": n_classes,
152 | "class_paths": class_paths,
153 | "loaded": True,
154 | "symmetry": symmetry,
155 | }
156 |
157 | return st.session_state[data_key]
158 |
159 |
160 | def _ensure_downsampled(job_data: dict, map_resize: int) -> List[np.ndarray]:
161 | """
162 | Ensure the volumes in session_state are downsampled to 'map_resize' size.
163 | Only redo if the stored version does not match the requested size.
164 | """
165 | if not job_data.get("volumes_raw", []):
166 | return job_data.get("volumes_downsampled", [])
167 |
168 | # Check if we already have downsampled volumes for this size
169 | # For clarity, we store them in job_data["volumes_downsampled_{size}"] or a single key with the size
170 | cached_size = job_data.get("cached_map_resize", None)
171 | if cached_size == map_resize and job_data.get("volumes_downsampled", []):
172 | logger.debug(f"Using cached volumes downsampled to {map_resize}.")
173 | return job_data["volumes_downsampled"]
174 |
175 | # Otherwise, downsample now
176 | new_downsampled = []
177 | for vol in job_data["volumes_raw"]:
178 | dvol = downsample_volume(vol, map_resize)
179 | new_downsampled.append(dvol)
180 |
181 | job_data["volumes_downsampled"] = new_downsampled
182 | job_data["cached_map_resize"] = map_resize
183 | return new_downsampled
184 |
185 |
186 | def plot_combined_classes(volumes, class_dist, session_key="plot_combined_classes"):
187 | """
188 | Display multiple 3D volumes as isosurfaces in a grid of Plotly subplots,
189 | with user controls for threshold, resizing, columns, and row height.
190 | Uses the `plot_volume` function to generate iso-surfaces, caching the
191 | results in st.session_state to avoid repeated computation.
192 |
193 | 'session_key' can incorporate a job-specific identifier to keep caches separate.
194 | The background for each subplot is set to black.
195 | """
196 | logger.debug(f"{datetime.now()}: plot_combined_classes called with {len(volumes)} volumes.")
197 |
198 | # Initialize session state cache for this session_key if needed.
199 | if session_key not in st.session_state:
200 | st.session_state[session_key] = {}
201 |
202 | # Retrieve local iso-surface cache.
203 | iso_cache = st.session_state.get("iso_surface_cache", {})
204 | #logger.debug(f"Current iso_surface_cache: {iso_cache}")
205 |
206 | volumes_n = len(volumes)
207 | if volumes_n < 5:
208 | columns_to_show = volumes_n
209 | if volumes_n == 1:
210 | plot_height = 700
211 | else:
212 | plot_height = 600
213 | else:
214 | columns_to_show = 5
215 | plot_height = 300
216 |
217 |
218 |
219 | # Basic UI controls.
220 | c1, c2 = st.columns(2)
221 | with c1:
222 | threshold = st.slider("Select Volume Threshold (Fraction)", 0.0, 1.0, 0.5, 0.01)
223 | map_resize = st.slider("Map size (px)", 64, 256, 150, 2)
224 | with c2:
225 | n_columns = st.slider("Number of columns", min_value=1, max_value=5, value=columns_to_show, step=1)
226 | row_height = st.slider("Plot height", min_value=100, max_value=1000, value=plot_height, step=100)
227 |
228 | show_class = st.checkbox("Show class volumes?", value=True)
229 | if not show_class:
230 | st.info("Class volumes are hidden. Check the box to show them.")
231 | return
232 |
233 | num_classes = len(volumes)
234 | cols = min(n_columns, num_classes)
235 | rows = math.ceil(num_classes / n_columns)
236 |
237 | # Create subplots with type 'scene' for 3D.
238 | fig = make_subplots(rows=rows, cols=cols, specs=[[{"type": "scene"} for _ in range(cols)] for _ in range(rows)])
239 |
240 | annotations = []
241 | for idx, volume in enumerate(volumes):
242 | row_idx, col_idx = divmod(idx, cols)
243 | # Create a cache key from session_key, volume index, threshold, and map_resize.
244 | cache_key = (session_key, idx, threshold, map_resize)
245 | if cache_key in iso_cache:
246 | fig_ = iso_cache[cache_key]
247 | logger.debug(f"Reusing cached iso-surface for volume {idx+1}.")
248 | else:
249 | with st.spinner(f"Creating iso-surface for class {idx+1}..."):
250 | fig_ = plot_volume(volume, threshold, max_size=map_resize)
251 | iso_cache[cache_key] = fig_
252 |
253 | if fig_ and len(fig_.data) > 0:
254 | fig.add_trace(fig_.data[0], row=row_idx+1, col=col_idx+1)
255 | dist_percent = 0.0
256 | if idx < len(class_dist):
257 | dist_percent = round(float(class_dist[idx]) * 100, 2)
258 | ann_text = f"Class {idx+1}
Dist: {dist_percent}%"
259 | x_ = (col_idx + 0.5) / cols
260 | y_ = 1 - (row_idx / rows) - 0.05
261 | annotations.append(
262 | dict(
263 | x=x_,
264 | y=y_,
265 | xref="paper",
266 | yref="paper",
267 | text=ann_text,
268 | showarrow=False,
269 | xanchor="center",
270 | yanchor="bottom",
271 | font=dict(size=12),
272 | )
273 | )
274 |
275 | # Set overall layout with black background.
276 | fig.update_layout(
277 | hovermode=False,
278 | annotations=annotations,
279 | height=rows * row_height,
280 | margin=dict(l=0, r=0, t=0, b=0),
281 | paper_bgcolor="black", # overall background color
282 | )
283 | fig.update_layout(scene_dragmode='orbit', title='3D Volume Isosurfaces')
284 |
285 | # Configure each subplot's scene with a black background.
286 | for n in range(num_classes):
287 | scene_id = f"scene{n+1}" if n > 0 else "scene"
288 | fig.update_layout(
289 | **{
290 | scene_id: dict(
291 | xaxis=dict(visible=False),
292 | yaxis=dict(visible=False),
293 | zaxis=dict(visible=False),
294 | camera=dict(eye=dict(x=2.5, y=2.5, z=2.5)),
295 | bgcolor="black" # set each scene's background to black.
296 | )
297 | }
298 | )
299 |
300 | # Store the updated cache back into session state.
301 | st.session_state["iso_surface_cache"] = iso_cache
302 |
303 | st.plotly_chart(fig, use_container_width=True)
304 |
305 |
306 | def plot_class3d(rln_folder: str, nodes: List[str]) -> None:
307 | """
308 | Main function to plot the results of a 3D classification or refinement job,
309 | caching data in session state to avoid repeated I/O and repeated downsampling.
310 |
311 | This function:
312 | 1) Clears/sets a job-specific session state if the job changed.
313 | 2) Loads data for the job, downsampling volumes as needed.
314 | 3) Displays combined classes, volume projections, class distributions, etc.
315 | """
316 | logger.debug(f"{datetime.now()}: plot_class3d called with nodes={nodes}")
317 | if not nodes:
318 | st.warning("No files provided in 'nodes'.")
319 | return
320 |
321 | job_name = os.path.dirname(nodes[0]) # Basic extraction of job folder name
322 | if not job_name:
323 | st.warning("Could not determine job name from nodes.")
324 | return
325 |
326 | # Check if this is a new job
327 | if "current_class3d_job" not in st.session_state or st.session_state["current_class3d_job"] != job_name:
328 | # It's a different job => reset iso_surface_cache or use a job-specific key
329 | st.session_state["current_class3d_job"] = job_name
330 | st.session_state["iso_surface_cache"] = {} # Clear old caches for iso surfaces
331 |
332 | if 'Class3D' in job_name:
333 | job_type = 'Class3D'
334 | elif 'Refine3D' in job_name:
335 | job_type = 'Refine3D'
336 | elif 'Reconstruct' in job_name:
337 | job_type = 'Reconstruct'
338 | else:
339 | job_type = '3D Classification'
340 |
341 | st.subheader(f"{job_type} job: {job_name}")
342 |
343 | # 1) Load or retrieve data from session state
344 | job_data = _load_and_cache_job_data(rln_folder, job_name, nodes)
345 | if not job_data.get("loaded", False):
346 | st.warning("Could not load the job data properly.")
347 | return
348 |
349 | # 2) If user modifies map_resize, only re-downsample volumes if needed
350 | # Optional: you can store 'plot_combined_classes_map_resize' in st.session_state
351 | map_resize_request = 150
352 | if "plot_combined_classes_map_resize" in st.session_state:
353 | map_resize_request = st.session_state["plot_combined_classes_map_resize"]
354 | with st.spinner("Downsampling Volumes", show_time=True):
355 | volumes_downsampled = _ensure_downsampled(job_data, map_resize_request)
356 | n_classes = job_data["n_classes"]
357 | class_dist = job_data["class_dist"]
358 | class_res = job_data["class_res"]
359 |
360 | # Ensure FSC data is NumPy arrays
361 | fsc_res = np.array(job_data["fsc_res"], dtype=float) if isinstance(job_data["fsc_res"], (list, tuple)) else job_data["fsc_res"]
362 | fsc_vals = np.array(job_data["fsc_vals"], dtype=float) if isinstance(job_data["fsc_vals"], (list, tuple)) else job_data["fsc_vals"]
363 |
364 | rot, tilt, psi = job_data["angles"]
365 | # Ensure angles are NumPy arrays not pandas Series
366 | rot = np.array(rot, dtype=float) if isinstance(rot, (pd.Series)) else rot
367 | tilt = np.array(tilt, dtype=float) if isinstance(tilt, (pd.Series)) else tilt
368 | psi = np.array(psi, dtype=float) if isinstance(psi, (pd.Series)) else psi
369 |
370 |
371 | star_block = job_data["star_block"]
372 |
373 | logger.debug(f'rot: {rot}, tilt: {tilt}, psi: {psi}')
374 |
375 | # Determine final distribution if multiple iterations
376 | if isinstance(class_dist, np.ndarray) and class_dist.ndim == 2:
377 | class_dist_final = class_dist[:, -1] # last iteration
378 | else:
379 | class_dist_final = class_dist
380 |
381 | if n_classes == 0 and len(volumes_downsampled) == 0:
382 | st.warning("No 3D class references found.")
383 | return
384 |
385 | st.write(f"Found {n_classes} classes in final iteration.")
386 | if len(volumes_downsampled) < n_classes:
387 | pass
388 |
389 | # 3) Show 3D volumes
390 | with st.expander("Combined 3D Volumes", expanded=True):
391 | # use a job-specific session key => "plot_combined_classes_{job_name}"
392 | job_session_key = f"plot_combined_classes_{job_name}"
393 | with st.spinner("plotting combined classes..."):
394 | plot_combined_classes(volumes_downsampled, class_dist_final, session_key=job_session_key)
395 |
396 | # 4) Volume projections
397 | with st.expander("Volume Projections", expanded=True):
398 | plot_projections(volumes_downsampled, class_dist_final)
399 |
400 | # 5) Class distribution
401 | if n_classes > 1 and isinstance(class_dist, np.ndarray) and class_dist.ndim == 2:
402 | with st.expander("Class Distribution Over Iterations", expanded=True):
403 | plot_class_distribution(class_dist)
404 |
405 | # 6) FSC for "Refine3D" jobs
406 | if "Refine3D" in job_name and fsc_res.size > 0 and fsc_vals.size > 0:
407 | col1, col2 = st.columns(2)
408 | with col1.expander("Fourier Shell Correlation (FSC) Stats", expanded=True):
409 | plot_fsc_stats(fsc_res, fsc_vals)
410 |
411 | # 7) Class resolution
412 | if isinstance(class_res, np.ndarray) and class_res.size > 0:
413 | if "Refine3D" in job_name:
414 | with col2.expander("Class Resolution", expanded=True):
415 | plot_class_resolution(class_res)
416 | else:
417 | with st.expander("Class Resolution", expanded=True):
418 | plot_class_resolution(class_res)
419 |
420 | # 8) Angular distribution
421 | try:
422 | if n_classes > 1 and star_block is not None and "_rlnClassNumber" in star_block.columns:
423 | cls_idx = star_block["_rlnClassNumber"]
424 | else:
425 | cls_idx = None
426 |
427 | if not (rot.all() and tilt.all() and psi.all()):
428 | logger.info("No angles found for angular distribution.")
429 | return
430 | else:
431 | with st.expander("Angular Distribution", expanded=True):
432 | plot_type = st.radio("Select plot type:", ("3D", "2D"), index=0, horizontal=True)
433 | # Example: 3D sphere representation
434 | # or you can call plot_angular_distribution(...) for 2D histograms
435 | if plot_type == "3D":
436 | plot_angular_distribution_sphere(psi, rot, tilt, cls_idx, symmetry=job_data["symmetry"])
437 | else:
438 | plot_angular_distribution_heatmap(psi, rot, tilt, cls_idx, symmetry=job_data["symmetry"])
439 |
440 | # Show star file scatter
441 | if st.checkbox("Show star file interactive scatter?"):
442 | interactive_scatter_plot(os.path.join(rln_folder, nodes[0]))
443 |
444 | except Exception as e:
445 | st.warning("No angle or particle data found for angular distribution.")
446 | logger.debug(f"Angular distribution error: {report_error(e)}")
447 |
448 | logger.info(f"{datetime.now()}: plot_class3d done.")
--------------------------------------------------------------------------------
/relion_jobs/motioncorr_job.py:
--------------------------------------------------------------------------------
1 | # motioncorr_job.py
2 |
3 | import logging
4 | import os
5 | from typing import List, Optional
6 |
7 | # Third-Party Imports
8 | import altair as alt # Restore Altair import
9 | import numpy as np
10 | import pandas as pd
11 | import plotly.graph_objects as go
12 | import streamlit as st
13 |
14 | # Local Imports
15 | from lib.image_utils import micrograph_viewer
16 | from lib.utils import (
17 | get_first_key,
18 | get_values_from_first_key,
19 | interactive_scatter_plot,
20 | parse_star,
21 | report_error,
22 | )
23 |
24 | # =============================================================================
25 | # Logger Setup
26 | # =============================================================================
27 | logger = logging.getLogger("main_app") # Use logger from main app
28 |
29 |
30 | def show_motion(rln_folder: str, mic_paths: List[str]) -> None:
31 | """
32 | Plots global and local motion shifts for a selected micrograph using Plotly.
33 |
34 | Args:
35 | rln_folder: Base folder containing the project data.
36 | mic_paths: List of relative paths to micrograph metadata STAR files.
37 | """
38 | if not mic_paths:
39 | st.info("No micrograph metadata files provided to display motion.")
40 | return
41 |
42 | try:
43 | # --- Micrograph Selection ---
44 | col1, col2 = st.columns([1, 3])
45 | with col1:
46 | if len(mic_paths) == 1:
47 | idx = 0
48 | st.caption(f"Displaying motion for: {os.path.basename(mic_paths[0])}")
49 | else:
50 | idx = st.slider("Select Micrograph Index:", 0, len(mic_paths) - 1, 0, key="motioncorr_mic_slider")
51 | mic_metadata_rel_path = mic_paths[idx]
52 | st.caption(f"File: {os.path.basename(mic_metadata_rel_path)}")
53 |
54 | mic_metadata_abs_path = os.path.join(rln_folder, mic_metadata_rel_path)
55 | logger.debug(f"Processing motion from file: {mic_metadata_abs_path}")
56 |
57 | if not os.path.exists(mic_metadata_abs_path):
58 | st.warning(f"Metadata file not found: {mic_metadata_rel_path}")
59 | return
60 |
61 | # --- Parse STAR File ---
62 | motion_star = parse_star(mic_metadata_abs_path)
63 | if not motion_star:
64 | st.warning(f"Could not parse or empty motion metadata file: {mic_metadata_rel_path}")
65 | return
66 |
67 | # --- Extract Motion Data ---
68 | global_shift_df = motion_star.get("global_shift")
69 | local_shift_df = motion_star.get("local_shift")
70 |
71 | if global_shift_df is None and local_shift_df is None:
72 | st.info("No global or local motion data found in this metadata file.")
73 | return
74 |
75 | fig = go.Figure()
76 | traces_added = False
77 |
78 | # --- Process and Plot Local Motion ---
79 | if isinstance(local_shift_df, pd.DataFrame) and not local_shift_df.empty:
80 | try:
81 | local_shift = local_shift_df.astype(float)
82 | coord_x_col, coord_y_col = "_rlnCoordinateX", "_rlnCoordinateY"
83 | shift_x_col, shift_y_col = "_rlnMicrographShiftX", "_rlnMicrographShiftY"
84 | required_local_cols = [coord_x_col, coord_y_col, shift_x_col, shift_y_col]
85 |
86 | if all(col in local_shift.columns for col in required_local_cols):
87 | fold_local_motion = 200 #col1.slider(
88 | # "Local Motion Scale:", 1, 500, 100, 10, key="local_motion_scale",
89 | #help="Magnifies local shifts for visibility.")
90 |
91 | local_shift["X_start"] = local_shift[coord_x_col]
92 | local_shift["Y_start"] = local_shift[coord_y_col]
93 | local_shift["X_end"] = local_shift[coord_x_col] + local_shift[shift_x_col] * fold_local_motion
94 | local_shift["Y_end"] = local_shift[coord_y_col] + local_shift[shift_y_col] * fold_local_motion
95 |
96 | grouped = local_shift.groupby([coord_x_col, coord_y_col])
97 | logger.debug(f"Plotting {len(grouped)} local motion tracks.")
98 | all_x_local, all_y_local = [], []
99 | for _, group_df in grouped:
100 | all_x_local.extend([group_df['X_start'].iloc[0], group_df['X_end'].iloc[-1], None])
101 | all_y_local.extend([group_df['Y_start'].iloc[0], group_df['Y_end'].iloc[-1], None])
102 |
103 | if all_x_local:
104 | fig.add_trace(go.Scatter(
105 | x=all_x_local, y=all_y_local, mode="lines", name="Local Shifts",
106 | line=dict(color="#039d83", width=1), hoverinfo='none'
107 | ))
108 | traces_added = True
109 | else: logger.warning(f"Missing required local motion columns in {mic_metadata_rel_path}")
110 | except Exception as exc:
111 | report_error(exc, f"Error processing local shift data for {mic_metadata_rel_path}")
112 | st.warning(f"Could not process local motion data: {exc}")
113 |
114 | # --- Process and Plot Global Motion ---
115 | if isinstance(global_shift_df, pd.DataFrame) and not global_shift_df.empty:
116 | try:
117 | global_shift = global_shift_df.astype(float)
118 | shift_x_col, shift_y_col = "_rlnMicrographShiftX", "_rlnMicrographShiftY"
119 | if all(col in global_shift.columns for col in [shift_x_col, shift_y_col]):
120 | center_x, center_y = 0.0, 0.0
121 | if isinstance(local_shift_df, pd.DataFrame) and not local_shift_df.empty and \
122 | all(col in local_shift_df.columns for col in ["_rlnCoordinateX", "_rlnCoordinateY"]):
123 | try:
124 | center_x = local_shift_df["_rlnCoordinateX"].astype(float).mean()
125 | center_y = local_shift_df["_rlnCoordinateY"].astype(float).mean()
126 | except Exception: pass # Ignore errors, default to 0,0
127 |
128 | if isinstance(local_shift_df, pd.DataFrame) and not local_shift_df.empty and \
129 | all(col in local_shift_df.columns for col in ["_rlnCoordinateX", "_rlnCoordinateY", "_rlnMicrographShiftX", "_rlnMicrographShiftY"]):
130 | try:
131 | max_coord_x = local_shift_df["_rlnCoordinateX"].astype(float).max()
132 | max_coord_y = local_shift_df["_rlnCoordinateY"].astype(float).max()
133 | max_gshift_x = global_shift[shift_x_col].abs().max()
134 | max_gshift_y = global_shift[shift_y_col].abs().max()
135 | scale_x = (max_coord_x / (2 * max_gshift_x)) if max_gshift_x > 1e-6 else 100
136 | scale_y = (max_coord_y / (2 * max_gshift_y)) if max_gshift_y > 1e-6 else 100
137 | fold_global_motion = max(1, int(min(scale_x, scale_y, 500)))
138 | except Exception as e:
139 | logger.warning(f"Could not calculate global fold factor: {e}. Using default.")
140 | fold_global_motion = 100 # Default scale if local data missing/problematic
141 | else:
142 | fold_global_motion = st.slider(
143 | "Global Motion Scale:", 1, 1000, 100, 10, key="global_motion_scale",
144 | help="Magnifies global shifts for visibility when local data is absent.")
145 |
146 | global_x = global_shift[shift_x_col] * fold_global_motion + center_x
147 | global_y = global_shift[shift_y_col] * fold_global_motion + center_y
148 |
149 | fig.add_trace(go.Scatter(
150 | x=global_x, y=global_y, mode="lines+markers", name="Global Shift",
151 | line=dict(color="#007acc", width=2), marker=dict(size=5),
152 | hovertemplate="Frame: %{customdata}
X: %{x:.1f}
Y: %{y:.1f}",
153 | customdata=np.arange(len(global_x))
154 | ))
155 | traces_added = True
156 | else: logger.warning(f"Missing required global motion columns in {mic_metadata_rel_path}")
157 | except Exception as exc:
158 | report_error(exc, f"Error processing global shift data for {mic_metadata_rel_path}")
159 | st.warning(f"Could not process global motion data: {exc}")
160 |
161 | # --- Display Combined Plot ---
162 | if traces_added:
163 | fig.update_layout(
164 | title=f"Motion Track: {os.path.basename(mic_metadata_rel_path)}",
165 | xaxis_title="X Coordinate (px)", yaxis_title="Y Coordinate (px)",
166 | width=500, height=500, showlegend=True, hovermode="closest",
167 | legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
168 | xaxis=dict(scaleanchor="y", scaleratio=1),
169 | margin=dict(l=10, r=10, t=50, b=10)
170 | )
171 | with col2:
172 | st.plotly_chart(fig, use_container_width=True)
173 | else:
174 | with col2:
175 | st.info("No plottable motion data found.")
176 |
177 | except FileNotFoundError:
178 | st.warning(f"Motion metadata file not found: {mic_metadata_rel_path}")
179 | except Exception as exc:
180 | report_error(exc, f"Error in show_motion for {mic_metadata_rel_path}")
181 | st.warning(f"An error occurred while displaying motion: {exc}")
182 |
183 |
184 | # =============================================================================
185 | # Main Job Function
186 | # =============================================================================
187 |
188 | def plot_motioncorr(rln_folder: str, node: str) -> None:
189 | """
190 | Processes and plots MotionCorr job results, including Altair statistics plots.
191 |
192 | Args:
193 | rln_folder: Base directory of the RELION project.
194 | node: Specific node file for this job (e.g., "MotionCorr/.../corrected_micrographs.star").
195 | """
196 | star_path = os.path.join(rln_folder, node)
197 | logger.info(f"Processing MotionCorr job: {star_path}")
198 |
199 | if not os.path.exists(star_path):
200 | st.warning(f"MotionCorr STAR file not found: {node}")
201 | logger.error(f"File not found: {star_path}")
202 | return
203 |
204 | try:
205 | star = parse_star(star_path)
206 | if not star:
207 | st.warning(f"Could not parse or empty STAR file: {node}")
208 | return
209 | except Exception as exc:
210 | report_error(exc, f"Failed to parse MotionCorr STAR file: {node}")
211 | st.warning(f"Error parsing {node}. See logs.")
212 | return
213 |
214 | star_data: Optional[pd.DataFrame] = None
215 | # Handle different potential block names
216 | if 'micrographs' in star:
217 | star_data = star["micrographs"]
218 | logger.debug("Using 'micrographs' block.")
219 | elif 'global' in star: # Relion 5 Tomo case
220 | logger.debug("Found 'global' block (likely Relion 5 Tomo MotionCorr).")
221 | try:
222 | tomo_star_files = star["global"]['_rlnTomoTiltSeriesStarFile']
223 | tomo_star_data_list = []
224 | failed_loads = 0
225 | for tomo_star in tomo_star_files:
226 | tomo_star_path = os.path.join(rln_folder, tomo_star)
227 | if os.path.exists(tomo_star_path):
228 | parsed_tomo_star = parse_star(tomo_star_path)
229 | first_block_data = get_values_from_first_key(parsed_tomo_star)
230 | if isinstance(first_block_data, pd.DataFrame):
231 | tomo_star_data_list.append(first_block_data)
232 | else: failed_loads += 1; logger.warning(f"No DataFrame in {tomo_star}")
233 | else: failed_loads += 1; logger.warning(f"Tilt series STAR not found: {tomo_star_path}")
234 | if failed_loads > 0: st.warning(f"Failed to load data from {failed_loads} tilt series STAR(s).")
235 | if tomo_star_data_list:
236 | star_data = pd.concat(tomo_star_data_list, ignore_index=True)
237 | logger.info(f"Concatenated data from {len(tomo_star_data_list)} tilt series.")
238 | else: st.warning("Could not load motion data from linked tilt series STAR files."); return
239 | except KeyError as e: report_error(e, "Missing key in Relion 5 Tomo MotionCorr"); st.warning(f"Error: Missing key {e}."); return
240 | except Exception as e: report_error(e, "Error processing Relion 5 Tomo MotionCorr"); st.warning(f"Error processing Tomo data: {e}."); return
241 | else:
242 | first_key = get_first_key(star)
243 | if first_key and isinstance(star[first_key], pd.DataFrame):
244 | star_data = star[first_key]
245 | logger.warning(f"Using first block '{first_key}' as fallback data source.")
246 | else: st.warning("No 'micrographs' or 'global' block found."); return
247 |
248 | if star_data is None or star_data.empty:
249 | st.warning("No valid micrograph data found for MotionCorr.")
250 | return
251 |
252 | # --- Motion Statistics Plotting (Altair) ---
253 | st.subheader("Motion Statistics")
254 | meta_names = ["_rlnAccumMotionTotal", "_rlnAccumMotionEarly", "_rlnAccumMotionLate"]
255 | df_line = pd.DataFrame()
256 | df_hist = pd.DataFrame()
257 | percentile_99 = None # Initialize percentile
258 |
259 | try:
260 | valid_meta_found = False
261 | for meta in meta_names:
262 | if meta in star_data.columns:
263 | data_array = pd.to_numeric(star_data[meta], errors='coerce').dropna()
264 | if not data_array.empty:
265 | valid_meta_found = True
266 | # Calculate percentile for clipping (only once, based on Total if possible)
267 | if meta == "_rlnAccumMotionTotal" and percentile_99 is None:
268 | percentile_99 = np.percentile(data_array, 99.5)
269 |
270 | # Use the calculated percentile, or fallback if Total wasn't available first
271 | clip_value = percentile_99 if percentile_99 is not None else np.percentile(data_array, 99.5)
272 | data_clipped_line = np.clip(data_array, None, clip_value)
273 | # Use a fixed clip for histogram for better comparison range? e.g., 100 Angstroms
274 | data_clipped_hist = np.clip(data_array, None, 100)
275 |
276 | series_name = meta.replace("_rln", "").replace("AccumMotion", "") # Shorter name
277 | temp_df_line = pd.DataFrame({"Index": np.arange(len(data_clipped_line)), "Motion": data_clipped_line, "Series": series_name})
278 | temp_df_hist = pd.DataFrame({"Motion": data_clipped_hist, "Series": series_name})
279 | df_line = pd.concat([df_line, temp_df_line], ignore_index=True)
280 | df_hist = pd.concat([df_hist, temp_df_hist], ignore_index=True)
281 | else: logger.warning(f"Column '{meta}' is empty or has no numeric data.")
282 | else: logger.warning(f"Column '{meta}' not found in STAR data.")
283 |
284 | if not valid_meta_found:
285 | st.caption("No motion statistics columns found to plot.")
286 | else:
287 | # Create Altair charts
288 | clip_info = f" (line plot clipped at {percentile_99:.2f} Å)" if percentile_99 is not None else ""
289 | line_chart = alt.Chart(df_line).mark_line(point=False, opacity=0.8).encode( # point=False for large datasets
290 | x=alt.X("Index:Q", title="Micrograph Index"),
291 | y=alt.Y("Motion:Q", title="Accumulated Motion (Å)"),
292 | color=alt.Color("Series:N", title="Motion Type"),
293 | tooltip=["Index", "Motion", "Series"]
294 | ).properties(
295 | title=f"Motion Statistics{clip_info}"
296 | ).interactive() # Add interactivity
297 |
298 | hist_chart = alt.Chart(df_hist).mark_bar(opacity=0.6, binSpacing=0).encode( # Adjust binSpacing
299 | x=alt.X("Motion:Q", bin=alt.Bin(maxbins=50), title="Accumulated Motion (Å, clipped at 100)"),
300 | y=alt.Y("count()", title="Frequency", stack=None), # stack=None for overlaid bars
301 | color=alt.Color("Series:N", title="Motion Type"),
302 | tooltip=[alt.Tooltip("Motion:Q", bin=True), "count()", "Series:N"]
303 | ).properties(
304 | title="Motion Histograms"
305 | ).interactive()
306 |
307 | # Display side-by-side
308 | col1, col2 = st.columns(2)
309 | with col1: st.altair_chart(line_chart, use_container_width=True)
310 | with col2: st.altair_chart(hist_chart, use_container_width=True)
311 |
312 | except Exception as exc:
313 | report_error(exc, "Error generating MotionCorr statistics plots.")
314 | st.warning(f"Could not generate motion statistics plots: {exc}")
315 |
316 | st.divider()
317 |
318 | # --- Motion Tracks and Micrograph Display ---
319 | # Checkbox to show motion tracks plot
320 | metadata_col = '_rlnMicrographMetadata'
321 | if metadata_col in star_data.columns:
322 | # Use default value True to show motion by default if available
323 | if st.checkbox("Show Motion Tracks per Micrograph?", value=True, key="motioncorr_show_tracks"):
324 | st.subheader("Motion Tracks")
325 | motion_star_files = star_data[metadata_col].dropna().tolist() # Drop missing values
326 | valid_motion_files = [f for f in motion_star_files if isinstance(f, str) and f]
327 | if valid_motion_files:
328 | show_motion(rln_folder, valid_motion_files)
329 | else: st.info("No valid micrograph metadata files found.")
330 | else: st.caption(f"'{metadata_col}' column not found - cannot display individual motion tracks.")
331 |
332 | st.divider()
333 |
334 | # Display corrected micrographs
335 | st.subheader("Corrected Micrographs")
336 | mics_col = '_rlnMicrographName'
337 | if mics_col in star_data.columns:
338 | mics = star_data[mics_col].dropna().tolist()
339 | valid_mics = [m for m in mics if isinstance(m, str) and m]
340 | if valid_mics:
341 | try:
342 | micrograph_viewer(rln_folder, valid_mics)
343 | except Exception as exc:
344 | report_error(exc, "Error calling micrograph_viewer for corrected micrographs")
345 | st.warning(f"Error displaying corrected micrographs: {exc}")
346 | else: st.info("No valid corrected micrograph paths found.")
347 | else: st.caption(f"'{mics_col}' column not found - cannot display corrected micrographs.")
348 |
349 | st.divider()
350 |
351 | # Option to plot metadata interactively
352 | if st.checkbox("Plot Metadata Interactively?", key="motioncorr_plot_meta"):
353 | st.subheader("Metadata Distribution")
354 | try:
355 | interactive_scatter_plot(data_source=star_path, title_prefix="MotionCorr Metadata")
356 | except Exception as exc:
357 | report_error(exc, "Error calling interactive_scatter_plot for MotionCorr metadata")
358 | st.warning(f"Error plotting metadata: {exc}")
359 |
360 | logger.info(f"Finished processing MotionCorr job: {node}")
--------------------------------------------------------------------------------
/follow_relion_gracefully.py:
--------------------------------------------------------------------------------
1 | #
2 | # ____ ___ ___ ____ ___ ____ ___ ___ ___
3 | # /\ __\ /\_ \ /\_ \ /\ _`\ /\_ \ __ /\ _`\ /'___\ /\_ \ /\_ \
4 | # \ \ \_ __\//\ \ \//\ \ ___ __ __ __ \ \ \L\ \ __\//\ \ /\_\ ___ ___ \ \ \L\_\ _ __ __ ___ __ /\ \__/ __ __\//\ \ \//\ \ __ __
5 | # \ \ _\/ __`\\ \ \ \ \ \ / __`\/\ \/\ \/\ \ \ \ , / /'__`\\ \ \ \/\ \ / __`\ /' _ `\ \ \ \L_L /\`'__\/'__`\ /'___\ /'__`\ \ ,__\/\ \/\ \ \ \ \ \ \ \ /\ \/\ \
6 | # \ \ \/\ \L\ \\_\ \_ \_\ \_/\ \L\ \ \ \_/ \_/ \ \ \ \\ \ /\ __/ \_\ \_\ \ \/\ \L\ \/\ \/\ \ \ \ \/, \ \ \//\ \L\.\_/\ \__//\ __/\ \ \_/\ \ \_\ \ \_\ \_ \_\ \_\ \ \_\ \
7 | # \ \_\ \____//\____\/\____\ \____/\ \___x___/' \ \_\ \_\ \____\/\____\\ \_\ \____/\ \_\ \_\ \ \____/\ \_\\ \__/.\_\ \____\ \____\\ \_\ \ \____/ /\____\/\____\\/`____ \
8 | # \/_/\/___/ \/____/\/____/\/___/ \/__//__/ \/_/\/ /\/____/\/____/ \/_/\/___/ \/_/\/_/ \/___/ \/_/ \/__/\/_/\/____/\/____/ \/_/ \/___/ \/____/\/____/ `/___/> \
9 | # \\___/
10 | # Follow Relion Gracefully (v6)
11 | # Developed by Dawid Zyla, La Jolla Institute for Immunology
12 | # Non-Profit Open Software License 3.0
13 |
14 | # update v6 (2025-04-26)
15 |
16 | # ## Main Changes
17 | # -> Better integration with Streamlit platform (https://streamlit.io/)
18 | # -> Added support for all (most) Relion jobs covering all cryo-ET and SPA jobs (except for DynaMight)
19 | # -> Temporary removed live and in-browser job execution
20 | # -> General QOL fixes and improvements
21 | #
22 | # ## New Features
23 | # -> Support for all cryo-ET jobs with job previews
24 | # -> Optimized visualizations for most of the Relion jobs
25 | # -> Divided the code into smaller modules for better readability and maintainability
26 | # -> Overhauled visualization of Local Resolution, picking, micrograph previews, and other jobs
27 | # -> Increased performance and reduced loading times (in most cases)
28 | # -> Added better logging and error handling
29 | #
30 | #
31 | # ## To Do
32 | # -> Add own DynaMight job preview (currently not supported)
33 | # -> Add support for cryoSPARC cs files and export to Relion (most likely via pyem)
34 | # -> Further speed optimization and code cleanup
35 |
36 | # Standard Library Imports
37 | import argparse
38 | import logging
39 | import os
40 | import re
41 | from typing import Callable, Dict, List, Optional, Tuple
42 |
43 | # Third-Party Imports
44 | import pandas as pd
45 | import streamlit as st
46 |
47 | # Local Imports
48 | from lib.jobs_utils import ( # Assuming these are correctly defined in jobs_utils
49 | create_network,
50 | display_job_info,
51 | format_display_name,
52 | )
53 | from lib.utils import ( # Assuming these are correctly defined in utils
54 | check_password,
55 | custom_css,
56 | dynamic_folder_explorer,
57 | get_footer,
58 | # get_newest_change, # Not directly used here, used within display_job_info
59 | # get_note, # Not directly used here, used within display_job_info
60 | # get_relationships_df, # Not directly used here, used within display_job_info
61 | interactive_scatter_plot,
62 | parse_star,
63 | render_svg,
64 | report_error, # Use the central report_error
65 | )
66 |
67 | # =============================================================================
68 | # Constants
69 | # =============================================================================
70 | # Column names in RELION STAR files
71 | RLN_PROCESS_TYPE_LABEL = "_rlnPipeLineProcessTypeLabel"
72 | RLN_PROCESS_ALIAS = "_rlnPipeLineProcessAlias"
73 | RLN_PROCESS_NAME = "_rlnPipeLineProcessName"
74 | RLN_STATUS_LABEL = "_rlnPipeLineProcessStatusLabel"
75 |
76 | # Keys for data blocks in parsed pipeline STAR dictionary
77 | PIPELINE_PROCESSES_KEY = "pipeline_processes"
78 | PIPELINE_NODES_KEY = "pipeline_nodes"
79 | PIPELINE_EDGES_KEY = "pipeline_input_edges" # Make sure this key matches parser output
80 |
81 | # Special process names used internally
82 | FLOWCHART_PROCESS = "relion.flowchart"
83 | INTERACTIVE_PLOT_PROCESS = "relion.InteractivePlot"
84 |
85 | # Session State Keys
86 | STATE_SELECTED_PROCESS = "selected_process"
87 | STATE_CURRENT_JOB = "current_job"
88 | STATE_JOB_PARAMS = "job_params"
89 | STATE_PROCESS_RADIO_KEY = "process_radio_key"
90 | STATE_JOB_RADIO_KEY = "job_radio_key"
91 | STATE_PASSWORD_ARGS = "password_args"
92 | STATE_DEFAULT_FOLDER = "default_job_folder" # Initial folder path from args/explorer
93 | STATE_CURRENT_FOLDER = "current_folder_path" # Path currently being viewed
94 | STATE_DISPLAY_TO_ORIGINAL_PROCESS = "display_to_original_process_map"
95 | STATE_JOBS_DICT = "jobs_dict"
96 | TEMP_DIR_PATH = "temp"
97 |
98 | # Logger configuration
99 | logging_level = logging.DEBUG
100 |
101 | # =============================================================================
102 | # Logger and Global Error Handling
103 | # =============================================================================
104 | logger = logging.getLogger("main_app")
105 | if not logger.handlers:
106 | logger.setLevel(logging_level)
107 | console_handler = logging.StreamHandler()
108 | console_handler.setLevel(logging_level)
109 | formatter = logging.Formatter(
110 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
111 | datefmt="%Y-%m-%d %H:%M:%S",
112 | )
113 | console_handler.setFormatter(formatter)
114 | logger.addHandler(console_handler)
115 | logger.propagate = False
116 |
117 | ERROR_HANDLER: Optional[Callable[[Exception, str], None]] = None
118 |
119 |
120 | def local_report_error_handler(exc: Exception, error_info: str):
121 | """Local fallback error handler that logs the error."""
122 | logger.error("An unexpected error occurred:\n%s", error_info)
123 |
124 |
125 | def set_error_handler(handler: Callable[[Exception, str], None]) -> None:
126 | """Sets the global error handler for this module."""
127 | global ERROR_HANDLER
128 | ERROR_HANDLER = handler
129 |
130 |
131 | # =============================================================================
132 | # UI Setup
133 | # =============================================================================
134 | def set_style() -> None:
135 | """Configures Streamlit page settings and applies custom CSS."""
136 | try:
137 | st.set_page_config(
138 | page_title="Follow Relion Gracefully",
139 | page_icon=":microscope:",
140 | initial_sidebar_state="expanded",
141 | layout="wide",
142 | )
143 | st.markdown(custom_css(), unsafe_allow_html=True)
144 | st.markdown(
145 | """
146 | """,
150 | unsafe_allow_html=True,
151 | )
152 | except Exception as e:
153 | logger.error(f"Failed to set page style: {e}")
154 |
155 |
156 | # =============================================================================
157 | # Utility Functions
158 | # =============================================================================
159 | def parse_help_text(help_text: str) -> Dict[str, str]:
160 | """Parses --parameter (description) format from help text."""
161 | pattern = (
162 | r"--(\w+)(?: \(([^)]+)\))?" # Non-capturing group for optional description
163 | )
164 | matches = re.findall(pattern, help_text)
165 | return {f"--{match[0]}": match[1].strip() if match[1] else "" for match in matches}
166 |
167 |
168 | def create_temp_directory(temp_dir: str = TEMP_DIR_PATH) -> bool:
169 | """Creates a temporary directory if it doesn't exist."""
170 | if not os.path.exists(temp_dir):
171 | try:
172 | os.makedirs(temp_dir)
173 | logger.info("Created temporary directory: %s", temp_dir)
174 | return True
175 | except OSError as e:
176 | # Use imported report_error
177 | report_error(e, f"Failed to create temporary directory {temp_dir}")
178 | st.error(f"Failed to create required directory '{temp_dir}': {e}")
179 | return False
180 | return True
181 |
182 |
183 | # =============================================================================
184 | # Caching Heavy Computations
185 | # =============================================================================
186 | @st.cache_data(ttl=3600) # Cache STAR file parsing for 1 hour
187 | def load_pipeline_star(folder: str) -> Optional[Dict[str, pd.DataFrame]]:
188 | """Loads and parses the default_pipeline.star file from the given folder."""
189 | default_pipeline_path = os.path.join(folder, "default_pipeline.star")
190 | if not os.path.exists(default_pipeline_path):
191 | logger.warning("File not found: %s", default_pipeline_path)
192 | return None
193 | try:
194 | logger.info("Loading and parsing STAR file: %s", default_pipeline_path)
195 | pipeline_star = parse_star(default_pipeline_path)
196 | if not pipeline_star:
197 | logger.warning(
198 | "Parsing STAR file returned empty: %s", default_pipeline_path
199 | )
200 | return None
201 | logger.info("Successfully parsed STAR file: %s", default_pipeline_path)
202 | return pipeline_star
203 | except Exception as exc:
204 | report_error(exc, f"Failed to parse STAR file: {default_pipeline_path}")
205 | st.error(
206 | f"Failed to parse STAR file: {os.path.basename(default_pipeline_path)}. Error: {exc}"
207 | )
208 | return None
209 |
210 |
211 | # =============================================================================
212 | # Data Preparation Functions
213 | # =============================================================================
214 | def get_pipeline_df(pipeline_star: Optional[Dict[str, pd.DataFrame]]) -> pd.DataFrame:
215 | """Extracts and prepares the 'pipeline_processes' DataFrame."""
216 | if not pipeline_star:
217 | return pd.DataFrame()
218 | pipeline_processes = pipeline_star.get(PIPELINE_PROCESSES_KEY)
219 | if not isinstance(pipeline_processes, pd.DataFrame) or pipeline_processes.empty:
220 | logger.warning("'%s' missing or empty in STAR data.", PIPELINE_PROCESSES_KEY)
221 | return pd.DataFrame()
222 |
223 | df = pipeline_processes.copy()
224 | required_cols = {
225 | RLN_PROCESS_TYPE_LABEL: "",
226 | RLN_PROCESS_NAME: "",
227 | RLN_STATUS_LABEL: "Unknown",
228 | RLN_PROCESS_ALIAS: "None",
229 | }
230 | for col, default in required_cols.items():
231 | if col not in df.columns:
232 | logger.warning(
233 | "Column '%s' missing, adding with default '%s'.", col, default
234 | )
235 | df[col] = default
236 | df[RLN_PROCESS_ALIAS] = df[RLN_PROCESS_ALIAS].fillna("None").astype(str)
237 | df[RLN_STATUS_LABEL] = df[RLN_STATUS_LABEL].fillna("Unknown").astype(str)
238 | return df
239 |
240 |
241 | @st.cache_data
242 | def get_jobs_for_process(
243 | selected_process: str, df: pd.DataFrame
244 | ) -> Tuple[List[str], Dict[str, str]]:
245 | """
246 | Filters jobs for a selected process type and creates a display key mapping.
247 | Jobs are returned in the order they appear in the DataFrame.
248 |
249 | Args:
250 | selected_process: Original process name (e.g., "relion.Class2D").
251 | df: Prepared pipeline_processes DataFrame.
252 |
253 | Returns:
254 | Tuple: (list of display keys in original order, mapping display_key -> job_name).
255 | """
256 | if df.empty or RLN_PROCESS_TYPE_LABEL not in df.columns:
257 | logger.debug(
258 | "Cannot get jobs: DataFrame empty or missing '%s'.", RLN_PROCESS_TYPE_LABEL
259 | )
260 | return [], {}
261 |
262 | jobs_df = df.loc[df[RLN_PROCESS_TYPE_LABEL] == selected_process]
263 | if jobs_df.empty:
264 | logger.debug("No jobs found for process type '%s'.", selected_process)
265 | return [], {}
266 |
267 | jobs_dict: Dict[str, str] = {}
268 | jobs_display_keys: List[str] = [] # Maintain order
269 | display_key_counts: Dict[str, int] = {}
270 |
271 | for _, row in jobs_df.iterrows(): # Iterate in DataFrame order
272 | alias = row.get(RLN_PROCESS_ALIAS, "None")
273 | job_name = row.get(RLN_PROCESS_NAME, "")
274 | if not job_name:
275 | continue
276 |
277 | base_display_key = (
278 | alias if alias != "None" and alias.strip() != "" else job_name
279 | )
280 | current_count = display_key_counts.get(base_display_key, 0)
281 | display_key = base_display_key
282 | if current_count > 0:
283 | display_key = (
284 | f"{base_display_key} ({current_count})" # Append count for uniqueness
285 | )
286 | display_key_counts[base_display_key] = current_count + 1
287 |
288 | if display_key not in jobs_dict:
289 | jobs_dict[display_key] = job_name
290 | jobs_display_keys.append(display_key) # Add to ordered list
291 | # Do not sort jobs_display_keys
292 |
293 | logger.debug(
294 | "Found %d jobs for process '%s'.", len(jobs_display_keys), selected_process
295 | )
296 | return jobs_display_keys, jobs_dict
297 |
298 |
299 | # =============================================================================
300 | # State Management Callbacks
301 | # =============================================================================
302 | def handle_folder_change():
303 | """Resets relevant session state variables when the folder changes."""
304 | logger.info("Resetting process/job state due to folder change.")
305 | keys_to_reset = [
306 | STATE_SELECTED_PROCESS,
307 | STATE_CURRENT_JOB,
308 | STATE_JOB_PARAMS,
309 | STATE_PROCESS_RADIO_KEY,
310 | STATE_JOB_RADIO_KEY,
311 | STATE_JOBS_DICT,
312 | STATE_DISPLAY_TO_ORIGINAL_PROCESS,
313 | ]
314 | for key in keys_to_reset:
315 | if key in st.session_state:
316 | st.session_state[key] = None # Reset to None or appropriate default
317 | st.session_state[STATE_JOBS_DICT] = {} # Ensure these are empty dicts
318 | st.session_state[STATE_JOB_PARAMS] = {}
319 | st.session_state[STATE_DISPLAY_TO_ORIGINAL_PROCESS] = {}
320 |
321 |
322 | def handle_process_change():
323 | """Resets job state when the selected process type changes via the radio button."""
324 | selected_display_name = st.session_state.get(STATE_PROCESS_RADIO_KEY)
325 | display_to_original = st.session_state.get(STATE_DISPLAY_TO_ORIGINAL_PROCESS, {})
326 | newly_selected_process = display_to_original.get(selected_display_name)
327 |
328 | # Check if the *actual underlying process name* has changed
329 | if newly_selected_process != st.session_state.get(STATE_SELECTED_PROCESS):
330 | logger.info("Process selection changed to: '%s'", newly_selected_process)
331 | st.session_state[STATE_SELECTED_PROCESS] = newly_selected_process
332 | # Reset only job-related state
333 | st.session_state[STATE_CURRENT_JOB] = None
334 | st.session_state[STATE_JOB_PARAMS] = {}
335 | st.session_state[STATE_JOB_RADIO_KEY] = None
336 | st.session_state[STATE_JOBS_DICT] = {}
337 | logger.debug("Job state reset due to process change.")
338 |
339 |
340 | def handle_job_change():
341 | """Updates the current job based on the job radio button selection."""
342 | selected_display_key = st.session_state.get(STATE_JOB_RADIO_KEY)
343 | jobs_dict = st.session_state.get(STATE_JOBS_DICT, {})
344 | newly_selected_job_name = jobs_dict.get(selected_display_key)
345 |
346 | # Check if the *actual underlying job name* has changed
347 | if newly_selected_job_name != st.session_state.get(STATE_CURRENT_JOB):
348 | logger.info("Job selection changed to: '%s'", newly_selected_job_name)
349 | st.session_state[STATE_CURRENT_JOB] = newly_selected_job_name
350 | st.session_state[STATE_JOB_PARAMS] = {
351 | "folder": st.session_state.get(STATE_CURRENT_FOLDER),
352 | "process": st.session_state.get(STATE_SELECTED_PROCESS),
353 | }
354 |
355 |
356 | # =============================================================================
357 | # Main Application Logic
358 | # =============================================================================
359 | def main() -> None:
360 | """Main function to run the Streamlit application."""
361 | try:
362 | # --- Argument Parsing ---
363 | parser = argparse.ArgumentParser(description="Follow Relion Gracefully Viewer")
364 | parser.add_argument(
365 | "-i",
366 | "--folder",
367 | type=str,
368 | help="Initial Relion project folder",
369 | default="~",
370 | )
371 | # parser.add_argument("--relion-path", type=str, help="Path to RELION installation (bin)", default="/")
372 | parser.add_argument(
373 | "-p",
374 | "--password",
375 | type=str,
376 | help="Password to protect the instance",
377 | default="",
378 | )
379 | args, _ = parser.parse_known_args()
380 |
381 | # --- Session State Initialization ---
382 | # Initialize folder state using command line argument ONLY IF state is empty
383 | if STATE_DEFAULT_FOLDER not in st.session_state:
384 | initial_folder = os.path.abspath(os.path.expanduser(args.folder))
385 | st.session_state[STATE_DEFAULT_FOLDER] = initial_folder
386 | st.session_state[STATE_CURRENT_FOLDER] = (
387 | initial_folder # Also set current initially
388 | )
389 | logger.info(
390 | "Initialized state: Default folder set from arg to %s", initial_folder
391 | )
392 | # Ensure other states have default values if not set previously
393 | st.session_state.setdefault(
394 | STATE_CURRENT_FOLDER, st.session_state[STATE_DEFAULT_FOLDER]
395 | )
396 | st.session_state.setdefault(STATE_PASSWORD_ARGS, args.password)
397 | st.session_state.setdefault(STATE_SELECTED_PROCESS, None)
398 | st.session_state.setdefault(STATE_CURRENT_JOB, None)
399 | st.session_state.setdefault(STATE_JOB_PARAMS, {})
400 | st.session_state.setdefault(STATE_PROCESS_RADIO_KEY, None)
401 | st.session_state.setdefault(STATE_JOB_RADIO_KEY, None)
402 | st.session_state.setdefault(STATE_DISPLAY_TO_ORIGINAL_PROCESS, {})
403 | st.session_state.setdefault(STATE_JOBS_DICT, {})
404 |
405 | # --- Password Check ---
406 | if st.session_state[STATE_PASSWORD_ARGS]:
407 | if not check_password(
408 | st.session_state[STATE_PASSWORD_ARGS]
409 | ): # Pass correct arg
410 | st.warning("Password required.")
411 | st.stop()
412 |
413 | # --- UI Rendering & Folder Selection ---
414 | render_svg("./static/frg.svg")
415 | footer = get_footer()
416 | create_temp_directory()
417 |
418 | # dynamic_folder_explorer updates STATE_DEFAULT_FOLDER internally now
419 | # It uses STATE_DEFAULT_FOLDER as its starting point if uninitialized.
420 | # It returns the confirmed path, which we use to update STATE_CURRENT_FOLDER if needed.
421 | selected_folder = dynamic_folder_explorer(
422 | st.session_state[STATE_DEFAULT_FOLDER]
423 | )
424 |
425 | # Detect if the folder confirmed by the explorer differs from the current viewing folder
426 | if selected_folder != st.session_state.get(STATE_CURRENT_FOLDER):
427 | logger.info(
428 | "Folder changed via explorer: '%s' -> '%s'",
429 | st.session_state.get(STATE_CURRENT_FOLDER),
430 | selected_folder,
431 | )
432 | st.session_state[STATE_CURRENT_FOLDER] = selected_folder
433 | # Crucially, update the DEFAULT folder as well if the user explicitly selected it
434 | st.session_state[STATE_DEFAULT_FOLDER] = selected_folder
435 | handle_folder_change() # Reset dependent states
436 | st.cache_data.clear() # Clear data cache on folder change
437 | st.rerun()
438 |
439 | current_folder = st.session_state[STATE_CURRENT_FOLDER]
440 |
441 | # --- Load Data for Current Folder ---
442 | pipeline_star = load_pipeline_star(current_folder) # Uses cache
443 | df = get_pipeline_df(pipeline_star) # Process potentially cached data
444 |
445 | # --- Sidebar: Process Selection ---
446 | st.sidebar.title("Process Types")
447 | process_types = []
448 | if df.empty:
449 | st.sidebar.warning("No pipeline data found.")
450 | elif RLN_PROCESS_TYPE_LABEL not in df.columns:
451 | st.sidebar.error(f"Missing column: {RLN_PROCESS_TYPE_LABEL}")
452 | else:
453 | # Get unique process types IN THE ORDER THEY APPEAR in the DataFrame
454 | process_types = list(df[RLN_PROCESS_TYPE_LABEL].unique()) # No sorting here
455 |
456 | special_processes = [FLOWCHART_PROCESS, INTERACTIVE_PLOT_PROCESS]
457 | all_processes = (
458 | process_types + special_processes
459 | ) # Add special views at the end
460 |
461 | if not all_processes:
462 | st.sidebar.info("No processes found.")
463 | st.info("Select a valid RELION project folder.")
464 | else:
465 | # Map original names to display names
466 | original_to_display = {p: format_display_name(p) for p in all_processes}
467 | display_to_original = {d: p for p, d in original_to_display.items()}
468 | st.session_state[STATE_DISPLAY_TO_ORIGINAL_PROCESS] = display_to_original
469 | display_options = list(original_to_display.values())
470 |
471 | # Determine current selection index for the radio button
472 | current_display = st.session_state.get(STATE_PROCESS_RADIO_KEY)
473 | if current_display not in display_options:
474 | current_display = display_options[0] if display_options else None
475 | # If defaulting, update the underlying selected process state
476 | if current_display:
477 | st.session_state[STATE_PROCESS_RADIO_KEY] = current_display
478 | st.session_state[STATE_SELECTED_PROCESS] = display_to_original.get(
479 | current_display
480 | )
481 | handle_process_change() # Reset job state implicitly
482 |
483 | current_index = (
484 | display_options.index(current_display)
485 | if current_display in display_options
486 | else 0
487 | )
488 |
489 | # Process Radio Button
490 | st.sidebar.radio(
491 | "Select View:",
492 | options=display_options,
493 | index=current_index,
494 | key=STATE_PROCESS_RADIO_KEY,
495 | on_change=handle_process_change,
496 | )
497 |
498 | # --- Sidebar: Job Selection ---
499 | selected_process = st.session_state.get(STATE_SELECTED_PROCESS)
500 | if selected_process and selected_process not in special_processes:
501 | # Get jobs IN ORIGINAL ORDER
502 | jobs_display_keys, jobs_dict = get_jobs_for_process(
503 | selected_process, df
504 | ) # Uses cache
505 | st.session_state[STATE_JOBS_DICT] = jobs_dict
506 |
507 | if jobs_display_keys:
508 | process_display_name = format_display_name(selected_process)
509 | st.sidebar.title(f"{process_display_name} Jobs")
510 |
511 | current_job_key = st.session_state.get(STATE_JOB_RADIO_KEY)
512 | # Check if current selection is valid, default to LAST job if not
513 | if current_job_key not in jobs_display_keys:
514 | current_job_key = jobs_display_keys[
515 | -1
516 | ] # Default to last job in the list
517 | st.session_state[STATE_JOB_RADIO_KEY] = current_job_key
518 | st.session_state[STATE_CURRENT_JOB] = jobs_dict.get(
519 | current_job_key
520 | )
521 | st.session_state[STATE_JOB_PARAMS] = {
522 | "folder": current_folder,
523 | "process": selected_process,
524 | }
525 |
526 | job_index = jobs_display_keys.index(current_job_key)
527 |
528 | # Job Radio Button
529 | st.sidebar.radio(
530 | "Select Job:",
531 | options=jobs_display_keys,
532 | index=job_index,
533 | key=STATE_JOB_RADIO_KEY,
534 | on_change=handle_job_change,
535 | )
536 | else:
537 | st.sidebar.caption("No jobs of this type found.")
538 | if (
539 | st.session_state.get(STATE_CURRENT_JOB) is not None
540 | ): # Clear state if no jobs
541 | st.session_state[STATE_CURRENT_JOB] = None
542 | st.session_state[STATE_JOB_PARAMS] = {}
543 | st.session_state[STATE_JOB_RADIO_KEY] = None
544 |
545 | # --- Main Area Display ---
546 | st.markdown("---")
547 | if selected_process == FLOWCHART_PROCESS:
548 | st.title("Pipeline Flowchart")
549 | orientation = st.radio(
550 | "Layout:",
551 | ["top-bottom", "left-right"],
552 | index=0,
553 | key="flowchart_orientation",
554 | horizontal=True,
555 | )
556 | if pipeline_star:
557 | dot_string = create_network(pipeline_star, orientation=orientation)
558 | if dot_string:
559 | st.graphviz_chart(dot_string, use_container_width=True)
560 | else:
561 | st.warning("Could not generate flowchart.")
562 | else:
563 | st.warning("Load a project first.")
564 |
565 | elif selected_process == INTERACTIVE_PLOT_PROCESS:
566 | st.title("Interactive STAR File Plot")
567 | col1, col2 = st.columns([1, 2])
568 | star_file_data = None
569 | plot_file_name = "Data"
570 | viewer_prefix = "iplot_"
571 | with col1.expander("Select Data", expanded=True):
572 | star_path = st.text_input(
573 | "STAR file path:", key=f"{viewer_prefix}path"
574 | )
575 | up_file = st.file_uploader(
576 | "Or upload STAR:", type=["star"], key=f"{viewer_prefix}up"
577 | )
578 | if up_file:
579 | tmp_path = os.path.join(TEMP_DIR_PATH, up_file.name)
580 | try:
581 | with open(tmp_path, "wb") as f:
582 | f.write(up_file.getbuffer())
583 | star_file_data = parse_star(tmp_path)
584 | plot_file_name = up_file.name
585 | os.remove(tmp_path)
586 | except Exception as e:
587 | st.error(f"Error processing upload: {e}")
588 | elif star_path and os.path.exists(star_path):
589 | star_file_data = parse_star(star_path)
590 | plot_file_name = os.path.basename(star_path)
591 | elif star_path:
592 | st.warning("Path does not exist.")
593 |
594 | if star_file_data:
595 | blocks = list(star_file_data.keys())
596 | if blocks:
597 | sel_block = col1.selectbox(
598 | "Select Block:",
599 | blocks,
600 | index=len(blocks) - 1,
601 | key=f"{viewer_prefix}block",
602 | )
603 | if sel_block:
604 | with col2:
605 | interactive_scatter_plot(
606 | data_source=star_file_data, # Pass data dict
607 | block_selector_options=blocks,
608 | default_block=sel_block,
609 | title_prefix=plot_file_name,
610 | )
611 | else:
612 | col1.warning("No data blocks found.")
613 | else:
614 | col1.info("Upload or provide path to a STAR file.")
615 |
616 | elif selected_process: # Regular RELION job type
617 | selected_job = st.session_state.get(STATE_CURRENT_JOB)
618 | if selected_job:
619 | display_job_info(
620 | selected_job, current_folder, df, pipeline_star or {}
621 | )
622 | else:
623 | if selected_process in process_types:
624 | st.info(
625 | f"Select a job for '{format_display_name(selected_process)}' from the sidebar."
626 | )
627 |
628 | else: # No process selected
629 | st.info("Select a process type or view from the sidebar.")
630 |
631 | st.sidebar.button(
632 | "Refresh",
633 | key="refresh_button",
634 | help="Refresh the page to clear any temporary data.",
635 | on_click=handle_folder_change, # Reset state on refresh
636 | )
637 | # --- Footer ---
638 | st.sidebar.markdown("---")
639 | st.sidebar.markdown(footer, unsafe_allow_html=True)
640 |
641 | except Exception as exc:
642 | report_error(exc, "An unexpected error occurred in the main application.")
643 | st.error("An unexpected error occurred. Check logs for details.")
644 |
645 |
646 | if __name__ == "__main__":
647 | set_style()
648 | # Set local fallback error handler for logging
649 | set_error_handler(local_report_error_handler)
650 | main()
651 |
--------------------------------------------------------------------------------