├── src ├── ui │ ├── __init__.py │ ├── sidebar.py │ └── widgets.py ├── core │ ├── __init__.py │ └── core.py ├── lib │ ├── __init__.py │ ├── serializer.py │ ├── test_utils.py │ └── utils.py ├── pages │ ├── __init__.py │ ├── app_page.py │ ├── loader_page.py │ ├── plotter_page.py │ └── fit_page.py ├── scripts │ └── generate_data.py └── app.py ├── mypy.ini ├── README.md ├── .gitignore ├── pyrightconfig.json └── requirements.txt /src/ui/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pages/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/ui/sidebar.py: -------------------------------------------------------------------------------- 1 | from typing import List, cast 2 | 3 | import streamlit as st 4 | 5 | 6 | def radio_button(title: str, options: List[str]) -> str: 7 | return cast(str, st.sidebar.radio(title, options)) 8 | -------------------------------------------------------------------------------- /src/lib/serializer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import jsonpickle 4 | import jsonpickle.ext.numpy 5 | 6 | jsonpickle.ext.numpy.register_handlers() 7 | 8 | 9 | class CustomSerializer: 10 | @staticmethod 11 | def serialize(obj: Any) -> None: 12 | jsonpickle.encode(obj) 13 | 14 | @staticmethod 15 | def deserialize(obj: Any) -> Any: 16 | return jsonpickle.decode(obj) 17 | -------------------------------------------------------------------------------- /src/pages/app_page.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from core.core import SessionState 4 | 5 | 6 | class AppPage(ABC): 7 | """ 8 | An application page. Each application page class must be a callable function 9 | to the nature of Streamlit 10 | """ 11 | 12 | @abstractmethod 13 | def __call__(self, sess: SessionState) -> None: 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /src/ui/widgets.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import streamlit as st 4 | from streamlit.uploaded_file_manager import UploadedFile 5 | 6 | 7 | def file_uploader(label: str) -> List[UploadedFile]: 8 | files: List[UploadedFile] = st.file_uploader( 9 | label=label, 10 | accept_multiple_files=True, 11 | ) 12 | return files 13 | 14 | 15 | def select_box(label: str, options: List[str]) -> str: 16 | selected: str = st.selectbox(label, options) 17 | return selected 18 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | # MyPy config file 2 | # File reference here: 3 | # http://mypy.readthedocs.io/en/latest/config_file.html#config-file 4 | 5 | [mypy] 6 | warn_redundant_casts = True 7 | warn_unused_ignores = True 8 | 9 | # Needed because of bug in MyPy 10 | disallow_subclassing_any = False 11 | scripts_are_modules = False 12 | mypy_path = stubs 13 | disallow_untyped_calls = True 14 | disallow_untyped_defs = True 15 | check_untyped_defs = True 16 | warn_return_any = True 17 | no_implicit_any = True 18 | no_implicit_optional = True 19 | strict_optional = True 20 | 21 | ignore_missing_imports = True 22 | 23 | [mypy-session.*] 24 | ignore_errors = True -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## What is this? 2 | An example of how one can track state between multiple pages in *Streamlit* without reloading pages. 3 | 4 | Inspired by https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662 5 | 6 | In addition to this, it also demonstrates Mypy type-safety with classes, decorators, inheritance, wrapping Streamlit, etc. 7 | 8 | Instead of using `self` you'll find that all data lives in the session `sess`. 9 | 10 | Tested on Python 3.7 and Streamlit 0.73. 11 | 12 | ## How to run this 13 | 1. Create a venv and populate it with `pip install -r requirements.txt` 14 | 2. Generate some sample data for the app `python scripts/generate_data.py` 15 | 3. Start the app with `streamlit run src/app.py` 16 | 4. Run tests (there's one, I haven't gotten to that yet) with `pytest` -------------------------------------------------------------------------------- /src/lib/test_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import cast 4 | from unittest.mock import Mock 5 | 6 | from mock import patch 7 | 8 | from core.core import SessionState 9 | from lib.utils import sync_state_after 10 | from pages.app_page import AppPage 11 | 12 | 13 | class MockPage(AppPage): 14 | @sync_state_after 15 | def __call__(self, sess: SessionState) -> None: 16 | pass 17 | 18 | 19 | def test_autosync_decorator_calls_session_state() -> None: 20 | page: AppPage = MockPage() 21 | session_state: SessionState = cast(Mock, SessionState) 22 | 23 | with patch.object( 24 | session_state, SessionState.sync.__name__, return_value=None 25 | ) as mock: 26 | page(session_state) 27 | assert mock.called 28 | -------------------------------------------------------------------------------- /src/scripts/generate_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | 9 | path = Path(f"~/Desktop/test").expanduser() 10 | path.mkdir(exist_ok=True, parents=True) 11 | 12 | for i in range(10): 13 | slope = int(np.random.randint(50, 150, 1)) 14 | noise_max = int(np.random.randint(1, 5, 1)) 15 | 16 | noise = np.random.normal(0, noise_max, 100) 17 | 18 | ypts = np.random.normal(0, 15, 100) + np.linspace(0, slope, 100) 19 | xpts = np.linspace(0, 1, 100) 20 | 21 | plt.plot(xpts, ypts, alpha=0.5, color="black") 22 | 23 | filepath = path / f"test_file_{i}.csv" 24 | pd.DataFrame({"x": xpts, "y": ypts}).to_csv(filepath) 25 | 26 | plt.title(f"Generated files in {path}") 27 | plt.show() 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | venv/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | .pytest_cache 46 | 47 | # Translations 48 | *.mo 49 | *.pot 50 | 51 | # Django stuff: 52 | *.log 53 | 54 | # Sphinx documentation 55 | docs/_build/ 56 | 57 | # PyBuilder 58 | target/ 59 | 60 | # Auto-generated type hints 61 | typings 62 | 63 | # Local file for messing around 64 | playground.py 65 | 66 | # Pycharm 67 | .idea/ -------------------------------------------------------------------------------- /src/lib/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import wraps 3 | from typing import Any, Callable, TypeVar 4 | 5 | from core.core import SessionState 6 | from pages.app_page import AppPage 7 | 8 | 9 | def timeit(f: Callable) -> Callable: 10 | """Decorator to time functions and methods for optimization""" 11 | 12 | def _print_elapsed(start: float, end: float, name: str = "") -> None: 13 | print("'{}' {:.2f} ms".format(name, (end - start) * 1e3)) 14 | 15 | @wraps(f) 16 | def _timed(*args: Any, **kwargs: Any) -> Any: 17 | ts = time.time() 18 | result = f(*args, **kwargs) 19 | te = time.time() 20 | _print_elapsed(name=f.__name__, start=ts, end=te) 21 | return result 22 | 23 | return _timed 24 | 25 | 26 | # equivalent to C# TPage where TPage : AppPage 27 | TPage = TypeVar("TPage", bound=AppPage) 28 | 29 | 30 | def sync_state_after( 31 | f: Callable[[TPage, SessionState], None] 32 | ) -> Callable[[TPage, SessionState], None]: 33 | """ 34 | Synchronizes state after a page's __call__ method has been invoked 35 | """ 36 | 37 | def decorated(self: TPage, sess: SessionState) -> None: 38 | f(self, sess) 39 | sess.sync() 40 | 41 | return decorated 42 | -------------------------------------------------------------------------------- /src/app.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import streamlit as st 4 | 5 | from core.core import get_session_state, SessionState 6 | from lib.serializer import CustomSerializer 7 | from lib.utils import sync_state_after 8 | from pages.app_page import AppPage 9 | from pages.fit_page import FitterPage 10 | from pages.loader_page import LoaderPage 11 | from pages.plotter_page import PlotterPage 12 | from ui import sidebar 13 | 14 | 15 | class Main(AppPage): 16 | @sync_state_after 17 | def __call__(self, sess: SessionState) -> None: 18 | pages: Dict[str, AppPage] = { 19 | "Load files": LoaderPage(serializer=CustomSerializer()), 20 | "Fit": FitterPage(), 21 | "Plot": PlotterPage(), 22 | } 23 | 24 | st.sidebar.title(":floppy_disk: Page states") 25 | selected_page = sidebar.radio_button("Select your page", list(pages.keys())) 26 | 27 | self.display_page_with_session(sess, pages, selected_page) 28 | 29 | @staticmethod 30 | def display_page_with_session( 31 | sess: SessionState, pages: Dict[str, AppPage], selected_page: str 32 | ) -> None: 33 | pages[selected_page](sess) 34 | 35 | 36 | if __name__ == "__main__": 37 | Main()(get_session_state()) 38 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "venvPath": "venv", 3 | "useLibraryCodeForTypes": true, 4 | "typeCheckingMode": "basic", 5 | "enableTypeIgnoreComments": true, 6 | "reportMissingTypeArgument": "error", 7 | "reportUndefinedVariable": "error", 8 | "strictListInference": true, 9 | "strictDictionaryInference": true, 10 | "strictParameterNoneValue": true, 11 | "reportImplicitStringConcatenation": "error", 12 | "reportUnusedVariable": "error", 13 | "reportOptionalCall": "error", 14 | "reportPrivateUsage": "error", 15 | "reportMissingTypeStubs": "error", 16 | "reportOptionalOperand": "error", 17 | "reportIncompatibleVariableOverride": "error", 18 | "reportUnknownVariableType": true, 19 | "reportAssertAlwaysTrue": true, 20 | "reportUntypedBaseClass": true, 21 | "reportUntypedClassDecorator": true, 22 | "reportUntypedNamedTuple": true, 23 | "reportGeneralTypeIssues": "error", 24 | "reportConstantRedefinition": "error", 25 | "reportWildcardImportFromLibrary": "error", 26 | "reportIncompatibleMethodOverride": "error", 27 | "reportUnnecessaryCast": "warning", 28 | "reportUnnecessaryIsInstance": "warning", 29 | "reportInvalidTypeVarUse": "error", 30 | "reportUnknownMemberType": "none" 31 | } -------------------------------------------------------------------------------- /src/pages/loader_page.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pandas as pd 4 | import streamlit as st 5 | from streamlit.uploaded_file_manager import UploadedFile 6 | 7 | from core.core import SessionState, Trace 8 | from lib.serializer import CustomSerializer 9 | from lib.utils import sync_state_after 10 | from pages.app_page import AppPage 11 | from ui import widgets 12 | 13 | 14 | class LoaderPage(AppPage): 15 | def __init__(self, serializer: CustomSerializer): 16 | # Not actually using this for anything. Just demonstrating DI... 17 | self.serializer = serializer 18 | 19 | @sync_state_after 20 | def __call__(self, sess: SessionState) -> None: 21 | st.title(":floppy_disk: Load data here") 22 | 23 | uploaded = widgets.file_uploader("Load file") 24 | st.write(uploaded) 25 | 26 | sess.data.all_filenames = [file.name for file in uploaded] 27 | sess.data.selected_filenames = st.multiselect( 28 | "Select trace(s) to plot", 29 | sess.data.all_filenames, 30 | sess.data.selected_filenames, 31 | ) 32 | 33 | traces = [self.parse(file) for file in uploaded] 34 | 35 | sess.data.set_traces(sess.data.all_filenames, traces) 36 | 37 | @staticmethod 38 | def parse(file: UploadedFile) -> Trace: 39 | return Trace(name=file.name, data=pd.read_csv(file)) 40 | -------------------------------------------------------------------------------- /src/pages/plotter_page.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import streamlit as st 6 | from matplotlib.pyplot import Figure, Axes 7 | 8 | from core.core import SessionState 9 | from lib.utils import sync_state_after, timeit 10 | from pages.app_page import AppPage 11 | 12 | 13 | class PlotterPage(AppPage): 14 | def __init__(self) -> None: 15 | pass 16 | 17 | @sync_state_after 18 | def __call__(self, sess: SessionState) -> None: 19 | st.title(":icecream: Plotting page") 20 | st.subheader(":point_left: Select traces to plot from *Load files*") 21 | traces = sess.data.get_selected_traces() 22 | 23 | if traces: 24 | st.header("Traces") 25 | selected_traces_array = sess.data.traces_to_array(traces, use_y_fit=False) 26 | if selected_traces_array is not None: 27 | self.plot(selected_traces_array, color="black") 28 | 29 | st.header("Fits") 30 | fitted_traces_array = sess.data.traces_to_array(traces, use_y_fit=True) 31 | if fitted_traces_array is not None: 32 | self.plot(fitted_traces_array, color="firebrick") 33 | else: 34 | st.subheader("There are no fits!") 35 | 36 | @staticmethod 37 | def plot(traces: np.ndarray[np.float64], color: str) -> None: 38 | ypts = traces 39 | xpts = np.linspace(0, ypts.shape[1], ypts.shape[1]) 40 | 41 | fig: Figure 42 | ax: Axes 43 | 44 | fig, ax = plt.subplots() 45 | for n in range(ypts.shape[0]): 46 | ax.plot(xpts, ypts[n, ...], alpha=0.5, color=color) # type: ignore 47 | 48 | st.write(fig) 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | altair==4.1.0 2 | appdirs==1.4.4 3 | appnope==0.1.2 4 | argon2-cffi==20.1.0 5 | astor==0.8.1 6 | async-generator==1.10 7 | attrs==20.3.0 8 | backcall==0.2.0 9 | base58==2.0.1 10 | black==20.8b1 11 | bleach==3.2.1 12 | blinker==1.4 13 | cachetools==4.2.0 14 | certifi==2020.12.5 15 | cffi==1.14.4 16 | chardet==4.0.0 17 | click==7.1.2 18 | cycler==0.10.0 19 | data-science-types==0.2.22 20 | decorator==4.4.2 21 | defusedxml==0.6.0 22 | entrypoints==0.3 23 | gitdb==4.0.5 24 | GitPython==3.1.11 25 | idna==2.10 26 | importlib-metadata==3.3.0 27 | iniconfig==1.1.1 28 | ipykernel==5.4.2 29 | ipython==7.19.0 30 | ipython-genutils==0.2.0 31 | ipywidgets==7.6.2 32 | jedi==0.18.0 33 | Jinja2==2.11.2 34 | jsonpickle==1.4.2 35 | jsonschema==3.2.0 36 | jupyter-client==6.1.7 37 | jupyter-core==4.7.0 38 | jupyterlab-pygments==0.1.2 39 | jupyterlab-widgets==1.0.0 40 | kiwisolver==1.3.1 41 | MarkupSafe==1.1.1 42 | matplotlib==3.3.3 43 | mistune==0.8.4 44 | mock==4.0.3 45 | mypy==0.790 46 | mypy-extensions==0.4.3 47 | nbclient==0.5.1 48 | nbconvert==6.0.7 49 | nbformat==5.0.8 50 | nest-asyncio==1.4.3 51 | notebook==6.1.6 52 | numpy==1.19.4 53 | packaging==20.8 54 | pandas==1.2.0 55 | pandocfilters==1.4.3 56 | parso==0.8.1 57 | pathspec==0.8.1 58 | pexpect==4.8.0 59 | pickleshare==0.7.5 60 | Pillow==8.1.0 61 | pluggy==0.13.1 62 | prometheus-client==0.9.0 63 | prompt-toolkit==3.0.8 64 | protobuf==3.14.0 65 | ptyprocess==0.7.0 66 | py==1.10.0 67 | pyarrow==2.0.0 68 | pycparser==2.20 69 | pydeck==0.5.0 70 | Pygments==2.7.3 71 | pyparsing==2.4.7 72 | pyrsistent==0.17.3 73 | pytest==6.2.1 74 | python-dateutil==2.8.1 75 | pytz==2020.5 76 | pyzmq==20.0.0 77 | regex==2020.11.13 78 | requests==2.25.1 79 | scipy==1.6.0 80 | Send2Trash==1.5.0 81 | six==1.15.0 82 | smmap==3.0.4 83 | streamlit==0.73.1 84 | terminado==0.9.1 85 | testpath==0.4.4 86 | toml==0.10.2 87 | toolz==0.11.1 88 | tornado==6.1 89 | traitlets==5.0.5 90 | typed-ast==1.4.2 91 | typing-extensions==3.7.4.3 92 | tzlocal==2.1 93 | urllib3==1.26.2 94 | validators==0.18.2 95 | watchdog==1.0.2 96 | wcwidth==0.2.5 97 | webencodings==0.5.1 98 | widgetsnbextension==3.5.1 99 | zipp==3.4.0 100 | -------------------------------------------------------------------------------- /src/pages/fit_page.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Callable 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import streamlit as st 8 | from scipy.optimize import curve_fit 9 | 10 | from core.core import SessionState, Trace 11 | from lib.utils import sync_state_after 12 | from pages.app_page import AppPage 13 | from ui import widgets 14 | 15 | 16 | class FitterPage(AppPage): 17 | def __init__(self) -> None: 18 | pass 19 | 20 | @sync_state_after 21 | def __call__(self, sess: SessionState) -> None: 22 | st.title(":rocket: Single trace fitting page") 23 | 24 | selection = widgets.select_box("Select trace to fit", sess.data.all_filenames) 25 | if selection is None: 26 | return 27 | 28 | selected_trace = sess.data.get_trace(selection) 29 | 30 | fit_single = st.button("Fit current trace") 31 | fit_all = st.button("Fit all traces") 32 | clear_single = st.button("Clear current fit") 33 | clear_all = st.button("Clear all fits") 34 | 35 | if fit_single: 36 | self._fit_single(selection, sess) 37 | 38 | if clear_single: 39 | sess.data.clear_single_fit(selection) 40 | 41 | if fit_all: 42 | self._fit_all(sess) 43 | 44 | if clear_all: 45 | sess.data.clear_all_fits() 46 | 47 | self.plot(selected_trace) 48 | 49 | def _fit_single(self, selection: str, sess: SessionState) -> None: 50 | fitted_trace = self.fit(self.line, sess.data.get_trace(selection)) 51 | sess.data.set_trace(fitted_trace) 52 | 53 | def _fit_all(self, sess: SessionState) -> None: 54 | traces = sess.data.get_all_traces() 55 | for trace in traces: 56 | fitted_trace = self.fit(self.line, trace) 57 | sess.data.set_trace(fitted_trace) 58 | 59 | @staticmethod 60 | def line( 61 | x: np.ndarray[np.float64], slope: float, intercept: float 62 | ) -> np.ndarray[np.float64]: 63 | return slope * x + intercept 64 | 65 | @staticmethod 66 | def fit(f: Callable, trace: Trace) -> Trace: 67 | popt, pcov = curve_fit(f, xdata=trace.x, ydata=trace.y) 68 | 69 | trace.y_fit = f(trace.x, *popt) 70 | 71 | return trace 72 | 73 | @staticmethod 74 | def plot(trace: Trace) -> None: 75 | x = trace.x.values 76 | y = trace.y.values 77 | 78 | fig, ax = plt.subplots() 79 | 80 | ax.plot(x, y, color="black") 81 | if trace.y_fit is not None: 82 | y_fit: np.ndarray[np.float64] = trace.y_fit.values 83 | ax.plot(x, y_fit, color="firebrick", linestyle="-") 84 | 85 | st.write(fig) 86 | -------------------------------------------------------------------------------- /src/core/core.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC 4 | from typing import Optional, Any, Dict, Union, List 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from streamlit.hashing import _CodeHasher 9 | from streamlit.report_session import ReportSession 10 | from streamlit.report_thread import get_report_ctx 11 | from streamlit.server.server import Server 12 | 13 | 14 | class BaseSessionState(ABC): 15 | """ 16 | Adapted from https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92 17 | 18 | This object is very dynamic in order to set and retrieve any value, at any time, 19 | so it inherits from AvailableSessionItems to provide reliable types and autocomplete 20 | """ 21 | 22 | def __init__( 23 | self, 24 | session: Union[SessionState, ReportSession], 25 | ) -> None: 26 | 27 | self.__dict__["_state"] = { 28 | "data": {}, 29 | "hash": None, 30 | "hasher": _CodeHasher(), 31 | "is_rerun": False, 32 | "session": session, 33 | } 34 | 35 | def __call__(self, **kwargs: Dict[str, Any]) -> None: 36 | """Initialize session data once.""" 37 | for item, value in kwargs.items(): 38 | if item not in self._state["data"]: 39 | self._state["data"][item] = value 40 | 41 | def __getitem__(self, item: Any) -> Any: 42 | """Return a saved session value, None if item is undefined.""" 43 | return self._state["data"].get(item, None) 44 | 45 | def __getattr__(self, item: Any) -> Any: 46 | """Return a saved session value, None if item is undefined.""" 47 | return self._state["data"].get(item, None) 48 | 49 | def __setitem__(self, item: str, value: Any) -> None: 50 | """Set session value.""" 51 | self._state["data"][item] = value 52 | 53 | def __setattr__(self, item: str, value: Any) -> None: 54 | """Set session value.""" 55 | self._state["data"][item] = value 56 | 57 | def clear(self) -> None: 58 | """Clear session session and request a rerun.""" 59 | self._state["data"].clear() 60 | self._state["session"].request_rerun() 61 | 62 | def sync(self) -> None: 63 | """ 64 | Rerun the app with all session values up to date from the beginning to fix 65 | rollbacks. 66 | """ 67 | 68 | # Ensure to rerun only once to avoid infinite loops 69 | # caused by a constantly changing session value at each run. 70 | # 71 | # Example: session.value += 1 72 | if self._state["is_rerun"]: 73 | self._state["is_rerun"] = False 74 | 75 | elif self._state["hash"] is not None: 76 | if self._state["hash"] != self._state["hasher"].to_bytes( 77 | self._state["data"], None 78 | ): 79 | self._state["is_rerun"] = True 80 | self._state["session"].request_rerun() 81 | 82 | self._state["hash"] = self._state["hasher"].to_bytes(self._state["data"], None) 83 | 84 | 85 | class LoadedData: 86 | """Typesafe container for loaded data""" 87 | 88 | def __init__(self) -> None: 89 | self._traces: Dict[str, Trace] = {} 90 | self.all_filenames: List[str] = [] 91 | self.selected_filenames: List[str] = [] 92 | 93 | def set_traces(self, names: List[str], traces: List[Trace]) -> None: 94 | for name, trace in zip(names, traces): 95 | if name in self._traces: 96 | continue 97 | self._traces[name] = trace 98 | 99 | self.all_filenames = names 100 | 101 | def get_trace(self, name: str) -> Trace: 102 | return self._traces[name] 103 | 104 | def set_trace(self, trace: Trace) -> None: 105 | self._traces[trace.name] = trace 106 | 107 | def get_all_traces(self) -> List[Trace]: 108 | return list(self._traces.values()) 109 | 110 | def get_selected_traces(self) -> List[Trace]: 111 | return [self._traces[name] for name in self.selected_filenames] 112 | 113 | def clear_single_fit(self, name: str) -> None: 114 | self._traces[name].clear_fit() 115 | 116 | def clear_all_fits(self) -> None: 117 | for name in self.all_filenames: 118 | self._traces[name].clear_fit() 119 | 120 | @staticmethod 121 | def traces_to_array( 122 | traces: List[Trace], use_y_fit: bool 123 | ) -> Optional[np.ndarray[np.float64]]: 124 | 125 | if not traces: 126 | return None 127 | 128 | if use_y_fit: 129 | y_data = [trace.y_fit.values for trace in traces if trace.y_fit is not None] 130 | if not y_data: 131 | return None 132 | else: 133 | y_data = [trace.y.values for trace in traces] 134 | return np.stack(y_data) 135 | 136 | 137 | class SessionState(BaseSessionState): 138 | """A catalogue to provide type-safe access to certain attributes""" 139 | 140 | def __init__( 141 | self, 142 | session: Union[SessionState, ReportSession], 143 | data: LoadedData, 144 | ): 145 | super().__init__(session) 146 | self.data = data 147 | 148 | 149 | def _get_report_session() -> ReportSession: 150 | try: 151 | session_id: str = get_report_ctx().session_id 152 | except AttributeError: 153 | raise RuntimeError("Couldn't start Streamlit application.") 154 | session_info = Server.get_current()._get_session_info(session_id) 155 | 156 | if session_info is None: 157 | raise RuntimeError("Couldn't get your Streamlit Session object.") 158 | 159 | return session_info.session 160 | 161 | 162 | def get_session_state() -> SessionState: 163 | """Gets the session state that is fed through the main call entrypoint""" 164 | 165 | streamlit_session = _get_report_session() 166 | loaded_data_cache = LoadedData() 167 | 168 | if not hasattr(streamlit_session, "_custom_session_state"): 169 | streamlit_session._custom_session_state = SessionState( 170 | streamlit_session, 171 | loaded_data_cache, 172 | ) 173 | 174 | session_state: SessionState = streamlit_session._custom_session_state 175 | 176 | return session_state 177 | 178 | 179 | class Trace: 180 | """A single trace entity""" 181 | 182 | def __init__(self, name: str, data: pd.DataFrame): 183 | self.name = name 184 | self.data = self._validate_data(data) 185 | self.x: pd.Series[float] = self.data["x"] 186 | self.y: pd.Series[float] = self.data["y"] 187 | self.y_fit: Optional[np.ndarray[np.float64]] = None 188 | 189 | @staticmethod 190 | def _validate_data(data: pd.DataFrame) -> pd.DataFrame: 191 | for c in "x", "y": 192 | if c not in data.columns: 193 | raise ValueError 194 | return data 195 | 196 | def clear_fit(self) -> None: 197 | self.y_fit = None 198 | --------------------------------------------------------------------------------