├── .github
└── workflows
│ ├── publish.yaml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── .readthedocs.yaml
├── Dockerfile
├── LICENSE
├── README.md
├── app
├── .gitignore
├── README.md
├── biome.json
├── bun.lockb
├── decs.d.ts
├── index.html
├── package.json
├── public
│ └── zndraw.png
├── src
│ ├── App.css
│ ├── App.tsx
│ ├── assets
│ │ └── react.svg
│ ├── components
│ │ ├── api.tsx
│ │ ├── cameraAndControls.tsx
│ │ ├── data.tsx
│ │ ├── floor.tsx
│ │ ├── geometries.tsx
│ │ ├── headbar.tsx
│ │ ├── lines.tsx
│ │ ├── meshes.tsx
│ │ ├── overlays.tsx
│ │ ├── particles.tsx
│ │ ├── particlesEditor.tsx
│ │ ├── plotting.tsx
│ │ ├── progressbar.tsx
│ │ ├── sidebar.tsx
│ │ ├── tooltips.tsx
│ │ ├── transforms.tsx
│ │ ├── utils.tsx
│ │ ├── utils
│ │ │ └── mergeInstancedMesh.tsx
│ │ └── vectorfield.tsx
│ ├── index.css
│ ├── main.tsx
│ ├── socket.tsx
│ └── vite-env.d.ts
├── tsconfig.json
├── tsconfig.node.json
└── vite.config.ts
├── docker-compose.yaml
├── docs
├── Makefile
├── make.bat
└── source
│ ├── 2025
│ └── 01.rst
│ ├── _static
│ ├── zndraw-dark.svg
│ └── zndraw-light.svg
│ ├── adventofcode.rst
│ ├── conf.py
│ ├── index.rst
│ └── python-api.rst
├── examples
├── md.ipynb
├── molecules.ipynb
└── stress_testing
│ ├── README.md
│ ├── multi_connection.py
│ └── single_connection.py
├── misc
├── darkmode
│ ├── analysis.png
│ ├── box.png
│ ├── overview.png
│ └── python.png
└── lightmode
│ ├── analysis.png
│ ├── box.png
│ ├── overview.png
│ └── python.png
├── pyproject.toml
├── tests
├── conftest.py
├── test_analysis.py
├── test_bookmarks.py
├── test_camera.py
├── test_config.py
├── test_figures.py
├── test_geometries.py
├── test_modifier.py
├── test_points.py
├── test_selection.py
├── test_serializer.py
├── test_step.py
├── test_tasks.py
├── test_utils.py
├── test_vectorfields.py
├── test_vis.py
└── test_zndraw.py
├── uv.lock
├── zndraw
├── .gitignore
├── __init__.py
├── abc.py
├── analyse
│ └── __init__.py
├── app.py
├── base.py
├── bonds
│ └── __init__.py
├── config.py
├── converter.py
├── draw
│ └── __init__.py
├── exceptions.py
├── figure.py
├── modify
│ ├── __init__.py
│ └── private.py
├── queue.py
├── selection
│ └── __init__.py
├── server
│ ├── __init__.py
│ ├── events.py
│ └── routes.py
├── standalone.py
├── tasks
│ └── __init__.py
├── type_defs.py
├── upload.py
├── utils.py
└── zndraw.py
└── zndraw_app
├── README.md
├── __init__.py
├── cli.py
├── healthcheck.py
└── make_celery.py
/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | # This workflow uses actions that are not certified by GitHub.
2 | # They are provided by a third-party and are governed by
3 | # separate terms of service, privacy policy, and support
4 | # documentation.
5 |
6 | # GitHub recommends pinning actions to a commit SHA.
7 | # To get a newer version, you will need to update the SHA.
8 | # You can also reference a tag or branch, but the action may change without warning.
9 |
10 | name: Publish Docker image
11 |
12 | on:
13 | release:
14 | types: [published]
15 |
16 | jobs:
17 | push_to_registry:
18 | name: Push Docker image to Docker Hub
19 | runs-on: ubuntu-latest
20 | steps:
21 | - name: Check out the repo
22 | uses: actions/checkout@v4
23 |
24 | - name: Log in to Docker Hub
25 | uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a
26 | with:
27 | username: ${{ secrets.DOCKER_USERNAME }}
28 | password: ${{ secrets.DOCKER_ACCESS_TOKEN }}
29 |
30 | - name: Extract metadata (tags, labels) for Docker
31 | id: meta
32 | uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
33 | with:
34 | images: pythonf/zndraw
35 |
36 | - name: Build and push Docker image
37 | uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671
38 | with:
39 | context: .
40 | file: ./Dockerfile
41 | push: true
42 | tags: ${{ steps.meta.outputs.tags }}
43 | labels: ${{ steps.meta.outputs.labels }}
44 |
45 | publish-pypi:
46 | runs-on: ubuntu-latest
47 | steps:
48 | - uses: actions/checkout@v4
49 | - name: Install uv
50 | uses: astral-sh/setup-uv@v5
51 | - name: build frontend
52 | run: |
53 | npm install -g bun
54 | cd app && bun install && bun vite build && cd ..
55 | - name: Publish
56 | env:
57 | PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
58 | run: |
59 | uv build
60 | uv publish --token $PYPI_TOKEN
61 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: tests
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | schedule:
8 | - cron: "14 3 * * 1" # at 03:14 on Monday.
9 |
10 | jobs:
11 | pytest:
12 | runs-on: ${{ matrix.os }}
13 | strategy:
14 | fail-fast: false
15 | matrix:
16 | python-version:
17 | - "3.13"
18 | - "3.12"
19 | - "3.11"
20 | - "3.10"
21 | os:
22 | - ubuntu-latest
23 |
24 | services:
25 | # Label used to access the service container
26 | redis:
27 | # Docker Hub image
28 | image: redis
29 | # Set health checks to wait until redis has started
30 | options: >-
31 | --health-cmd "redis-cli ping"
32 | --health-interval 10s
33 | --health-timeout 5s
34 | --health-retries 5
35 | ports:
36 | # Maps port 6379 on service container to the host
37 | - 6379:6379
38 |
39 | steps:
40 | - uses: actions/checkout@v4
41 | - name: Install uv and set the python version
42 | uses: astral-sh/setup-uv@v5
43 | with:
44 | python-version: ${{ matrix.python-version }}
45 | - name: Install package
46 | run: |
47 | uv sync --all-extras --dev
48 | - name: Pytest
49 | run: |
50 | uv run python --version
51 | uv run pytest --cov --junitxml=junit.xml -o junit_family=legacy
52 | - name: Upload coverage to Codecov
53 | uses: codecov/codecov-action@v5
54 | with:
55 | token: ${{ secrets.CODECOV_TOKEN }}
56 | - name: Upload test results to Codecov
57 | if: ${{ !cancelled() }}
58 | uses: codecov/test-results-action@v1
59 | with:
60 | token: ${{ secrets.CODECOV_TOKEN }}
61 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | tmp/
163 | .vscode
164 | data/
165 | .zndraw/
166 | control/
167 | .DS_Store
168 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: check-added-large-files
8 | - id: check-case-conflict
9 | - id: check-docstring-first
10 | - id: check-executables-have-shebangs
11 | # - id: check-json
12 | - id: check-merge-conflict
13 | args: ["--assume-in-merge"]
14 | - id: check-toml
15 | - id: check-yaml
16 | - id: debug-statements
17 | - id: end-of-file-fixer
18 | - id: mixed-line-ending
19 | args: ["--fix=lf"]
20 | - id: sort-simple-yaml
21 | - id: trailing-whitespace
22 | - repo: https://github.com/codespell-project/codespell
23 | rev: v2.4.1
24 | hooks:
25 | - id: codespell
26 | additional_dependencies: ["tomli"]
27 | - repo: https://github.com/biomejs/pre-commit
28 | rev: v0.6.1 # Use the sha / tag you want to point at
29 | hooks:
30 | - id: biome-format # not using check becasue there are lots of things that need fixed
31 | additional_dependencies: ["@biomejs/biome@1.9.4"]
32 | - repo: https://github.com/astral-sh/ruff-pre-commit
33 | rev: v0.9.6
34 | hooks:
35 | - id: ruff
36 | args: [--fix]
37 | - id: ruff-format
38 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.11
2 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | version: 2
6 |
7 | submodules:
8 | include: all
9 |
10 | # Set the version of Python and other tools you might need
11 | build:
12 | os: ubuntu-22.04
13 | tools:
14 | python: "3.11"
15 | jobs:
16 | post_install:
17 | # see https://github.com/astral-sh/uv/issues/10074
18 | - pip install uv
19 | - UV_PROJECT_ENVIRONMENT=$READTHEDOCS_VIRTUALENV_PATH uv sync --all-extras --link-mode=copy --group=docs
20 |
21 | # Build documentation in the docs/ directory with Sphinx
22 | sphinx:
23 | configuration: docs/source/conf.py
24 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.12
2 | SHELL ["/bin/bash", "--login", "-c"]
3 |
4 | WORKDIR /usr/src/app
5 |
6 | # required for h5py
7 | RUN apt update && apt install -y gcc pkg-config libhdf5-dev build-essential
8 | RUN curl -fsSL https://bun.sh/install | bash
9 |
10 | COPY ./ ./
11 | RUN cd app && bun install && bun vite build && cd ..
12 | RUN pip install -e .
13 |
14 | ENTRYPOINT ["zndraw", "--port", "5003", "--no-browser"]
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | 
8 | 
9 |
10 | [](https://github.com/zincware)
11 | [](https://badge.fury.io/py/zndraw)
12 | [](https://arxiv.org/abs/2402.08708)
13 | [](https://codecov.io/gh/zincware/ZnDraw)
14 | [](https://discord.gg/7ncfwhsnm4)
15 | [](https://zndraw.readthedocs.io/en/latest/?badge=latest)
16 | 
17 |
18 |
19 |
20 | # ZnDraw - Display and Edit Molecules
21 | Welcome to ZnDraw, a powerful tool for visualizing and interacting with your trajectories.
22 |
23 | ## Installation
24 |
25 | You can install ZnDraw directly from PyPi via:
26 |
27 | ```bash
28 | pip install zndraw
29 | ```
30 |
31 | ## Quick Start
32 |
33 | Visualize your trajectories with a single command:
34 |
35 | ```bash
36 | zndraw
37 | ```
38 |
39 | > [!NOTE]
40 | > ZnDraw's webapp-based approach allows you to use port forwarding to work with trajectories on remote systems.
41 |
42 | 
43 | 
44 |
45 | ## Multi-User and Multi-Client Support
46 |
47 | ZnDraw supports multiple users and clients. Connect one or more Python clients to your ZnDraw instance:
48 |
49 | 1. Click on `Python access` in the ZnDraw UI.
50 | 2. Connect using the following code:
51 |
52 | ```python
53 | from zndraw import ZnDraw
54 |
55 | vis = ZnDraw(url="http://localhost:1234", token="")
56 | ```
57 |
58 | 
59 | 
60 |
61 | The `vis` object provides direct access to your visualized scene. It inherits from `abc.MutableSequence`, so any changes you make are reflected for all connected clients.
62 |
63 | ```python
64 | from ase.collections import s22
65 | vis.extend(list(s22))
66 | ```
67 |
68 | ## Additional Features
69 |
70 | You can modify various aspects of the visualization:
71 |
72 | - `vis.camera`
73 | - `vis.points`
74 | - `vis.selection`
75 | - `vis.step`
76 | - `vis.figures`
77 | - `vis.bookmarks`
78 | - `vis.geometries`
79 |
80 | For example, to add a geometry:
81 |
82 | ```python
83 | from zndraw import Box
84 |
85 | vis.geometries = [Box(position=[0, 1, 2])]
86 | ```
87 |
88 | 
89 | 
90 |
91 | ## Analyzing Data
92 |
93 | ZnDraw enables you to analyze your data and generate plots using [Plotly](https://plotly.com/). It automatically detects available properties and offers a convenient drop-down menu for selection.
94 |
95 | 
96 | 
97 |
98 | ZnDraw will look for the `step` and `atom` index in the [customdata](https://plotly.com/python/reference/scatter/#scatter-customdata)`[0]` and `[1]` respectively to highlight the steps and atoms.
99 |
100 | ## Writing Extensions
101 |
102 | Make your tools accessible via the ZnDraw UI by writing an extension:
103 |
104 | ```python
105 | from zndraw import Extension
106 |
107 | class AddMolecule(Extension):
108 | name: str
109 |
110 | def run(self, vis, **kwargs) -> None:
111 | structures = kwargs["structures"]
112 | vis.append(structures[self.name])
113 | vis.step = len(vis) - 1
114 |
115 | vis.register(AddMolecule, run_kwargs={"structures": s22}, public=True)
116 | vis.socket.wait() # This can be ignored when using Jupyter
117 | ```
118 |
119 | The `AddMolecule` extension will appear for all `tokens` and can be used by any client.
120 |
121 | # Hosted Version
122 |
123 | A hosted version of ZnDraw is available at https://zndraw.icp.uni-stuttgart.de . To upload data, use:
124 |
125 | ```bash
126 | zndraw --url https://zndraw.icp.uni-stuttgart.de
127 | ```
128 |
129 | ## Self-Hosting
130 |
131 | To host your own version of ZnDraw, use the following `docker-compose.yaml` setup:
132 |
133 | ```yaml
134 | version: "3.9"
135 |
136 | services:
137 | zndraw:
138 | image: pythonf/zndraw:latest
139 | command: --no-standalone /src/file.xyz
140 | volumes:
141 | - /path/to/files:/src
142 | restart: unless-stopped
143 | ports:
144 | - 5003:5003
145 | depends_on:
146 | - redis
147 | - worker
148 | environment:
149 | - FLASK_STORAGE=redis://redis:6379/0
150 | - FLASK_AUTH_TOKEN=super-secret-token
151 |
152 | worker:
153 | image: pythonf/zndraw:latest
154 | entrypoint: celery -A zndraw_app.make_celery worker --loglevel=info -P eventlet
155 | volumes:
156 | - /path/to/files:/src
157 | restart: unless-stopped
158 | depends_on:
159 | - redis
160 | environment:
161 | - FLASK_STORAGE=redis://redis:6379/0
162 | - FLASK_SERVER_URL="http://zndraw:5003"
163 | - FLASK_AUTH_TOKEN=super-secret-token
164 |
165 | redis:
166 | image: redis:latest
167 | restart: always
168 | environment:
169 | - REDIS_PORT=6379
170 | ```
171 |
172 | If you want to host zndraw as subdirectory `domain.com/zndraw` you need to adjust the environmental variables as well as update `base: "/",` in the `app/vite.config.ts` before building the ap..
173 |
174 | # References
175 |
176 | If you use ZnDraw in your research and find it helpful please cite us.
177 |
178 | ```bibtex
179 | @misc{elijosiusZeroShotMolecular2024,
180 | title = {Zero {{Shot Molecular Generation}} via {{Similarity Kernels}}},
181 | author = {Elijo{\v s}ius, Rokas and Zills, Fabian and Batatia, Ilyes and Norwood, Sam Walton and Kov{\'a}cs, D{\'a}vid P{\'e}ter and Holm, Christian and Cs{\'a}nyi, G{\'a}bor},
182 | year = {2024},
183 | eprint = {2402.08708},
184 | archiveprefix = {arxiv},
185 | }
186 | ```
187 |
188 | # Acknowledgements
189 |
190 | The creation of ZnDraw was supported by the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) in the framework of the priority program SPP 2363, “Utilization and Development of Machine Learning for Molecular Applications - Molecular Machine Learning” Project No. 497249646. Further funding though the DFG under Germany's Excellence Strategy - EXC 2075 - 390740016 and the Stuttgart Center for Simulation Science (SimTech) was provided.
191 |
--------------------------------------------------------------------------------
/app/.gitignore:
--------------------------------------------------------------------------------
1 | # Logs
2 | logs
3 | *.log
4 | npm-debug.log*
5 | yarn-debug.log*
6 | yarn-error.log*
7 | pnpm-debug.log*
8 | lerna-debug.log*
9 |
10 | node_modules
11 | dist
12 | dist-ssr
13 | *.local
14 |
15 | # Editor directories and files
16 | .vscode/*
17 | !.vscode/extensions.json
18 | .idea
19 | .DS_Store
20 | *.suo
21 | *.ntvs*
22 | *.njsproj
23 | *.sln
24 | *.sw?
25 |
26 | TODO.md
27 |
--------------------------------------------------------------------------------
/app/README.md:
--------------------------------------------------------------------------------
1 | # React + TypeScript + Vite
2 |
3 | This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
4 |
5 | Currently, two official plugins are available:
6 |
7 | - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh
8 | - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
9 |
10 | ## Expanding the ESLint configuration
11 |
12 | If you are developing a production application, we recommend updating the configuration to enable type aware lint rules:
13 |
14 | - Configure the top-level `parserOptions` property like this:
15 |
16 | ```js
17 | export default {
18 | // other rules...
19 | parserOptions: {
20 | ecmaVersion: "latest",
21 | sourceType: "module",
22 | project: ["./tsconfig.json", "./tsconfig.node.json"],
23 | tsconfigRootDir: __dirname,
24 | },
25 | };
26 | ```
27 |
28 | - Replace `plugin:@typescript-eslint/recommended` to `plugin:@typescript-eslint/recommended-type-checked` or `plugin:@typescript-eslint/strict-type-checked`
29 | - Optionally add `plugin:@typescript-eslint/stylistic-type-checked`
30 | - Install [eslint-plugin-react](https://github.com/jsx-eslint/eslint-plugin-react) and add `plugin:react/recommended` & `plugin:react/jsx-runtime` to the `extends` list
31 |
--------------------------------------------------------------------------------
/app/biome.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://biomejs.dev/schemas/1.9.4/schema.json",
3 | "vcs": {
4 | "enabled": false,
5 | "clientKind": "git",
6 | "useIgnoreFile": false
7 | },
8 | "files": {
9 | "ignoreUnknown": false,
10 | "ignore": []
11 | },
12 | "formatter": {
13 | "enabled": true,
14 | "indentStyle": "tab"
15 | },
16 | "organizeImports": {
17 | "enabled": true
18 | },
19 | "linter": {
20 | "enabled": true,
21 | "rules": {
22 | "recommended": true
23 | }
24 | },
25 | "javascript": {
26 | "formatter": {
27 | "quoteStyle": "double"
28 | }
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/app/bun.lockb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zincware/ZnDraw/0a71d450976df5e1889b99d4764c0e793268b55e/app/bun.lockb
--------------------------------------------------------------------------------
/app/decs.d.ts:
--------------------------------------------------------------------------------
1 | declare module "@json-editor/json-editor" {
2 | const JSONEditor: any;
3 | export default JSONEditor;
4 | }
5 |
--------------------------------------------------------------------------------
/app/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | ZnDraw
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/app/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "zndraw",
3 | "private": true,
4 | "version": "0.5.8",
5 | "type": "module",
6 | "scripts": {
7 | "dev": "vite",
8 | "build": "tsc && vite build",
9 | "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
10 | "preview": "vite preview",
11 | "format": "prettier . --write"
12 | },
13 | "dependencies": {
14 | "@json-editor/json-editor": "^2.15.2",
15 | "@react-three/drei": "^9.116.3",
16 | "@react-three/fiber": "^8.17.10",
17 | "@react-three/gpu-pathtracer": "^0.2.0",
18 | "@types/lodash": "^4.17.13",
19 | "@types/three": "^0.165.0",
20 | "bootstrap": "5.3.3",
21 | "lodash": "^4.17.21",
22 | "plotly.js": "^2.35.2",
23 | "react": "^18.3.1",
24 | "react-bootstrap": "^2.10.5",
25 | "react-dom": "^18.3.1",
26 | "react-icons": "^5.3.0",
27 | "react-markdown": "^9.0.1",
28 | "react-plotly.js": "^2.6.0",
29 | "react-rnd": "^10.4.13",
30 | "react-select": "^5.8.3",
31 | "react-syntax-highlighter": "^15.6.1",
32 | "rehype-katex": "^7.0.1",
33 | "rehype-raw": "^7.0.0",
34 | "remark-breaks": "^4.0.0",
35 | "remark-gfm": "^4.0.0",
36 | "remark-math": "^6.0.0",
37 | "socket.io-client": "^4.8.1",
38 | "three": "^0.165.0",
39 | "znsocket": "^0.2.6"
40 | },
41 | "devDependencies": {
42 | "@biomejs/biome": "1.9.4",
43 | "@types/react": "^18.3.12",
44 | "@types/react-dom": "^18.3.1",
45 | "@vitejs/plugin-react-swc": "^3.7.1",
46 | "typescript": "^5.6.3",
47 | "vite": "^5.4.11"
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/app/public/zndraw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zincware/ZnDraw/0a71d450976df5e1889b99d4764c0e793268b55e/app/public/zndraw.png
--------------------------------------------------------------------------------
/app/src/App.css:
--------------------------------------------------------------------------------
1 | #root {
2 | max-width: 1280px;
3 | margin: 0 auto;
4 | padding: 2rem;
5 | text-align: center;
6 | }
7 |
8 | .logo {
9 | height: 6em;
10 | padding: 1.5em;
11 | will-change: filter;
12 | transition: filter 300ms;
13 | }
14 | .logo:hover {
15 | filter: drop-shadow(0 0 2em #646cffaa);
16 | }
17 | .logo.react:hover {
18 | filter: drop-shadow(0 0 2em #61dafbaa);
19 | }
20 |
21 | @keyframes logo-spin {
22 | from {
23 | transform: rotate(0deg);
24 | }
25 | to {
26 | transform: rotate(360deg);
27 | }
28 | }
29 |
30 | @media (prefers-reduced-motion: no-preference) {
31 | a:nth-of-type(2) .logo {
32 | animation: logo-spin infinite 20s linear;
33 | }
34 | }
35 |
36 | .card {
37 | padding: 2em;
38 | }
39 |
40 | .read-the-docs {
41 | color: #888;
42 | }
43 |
44 | .canvas-container {
45 | position: fixed;
46 | top: 0;
47 | left: 0;
48 | width: 100%;
49 | height: 100%;
50 | z-index: 0;
51 | }
52 |
53 | .frame-progress-bar .progress-bar {
54 | transition: none;
55 | }
56 |
57 | .blur-bg-90 {
58 | opacity: 0.95 !important;
59 | -webkit-backdrop-filter: blur(10px);
60 | backdrop-filter: blur(10px);
61 | }
62 |
63 | .custom-modal .modal-dialog {
64 | max-width: 100%;
65 | }
66 |
67 | .custom-modal .modal-content {
68 | height: 80vh;
69 | display: flex;
70 | flex-direction: column;
71 | }
72 |
73 | .custom-modal .modal-body-custom {
74 | flex: 1;
75 | padding: 0;
76 | display: flex;
77 | }
78 |
79 | .custom-modal .iframe-custom {
80 | width: 100%;
81 | height: 100%;
82 | border: none;
83 | }
84 | :root {
85 | --handle-size: 12px; /* Base size for the handle */
86 | --handle-color: rgb(48, 75, 183); /* Color of the handle */
87 | }
88 |
89 | .square {
90 | width: var(--handle-size);
91 | height: var(--handle-size);
92 | background-color: var(--handle-color); /* Color of the square */
93 | border-left: calc(var(--handle-size) / 2) solid transparent;
94 | border-right: calc(var(--handle-size) / 2) solid transparent;
95 | }
96 |
97 | .triangle {
98 | width: 0;
99 | height: 0;
100 | border-left: calc(var(--handle-size) / 2) solid transparent;
101 | border-right: calc(var(--handle-size) / 2) solid transparent;
102 | border-top: calc(var(--handle-size) / 2) solid var(--handle-color); /* Color of the triangle */
103 | }
104 |
105 | .handle {
106 | position: absolute;
107 | top: -15px; /* Adjust this value based on your design */
108 | transform: translateX(-50%);
109 | display: flex;
110 | align-items: center;
111 | flex-direction: column;
112 | z-index: 2; /* Ensure it is above other elements */
113 | }
114 |
115 | .progress-bar-v-line {
116 | position: absolute;
117 | top: 0;
118 | left: 0;
119 | transform: translateX(-50%);
120 | width: 2px; /* Thickness of the line */
121 | height: 100%; /* Full height of the column */
122 | background-color: var(--handle-color); /* Color of the line */
123 | z-index: 1; /* Ensure it is above tiles but below the bookmark */
124 | }
125 |
126 | .progress-bar-tick-line {
127 | position: absolute;
128 | top: 0;
129 | left: 0;
130 | transform: translateX(-50%) translateY(-100%);
131 | width: 1px; /* Thickness of the line */
132 | height: 14%; /* Size of the tick */
133 | z-index: 1; /* Ensure it is visible */
134 | }
135 |
136 | .progress-bar-bookmark {
137 | position: absolute;
138 | top: 0;
139 | left: 0;
140 | /* z-index: 1; */
141 | transform: translateX(-50%);
142 | }
143 |
--------------------------------------------------------------------------------
/app/src/assets/react.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/app/src/components/cameraAndControls.tsx:
--------------------------------------------------------------------------------
1 | import {
2 | OrbitControls,
3 | OrthographicCamera,
4 | PerspectiveCamera,
5 | TrackballControls,
6 | } from "@react-three/drei";
7 | import { Box } from "@react-three/drei";
8 | import { debounce } from "lodash";
9 | import {
10 | forwardRef,
11 | useCallback,
12 | useEffect,
13 | useMemo,
14 | useRef,
15 | useState,
16 | } from "react";
17 | import * as THREE from "three";
18 | import { getCentroid, useCentroid } from "./particlesEditor";
19 |
20 | const zeroVector = new THREE.Vector3(0, 0, 0);
21 | const initialCameraVector = new THREE.Vector3(5, 5, 5);
22 | const upVector = new THREE.Vector3(0, 1, 0);
23 |
24 | const MoveCameraTarget = forwardRef(
25 | ({ colorMode }: { colorMode: string }, ref: React.Ref) => {
26 | const shortDimension = 0.05;
27 | const longDimension = 0.5;
28 |
29 | return (
30 |
31 | {/* X axis box */}
32 |
33 |
36 |
37 |
38 | {/* Y axis box */}
39 |
40 |
43 |
44 |
45 | {/* Z axis box */}
46 |
47 |
50 |
51 |
52 | );
53 | },
54 | );
55 |
56 | type CameraAndControls = {
57 | camera: THREE.Vector3;
58 | target: THREE.Vector3;
59 | };
60 |
61 | type CameraAndControlsProps = {
62 | cameraConfig: any;
63 | cameraAndControls: CameraAndControls;
64 | setCameraAndControls: React.Dispatch>;
65 | currentFrame: any;
66 | selectedIds: Set;
67 | colorMode: string;
68 | };
69 |
70 | const CameraAndControls: React.FC = ({
71 | cameraConfig,
72 | cameraAndControls,
73 | setCameraAndControls,
74 | currentFrame,
75 | selectedIds,
76 | colorMode,
77 | }) => {
78 | const cameraRef = useRef(null);
79 | const controlsRef = useRef(null);
80 | const centroid = useCentroid({
81 | frame: currentFrame,
82 | selectedIds: selectedIds,
83 | });
84 |
85 | // need this extra for the crosshair
86 | const crossHairRef = useRef(null);
87 |
88 | const controlsOnChangeFn = useCallback((e: any) => {
89 | if (!crossHairRef.current) {
90 | return;
91 | }
92 | crossHairRef.current.position.copy(e.target.target);
93 | }, []);
94 |
95 | const controlsOnEndFn = useCallback(
96 | debounce(() => {
97 | if (!cameraRef.current || !controlsRef.current) {
98 | return;
99 | }
100 | setCameraAndControls({
101 | camera: cameraRef.current.position,
102 | target: controlsRef.current.target,
103 | });
104 | }, 100),
105 | [],
106 | );
107 |
108 | const rollCamera = useCallback((angle: number) => {
109 | if (!cameraRef.current) {
110 | return;
111 | }
112 | // Get the current direction the camera is looking
113 | const looksTo = new THREE.Vector3();
114 | cameraRef.current.getWorldDirection(looksTo);
115 |
116 | // Calculate the rotation axis (perpendicular to both yDir and looksTo)
117 | const rotationAxis = new THREE.Vector3();
118 | rotationAxis.crossVectors(upVector, looksTo).normalize();
119 |
120 | // Compute the quaternion for the roll rotation
121 | const quaternion = new THREE.Quaternion();
122 | quaternion.setFromAxisAngle(looksTo, angle);
123 |
124 | // Update the `up` vector of the camera
125 | const newUp = new THREE.Vector3();
126 | newUp.copy(cameraRef.current.up).applyQuaternion(quaternion).normalize();
127 | cameraRef.current.up.copy(newUp);
128 |
129 | // Update the camera matrix and controls
130 | cameraRef.current.updateProjectionMatrix();
131 | if (controlsRef.current) {
132 | controlsRef.current.update();
133 | }
134 | }, []);
135 |
136 | const getResetCamera = useCallback(() => {
137 | if (currentFrame.positions.length === 0) {
138 | return;
139 | }
140 | if (cameraRef.current === null) {
141 | return;
142 | }
143 | // now calculate the camera positions
144 | const fullCentroid = getCentroid(currentFrame.positions, new Set());
145 |
146 | // Compute the bounding sphere radius
147 | let maxDistance = 0;
148 | currentFrame.positions.forEach((x) => {
149 | maxDistance = Math.max(maxDistance, x.distanceTo(fullCentroid));
150 | });
151 |
152 | const fov = (cameraRef.current.fov * Math.PI) / 180; // Convert FOV to radians
153 | let distance = maxDistance / Math.tan(fov / 2);
154 | // if distance is NaN, return - happens for OrthographicCamera
155 | if (Number.isNaN(distance)) {
156 | distance = 0;
157 | }
158 |
159 | return {
160 | camera: new THREE.Vector3(distance, distance, distance),
161 | target: fullCentroid,
162 | };
163 | }, [currentFrame.positions]);
164 |
165 | // if the camera positions and target positions is default, adapt them to the scene
166 | useEffect(() => {
167 | if (
168 | cameraAndControls.camera.equals(initialCameraVector) &&
169 | cameraAndControls.target.equals(zeroVector)
170 | ) {
171 | const resetCamera = getResetCamera();
172 | if (resetCamera) {
173 | setCameraAndControls(resetCamera);
174 | }
175 | }
176 | }, [currentFrame.positions]);
177 |
178 | // if the camera changes, run resetCamera
179 | useEffect(() => {
180 | const resetCamera = getResetCamera();
181 | if (resetCamera) {
182 | setCameraAndControls(resetCamera);
183 | }
184 | }, [cameraConfig.camera]);
185 |
186 | useEffect(() => {
187 | if (!cameraRef.current || !controlsRef.current) {
188 | return;
189 | }
190 | cameraRef.current.position.copy(cameraAndControls.camera);
191 | controlsRef.current.target.copy(cameraAndControls.target);
192 | controlsRef.current.update();
193 | }, [cameraAndControls]);
194 |
195 | // keyboard controls
196 | useEffect(() => {
197 | // page initialization
198 |
199 | const handleKeyDown = (event: KeyboardEvent) => {
200 | // if canvas is not focused, don't do anything
201 | if (document.activeElement !== document.body) {
202 | return;
203 | }
204 | if (event.key === "c") {
205 | setCameraAndControls((prev: any) => ({
206 | ...prev,
207 | target: centroid,
208 | }));
209 | } else if (event.key === "o") {
210 | const resetCamera = getResetCamera();
211 | if (resetCamera) {
212 | setCameraAndControls(resetCamera);
213 | }
214 | // reset the camera roll
215 | if (cameraRef.current) {
216 | cameraRef.current.up.copy(upVector);
217 | cameraRef.current.updateProjectionMatrix();
218 | if (controlsRef.current) {
219 | controlsRef.current.update();
220 | }
221 | }
222 | } else if (event.key === "r") {
223 | const roll = Math.PI / 100;
224 | if (event.ctrlKey) {
225 | rollCamera(-roll);
226 | } else {
227 | rollCamera(roll);
228 | }
229 | }
230 | };
231 |
232 | // Add the event listener
233 | window.addEventListener("keydown", handleKeyDown);
234 |
235 | // Clean up the event listener on unmount
236 | return () => {
237 | window.removeEventListener("keydown", handleKeyDown);
238 | };
239 | }, [currentFrame, selectedIds]);
240 |
241 | return (
242 | <>
243 | {cameraConfig.camera === "OrthographicCamera" && (
244 |
251 |
252 |
253 | )}
254 | {cameraConfig.camera === "PerspectiveCamera" && (
255 |
261 |
262 |
263 | )}
264 | {cameraConfig.controls === "OrbitControls" && (
265 |
272 | )}
273 | {cameraConfig.controls === "TrackballControls" && (
274 |
280 | )}
281 | {cameraConfig.crosshair && controlsRef.current.target && (
282 |
283 | )}
284 | >
285 | );
286 | };
287 |
288 | export default CameraAndControls;
289 |
--------------------------------------------------------------------------------
/app/src/components/data.tsx:
--------------------------------------------------------------------------------
1 | import { Color } from "three";
2 |
3 | export const JMOL_COLORS = [
4 | new Color("#ff0000"),
5 | new Color("#ffffff"),
6 | new Color("#d9ffff"),
7 | new Color("#cc80ff"),
8 | new Color("#c2ff00"),
9 | new Color("#ffb5b5"),
10 | new Color("#909090"),
11 | new Color("#2f50f8"),
12 | new Color("#ff0d0d"),
13 | new Color("#90df50"),
14 | new Color("#b3e2f5"),
15 | new Color("#ab5cf1"),
16 | new Color("#89ff00"),
17 | new Color("#bea6a6"),
18 | new Color("#efc79f"),
19 | new Color("#ff8000"),
20 | new Color("#ffff2f"),
21 | new Color("#1fef1f"),
22 | new Color("#80d1e2"),
23 | new Color("#8f40d3"),
24 | new Color("#3cff00"),
25 | new Color("#e6e6e6"),
26 | new Color("#bec2c6"),
27 | new Color("#a6a6ab"),
28 | new Color("#8999c6"),
29 | new Color("#9c79c6"),
30 | new Color("#df6633"),
31 | new Color("#ef909f"),
32 | new Color("#50d050"),
33 | new Color("#c78033"),
34 | new Color("#7c80af"),
35 | new Color("#c28f8f"),
36 | new Color("#668f8f"),
37 | new Color("#bc80e2"),
38 | new Color("#ffa000"),
39 | new Color("#a62929"),
40 | new Color("#5cb8d1"),
41 | new Color("#6f2daf"),
42 | new Color("#00ff00"),
43 | new Color("#93ffff"),
44 | new Color("#93dfdf"),
45 | new Color("#73c2c8"),
46 | new Color("#53b5b5"),
47 | new Color("#3a9e9e"),
48 | new Color("#238f8f"),
49 | new Color("#097c8b"),
50 | new Color("#006985"),
51 | new Color("#c0c0c0"),
52 | new Color("#ffd98f"),
53 | new Color("#a67573"),
54 | new Color("#668080"),
55 | new Color("#9e62b5"),
56 | new Color("#d37900"),
57 | new Color("#930093"),
58 | new Color("#429eaf"),
59 | new Color("#56168f"),
60 | new Color("#00c800"),
61 | new Color("#6fd3ff"),
62 | new Color("#ffffc6"),
63 | new Color("#d9ffc6"),
64 | new Color("#c6ffc6"),
65 | new Color("#a2ffc6"),
66 | new Color("#8fffc6"),
67 | new Color("#60ffc6"),
68 | new Color("#45ffc6"),
69 | new Color("#2fffc6"),
70 | new Color("#1fffc6"),
71 | new Color("#00ff9c"),
72 | new Color("#00e675"),
73 | new Color("#00d352"),
74 | new Color("#00be38"),
75 | new Color("#00ab23"),
76 | new Color("#4dc2ff"),
77 | new Color("#4da6ff"),
78 | new Color("#2093d5"),
79 | new Color("#257cab"),
80 | new Color("#256695"),
81 | new Color("#165386"),
82 | new Color("#d0d0df"),
83 | new Color("#ffd122"),
84 | new Color("#b8b8d0"),
85 | new Color("#a6534d"),
86 | new Color("#565860"),
87 | new Color("#9e4fb5"),
88 | new Color("#ab5c00"),
89 | new Color("#754f45"),
90 | new Color("#428295"),
91 | new Color("#420066"),
92 | new Color("#007c00"),
93 | new Color("#6fabf9"),
94 | new Color("#00b9ff"),
95 | new Color("#00a0ff"),
96 | new Color("#008fff"),
97 | new Color("#0080ff"),
98 | new Color("#006bff"),
99 | new Color("#535cf1"),
100 | new Color("#785ce2"),
101 | new Color("#894fe2"),
102 | new Color("#a036d3"),
103 | new Color("#b31fd3"),
104 | new Color("#b31fb9"),
105 | new Color("#b30da6"),
106 | new Color("#bc0d86"),
107 | new Color("#c60066"),
108 | new Color("#cc0058"),
109 | new Color("#d1004f"),
110 | new Color("#d90045"),
111 | new Color("#df0038"),
112 | new Color("#e6002d"),
113 | new Color("#eb0025"),
114 | ];
115 |
116 | export const covalentRadii = [
117 | 1, 0.31, 0.28, 1.28, 0.96, 0.84, 0.76, 0.71, 0.66, 0.57, 0.58, 1.66, 1.41,
118 | 1.21, 1.11, 1.07, 1.05, 1.02, 1.06, 2.03, 1.76, 1.7, 1.6, 1.53, 1.39, 1.39,
119 | 1.32, 1.26, 1.24, 1.32, 1.22, 1.22, 1.2, 1.19, 1.2, 1.2, 1.16, 2.2, 1.95, 1.9,
120 | 1.75, 1.64, 1.54, 1.47, 1.46, 1.42, 1.39, 1.45, 1.44, 1.42, 1.39, 1.39, 1.38,
121 | 1.39, 1.4, 2.44, 2.15, 2.07, 2.04, 2.03, 2.01, 1.99, 1.98, 1.98, 1.96, 1.94,
122 | 1.92, 1.92, 1.89, 1.9, 1.87, 1.87, 1.75, 1.7, 1.62, 1.51, 1.44, 1.41, 1.36,
123 | 1.36, 1.32, 1.45, 1.46, 1.48, 1.4, 1.5, 1.5, 2.6, 2.21, 2.15, 2.06, 2.0, 1.96,
124 | 1.9, 1.87, 1.8, 1.69,
125 | ];
126 |
--------------------------------------------------------------------------------
/app/src/components/floor.tsx:
--------------------------------------------------------------------------------
1 | import { Plane } from "@react-three/drei";
2 | import { Line } from "@react-three/drei";
3 | import { useEffect, useState } from "react";
4 | import * as THREE from "three";
5 |
6 | function Grid({
7 | position = [0, 0, 0],
8 | gridSpacing = 10,
9 | sizeX = 500,
10 | sizeY = 500,
11 | color = "black",
12 | }) {
13 | const lines = [];
14 |
15 | // Generate vertical lines
16 | for (let x = -sizeX / 2; x <= sizeX / 2; x += gridSpacing) {
17 | lines.push(
18 | ,
27 | );
28 | }
29 |
30 | // Generate horizontal lines
31 | for (let y = -sizeY / 2; y <= sizeY / 2; y += gridSpacing) {
32 | lines.push(
33 | ,
42 | );
43 | }
44 |
45 | return {lines};
46 | }
47 |
48 | export const Floor: any = ({ colorMode, roomConfig }: any) => {
49 | const [bsColor, setBsColor] = useState({
50 | "--bs-body-bg": "#fff",
51 | "--bs-secondary": "#fff",
52 | });
53 |
54 | useEffect(() => {
55 | setBsColor({
56 | "--bs-body-bg": getComputedStyle(
57 | document.documentElement,
58 | ).getPropertyValue("--bs-body-bg"),
59 | "--bs-secondary": getComputedStyle(
60 | document.documentElement,
61 | ).getPropertyValue("--bs-secondary"),
62 | });
63 | }, [colorMode]);
64 |
65 | return (
66 | <>
67 | {" "}
68 |
74 |
85 |
86 |
87 |
94 | >
95 | );
96 | };
97 |
--------------------------------------------------------------------------------
/app/src/components/lines.tsx:
--------------------------------------------------------------------------------
1 | import { CatmullRomLine, Dodecahedron } from "@react-three/drei";
2 | import { useThree } from "@react-three/fiber";
3 | import { useEffect, useState } from "react";
4 | import { socket } from "../socket";
5 |
6 | import * as THREE from "three";
7 |
8 | const findClosestPoint = (points: THREE.Vector3[], position: THREE.Vector3) => {
9 | const closestPoint = new THREE.Vector3();
10 | points.forEach((point) => {
11 | if (point.distanceTo(position) < closestPoint.distanceTo(position)) {
12 | closestPoint.copy(point);
13 | }
14 | });
15 | return closestPoint;
16 | };
17 |
18 | // TODO: ensure type consistency, every point/... should be THREE.Vector3
19 | export const Line3D = ({
20 | points,
21 | setPoints,
22 | setSelectedPoint,
23 | isDrawing,
24 | colorMode,
25 | hoveredId, // if null, hover virtual canvas -> close line
26 | setIsDrawing,
27 | setLineLength,
28 | }: {
29 | points: THREE.Vector3[];
30 | setPoints: any;
31 | setSelectedPoint: any;
32 | isDrawing: boolean;
33 | colorMode: string;
34 | hoveredId: number | null;
35 | setIsDrawing: any;
36 | setLineLength: (length: number) => void;
37 | }) => {
38 | // a virtual point is between every two points in the points array on the line
39 | const [virtualPoints, setVirtualPoints] = useState([]);
40 | const [lineColor, setLineColor] = useState("black");
41 | const [pointColor, setPointColor] = useState("black");
42 | const [virtualPointColor, setVirtualPointColor] = useState("darkcyan");
43 | const initalTriggerRef = useRef(true);
44 |
45 | useEffect(() => {
46 | // TODO: use bootstrap colors
47 | if ((hoveredId == null || hoveredId === -1) && isDrawing) {
48 | setLineColor("#f01d23");
49 | setPointColor("#710000");
50 | } else if (colorMode === "light") {
51 | setLineColor("#454b66");
52 | setPointColor("#191308");
53 | setVirtualPointColor("#677db7");
54 | } else {
55 | setLineColor("#f5fdc6");
56 | setPointColor("#41521f");
57 | setVirtualPointColor("#a89f68");
58 | }
59 | }, [colorMode, hoveredId, isDrawing]);
60 |
61 | const handleClick = (event: any) => {
62 | if (!isDrawing) {
63 | setSelectedPoint(event.object.position.clone());
64 | } else {
65 | if (hoveredId != null && hoveredId !== -1) {
66 | const point = event.point.clone();
67 | setPoints([...points, point]);
68 | } else {
69 | setIsDrawing(false);
70 | }
71 | }
72 | };
73 |
74 | const handleVirtualClick = (index: number, event: any) => {
75 | // make the virtual point a real point and insert it at the correct position
76 | const newPoints = [...points];
77 | newPoints.splice(index + 1, 0, event.object.position.clone());
78 | setPoints(newPoints);
79 | setSelectedPoint(event.object.position.clone());
80 | };
81 |
82 | useEffect(() => {
83 | if (points.length < 2) {
84 | return;
85 | }
86 | // TODO: do not compute the curve twice
87 | // TODO: clean up types, reuse vector objects here
88 | const curve = new THREE.CatmullRomCurve3(points);
89 |
90 | setLineLength(curve.getLength());
91 |
92 | const linePoints = curve.getPoints(points.length * 20);
93 | const position = new THREE.Vector3();
94 | let _newPoints: THREE.Vector3[] = [];
95 | for (let i = 0; i < points.length - 1; i++) {
96 | position.copy(points[i]);
97 | position.lerp(new THREE.Vector3(...points[i + 1]), 0.5);
98 |
99 | _newPoints = [..._newPoints, findClosestPoint(linePoints, position)];
100 | }
101 | setVirtualPoints(_newPoints);
102 | }, [points]);
103 |
104 | useEffect(() => {
105 | if (initalTriggerRef.current) {
106 | initalTriggerRef.current = false;
107 | return;
108 | }
109 | if (points.length > 0) {
110 | // add the moving point when going from not drawing -> drawing
111 | // this removes a point when triggered initially
112 | // This should not trigger initially, so the initialTriggeRef is
113 | // a strange workaround
114 | if (isDrawing) {
115 | setPoints([...points, points[points.length - 1]]);
116 | } else {
117 | setPoints(points.slice(0, points.length - 1));
118 | }
119 | }
120 | }, [isDrawing]);
121 |
122 | return (
123 | <>
124 | {points.map((point, index) => (
125 |
132 | ))}
133 | {points.length >= 2 && (
134 | <>
135 | new THREE.Vector3(...point))}
137 | color={lineColor}
138 | lineWidth={2}
139 | segments={Number.parseInt(points.length * 20)}
140 | />
141 | {virtualPoints.map((point, index) => (
142 | handleVirtualClick(index, event)}
148 | />
149 | ))}
150 | >
151 | )}
152 | >
153 | );
154 | };
155 |
156 | import { Plane } from "@react-three/drei";
157 | import { useFrame } from "@react-three/fiber";
158 | import { useRef } from "react";
159 |
160 | export const VirtualCanvas = ({
161 | isDrawing,
162 | setPoints,
163 | points,
164 | hoveredId,
165 | setHoveredId,
166 | }: {
167 | isDrawing: boolean;
168 | setPoints: any;
169 | points: THREE.Vector3[];
170 | hoveredId: number | null;
171 | setHoveredId: (id: number | null) => void;
172 | }) => {
173 | const { camera, size } = useThree();
174 | const [distance, setDistance] = useState(10);
175 | const [canvasVisible, setCanvasVisible] = useState(false);
176 | const canvasRef = useRef();
177 |
178 | // setDistance to camera <-> last point distance
179 |
180 | // useEffect(() => {
181 | // if (canvasRef.current) {
182 | // // Set the initial size of the plane
183 | // updatePlaneSize();
184 | // }
185 | // }, [camera, size]);
186 |
187 | const updatePlaneSize = () => {
188 | const vFOV = THREE.MathUtils.degToRad(camera.fov); // Convert vertical FOV to radians
189 | const height = 2 * Math.tan(vFOV / 2) * distance; // Visible height
190 | const width = height * camera.aspect; // Visible width
191 |
192 | canvasRef.current.scale.set(width, height, 1); // Update the scale of the plane
193 | };
194 |
195 | const onHover = (event: any) => {
196 | if (isDrawing && event.object.visible) {
197 | if (!canvasRef.current) {
198 | return;
199 | }
200 | // this feature is temporarily disabled
201 | // if (event.shiftKey) {
202 | // setHoveredId(canvasRef.current);
203 | // // set opacity of the virtual canvas
204 | // setCanvasVisible(true);
205 | // } else {
206 | // setHoveredId(null);
207 | // console.log("virtual canvas");
208 | // setCanvasVisible(false);
209 | // }
210 |
211 | // find the index of the closest visible point from the camera
212 | // if nothing is being hovered, this is the virtual canvas
213 | let i = 0;
214 | while (
215 | i < event.intersections.length &&
216 | !event.intersections[i].object.visible
217 | ) {
218 | i++;
219 | }
220 |
221 | setPoints((prevPoints: THREE.Vector3[]) => [
222 | ...prevPoints.slice(0, prevPoints.length - 1),
223 | event.intersections[i].point,
224 | ]);
225 | }
226 | };
227 |
228 | useFrame(() => {
229 | if (!isDrawing) {
230 | return;
231 | }
232 | if (canvasRef.current) {
233 | updatePlaneSize();
234 | // if nothing is hovered, the canvas should be visible
235 | canvasRef.current.visible =
236 | hoveredId == null ||
237 | hoveredId === canvasRef.current ||
238 | hoveredId === -1;
239 | }
240 |
241 | if (points.length >= 2) {
242 | const lastPoint = points[points.length - 2];
243 | // the lastPoint is the one we are currently drawing
244 | const dist = camera.position.distanceTo(lastPoint);
245 | setDistance(dist);
246 | }
247 |
248 | if (canvasRef.current) {
249 | const direction = new THREE.Vector3(0, 0, -1).applyQuaternion(
250 | camera.quaternion,
251 | );
252 | const position = direction.multiplyScalar(distance).add(camera.position);
253 | canvasRef.current.position.copy(position);
254 |
255 | canvasRef.current.lookAt(camera.position);
256 | }
257 | });
258 |
259 | return (
260 | <>
261 | {isDrawing && (
262 |
270 |
277 |
278 | )}
279 | >
280 | );
281 | };
282 |
--------------------------------------------------------------------------------
/app/src/components/meshes.tsx:
--------------------------------------------------------------------------------
1 | import type React from "react";
2 | import { useEffect, useMemo, useRef } from "react";
3 | import * as THREE from "three";
4 | import { BufferGeometryUtils } from "three/examples/jsm/Addons.js";
5 | import { type ColorRange, type HSLColor, interpolateColor } from "./utils";
6 | import { useMergedMesh } from "./utils/mergeInstancedMesh";
7 |
8 | function createArrowMesh() {
9 | const cylinderRadius = 0.04;
10 | const cylinderHeight = 0.6;
11 | const coneRadius = 0.1;
12 | const coneHeight = 0.4;
13 |
14 | const cylinderGeometry = new THREE.CylinderGeometry(
15 | cylinderRadius,
16 | cylinderRadius,
17 | cylinderHeight,
18 | 32,
19 | );
20 | const coneGeometry = new THREE.ConeGeometry(coneRadius, coneHeight, 32);
21 |
22 | cylinderGeometry.translate(0, cylinderHeight / 2, 0);
23 | coneGeometry.translate(0, cylinderHeight + coneHeight / 2, 0);
24 |
25 | const arrowGeometry = BufferGeometryUtils.mergeGeometries([
26 | cylinderGeometry,
27 | coneGeometry,
28 | ]);
29 |
30 | return arrowGeometry;
31 | }
32 |
33 | interface ArrowsProps {
34 | start: number[][];
35 | end: number[][];
36 | scale_vector_thickness?: boolean;
37 | colormap: HSLColor[];
38 | colorrange: ColorRange;
39 | opacity?: number;
40 | rescale?: number;
41 | pathTracingSettings: any | undefined;
42 | }
43 |
44 | const Arrows: React.FC = ({
45 | start,
46 | end,
47 | scale_vector_thickness,
48 | colormap,
49 | colorrange,
50 | opacity = 1.0,
51 | rescale = 1.0,
52 | pathTracingSettings = undefined,
53 | }) => {
54 | const meshRef = useRef(null);
55 | const materialRef = useRef(null);
56 |
57 | const geometry = useMemo(() => {
58 | const _geom = createArrowMesh();
59 | if (pathTracingSettings?.enabled) {
60 | // make invisible when path tracing is enabled
61 | _geom.scale(0, 0, 0);
62 | }
63 | return _geom;
64 | }, [pathTracingSettings]);
65 |
66 | const instancedGeometry = useMemo(() => {
67 | return createArrowMesh();
68 | }, []);
69 |
70 | const mergedMesh = useMergedMesh(
71 | meshRef,
72 | instancedGeometry,
73 | pathTracingSettings,
74 | [
75 | start,
76 | end,
77 | scale_vector_thickness,
78 | colormap,
79 | colorrange,
80 | opacity,
81 | rescale,
82 | ],
83 | );
84 |
85 | useEffect(() => {
86 | if (!meshRef.current) return;
87 | const matrix = new THREE.Matrix4();
88 | const up = new THREE.Vector3(0, 1, 0);
89 | const startVector = new THREE.Vector3();
90 | const endVector = new THREE.Vector3();
91 | const direction = new THREE.Vector3();
92 | const quaternion = new THREE.Quaternion();
93 |
94 | for (let i = 0; i < start.length; i++) {
95 | startVector.fromArray(start[i]);
96 | endVector.fromArray(end[i]);
97 | direction.subVectors(endVector, startVector);
98 | let length = direction.length();
99 | const color = interpolateColor(colormap, colorrange, length);
100 | // rescale after the color interpolation
101 | length *= rescale;
102 |
103 | const scale = scale_vector_thickness
104 | ? new THREE.Vector3(length, length, length)
105 | : new THREE.Vector3(1, length, 1);
106 |
107 | quaternion.setFromUnitVectors(up, direction.clone().normalize());
108 | matrix.makeRotationFromQuaternion(quaternion);
109 | matrix.setPosition(startVector);
110 | matrix.scale(scale);
111 |
112 | meshRef.current.setColorAt(i, color);
113 | meshRef.current.setMatrixAt(i, matrix);
114 | }
115 | meshRef.current.instanceMatrix.needsUpdate = true;
116 | }, [start, end, scale_vector_thickness, colormap, colorrange]);
117 |
118 | useEffect(() => {
119 | if (!materialRef.current) return;
120 | materialRef.current.needsUpdate = true; // TODO: check for particles as well
121 | if (!meshRef.current) return;
122 | if (!meshRef.current.instanceColor) return;
123 | meshRef.current.instanceColor.needsUpdate = true;
124 | }, [start, end, scale_vector_thickness, colormap, colorrange]);
125 |
126 | return (
127 |
128 |
134 |
135 | );
136 | };
137 |
138 | export default Arrows;
139 |
--------------------------------------------------------------------------------
/app/src/components/overlays.tsx:
--------------------------------------------------------------------------------
1 | import { Button, Card } from "react-bootstrap";
2 | import { Rnd } from "react-rnd";
3 | import type { Frame } from "./particles";
4 |
5 | export const ParticleInfoOverlay = ({
6 | show,
7 | info,
8 | position,
9 | }: {
10 | show: boolean;
11 | info: { [key: string]: any };
12 | position: { x: number; y: number };
13 | }) => {
14 | return (
15 | <>
16 | {show && (
17 |
28 |
29 |
30 | {Object.entries(info).map(([key, value]) => (
31 | <>
32 | {key}: {value}
33 |
34 | >
35 | ))}
36 |
37 |
38 |
39 | )}
40 | >
41 | );
42 | };
43 |
44 | export const SceneInfoOverlay = ({
45 | frame,
46 | setShowParticleInfo,
47 | }: {
48 | frame: Frame;
49 | setShowParticleInfo: any;
50 | }) => {
51 | return (
52 |
67 |
75 |
76 | Info
77 |
79 |
80 | {frame.calc.energy && (
81 | <>
82 | Energy: {frame.calc.energy} eV
83 |
84 | >
85 | )}
86 | Particles: {frame.positions.length}
87 |
88 |
89 |
90 | );
91 | };
92 |
--------------------------------------------------------------------------------
/app/src/components/particlesEditor.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import { TransformControls } from "@react-three/drei";
3 | import { useCallback, useEffect, useMemo, useRef, useState } from "react";
4 | import { Euler, Vector3 } from "three";
5 | import { socket } from "../socket";
6 |
7 | export const getCentroid = (positions: Vector3[], selection: Set) => {
8 | const centroid = new Vector3();
9 | if (!positions || positions.length === 0) {
10 | return centroid;
11 | }
12 | if (selection.size > 0) {
13 | selection.forEach((i) => {
14 | centroid.add(positions[i]);
15 | });
16 | centroid.divideScalar(selection.size);
17 | } else {
18 | positions.forEach((position) => {
19 | centroid.add(position);
20 | });
21 | centroid.divideScalar(positions.length);
22 | }
23 | return centroid;
24 | };
25 |
26 | // Custom hook for handling centroid calculations
27 | export const useCentroid = ({ frame, selectedIds }: any) => {
28 | return useMemo(() => {
29 | return getCentroid(frame.positions, selectedIds);
30 | }, [frame.positions, selectedIds]);
31 | };
32 |
33 | interface ParticleControlsProps {
34 | frame: any;
35 | selectedIds: Set;
36 | setFrame: (frame: any) => void;
37 | roomLock: boolean;
38 | editMode: string;
39 | setEditMode: (mode: string) => void;
40 | }
41 |
42 | export const ParticleControls: React.FC = ({
43 | frame,
44 | selectedIds,
45 | setFrame,
46 | roomLock,
47 | editMode,
48 | setEditMode,
49 | }: ParticleControlsProps) => {
50 | const controls = useRef(null);
51 | const controlsPostRef = useRef(new Vector3());
52 | const controlsRotationRef = useRef(new Vector3());
53 |
54 | // State for the edit mode: "None", "translate", or "rotate"
55 |
56 | useEffect(() => {
57 | if (roomLock) {
58 | setEditMode("none");
59 | }
60 | }, [roomLock]);
61 |
62 | // Efficiently calculate centroid and attach control to it when `selectedIds` changes
63 | const centroid = useCentroid({ frame, selectedIds });
64 |
65 | useEffect(() => {
66 | if (controls.current && selectedIds.size > 0) {
67 | controls.current.object.position.copy(centroid);
68 | controlsPostRef.current.copy(centroid);
69 | }
70 | }, [centroid]);
71 |
72 | // Helper to update frame positions based on delta
73 | const applyDeltaToPositions = useCallback(
74 | (deltaPosition) => {
75 | setFrame((prevFrame) => ({
76 | ...prevFrame,
77 | positions: prevFrame.positions.map((pos, i) =>
78 | selectedIds.has(i) ? pos.clone().sub(deltaPosition) : pos,
79 | ),
80 | }));
81 | },
82 | [setFrame, selectedIds],
83 | );
84 |
85 | // Helper to update frame rotations based on delta rotation
86 | const applyDeltaToRotations = useCallback(
87 | (deltaRotation) => {
88 | setFrame((prevFrame) => ({
89 | ...prevFrame,
90 | positions: prevFrame.positions.map((rot, i) => {
91 | if (selectedIds.has(i)) {
92 | // rotate the position around the centroid
93 | const position = rot.clone().sub(centroid);
94 | const euler = new Euler().setFromVector3(deltaRotation);
95 | position.applyEuler(euler);
96 | position.add(centroid);
97 | return position;
98 | }
99 | return rot;
100 | }),
101 | }));
102 | },
103 | [setFrame, selectedIds, centroid],
104 | );
105 |
106 | // Handle control changes, applying only necessary updates to position and delta
107 | const handleControlsChange = useCallback(() => {
108 | if (editMode === "translate") {
109 | if (controls.current?.object?.position && selectedIds.size > 0) {
110 | const deltaPosition = controlsPostRef.current
111 | .clone()
112 | .sub(controls.current.object.position);
113 | applyDeltaToPositions(deltaPosition);
114 | controlsPostRef.current.copy(controls.current.object.position);
115 | }
116 | } else if (editMode === "rotate") {
117 | if (controls.current?.object?.rotation && selectedIds.size > 0) {
118 | const deltaRotation = controlsRotationRef.current
119 | .clone()
120 | .sub(controls.current.object.rotation);
121 | applyDeltaToRotations(deltaRotation);
122 | controlsRotationRef.current.copy(controls.current.object.rotation);
123 | }
124 | }
125 | }, [applyDeltaToPositions, selectedIds, editMode, applyDeltaToRotations]);
126 |
127 | // Toggle mode between "None", "translate", and "rotate" on "E" key press
128 | useEffect(() => {
129 | const toggleMode = (event) => {
130 | if (document.activeElement !== document.body) {
131 | return;
132 | }
133 | if (roomLock) {
134 | return;
135 | }
136 | if (event.key.toLowerCase() === "e") {
137 | socket.emit("room:copy");
138 | setEditMode((prevMode) => {
139 | switch (prevMode) {
140 | case "none":
141 | return "translate";
142 | case "translate":
143 | return "rotate";
144 | case "rotate":
145 | return "none";
146 | default:
147 | return "none";
148 | }
149 | });
150 | }
151 | };
152 |
153 | window.addEventListener("keydown", toggleMode);
154 | return () => {
155 | window.removeEventListener("keydown", toggleMode);
156 | };
157 | }, [roomLock]);
158 |
159 | // Apply mode to TransformControls whenever it changes
160 | useEffect(() => {
161 | if (controls.current) {
162 | controls.current.mode = editMode === "none" ? "" : editMode;
163 | }
164 | }, [editMode]);
165 |
166 | return (
167 | <>
168 | {selectedIds.size > 0 && editMode !== "none" && (
169 |
170 | )}
171 | >
172 | );
173 | };
174 |
--------------------------------------------------------------------------------
/app/src/components/tooltips.tsx:
--------------------------------------------------------------------------------
1 | import { OverlayTrigger, Tooltip } from "react-bootstrap";
2 | import type { Placement } from "react-bootstrap/esm/types";
3 |
4 | interface BtnTooltipProps {
5 | text: string;
6 | children: any;
7 | placement?: Placement;
8 | delayShow?: number;
9 | delayHide?: number;
10 | }
11 |
12 | export const BtnTooltip: React.FC = ({
13 | text,
14 | children,
15 | placement = "bottom",
16 | delayShow = 0,
17 | delayHide = 100,
18 | }) => {
19 | return (
20 | {text}}
24 | >
25 | {children}
26 |
27 | );
28 | };
29 |
--------------------------------------------------------------------------------
/app/src/components/transforms.tsx:
--------------------------------------------------------------------------------
1 | import { Dodecahedron, TransformControls } from "@react-three/drei";
2 | import { useEffect, useRef } from "react";
3 | import type * as THREE from "three";
4 |
5 | export default function ControlsBuilder({
6 | points,
7 | setPoints,
8 | selectedPoint,
9 | setSelectedPoint,
10 | }: {
11 | points: THREE.Vector3[];
12 | setPoints: any;
13 | selectedPoint: THREE.Vector3 | null;
14 | setSelectedPoint: any;
15 | }) {
16 | const mesh = useRef(null); // TODO: check type
17 | const controls = useRef(null);
18 |
19 | useEffect(() => {
20 | if (selectedPoint == null) {
21 | return;
22 | }
23 |
24 | if (controls.current && mesh.current) {
25 | controls.current.attach(mesh.current);
26 |
27 | const handleChange = () => {
28 | if (mesh.current) {
29 | const index = points.findIndex(
30 | (point) => point.distanceTo(selectedPoint) < 0.1,
31 | );
32 | if (index === -1) {
33 | // TODO: check what would be best here?
34 | return;
35 | }
36 |
37 | const newPosition = mesh.current.position.clone();
38 | // update the position of points[index] to newPosition
39 | const newPoints = [...points];
40 | newPoints[index] = newPosition;
41 | setPoints(newPoints);
42 | }
43 | };
44 |
45 | controls.current.addEventListener("objectChange", handleChange);
46 |
47 | // Clean up event listeners on unmount
48 | return () => {
49 | if (controls.current) {
50 | controls.current.removeEventListener("objectChange", handleChange);
51 | }
52 | };
53 | }
54 | }, [selectedPoint]);
55 |
56 | return (
57 | <>
58 | {selectedPoint !== null && (
59 | <>
60 |
61 |
67 | >
68 | )}
69 | >
70 | );
71 | }
72 |
--------------------------------------------------------------------------------
/app/src/components/utils.tsx:
--------------------------------------------------------------------------------
1 | import { useEffect, useState } from "react";
2 |
3 | import * as THREE from "three";
4 |
5 | export const useColorMode = (): [string, () => void] => {
6 | const [colorMode, setColorMode] = useState("light");
7 |
8 | useEffect(() => {
9 | const theme =
10 | localStorage.getItem("theme") ||
11 | (window.matchMedia("(prefers-color-scheme: dark)").matches
12 | ? "dark"
13 | : "light");
14 | setTheme(theme, setColorMode);
15 | }, []);
16 |
17 | const handleColorMode = () => {
18 | const newColorMode = colorMode === "light" ? "dark" : "light";
19 | setTheme(newColorMode, setColorMode);
20 | localStorage.setItem("theme", newColorMode);
21 | };
22 |
23 | return [colorMode, handleColorMode];
24 | };
25 |
26 | const setTheme = (
27 | theme: string,
28 | setColorMode: (mode: string) => void,
29 | ): void => {
30 | document.documentElement.setAttribute("data-bs-theme", theme);
31 | setColorMode(theme);
32 | };
33 |
34 | export type HSLColor = [number, number, number];
35 | export type ColorRange = [number, number];
36 |
37 | export const interpolateColor = (
38 | colors: HSLColor[],
39 | range: ColorRange,
40 | value: number,
41 | ): THREE.Color => {
42 | const [min, max] = range;
43 |
44 | // Clamp the value to the range
45 | if (value <= min) {
46 | return new THREE.Color().setHSL(...colors[0]);
47 | }
48 | if (value >= max) {
49 | return new THREE.Color().setHSL(...colors[colors.length - 1]);
50 | }
51 |
52 | // Normalize the value within the range
53 | const normalizedValue = (value - min) / (max - min);
54 |
55 | // Calculate the exact position within the colors array
56 | const scaledValue = normalizedValue * (colors.length - 1);
57 | const lowerIndex = Math.floor(scaledValue);
58 | const upperIndex = Math.ceil(scaledValue);
59 |
60 | const lowerColor = colors[lowerIndex];
61 | const upperColor = colors[upperIndex];
62 | const t = scaledValue - lowerIndex;
63 |
64 | // Interpolate between lowerColor and upperColor
65 | const h = THREE.MathUtils.lerp(lowerColor[0], upperColor[0], t);
66 | const s = THREE.MathUtils.lerp(lowerColor[1], upperColor[1], t);
67 | const l = THREE.MathUtils.lerp(lowerColor[2], upperColor[2], t);
68 |
69 | return new THREE.Color().setHSL(h, s, l);
70 | };
71 |
72 | // Define the type for your state
73 | export type IndicesState = {
74 | active: boolean;
75 | indices: Set; // or Set if you have string indices
76 | };
77 |
--------------------------------------------------------------------------------
/app/src/components/utils/mergeInstancedMesh.tsx:
--------------------------------------------------------------------------------
1 | import { useThree } from "@react-three/fiber";
2 | import { usePathtracer } from "@react-three/gpu-pathtracer";
3 | import { type RefObject, useEffect, useMemo, useRef } from "react";
4 | import * as THREE from "three";
5 | import * as BufferGeometryUtils from "three/addons/utils/BufferGeometryUtils.js";
6 |
7 | interface PathTracingSettings {
8 | enabled: boolean;
9 | [key: string]: any;
10 | }
11 |
12 | /**
13 | * Hook to create a merged mesh from an InstancedMesh, add it to the scene, and update it for path tracing.
14 | *
15 | * @param {RefObject} meshRef - React ref of the InstancedMesh.
16 | * @param {THREE.BufferGeometry} geometry - Geometry to use for the merged mesh.
17 | * @param {PathTracingSettings} settings - Settings for path tracing and merging.
18 | * @param {Array} dependencies - Dependencies array for recalculating merged mesh.
19 | *
20 | * @returns {THREE.Group | null} Merged mesh if created, otherwise null.
21 | */
22 | export function useMergedMesh(
23 | meshRef: RefObject,
24 | geometry: THREE.BufferGeometry,
25 | settings: PathTracingSettings,
26 | dependencies: unknown[] = [],
27 | ): THREE.Group | null {
28 | const { scene } = useThree();
29 | const { update } = usePathtracer();
30 |
31 | const mergedMesh = useMemo(() => {
32 | if (
33 | meshRef.current?.instanceMatrix?.array?.length > 0 &&
34 | settings?.enabled
35 | ) {
36 | const singleMesh = splitInstancedMesh(
37 | meshRef.current,
38 | geometry,
39 | settings,
40 | );
41 | return singleMesh;
42 | }
43 | return null;
44 | }, [geometry, settings, ...dependencies]);
45 |
46 | useEffect(() => {
47 | if (mergedMesh && settings?.enabled) {
48 | scene.add(mergedMesh);
49 | update();
50 |
51 | return () => {
52 | scene.remove(mergedMesh);
53 | };
54 | }
55 | }, [mergedMesh, settings]);
56 |
57 | return mergedMesh;
58 | }
59 |
60 | /**
61 | * Merges an InstancedMesh into a single Mesh.
62 | *
63 | * @param {THREE.InstancedMesh} instancedMesh - The instanced mesh to merge.
64 | * @returns {THREE.Mesh} - A single mesh with merged geometry.
65 | */
66 | export function mergeInstancedMesh(
67 | instancedMesh: THREE.InstancedMesh,
68 | ): THREE.Mesh {
69 | const { count, geometry, material } = instancedMesh;
70 | const mergedGeometries = [];
71 |
72 | // Temporary objects to hold decomposed transformations
73 | const tempPosition = new THREE.Vector3();
74 | const tempQuaternion = new THREE.Quaternion();
75 | const tempScale = new THREE.Vector3();
76 | const tempMatrix = new THREE.Matrix4();
77 |
78 | for (let i = 0; i < count; i++) {
79 | // Get the transformation matrix for the current instance
80 | instancedMesh.getMatrixAt(i, tempMatrix);
81 |
82 | // Decompose the matrix into position, rotation (as quaternion), and scale
83 | tempMatrix.decompose(tempPosition, tempQuaternion, tempScale);
84 |
85 | // Clone the original geometry and apply the instance's transformation
86 | const instanceGeometry = geometry.clone();
87 | instanceGeometry.applyMatrix4(
88 | new THREE.Matrix4().compose(tempPosition, tempQuaternion, tempScale),
89 | );
90 |
91 | // Add transformed geometry to the array for merging
92 | mergedGeometries.push(instanceGeometry);
93 | }
94 |
95 | // Merge all transformed geometries into one
96 | const mergedGeometry = BufferGeometryUtils.mergeGeometries(
97 | mergedGeometries,
98 | true,
99 | );
100 |
101 | // Return a single mesh
102 | return new THREE.Mesh(mergedGeometry, material);
103 | }
104 |
105 | /**
106 | * Converts an InstancedMesh into a Group of individual meshes, each with its own color.
107 | *
108 | * @param {THREE.InstancedMesh} instancedMesh - The instanced mesh to split into individual meshes.
109 | * @returns {THREE.Group} - A group containing individual meshes.
110 | */
111 | export function splitInstancedMesh(
112 | instancedMesh: THREE.InstancedMesh,
113 | geometry: THREE.BufferGeometry,
114 | pathTracingSettings: any,
115 | ): THREE.Group {
116 | const { count } = instancedMesh;
117 | const group = new THREE.Group();
118 |
119 | // Temporary objects to hold decomposed transformations
120 | const tempPosition = new THREE.Vector3();
121 | const tempQuaternion = new THREE.Quaternion();
122 | const tempScale = new THREE.Vector3();
123 | const tempMatrix = new THREE.Matrix4();
124 |
125 | // Accessing the instance color array if it exists
126 | const colorArray = instancedMesh.instanceColor?.array;
127 |
128 | for (let i = 0; i < count; i++) {
129 | // Get the transformation matrix for the current instance
130 | instancedMesh.getMatrixAt(i, tempMatrix);
131 | tempMatrix.decompose(tempPosition, tempQuaternion, tempScale);
132 |
133 | // Clone the geometry for each instance and apply the transformation
134 | const instanceGeometry = geometry.clone();
135 | // const instanceGeometry = new THREE.SphereGeometry(1, 32, 32);
136 | instanceGeometry.applyMatrix4(
137 | new THREE.Matrix4().compose(tempPosition, tempQuaternion, tempScale),
138 | );
139 |
140 | // Determine the color for this instance
141 | let color = new THREE.Color(0xffffff); // Default to white if no color attribute
142 | if (colorArray) {
143 | const r = colorArray[i * 3];
144 | const g = colorArray[i * 3 + 1];
145 | const b = colorArray[i * 3 + 2];
146 | color = new THREE.Color(r, g, b);
147 | }
148 |
149 | // Create a material with the instance's color
150 | const instanceMaterial = new THREE.MeshPhysicalMaterial({
151 | color: color,
152 | metalness: pathTracingSettings.metalness,
153 | roughness: pathTracingSettings.roughness,
154 | clearcoat: pathTracingSettings.clearcoat,
155 | clearcoatRoughness: pathTracingSettings.clearcoatRoughness,
156 | });
157 |
158 | // Create a mesh for this instance
159 | const instanceMesh = new THREE.Mesh(instanceGeometry, instanceMaterial);
160 |
161 | // Add the mesh to the group
162 | group.add(instanceMesh);
163 | }
164 |
165 | return group;
166 | }
167 |
--------------------------------------------------------------------------------
/app/src/components/vectorfield.tsx:
--------------------------------------------------------------------------------
1 | import type React from "react";
2 | import { useEffect, useMemo, useState } from "react";
3 | import * as THREE from "three";
4 | import Arrows from "./meshes";
5 | import type { HSLColor } from "./utils";
6 |
7 | interface VectorFieldProps {
8 | vectors: [number, number, number][][];
9 | showArrows?: boolean;
10 | pathTracingSettings: any | undefined;
11 | arrowsConfig: {
12 | normalize: boolean;
13 | scale_vector_thickness: boolean;
14 | colormap: HSLColor[];
15 | colorrange: [number, number];
16 | opacity: number;
17 | };
18 | }
19 |
20 | export const VectorField: React.FC = ({
21 | vectors,
22 | showArrows = true,
23 | arrowsConfig,
24 | pathTracingSettings,
25 | }) => {
26 | const [colorRange, setColorRange] = useState<[number, number]>(
27 | arrowsConfig.colorrange,
28 | );
29 |
30 | useEffect(() => {
31 | if (arrowsConfig.normalize) {
32 | const max = Math.max(
33 | ...vectors.map((vector) =>
34 | new THREE.Vector3(...vector[0]).distanceTo(
35 | new THREE.Vector3(...vector[1]),
36 | ),
37 | ),
38 | );
39 | setColorRange([0, max]);
40 | } else {
41 | setColorRange(arrowsConfig.colorrange);
42 | }
43 | }, [vectors, arrowsConfig.normalize, arrowsConfig.colorrange]);
44 |
45 | const startMap = useMemo(() => vectors.map((vector) => vector[0]), [vectors]);
46 | const endMap = useMemo(() => vectors.map((vector) => vector[1]), [vectors]);
47 |
48 | return (
49 |
59 | );
60 | };
61 |
62 | export default VectorField;
63 |
--------------------------------------------------------------------------------
/app/src/index.css:
--------------------------------------------------------------------------------
1 | :root {
2 | font-family: Inter, system-ui, Avenir, Helvetica, Arial, sans-serif;
3 | line-height: 1.5;
4 | font-weight: 400;
5 |
6 | color-scheme: light dark;
7 | /* color: rgba(255, 255, 255, 0.87);
8 | background-color: #242424; */
9 |
10 | font-synthesis: none;
11 | text-rendering: optimizeLegibility;
12 | -webkit-font-smoothing: antialiased;
13 | -moz-osx-font-smoothing: grayscale;
14 | }
15 |
16 | a {
17 | font-weight: 500;
18 | color: #646cff;
19 | text-decoration: inherit;
20 | }
21 | a:hover {
22 | color: #535bf2;
23 | }
24 |
25 | body {
26 | margin: 0;
27 | display: flex;
28 | place-items: center;
29 | min-width: 320px;
30 | min-height: 100vh;
31 | }
32 |
33 | h1 {
34 | font-size: 3.2em;
35 | line-height: 1.1;
36 | }
37 |
38 | button {
39 | border-radius: 8px;
40 | border: 1px solid transparent;
41 | padding: 0.6em 1.2em;
42 | font-size: 1em;
43 | font-weight: 500;
44 | font-family: inherit;
45 | /* background-color: #1a1a1a; */
46 | cursor: pointer;
47 | transition: border-color 0.25s;
48 | }
49 | button:hover {
50 | border-color: #646cff;
51 | }
52 | button:focus,
53 | button:focus-visible {
54 | outline: 4px auto -webkit-focus-ring-color;
55 | }
56 | /*
57 | @media (prefers-color-scheme: light) {
58 | :root {
59 | color: #213547;
60 | background-color: #ffffff;
61 | }
62 | a:hover {
63 | color: #747bff;
64 | }
65 | button {
66 | background-color: #f9f9f9;
67 | }
68 | } */
69 |
70 | .edit-mode-description {
71 | position: absolute;
72 | top: 60px;
73 | right: 10px;
74 | /* background-color: rgba(var(--bs-info-rgb), 0.8); */
75 | backdrop-filter: blur(10px);
76 | padding: 10px;
77 | padding-left: 20px;
78 | padding-right: 20px;
79 | border-radius: 10px;
80 | border: 2px solid rgba(var(--bs-info-rgb), 0.8);
81 | }
82 |
--------------------------------------------------------------------------------
/app/src/main.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import ReactDOM from "react-dom/client";
3 | import App from "./App.tsx";
4 | import "bootstrap/dist/css/bootstrap.min.css";
5 | import "./index.css";
6 |
7 | // React strict mode renders the app twice to detect side effects
8 | // this will fail for our useRef based socket detection
9 | // and messages will be send through the socket, that should not be send
10 | ReactDOM.createRoot(document.getElementById("root")!).render(
11 |
12 |
13 | ,
14 | );
15 |
--------------------------------------------------------------------------------
/app/src/socket.tsx:
--------------------------------------------------------------------------------
1 | import { useEffect, useRef } from "react";
2 | import { Manager } from "socket.io-client";
3 | import { createClient } from "znsocket";
4 |
5 | import * as THREE from "three";
6 |
7 | function setupIO() {
8 | const basePath = import.meta.env.BASE_URL || "/";
9 | let manager;
10 |
11 | if (basePath === "/") {
12 | // manager = new Manager("http://localhost:1235"); // for local development
13 | manager = new Manager(window.location.origin); // for production
14 | } else {
15 | manager = new Manager(window.location.origin, {
16 | path: `${basePath}socket.io`,
17 | });
18 | }
19 | console.log("manager", manager);
20 | return {
21 | socket: manager.socket("/"),
22 | client: createClient({ socket: manager.socket("/znsocket") }),
23 | };
24 | }
25 | export const { socket, client } = setupIO();
26 |
--------------------------------------------------------------------------------
/app/src/vite-env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 |
--------------------------------------------------------------------------------
/app/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2020",
4 | "useDefineForClassFields": true,
5 | "lib": ["ES2020", "DOM", "DOM.Iterable"],
6 | "module": "ESNext",
7 | "skipLibCheck": true,
8 |
9 | // Bundler mode
10 | "moduleResolution": "bundler",
11 | "allowImportingTsExtensions": true,
12 | "resolveJsonModule": true,
13 | "isolatedModules": true,
14 | "noEmit": true,
15 | "jsx": "react-jsx",
16 |
17 | // Linting
18 | "strict": true,
19 | "noUnusedLocals": true,
20 | "noUnusedParameters": true,
21 | "noFallthroughCasesInSwitch": true
22 | },
23 | "include": ["src", "decs.d.ts"],
24 | "references": [{ "path": "./tsconfig.node.json" }]
25 | }
26 |
--------------------------------------------------------------------------------
/app/tsconfig.node.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "composite": true,
4 | "skipLibCheck": true,
5 | "module": "ESNext",
6 | "moduleResolution": "bundler",
7 | "allowSyntheticDefaultImports": true,
8 | "strict": true
9 | },
10 | "include": ["vite.config.ts"]
11 | }
12 |
--------------------------------------------------------------------------------
/app/vite.config.ts:
--------------------------------------------------------------------------------
1 | import react from "@vitejs/plugin-react-swc";
2 | import { defineConfig } from "vite";
3 |
4 | // https://vitejs.dev/config/
5 | export default defineConfig({
6 | plugins: [react()],
7 | base: "/",
8 | build: {
9 | outDir: "../zndraw/templates", // Output directory for templates
10 | emptyOutDir: true, // Clear the output directory before building
11 | },
12 | publicDir: "public", // Directory for static assets
13 | server: {
14 | proxy: {
15 | "/reset": "http://localhost:1234",
16 | "/token": "http://localhost:1234",
17 | "/upload": "http://localhost:1234",
18 | "/download": "http://localhost:1234",
19 | "/socket.io": {
20 | target: "ws://localhost:1234",
21 | ws: true,
22 | },
23 | },
24 | },
25 | });
26 |
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | zndraw:
3 | build: .
4 | healthcheck:
5 | test: ["CMD", "zndraw", "--healthcheck", "--url", "http://zndraw:5003"]
6 | interval: 30s
7 | timeout: 10s
8 | retries: 5
9 | command: --no-standalone
10 | restart: unless-stopped
11 | ports:
12 | - 5003:5003
13 | depends_on:
14 | - redis
15 | - worker
16 | environment:
17 | - FLASK_STORAGE=redis://redis:6379/0
18 | - FLASK_AUTH_TOKEN=super-secret-token
19 |
20 | worker:
21 | build: .
22 | healthcheck:
23 | test: ["CMD", "zndraw", "--healthcheck", "--url", "http://zndraw:5003"]
24 | interval: 30s
25 | timeout: 10s
26 | retries: 5
27 | entrypoint: celery -A zndraw_app.make_celery worker --loglevel=info -P eventlet
28 | restart: unless-stopped
29 | depends_on:
30 | - redis
31 | environment:
32 | - FLASK_STORAGE=redis://redis:6379/0
33 | - FLASK_SERVER_URL="http://zndraw:5003"
34 | - FLASK_AUTH_TOKEN=super-secret-token
35 |
36 | redis:
37 | image: redis:latest
38 | restart: always
39 | environment:
40 | - REDIS_PORT=6379
41 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/2025/01.rst:
--------------------------------------------------------------------------------
1 | 01 - Christmas Tree Builder
2 | ============================
3 |
4 | You can use a ZnDraw extension to provide a graphical user interface for building a Christmas tree out of molecules.
5 |
6 | .. code:: python
7 |
8 | from ase import Atoms
9 | from zndraw import ZnDraw, Extension
10 | from pydantic import Field
11 | from rdkit2ase import smiles2atoms, smiles2conformers
12 |
13 | vis = ZnDraw(url="http://localhost:5003/", token="tree")
14 |
15 | class BuildChristmasTree(Extension):
16 | smiles: str = Field(
17 | "CO", description="SMILES string of the molecule to use for the tree"
18 | )
19 | n: int = Field(5, description="Number of layers for the tree", ge=1, le=10)
20 | x_spacing: float = Field(
21 | 4,
22 | description="Horizontal spacing between molecules in each layer (in Ångstroms)",
23 | ge=0,
24 | le=10,
25 | )
26 | y_spacing: float = Field(
27 | 3, description="Vertical spacing between layers (in Ångstroms)", ge=0, le=10
28 | )
29 | trunk_height: int = Field(
30 | 2, description="Number of molecules in the trunk", ge=0, le=10
31 | )
32 | conformers: bool = True
33 |
34 | def run(self, vis, **kwargs):
35 | # compute number of needed conformers
36 | n_molecules = (self.n * (self.n + 1)) + self.trunk_height
37 | if self.conformers:
38 | molecules = smiles2conformers(self.smiles, numConfs=n_molecules)
39 | else:
40 | molecule = smiles2atoms(self.smiles)
41 | molecules = [molecule.copy() for _ in range(n_molecules)]
42 | tree = build_christmas_tree(
43 | molecules, self.n, self.trunk_height, self.x_spacing, self.y_spacing
44 | )
45 | vis.append(tree)
46 |
47 | @classmethod
48 | def model_json_schema(cls):
49 | schema = super().model_json_schema()
50 | schema["properties"]["conformers"]["format"] = "checkbox"
51 | # make format range
52 | schema["properties"]["n"]["format"] = "range"
53 | schema["properties"]["x_spacing"]["format"] = "range"
54 | schema["properties"]["x_spacing"]["step"] = 0.1
55 | schema["properties"]["y_spacing"]["format"] = "range"
56 | schema["properties"]["y_spacing"]["step"] = 0.1
57 | schema["properties"]["trunk_height"]["format"] = "range"
58 | return schema
59 |
60 |
61 | def build_christmas_tree(
62 | molecules: list[Atoms],
63 | n: int = 5,
64 | trunk_height: int = 2,
65 | x_spacing: float = 3.0,
66 | y_spacing: float = 3.0,
67 | ) -> Atoms:
68 | """Build an atomic Christmas tree.
69 |
70 | Arguments
71 | ---------
72 | molecules : list[Atoms]
73 | A list of molecular structures to use for each part of the tree.
74 | n : int
75 | The number of layers for the tree.
76 | trunk_height : int
77 | The number of molecules in the trunk.
78 | x_spacing : float
79 | Horizontal spacing between molecules in each layer (in Ångstroms).
80 | y_spacing : float
81 | Vertical spacing between layers (in Ångstroms).
82 |
83 | Returns
84 | -------
85 | tree : Atoms
86 | An assembled "tree" with the trunk and branches built from the provided molecules.
87 | """
88 | # Ensure there are enough molecules to build the tree
89 | if len(molecules) < n * (n + 1) // 2 + trunk_height:
90 | raise ValueError(
91 | "Not enough molecules to build the tree and trunk with the given parameters."
92 | )
93 |
94 | # Center molecules individually
95 | for mol in molecules:
96 | mol.center()
97 |
98 | # Create an empty structure for the tree
99 | tree = Atoms()
100 |
101 | # Build the trunk
102 | for _ in range(trunk_height):
103 | mol_copy = molecules.pop()
104 | tree += mol_copy
105 | [mol.translate([0, y_spacing, 0]) for mol in molecules]
106 |
107 | # Build the layers from bottom to top
108 | for layer_num in reversed(range(n)):
109 | layer = Atoms()
110 | num_molecules = layer_num + 1
111 | x_offset = (
112 | x_spacing * (num_molecules - 1) / 2
113 | ) # Offset to center the layer horizontally
114 |
115 | for j in range(num_molecules):
116 | mol_copy = molecules.pop()
117 | mol_copy.translate([j * x_spacing - x_offset, 0, 0])
118 | layer += mol_copy
119 |
120 | tree += layer
121 | [mol.translate([0, y_spacing, 0]) for mol in molecules]
122 |
123 | return tree
124 |
125 | vis.register(BuildChristmasTree, public=True)
126 | vis.socket.wait()
127 |
128 | The Extension will appear on the modifier sidebar and gives you full control over the parameters of the tree builder.
129 |
130 | .. image:: https://github.com/user-attachments/assets/161e6b40-f539-45b9-9bab-cfa613e37b8f
131 | :width: 100%
132 | :alt: ZnDraw
133 | :class: only-light
134 |
135 | .. image:: https://github.com/user-attachments/assets/f1495096-c443-4a53-98c4-07368354b21d
136 | :width: 100%
137 | :alt: ZnDraw
138 | :class: only-dark
139 |
140 |
141 | .. tip::
142 |
143 | Use the PathTracer integrated with ZnDraw to make the christmas tree reflective like christmas decorations.
144 |
145 | .. image:: https://github.com/user-attachments/assets/ca382068-1f17-4bcb-a6f2-ef48c671ac48
146 | :width: 100%
147 | :alt: ZnDraw
148 | :class: only-light
149 |
150 | .. image:: https://github.com/user-attachments/assets/5df7e7ab-a930-4361-ac5e-e4f0fcd10cc1
151 | :width: 100%
152 | :alt: ZnDraw
153 | :class: only-dark
154 |
--------------------------------------------------------------------------------
/docs/source/adventofcode.rst:
--------------------------------------------------------------------------------
1 | Advent of Code
2 | ==============
3 |
4 | Each day of the Advent we will showcase one of ZnDraws features for you!
5 |
6 | Each example - unless otherwise noted - assumes a running ZnDraw instance on ``http://localhost:5003``.
7 |
8 | .. toctree::
9 | :maxdepth: 2
10 | :glob:
11 |
12 | 2025/*
13 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | import zndraw
10 |
11 | project = "ZnDraw"
12 | copyright = "2024, Fabian Zills"
13 | author = "Fabian Zills"
14 | release = zndraw.__version__
15 |
16 | # -- General configuration ---------------------------------------------------
17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
18 |
19 | extensions = [
20 | "nbsphinx",
21 | "sphinx_copybutton",
22 | ]
23 |
24 | templates_path = ["_templates"]
25 | exclude_patterns = []
26 |
27 |
28 | # -- Options for HTML output -------------------------------------------------
29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
30 |
31 | html_theme = "furo"
32 | html_static_path = ["_static"]
33 |
34 | html_theme_options: dict = {
35 | "light_logo": "zndraw-light.svg",
36 | "dark_logo": "zndraw-dark.svg",
37 | "footer_icons": [
38 | {
39 | "name": "GitHub",
40 | "url": "https://github.com/zincware/zndraw",
41 | "html": "",
42 | "class": "fa-brands fa-github fa-2x",
43 | },
44 | ],
45 | "source_repository": "https://github.com/zincware/zndraw",
46 | "source_branch": "main",
47 | "source_directory": "docs/source/",
48 | "navigation_with_keys": True,
49 | }
50 |
51 | # font-awesome logos
52 | html_css_files = [
53 | "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/brands.min.css",
54 | ]
55 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | ZnDraw documentation
2 | ====================
3 |
4 | A python-first visualisation and editing tool for atomic structures.
5 |
6 | Installation
7 | ------------
8 |
9 | ZnDraw can be installed via pip:
10 |
11 | .. code:: console
12 |
13 | $ pip install zndraw
14 |
15 | Usage
16 | -----
17 | You can use the command line interface to visualise atomic structures.
18 | ZnDraw uses a local web server to display the visualisation and runs in your default web browser.
19 | You can use port forwarding to access the visualisation from a remote machine.
20 |
21 | .. code:: console
22 |
23 | $ zndraw file.xyz
24 |
25 |
26 | .. toctree::
27 | :maxdepth: 2
28 | :hidden:
29 |
30 | python-api
31 | adventofcode
32 |
--------------------------------------------------------------------------------
/docs/source/python-api.rst:
--------------------------------------------------------------------------------
1 | Python interface
2 | ================
3 |
4 | The ``zndraw`` package provides a Python interface to interact with the visualisation tool.
5 | To use this API, you need to have a running instance of the ZnDraw web server.
6 |
7 | .. note::
8 |
9 | You can run a local webserver by using the command line interface.
10 |
11 | .. code:: console
12 |
13 | $ zndraw file.xyz --port 1234
14 |
15 |
16 | .. code:: python
17 |
18 | from zndraw import ZnDraw
19 |
20 | vis = ZnDraw(url="http://localhost:1234", token="c91bb84f")
21 |
22 |
23 | .. note::
24 |
25 | In ZnDraw each visualisation is associated with a unique token.
26 | You find this token in the URL of the visualisation.
27 | This token can be used to interact with the visualisation using the Python API.
28 | Additionally, you can use the token to share the visualisation with others or view
29 | the visualisation from different angles in different browser tabs.
30 |
31 |
32 | The ``vis`` object provides a Python interface to interact with the visualisation.
33 | Most basically, it behaves like a Python list of `ase.Atoms `_ objects.
34 | Modifying the list will update the visualisation in real-time.
35 |
36 | .. code:: python
37 |
38 | import ase.io as aio
39 |
40 | frames = aio.read("file.xyz", index=":")
41 | vis.extend(frames)
42 |
--------------------------------------------------------------------------------
/examples/md.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Molecular Dynamics Simulation\n",
8 | "\n",
9 | "In this example we will run an MD simulation using ASE and visualise it using ZnDraw."
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 7,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from ase import units\n",
19 | "from ase.calculators.emt import EMT\n",
20 | "from ase.lattice.cubic import FaceCenteredCubic\n",
21 | "from ase.md.velocitydistribution import MaxwellBoltzmannDistribution\n",
22 | "from ase.md.verlet import VelocityVerlet\n",
23 | "\n",
24 | "from zndraw import ZnDraw\n",
25 | "\n",
26 | "size = 2\n",
27 | "\n",
28 | "# Set up a crystal\n",
29 | "atoms = FaceCenteredCubic(\n",
30 | " directions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n",
31 | " symbol=\"Cu\",\n",
32 | " size=(size, size, size),\n",
33 | " pbc=True,\n",
34 | ")\n",
35 | "\n",
36 | "# Describe the interatomic interactions with the Effective Medium Theory\n",
37 | "atoms.calc = EMT()\n",
38 | "\n",
39 | "# Set the momenta corresponding to T=300K\n",
40 | "MaxwellBoltzmannDistribution(atoms, temperature_K=300)\n",
41 | "\n",
42 | "# We want to run MD with constant energy using the VelocityVerlet algorithm.\n",
43 | "dyn = VelocityVerlet(atoms, 5 * units.fs) # 5 fs time step.\n",
44 | "\n",
45 | "\n",
46 | "def printenergy(a=atoms): # store a reference to atoms in the definition.\n",
47 | " \"\"\"Function to print the potential, kinetic and total energy.\"\"\"\n",
48 | " epot = a.get_potential_energy() / len(a)\n",
49 | " ekin = a.get_kinetic_energy() / len(a)\n",
50 | " print(\n",
51 | " \"Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) \"\n",
52 | " \"Etot = %.3feV\" % (epot, ekin, ekin / (1.5 * units.kB), epot + ekin)\n",
53 | " )"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 8,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "vis = ZnDraw(url=\"http://127.0.0.1:47823/\", token=\"fcd45c3917d34b1a8329bd5a5f172c0f\")"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 9,
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "name": "stdout",
72 | "output_type": "stream",
73 | "text": [
74 | "Energy per atom: Epot = -0.006eV Ekin = 0.040eV (T=306K) Etot = 0.034eV\n"
75 | ]
76 | },
77 | {
78 | "data": {
79 | "text/plain": [
80 | "True"
81 | ]
82 | },
83 | "execution_count": 9,
84 | "metadata": {},
85 | "output_type": "execute_result"
86 | }
87 | ],
88 | "source": [
89 | "# Now run the dynamics\n",
90 | "\n",
91 | "# dyn.attach(printenergy, interval=1)\n",
92 | "dyn.attach(lambda: vis.append(atoms), interval=1)\n",
93 | "printenergy()\n",
94 | "dyn.run(200)"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": []
103 | }
104 | ],
105 | "metadata": {
106 | "kernelspec": {
107 | "display_name": "zndraw",
108 | "language": "python",
109 | "name": "python3"
110 | },
111 | "language_info": {
112 | "codemirror_mode": {
113 | "name": "ipython",
114 | "version": 3
115 | },
116 | "file_extension": ".py",
117 | "mimetype": "text/x-python",
118 | "name": "python",
119 | "nbconvert_exporter": "python",
120 | "pygments_lexer": "ipython3",
121 | "version": "3.10.10"
122 | },
123 | "orig_nbformat": 4
124 | },
125 | "nbformat": 4,
126 | "nbformat_minor": 2
127 | }
128 |
--------------------------------------------------------------------------------
/examples/molecules.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from ase.collections import s22\n",
10 | "\n",
11 | "from zndraw import ZnDraw"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "metadata": {},
18 | "outputs": [
19 | {
20 | "name": "stdout",
21 | "output_type": "stream",
22 | "text": [
23 | "Starting ZnDraw server at http://127.0.0.1:1234\n",
24 | "Connected to ZnDraw server at http://127.0.0.1:1234\n"
25 | ]
26 | },
27 | {
28 | "data": {
29 | "text/html": [
30 | "\n",
31 | " \n",
39 | " "
40 | ],
41 | "text/plain": [
42 | "ZnDraw(url='http://127.0.0.1:1234', socket=, jupyter=True, bonds_calculator=ASEComputeBonds(single_bond_multiplier=1.2, double_bond_multiplier=0.9, triple_bond_multiplier=0.0), display_new=True, _retries=5)"
43 | ]
44 | },
45 | "execution_count": 2,
46 | "metadata": {},
47 | "output_type": "execute_result"
48 | }
49 | ],
50 | "source": [
51 | "vis = ZnDraw(jupyter=True)\n",
52 | "vis"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 3,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "vis.socket.sleep(0.5) # if we just started the server, we need to wait a bit"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "metadata": {},
67 | "source": [
68 | "# The `zndraw.ZnDraw` Object\n",
69 | "\n",
70 | "The `zndraw.ZnDraw` object acts a bit like a list.\n",
71 | "All objects in that list visualized using `ZnDraw`.\n",
72 | "Objects can not be changed in place but otherwise, typical list operations are possible."
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 4,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "vis.extend(s22)"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 5,
87 | "metadata": {},
88 | "outputs": [
89 | {
90 | "data": {
91 | "text/plain": [
92 | "[Atoms(symbols='', pbc=True),\n",
93 | " Atoms(symbols='NH3NH3', pbc=True),\n",
94 | " Atoms(symbols='OH2OH2', pbc=True),\n",
95 | " Atoms(symbols='CO2H2CO2H2', pbc=True),\n",
96 | " Atoms(symbols='CONH3CONH3', pbc=True),\n",
97 | " Atoms(symbols='OCNC3NOH4OCNC3NOH4', pbc=True),\n",
98 | " Atoms(symbols='ONC5H5NC5H4NH2', pbc=True),\n",
99 | " Atoms(symbols='NC3NCNCN2H5NC3NC2O2H6', pbc=True),\n",
100 | " Atoms(symbols='CH4CH4', pbc=True),\n",
101 | " Atoms(symbols='C2H4C2H4', pbc=True),\n",
102 | " Atoms(symbols='C6H6CH4', pbc=True),\n",
103 | " Atoms(symbols='C6H6C6H6', pbc=True),\n",
104 | " Atoms(symbols='C2NC2NH4C2NC2NH4', pbc=True),\n",
105 | " Atoms(symbols='C6H7C2HC2HCNC2HCH3', pbc=True),\n",
106 | " Atoms(symbols='NCHNC2NH2NCHNCHNCHC2H3CONHCOH', pbc=True),\n",
107 | " Atoms(symbols='C2H4C2H2', pbc=True),\n",
108 | " Atoms(symbols='C6H6OH2', pbc=True),\n",
109 | " Atoms(symbols='C6H6NCH', pbc=True),\n",
110 | " Atoms(symbols='C6H6C6H6', pbc=True),\n",
111 | " Atoms(symbols='C6H7NC8H6', pbc=True),\n",
112 | " Atoms(symbols='COHC5H5OCHC5H5', pbc=True)]"
113 | ]
114 | },
115 | "execution_count": 5,
116 | "metadata": {},
117 | "output_type": "execute_result"
118 | }
119 | ],
120 | "source": [
121 | "list(vis)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 6,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "vis.display(2) # display the second molecule"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 7,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "vis[0] = s22[\"Water_dimer\"]"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 8,
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "vis.append(s22[\"Formic_acid_dimer\"])"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "metadata": {},
155 | "outputs": [],
156 | "source": []
157 | }
158 | ],
159 | "metadata": {
160 | "kernelspec": {
161 | "display_name": "zndraw",
162 | "language": "python",
163 | "name": "python3"
164 | },
165 | "language_info": {
166 | "codemirror_mode": {
167 | "name": "ipython",
168 | "version": 3
169 | },
170 | "file_extension": ".py",
171 | "mimetype": "text/x-python",
172 | "name": "python",
173 | "nbconvert_exporter": "python",
174 | "pygments_lexer": "ipython3",
175 | "version": "3.11.4"
176 | },
177 | "orig_nbformat": 4
178 | },
179 | "nbformat": 4,
180 | "nbformat_minor": 2
181 | }
182 |
--------------------------------------------------------------------------------
/examples/stress_testing/README.md:
--------------------------------------------------------------------------------
1 | # Stress Testing
2 | The following scripts can be used to perform some stress testing on a given ZnDraw instance.
3 | Do not use them against public servers which you are not hosting yourself.
4 |
--------------------------------------------------------------------------------
/examples/stress_testing/multi_connection.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | import typer
4 |
5 | app = typer.Typer()
6 |
7 |
8 | @app.command()
9 | def main(file: str, n: int, browser: bool = False):
10 | cmd = ["zndraw", file]
11 | if not browser:
12 | cmd.append("--no-browser")
13 | # run cmd n times in parallel
14 | for _ in range(n):
15 | subprocess.Popen(cmd)
16 |
17 |
18 | if __name__ == "__main__":
19 | app()
20 |
--------------------------------------------------------------------------------
/examples/stress_testing/single_connection.py:
--------------------------------------------------------------------------------
1 | # Run tests with and without eventlet
2 | import eventlet
3 |
4 | eventlet.monkey_patch()
5 |
6 | import datetime # noqa
7 | import uuid
8 | import os
9 |
10 | import tqdm
11 | from rdkit2ase import smiles2conformers
12 |
13 | #### Import ZnDraw ####
14 | from zndraw import ZnDraw
15 |
16 | vis = ZnDraw(url=os.environ["ZNDRAW_URL"], token=uuid.uuid4().hex)
17 | #### ------------- ####
18 |
19 | # vis._refresh_client.delay_between_calls = datetime.timedelta(milliseconds=10)
20 |
21 | conformers = smiles2conformers("CCCCCCCCCO", numConfs=1000)
22 |
23 | # append
24 | for atoms in tqdm.tqdm(conformers, desc="append", ncols=80):
25 | vis.append(atoms)
26 |
27 | # read
28 | for i in tqdm.trange(len(vis), desc="getitem", ncols=80):
29 | _ = vis[i]
30 |
31 | # delete
32 | for i in tqdm.tqdm(range(len(vis) - 1, -1, -1), desc="delete", ncols=80):
33 | del vis[i]
34 |
35 | # extend
36 | vis.extend(conformers)
37 |
38 | # read_all
39 | print(f"len(vis[:]): {len(vis[:])}")
40 |
--------------------------------------------------------------------------------
/misc/darkmode/analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zincware/ZnDraw/0a71d450976df5e1889b99d4764c0e793268b55e/misc/darkmode/analysis.png
--------------------------------------------------------------------------------
/misc/darkmode/box.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zincware/ZnDraw/0a71d450976df5e1889b99d4764c0e793268b55e/misc/darkmode/box.png
--------------------------------------------------------------------------------
/misc/darkmode/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zincware/ZnDraw/0a71d450976df5e1889b99d4764c0e793268b55e/misc/darkmode/overview.png
--------------------------------------------------------------------------------
/misc/darkmode/python.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zincware/ZnDraw/0a71d450976df5e1889b99d4764c0e793268b55e/misc/darkmode/python.png
--------------------------------------------------------------------------------
/misc/lightmode/analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zincware/ZnDraw/0a71d450976df5e1889b99d4764c0e793268b55e/misc/lightmode/analysis.png
--------------------------------------------------------------------------------
/misc/lightmode/box.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zincware/ZnDraw/0a71d450976df5e1889b99d4764c0e793268b55e/misc/lightmode/box.png
--------------------------------------------------------------------------------
/misc/lightmode/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zincware/ZnDraw/0a71d450976df5e1889b99d4764c0e793268b55e/misc/lightmode/overview.png
--------------------------------------------------------------------------------
/misc/lightmode/python.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zincware/ZnDraw/0a71d450976df5e1889b99d4764c0e793268b55e/misc/lightmode/python.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "zndraw"
3 | version = "0.5.10"
4 | description = "Display and Edit Molecular Structures and Trajectories in the Browser."
5 | authors = [
6 | { name = "Fabian Zills", email = "fzills@icp.uni-stuttgart.de" },
7 | { name = "Rokas Elijošius", email = "re344@cam.ac.uk" },
8 | { name = "Paul Hohenberger"},
9 | ]
10 | classifiers = ["License :: OSI Approved :: Eclipse Public License 2.0 (EPL-2.0)"]
11 | readme = "README.md"
12 | requires-python = ">=3.10"
13 | dependencies = [
14 | "ase>=3.24.0",
15 | "celery>=5.4.0",
16 | "eventlet>=0.39.0",
17 | "flask>=3.1.0",
18 | "flask-socketio>=5.5.1",
19 | "networkx>=3.4.2",
20 | "pandas>=2.2.3",
21 | "plotly>=6.0.0",
22 | "pydantic>=2.10.6",
23 | "python-socketio[client]>=5.12.1",
24 | "redis>=5.2.1",
25 | "splines>=0.3.2",
26 | "sqlalchemy>=2.0.38",
27 | "tqdm>=4.67.1",
28 | "typer>=0.15.1",
29 | "znjson>=0.2.6",
30 | "znsocket>=0.2.8",
31 | ]
32 |
33 | [project.urls]
34 | Repository = "https://github.com/zincware/ZnDraw"
35 | Releases = "https://github.com/zincware/ZnDraw/releases"
36 | Discord = "https://discord.gg/7ncfwhsnm4"
37 | Documentation = "https://zndraw.readthedocs.io/"
38 |
39 |
40 | [project.scripts]
41 | zndraw = 'zndraw_app.cli:cli'
42 |
43 |
44 | [build-system]
45 | requires = ["hatchling"]
46 | build-backend = "hatchling.build"
47 |
48 | [tool.hatch.build.targets.sdist]
49 | exclude = [
50 | "/app",
51 | ]
52 |
53 | [tool.hatch.build.targets.wheel]
54 | include = ["zndraw", "zndraw_app"]
55 | artifacts = [
56 | "zndraw/templates/**",
57 | ]
58 |
59 | [tool.ruff]
60 | line-length = 90
61 |
62 | [tool.ruff.lint]
63 | select = ["I", "F"]
64 |
65 | # by default do not run pytest marked with "chrome"
66 | [tool.pytest.ini_options]
67 | addopts = "-m 'not chrome'"
68 |
69 | [tool.codespell]
70 | skip = "*.svg,*.lock"
71 |
72 | [dependency-groups]
73 | dev = [
74 | "mdanalysis>=2.8.0",
75 | "pytest>=8.3.4",
76 | "pytest-cov>=6.0.0",
77 | "rdkit2ase>=0.1.4",
78 | "tidynamics>=1.1.2",
79 | "znh5md>=0.4.4",
80 | "zntrack>=0.8.2",
81 | ]
82 | docs = [
83 | "furo>=2024.8.6",
84 | "nbsphinx>=0.9.6",
85 | "sphinx>=8.1.3",
86 | "sphinx-copybutton>=0.5.2",
87 | ]
88 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import eventlet.wsgi
4 |
5 | eventlet.monkey_patch() # MUST BE THERE FOR THE TESTS TO WORK
6 |
7 | import random
8 |
9 | import ase.build
10 | import ase.collections
11 | import numpy as np
12 | import pytest
13 | import socketio.exceptions
14 | from ase.calculators.singlepoint import SinglePointCalculator
15 |
16 | from zndraw.app import create_app
17 | from zndraw.standalone import run_celery_worker
18 |
19 |
20 | @pytest.fixture
21 | def server():
22 | port = random.randint(10000, 20000)
23 |
24 | os.environ["FLASK_PORT"] = str(port)
25 | os.environ["FLASK_STORAGE"] = "redis://localhost:6379/0"
26 | os.environ["FLASK_SERVER_URL"] = f"http://localhost:{port}"
27 |
28 | proc = run_celery_worker()
29 |
30 | def start_server():
31 | app = create_app()
32 | app.config["TESTING"] = True
33 |
34 | socketio = app.extensions["socketio"]
35 | try:
36 | socketio.run(
37 | app,
38 | host="0.0.0.0",
39 | port=app.config["PORT"],
40 | )
41 | finally:
42 | app.extensions["redis"].flushall()
43 |
44 | thread = eventlet.spawn(start_server)
45 |
46 | # wait for the server to be ready
47 | for _ in range(100):
48 | try:
49 | with socketio.SimpleClient() as client:
50 | client.connect(f"http://localhost:{port}")
51 | break
52 | except socketio.exceptions.ConnectionError:
53 | eventlet.sleep(0.1)
54 | else:
55 | raise TimeoutError("Server did not start in time")
56 |
57 | yield f"http://127.0.0.1:{port}"
58 |
59 | thread.kill()
60 | proc.kill()
61 | proc.wait()
62 |
63 |
64 | @pytest.fixture
65 | def s22() -> list[ase.Atoms]:
66 | """S22 dataset."""
67 | return list(ase.collections.s22)
68 |
69 |
70 | @pytest.fixture
71 | def water() -> ase.Atoms:
72 | """Water molecule."""
73 | return ase.build.molecule("H2O")
74 |
75 |
76 | @pytest.fixture
77 | def s22_energy_forces() -> list[ase.Atoms]:
78 | images = []
79 | for atoms in ase.collections.s22:
80 | calc = SinglePointCalculator(
81 | atoms, energy=np.random.rand(), forces=np.random.rand(len(atoms), 3)
82 | )
83 | atoms.calc = calc
84 | images.append(atoms)
85 | return images
86 |
--------------------------------------------------------------------------------
/tests/test_analysis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import znsocket
3 |
4 | from zndraw import ZnDraw
5 | from zndraw.analyse import Properties1D
6 |
7 |
8 | def run_queue(vis, key, msg: dict):
9 | modifier_queue = znsocket.Dict(
10 | r=vis.r,
11 | socket=vis._refresh_client,
12 | key=f"queue:{vis.token}:{key}",
13 | )
14 | modifier_queue.update(msg)
15 | vis.socket.emit("room:worker:run")
16 | vis.socket.sleep(10)
17 |
18 |
19 | def test_run_analysis_distance(server, s22_energy_forces):
20 | vis = ZnDraw(url=server, token="test_token")
21 | vis.extend(s22_energy_forces)
22 | assert vis.figures == {}
23 | vis.selection = [0, 1]
24 |
25 | run_queue(vis, "analysis", {"Distance": {}})
26 |
27 | fig = vis.figures["Distance"]
28 | # assert that the x-axis label is "step"
29 | assert fig.layout.xaxis.title.text == "step"
30 |
31 |
32 | def test_run_analysis_Properties1D(server, s22_energy_forces):
33 | vis = ZnDraw(url=server, token="test_token")
34 | vis.extend(s22_energy_forces)
35 | assert vis.figures == {}
36 |
37 | run_queue(vis, "analysis", {"Properties1D": {"value": "energy"}})
38 |
39 | fig = vis.figures["Properties1D"]
40 | # assert that the x-axis label is "step"
41 | assert fig.layout.xaxis.title.text == "step"
42 | assert fig.layout.yaxis.title.text == "energy"
43 |
44 |
45 | def test_run_analysis_Properties2D(server, s22_energy_forces):
46 | vis = ZnDraw(url=server, token="test_token")
47 | vis.extend(s22_energy_forces)
48 | assert vis.figures == {}
49 |
50 | run_queue(
51 | vis,
52 | "analysis",
53 | {"Properties2D": {"x_data": "energy", "y_data": "step", "color": "energy"}},
54 | )
55 |
56 | fig = vis.figures["Properties2D"]
57 | # assert that the x-axis label is "step"
58 | assert fig.layout.yaxis.title.text == "step"
59 | assert fig.layout.xaxis.title.text == "energy"
60 |
61 |
62 | def test_run_analysis_DihedralAngle(server, s22_energy_forces):
63 | vis = ZnDraw(url=server, token="test_token")
64 | vis.extend(s22_energy_forces)
65 | assert vis.figures == {}
66 | vis.selection = [0, 1, 3, 4]
67 |
68 | run_queue(vis, "analysis", {"DihedralAngle": {}})
69 |
70 | fig = vis.figures["DihedralAngle"]
71 | # assert that the x-axis label is "step"
72 | assert fig.layout.xaxis.title.text == "step"
73 |
74 |
75 | def test_analysis_Properties1D_json_schema(s22_energy_forces):
76 | # add custom info keys
77 | for atoms in s22_energy_forces:
78 | atoms.info["custom"] = 42
79 | atoms.info["custom2"] = np.random.rand(10)
80 | atoms.info["custom3"] = np.random.rand(10, 5)
81 | atoms.calc.results["custom4"] = np.random.rand(10)
82 | atoms.arrays["arr"] = np.zeros_like(atoms.get_positions())
83 |
84 | schema = Properties1D.model_json_schema_from_atoms(s22_energy_forces[0])
85 | assert set(schema["properties"]["value"]["enum"]) == {
86 | "energy",
87 | "forces",
88 | "custom",
89 | "numbers",
90 | "positions",
91 | "arr",
92 | "custom2",
93 | "custom3",
94 | "custom4",
95 | }
96 |
--------------------------------------------------------------------------------
/tests/test_bookmarks.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from zndraw import ZnDraw
4 |
5 |
6 | def test_bookmarks(server, s22):
7 | """Test the server fixture."""
8 | vis = ZnDraw(url=server, token="test_token")
9 | vis.extend(s22)
10 |
11 | assert vis.bookmarks == {}
12 | vis.bookmarks = {5: "Hey there!"}
13 | assert vis.bookmarks == {"5": "Hey there!"}
14 |
15 | vis.bookmarks[5] = "Hello!"
16 | assert vis.bookmarks[5] == "Hello!"
17 |
18 | with pytest.raises(ValueError):
19 | vis.bookmarks = ["Bookmarks are not a list!"]
20 |
21 | with pytest.raises(ValueError):
22 | vis.bookmarks = {5: object()}
23 |
24 | with pytest.raises(ValueError):
25 | vis.bookmarks = {"string": "Hey there!"}
26 |
--------------------------------------------------------------------------------
/tests/test_camera.py:
--------------------------------------------------------------------------------
1 | from zndraw import ZnDraw
2 |
3 |
4 | def test_camera(server, s22):
5 | """Test the server fixture."""
6 | vis = ZnDraw(url=server, token="test_token")
7 | vis.extend(s22)
8 |
9 | assert vis.camera["position"] == [5, 5, 5]
10 | assert vis.camera["target"] == [0, 0, 0]
11 | assert len(vis.camera) == 2
12 |
13 | vis.camera = {"position": [1, 0, 0], "target": [0, 1, 0]}
14 |
15 | assert vis.camera["position"] == [1, 0, 0]
16 | assert vis.camera["target"] == [0, 1, 0]
17 |
--------------------------------------------------------------------------------
/tests/test_config.py:
--------------------------------------------------------------------------------
1 | from zndraw import ZnDraw
2 |
3 | # def test_config_defaults(server):
4 | # vis = ZnDraw(url=server, token="test_token")
5 | # ref_config = ZnDrawConfig(vis=None)
6 |
7 | # vis.config == ref_config.to_dict()
8 |
9 |
10 | def test_config_modify_arrows(server):
11 | room = "test_config_arrows"
12 | vis = ZnDraw(url=server, token=room)
13 |
14 | vis.config["Arrows"]["normalize"] = False
15 | assert vis.config["Arrows"]["normalize"] is False
16 |
17 | vis.config["Arrows"]["normalize"] = True
18 | assert vis.config["Arrows"]["normalize"] is True
19 |
20 | # check other defaults
21 | assert vis.config["Arrows"]["opacity"] == 1.0
22 |
23 |
24 | # def test_config_replace_znsocket(server):
25 | # room = "test_config_znsocket"
26 | # vis = ZnDraw(url=server, token=room)
27 |
28 | # vis.config["scene"] = None
29 | # assert isinstance(vis.config["scene"], znsocket.Dict)
30 |
31 |
32 | def test_config_modify_scene(server):
33 | room = "test_config_scene"
34 | vis = ZnDraw(url=server, token=room)
35 |
36 | vis.config["Camera"]["fps"] = 30
37 | assert vis.config["Camera"]["fps"] == 30
38 | vis.config["Camera"]["fps"] = 60
39 | assert vis.config["Camera"]["fps"] == 60
40 |
41 | # check other defaults
42 | assert vis.config["Camera"]["camera"] == "PerspectiveCamera"
43 |
--------------------------------------------------------------------------------
/tests/test_figures.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | from zndraw import Figure, ZnDraw
4 |
5 |
6 | def test_figures(server, s22):
7 | """Test the server fixture."""
8 | vis = ZnDraw(url=server, token="test_token")
9 | vis.extend(s22)
10 |
11 | fig, ax = plt.subplots()
12 | ax.plot([1, 2, 3], [1, 2, 3])
13 |
14 | vis.figures["mtpl"] = Figure(figure=fig)
15 |
16 | assert len(vis.figures) == 1
17 |
--------------------------------------------------------------------------------
/tests/test_geometries.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import znsocket
3 | from ase.build import molecule
4 |
5 | from zndraw import Box, Sphere, ZnDraw
6 |
7 |
8 | def test_geometries(server, s22):
9 | """Test the server fixture."""
10 | vis = ZnDraw(url=server, token="test_token")
11 | vis.extend(s22)
12 |
13 | assert vis.geometries == []
14 | vis.geometries = [Box(position=[1, 2, 3]), Sphere()]
15 | vis.socket.sleep(0.1)
16 |
17 | assert len(vis.geometries) == 2
18 | assert isinstance(vis.geometries[0], Box)
19 | assert isinstance(vis.geometries[1], Sphere)
20 |
21 | assert vis.geometries[0].position == [1, 2, 3]
22 | assert vis.geometries[1].position == [0, 0, 0]
23 |
24 | with pytest.raises(ValueError):
25 | vis.geometries = ["Geometries are not string!"]
26 |
27 | with pytest.raises(ValueError):
28 | vis.geometries = "Geometries are not string!"
29 |
30 |
31 | def test_geometry_selection_position(server):
32 | vis = ZnDraw(url=server, token="test_token")
33 | vis.append(molecule("CH4"))
34 | assert len(vis.geometries) == 0
35 | geometry_queue = znsocket.Dict(
36 | r=vis.r,
37 | socket=vis._refresh_client,
38 | key=f"queue:{vis.token}:geometry",
39 | )
40 | geometry_queue["Plane"] = {
41 | "material": {
42 | "color": "#62929e",
43 | "opacity": 0.2,
44 | "wireframe": False,
45 | "outlines": False,
46 | },
47 | "width": 10,
48 | "height": 10,
49 | }
50 | vis.socket.emit("room:worker:run")
51 | vis.socket.sleep(8)
52 |
53 | assert len(vis.geometries) == 1
54 | assert vis.geometries[0].position == [0, 0, 0]
55 |
56 | # now with a selection
57 | vis.selection = [1]
58 | geometry_queue["Plane"] = {
59 | "material": {
60 | "color": "#62929e",
61 | "opacity": 0.2,
62 | "wireframe": False,
63 | "outlines": False,
64 | },
65 | "width": 10,
66 | "height": 10,
67 | }
68 | vis.socket.emit("room:worker:run")
69 | vis.socket.sleep(8)
70 |
71 | assert len(vis.geometries) == 2
72 | assert vis.geometries[1].position == vis.atoms.positions[1].tolist()
73 | assert vis.geometries[1].position != [0, 0, 0]
74 |
75 | # now with a selection of multiple atoms
76 | vis.selection = [1, 2]
77 | geometry_queue["Plane"] = {
78 | "material": {
79 | "color": "#62929e",
80 | "opacity": 0.2,
81 | "wireframe": False,
82 | "outlines": False,
83 | },
84 | "width": 10,
85 | "height": 10,
86 | }
87 | vis.socket.emit("room:worker:run")
88 | vis.socket.sleep(8)
89 |
90 | assert len(vis.geometries) == 3
91 | assert (
92 | vis.geometries[2].position
93 | == vis.atoms.get_center_of_mass(indices=[1, 2]).tolist()
94 | )
95 |
--------------------------------------------------------------------------------
/tests/test_modifier.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.testing as npt
3 | import pytest
4 | import znsocket
5 | from ase.build import bulk, molecule
6 |
7 | from zndraw import Extension, ZnDraw
8 | from zndraw.exceptions import RoomLockedError
9 |
10 |
11 | def run_queue(vis, key, msg: dict):
12 | modifier_queue = znsocket.Dict(
13 | r=vis.r,
14 | socket=vis._refresh_client,
15 | key=f"queue:{vis.token}:{key}",
16 | )
17 | modifier_queue.update(msg)
18 | vis.socket.emit("room:worker:run")
19 | vis.socket.sleep(10)
20 |
21 |
22 | def test_run_selection(server, s22):
23 | """Test the server fixture."""
24 | vis = ZnDraw(url=server, token="test_token")
25 | vis.extend(s22)
26 | vis.step = 0
27 | vis.selection = [0]
28 |
29 | run_queue(vis, "selection", {"ConnectedParticles": {}})
30 |
31 | assert vis.selection == [0, 1, 2, 3]
32 |
33 |
34 | def test_run_modifier(server):
35 | vis = ZnDraw(url=server, token="test_token")
36 | vis.append(molecule("H2O"))
37 | vis.selection = [0]
38 | assert len(vis) == 1
39 | assert len(vis[-1]) == 3
40 |
41 | run_queue(vis, "modifier", {"Delete": {}})
42 |
43 | assert len(vis) == 2
44 | assert len(vis[-1]) == 2
45 |
46 |
47 | def test_register_modifier(server, s22, water):
48 | vis = ZnDraw(url=server, token="test_token")
49 | vis.extend(s22)
50 |
51 | class AppendWater(Extension):
52 | def run(self, vis, **kwargs) -> None:
53 | vis.append(water)
54 | vis.step = len(vis) - 1
55 |
56 | vis.register_modifier(AppendWater)
57 |
58 | run_queue(vis, "modifier", {"AppendWater": {}})
59 |
60 | assert vis.atoms == water
61 |
62 |
63 | def test_locked(server):
64 | vis = ZnDraw(url=server, token="test_token")
65 |
66 | assert vis.locked is False
67 | vis.locked = True
68 | assert vis.locked is True
69 |
70 |
71 | ##### Tests for each available modifier #####
72 |
73 |
74 | def test_modify_delete(server):
75 | vis = ZnDraw(url=server, token="test_token")
76 | vis.append(molecule("H2O"))
77 | vis.selection = [0]
78 | assert len(vis) == 1
79 | assert len(vis[-1]) == 3
80 |
81 | run_queue(vis, "modifier", {"Delete": {}})
82 |
83 | assert len(vis) == 2
84 | assert len(vis[-1]) == 2
85 |
86 |
87 | def test_modify_rotate(server):
88 | vis = ZnDraw(url=server, token="test_token")
89 | vis.append(molecule("H2O"))
90 | vis.selection = [0, 1]
91 | vis.points = [[0, 0, 0], [1, 0, 0]]
92 |
93 | run_queue(vis, "modifier", {"Rotate": {"steps": 10}})
94 | vis.socket.sleep(5)
95 |
96 | assert len(vis) == 11
97 | # TODO: test that the atoms rotated correctly
98 |
99 |
100 | def test_modify_translate(server):
101 | vis = ZnDraw(url=server, token="test_token")
102 | vis.append(molecule("H2O"))
103 | vis.selection = [0, 1, 2]
104 | vis.points = [[1, 0, 0], [0, 0, 0]]
105 |
106 | run_queue(vis, "modifier", {"Translate": {"steps": 10}})
107 |
108 | assert len(vis) == 11
109 |
110 | orig_pos = vis[0].positions
111 | npt.assert_allclose(vis[10].positions, orig_pos - np.array([1, 0, 0]))
112 | # spline interpolation is not an exact line
113 | npt.assert_allclose(vis[5].positions, orig_pos - np.array([0.5, 0, 0]), rtol=0.015)
114 |
115 | assert len(vis.points) == 2
116 |
117 |
118 | def test_modify_duplicate(server):
119 | vis = ZnDraw(url=server, token="test_token")
120 | vis.append(molecule("H2O"))
121 | vis.selection = [0]
122 |
123 | run_queue(vis, "modifier", {"Duplicate": {}})
124 |
125 | assert len(vis) == 2
126 | assert len(vis[0]) == 3
127 | assert len(vis[1]) == 4 # one duplicated atom
128 |
129 |
130 | def test_modify_change_type(server):
131 | vis = ZnDraw(url=server, token="test_token")
132 | vis.append(molecule("H2O"))
133 | vis.selection = [0]
134 |
135 | run_queue(vis, "modifier", {"ChangeType": {"symbol": "He"}})
136 |
137 | assert vis[1].symbols[0] == "He"
138 |
139 |
140 | def test_modify_wrap(server):
141 | vis = ZnDraw(url=server, token="test_token")
142 | copper = bulk("Cu", cubic=True)
143 | copper.positions += 5 # shift, so wrapped is recognizable
144 | vis.extend([copper, copper])
145 |
146 | run_queue(vis, "modifier", {"Wrap": {"all": True}})
147 |
148 | # Wrap is an inplace modifier
149 | assert len(vis) == 2
150 | for idx in range(2):
151 | wrapped_atoms = vis[idx]
152 | wrapped_atoms.wrap()
153 | assert not np.allclose(vis[idx].positions, copper.positions)
154 | assert np.allclose(vis[idx].positions, wrapped_atoms.positions)
155 |
156 |
157 | def test_modify_replicate(server):
158 | vis = ZnDraw(url=server, token="test_token")
159 | wurtzite = bulk("ZnO", "wurtzite", a=3.25, c=5.2)
160 | vis.extend([wurtzite, wurtzite])
161 |
162 | run_queue(vis, "modifier", {"Replicate": {"x": 2, "y": 2, "z": 2, "all": True}})
163 | vis.socket.sleep(5)
164 |
165 | # Replicate is an inplace modifier
166 | assert len(vis) == 2
167 | for idx in range(2):
168 | assert len(vis[idx]) == 8 * len(wurtzite)
169 |
170 |
171 | def test_modify_AddLineParticles(server):
172 | vis = ZnDraw(url=server, token="test_token")
173 | vis.append(molecule("H2O"))
174 | vis.points = [[0, 0, 0], [1, 0, 0]]
175 |
176 | run_queue(vis, "modifier", {"AddLineParticles": {"steps": 10, "symbol": "He"}})
177 |
178 | assert len(vis[0]) == 3
179 | assert len(vis[1]) == 5
180 | npt.assert_allclose(vis[1].positions[3], [0, 0, 0])
181 | npt.assert_allclose(vis[1].positions[4], [1, 0, 0])
182 | assert vis[1].symbols[4] == "He"
183 | assert vis[1].symbols[3] == "He"
184 |
185 |
186 | def test_modify_center(server):
187 | vis = ZnDraw(url=server, token="test_token")
188 | copper = bulk("Cu", cubic=True)
189 | vis.append(copper)
190 | vis.selection = [0]
191 |
192 | run_queue(vis, "modifier", {"Center": {"all": True}})
193 |
194 | assert np.allclose(vis[0][0].position, np.diag(vis[0].cell) / 2)
195 | assert not np.allclose(vis[0].positions, copper.positions)
196 |
197 |
198 | def test_modify_RemoveAtoms(server):
199 | vis = ZnDraw(url=server, token="test_token")
200 | vis.append(molecule("H2O"))
201 | vis.append(molecule("H2O"))
202 | assert len(vis) == 2
203 | vis.step = 0
204 |
205 | run_queue(vis, "modifier", {"RemoveAtoms": {}})
206 |
207 | assert len(vis) == 1
208 |
209 |
210 | def test_modified_while_locked(server):
211 | vis = ZnDraw(url=server, token="test_token")
212 | vis.append(molecule("H2O"))
213 | vis.append(molecule("H2O"))
214 | vis.append(molecule("H2O"))
215 |
216 | assert len(vis) == 3
217 |
218 | run_queue(vis, "modifier", {"RemoveAtoms": {}})
219 |
220 | assert len(vis) == 2
221 | vis.locked = True
222 |
223 | run_queue(vis, "modifier", {"RemoveAtoms": {}})
224 |
225 | assert len(vis) == 2
226 |
227 | with pytest.raises(RoomLockedError):
228 | vis.append(molecule("H2O"))
229 |
230 | with pytest.raises(RoomLockedError):
231 | vis.step = 0
232 |
233 | with pytest.raises(RoomLockedError):
234 | vis.selection = [0]
235 |
236 | with pytest.raises(RoomLockedError):
237 | vis.atoms = molecule("H2O")
238 |
239 | with pytest.raises(RoomLockedError):
240 | vis.extend([molecule("H2O")])
241 |
--------------------------------------------------------------------------------
/tests/test_points.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.testing as npt
3 |
4 | from zndraw import ZnDraw
5 |
6 |
7 | def test_points_segments(server):
8 | """Test the server fixture."""
9 | zndraw = ZnDraw(url=server, token="test_token")
10 |
11 | assert isinstance(zndraw.points, np.ndarray)
12 | assert isinstance(zndraw.segments, np.ndarray)
13 |
14 | npt.assert_array_equal(zndraw.points, [])
15 | npt.assert_array_equal(zndraw.segments, [])
16 |
17 | zndraw.points = np.array([[0, 0, 0], [1, 1, 1]])
18 | assert zndraw.points.shape == (2, 3)
19 | assert zndraw.segments.shape == (100, 3)
20 |
21 | npt.assert_array_equal(zndraw.points, [[0, 0, 0], [1, 1, 1]])
22 |
--------------------------------------------------------------------------------
/tests/test_selection.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from zndraw import ZnDraw
4 |
5 |
6 | def test_selection(server, s22):
7 | """Test the server fixture."""
8 | vis = ZnDraw(url=server, token="test_token")
9 | vis.extend(s22)
10 |
11 | assert len(vis.atoms) == 8
12 |
13 | with pytest.raises(ValueError):
14 | vis.selection = ["Hello"]
15 |
16 | with pytest.raises(IndexError):
17 | vis.selection = [0, 1, 2, 25]
18 |
19 | with pytest.raises(IndexError):
20 | vis.selection = [0, 1, 2, -10]
21 |
22 | vis.selection = [0, 7, 6, 5, 4]
23 | assert vis.selection == [0, 7, 6, 5, 4]
24 |
25 | with pytest.raises(ValueError):
26 | vis.selection = [0, 0, 2]
27 |
28 | with pytest.raises(ValueError):
29 | vis.selection = "Hello"
30 |
--------------------------------------------------------------------------------
/tests/test_serializer.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import ase
4 | import numpy.testing as npt
5 | import pytest
6 | import znjson
7 | from ase.calculators.singlepoint import SinglePointCalculator
8 | from ase.constraints import FixAtoms
9 |
10 | from zndraw.converter import ASEConverter
11 |
12 |
13 | def test_ase_converter(s22):
14 | s22[0].connectivity = [[0, 1, 1], [1, 2, 1], [2, 3, 1]]
15 | s22[3].calc = SinglePointCalculator(s22[3])
16 | s22[3].calc.results = {"energy": 0.0, "predicted_energy": 1.0}
17 | s22[4].info = {"key": "value"}
18 |
19 | structures_json = znjson.dumps(
20 | s22, cls=znjson.ZnEncoder.from_converters([ASEConverter])
21 | )
22 |
23 | non_json = json.loads(structures_json)
24 | assert "numbers" not in non_json[0]["value"]["arrays"]
25 | assert "positions" not in non_json[0]["value"]["arrays"]
26 | assert "pbc" not in non_json[0]["value"]["info"]
27 | assert "cell" not in non_json[0]["value"]["info"]
28 |
29 | structures = znjson.loads(
30 | structures_json, cls=znjson.ZnDecoder.from_converters([ASEConverter])
31 | )
32 | for s1, s2 in zip(s22, structures):
33 | assert s1 == s2
34 |
35 | npt.assert_array_equal(structures[0].connectivity, [[0, 1, 1], [1, 2, 1], [2, 3, 1]])
36 | with pytest.raises(AttributeError):
37 | _ = structures[1].connectivity
38 |
39 | assert structures[3].calc.results == {"energy": 0.0, "predicted_energy": 1.0}
40 |
41 | assert "colors" not in structures[0].arrays
42 | assert "radii" not in structures[0].arrays
43 |
44 | assert structures[4].info == {"key": "value"}
45 |
46 |
47 | def test_exotic_atoms():
48 | atoms = ase.Atoms("X", positions=[[0, 0, 0]])
49 | atoms.arrays["colors"] = ["#ff0000"]
50 | atoms.arrays["radii"] = [0.3]
51 |
52 | new_atoms = znjson.loads(
53 | znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
54 | cls=znjson.ZnDecoder.from_converters([ASEConverter]),
55 | )
56 | npt.assert_array_equal(new_atoms.arrays["colors"], ["#ff0000"])
57 | npt.assert_array_equal(new_atoms.arrays["radii"], [0.3])
58 |
59 |
60 | def test_modified_atoms():
61 | atoms = ase.Atoms("H2", positions=[[0, 0, 0], [0, 0, 1]])
62 | new_atoms = znjson.loads(
63 | znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
64 | cls=znjson.ZnDecoder.from_converters([ASEConverter]),
65 | )
66 | npt.assert_array_equal(new_atoms.get_atomic_numbers(), [1, 1])
67 |
68 | # subtract
69 | atoms = new_atoms[:1]
70 | new_atoms = znjson.loads(
71 | znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
72 | cls=znjson.ZnDecoder.from_converters([ASEConverter]),
73 | )
74 |
75 | npt.assert_array_equal(new_atoms.get_atomic_numbers(), [1])
76 |
77 | # add
78 | atoms = new_atoms + ase.Atoms("H", positions=[[0, 0, 1]])
79 |
80 | new_atoms = znjson.loads(
81 | znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
82 | cls=znjson.ZnDecoder.from_converters([ASEConverter]),
83 | )
84 |
85 | npt.assert_array_equal(new_atoms.get_atomic_numbers(), [1, 1])
86 |
87 |
88 | def test_constraints_fixed_atoms():
89 | atoms = ase.Atoms("H2", positions=[[0, 0, 0], [0, 0, 1]])
90 | atoms.set_constraint(FixAtoms([0]))
91 | new_atoms = znjson.loads(
92 | znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
93 | cls=znjson.ZnDecoder.from_converters([ASEConverter]),
94 | )
95 | assert isinstance(new_atoms.constraints[0], FixAtoms)
96 | assert new_atoms.constraints[0].index == [0]
97 |
--------------------------------------------------------------------------------
/tests/test_step.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from zndraw import ZnDraw
4 |
5 |
6 | def test_step(server, s22):
7 | """Test the server fixture."""
8 | vis = ZnDraw(url=server, token="test_token")
9 | vis.extend(s22)
10 |
11 | assert vis.step == 0
12 | assert isinstance(vis.step, int)
13 |
14 | assert len(vis) == 22
15 |
16 | with pytest.raises(ValueError):
17 | vis.step = -1
18 |
19 | with pytest.raises(IndexError):
20 | vis.step = 22
21 |
22 | with pytest.raises(ValueError):
23 | vis.step = "string"
24 |
25 | vis.step = 5
26 | assert vis.step == 5
27 | assert isinstance(vis.step, int)
28 |
--------------------------------------------------------------------------------
/tests/test_tasks.py:
--------------------------------------------------------------------------------
1 | import ase
2 |
3 | from zndraw.tasks import FileIO, get_generator_from_filename
4 |
5 |
6 | def test_get_generator_from_filename():
7 | file = FileIO(
8 | name="https://raw.githubusercontent.com/LarsSchaaf/Guaranteed-Non-Local-Molecular-Dataset/main/gnl-dataset/GNL-v0.2/gnl-v0.2-test.xyz"
9 | )
10 | generator = get_generator_from_filename(file)
11 |
12 | assert isinstance(next(iter(generator)), ase.Atoms)
13 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.testing as npt
3 |
4 | from zndraw.utils import (
5 | convert_url_to_http,
6 | direction_to_euler,
7 | euler_to_direction,
8 | parse_url,
9 | )
10 |
11 |
12 | def test_parse_url():
13 | assert parse_url("http://example.com") == ("http://example.com", None)
14 | assert parse_url("http://example.com/path") == ("http://example.com", "path")
15 | assert parse_url("http://localhost:5000") == ("http://localhost:5000", None)
16 | assert parse_url("http://localhost:5000/path") == ("http://localhost:5000", "path")
17 | assert parse_url("http://localhost:5000/path/") == ("http://localhost:5000", "path")
18 |
19 |
20 | def test_conversion_utils():
21 | """Test conversion functions"""
22 | direction = np.array([1, 2, 3])
23 | direction = direction / np.linalg.norm(direction)
24 | euler = direction_to_euler(direction)
25 | new_direction = euler_to_direction(euler)
26 |
27 | npt.assert_allclose(direction, new_direction, atol=1e-6)
28 |
29 | direction = np.array([1, 0, 0])
30 | euler = direction_to_euler(direction)
31 | new_direction = euler_to_direction(euler)
32 |
33 | npt.assert_allclose(direction, new_direction, atol=1e-6)
34 | npt.assert_allclose(euler, [0, 0, 0], atol=1e-6)
35 |
36 | direction = np.array([0, -1, 0])
37 | euler = direction_to_euler(direction)
38 | new_direction = euler_to_direction(euler)
39 |
40 | npt.assert_allclose(direction, new_direction, atol=1e-6)
41 |
42 |
43 | def test_url():
44 | # safe url
45 | before_url = "ws://localhost:8000/token/1234"
46 | url = convert_url_to_http(before_url)
47 | assert url == "http://localhost:8000/token/1234"
48 |
49 | # unsafe url containing ws in token
50 | before_url = "ws://localhost:8000/token/eNwsdW5k"
51 | url = convert_url_to_http(before_url)
52 | assert url == "http://localhost:8000/token/eNwsdW5k"
53 |
54 | # Now with https
55 | before_url = "wss://localhost:8000/token/1234"
56 | url = convert_url_to_http(before_url)
57 | assert url == "https://localhost:8000/token/1234"
58 |
59 | # unsafe url containing ws in token
60 | before_url = "wss://localhost:8000/token/eNwsdW5k"
61 | url = convert_url_to_http(before_url)
62 | assert url == "https://localhost:8000/token/eNwsdW5k"
63 |
--------------------------------------------------------------------------------
/tests/test_vectorfields.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from ase.build import molecule
4 |
5 | from zndraw import ZnDraw
6 |
7 |
8 | def test_vectors(server):
9 | vis = ZnDraw(url=server, token="test_token")
10 |
11 | water = molecule("H2O")
12 | water.info["vectors"] = [[[1, 0, 0], [0, 1, 0]]]
13 |
14 | vis.append(water)
15 |
16 | assert vis[0].info["vectors"] == [[[1, 0, 0], [0, 1, 0]]]
17 |
18 |
19 | def test_vectors_list_numpy(server):
20 | vis = ZnDraw(url=server, token="test_token")
21 |
22 | water = molecule("H2O")
23 | water.info["vectors"] = [np.array([[1, 0, 0], [0, 1, 0]])]
24 |
25 | vis.append(water)
26 |
27 | assert vis[0].info["vectors"] == [[[1, 0, 0], [0, 1, 0]]]
28 |
29 |
30 | def test_vectors_numpy(server):
31 | vis = ZnDraw(url=server, token="test_token")
32 |
33 | water = molecule("H2O")
34 | water.info["vectors"] = np.array([[[1, 0, 0], [0, 1, 0]]])
35 |
36 | vis.append(water)
37 |
38 | assert vis[0].info["vectors"] == [[[1, 0, 0], [0, 1, 0]]]
39 |
40 |
41 | def test_vectors_format(server):
42 | vis = ZnDraw(url=server, token="test_token")
43 |
44 | water = molecule("H2O")
45 | water.info["vectors"] = [1, 2, 3]
46 |
47 | with pytest.raises(ValueError):
48 | vis.append(water)
49 |
50 | water.info["vectors"] = [[1, 2, 3], [4, 5, 6]]
51 |
52 | with pytest.raises(ValueError):
53 | vis.append(water)
54 |
--------------------------------------------------------------------------------
/tests/test_vis.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import redis
3 | import znjson
4 | from ase.build import molecule
5 |
6 | from zndraw import ZnDraw, ZnDrawLocal
7 | from zndraw.converter import ASEConverter
8 |
9 |
10 | @pytest.fixture
11 | def full(server):
12 | return ZnDraw(url=server, token="test_token")
13 |
14 |
15 | @pytest.fixture
16 | def local(server):
17 | r = redis.Redis.from_url("redis://localhost:6379/0")
18 |
19 | return ZnDrawLocal(url=server, token="test_token", r=r)
20 |
21 |
22 | @pytest.mark.parametrize("ref", ["full", "local"])
23 | def test_append_atoms(ref, request):
24 | """Test the server fixture."""
25 | vis = request.getfixturevalue(ref)
26 | water = molecule("H2O")
27 | vis.append(water)
28 |
29 | assert vis[-1] == water
30 |
31 |
32 | @pytest.mark.parametrize("ref", ["full", "local"])
33 | def test_append_dump(ref, request):
34 | """Test the server fixture."""
35 | vis = request.getfixturevalue(ref)
36 |
37 | water = molecule("H2O")
38 | vis.append(water)
39 |
40 | assert vis[-1] == water
41 |
42 |
43 | @pytest.mark.parametrize("ref", ["full", "local"])
44 | def test_append_faulty(ref, request):
45 | """Test the server fixture."""
46 | vis = request.getfixturevalue(ref)
47 |
48 | water = molecule("H2O")
49 | data = ASEConverter().encode(water)
50 | with pytest.raises(ValueError, match="Unable to parse provided data object"):
51 | vis.append(data)
52 | with pytest.raises(ValueError, match="Unable to parse provided data object"):
53 | vis.extend(3.14)
54 |
55 |
56 | @pytest.mark.parametrize("ref", ["full", "local"])
57 | def test_setitem_atoms(ref, request, s22):
58 | """Test the server fixture."""
59 | vis = request.getfixturevalue(ref)
60 | vis.extend(s22)
61 | water = molecule("H2O")
62 | vis[0] = water
63 | assert vis[0] == water
64 |
65 | vis[[1, 2]] = [water, water]
66 | assert vis[[0, 1, 2]] == [water, water, water]
67 |
68 |
69 | @pytest.mark.parametrize("ref", ["local", "full"])
70 | def test_setitem_dump(ref, request, s22):
71 | """Test the server fixture."""
72 | vis = request.getfixturevalue(ref)
73 | vis.extend(s22)
74 | water = molecule("H2O")
75 | vis[0] = water
76 | assert vis[0] == molecule("H2O")
77 |
78 | vis[[1, 2]] = [water, water]
79 | assert vis[[0, 1, 2]] == [water, water, water]
80 |
81 |
82 | @pytest.mark.parametrize("ref", ["full", "local"])
83 | def test_setitem_faulty(ref, request, s22):
84 | """Test the server fixture."""
85 | vis = request.getfixturevalue(ref)
86 | vis.extend(s22)
87 |
88 | water = molecule("H2O")
89 | data = ASEConverter().encode(water)
90 | with pytest.raises(ValueError, match="Unable to parse provided data object"):
91 | vis[0] = data
92 | with pytest.raises(ValueError, match="Unable to parse provided data object"):
93 | vis.extend(3.14)
94 |
95 |
96 | @pytest.mark.parametrize("ref", ["full", "local"])
97 | def test_extend_atoms(ref, request, s22):
98 | """Test the server fixture."""
99 | vis = request.getfixturevalue(ref)
100 | vis.extend(s22)
101 | assert vis[:] == s22
102 |
103 |
104 | @pytest.mark.parametrize("ref", ["local", "full"])
105 | def test_extend_dump(ref, request, s22):
106 | """Test the server fixture."""
107 | vis = request.getfixturevalue(ref)
108 | vis.extend(s22)
109 | assert vis[:] == s22
110 |
111 |
112 | @pytest.mark.parametrize("ref", ["full", "local"])
113 | def test_extend_faulty(ref, request, s22):
114 | """Test the server fixture."""
115 | vis = request.getfixturevalue(ref)
116 | vis.extend(s22)
117 |
118 | with pytest.raises(ValueError, match="Unable to parse provided data object"):
119 | vis.extend(znjson.dumps(s22, cls=znjson.ZnEncoder.from_converters(ASEConverter)))
120 |
121 | data = [ASEConverter().encode(s) for s in s22]
122 |
123 | with pytest.raises(ValueError, match="Unable to parse provided data object"):
124 | vis.extend(data)
125 |
126 | with pytest.raises(ValueError, match="Unable to parse provided data object"):
127 | vis.extend(3.14)
128 |
--------------------------------------------------------------------------------
/tests/test_zndraw.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import socketio.exceptions
3 |
4 | from zndraw import ZnDraw
5 |
6 |
7 | @pytest.fixture
8 | def lst():
9 | return []
10 |
11 |
12 | @pytest.fixture
13 | def vis(server):
14 | return ZnDraw(url=server, token="test_token")
15 |
16 |
17 | @pytest.mark.parametrize("ref", ["lst", "vis"])
18 | def test_zndraw(ref, request):
19 | """Test the server fixture."""
20 | obj = request.getfixturevalue(ref)
21 | assert len(obj) == 0
22 |
23 |
24 | def test_zndraw_no_connection():
25 | with pytest.raises(
26 | socketio.exceptions.ConnectionError,
27 | match="Could not connect",
28 | ):
29 | ZnDraw(url="http://localhost:8080", token="test_token")
30 |
31 |
32 | @pytest.mark.parametrize("ref", ["lst", "vis"])
33 | def test_extend(ref, s22, request):
34 | """Test the server fixture."""
35 | obj = request.getfixturevalue(ref)
36 | obj.extend(s22)
37 | assert len(obj) == len(s22)
38 |
39 |
40 | @pytest.mark.parametrize("ref", ["lst", "vis"])
41 | def test_getitem(ref, s22, request):
42 | """Test the server fixture."""
43 | vis = request.getfixturevalue(ref)
44 | vis.extend(s22)
45 | assert vis[0] == s22[0]
46 | assert vis[10] == s22[10]
47 |
48 | with pytest.raises(IndexError):
49 | vis[100]
50 |
51 |
52 | @pytest.mark.parametrize("ref", ["lst", "vis"])
53 | def test_insert(ref, s22, water, request):
54 | """Test the server fixture."""
55 | vis = request.getfixturevalue(ref)
56 | vis.extend(s22)
57 | vis.insert(0, water)
58 | assert len(vis) == len(s22) + 1
59 | assert vis[0] == water
60 | vis.insert(10, water)
61 | assert len(vis) == len(s22) + 2
62 | assert vis[10] == water
63 |
64 |
65 | @pytest.mark.parametrize("ref", ["lst", "vis"])
66 | def test_setitem(ref, s22, water, request):
67 | vis = request.getfixturevalue(ref)
68 | vis.extend(s22)
69 | vis[0] = water
70 | assert vis[0] == water
71 | vis[10] = water
72 | assert vis[10] == water
73 |
74 | assert len(vis) == len(s22)
75 |
76 |
77 | @pytest.mark.parametrize("ref", ["lst", "vis"])
78 | def test_setitem_slice(ref, s22, water, request):
79 | vis = request.getfixturevalue(ref)
80 | vis.extend(s22)
81 | vis[0:5] = [water] * 5
82 | assert vis[0:5] == [water] * 5
83 | assert vis[5:] == s22[5:]
84 |
85 |
86 | @pytest.mark.parametrize("ref", ["lst", "vis"])
87 | def test_delitem(ref, s22, request):
88 | vis = request.getfixturevalue(ref)
89 | vis.extend(s22)
90 | del vis[0]
91 | assert len(vis) == len(s22) - 1
92 | del vis[10]
93 | assert len(vis) == len(s22) - 2
94 |
95 |
96 | @pytest.mark.parametrize("ref", ["lst", "vis"])
97 | def test_delitem_slice(ref, server, s22, request):
98 | vis = request.getfixturevalue(ref)
99 | vis.extend(s22)
100 | del vis[0:5]
101 | assert len(vis) == len(s22) - 5
102 | del vis[5:10]
103 | assert len(vis) == len(s22) - 10
104 | del vis[5:]
105 | assert len(vis) == 5
106 |
107 |
108 | @pytest.mark.parametrize("ref", ["lst", "vis"])
109 | def test_append(ref, s22, water, request):
110 | vis = request.getfixturevalue(ref)
111 | vis.extend(s22)
112 | vis.append(water)
113 | assert len(vis) == len(s22) + 1
114 | assert vis[-1] == water
115 |
--------------------------------------------------------------------------------
/zndraw/.gitignore:
--------------------------------------------------------------------------------
1 | templates/
2 |
--------------------------------------------------------------------------------
/zndraw/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 |
3 | from zndraw.base import Extension
4 | from zndraw.draw import (
5 | Box,
6 | Circle,
7 | Cone,
8 | Custom2DShape,
9 | Cylinder,
10 | Dodecahedron,
11 | Ellipsoid,
12 | Icosahedron,
13 | Material,
14 | Octahedron,
15 | Plane,
16 | Rhomboid,
17 | Ring,
18 | Sphere,
19 | Tetrahedron,
20 | Torus,
21 | TorusKnot,
22 | )
23 | from zndraw.figure import Figure
24 | from zndraw.zndraw import ZnDraw, ZnDrawLocal
25 |
26 | __all__ = [
27 | "ZnDraw",
28 | "ZnDrawLocal",
29 | "Extension",
30 | "Plane",
31 | "Sphere",
32 | "Box",
33 | "Circle",
34 | "Cone",
35 | "Cylinder",
36 | "Dodecahedron",
37 | "Icosahedron",
38 | "Octahedron",
39 | "Ring",
40 | "Tetrahedron",
41 | "Torus",
42 | "TorusKnot",
43 | "Rhomboid",
44 | "Ellipsoid",
45 | "Material",
46 | "Custom2DShape",
47 | "Figure",
48 | ]
49 |
50 | __version__ = importlib.metadata.version("zndraw")
51 |
--------------------------------------------------------------------------------
/zndraw/abc.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 |
4 | class Message(t.TypedDict):
5 | """A message to be sent to the client."""
6 |
7 | time: str
8 | msg: str
9 | origin: str | None
10 |
--------------------------------------------------------------------------------
/zndraw/app.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import time
3 |
4 | import redis
5 | import znsocket.exceptions
6 | from celery import Celery, Task
7 | from flask import Flask
8 | from flask_socketio import SocketIO
9 |
10 | from zndraw.server import init_socketio_events, main_blueprint
11 |
12 |
13 | def celery_init_app(app: Flask) -> Celery:
14 | class FlaskTask(Task):
15 | def __call__(self, *args: object, **kwargs: object) -> object:
16 | with app.app_context():
17 | return self.run(*args, **kwargs)
18 |
19 | celery_app = Celery(app.name, task_cls=FlaskTask)
20 | celery_app.config_from_object(app.config["CELERY"])
21 | celery_app.set_default()
22 | app.extensions["celery"] = celery_app
23 | return celery_app
24 |
25 |
26 | def storage_init_app(app: Flask) -> None:
27 | if app.config["STORAGE"].startswith("redis"):
28 | app.extensions["redis"] = redis.Redis.from_url(
29 | app.config["STORAGE"], decode_responses=True
30 | )
31 | elif app.config["STORAGE"].startswith("znsocket"):
32 | for _ in range(100): # try to connect to znsocket for 10 s
33 | # if we start znsocket via celery it will take some time to start
34 | try:
35 | app.extensions["redis"] = znsocket.Client.from_url(app.config["STORAGE"])
36 | break
37 | except ConnectionError:
38 | # wait for znsocket to start, if started together with the server
39 | time.sleep(0.1)
40 | else:
41 | raise ValueError(f"Unknown storage type: {app.config['STORAGE']}")
42 |
43 |
44 | def create_app() -> Flask:
45 | app = Flask(__name__)
46 | app.config.update(SESSION_COOKIE_SAMESITE="None", SESSION_COOKIE_SECURE=True)
47 | app.config["SECRET_KEY"] = "secret!"
48 | # loads all FLASK_ prefixed environment variables into the app config
49 | app.config.from_prefixed_env()
50 |
51 | # TODO: this will not work without redis!!!!
52 | if app.config["STORAGE"].startswith("redis"):
53 | app.config.from_mapping(
54 | CELERY={
55 | "broker_url": app.config["STORAGE"],
56 | "result_backend": app.config["STORAGE"],
57 | "task_ignore_result": True,
58 | },
59 | )
60 | else:
61 | # nothing else supported, using filesystem storage
62 | data_folder = pathlib.Path("~/.zincware/zndraw/celery/out").expanduser()
63 | data_folder_processed = pathlib.Path(
64 | "~/.zincware/zndraw/celery/processed"
65 | ).expanduser()
66 | control_folder = pathlib.Path("~/.zincware/zndraw/celery/ctrl").expanduser()
67 |
68 | data_folder.mkdir(parents=True, exist_ok=True)
69 | data_folder_processed.mkdir(parents=True, exist_ok=True)
70 | control_folder.mkdir(parents=True, exist_ok=True)
71 |
72 | app.config.from_mapping(
73 | CELERY={
74 | "broker_url": "filesystem://",
75 | "result_backend": "cache",
76 | "cache_backend": "memory",
77 | "task_ignore_result": True,
78 | "broker_transport_options": {
79 | "data_folder_in": data_folder.as_posix(),
80 | "data_folder_out": data_folder.as_posix(),
81 | "data_folder_processed": data_folder_processed.as_posix(),
82 | "control_folder": control_folder.as_posix(),
83 | },
84 | },
85 | )
86 |
87 | # Initialize SocketIO
88 | message_queue = (
89 | app.config["CELERY"]["broker_url"]
90 | if app.config["STORAGE"].startswith("redis")
91 | else None
92 | )
93 | max_http_buffer_size = app.config.get("MAX_HTTP_BUFFER_SIZE")
94 | kwargs = {
95 | "message_queue": message_queue,
96 | "cors_allowed_origins": "*",
97 | }
98 | if "SOCKETIO_PING_TIMEOUT" in app.config:
99 | kwargs["ping_timeout"] = app.config["SOCKETIO_PING_TIMEOUT"]
100 | if max_http_buffer_size is not None:
101 | kwargs["max_http_buffer_size"] = int(max_http_buffer_size)
102 |
103 | socketio = SocketIO(app, **kwargs, logger=False, engineio_logger=False)
104 |
105 | # Initialize Celery
106 | celery_init_app(app)
107 |
108 | # Initialize storage
109 | storage_init_app(app)
110 | # we only need this server if we are using redis
111 | # otherwise a znsocket server will run anyhow
112 | if app.config["STORAGE"].startswith("redis"):
113 | from redis import Redis
114 |
115 | znsocket.attach_events(
116 | socketio.server,
117 | storage=Redis.from_url(app.config["STORAGE"], decode_responses=True),
118 | )
119 | else:
120 | znsocket.attach_events(
121 | socketio.server,
122 | storage=znsocket.Client.from_url(
123 | app.config["STORAGE"], decode_responses=True
124 | ),
125 | )
126 |
127 | # Register routes and socketio events
128 | app.register_blueprint(main_blueprint)
129 | init_socketio_events(socketio)
130 |
131 | # Add socketio to app extensions for easy access
132 | app.extensions["socketio"] = socketio
133 |
134 | return app
135 |
--------------------------------------------------------------------------------
/zndraw/base.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import logging
3 | import typing as t
4 | from abc import abstractmethod
5 |
6 | from pydantic import BaseModel
7 |
8 | if t.TYPE_CHECKING:
9 | from zndraw import ZnDraw
10 |
11 | log = logging.getLogger(__name__)
12 |
13 |
14 | class Extension(BaseModel):
15 | @abstractmethod
16 | def run(self, vis: "ZnDraw", **kwargs) -> None:
17 | raise NotImplementedError("run method must be implemented in subclass")
18 |
19 |
20 | @dataclasses.dataclass # TODO: move to a separate file, so it can be imported in other files
21 | class FileIO:
22 | name: str | None = None
23 | start: int = 0
24 | stop: int | None = None
25 | step: int = 1
26 | remote: str | None = None
27 | rev: str | None = None
28 | convert_nan: bool = False
29 |
30 | def to_dict(self):
31 | return dataclasses.asdict(self)
32 |
--------------------------------------------------------------------------------
/zndraw/bonds/__init__.py:
--------------------------------------------------------------------------------
1 | import ase
2 | import networkx as nx
3 | import numpy as np
4 | from ase.neighborlist import natural_cutoffs
5 | from networkx.exception import NetworkXError
6 | from pydantic import BaseModel, Field
7 |
8 |
9 | class ASEComputeBonds(BaseModel):
10 | single_bond_multiplier: float = Field(1.2, le=2, ge=0)
11 | double_bond_multiplier: float = Field(0.9, le=1, ge=0)
12 | triple_bond_multiplier: float = Field(0.0, le=1, ge=0)
13 |
14 | def build_graph(self, atoms: ase.Atoms):
15 | cutoffs = [
16 | self.single_bond_multiplier,
17 | self.double_bond_multiplier,
18 | self.triple_bond_multiplier,
19 | ]
20 | atoms_copy = atoms.copy()
21 | connectivity_matrix = np.zeros((len(atoms_copy), len(atoms_copy)), dtype=int)
22 | atoms_copy.pbc = False
23 | distance_matrix = atoms_copy.get_all_distances(mic=False)
24 | np.fill_diagonal(distance_matrix, np.inf)
25 | for cutoff in cutoffs:
26 | cutoffs = np.array(natural_cutoffs(atoms_copy, mult=cutoff))
27 | cutoffs = cutoffs[:, None] + cutoffs[None, :]
28 | connectivity_matrix[distance_matrix <= cutoffs] += 1
29 | G = nx.from_numpy_array(connectivity_matrix)
30 | return G
31 |
32 | def update_graph_using_modifications(self, atoms: ase.Atoms):
33 | modifications = atoms.info.get("modifications", {})
34 | graph = atoms.connectivity
35 | for key in modifications:
36 | atom_1, atom_2 = key
37 | weight = modifications[key]
38 | if weight == 0:
39 | self.remove_edge(graph, atom_1, atom_2)
40 | else:
41 | graph.add_edge(atom_1, atom_2, weight=weight)
42 |
43 | @staticmethod
44 | def remove_edge(graph, atom_1, atom_2):
45 | try:
46 | graph.remove_edge(atom_1, atom_2)
47 | except NetworkXError:
48 | pass
49 |
50 | def get_bonds(self, atoms: ase.Atoms, graph: nx.Graph = None):
51 | if graph is None:
52 | graph = self.build_graph(atoms)
53 | bonds = []
54 | for edge in graph.edges:
55 | bonds.append((edge[0], edge[1], graph.edges[edge]["weight"]))
56 | return bonds
57 |
58 | def update_bond_order(self, atoms: ase.Atoms, particles: list[int], order: int):
59 | if len(particles) != 2:
60 | raise ValueError("Exactly two particles must be selected")
61 | modifications = atoms.info.get("modifications", {})
62 | sorted_particles = tuple(sorted(particles))
63 | modifications[sorted_particles] = order
64 | atoms.info["modifications"] = modifications
65 |
--------------------------------------------------------------------------------
/zndraw/converter.py:
--------------------------------------------------------------------------------
1 | import ase
2 | import numpy as np
3 | import znjson
4 | from ase.calculators.singlepoint import SinglePointCalculator
5 | from ase.constraints import FixAtoms
6 |
7 | from zndraw.draw import Object3D
8 | from zndraw.type_defs import ASEDict
9 |
10 |
11 | class ASEConverter(znjson.ConverterBase):
12 | """Encode/Decode datetime objects
13 |
14 | Attributes
15 | ----------
16 | level: int
17 | Priority of this converter over others.
18 | A higher level will be used first, if there
19 | are multiple converters available
20 | representation: str
21 | An unique identifier for this converter.
22 | instance:
23 | Used to select the correct converter.
24 | This should fulfill isinstance(other, self.instance)
25 | or __eq__ should be overwritten.
26 | """
27 |
28 | level = 100
29 | representation = "ase.Atoms"
30 | instance = ase.Atoms
31 |
32 | def encode(self, obj: ase.Atoms) -> ASEDict:
33 | """Convert the datetime object to str / isoformat"""
34 |
35 | numbers = obj.numbers.tolist()
36 | positions = obj.positions.tolist()
37 | pbc = obj.pbc.tolist()
38 | cell = obj.cell.tolist()
39 |
40 | info = {
41 | k: v
42 | for k, v in obj.info.items()
43 | if isinstance(v, (float, int, str, bool, list))
44 | }
45 | info |= {k: v.tolist() for k, v in obj.info.items() if isinstance(v, np.ndarray)}
46 | vectors = info.pop("vectors", [])
47 | if isinstance(vectors, np.ndarray):
48 | vectors = vectors.tolist()
49 | for idx, vector in enumerate(vectors):
50 | if isinstance(vector, np.ndarray):
51 | vectors[idx] = vector.tolist()
52 |
53 | if len(vectors) != 0:
54 | vectors = np.array(vectors)
55 | if vectors.ndim != 3:
56 | raise ValueError(
57 | f"Vectors must be of shape (n, 2, 3), found '{vectors.shape}'"
58 | )
59 | if vectors.shape[1] != 2:
60 | raise ValueError(
61 | f"Vectors must be of shape (n, 2, 3), found '{vectors.shape}'"
62 | )
63 | if vectors.shape[2] != 3:
64 | raise ValueError(
65 | f"Vectors must be of shape (n, 2, 3), found '{vectors.shape}'"
66 | )
67 |
68 | vectors = vectors.tolist()
69 |
70 | if obj.calc is not None:
71 | calc = {
72 | k: v
73 | for k, v in obj.calc.results.items()
74 | if isinstance(v, (float, int, str, bool, list))
75 | }
76 | calc |= {
77 | k: v.tolist()
78 | for k, v in obj.calc.results.items()
79 | if isinstance(v, np.ndarray)
80 | }
81 | else:
82 | calc = {}
83 |
84 | # All additional information should be stored in calc.results
85 | # and not in calc.arrays, thus we will not convert it here!
86 | arrays = {}
87 |
88 | for key in obj.arrays:
89 | if isinstance(obj.arrays[key], np.ndarray):
90 | arrays[key] = obj.arrays[key].tolist()
91 | else:
92 | arrays[key] = obj.arrays[key]
93 |
94 | if hasattr(obj, "connectivity") and obj.connectivity is not None:
95 | connectivity = (
96 | obj.connectivity.tolist()
97 | if isinstance(obj.connectivity, np.ndarray)
98 | else obj.connectivity
99 | )
100 | else:
101 | connectivity = []
102 |
103 | constraints = []
104 | if len(obj.constraints) > 0:
105 | for constraint in obj.constraints:
106 | if isinstance(constraint, FixAtoms):
107 | constraints.append(
108 | {"type": "FixAtoms", "indices": constraint.index.tolist()}
109 | )
110 | else:
111 | # Can not serialize other constraints
112 | pass
113 |
114 | # We don't want to send positions twice
115 | arrays.pop("positions", None)
116 | arrays.pop("numbers", None)
117 |
118 | return ASEDict(
119 | numbers=numbers,
120 | positions=positions,
121 | connectivity=connectivity,
122 | arrays=arrays,
123 | info=info,
124 | calc=calc,
125 | pbc=pbc,
126 | cell=cell,
127 | vectors=vectors,
128 | constraints=constraints,
129 | )
130 |
131 | def decode(self, value: ASEDict) -> ase.Atoms:
132 | """Create datetime object from str / isoformat"""
133 | atoms = ase.Atoms(
134 | numbers=value["numbers"],
135 | positions=value["positions"],
136 | info=value["info"],
137 | pbc=value["pbc"],
138 | cell=value["cell"],
139 | )
140 | if connectivity := value.get("connectivity"):
141 | # or do we want this to be nx.Graph?
142 | atoms.connectivity = np.array(connectivity)
143 | for key, val in value["arrays"].items():
144 | atoms.arrays[key] = np.array(val)
145 | if calc := value.get("calc"):
146 | atoms.calc = SinglePointCalculator(atoms)
147 | atoms.calc.results.update(calc)
148 | if vectors := value.get("vectors"):
149 | atoms.info["vectors"] = vectors
150 | if constraints := value.get("constraints"):
151 | for constraint in constraints:
152 | if constraint["type"] == "FixAtoms":
153 | atoms.set_constraint(FixAtoms(constraint["indices"]))
154 | return atoms
155 |
156 |
157 | class Object3DConverter(znjson.ConverterBase):
158 | instance: type = Object3D
159 | representation: str = "zndraw.Object3D"
160 | level: int = 100
161 |
162 | def encode(self, obj: Object3D) -> dict:
163 | return {"class": obj.__class__.__name__, "data": obj.model_dump()}
164 |
165 | def decode(self, value: str) -> Object3D:
166 | import zndraw
167 |
168 | cls = getattr(zndraw, value["class"])
169 |
170 | return cls(**value["data"])
171 |
--------------------------------------------------------------------------------
/zndraw/draw/__init__.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | from pydantic import BaseModel, ConfigDict, Field
4 |
5 |
6 | def _update_material_schema(schema: dict) -> dict:
7 | schema["properties"]["wireframe"]["format"] = "checkbox"
8 | schema["properties"]["color"]["format"] = "color"
9 | schema["properties"]["opacity"]["format"] = "range"
10 | schema["properties"]["opacity"]["step"] = 0.01
11 | schema["properties"]["outlines"]["format"] = "checkbox"
12 | return schema
13 |
14 |
15 | def _update_object3d_schema(schema: dict) -> dict:
16 | """Remote position, rotation, and scale from the schema."""
17 | schema["properties"].pop("position", None)
18 | schema["properties"].pop("rotation", None)
19 | schema["properties"].pop("scale", None)
20 | return schema
21 |
22 |
23 | class Material(BaseModel):
24 | color: str = "#62929E"
25 | opacity: float = Field(default=0.2, ge=0.0, le=1.0)
26 | wireframe: bool = False
27 | outlines: bool = False
28 |
29 | model_config = ConfigDict(json_schema_extra=_update_material_schema)
30 |
31 |
32 | class Object3D(BaseModel):
33 | """Base class for all 3D objects."""
34 |
35 | material: Material = Material()
36 |
37 | position: t.Tuple[float, float, float] | list[float] = (0, 0, 0)
38 | rotation: t.Tuple[float, float, float] | list[float] = (0, 0, 0)
39 | scale: t.Tuple[float, float, float] | list[float] = (1, 1, 1)
40 |
41 | model_config = ConfigDict(json_schema_extra=_update_object3d_schema)
42 |
43 | def run(self, vis, **kwargs) -> None:
44 | # get the selected particles and compute the COM
45 | if len(vis.selection) > 0:
46 | selected = vis.atoms[vis.selection]
47 | self.position = selected.get_center_of_mass().tolist()
48 | # self.position = vis.points[0].tolist()
49 | print(f"Running {self.__class__.__name__} at {self.position}")
50 |
51 | vis.geometries.append(self) # TODO: dump / load without pickle
52 |
53 |
54 | class Plane(Object3D):
55 | width: float = 10.0
56 | height: float = 10.0
57 |
58 |
59 | class Box(Object3D):
60 | width: float = 10.0
61 | height: float = 10.0
62 | depth: float = 10.0
63 |
64 |
65 | class Circle(Object3D):
66 | radius: float = 5.0
67 |
68 |
69 | class Cone(Object3D):
70 | radius: float = 5.0
71 | height: float = 10.0
72 |
73 |
74 | class Cylinder(Object3D):
75 | radius_top: float = 5.0
76 | radius_bottom: float = 5.0
77 | height: float = 10.0
78 |
79 |
80 | class Dodecahedron(Object3D):
81 | radius: float = 5.0
82 |
83 |
84 | class Icosahedron(Object3D):
85 | radius: float = 5.0
86 |
87 |
88 | class Octahedron(Object3D):
89 | radius: float = 5.0
90 |
91 |
92 | class Ring(Object3D):
93 | inner_radius: float = 1
94 | outer_radius: float = 4.0
95 |
96 |
97 | class Sphere(Object3D):
98 | radius: float = 4.0
99 |
100 |
101 | class Tetrahedron(Object3D):
102 | radius: float = 5.0
103 |
104 |
105 | class Torus(Object3D):
106 | radius: float = 3.0
107 | tube: float = 1.0
108 |
109 |
110 | class TorusKnot(Object3D):
111 | radius: float = 3.0
112 | tube: float = 1.0
113 |
114 |
115 | class Rhomboid(Object3D):
116 | vectorA: t.Tuple[float, float, float] | list[float] = (10, 0, 0)
117 | vectorB: t.Tuple[float, float, float] | list[float] = (0, 10, 0)
118 | vectorC: t.Tuple[float, float, float] | list[float] = (0, 0, 10)
119 |
120 |
121 | class Custom2DShape(Object3D):
122 | points: list[tuple[float, float]]
123 |
124 |
125 | class Ellipsoid(Object3D):
126 | a: float = 10.0
127 | b: float = 5.0
128 | c: float = 5.0
129 |
130 |
131 | geometries: dict[str, t.Type[Object3D]] = {
132 | Plane.__name__: Plane,
133 | Box.__name__: Box,
134 | Circle.__name__: Circle,
135 | Cone.__name__: Cone,
136 | Cylinder.__name__: Cylinder,
137 | Dodecahedron.__name__: Dodecahedron,
138 | Icosahedron.__name__: Icosahedron,
139 | Octahedron.__name__: Octahedron,
140 | Ring.__name__: Ring,
141 | Sphere.__name__: Sphere,
142 | Tetrahedron.__name__: Tetrahedron,
143 | Torus.__name__: Torus,
144 | TorusKnot.__name__: TorusKnot,
145 | Rhomboid.__name__: Rhomboid,
146 | Ellipsoid.__name__: Ellipsoid,
147 | }
148 |
--------------------------------------------------------------------------------
/zndraw/exceptions.py:
--------------------------------------------------------------------------------
1 | class ZnDrawException(Exception):
2 | """Base ZnDraw exception."""
3 |
4 |
5 | class RoomNotFound(Exception):
6 | """Raised when a room is not found on the server."""
7 |
8 |
9 | class RoomLockedError(ZnDrawException):
10 | """Raised when tried to modify a locked room."""
11 |
--------------------------------------------------------------------------------
/zndraw/figure.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import dataclasses
3 |
4 | import matplotlib.pyplot as plt
5 | import znjson
6 |
7 |
8 | @dataclasses.dataclass(frozen=True)
9 | class Figure:
10 | """Visualize a file or a matplotlib figure."""
11 |
12 | path: str | None = None
13 | figure: plt.Figure | None = None
14 |
15 | def __post_init__(self):
16 | if self.path is not None and self.figure is not None:
17 | raise ValueError("Figure can't have both path and figure")
18 |
19 | def to_base64(self) -> str:
20 | if self.path is not None:
21 | with open(self.path, "rb") as image_file:
22 | return base64.b64encode(image_file.read()).decode("utf-8")
23 | else:
24 | import io
25 |
26 | buf = io.BytesIO()
27 | self.figure.savefig(buf, format="png")
28 | buf.seek(0)
29 | return base64.b64encode(buf.read()).decode("utf-8")
30 |
31 |
32 | class FigureConverter(znjson.ConverterBase):
33 | level = 100
34 | instance = Figure
35 | representation = "zndraw.Figure"
36 |
37 | def encode(self, obj: Figure) -> dict:
38 | return {"path": obj.path, "base64": obj.to_base64()}
39 |
40 | def decode(self, value: dict) -> Figure:
41 | return Figure(value["path"])
42 |
--------------------------------------------------------------------------------
/zndraw/modify/private.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | import ase
4 |
5 | if t.TYPE_CHECKING:
6 | from zndraw.zndraw import ZnDraw
7 |
8 | from . import UpdateScene
9 |
10 |
11 | class NewScene(UpdateScene):
12 | """Clear the scene, deleting all atoms and points."""
13 |
14 | def run(self, vis: "ZnDraw", **kwargs) -> None:
15 | del vis[vis.step + 1 :]
16 | vis.points = []
17 | vis.append(ase.Atoms())
18 | vis.selection = []
19 | step = len(vis) - 1
20 | vis.step = step
21 | vis.bookmarks = vis.bookmarks | {step: "New Scene"}
22 | vis.camera = {"position": [0, 0, 20], "target": [0, 0, 0]}
23 |
24 |
25 | class ClearTools(UpdateScene):
26 | """Clear the tools, removing all guiding points and undoing any selection."""
27 |
28 | def run(self, vis: "ZnDraw", **kwargs) -> None:
29 | vis.points = []
30 | vis.selection = []
31 |
--------------------------------------------------------------------------------
/zndraw/queue.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import traceback
3 | import typing as t
4 |
5 | import socketio.exceptions
6 | import znsocket.exceptions
7 |
8 | from zndraw.base import Extension
9 |
10 | if t.TYPE_CHECKING:
11 | from zndraw import ZnDraw
12 |
13 |
14 | log = logging.getLogger(__name__)
15 | TASK_RUNNING = "ZNDRAW TASK IS RUNNING"
16 |
17 |
18 | def check_queue(vis: "ZnDraw") -> None:
19 | """Main loop to check and process modifier tasks for both private and public queues."""
20 | while True:
21 | if not vis._modifiers:
22 | vis.socket.sleep(1)
23 | continue
24 | try:
25 | process_modifier_queue(vis)
26 | process_public_queue(vis)
27 | vis.socket.sleep(1)
28 | except (znsocket.exceptions.ZnSocketError, socketio.exceptions.SocketIOError):
29 | log.warning("Connection to ZnDraw server lost. Reconnecting...")
30 | vis.socket.disconnect()
31 | vis.socket.sleep(1)
32 |
33 |
34 | def process_modifier_queue(vis: "ZnDraw") -> None:
35 | """Process private modifier tasks in the queue."""
36 | modifier_queue = znsocket.Dict(
37 | r=vis.r,
38 | socket=vis._refresh_client,
39 | key=f"queue:{vis.token}:modifier",
40 | )
41 |
42 | for key in modifier_queue:
43 | if key in vis._modifiers:
44 | try:
45 | task = modifier_queue.pop(key)
46 | cls = vis._modifiers[key]["cls"]
47 | run_kwargs = vis._modifiers[key]["run_kwargs"]
48 | run_queued_task(vis, cls, task, modifier_queue, run_kwargs)
49 | except IndexError:
50 | pass
51 |
52 |
53 | def process_public_queue(vis: "ZnDraw") -> None:
54 | """Process public modifier tasks in the public queue."""
55 | from zndraw import ZnDraw
56 |
57 | if not any(mod["public"] for mod in vis._modifiers.values()):
58 | return
59 |
60 | public_queue = znsocket.Dict(
61 | r=vis.r,
62 | socket=vis._refresh_client,
63 | key="queue:default:modifier",
64 | )
65 |
66 | for room, room_queue in public_queue.items():
67 | for key in room_queue:
68 | if key in vis._modifiers and vis._modifiers[key]["public"]:
69 | new_vis = ZnDraw(url=vis.url, token=room, r=vis.r)
70 | try:
71 | task = room_queue.pop(key)
72 | # run_queued_task(new_vis, key, task, room_queue)
73 | cls = vis._modifiers[key]["cls"]
74 | run_kwargs = vis._modifiers[key]["run_kwargs"]
75 | run_queued_task(new_vis, cls, task, room_queue, run_kwargs)
76 | except IndexError:
77 | pass
78 | finally:
79 | new_vis.socket.sleep(1)
80 | new_vis.socket.disconnect()
81 |
82 |
83 | def run_queued_task(
84 | vis: "ZnDraw",
85 | cls: t.Type[Extension],
86 | task: dict,
87 | queue: znsocket.Dict,
88 | run_kwargs: dict | None = None,
89 | ) -> None:
90 | """Run a specific task and handle exceptions."""
91 | if not run_kwargs:
92 | run_kwargs = {}
93 | try:
94 | queue[TASK_RUNNING] = True
95 | cls(**task).run(vis, **run_kwargs)
96 | except Exception:
97 | vis.log(f"""Error running `{cls}`:
98 | ```python
99 | {traceback.format_exc()}
100 | ```""")
101 | finally:
102 | queue.pop(TASK_RUNNING)
103 |
--------------------------------------------------------------------------------
/zndraw/selection/__init__.py:
--------------------------------------------------------------------------------
1 | import random
2 | import typing as t
3 |
4 | import networkx as nx
5 | import numpy as np
6 | from pydantic import Field
7 |
8 | from zndraw.base import Extension
9 | from zndraw.zndraw import ZnDraw
10 |
11 | try:
12 | from zndraw.select import mda # noqa: F401
13 | except ImportError:
14 | # mdanalysis is not installed
15 | pass
16 |
17 |
18 | class NoneSelection(Extension):
19 | def run(self, vis) -> None:
20 | vis.selection = []
21 |
22 |
23 | class All(Extension):
24 | """Select all atoms."""
25 |
26 | def run(self, vis) -> None:
27 | atoms = vis[vis.step]
28 | vis.selection = list(range(len(atoms)))
29 |
30 |
31 | class Invert(Extension):
32 | def run(self, vis) -> None:
33 | atoms = vis[vis.step]
34 | selected_ids = vis.selection
35 | vis.selection = list(set(range(len(atoms))) - set(selected_ids))
36 |
37 |
38 | class Range(Extension):
39 | start: int = Field(0, description="Start index")
40 | end: int = Field(5, description="End index")
41 | step: int = Field(1, description="Step size")
42 |
43 | def run(self, vis) -> None:
44 | vis.selection = list(range(self.start, self.end, self.step))
45 |
46 |
47 | class Random(Extension):
48 | count: int = Field(..., description="Number of atoms to select")
49 |
50 | def run(self, vis) -> None:
51 | atoms = vis[vis.step]
52 | vis.selection = random.sample(range(len(atoms)), self.count)
53 |
54 |
55 | class IdenticalSpecies(Extension):
56 | def run(self, vis) -> None:
57 | atoms = vis[vis.step]
58 | selected_ids = vis.selection
59 | selected_ids = set(selected_ids)
60 | for idx in tuple(selected_ids):
61 | selected_symbol = atoms[idx].symbol
62 | selected_ids.update(
63 | idx for idx, atom in enumerate(atoms) if atom.symbol == selected_symbol
64 | )
65 | vis.selection = list(selected_ids)
66 |
67 |
68 | class ConnectedParticles(Extension):
69 | def run(self, vis) -> None:
70 | atoms = vis.atoms
71 | selected_ids = vis.selection
72 | total_ids = []
73 | try:
74 | edges = atoms.connectivity
75 | graph = nx.Graph()
76 | for edge in edges:
77 | node_a, node_b, weight = edge
78 | graph.add_edge(node_a, node_b, weight=weight)
79 | except AttributeError:
80 | return selected_ids
81 |
82 | for node_id in selected_ids:
83 | total_ids += list(nx.node_connected_component(graph, node_id))
84 | total_ids = np.array(total_ids)
85 |
86 | vis.selection = [x.item() for x in set(total_ids)]
87 |
88 |
89 | class Neighbour(Extension):
90 | """Select the nth order neighbours of the selected atoms."""
91 |
92 | order: int = Field(1, description="Order of neighbour")
93 |
94 | def run(self, vis) -> None:
95 | total_ids = []
96 | atoms = vis[vis.step]
97 | selected_ids = vis.selection
98 | try:
99 | graph = atoms.connectivity
100 | except AttributeError:
101 | return selected_ids
102 |
103 | for node_id in selected_ids:
104 | total_ids += list(
105 | nx.single_source_shortest_path_length(graph, node_id, self.order).keys()
106 | )
107 |
108 | vis.selection = list(set(total_ids))
109 |
110 |
111 | class UpdateSelection(Extension):
112 | """Reload Selection."""
113 |
114 | def run(self, vis: ZnDraw) -> None:
115 | vis.selection = vis.selection
116 |
117 |
118 | selections: dict[str, t.Type[Extension]] = {
119 | ConnectedParticles.__name__: ConnectedParticles,
120 | NoneSelection.__name__: NoneSelection,
121 | All.__name__: All,
122 | Invert.__name__: Invert,
123 | Range.__name__: Range,
124 | Random.__name__: Random,
125 | IdenticalSpecies.__name__: IdenticalSpecies,
126 | Neighbour.__name__: Neighbour,
127 | }
128 |
--------------------------------------------------------------------------------
/zndraw/server/__init__.py:
--------------------------------------------------------------------------------
1 | from zndraw.server.events import init_socketio_events
2 | from zndraw.server.routes import main as main_blueprint
3 |
4 | __all__ = ["init_socketio_events", "main_blueprint"]
5 |
--------------------------------------------------------------------------------
/zndraw/server/events.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import importlib
3 | import importlib.metadata
4 | import importlib.util
5 | import logging
6 | import uuid
7 |
8 | import znsocket
9 | from flask import current_app, request, session
10 | from flask_socketio import SocketIO, emit, join_room
11 | from redis import Redis
12 |
13 | from zndraw.tasks import (
14 | inspect_zntrack_node,
15 | load_zntrack_figures,
16 | load_zntrack_frames,
17 | run_room_copy,
18 | run_room_worker,
19 | run_scene_dependent_schema,
20 | run_schema,
21 | run_upload_file,
22 | )
23 |
24 | log = logging.getLogger(__name__)
25 | __version__ = importlib.metadata.version("zndraw")
26 |
27 |
28 | @dataclasses.dataclass
29 | class DummyClient:
30 | """Dummy replacement for znsocket.Client."""
31 |
32 | sio: SocketIO
33 | refresh_callbacks: list = dataclasses.field(default_factory=list)
34 |
35 |
36 | def init_socketio_events(io: SocketIO):
37 | @io.on("connect")
38 | def connect():
39 | emit("version", __version__)
40 |
41 | @io.on("shutdown")
42 | def shutdown():
43 | if "AUTH_TOKEN" not in current_app.config or session["authenticated"]:
44 | log.critical("Shutting down server")
45 | current_app.extensions["redis"].flushall()
46 |
47 | current_app.extensions["celery"].control.purge()
48 | current_app.extensions["celery"].control.broadcast("shutdown")
49 |
50 | io.stop()
51 | else:
52 | log.critical("Unauthenticated user tried to shut down the server.")
53 |
54 | @io.on("disconnect")
55 | def disconnect():
56 | try:
57 | room = str(session["token"])
58 | except KeyError:
59 | log.critical(f"disconnecting {request.sid}")
60 | return
61 | r = current_app.extensions["redis"]
62 |
63 | if "name" in session:
64 | log.critical(f"disconnecting (webclient) {request.sid} from room {room}")
65 | r.srem(f"room:{room}:webclients", session["name"])
66 | emit(
67 | "room:users:refresh", list(r.smembers(f"room:{room}:webclients")), to=room
68 | )
69 | else:
70 | log.critical(f"disconnecting (pyclient) {request.sid}")
71 |
72 | for ref_token in [room, "default"]:
73 | modifier_registry = znsocket.Dict(
74 | r=r, key=f"registry:{ref_token}:modifier", repr_type="full"
75 | )
76 |
77 | modifier_schema = znsocket.Dict(
78 | r=r,
79 | key=f"schema:{ref_token}:modifier",
80 | repr_type="full",
81 | socket=DummyClient(sio=io),
82 | )
83 |
84 | # TODO: default room modifier
85 | for modifier in modifier_registry.pop(request.sid, []):
86 | for other in modifier_registry.values():
87 | if modifier in other:
88 | break
89 | else:
90 | log.debug(f"Remove {modifier} from room {room}")
91 | modifier_schema.pop(
92 | modifier
93 | ) # TODO this does not work well with public yet.
94 |
95 | @io.on("webclient:connect")
96 | def webclient_connect():
97 | try:
98 | token = session["token"]
99 | except KeyError:
100 | token = uuid.uuid4().hex[:8]
101 | session["token"] = token
102 |
103 | if "AUTH_TOKEN" not in current_app.config:
104 | session["authenticated"] = True
105 | else:
106 | if "authenticated" not in session:
107 | session["authenticated"] = False
108 |
109 | room = str(session["token"])
110 | join_room(room) # rename token to room or room_token
111 |
112 | run_schema.delay(room)
113 | run_scene_dependent_schema.delay(room)
114 |
115 | session["name"] = uuid.uuid4().hex[:8]
116 |
117 | r = current_app.extensions["redis"]
118 | r.sadd(f"room:{room}:webclients", session["name"])
119 |
120 | # TODO: this is currently not used afaik
121 | emit("room:users:refresh", list(r.smembers(f"room:{room}:webclients")), to=room)
122 |
123 | log.critical(f"connecting (webclient) {request.sid} to {room}")
124 |
125 | if "TUTORIAL" in current_app.config:
126 | emit("tutorial:url", current_app.config["TUTORIAL"])
127 | if "SIMGEN" in current_app.config:
128 | emit("showSiMGen", True)
129 |
130 | return {
131 | "name": session["name"],
132 | "room": room,
133 | "authenticated": session["authenticated"],
134 | }
135 |
136 | @io.on("join") # rename pyclient:connect
137 | def join(data: dict):
138 | """
139 | Arguments:
140 | data: {"token": str, "auth_token": str}
141 | """
142 | # TODO: prohibit "token" to be "default"
143 |
144 | if "AUTH_TOKEN" not in current_app.config:
145 | session["authenticated"] = True
146 | else:
147 | session["authenticated"] = (
148 | data["auth_token"] == current_app.config["AUTH_TOKEN"]
149 | )
150 | token = data["token"]
151 | session["token"] = token
152 | room = str(session["token"])
153 |
154 | join_room(room)
155 | log.critical(f"connecting (pyclient) {request.sid} to {room}")
156 | # join_room(f"pyclients_{token}")
157 |
158 | @io.on("room:worker:run")
159 | def room_worker_run():
160 | """Start a worker to process all (available) queued tasks."""
161 | room = session.get("token")
162 | run_room_worker.delay(room)
163 |
164 | @io.on("schema:refresh")
165 | def schema_refresh():
166 | room = session.get("token")
167 | run_scene_dependent_schema.delay(room)
168 |
169 | @io.on("room:alert")
170 | def room_alert(msg: str):
171 | """Forward the alert message to every client in the room"""
172 | # TODO: identify the source client.
173 | room = session.get("token")
174 | emit("room:alert", msg, to=room)
175 |
176 | @io.on("room:upload:file")
177 | def room_upload_file(data: dict):
178 | room = session.get("token")
179 | run_upload_file.delay(room, data)
180 |
181 | @io.on("room:lock:set")
182 | def room_lock_set(locked: bool):
183 | room = session.get("token")
184 | r: Redis = current_app.extensions["redis"]
185 | r.set(f"room:{room}:locked", str(locked))
186 | emit("room:lock:set", locked, to=room)
187 |
188 | @io.on("room:lock:get")
189 | def room_lock_get() -> bool:
190 | room = session.get("token")
191 | r: Redis = current_app.extensions["redis"]
192 | locked = r.get(f"room:{room}:locked")
193 | return locked == "True"
194 |
195 | @io.on("room:token:get")
196 | def room_token_get() -> str:
197 | return session.get("token")
198 |
199 | @io.on("zntrack:available")
200 | def check_zntrack_available() -> bool:
201 | return importlib.util.find_spec("zntrack") is not None
202 |
203 | @io.on("zntrack:list-stages")
204 | def zntrack_list_stages(data: dict):
205 | try:
206 | import dvc.api
207 |
208 | fs = dvc.api.DVCFileSystem(url=data.get("remote"), rev=data.get("rev"))
209 | return [x.name for x in fs.repo.stage.collect() if hasattr(x, "name")]
210 | except Exception:
211 | return []
212 |
213 | @io.on("zntrack:inspect-stage")
214 | def zntrack_inspect_stage(data: dict):
215 | return inspect_zntrack_node(**data)
216 |
217 | @io.on("zntrack:load-frames")
218 | def zntrack_load_frames(data: dict):
219 | load_zntrack_frames.delay(room=session.get("token"), **data)
220 |
221 | @io.on("zntrack:load-figure")
222 | def zntrack_load_figure(data: dict):
223 | load_zntrack_figures.delay(room=session.get("token"), **data)
224 |
225 | @io.on("room:copy")
226 | def room_copy():
227 | room = session.get("token")
228 | run_room_copy.delay(room)
229 |
--------------------------------------------------------------------------------
/zndraw/server/routes.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import uuid
3 | from io import BytesIO, StringIO
4 |
5 | import ase.io
6 | from flask import (
7 | Blueprint,
8 | current_app,
9 | redirect,
10 | request,
11 | send_file,
12 | send_from_directory,
13 | session,
14 | )
15 |
16 | main = Blueprint("main", __name__)
17 |
18 | log = logging.getLogger(__name__)
19 |
20 |
21 | @main.route("/")
22 | def index():
23 | """Render the main ZnDraw page."""
24 | try:
25 | token = session["token"]
26 | except KeyError:
27 | token = uuid.uuid4().hex[:8]
28 | session["token"] = token
29 |
30 | request_args = request.args.to_dict()
31 | if len(request_args) > 0:
32 | from zndraw import ZnDraw
33 |
34 | vis = ZnDraw(
35 | r=current_app.extensions["redis"],
36 | url=current_app.config["SERVER_URL"],
37 | token=session.get("token"),
38 | )
39 | if "step" in request_args:
40 | vis.step = int(request_args["step"])
41 | if "selection" in request_args:
42 | if request_args["selection"] == "null":
43 | vis.selection = []
44 | else:
45 | vis.selection = [int(i) for i in request_args["selection"].split(",")]
46 |
47 | if "APPLICATION_ROOT" in current_app.config:
48 | return redirect(f"{current_app.config['APPLICATION_ROOT']}token/{token}")
49 | return redirect(f"/token/{token}")
50 |
51 |
52 | @main.route("/")
53 | def main_files(filename):
54 | return send_from_directory("templates/", filename)
55 |
56 |
57 | @main.route("/assets/")
58 | def assets(filename):
59 | return send_from_directory("templates/assets", filename)
60 |
61 |
62 | @main.route("/token/")
63 | def token(token):
64 | if token == "default":
65 | return "Invalid token", 403
66 | session["token"] = token
67 | return send_from_directory("templates", "index.html")
68 |
69 |
70 | @main.route("/reset")
71 | def reset():
72 | session["token"] = uuid.uuid4().hex[:8] # TODO: how should reset work locally?
73 | if "APPLICATION_ROOT" in current_app.config:
74 | return redirect(
75 | f"{current_app.config['APPLICATION_ROOT']}token/{session['token']}"
76 | )
77 | return redirect(f"/token/{session['token']}")
78 |
79 |
80 | @main.route("/exit")
81 | @main.route("/exit/")
82 | def exit_route(token: str | None = None):
83 | """Exit the session."""
84 | log.critical("Server shutting down...")
85 | if not session.get("authenticated", False) and token != current_app.config.get(
86 | "AUTH_TOKEN"
87 | ):
88 | return "Invalid auth token", 403
89 |
90 | current_app.extensions["socketio"].stop()
91 | return "Server shutting down..."
92 |
93 |
94 | @main.route("/login/")
95 | def login_route(auth_token: str | None = None):
96 | """Create an authenticated session."""
97 | session["authenticated"] = auth_token == current_app.config.get("AUTH_TOKEN", "NONE")
98 | if session["authenticated"]:
99 | if "APPLICATION_ROOT" in current_app.config:
100 | return redirect(f"{current_app.config['APPLICATION_ROOT']}")
101 | return redirect("/")
102 | return "Invalid auth token", 403
103 |
104 |
105 | @main.route("/logout")
106 | def logout_route():
107 | if not session.get("authenticated", False):
108 | return "Can only log out, if you logged in before.", 403
109 | session["authenticated"] = False
110 | if "APPLICATION_ROOT" in current_app.config:
111 | return redirect(f"{current_app.config['APPLICATION_ROOT']}")
112 | return redirect("/")
113 |
114 |
115 | @main.route("/upload", methods=["POST"])
116 | def upload():
117 | """Upload a file to the server."""
118 | from zndraw import ZnDrawLocal
119 |
120 | file = request.files["file"]
121 | token = session.get("token")
122 |
123 | if not token:
124 | return "Unauthorized", 401
125 |
126 | try:
127 | # Extract the file format from the filename
128 | file_format = file.filename.split(".")[-1]
129 | file_content = file.read()
130 |
131 | stream = StringIO(file_content.decode("utf-8"))
132 |
133 | vis = ZnDrawLocal(
134 | r=current_app.extensions["redis"],
135 | url=current_app.config["SERVER_URL"],
136 | token=token,
137 | )
138 | structures = list(ase.io.iread(stream, format=file_format))
139 | vis.extend(structures)
140 | vis.socket.disconnect()
141 |
142 | return "File uploaded", 200
143 |
144 | except Exception as e:
145 | log.error(f"Error uploading file: {e}")
146 | return str(e), 500
147 |
148 |
149 | @main.route("/download", methods=["GET"])
150 | def download():
151 | """Download a file to the client."""
152 | from zndraw import ZnDrawLocal
153 |
154 | token = session.get("token")
155 |
156 | file = StringIO()
157 | vis = ZnDrawLocal(
158 | r=current_app.extensions["redis"],
159 | url=current_app.config["SERVER_URL"],
160 | token=token,
161 | )
162 | try:
163 | for atoms in vis:
164 | ase.io.write(file, atoms, format="xyz", append=True)
165 | except Exception as e:
166 | log.error(f"Error downloading file: {e}")
167 |
168 | # convert StringIO to BytesIO
169 | file = BytesIO(file.getvalue().encode("utf-8"))
170 | try:
171 | return send_file(file, as_attachment=True, download_name="trajectory.xyz")
172 | except Exception as e:
173 | log.error(f"Error downloading file: {e}")
174 | return str(e), 500
175 |
--------------------------------------------------------------------------------
/zndraw/standalone.py:
--------------------------------------------------------------------------------
1 | """Utils for running ZnDraw standalone, without redis or external celery worker."""
2 |
3 | import os
4 | import platform
5 | import subprocess
6 | import threading
7 | import time
8 |
9 | import znsocket.exceptions
10 |
11 |
12 | def run_znsocket(port) -> subprocess.Popen:
13 | """Run a znsocket server instead of redis."""
14 |
15 | server = subprocess.Popen(["znsocket", "--port", str(port)])
16 |
17 | for trial in range(1000):
18 | try:
19 | znsocket.Client.from_url(f"znsocket://localhost:{port}")
20 | break
21 | except znsocket.exceptions.ConnectionError:
22 | time.sleep(0.1)
23 | if trial % 10 == 0:
24 | print("Waiting for znsocket to start...")
25 | else:
26 | raise RuntimeError("Unable to start ZnSocket server!")
27 |
28 | return server
29 |
30 |
31 | def run_celery_thread_worker() -> threading.Thread:
32 | """Run a celery worker."""
33 | my_env = os.environ.copy()
34 | if platform.system() == "Darwin" and platform.processor() == "arm":
35 | # fix celery worker issue on apple silicon
36 | my_env["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES"
37 |
38 | def run_celery_worker():
39 | from zndraw_app.make_celery import celery_app
40 |
41 | celery_app.worker_main(
42 | argv=["worker", "--loglevel=info", "--without-gossip", "--pool=eventlet"]
43 | )
44 |
45 | worker = threading.Thread(target=run_celery_worker)
46 | worker.start()
47 | return worker
48 |
49 |
50 | # We use this for running tests for now
51 | def run_celery_worker() -> subprocess.Popen:
52 | """Run a celery worker."""
53 | my_env = os.environ.copy()
54 | if platform.system() == "Darwin" and platform.processor() == "arm":
55 | # fix celery worker issue on apple silicon
56 | my_env["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES"
57 |
58 | worker = subprocess.Popen(
59 | [
60 | "celery",
61 | "-A",
62 | "zndraw_app.make_celery",
63 | "worker",
64 | "--loglevel=info",
65 | "-P",
66 | "eventlet",
67 | ],
68 | env=my_env,
69 | )
70 | return worker
71 |
--------------------------------------------------------------------------------
/zndraw/type_defs.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | if t.TYPE_CHECKING:
4 | from zndraw.base import Extension
5 |
6 |
7 | class RegisterModifier(t.TypedDict):
8 | cls: t.Type["Extension"]
9 | run_kwargs: dict
10 | public: bool
11 |
12 |
13 | class TimeoutConfig(t.TypedDict):
14 | """Timeout configuration for the ZnDraw client."""
15 |
16 | connection: int
17 | modifier: float
18 | between_calls: float
19 |
20 | emit_retries: int
21 | call_retries: int
22 | connect_retries: int
23 |
24 |
25 | class JupyterConfig(t.TypedDict):
26 | width: str | int
27 | height: str | int
28 |
29 |
30 | class CameraData(t.TypedDict):
31 | position: list[float]
32 | target: list[float]
33 |
34 |
35 | class ASEDict(t.TypedDict):
36 | numbers: list[int]
37 | positions: list[list[float]]
38 | connectivity: list[tuple[int, int, int]]
39 | arrays: dict[str, list[float | int | list[float | int]]]
40 | info: dict[str, float | int]
41 | # calc: dict[str, float|int|np.ndarray] # should this be split into arrays and info?
42 | pbc: list[bool]
43 | cell: list[list[float]]
44 | vectors: list[list[list[float]]]
45 | constraints: list[dict]
46 |
47 |
48 | class ASEJson(t.TypedDict):
49 | _type: t.Literal["ase.Atoms"]
50 | value: ASEDict
51 |
52 |
53 | # Type hint is string, but correctly it is 'json.dumps(ASEJson)'
54 | ATOMS_LIKE = t.Union[ASEDict, str]
55 |
--------------------------------------------------------------------------------
/zndraw/upload.py:
--------------------------------------------------------------------------------
1 | """zndraw-upload API"""
2 |
3 | import typing as t
4 | import uuid
5 | import webbrowser
6 |
7 | import typer
8 |
9 | from zndraw import ZnDraw
10 |
11 | from .tasks import FileIO, get_generator_from_filename
12 | from .utils import load_plots_to_dict
13 |
14 | cli = typer.Typer()
15 |
16 |
17 | def upload(
18 | url: str,
19 | token: t.Optional[str],
20 | fileio: FileIO,
21 | append: bool,
22 | plots: list[str],
23 | browser: bool,
24 | batch_size: int = 16,
25 | ):
26 | """Upload a file to ZnDraw."""
27 | if token is None:
28 | token = str(uuid.uuid4())
29 | vis = ZnDraw(url=url, token=token, convert_nan=fileio.convert_nan)
30 | typer.echo(f"Uploading to: {url}/token/{vis.token}")
31 |
32 | if not append:
33 | del vis[:]
34 |
35 | generator = get_generator_from_filename(fileio)
36 |
37 | if browser:
38 | webbrowser.open(f"{url}/token/{vis.token}")
39 |
40 | frames = []
41 | for frame in generator:
42 | frames.append(frame)
43 | if len(frames) == batch_size:
44 | vis.extend(frames)
45 | frames = []
46 | vis.extend(frames)
47 |
48 | vis.figures.update(load_plots_to_dict(plots, fileio.remote, fileio.rev))
49 |
--------------------------------------------------------------------------------
/zndraw/utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import pathlib
4 | import socket
5 | import typing as t
6 | import urllib.parse
7 | from urllib.parse import urlparse
8 |
9 | import numpy as np
10 | import plotly.graph_objects as go
11 | import plotly.graph_objs
12 | import socketio.exceptions
13 | import znjson
14 | from ase.data import covalent_radii
15 |
16 | log = logging.getLogger(__name__)
17 |
18 |
19 | def parse_url(input_url) -> t.Tuple[str, t.Optional[str]]:
20 | parsed = urlparse(input_url)
21 | base_url = f"{parsed.scheme}://{parsed.netloc}"
22 | path = parsed.path.strip("/") if parsed.path else None
23 | return base_url, path if path else None
24 |
25 |
26 | def rgb2hex(value):
27 | r, g, b = np.array(value * 255, dtype=int)
28 | return "#%02x%02x%02x" % (r, g, b)
29 |
30 |
31 | @functools.lru_cache(maxsize=128)
32 | def get_scaled_radii() -> np.ndarray:
33 | """Scale down the covalent radii to visualize bonds better."""
34 | radii = covalent_radii
35 | # shift the values such that they are in [0.3, 1.3]
36 | radii = radii - np.min(radii)
37 | radii = radii / np.max(radii)
38 | radii = radii + 0.3
39 | return radii
40 |
41 |
42 | def get_port(default: int) -> int:
43 | """Get an open port."""
44 | try:
45 | sock = socket.socket()
46 | sock.bind(("", default))
47 | port = default
48 | except OSError:
49 | sock = socket.socket()
50 | sock.bind(("", 0))
51 | port = sock.getsockname()[1]
52 | finally:
53 | sock.close()
54 | return port
55 |
56 |
57 | class ZnDrawLoggingHandler(logging.Handler):
58 | """Logging handler which emits log messages to the ZnDraw server."""
59 |
60 | def __init__(self, zndraw):
61 | super().__init__()
62 | self.zndraw = zndraw
63 |
64 | def emit(self, record):
65 | try:
66 | msg = self.format(record)
67 | self.zndraw.log(msg)
68 | except RecursionError: # See StreamHandler
69 | raise
70 | except Exception:
71 | print("Something went wrong")
72 | self.handleError(record)
73 |
74 |
75 | def emit_with_retry(
76 | socket: socketio.Client,
77 | event,
78 | data=None,
79 | namespace=None,
80 | callback=None,
81 | retries: int = 1,
82 | delay: float = 0.1,
83 | increase_delay: float = 0.1,
84 | reconnect: bool = False,
85 | ) -> None:
86 | """Emit data to a socket with retries."""
87 | for idx in range(retries):
88 | try:
89 | socket.emit(event=event, data=data, namespace=namespace, callback=callback)
90 | break
91 | except socketio.exceptions.BadNamespaceError as err:
92 | log.error(f"Retrying {event} due to {err}")
93 | if idx == retries - 1:
94 | raise err
95 | if reconnect:
96 | raise ValueError("Reconnect not implemented yet")
97 | except Exception as err:
98 | if idx == retries - 1:
99 | raise err
100 | socket.sleep(delay)
101 | delay += increase_delay
102 |
103 |
104 | def call_with_retry(
105 | socket: socketio.Client,
106 | event,
107 | data=None,
108 | namespace=None,
109 | timeout=60,
110 | retries: int = 1,
111 | delay: float = 0.1,
112 | increase_delay: float = 0.1,
113 | reconnect: bool = False,
114 | ) -> t.Any:
115 | """Call a function with retries."""
116 | for idx in range(retries):
117 | try:
118 | return socket.call(
119 | event=event, data=data, namespace=namespace, timeout=timeout
120 | )
121 | except (
122 | socketio.exceptions.TimeoutError,
123 | socketio.exceptions.BadNamespaceError,
124 | ) as err:
125 | log.error(f"Retrying {event} due to {err} ({idx} / {retries})")
126 | if idx == retries - 1:
127 | raise err
128 | if reconnect:
129 | raise ValueError("Reconnect not implemented yet")
130 | except Exception as err:
131 | if idx == retries - 1:
132 | raise err
133 | socket.sleep(delay)
134 | delay += increase_delay
135 | return None
136 |
137 |
138 | def direction_to_euler(direction, roll=0):
139 | """
140 | Convert a direction vector to euler angles.
141 |
142 | You get an increased degree of freedom by setting the roll angle.
143 | """
144 | direction = np.array(direction)
145 | direction = direction / np.linalg.norm(direction)
146 |
147 | x, y, z = direction
148 |
149 | yaw = np.arctan2(y, x)
150 | pitch = np.arctan2(z, np.sqrt(x**2 + y**2))
151 |
152 | return np.array([yaw, pitch, roll])
153 |
154 |
155 | def euler_to_direction(angles):
156 | """
157 | Convert euler angles to a direction vector.
158 |
159 | Roll gets discarded as this is a reduction in degrees of freedom.
160 | """
161 | yaw, pitch, roll = angles
162 |
163 | x = np.cos(yaw) * np.cos(pitch)
164 | y = np.sin(yaw) * np.cos(pitch)
165 | z = np.sin(pitch)
166 |
167 | return np.array([x, y, z])
168 |
169 |
170 | def convert_url_to_http(url: str) -> str:
171 | """Convert a URL to a local file path."""
172 | url = urllib.parse.urlparse(url)
173 | if url.scheme == "wss":
174 | url = url._replace(scheme="https")
175 | elif url.scheme == "ws":
176 | url = url._replace(scheme="http")
177 |
178 | return urllib.parse.urlunparse(url)
179 |
180 |
181 | def get_schema_with_instance_defaults(self) -> dict:
182 | """Update the schema defaults from the instance."""
183 | try:
184 | schema = self.get_updated_schema()
185 | except AttributeError:
186 | schema = self.model_json_schema()
187 | for key, value in self.__dict__.items():
188 | if key in schema["properties"]:
189 | schema["properties"][key]["default"] = value
190 | return schema
191 |
192 |
193 | def get_plots_from_zntrack(path: str, remote: str | None, rev: str | None):
194 | node_name, attribute = path.split(".", 1)
195 | try:
196 | import os
197 |
198 | ## FIX for zntrack bug https://github.com/zincware/ZnTrack/issues/806
199 | import sys
200 |
201 | import zntrack
202 |
203 | sys.path.insert(0, os.getcwd())
204 | ##
205 |
206 | node = zntrack.from_rev(node_name, remote=remote, rev=rev)
207 | return getattr(node, attribute)
208 | except ImportError as err:
209 | raise ImportError(
210 | "You need to install ZnTrack to use the remote feature."
211 | ) from err
212 |
213 |
214 | def load_plots_to_dict(
215 | paths: list[str], remote: str | None, rev: str | None
216 | ) -> dict[str, go.Figure]:
217 | data = {}
218 | for path in paths:
219 | if not pathlib.Path(path).exists():
220 | if remote is not None or rev is not None:
221 | plots = get_plots_from_zntrack(path, remote, rev)
222 | else:
223 | raise FileNotFoundError(f"File {path} does not")
224 | else:
225 | plots = znjson.loads(pathlib.Path(path).read_text())
226 | if isinstance(plots, plotly.graph_objs.Figure):
227 | data[path] = plots
228 | elif isinstance(plots, dict):
229 | if not all(isinstance(v, plotly.graph_objs.Figure) for v in plots.values()):
230 | raise ValueError("All values in the plots dict must be plotly.graph_objs")
231 | data.update({f"{path}_{k}": v for k, v in plots.items()})
232 | elif isinstance(plots, list):
233 | if not all(isinstance(v, plotly.graph_objs.Figure) for v in plots):
234 | raise ValueError("All values in the plots list must be plotly.graph_objs")
235 | data.update({f"{path}_{i}": v for i, v in enumerate(plots)})
236 | else:
237 | raise ValueError("The plots must be a dict, list or Figure")
238 |
239 | return data
240 |
241 |
242 | def update_figure_layout(fig: go.Figure) -> None:
243 | fig.update_layout(
244 | plot_bgcolor="rgba(64, 128, 230, 0.05)", # Neutral light background tint
245 | paper_bgcolor="rgba(255, 255, 255, 0)",
246 | font_color="rgb(64, 128, 230, 0.75)", # Dark gray text for improved contrast
247 | )
248 | fig.update_xaxes(
249 | showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False
250 | )
251 | fig.update_yaxes(
252 | showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False
253 | )
254 |
--------------------------------------------------------------------------------
/zndraw_app/README.md:
--------------------------------------------------------------------------------
1 | # ZnDraw App
2 |
3 | To provide access to the `zndraw.ZnDraw` class without using `eventlet.monkey_patch()` all the `monkey_patch` features have been extracted into this minimal package.
4 |
--------------------------------------------------------------------------------
/zndraw_app/__init__.py:
--------------------------------------------------------------------------------
1 | import eventlet
2 |
3 | eventlet.monkey_patch()
4 |
5 | import os
6 |
7 | from engineio.payload import Payload
8 |
9 | if max_decode_packets := os.environ.get("ENGINEIO_PAYLOAD_MAX_DECODE_PACKETS"):
10 | Payload.max_decode_packets = int(max_decode_packets)
11 |
--------------------------------------------------------------------------------
/zndraw_app/cli.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import datetime
3 | import os
4 | import pathlib
5 | import shutil
6 | import signal
7 | import typing as t
8 | import webbrowser
9 |
10 | import typer
11 |
12 | from zndraw.app import create_app
13 | from zndraw.base import FileIO
14 | from zndraw.standalone import run_celery_thread_worker, run_znsocket
15 | from zndraw.tasks import read_file, read_plots
16 | from zndraw.upload import upload
17 | from zndraw.utils import get_port
18 | from zndraw_app.healthcheck import run_healthcheck
19 |
20 | cli = typer.Typer()
21 |
22 |
23 | @dataclasses.dataclass
24 | class EnvOptions:
25 | FLASK_PORT: str | None = None
26 | FLASK_STORAGE: str | None = None
27 | FLASK_AUTH_TOKEN: str | None = None
28 | FLASK_TUTORIAL: str | None = None
29 | FLASK_SIMGEN: str | None = None
30 | FLASK_SERVER_URL: str | None = None
31 | FLASK_STORAGE_PORT: str | None = None
32 | FLASK_COMPUTE_BONDS: str | None = None
33 | FLASK_MAX_HTTP_BUFFER_SIZE: str | None = None
34 |
35 | @classmethod
36 | def from_env(cls):
37 | return cls(
38 | **{
39 | field.name: os.environ.get(field.name)
40 | for field in dataclasses.fields(cls)
41 | }
42 | )
43 |
44 | def save_to_env(self):
45 | for field in dataclasses.fields(self):
46 | value = getattr(self, field.name)
47 | if value is not None:
48 | os.environ[field.name] = value
49 |
50 |
51 | @cli.command()
52 | def main(
53 | filename: t.Optional[str] = typer.Argument(
54 | None,
55 | help="Path to the file which should be visualized in ZnDraw. Can also be the name and attribute of a ZnTrack Node like 'MyNode.atoms' if at least '--remote .' is provided. ",
56 | ),
57 | url: t.Optional[str] = typer.Option(
58 | None,
59 | help="URL to a running ZnDraw server. Use this server instead of starting a new one.",
60 | envvar="ZNDRAW_URL",
61 | ),
62 | append: bool = typer.Option(
63 | False, help="Append the file to the existing data on the server."
64 | ),
65 | token: t.Optional[str] = typer.Option(
66 | None, help="Only valid if 'url' is provided. Room token to upload the file to."
67 | ),
68 | port: int = typer.Option(
69 | None, help="""Port to use for the ZnDraw server. Default port is 1234"""
70 | ),
71 | browser: bool = typer.Option(
72 | True, help="""Whether to open the ZnDraw GUI in the default web browser."""
73 | ),
74 | start: int = typer.Option(
75 | None,
76 | help="""First frame to be visualized. If set to 0, the first frame will be visualized.""",
77 | ),
78 | stop: int = typer.Option(
79 | None,
80 | help="""Last frame to be visualized. If set to None, the last frame will be visualized.""",
81 | ),
82 | step: int = typer.Option(
83 | None,
84 | help="""Stepsize for the frames to be visualized. If set to 1, all frames will be visualized.
85 | If e.g. set to 2, every second frame will be visualized.""",
86 | ),
87 | remote: str = typer.Option(
88 | None,
89 | help="URL to a ZnTrack repository to stream data from.",
90 | ),
91 | rev: str = typer.Option(
92 | None,
93 | help="Revision of the ZnTrack repository to stream data from.",
94 | ),
95 | tutorial: str = typer.Option(
96 | None,
97 | help="Show the tutorial from the URL inside an IFrame.",
98 | ),
99 | auth_token: str = typer.Option(
100 | None,
101 | help="Token to authenticate pyclient requests to the ZnDraw server, e.g., for adding defaults to all webclients.",
102 | ),
103 | simgen: bool = typer.Option(
104 | False,
105 | help="Show the SiMGen demo UI.",
106 | ),
107 | storage: str = typer.Option(
108 | None,
109 | help="URL to the redis `redis://localhost:6379/0` or znsocket `znsocket://127.0.0.1:6379` server. If None is provided, a local znsocket server will be started.",
110 | ),
111 | storage_port: int = typer.Option(
112 | None, help="Port to use for the storage server. Default port is 6374"
113 | ),
114 | standalone: bool = typer.Option(
115 | True,
116 | help="Run ZnDraw without additional tools. If disabled, redis and celery must be started manually.",
117 | ),
118 | bonds: bool = typer.Option(
119 | True,
120 | help="Compute bonds based on covalent distances. This can be slow for large structures.",
121 | ),
122 | max_http_buffer_size: int = typer.Option(
123 | None, help="Maximum size of the HTTP buffer in bytes. Default is 1MB."
124 | ),
125 | plots: list[str] = typer.Option(
126 | None, "--plots", "-p", help="List of plots to be shown in the ZnDraw GUI."
127 | ),
128 | convert_nan: bool = typer.Option(
129 | False,
130 | help="Convert NaN values to None. This is slow and experimental, but if your file contains NaN/inf values, it is required.",
131 | envvar="ZNDRAW_CONVERT_NAN",
132 | ),
133 | healthcheck: bool = typer.Option(False, help="Run the healthcheck."),
134 | ):
135 | """Start the ZnDraw server.
136 |
137 | Visualize Trajectories, Structures, and more in ZnDraw.
138 | """
139 | if healthcheck:
140 | if url is None:
141 | raise ValueError("You need to provide a URL to use the healthcheck feature.")
142 | run_healthcheck(url)
143 | if plots is None:
144 | plots = []
145 | if token is not None and url is None:
146 | raise ValueError("You need to provide a URL to use the token feature.")
147 | if url is not None and port is not None:
148 | raise ValueError(
149 | "You cannot provide a URL and a port at the same time. Use something like '--url http://localhost:1234' instead."
150 | )
151 |
152 | env_config = EnvOptions.from_env()
153 |
154 | if storage_port is not None:
155 | env_config.FLASK_STORAGE_PORT = str(storage_port)
156 | elif env_config.FLASK_STORAGE_PORT is None:
157 | env_config.FLASK_STORAGE_PORT = str(get_port(default=6374))
158 |
159 | if port is not None:
160 | env_config.FLASK_PORT = str(port)
161 | elif env_config.FLASK_PORT is None:
162 | env_config.FLASK_PORT = str(get_port(default=1234))
163 | if storage is not None:
164 | env_config.FLASK_STORAGE = storage
165 | if auth_token is not None:
166 | env_config.FLASK_AUTH_TOKEN = auth_token
167 | if tutorial is not None:
168 | env_config.FLASK_TUTORIAL = tutorial
169 | if simgen:
170 | env_config.FLASK_SIMGEN = "TRUE"
171 | if bonds:
172 | env_config.FLASK_COMPUTE_BONDS = "TRUE"
173 | if max_http_buffer_size is not None:
174 | env_config.FLASK_MAX_HTTP_BUFFER_SIZE = str(int(max_http_buffer_size))
175 |
176 | env_config.FLASK_SERVER_URL = f"http://localhost:{env_config.FLASK_PORT}"
177 |
178 | if standalone and storage is None:
179 | env_config.FLASK_STORAGE = f"znsocket://localhost:{env_config.FLASK_STORAGE_PORT}"
180 |
181 | env_config.save_to_env()
182 |
183 | if remote is None and rev is None and filename is not None:
184 | if not pathlib.Path(filename).exists():
185 | typer.echo(f"File {filename} does not exist.")
186 | raise typer.Exit(code=1)
187 |
188 | if standalone and url is None:
189 | if env_config.FLASK_STORAGE.startswith("znsocket"):
190 | # standalone with redis would assume a running instance of redis
191 | server = run_znsocket(env_config.FLASK_STORAGE_PORT)
192 | worker = run_celery_thread_worker()
193 |
194 | fileio = FileIO(
195 | name=filename,
196 | remote=remote,
197 | rev=rev,
198 | start=start,
199 | stop=stop,
200 | step=step,
201 | convert_nan=convert_nan,
202 | )
203 |
204 | if url is not None:
205 | upload(url, token, fileio, append, plots, browser)
206 | return
207 |
208 | typer.echo(
209 | f"{datetime.datetime.now().isoformat()}: Starting zndraw server on port {port}"
210 | )
211 |
212 | app = create_app()
213 |
214 | if browser:
215 | webbrowser.open(f"http://localhost:{env_config.FLASK_PORT}")
216 |
217 | socketio = app.extensions["socketio"]
218 |
219 | def signal_handler(sig, frame):
220 | if standalone and url is None:
221 | print("---------------------- SHUTDOWN CELERY ----------------------")
222 | celery_app = app.extensions["celery"]
223 | celery_app.control.broadcast("shutdown")
224 | print("---------------------- SHUTDOWN ZNSOCKET ----------------------")
225 | if env_config.FLASK_STORAGE.startswith("znsocket"):
226 | server.terminate()
227 | server.wait()
228 | print("znsocket server terminated.")
229 | socketio.stop()
230 | worker.join()
231 |
232 | signal.signal(
233 | signal.SIGINT, signal_handler
234 | ) # need to have the signal handler to avoid stalling the celery worker
235 |
236 | read_file.s(fileio.to_dict()).apply_async()
237 | read_plots.s(plots, fileio.remote, fileio.rev).apply_async()
238 |
239 | try:
240 | socketio.run(
241 | app,
242 | host="0.0.0.0",
243 | port=app.config["PORT"],
244 | )
245 | finally:
246 | # get the celery broker config
247 | if app.config["CELERY"]["broker_url"] == "filesystem://":
248 | print("---------------------- REMOVE CELERY CTRL ----------------------")
249 | for path in app.config["CELERY"]["broker_transport_options"].values():
250 | if os.path.exists(path):
251 | shutil.rmtree(path)
252 |
--------------------------------------------------------------------------------
/zndraw_app/healthcheck.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 |
4 | def run_healthcheck(server: str):
5 | """Check the health of a server."""
6 | from zndraw import ZnDraw
7 |
8 | _ = ZnDraw(url=server, token="healthcheck")
9 | sys.exit(0)
10 |
--------------------------------------------------------------------------------
/zndraw_app/make_celery.py:
--------------------------------------------------------------------------------
1 | from zndraw import tasks # noqa used for registering tasks at the moment
2 | from zndraw.app import create_app
3 |
4 | flask_app = create_app()
5 | celery_app = flask_app.extensions["celery"]
6 |
--------------------------------------------------------------------------------