├── examples
├── chainlit.md
├── public
│ ├── favicon.png
│ ├── logo_light.png
│ ├── hide-watermark.js
│ ├── theme.json
│ ├── elements
│ │ └── DataDisplay.jsx
│ └── logo_light.svg
├── particle_trajectory_analysis.png
├── README.md
├── particle_trajectory_analysis.yaml
├── .chainlit
│ └── config.toml
├── writing_improvement.py
└── trajectory_analysis.py
├── assets
└── logo.jpg
├── pyproject.toml
├── requirements.txt
├── MANIFEST.in
├── CITATION.cff
├── setup.py
├── LICENSE
├── nodeology
├── __init__.py
├── log.py
├── client.py
├── state.py
├── interface.py
└── node.py
├── CONTRIBUTING.md
├── .gitignore
├── README.md
└── tests
├── test_state.py
└── test_node.py
/examples/chainlit.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyin-anl/Nodeology/HEAD/assets/logo.jpg
--------------------------------------------------------------------------------
/examples/public/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyin-anl/Nodeology/HEAD/examples/public/favicon.png
--------------------------------------------------------------------------------
/examples/public/logo_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyin-anl/Nodeology/HEAD/examples/public/logo_light.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=45", "wheel"]
3 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/examples/particle_trajectory_analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyin-anl/Nodeology/HEAD/examples/particle_trajectory_analysis.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | requests
2 | pyyaml
3 | typing-extensions
4 | numpy
5 | plotly
6 | kaleido
7 | langgraph<=0.2.45
8 | litellm>=1.0.0
9 | langfuse>=2.0.0
10 | chainlit>=2.0.0
11 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include README.md
3 | include requirements.txt
4 | include CITATION.cff
5 | include CONTRIBUTING.md
6 | recursive-include examples *
7 | recursive-include tests *
8 | recursive-include nodeology *
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Xiangyu"
5 | given-names: "Yin"
6 | orcid: "https://orcid.org/0000-0003-2868-1728"
7 | title: "Nodeology"
8 | version: 0.0.1
9 | date-released: 2024-11-20
10 | url: "https://github.com/xyin-anl/Nodeology"
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.md", "r", encoding="utf-8") as fh:
4 | long_description = fh.read()
5 |
6 | with open("requirements.txt", "r", encoding="utf-8") as fh:
7 | requirements = [
8 | line.strip() for line in fh if line.strip() and not line.startswith("#")
9 | ]
10 |
11 | setup(
12 | name="nodeology",
13 | version="0.0.2",
14 | author="Xiangyu Yin",
15 | author_email="xyin@anl.gov",
16 | description="Foundation AI-Enhanced Scientific Workflow",
17 | long_description=long_description,
18 | long_description_content_type="text/markdown",
19 | url="https://github.com/xyin-anl/nodeology",
20 | packages=find_packages(),
21 | include_package_data=True,
22 | package_data={
23 | "": ["*.md", "*.cff", "LICENSE"],
24 | "nodeology": ["examples/*", "tests/*"],
25 | },
26 | classifiers=[
27 | "Intended Audience :: Science/Research",
28 | "Programming Language :: Python :: 3",
29 | ],
30 | python_requires=">=3.10",
31 | install_requires=requirements,
32 | )
33 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Nodeology Examples
2 |
3 | This directory contains example applications built with Nodeology, demonstrating features of the framework.
4 |
5 | ## Prerequisites
6 |
7 | Before running the examples, ensure you have `nodeology` installed
8 |
9 | ```bash
10 | pip install nodeology
11 | ```
12 |
13 | ## Directory Structure
14 |
15 | - `writing_improvement.py` - Text analysis and improvement workflow
16 | - `trajectory_analysis.py` - Particle trajectory simulation and visualization
17 | - `public/` - Static assets for the examples (needed for `nodeology` UI elements)
18 | - `.chainlit/` - Chainlit configuration files (needed for `nodeology` UI settings)
19 |
20 | ## Available Examples
21 |
22 | ### 1. Writing Improvement (`writing_improvement.py`)
23 |
24 | An interactive application that helps users improve their writing through analysis and suggestions. This example demonstrates:
25 |
26 | - State management with Nodeology
27 | - Interactive user input handling
28 | - Text analysis workflow
29 | - Chainlit UI integration
30 |
31 | To run this example:
32 |
33 | ```bash
34 | cd examples
35 | python writing_improvement.py
36 | ```
37 |
38 | ### 2. Particle Trajectory Analysis (`trajectory_analysis.py`)
39 |
40 | A scientific application that simulates and visualizes particle trajectories under electromagnetic fields. This example showcases:
41 |
42 | - Complex scientific calculations
43 | - Interactive parameter input
44 | - Data visualization
45 | - State management for scientific workflows
46 | - Advanced Chainlit UI features
47 |
48 | To run this example:
49 |
50 | ```bash
51 | cd examples
52 | python trajectory_analysis.py
53 | ```
54 |
55 | ## Usage Tips
56 |
57 | 1. Each example will open in your default web browser when launched
58 | 2. Follow the interactive prompts in the Chainlit UI
59 | 3. You can modify parameters and experiment with different inputs
60 | 4. Use the chat interface to interact with the applications
61 |
62 | ## License
63 |
64 | These examples are provided under the same license as the main Nodeology project. See the license headers in individual files for details.
65 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
2 |
3 | Copyright 2024. UChicago Argonne, LLC. This software was produced
4 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
5 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
6 | U.S. Department of Energy. The U.S. Government has rights to use,
7 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
8 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
9 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
10 | modified to produce derivative works, such modified software should
11 | be clearly marked, so as not to confuse it with the version available
12 | from ANL.
13 |
14 | Additionally, redistribution and use in source and binary forms, with
15 | or without modification, are permitted provided that the following
16 | conditions are met:
17 |
18 | * Redistributions of source code must retain the above copyright
19 | notice, this list of conditions and the following disclaimer.
20 |
21 | * Redistributions in binary form must reproduce the above copyright
22 | notice, this list of conditions and the following disclaimer in
23 | the documentation and/or other materials provided with the
24 | distribution.
25 |
26 | * Neither the name of UChicago Argonne, LLC, Argonne National
27 | Laboratory, ANL, the U.S. Government, nor the names of its
28 | contributors may be used to endorse or promote products derived
29 | from this software without specific prior written permission.
30 |
31 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
32 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
33 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
34 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
35 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
36 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
37 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
39 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
40 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
41 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
42 | POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/nodeology/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
3 |
4 | Copyright 2024. UChicago Argonne, LLC. This software was produced
5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
7 | U.S. Department of Energy. The U.S. Government has rights to use,
8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
11 | modified to produce derivative works, such modified software should
12 | be clearly marked, so as not to confuse it with the version available
13 | from ANL.
14 |
15 | Additionally, redistribution and use in source and binary forms, with
16 | or without modification, are permitted provided that the following
17 | conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright
20 | notice, this list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright
23 | notice, this list of conditions and the following disclaimer in
24 | the documentation and/or other materials provided with the
25 | distribution.
26 |
27 | * Neither the name of UChicago Argonne, LLC, Argonne National
28 | Laboratory, ANL, the U.S. Government, nor the names of its
29 | contributors may be used to endorse or promote products derived
30 | from this software without specific prior written permission.
31 |
32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43 | POSSIBILITY OF SUCH DAMAGE.
44 | """
45 |
46 | ### Initial Author <2024>: Xiangyu Yin
47 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Nodeology
2 |
3 | Thank you for your interest in contributing to Nodeology! This document provides guidelines and information about contributing to the project.
4 |
5 | ## Ways to Contribute
6 |
7 | ### Code Contributions
8 |
9 | 1. **Core Framework**
10 |
11 | - Bug fixes and improvements
12 | - Performance optimizations
13 | - New features
14 | - Test coverage
15 |
16 | 2. **Pre-built Components**
17 | - New node types
18 | - State definitions
19 | - Workflow templates
20 | - Domain-specific tools
21 |
22 | ### Documentation
23 |
24 | - API documentation
25 | - Usage examples
26 | - Tutorials
27 | - Best practices
28 |
29 | ### Research Collaborations
30 |
31 | 1. **Workflow Patterns**
32 |
33 | - Novel automation patterns
34 | - Optimization strategies
35 | - Human-AI interaction interfaces
36 | - Error handling approaches
37 |
38 | 2. **Scientific Integration**
39 |
40 | - Domain-specific applications
41 | - Instrument interfaces
42 | - Data processing pipelines
43 | - Analysis tools
44 |
45 | 3. **Evaluation Methods**
46 |
47 | - Benchmark development
48 | - Performance metrics
49 | - Reliability assessment
50 | - Comparison frameworks
51 |
52 | 4. **AI Integration**
53 | - Prompt optimization
54 | - Model evaluation
55 | - Hybrid processing
56 | - Knowledge integration
57 |
58 | ## Getting Started
59 |
60 | 1. **Development Setup**
61 |
62 | ```bash
63 | # Clone repository
64 | git clone https://github.com/xyin-anl/nodeology.git
65 | cd nodeology
66 |
67 | # Create virtual environment using venv or conda
68 | python -m venv venv
69 | source venv/bin/activate
70 |
71 | # Install dependencies
72 | pip install -r requirements.txt
73 |
74 | # Install pytest
75 | pip install pytest
76 |
77 | # Run tests
78 | pytest tests/
79 | ```
80 |
81 | ## Guidelines
82 |
83 | 1. **Code Contributions**
84 |
85 | - Fork repository
86 | - Create feature branch
87 | - Submit pull request
88 | - Use `black` formatter
89 | - Write clear commit messages
90 | - Add unit tests
91 | - Update/create documentation if possible
92 | - Include example usage if possible
93 |
94 | 2. **Documentation**
95 |
96 | - Clear, concise writing
97 | - Practical examples
98 | - Proper formatting
99 | - Complete coverage
100 |
101 | 3. **Community**
102 | - Be respectful
103 | - Provide constructive feedback
104 | - Help new users
105 | - Share knowledge
106 |
--------------------------------------------------------------------------------
/examples/particle_trajectory_analysis.yaml:
--------------------------------------------------------------------------------
1 | name: TrajectoryWorkflow_03_13_2025_20_06_45
2 | state_defs:
3 | - current_node_type: str
4 | - previous_node_type: str
5 | - human_input: str
6 | - input: str
7 | - output: str
8 | - messages: List[dict]
9 | - mass: float
10 | - charge: float
11 | - initial_velocity: ndarray
12 | - E_field: ndarray
13 | - B_field: ndarray
14 | - confirm_parameters: bool
15 | - parameters_updater_output: str
16 | - positions: List[ndarray]
17 | - trajectory_plot: str
18 | - trajectory_plot_path: str
19 | - analysis_result: dict
20 | - continue_simulation: bool
21 | nodes:
22 | display_parameters:
23 | type: display_parameters
24 | next: ask_confirm_parameters
25 | ask_confirm_parameters:
26 | type: ask_confirm_parameters
27 | sink: confirm_parameters
28 | next:
29 | condition: confirm_parameters
30 | then: calculate_trajectory
31 | otherwise: ask_parameters_input
32 | ask_parameters_input:
33 | type: ask_parameters_input
34 | sink: human_input
35 | next: update_parameters
36 | update_parameters:
37 | type: prompt
38 | template: 'Update the parameters based on the user''s input. Current parameters:
39 | mass: {mass} charge: {charge} initial_velocity: {initial_velocity} E_field:
40 | {E_field} B_field: {B_field} User input: {human_input} Please return the updated
41 | parameters in JSON format. {{ "mass": float, "charge": float, "initial_velocity":
42 | list[float], "E_field": list[float], "B_field": list[float] }}'
43 | sink: parameters_updater_output
44 | next: display_parameters
45 | calculate_trajectory:
46 | type: calculate_trajectory
47 | sink: positions
48 | next: plot_trajectory
49 | plot_trajectory:
50 | type: plot_trajectory
51 | sink: [trajectory_plot, trajectory_plot_path]
52 | next: analyze_trajectory
53 | analyze_trajectory:
54 | type: prompt
55 | template: 'Analyze this particle trajectory plot. Please determine: 1. The type
56 | of motion (linear, circular, helical, or chaotic) 2. Key physical features (radius,
57 | period, pitch angle if applicable) 3. Explanation of the motion 4. Anomalies
58 | in the motion Output in JSON format: {{ "trajectory_type": "type_name", "key_features":
59 | { "feature1": value, "feature2": value }, "explanation": "detailed explanation",
60 | "anomalies": "anomaly description" }}'
61 | sink: analysis_result
62 | image_keys: trajectory_plot_path
63 | next: ask_continue_simulation
64 | ask_continue_simulation:
65 | type: ask_continue_simulation
66 | sink: continue_simulation
67 | next:
68 | condition: continue_simulation
69 | then: display_parameters
70 | otherwise: END
71 | entry_point: display_parameters
72 | llm: gemini/gemini-2.0-flash
73 | vlm: gemini/gemini-2.0-flash
74 | exit_commands: [stop workflow, quit workflow, terminate workflow]
75 |
--------------------------------------------------------------------------------
/examples/public/hide-watermark.js:
--------------------------------------------------------------------------------
1 | function hideWatermark() {
2 | // Try multiple selector approaches
3 | const selectors = [
4 | "#chainlit-copilot",
5 | ".cl-copilot-container",
6 | "[data-testid='copilot-container']",
7 | // Add any other potential selectors
8 | ];
9 |
10 | for (const selector of selectors) {
11 | const elements = document.querySelectorAll(selector);
12 |
13 | elements.forEach(element => {
14 | // Try to access shadow DOM if it exists
15 | if (element.shadowRoot) {
16 | const watermarks = element.shadowRoot.querySelectorAll("a.watermark, .watermark, [class*='watermark']");
17 | watermarks.forEach(watermark => {
18 | watermark.style.display = "none";
19 | watermark.style.visibility = "hidden";
20 | watermark.remove(); // Try to remove it completely
21 | });
22 | }
23 |
24 | // Also check for watermarks in the regular DOM
25 | const directWatermarks = element.querySelectorAll("a.watermark, .watermark, [class*='watermark']");
26 | directWatermarks.forEach(watermark => {
27 | watermark.style.display = "none";
28 | watermark.style.visibility = "hidden";
29 | watermark.remove(); // Try to remove it completely
30 | });
31 | });
32 | }
33 |
34 | // Add CSS to hide watermarks globally
35 | const style = document.createElement('style');
36 | style.textContent = `
37 | a.watermark, .watermark, [class*='watermark'] {
38 | display: none !important;
39 | visibility: hidden !important;
40 | opacity: 0 !important;
41 | pointer-events: none !important;
42 | }
43 | `;
44 | document.head.appendChild(style);
45 | }
46 |
47 | // More aggressive approach with mutation observer for the entire document
48 | function setupGlobalObserver() {
49 | const observer = new MutationObserver((mutations) => {
50 | let shouldCheck = false;
51 |
52 | for (const mutation of mutations) {
53 | if (mutation.addedNodes.length > 0) {
54 | shouldCheck = true;
55 | break;
56 | }
57 | }
58 |
59 | if (shouldCheck) {
60 | hideWatermark();
61 | }
62 | });
63 |
64 | observer.observe(document.body, {
65 | childList: true,
66 | subtree: true
67 | });
68 | }
69 |
70 | // Run on page load
71 | document.addEventListener("DOMContentLoaded", function() {
72 | // Try immediately
73 | hideWatermark();
74 |
75 | // Setup global observer
76 | setupGlobalObserver();
77 |
78 | // Try again after delays to catch late-loading elements
79 | setTimeout(hideWatermark, 1000);
80 | setTimeout(hideWatermark, 3000);
81 |
82 | // Periodically check
83 | setInterval(hideWatermark, 5000);
84 | });
85 |
86 | // Also run the script immediately in case the DOM is already loaded
87 | if (document.readyState === "complete" || document.readyState === "interactive") {
88 | hideWatermark();
89 | setTimeout(setupGlobalObserver, 0);
90 | }
91 |
--------------------------------------------------------------------------------
/examples/public/theme.json:
--------------------------------------------------------------------------------
1 | {
2 | "custom_fonts": [],
3 | "variables": {
4 | "light": {
5 | "--font-sans": "'Inter', sans-serif",
6 | "--font-mono": "source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace",
7 | "--background": "0 0% 100%",
8 | "--foreground": "0 0% 5%",
9 | "--card": "0 0% 100%",
10 | "--card-foreground": "0 0% 5%",
11 | "--popover": "0 0% 100%",
12 | "--popover-foreground": "0 0% 5%",
13 | "--primary": "150 40% 60%",
14 | "--primary-foreground": "0 0% 100%",
15 | "--secondary": "150 30% 90%",
16 | "--secondary-foreground": "150 30% 20%",
17 | "--muted": "0 0% 90%",
18 | "--muted-foreground": "150 15% 30%",
19 | "--accent": "0 0% 95%",
20 | "--accent-foreground": "150 30% 20%",
21 | "--destructive": "0 84.2% 60.2%",
22 | "--destructive-foreground": "210 40% 98%",
23 | "--border": "150 30% 75%",
24 | "--input": "150 30% 75%",
25 | "--ring": "150 40% 60%",
26 | "--radius": "0.75rem",
27 | "--sidebar-background": "0 0% 98%",
28 | "--sidebar-foreground": "240 5.3% 26.1%",
29 | "--sidebar-primary": "150 40% 60%",
30 | "--sidebar-primary-foreground": "0 0% 98%",
31 | "--sidebar-accent": "240 4.8% 95.9%",
32 | "--sidebar-accent-foreground": "240 5.9% 10%",
33 | "--sidebar-border": "220 13% 91%",
34 | "--sidebar-ring": "217.2 91.2% 59.8%"
35 | },
36 | "dark": {
37 | "--font-sans": "'Inter', sans-serif",
38 | "--font-mono": "source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace",
39 | "--background": "0 0% 13%",
40 | "--foreground": "0 0% 93%",
41 | "--card": "0 0% 18%",
42 | "--card-foreground": "210 40% 98%",
43 | "--popover": "0 0% 18%",
44 | "--popover-foreground": "210 40% 98%",
45 | "--primary": "150 45% 50%",
46 | "--primary-foreground": "0 0% 100%",
47 | "--secondary": "150 35% 25%",
48 | "--secondary-foreground": "0 0% 98%",
49 | "--muted": "150 15% 30%",
50 | "--muted-foreground": "150 10% 80%",
51 | "--accent": "150 40% 40%",
52 | "--accent-foreground": "0 0% 98%",
53 | "--destructive": "0 62.8% 30.6%",
54 | "--destructive-foreground": "210 40% 98%",
55 | "--border": "150 30% 40%",
56 | "--input": "150 30% 40%",
57 | "--ring": "150 45% 50%",
58 | "--sidebar-background": "0 0% 9%",
59 | "--sidebar-foreground": "240 4.8% 95.9%",
60 | "--sidebar-primary": "150 45% 50%",
61 | "--sidebar-primary-foreground": "0 0% 100%",
62 | "--sidebar-accent": "150 25% 20%",
63 | "--sidebar-accent-foreground": "240 4.8% 95.9%",
64 | "--sidebar-border": "240 3.7% 15.9%",
65 | "--sidebar-ring": "217.2 91.2% 59.8%"
66 | }
67 | }
68 | }
--------------------------------------------------------------------------------
/examples/.chainlit/config.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | # Whether to enable telemetry (default: true). No personal data is collected.
3 | enable_telemetry = false
4 |
5 | # List of environment variables to be provided by each user to use the app.
6 | user_env = []
7 |
8 | # Duration (in seconds) during which the session is saved when the connection is lost
9 | session_timeout = 3600
10 |
11 | # Duration (in seconds) of the user session expiry
12 | user_session_timeout = 1296000 # 15 days
13 |
14 | # Enable third parties caching (e.g LangChain cache)
15 | cache = false
16 |
17 | # Authorized origins
18 | allow_origins = ["*"]
19 |
20 | [features]
21 | # Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
22 | unsafe_allow_html = false
23 |
24 | # Process and display mathematical expressions. This can clash with "$" characters in messages.
25 | latex = false
26 |
27 | # Automatically tag threads with the current chat profile (if a chat profile is used)
28 | auto_tag_thread = true
29 |
30 | # Allow users to edit their own messages
31 | edit_message = true
32 |
33 | # Authorize users to spontaneously upload files with messages
34 | [features.spontaneous_file_upload]
35 | enabled = true
36 | # Define accepted file types using MIME types
37 | # Examples:
38 | # 1. For specific file types:
39 | # accept = ["image/jpeg", "image/png", "application/pdf"]
40 | # 2. For all files of certain type:
41 | # accept = ["image/*", "audio/*", "video/*"]
42 | # 3. For specific file extensions:
43 | # accept = { "application/octet-stream" = [".xyz", ".pdb"] }
44 | # Note: Using "*/*" is not recommended as it may cause browser warnings
45 | accept = ["*/*"]
46 | max_files = 20
47 | max_size_mb = 500
48 |
49 | [features.audio]
50 | # Sample rate of the audio
51 | sample_rate = 24000
52 |
53 | [UI]
54 | # Name of the assistant.
55 | name = "Assistant"
56 |
57 | # default_theme = "light"
58 |
59 | # layout = "wide"
60 |
61 | # Description of the assistant. This is used for HTML tags.
62 | # description = ""
63 |
64 | # Chain of Thought (CoT) display mode. Can be "hidden", "tool_call" or "full".
65 | cot = "full"
66 |
67 | # Specify a CSS file that can be used to customize the user interface.
68 | # The CSS file can be served from the public directory or via an external link.
69 | # custom_css = "/public/test.css"
70 |
71 | # Specify a Javascript file that can be used to customize the user interface.
72 | # The Javascript file can be served from the public directory.
73 | custom_js = "/public/hide-watermark.js"
74 |
75 | # Specify a custom meta image url.
76 | # custom_meta_image_url = "https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png"
77 |
78 | # Specify a custom build directory for the frontend.
79 | # This can be used to customize the frontend code.
80 | # Be careful: If this is a relative path, it should not start with a slash.
81 | # custom_build = "./public/build"
82 |
83 | # Specify optional one or more custom links in the header.
84 | # [[UI.header_links]]
85 | # name = "Nodeology"
86 | # icon_url = "https://avatars.githubusercontent.com/u/128686189?s=200&v=4"
87 | # url = "https://github.com/xyin-anl/Nodeology"
88 |
89 | [meta]
90 | generated_by = "2.2.1"
91 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | # Apple
165 | .DS_Store
166 |
167 | # Nodeology
168 | artifacts/
169 | logs/
170 |
171 | # PyPI
172 | dist/
173 | *.egg-info/
174 | .pypirc
175 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | > [!IMPORTANT]
2 | > This package is actively in development, and breaking changes may occur.
3 |
4 |
5 |
6 |
7 |
8 |
9 | ## 🤖 Foundation AI-Enhanced Scientific Workflow
10 |
11 | Foundation AI holds enormous potential for scientific research, especially in analyzing unstructured data, automating complex reasoning tasks, and simplifying human-computer interactions. However, integrating foundation AI models like LLMs and VLMs into scientific workflows poses challenges: handling diverse data types beyond text and images, managing model inaccuracies (hallucinations), and adapting general-purpose models to highly specialized scientific contexts.
12 |
13 | `nodeology` addresses these challenges by combining the strengths of foundation AI with traditional scientific methods and expert oversight. Built on `langgraph`'s state machine framework, it simplifies creating robust, AI-driven workflows through an intuitive, accessible interface. Originally developed at Argonne National Lab, the framework enables researchers—especially those without extensive programming experience—to quickly design and deploy full-stack AI workflows simply using prompt templates and existing functions as reusable nodes.
14 |
15 | Key features include:
16 |
17 | - Easy creation of AI-integrated workflows without complex syntax
18 | - Flexible and composable node architecture for various tasks
19 | - Seamless human-in-the-loop interactions for expert oversight
20 | - Portable workflow templates for collaboration and reproducibility
21 | - Quickly spin up simple chatbots for immediate AI interaction
22 | - Built-in tracing and telemetry for workflow monitoring and optimization
23 |
24 | ## 🚀 Getting Started
25 |
26 | ### Install the package
27 |
28 | To use the latest development version:
29 |
30 | ```bash
31 | pip install git+https://github.com/xyin-anl/Nodeology.git
32 | ```
33 |
34 | To use the latest release version:
35 |
36 | ```bash
37 | pip install nodeology
38 | ```
39 |
40 | ### Access foundation models
41 |
42 | Nodeology supports various cloud-based/local foundation models via [LiteLLM](https://docs.litellm.ai/docs/), see [provider list](https://docs.litellm.ai/docs/providers). Most of cloud-based models usage requires setting up API key. For example:
43 |
44 | ```bash
45 | # For OpenAI models
46 | export OPENAI_API_KEY='your-api-key'
47 |
48 | # For Anthropic models
49 | export ANTHROPIC_API_KEY='your-api-key'
50 |
51 | # For Gemini models
52 | export GEMINI_API_KEY='your-api-key'
53 |
54 | # For Together AI hosted open weight models
55 | export TOGETHER_API_KEY='your-api-key'
56 | ```
57 |
58 | > **💡 Tip:** The field of foundation models is evolving rapidly with new and improved models emerging frequently. As of **February 2025**, we recommend the following models based on their strengths:
59 | >
60 | > - **gpt-4o**: Excellent for broad general knowledge, writing tasks, and conversational interactions
61 | > - **o3-mini**: Good balance of math, coding, and reasoning capabilities at a lower price point
62 | > - **anthropic/claude-3.7**: Strong performance in general knowledge, math, science, and coding with well-constrained outputs
63 | > - **gemini/gemini-2.0-flash**: Effective for general knowledge tasks with a large context window for processing substantial information
64 | > - **together_ai/deepseek-ai/DeepSeek-R1**: Exceptional reasoning, math, science, and coding capabilities with transparent thinking processes
65 |
66 | **For Argonne Users:** if you are within Argonne network, you will have access to OpenAI's models through Argonne's ARGO inference service and ALCF's open weights model inference service for free. Please check this [link](https://gist.github.com/xyin-anl/0cc744a7862e153414857b15fe31b239) to see how to use them
67 |
68 | ### Langfuse Tracing (Optional)
69 |
70 | Nodeology supports [Langfuse](https://langfuse.com/) for observability and tracing of LLM/VLM calls. To use Langfuse:
71 |
72 | 1. Set up a Langfuse account and get your API keys
73 | 2. Configure Langfuse with your keys:
74 |
75 | ```bash
76 | # Set environment variables
77 | export LANGFUSE_PUBLIC_KEY='your-public-key'
78 | export LANGFUSE_SECRET_KEY='your-secret-key'
79 | export LANGFUSE_HOST='https://cloud.langfuse.com' # Or your self-hosted URL
80 | ```
81 |
82 | Or configure programmatically:
83 |
84 | ```python
85 | from nodeology.client import configure_langfuse
86 |
87 | configure_langfuse(
88 | public_key='your-public-key',
89 | secret_key='your-secret-key',
90 | host='https://cloud.langfuse.com' # Optional
91 | )
92 | ```
93 |
94 | ### Chainlit Interface (Optional)
95 |
96 | Nodeology supports [Chainlit](https://docs.chainlit.io/get-started/overview) for creating chat-based user interfaces. To use this feature, simply set `ui=True` when running your workflow:
97 |
98 | ```python
99 | # Create your workflow
100 | workflow = MyWorkflow()
101 |
102 | # Run with UI enabled
103 | workflow.run(ui=True)
104 | ```
105 |
106 | This will automatically launch a Chainlit server with a chat interface for interacting with your workflow. The interface preserves your workflow's state and configuration, allowing users to interact with it through a user-friendly chat interface.
107 |
108 | When the Chainlit server starts, you can access the interface through your web browser at `http://localhost:8000` by default.
109 |
110 | ## 🧪 Illustrating Examples
111 |
112 | ### [Writing Improvement](https://github.com/xyin-anl/Nodeology/examples/writing_improvement.py)
113 |
114 |
119 |
120 | ### [Trajectory Analysis](https://github.com/xyin-anl/Nodeology/examples/trajectory_analysis.py)
121 |
122 |
127 |
128 | ## 🔬 Scientific Applications
129 |
130 | - [PEAR: Ptychography automation framework](https://arxiv.org/abs/2410.09034)
131 | - [AutoScriptCopilot: TEM experiment control](https://github.com/xyin-anl/AutoScriptCopilot)
132 |
133 | ## 👥 Contributing & Collaboration
134 |
135 | We welcome comments, feedbacks, bugs report, code contributions and research collaborations. Please refer to CONTRIBUTING.md
136 |
137 | If you find `nodeology` useful and may inspire your research, please use the **Cite this repository** function
138 |
--------------------------------------------------------------------------------
/examples/writing_improvement.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
3 |
4 | Copyright 2024. UChicago Argonne, LLC. This software was produced
5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
7 | U.S. Department of Energy. The U.S. Government has rights to use,
8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
11 | modified to produce derivative works, such modified software should
12 | be clearly marked, so as not to confuse it with the version available
13 | from ANL.
14 |
15 | Additionally, redistribution and use in source and binary forms, with
16 | or without modification, are permitted provided that the following
17 | conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright
20 | notice, this list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright
23 | notice, this list of conditions and the following disclaimer in
24 | the documentation and/or other materials provided with the
25 | distribution.
26 |
27 | * Neither the name of UChicago Argonne, LLC, Argonne National
28 | Laboratory, ANL, the U.S. Government, nor the names of its
29 | contributors may be used to endorse or promote products derived
30 | from this software without specific prior written permission.
31 |
32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43 | POSSIBILITY OF SUCH DAMAGE.
44 | """
45 |
46 | ### Initial Author <2025>: Xiangyu Yin
47 |
48 | import json
49 | from nodeology.state import State
50 | from nodeology.node import Node, as_node
51 | from nodeology.workflow import Workflow
52 |
53 | import chainlit as cl
54 | from chainlit import Message, AskActionMessage, run_sync
55 | from langgraph.graph import END
56 |
57 |
58 | # 1. Define your state
59 | class TextAnalysisState(State):
60 | analysis: dict # Analysis results
61 | text: str # Enhanced text
62 | continue_improving: bool # Whether to continue improving
63 |
64 |
65 | # 2. Create nodes
66 | @as_node(sink="text")
67 | def parse_human_input(human_input: str):
68 | return human_input
69 |
70 |
71 | analyze_text = Node(
72 | prompt_template="""Text to analyze: {text}
73 |
74 | Analyze the above text for:
75 | - Clarity (1-10)
76 | - Grammar (1-10)
77 | - Style (1-10)
78 | - Suggestions for improvement
79 |
80 | Output as JSON:
81 | {{
82 | "clarity_score": int,
83 | "grammar_score": int,
84 | "style_score": int,
85 | "suggestions": str
86 | }}
87 | """,
88 | sink="analysis",
89 | sink_format="json",
90 | )
91 |
92 |
93 | def report_analysis(state, client, **kwargs):
94 | analysis = json.loads(state["analysis"])
95 | run_sync(
96 | Message(
97 | content="Below is the analysis of the text:",
98 | elements=[cl.CustomElement(name="DataDisplay", props={"data": analysis})],
99 | ).send()
100 | )
101 | return state
102 |
103 |
104 | analyze_text.post_process = report_analysis
105 |
106 | improve_text = Node(
107 | prompt_template="""Text to improve: {text}
108 |
109 | Analysis: {analysis}
110 |
111 | Rewrite the text incorporating the suggestions while maintaining the original meaning.
112 | Focus on clarity, grammar, and style improvements. Return the improved text only.""",
113 | sink="text",
114 | )
115 |
116 |
117 | def report_improvement(state, client, **kwargs):
118 | text_md = f"{state['text']}"
119 | run_sync(
120 | Message(
121 | content="Below is the improved text:", elements=[cl.Text(content=text_md)]
122 | ).send()
123 | )
124 | return state
125 |
126 |
127 | improve_text.post_process = report_improvement
128 |
129 |
130 | @as_node(sink="continue_improving")
131 | def ask_continue_improve():
132 | res = run_sync(
133 | AskActionMessage(
134 | content="Would you like to further improve the text?",
135 | timeout=300,
136 | actions=[
137 | cl.Action(
138 | name="continue",
139 | payload={"value": "continue"},
140 | label="Continue Improving",
141 | ),
142 | cl.Action(
143 | name="finish",
144 | payload={"value": "finish"},
145 | label="Finish",
146 | ),
147 | ],
148 | ).send()
149 | )
150 |
151 | # Return the user's choice
152 | if res and res.get("payload").get("value") == "continue":
153 | return True
154 | else:
155 | return False
156 |
157 |
158 | # 3. Create workflow
159 | class TextEnhancementWorkflow(Workflow):
160 | state_schema = TextAnalysisState
161 |
162 | def create_workflow(self):
163 | # Add nodes
164 | self.add_node("parse_human_input", parse_human_input)
165 | self.add_node("analyze", analyze_text)
166 | self.add_node("improve", improve_text)
167 | self.add_node("ask_continue", ask_continue_improve)
168 |
169 | # Connect nodes
170 | self.add_flow("parse_human_input", "analyze")
171 | self.add_flow("analyze", "improve")
172 | self.add_flow("improve", "ask_continue")
173 |
174 | # Add conditional flow based on user's choice
175 | self.add_conditional_flow(
176 | "ask_continue",
177 | "continue_improving",
178 | "analyze",
179 | END,
180 | )
181 |
182 | # Set entry point
183 | self.set_entry("parse_human_input")
184 |
185 | # Compile workflow
186 | self.compile(
187 | interrupt_before=["parse_human_input"],
188 | interrupt_before_phrases={
189 | "parse_human_input": "Please enter the text to analyze."
190 | },
191 | )
192 |
193 |
194 | # 4. Run workflow
195 | workflow = TextEnhancementWorkflow(
196 | llm_name="gemini/gemini-2.0-flash", save_artifacts=True
197 | )
198 |
199 | if __name__ == "__main__":
200 | result = workflow.run(ui=True)
201 |
--------------------------------------------------------------------------------
/examples/public/elements/DataDisplay.jsx:
--------------------------------------------------------------------------------
1 | import { Card, CardHeader, CardTitle, CardContent } from "@/components/ui/card"
2 | import { Badge } from "@/components/ui/badge"
3 | import { Separator } from "@/components/ui/separator"
4 | import { ScrollArea } from "@/components/ui/scroll-area"
5 | import { Button } from "@/components/ui/button"
6 | import { Copy, ChevronDown, ChevronRight } from "lucide-react"
7 | import { useState } from "react"
8 |
9 | export default function DataDisplay() {
10 | // Data is passed via props
11 | const data = props.data || {};
12 | const title = props.title || "Data";
13 | const badge = props.badge || null;
14 | const maxHeight = props.maxHeight || "300px";
15 | const showScrollArea = props.showScrollArea !== false;
16 | const collapsible = props.collapsible !== false;
17 | const theme = props.theme || "default"; // default, compact, or expanded
18 |
19 | // State for collapsible sections
20 | const [expandedSections, setExpandedSections] = useState({});
21 |
22 | // Toggle section expansion
23 | const toggleSection = (key) => {
24 | setExpandedSections(prev => ({
25 | ...prev,
26 | [key]: !prev[key]
27 | }));
28 | };
29 |
30 | // Copy value to clipboard
31 | const copyToClipboard = (value) => {
32 | let textToCopy;
33 |
34 | if (Array.isArray(value)) {
35 | // Format arrays properly for copying
36 | textToCopy = JSON.stringify(value);
37 | } else if (typeof value === 'object' && value !== null) {
38 | textToCopy = JSON.stringify(value, null, 2);
39 | } else {
40 | textToCopy = String(value);
41 | }
42 |
43 | navigator.clipboard.writeText(textToCopy);
44 | // Could use sonner for toast notification here
45 | };
46 |
47 | // Helper to format values with scientific notation for very small/large numbers
48 | const formatValue = (value) => {
49 | if (value === undefined || value === null) return "N/A";
50 |
51 | if (Array.isArray(value) || (typeof value === 'object' && value.length)) {
52 | return formatArray(value);
53 | }
54 |
55 | if (typeof value === 'object' && value !== null) {
56 | return null; // Handled separately in the render
57 | }
58 |
59 | if (typeof value === 'number') {
60 | // Use scientific notation for very small or large numbers
61 | if (Math.abs(value) < 0.001 || Math.abs(value) > 10000) {
62 | return value.toExponential(4);
63 | }
64 | // Format with appropriate precision
65 | return Number.isInteger(value) ? value.toString() : value.toFixed(4).replace(/\.?0+$/, '');
66 | }
67 |
68 | if (typeof value === 'boolean') {
69 | return value ? "true" : "false";
70 | }
71 |
72 | return String(value);
73 | };
74 |
75 | // Helper function to format arrays nicely
76 | const formatArray = (arr) => {
77 | if (!arr) return "N/A";
78 | if (typeof arr === 'string') return arr;
79 |
80 | try {
81 | // Handle both array-like objects and actual arrays
82 | if (arr.length > 10) {
83 | const displayed = Array.from(arr).slice(0, 10);
84 | return `[${displayed.map(val =>
85 | typeof val === 'object' ? '{...}' :
86 | typeof val === 'number' ? formatValue(val) :
87 | JSON.stringify(val)
88 | ).join(", ")}, ... ${arr.length - 10} more]`;
89 | }
90 |
91 | return `[${Array.from(arr).map(val =>
92 | typeof val === 'object' ? '{...}' :
93 | typeof val === 'number' ? formatValue(val) :
94 | JSON.stringify(val)
95 | ).join(", ")}]`;
96 | } catch (e) {
97 | return String(arr);
98 | }
99 | };
100 |
101 | // Helper function to render nested objects
102 | const renderNestedObject = (obj, path = "", level = 0) => {
103 | if (!obj || typeof obj !== 'object') return {String(obj || "N/A")} ;
104 |
105 | if (Array.isArray(obj)) {
106 | return (
107 |
108 | {formatArray(obj)}
109 | copyToClipboard(obj)}
114 | >
115 |
116 |
117 |
118 | );
119 | }
120 |
121 | const isExpanded = expandedSections[path] !== false; // Default to expanded
122 |
123 | return (
124 | 0 ? "pl-4" : ""}`}>
125 | {Object.entries(obj).map(([key, value], index) => {
126 | const currentPath = path ? `${path}.${key}` : key;
127 | const isObject = typeof value === 'object' && value !== null;
128 |
129 | return (
130 |
131 | {index > 0 && level === 0 &&
}
132 |
133 |
134 |
135 | {isObject && collapsible && (
136 |
toggleSection(currentPath)}
141 | >
142 | {expandedSections[currentPath] === false ?
143 | :
144 |
145 | }
146 |
147 | )}
148 |
{key}
149 |
150 |
151 |
152 |
0 ? "border-l-2 border-gray-200 dark:border-gray-700 pl-2" : ""}
155 | ${isObject && expandedSections[currentPath] === false ? "hidden" : ""}
156 | `}>
157 | {isObject ?
158 | renderNestedObject(value, currentPath, level + 1) :
159 |
160 |
{formatValue(value)}
161 |
copyToClipboard(value)}
166 | >
167 |
168 |
169 |
170 | }
171 |
172 |
173 | );
174 | })}
175 |
176 | );
177 | };
178 |
179 | // Apply theme styles
180 | const getThemeStyles = () => {
181 | switch (theme) {
182 | case 'compact':
183 | return "text-xs";
184 | case 'expanded':
185 | return "text-base";
186 | default:
187 | return "text-sm";
188 | }
189 | };
190 |
191 | return (
192 |
193 |
194 |
195 |
196 | {title}
197 |
198 |
199 | {badge && (
200 | {badge}
201 | )}
202 | copyToClipboard(data)}
207 | title="Copy all data"
208 | >
209 |
210 |
211 |
212 |
213 |
214 |
215 | {Object.keys(data).length === 0 ? (
216 | No data available
217 | ) : (
218 | showScrollArea ? (
219 |
220 | {renderNestedObject(data)}
221 |
222 | ) : (
223 | renderNestedObject(data)
224 | )
225 | )}
226 |
227 |
228 | )
229 | }
--------------------------------------------------------------------------------
/nodeology/log.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
3 |
4 | Copyright 2024. UChicago Argonne, LLC. This software was produced
5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
7 | U.S. Department of Energy. The U.S. Government has rights to use,
8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
11 | modified to produce derivative works, such modified software should
12 | be clearly marked, so as not to confuse it with the version available
13 | from ANL.
14 |
15 | Additionally, redistribution and use in source and binary forms, with
16 | or without modification, are permitted provided that the following
17 | conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright
20 | notice, this list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright
23 | notice, this list of conditions and the following disclaimer in
24 | the documentation and/or other materials provided with the
25 | distribution.
26 |
27 | * Neither the name of UChicago Argonne, LLC, Argonne National
28 | Laboratory, ANL, the U.S. Government, nor the names of its
29 | contributors may be used to endorse or promote products derived
30 | from this software without specific prior written permission.
31 |
32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43 | POSSIBILITY OF SUCH DAMAGE.
44 | """
45 |
46 | ### Initial Author <2024>: Xiangyu Yin
47 |
48 | import os, logging, logging.config
49 | import sys
50 |
51 | logger = logging.getLogger(__name__)
52 |
53 |
54 | def setup_logging(log_dir, log_name, debug_mode=False, base_dir=None):
55 | """Configure the logging system with console and/or file handlers.
56 |
57 | Args:
58 | log_dir (str): Directory where log files will be stored
59 | log_name (str): Name of the log file (without extension)
60 | debug_mode (bool): If True, only console logging with debug level is enabled
61 | base_dir (str, optional): Base directory to prepend to log_dir
62 | """
63 | # Get root logger
64 | root_logger = logging.getLogger()
65 |
66 | # Remove any existing handlers first
67 | for handler in root_logger.handlers[:]:
68 | handler.close() # Properly close handlers
69 | root_logger.removeHandler(handler)
70 |
71 | # Create handlers
72 | console_handler = logging.StreamHandler(sys.stdout)
73 |
74 | # Set the root logger level to DEBUG to capture all messages
75 | root_logger.setLevel(logging.DEBUG)
76 |
77 | # Use base_dir if provided, otherwise use log_dir directly
78 | full_log_dir = os.path.join(base_dir, log_dir) if base_dir else log_dir
79 |
80 | if not os.path.exists(full_log_dir):
81 | os.makedirs(full_log_dir)
82 |
83 | log_file_path = f"{full_log_dir}/{log_name}.log"
84 |
85 | if os.path.isfile(log_file_path):
86 | log_print_color(
87 | f"WARNING: {log_file_path} already exists and will be overwritten.",
88 | "red",
89 | )
90 |
91 | # Create file handler for both modes
92 | file_handler = logging.FileHandler(log_file_path, "w")
93 |
94 | if debug_mode:
95 | # Debug mode configuration
96 | # Console shows DEBUG and above
97 | console_handler.setLevel(logging.DEBUG)
98 | console_format = logging.Formatter(
99 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
100 | )
101 | console_handler.setFormatter(console_format)
102 |
103 | # File handler captures everything
104 | file_handler.setLevel(logging.DEBUG)
105 | file_format = logging.Formatter(
106 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
107 | )
108 | file_handler.setFormatter(file_format)
109 | else:
110 | # Production mode configuration
111 | # Console only shows PRINTLOG and WARNING+ messages
112 | console_handler.setLevel(logging.PRINTLOG)
113 | console_format = logging.Formatter("%(message)s")
114 | console_handler.setFormatter(console_format)
115 |
116 | # File handler captures everything (DEBUG and above)
117 | file_handler.setLevel(logging.DEBUG)
118 | file_format = logging.Formatter(
119 | "%(asctime)s - %(message)s", datefmt="%Y%m%d-%H:%M:%S"
120 | )
121 | file_handler.setFormatter(file_format)
122 |
123 | # Add handlers to root logger
124 | root_logger.addHandler(console_handler)
125 | root_logger.addHandler(file_handler)
126 |
127 | # Configure third-party loggers to be less verbose
128 | # This prevents them from cluttering the console
129 | for logger_name in logging.root.manager.loggerDict:
130 | if logger_name != __name__ and not logger_name.startswith("nodeology"):
131 | third_party_logger = logging.getLogger(logger_name)
132 | if debug_mode:
133 | # In debug mode, third-party loggers show WARNING and above
134 | third_party_logger.setLevel(logging.WARNING)
135 | else:
136 | # In production mode, third-party loggers show ERROR and above
137 | third_party_logger.setLevel(logging.ERROR)
138 |
139 | # Store handlers in logger for later cleanup
140 | root_logger.handlers_to_close = root_logger.handlers[:]
141 |
142 |
143 | def cleanup_logging():
144 | """Properly clean up logging handlers to prevent resource leaks."""
145 | root_logger = logging.getLogger()
146 |
147 | # Close and remove any existing handlers
148 | if hasattr(root_logger, "handlers_to_close"):
149 | for handler in root_logger.handlers_to_close:
150 | try:
151 | handler.close()
152 | except:
153 | pass # Ignore errors during cleanup
154 | if handler in root_logger.handlers:
155 | root_logger.removeHandler(handler)
156 | root_logger.handlers_to_close = []
157 |
158 |
159 | # https://stackoverflow.com/questions/2183233/how-to-add-a-custom-loglevel-to-pythons-logging-facility/35804945#35804945
160 | def add_logging_level(levelName, levelNum, methodName=None):
161 | """Add a new logging level to the logging module.
162 |
163 | Args:
164 | levelName (str): Name of the new level (e.g., 'TRACE')
165 | levelNum (int): Numeric value for the level
166 | methodName (str, optional): Method name to add. Defaults to levelName.lower()
167 |
168 | Raises:
169 | AttributeError: If levelName or methodName already exists
170 | """
171 | if not methodName:
172 | methodName = levelName.lower()
173 |
174 | if hasattr(logging, levelName):
175 | raise AttributeError("{} already defined in logging module".format(levelName))
176 | if hasattr(logging, methodName):
177 | raise AttributeError("{} already defined in logging module".format(methodName))
178 | if hasattr(logging.getLoggerClass(), methodName):
179 | raise AttributeError("{} already defined in logger class".format(methodName))
180 |
181 | # This method was inspired by the answers to Stack Overflow post
182 | # http://stackoverflow.com/q/2183233/2988730, especially
183 | # http://stackoverflow.com/a/13638084/2988730
184 | def logForLevel(self, message, *args, **kwargs):
185 | if self.isEnabledFor(levelNum):
186 | self._log(levelNum, message, args, **kwargs)
187 |
188 | def logToRoot(message, *args, **kwargs):
189 | logging.log(levelNum, message, *args, **kwargs)
190 |
191 | logging.addLevelName(levelNum, levelName)
192 | setattr(logging, levelName, levelNum)
193 | setattr(logging.getLoggerClass(), methodName, logForLevel)
194 | setattr(logging, methodName, logToRoot)
195 |
196 |
197 | def log_print_color(text, color="", print_to_console=True):
198 | """Print colored text to console and log it to file.
199 |
200 | Args:
201 | text (str): Text to print and log
202 | color (str): Color name ('green', 'red', 'blue', 'yellow', or '' for white)
203 | print_to_console (bool): If True, print the text to console
204 | """
205 | # Define color codes as constants at the top of the function
206 | COLOR_CODES = {
207 | "green": "\033[92m",
208 | "red": "\033[91m",
209 | "blue": "\033[94m",
210 | "yellow": "\033[93m",
211 | "": "\033[97m", # default white
212 | }
213 |
214 | # Get color code from dictionary, defaulting to white
215 | ansi_code = COLOR_CODES.get(color, COLOR_CODES[""])
216 |
217 | if print_to_console:
218 | print(ansi_code + text + "\033[0m")
219 | logger.logonly(text)
220 |
--------------------------------------------------------------------------------
/tests/test_state.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
3 |
4 | Copyright 2024. UChicago Argonne, LLC. This software was produced
5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
7 | U.S. Department of Energy. The U.S. Government has rights to use,
8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
11 | modified to produce derivative works, such modified software should
12 | be clearly marked, so as not to confuse it with the version available
13 | from ANL.
14 |
15 | Additionally, redistribution and use in source and binary forms, with
16 | or without modification, are permitted provided that the following
17 | conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright
20 | notice, this list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright
23 | notice, this list of conditions and the following disclaimer in
24 | the documentation and/or other materials provided with the
25 | distribution.
26 |
27 | * Neither the name of UChicago Argonne, LLC, Argonne National
28 | Laboratory, ANL, the U.S. Government, nor the names of its
29 | contributors may be used to endorse or promote products derived
30 | from this software without specific prior written permission.
31 |
32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43 | POSSIBILITY OF SUCH DAMAGE.
44 | """
45 |
46 | ### Initial Author <2024>: Xiangyu Yin
47 |
48 | import numpy as np
49 | from typing import List, Dict, Union
50 | import pytest
51 | from nodeology.state import (
52 | process_state_definitions,
53 | _resolve_state_type,
54 | _type_from_str,
55 | )
56 |
57 |
58 | class TestTypeResolution:
59 | """Tests for basic type resolution functionality"""
60 |
61 | def test_primitive_types(self):
62 | """Test resolution of primitive types"""
63 | assert _resolve_state_type("str") == str
64 | assert _resolve_state_type("int") == int
65 | assert _resolve_state_type("float") == float
66 | assert _resolve_state_type("bool") == bool
67 | assert _resolve_state_type("ndarray") == np.ndarray
68 |
69 | def test_list_types(self):
70 | """Test resolution of List types"""
71 | assert _resolve_state_type("List[str]") == List[str]
72 | assert _resolve_state_type("List[int]") == List[int]
73 | assert _resolve_state_type("List[bool]") == List[bool]
74 |
75 | def test_dict_types(self):
76 | """Test resolution of Dict types"""
77 | assert _resolve_state_type("Dict[str, int]") == Dict[str, int]
78 | assert _resolve_state_type("Dict[str, List[int]]") == Dict[str, List[int]]
79 | assert (
80 | _resolve_state_type("Dict[str, Dict[str, bool]]")
81 | == Dict[str, Dict[str, bool]]
82 | )
83 |
84 | def test_nested_types(self):
85 | """Test resolution of deeply nested types"""
86 | complex_type = "Dict[str, List[Dict[str, Union[int, str]]]]"
87 | expected = Dict[str, List[Dict[str, Union[int, str]]]]
88 | assert _resolve_state_type(complex_type) == expected
89 |
90 | def test_numpy_composite_types(self):
91 | """Test resolution of composite types involving numpy arrays"""
92 | assert _resolve_state_type("List[ndarray]") == List[np.ndarray]
93 | assert _resolve_state_type("Dict[str, ndarray]") == Dict[str, np.ndarray]
94 | assert (
95 | _resolve_state_type("Dict[str, List[ndarray]]")
96 | == Dict[str, List[np.ndarray]]
97 | )
98 | assert _resolve_state_type("Union[ndarray, int]") == Union[np.ndarray, int]
99 |
100 | def test_type_conversion_symmetry(self):
101 | """Test that type conversion is symmetrical"""
102 | test_cases = [
103 | str,
104 | int,
105 | float,
106 | bool,
107 | List[str],
108 | List[int],
109 | Dict[str, int],
110 | Dict[str, List[str]],
111 | Union[str, int],
112 | Union[str, List[int]],
113 | np.ndarray,
114 | List[np.ndarray],
115 | Dict[str, np.ndarray],
116 | Dict[str, List[np.ndarray]],
117 | Union[np.ndarray, int],
118 | Union[List[np.ndarray], Dict[str, np.ndarray]],
119 | ]
120 |
121 | for type_obj in test_cases:
122 | # Convert type to string
123 | type_str = _type_from_str(type_obj)
124 | # Convert string back to type
125 | resolved_type = _resolve_state_type(type_str)
126 | # Verify they're equivalent
127 | assert str(resolved_type) == str(
128 | type_obj
129 | ), f"Type conversion failed for {type_obj}"
130 |
131 |
132 | class TestErrorHandling:
133 | """Tests for error handling in type resolution"""
134 |
135 | def test_invalid_type_names(self):
136 | """Test handling of invalid type names"""
137 | with pytest.raises(ValueError, match="Unknown state type"):
138 | _resolve_state_type("InvalidType")
139 |
140 | with pytest.raises(ValueError, match="Unknown state type"):
141 | _resolve_state_type("List[InvalidType]")
142 |
143 | def test_malformed_type_strings(self):
144 | """Test handling of malformed type strings"""
145 | invalid_types = [
146 | "List[str", # Missing closing bracket
147 | "Dict[str]", # Missing value type
148 | "Dict[str,]", # Empty value type
149 | "Union[]", # Empty union
150 | "List[]", # Empty list type
151 | "[str]", # Invalid format
152 | ]
153 |
154 | for invalid_type in invalid_types:
155 | with pytest.raises(ValueError):
156 | _resolve_state_type(invalid_type)
157 |
158 | def test_invalid_dict_formats(self):
159 | """Test handling of invalid dictionary formats"""
160 | with pytest.raises(ValueError):
161 | _resolve_state_type("Dict")
162 |
163 | with pytest.raises(ValueError):
164 | _resolve_state_type("Dict[str, int, bool]")
165 |
166 |
167 | class TestStateDefinitionProcessing:
168 | """Tests for state definition processing"""
169 |
170 | def test_dict_state_definition(self):
171 | """Test processing of dictionary state definitions"""
172 | state_def = {"name": "test_field", "type": "str"}
173 | result = process_state_definitions([state_def], {})
174 | assert result == [("test_field", str)]
175 |
176 | def test_custom_type_processing(self):
177 | """Test processing with custom types in registry"""
178 |
179 | class CustomType:
180 | pass
181 |
182 | registry = {"CustomType": CustomType}
183 |
184 | # Test direct custom type reference
185 | assert process_state_definitions(["CustomType"], registry) == [CustomType]
186 |
187 | # Test custom type in dictionary definition
188 | state_def = {
189 | "name": "custom_field",
190 | "type": "str",
191 | } # Can't use CustomType directly in type string
192 | result = process_state_definitions([state_def], registry)
193 | assert result == [("custom_field", str)]
194 |
195 | def test_list_state_definition(self):
196 | """Test processing of list format state definitions"""
197 | # Single list definition
198 | state_def = ["test_field", "List[int]"]
199 | result = process_state_definitions([state_def], {})
200 | assert result == [("test_field", List[int])]
201 |
202 | # Multiple list definitions
203 | state_defs = [["field1", "str"], ["field2", "List[int]"]]
204 | result = process_state_definitions(state_defs, {})
205 | assert result == [("field1", str), ("field2", List[int])]
206 |
207 | def test_process_mixed_definitions(self):
208 | """Test processing mixed format state definitions"""
209 | state_defs = [
210 | {"name": "field1", "type": "str"},
211 | ["field2", "List[int]"],
212 | {"name": "field3", "type": "Dict[str, bool]"},
213 | ]
214 | result = process_state_definitions(state_defs, {})
215 | assert result == [
216 | ("field1", str),
217 | ("field2", List[int]),
218 | ("field3", Dict[str, bool]),
219 | ]
220 |
221 | def test_numpy_array_state_definition(self):
222 | """Test processing of numpy array state definitions"""
223 | # Test direct ndarray type
224 | state_def = {"name": "array_field", "type": "ndarray"}
225 | result = process_state_definitions([state_def], {})
226 | assert result == [("array_field", np.ndarray)]
227 |
228 | # Test in list format
229 | state_def = ["array_field2", "ndarray"]
230 | result = process_state_definitions([state_def], {})
231 | assert result == [("array_field2", np.ndarray)]
232 |
233 | # Test in mixed definitions
234 | state_defs = [
235 | {"name": "field1", "type": "str"},
236 | ["array_field", "ndarray"],
237 | {"name": "field3", "type": "Dict[str, bool]"},
238 | ]
239 | result = process_state_definitions(state_defs, {})
240 | assert result == [
241 | ("field1", str),
242 | ("array_field", np.ndarray),
243 | ("field3", Dict[str, bool]),
244 | ]
245 |
246 | def test_numpy_composite_state_definition(self):
247 | """Test processing of composite state definitions with numpy arrays"""
248 | state_defs = [
249 | {"name": "array_list", "type": "List[ndarray]"},
250 | {"name": "array_dict", "type": "Dict[str, ndarray]"},
251 | ["nested_arrays", "Dict[str, List[ndarray]]"],
252 | {"name": "mixed_type", "type": "Union[ndarray, int]"},
253 | ]
254 |
255 | result = process_state_definitions(state_defs, {})
256 | assert result == [
257 | ("array_list", List[np.ndarray]),
258 | ("array_dict", Dict[str, np.ndarray]),
259 | ("nested_arrays", Dict[str, List[np.ndarray]]),
260 | ("mixed_type", Union[np.ndarray, int]),
261 | ]
262 |
--------------------------------------------------------------------------------
/examples/public/logo_light.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 | LangGraph N o d e o l o g y
58 |
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/examples/trajectory_analysis.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
3 |
4 | Copyright 2024. UChicago Argonne, LLC. This software was produced
5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
7 | U.S. Department of Energy. The U.S. Government has rights to use,
8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
11 | modified to produce derivative works, such modified software should
12 | be clearly marked, so as not to confuse it with the version available
13 | from ANL.
14 |
15 | Additionally, redistribution and use in source and binary forms, with
16 | or without modification, are permitted provided that the following
17 | conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright
20 | notice, this list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright
23 | notice, this list of conditions and the following disclaimer in
24 | the documentation and/or other materials provided with the
25 | distribution.
26 |
27 | * Neither the name of UChicago Argonne, LLC, Argonne National
28 | Laboratory, ANL, the U.S. Government, nor the names of its
29 | contributors may be used to endorse or promote products derived
30 | from this software without specific prior written permission.
31 |
32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43 | POSSIBILITY OF SUCH DAMAGE.
44 | """
45 |
46 | ### Initial Author <2025>: Xiangyu Yin
47 |
48 | import json
49 | import tempfile
50 | import numpy as np
51 | from scipy.integrate import solve_ivp
52 | from typing import List, Dict
53 | from langgraph.graph import END
54 | import chainlit as cl
55 | from chainlit import Message, AskUserMessage, AskActionMessage, run_sync
56 | from nodeology.state import State
57 | from nodeology.node import Node, as_node
58 | from nodeology.workflow import Workflow
59 | import plotly.graph_objects as go
60 |
61 |
62 | class TrajectoryState(State):
63 | """State for particle trajectory analysis workflow"""
64 |
65 | # Parameters
66 | mass: float # Particle mass (kg)
67 | charge: float # Particle charge (C)
68 | initial_velocity: np.ndarray # Initial velocity vector [vx, vy, vz]
69 | E_field: np.ndarray # Electric field vector [Ex, Ey, Ez]
70 | B_field: np.ndarray # Magnetic field vector [Bx, By, Bz]
71 |
72 | # Confirm parameters
73 | confirm_parameters: bool
74 |
75 | # Parameters updater
76 | parameters_updater_output: str
77 |
78 | # Calculation results
79 | positions: List[np.ndarray] # Position vectors at each time point
80 |
81 | # Image
82 | trajectory_plot: str
83 | trajectory_plot_path: str
84 |
85 | # Analysis results
86 | analysis_result: Dict
87 |
88 | # Continue simulation
89 | continue_simulation: bool
90 |
91 |
92 | @as_node(sink=[])
93 | def display_parameters(
94 | mass: float,
95 | charge: float,
96 | initial_velocity: np.ndarray,
97 | E_field: np.ndarray,
98 | B_field: np.ndarray,
99 | ):
100 | # Create a dictionary of parameters for the custom element
101 | parameters = {
102 | "Mass (kg)": mass,
103 | "Charge (C)": charge,
104 | "Initial Velocity (m/s)": initial_velocity.tolist(),
105 | "Electric Field (N/C)": E_field.tolist(),
106 | "Magnetic Field (T)": B_field.tolist(),
107 | }
108 |
109 | # Use the custom element to display parameters
110 | run_sync(
111 | Message(
112 | content="Below are the current simulation parameters:",
113 | elements=[
114 | cl.CustomElement(
115 | name="DataDisplay",
116 | props={
117 | "data": parameters,
118 | "title": "Particle Parameters",
119 | "badge": "Configured",
120 | "showScrollArea": False,
121 | },
122 | )
123 | ],
124 | ).send()
125 | )
126 | return
127 |
128 |
129 | @as_node(sink="confirm_parameters")
130 | def ask_confirm_parameters():
131 | res = run_sync(
132 | AskActionMessage(
133 | content="Are you happy with the parameters?",
134 | timeout=300,
135 | actions=[
136 | cl.Action(
137 | name="yes",
138 | payload={"value": "yes"},
139 | label="Yes",
140 | ),
141 | cl.Action(
142 | name="no",
143 | payload={"value": "no"},
144 | label="No",
145 | ),
146 | ],
147 | ).send()
148 | )
149 | if res and res.get("payload").get("value") == "yes":
150 | return True
151 | else:
152 | return False
153 |
154 |
155 | @as_node(sink=["human_input"])
156 | def ask_parameters_input():
157 | human_input = run_sync(
158 | AskUserMessage(
159 | content="Please let me know how you want to change any of the parameters :)",
160 | timeout=300,
161 | ).send()
162 | )["output"]
163 | return human_input
164 |
165 |
166 | parameters_updater = Node(
167 | node_type="parameters_updater",
168 | prompt_template="""Update the parameters based on the user's input.
169 |
170 | Current parameters:
171 | mass: {mass}
172 | charge: {charge}
173 | initial_velocity: {initial_velocity}
174 | E_field: {E_field}
175 | B_field: {B_field}
176 |
177 | User input:
178 | {human_input}
179 |
180 | Please return the updated parameters in JSON format.
181 | {{
182 | "mass": float,
183 | "charge": float,
184 | "initial_velocity": list[float],
185 | "E_field": list[float],
186 | "B_field": list[float]
187 | }}
188 | """,
189 | sink="parameters_updater_output",
190 | sink_format="json",
191 | )
192 |
193 |
194 | def parameters_updater_transform(state, client, **kwargs):
195 | params_dict = json.loads(state["parameters_updater_output"])
196 | state["mass"] = params_dict["mass"]
197 | state["charge"] = params_dict["charge"]
198 | state["initial_velocity"] = np.array(params_dict["initial_velocity"])
199 | state["E_field"] = np.array(params_dict["E_field"])
200 | state["B_field"] = np.array(params_dict["B_field"])
201 | return state
202 |
203 |
204 | parameters_updater.post_process = parameters_updater_transform
205 |
206 |
207 | @as_node(sink=["positions"])
208 | def calculate_trajectory(
209 | mass: float,
210 | charge: float,
211 | initial_velocity: np.ndarray,
212 | E_field: np.ndarray,
213 | B_field: np.ndarray,
214 | ) -> List[np.ndarray]:
215 | """Calculate particle trajectory under Lorentz force with automatic time steps"""
216 | B_magnitude = np.linalg.norm(B_field)
217 | if B_magnitude == 0 or charge == 0:
218 | # Handle the case where B=0 or charge=0 (no magnetic force)
219 | cyclotron_period = 1e-6 # Arbitrary time scale
220 | else:
221 | cyclotron_frequency = np.abs(charge) * B_magnitude / mass
222 | cyclotron_period = 2 * np.pi / cyclotron_frequency
223 |
224 | # Determine total simulation time and time steps
225 | num_periods = 5 # Simulate over 5 cyclotron periods
226 | num_points_per_period = 100 # At least 100 points per period
227 | total_time = num_periods * cyclotron_period
228 | total_points = int(num_periods * num_points_per_period)
229 | time_points = np.linspace(0, total_time, total_points)
230 |
231 | def lorentz_force(t, state):
232 | """Compute acceleration from Lorentz force"""
233 | vel = state[3:]
234 | force = charge * (E_field + np.cross(vel, B_field))
235 | acc = force / mass
236 | return np.concatenate([vel, acc])
237 |
238 | # Initial state vector [x, y, z, vx, vy, vz]
239 | initial_position = np.array([0.0, 0.0, 0.0])
240 | initial_state = np.concatenate([initial_position, initial_velocity])
241 |
242 | # Solve equations of motion
243 | solution = solve_ivp(
244 | lorentz_force,
245 | (time_points[0], time_points[-1]),
246 | initial_state,
247 | t_eval=time_points,
248 | method="RK45",
249 | rtol=1e-8,
250 | )
251 |
252 | if not solution.success:
253 | return [np.zeros(3) for _ in range(len(time_points))]
254 |
255 | return [solution.y[:3, i] for i in range(len(time_points))]
256 |
257 |
258 | @as_node(sink=["trajectory_plot", "trajectory_plot_path"])
259 | def plot_trajectory(positions: List[np.ndarray]) -> str:
260 | """Plot 3D particle trajectory and save to temp file
261 |
262 | Returns:
263 | tuple: (Plotly figure object, path to saved plot image)
264 | """
265 | positions = np.array(positions)
266 |
267 | # Create a Plotly 3D scatter plot
268 | fig = go.Figure(
269 | data=[
270 | go.Scatter3d(
271 | x=positions[:, 0],
272 | y=positions[:, 1],
273 | z=positions[:, 2],
274 | mode="lines",
275 | line=dict(width=4, color="green"),
276 | )
277 | ]
278 | )
279 |
280 | # Update layout
281 | fig.update_layout(
282 | scene=dict(xaxis_title="X (m)", yaxis_title="Y (m)", zaxis_title="Z (m)"),
283 | )
284 |
285 | # Save to temp file
286 | temp_path = tempfile.mktemp(suffix=".png")
287 | fig.write_image(temp_path)
288 |
289 | run_sync(
290 | Message(
291 | content="Below is the trajectory plot of the particle:",
292 | elements=[cl.Plotly(figure=fig)],
293 | ).send()
294 | )
295 |
296 | return fig, temp_path
297 |
298 |
299 | trajectory_analyzer = Node(
300 | node_type="trajectory_analyzer",
301 | prompt_template="""Analyze this particle trajectory plot.
302 |
303 | Please determine:
304 | 1. The type of motion (linear, circular, helical, or chaotic)
305 | 2. Key physical features (radius, period, pitch angle if applicable)
306 | 3. Explanation of the motion
307 | 4. Anomalies in the motion
308 | Output in JSON format:
309 | {{
310 | "trajectory_type": "type_name",
311 | "key_features": {
312 | "feature1": value,
313 | "feature2": value
314 | },
315 | "explanation": "detailed explanation",
316 | "anomalies": "anomaly description"
317 | }}""",
318 | sink="analysis_result",
319 | sink_format="json",
320 | image_keys=["trajectory_plot_path"],
321 | )
322 |
323 |
324 | def display_trajectory_analyzer_result(state, client, **kwargs):
325 | state["analysis_result"] = json.loads(state["analysis_result"])
326 |
327 | # Use the custom element to display analysis results
328 | run_sync(
329 | Message(
330 | content="Below is the analysis of the particle trajectory:",
331 | elements=[
332 | cl.CustomElement(
333 | name="DataDisplay",
334 | props={
335 | "data": state["analysis_result"],
336 | "title": "Trajectory Analysis",
337 | "badge": state["analysis_result"].get(
338 | "trajectory_type", "Unknown"
339 | ),
340 | "maxHeight": "400px",
341 | },
342 | )
343 | ],
344 | ).send()
345 | )
346 | return state
347 |
348 |
349 | trajectory_analyzer.post_process = display_trajectory_analyzer_result
350 |
351 |
352 | @as_node(sink="continue_simulation")
353 | def ask_continue_simulation():
354 | res = run_sync(
355 | AskActionMessage(
356 | content="Would you like to continue the simulation?",
357 | timeout=300,
358 | actions=[
359 | cl.Action(
360 | name="continue",
361 | payload={"value": "continue"},
362 | label="Continue Simulation",
363 | ),
364 | cl.Action(
365 | name="finish",
366 | payload={"value": "finish"},
367 | label="Finish",
368 | ),
369 | ],
370 | ).send()
371 | )
372 |
373 | # Return the user's choice
374 | if res and res.get("payload").get("value") == "continue":
375 | return True
376 | else:
377 | return False
378 |
379 |
380 | class TrajectoryWorkflow(Workflow):
381 | """Workflow for particle trajectory analysis"""
382 |
383 | def create_workflow(self):
384 | """Define the workflow graph structure"""
385 | # Add nodes
386 | self.add_node("display_parameters", display_parameters)
387 | self.add_node("ask_confirm_parameters", ask_confirm_parameters)
388 | self.add_node("ask_parameters_input", ask_parameters_input)
389 | self.add_node("update_parameters", parameters_updater)
390 | self.add_node("calculate_trajectory", calculate_trajectory)
391 | self.add_node("plot_trajectory", plot_trajectory)
392 | self.add_node("analyze_trajectory", trajectory_analyzer)
393 | self.add_node("ask_continue_simulation", ask_continue_simulation)
394 |
395 | self.add_flow("display_parameters", "ask_confirm_parameters")
396 | self.add_conditional_flow(
397 | "ask_confirm_parameters",
398 | "confirm_parameters",
399 | then="calculate_trajectory",
400 | otherwise="ask_parameters_input",
401 | )
402 | self.add_flow("ask_parameters_input", "update_parameters")
403 | self.add_flow("update_parameters", "display_parameters")
404 | self.add_flow("calculate_trajectory", "plot_trajectory")
405 | self.add_flow("plot_trajectory", "analyze_trajectory")
406 | self.add_flow("analyze_trajectory", "ask_continue_simulation")
407 | self.add_conditional_flow(
408 | "ask_continue_simulation",
409 | "continue_simulation",
410 | then="display_parameters",
411 | otherwise=END,
412 | )
413 |
414 | # Set entry point
415 | self.set_entry("display_parameters")
416 |
417 | # Compile workflow
418 | self.compile()
419 |
420 |
421 | if __name__ == "__main__":
422 | workflow = TrajectoryWorkflow(
423 | state_defs=TrajectoryState,
424 | llm_name="gemini/gemini-2.0-flash",
425 | vlm_name="gemini/gemini-2.0-flash",
426 | debug_mode=False,
427 | )
428 |
429 | # # Export workflow to YAML file
430 | # workflow.to_yaml("particle_trajectory_analysis.yaml")
431 |
432 | # # Print workflow graph
433 | # workflow.graph.get_graph().draw_mermaid_png(
434 | # output_file_path="particle_trajectory_analysis.png"
435 | # )
436 |
437 | initial_state = {
438 | "mass": 9.1093837015e-31, # electron mass in kg
439 | "charge": -1.602176634e-19, # electron charge in C
440 | "initial_velocity": np.array([1e6, 1e6, 1e6]), # 1e6 m/s in each direction
441 | "E_field": np.array([5e6, 1e6, 5e6]), # 1e6 N/C in y-direction
442 | "B_field": np.array(
443 | [0.0, 0.0, 50000.0]
444 | ), # deliberately typo to be caught by validation
445 | }
446 |
447 | result = workflow.run(init_values=initial_state, ui=True)
448 |
--------------------------------------------------------------------------------
/nodeology/client.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
3 |
4 | Copyright 2024. UChicago Argonne, LLC. This software was produced
5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
7 | U.S. Department of Energy. The U.S. Government has rights to use,
8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
11 | modified to produce derivative works, such modified software should
12 | be clearly marked, so as not to confuse it with the version available
13 | from ANL.
14 |
15 | Additionally, redistribution and use in source and binary forms, with
16 | or without modification, are permitted provided that the following
17 | conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright
20 | notice, this list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright
23 | notice, this list of conditions and the following disclaimer in
24 | the documentation and/or other materials provided with the
25 | distribution.
26 |
27 | * Neither the name of UChicago Argonne, LLC, Argonne National
28 | Laboratory, ANL, the U.S. Government, nor the names of its
29 | contributors may be used to endorse or promote products derived
30 | from this software without specific prior written permission.
31 |
32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43 | POSSIBILITY OF SUCH DAMAGE.
44 | """
45 |
46 | ### Initial Author <2024>: Xiangyu Yin
47 |
48 | import os, base64, json, getpass
49 | from abc import ABC, abstractmethod
50 | import litellm
51 | from datetime import datetime
52 |
53 |
54 | def get_client(model_name, **kwargs):
55 | """
56 | Factory function to create appropriate client based on model name.
57 |
58 | Handles three scenarios:
59 | 1. Just model name (e.g., "gpt-4o") - Let LiteLLM figure out the provider
60 | 2. Model name with provider keyword (e.g., model="gpt-4o", provider="openai")
61 | 3. Provider/name convention (e.g., "openai/gpt-4o")
62 |
63 | Args:
64 | model_name (str): Name of the model to use
65 | **kwargs: Additional arguments including optional 'provider'
66 |
67 | Returns:
68 | LLM_Client or VLM_Client: Appropriate client instance for the requested model
69 | """
70 | # Handle special clients first
71 | if model_name == "mock":
72 | return Mock_LLM_Client(**kwargs)
73 | elif model_name == "mock_vlm":
74 | return Mock_VLM_Client(**kwargs)
75 |
76 | # Get provider from kwargs if specified (Scenario 2)
77 | provider = kwargs.pop("provider", None)
78 |
79 | # Get tracing_enabled from kwargs
80 | tracing_enabled = kwargs.pop("tracing_enabled", False)
81 |
82 | # Handle provider/model format (Scenario 3)
83 | if "/" in model_name and provider is None:
84 | provider, model_name = model_name.split("/", 1)
85 |
86 | # Create LiteLLM client - for Scenario 1, provider will be None
87 | try:
88 | return LiteLLM_Client(
89 | model_name, provider=provider, tracing_enabled=tracing_enabled, **kwargs
90 | )
91 | except Exception as e:
92 | raise ValueError(f"Error creating client for model {model_name}: {e}")
93 |
94 |
95 | def configure_langfuse(public_key=None, secret_key=None, host=None, enabled=True):
96 | """
97 | Configure Langfuse for observability.
98 |
99 | Args:
100 | public_key (str, optional): Langfuse public key. Defaults to LANGFUSE_PUBLIC_KEY env var.
101 | secret_key (str, optional): Langfuse secret key. Defaults to LANGFUSE_SECRET_KEY env var.
102 | host (str, optional): Langfuse host URL. Defaults to LANGFUSE_HOST env var or https://cloud.langfuse.com.
103 | enabled (bool, optional): Whether to enable Langfuse tracing. Defaults to True.
104 | """
105 | if not enabled:
106 | litellm.success_callback = []
107 | litellm.failure_callback = []
108 | return
109 |
110 | # Set environment variables if provided
111 | if public_key:
112 | os.environ["LANGFUSE_PUBLIC_KEY"] = public_key
113 | if secret_key:
114 | os.environ["LANGFUSE_SECRET_KEY"] = secret_key
115 | if host:
116 | os.environ["LANGFUSE_HOST"] = host
117 |
118 | litellm.success_callback = ["langfuse"]
119 | litellm.failure_callback = ["langfuse"]
120 |
121 |
122 | class LLM_Client(ABC):
123 | """Base abstract class for Language Model clients."""
124 |
125 | def __init__(self) -> None:
126 | pass
127 |
128 | @abstractmethod
129 | def __call__(self, messages, **kwargs) -> str:
130 | """
131 | Process messages and return model response.
132 |
133 | Args:
134 | messages (list): List of message dictionaries with 'role' and 'content'
135 | **kwargs: Additional model-specific parameters
136 |
137 | Returns:
138 | str: Model's response text
139 | """
140 | pass
141 |
142 |
143 | class VLM_Client(LLM_Client):
144 | """Base abstract class for Vision Language Model clients."""
145 |
146 | def __init__(self) -> None:
147 | super().__init__()
148 |
149 | @abstractmethod
150 | def process_images(self, messages, images, **kwargs) -> list:
151 | """
152 | Process and format images for the model.
153 |
154 | Args:
155 | messages (list): List of message dictionaries
156 | images (list): List of image file paths
157 | **kwargs: Additional processing parameters
158 |
159 | Returns:
160 | list: Updated messages with processed images
161 | """
162 | pass
163 |
164 |
165 | class Mock_LLM_Client(LLM_Client):
166 | def __init__(self, response=None, **kwargs) -> None:
167 | super().__init__()
168 | self.response = response
169 | self.model_name = "mock"
170 |
171 | def __call__(self, messages, **kwargs) -> str:
172 | response = (
173 | "\n".join([msg["role"] + ": " + msg["content"] for msg in messages])
174 | if self.response is None
175 | else self.response
176 | )
177 | return response
178 |
179 |
180 | class Mock_VLM_Client(VLM_Client):
181 | def __init__(self, response=None, **kwargs) -> None:
182 | super().__init__()
183 | self.response = response
184 | self.model_name = "mock_vlm"
185 |
186 | def __call__(self, messages, images=None, **kwargs) -> str:
187 | if images is not None:
188 | messages = self.process_images(messages, images)
189 | if self.response is None:
190 | message_parts = []
191 | for msg in messages:
192 | content = msg["content"]
193 | if isinstance(content, str):
194 | message_parts.append(f"{msg['role']}: {content}")
195 | else: # content is already a list of text/image objects
196 | parts = []
197 | for item in content:
198 | if item["type"] == "text":
199 | parts.append(item["text"])
200 | elif item["type"] == "image":
201 | parts.append(f"[Image: {item['image_url']['url']}]")
202 | message_parts.append(f"{msg['role']}: {' '.join(parts)}")
203 | return "\n".join(message_parts)
204 | return self.response
205 |
206 | def process_images(self, messages, images, **kwargs) -> list:
207 | # Make a copy to avoid modifying the original
208 | messages = messages.copy()
209 |
210 | # Simply append a placeholder for each image
211 | for img in images:
212 | if isinstance(messages[-1]["content"], str):
213 | messages[-1]["content"] = [
214 | {"type": "text", "text": messages[-1]["content"]},
215 | {"type": "image", "image_url": {"url": f"mock_processed_{img}"}},
216 | ]
217 | elif isinstance(messages[-1]["content"], list):
218 | messages[-1]["content"].append(
219 | {"type": "image", "image_url": {"url": f"mock_processed_{img}"}}
220 | )
221 | return messages
222 |
223 |
224 | class LiteLLM_Client(VLM_Client):
225 | """
226 | Unified client for all LLM/VLM providers using LiteLLM.
227 | Supports both text and image inputs across multiple providers.
228 | """
229 |
230 | def __init__(
231 | self,
232 | model_name,
233 | provider=None,
234 | model_options=None,
235 | api_key=None,
236 | tracing_enabled=False,
237 | ) -> None:
238 | """
239 | Initialize LiteLLM client.
240 |
241 | Args:
242 | model_name (str): Name of the model to use
243 | provider (str, optional): Provider name (openai, anthropic, etc.)
244 | model_options (dict): Model parameters like temperature and top_p
245 | api_key (str, optional): API key for the specified provider
246 | tracing_enabled (bool, optional): Whether to enable Langfuse tracing. Defaults to False.
247 | """
248 | super().__init__()
249 | self.model_options = model_options if model_options else {}
250 | self.tracing_enabled = tracing_enabled
251 |
252 | # Set API key if provided
253 | if api_key and provider:
254 | os.environ[f"{provider.upper()}_API_KEY"] = api_key
255 |
256 | # Construct the model name for LiteLLM based on whether provider is specified
257 | # If provider is None, LiteLLM will infer the provider from the model name
258 | self.model_name = f"{provider}/{model_name}" if provider else model_name
259 |
260 | def collect_langfuse_metadata(
261 | self,
262 | workflow=None,
263 | node=None,
264 | **kwargs,
265 | ):
266 | """
267 | Collect metadata for Langfuse tracing from workflow and node information.
268 |
269 | Args:
270 | workflow: The workflow instance (optional)
271 | node: The node instance (optional)
272 | **kwargs: Additional metadata to include
273 |
274 | Returns:
275 | dict: Metadata dictionary formatted for Langfuse
276 | """
277 | metadata = {}
278 |
279 | timestamp = datetime.now().strftime("%Y%m%d")
280 | user_id = getpass.getuser()
281 | session_id_str = f"{user_id}-{timestamp}"
282 |
283 | metadata["trace_metadata"] = {}
284 |
285 | # Extract workflow metadata if available
286 | if workflow:
287 | # Use workflow class name as generation name
288 | metadata["generation_name"] = workflow.__class__.__name__
289 | session_id_str += f"-{workflow.__class__.__name__}"
290 |
291 | # Create a generation ID based on workflow name and timestamp
292 | metadata["generation_id"] = f"gen-{workflow.name}-{timestamp}"
293 |
294 | # Add user ID if available
295 | if hasattr(workflow, "user_name"):
296 | metadata["trace_user_id"] = workflow.user_name
297 | else:
298 | metadata["trace_user_id"] = user_id
299 |
300 | # Extract node metadata if available
301 | if node:
302 | # Use node type as trace name
303 | metadata["trace_name"] = node.node_type
304 | session_id_str += f"-{node.node_type}"
305 |
306 | # Add node metadata to trace metadata
307 | metadata["trace_metadata"].update(
308 | {
309 | "required_keys": node.required_keys,
310 | "sink": node.sink,
311 | "sink_format": node.sink_format,
312 | "image_keys": node.image_keys,
313 | "use_conversation": node.use_conversation,
314 | "prompt_template": node.prompt_template,
315 | }
316 | )
317 |
318 | # Add session ID based on timestamp
319 | metadata["session_id"] = f"session-{session_id_str}"
320 |
321 | # Add any additional metadata from kwargs
322 | metadata["trace_metadata"].update(kwargs)
323 |
324 | return metadata
325 |
326 | def process_images(self, messages, images):
327 | """
328 | Process and format images for the model using LiteLLM's format.
329 |
330 | Args:
331 | messages (list): List of message dictionaries
332 | images (list): List of image file paths
333 |
334 | Returns:
335 | list: Updated messages with processed images
336 | """
337 | # Make a copy to avoid modifying the original
338 | messages = messages.copy()
339 |
340 | # Convert images to base64
341 | image_contents = []
342 | for img in images:
343 | with open(img, "rb") as image_file:
344 | base64_image = base64.b64encode(image_file.read()).decode("utf-8")
345 | image_contents.append(
346 | {
347 | "type": "image_url",
348 | "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
349 | }
350 | )
351 |
352 | # Add images to the last message
353 | if isinstance(messages[-1]["content"], str):
354 | messages[-1]["content"] = [
355 | {"type": "text", "text": messages[-1]["content"]}
356 | ] + image_contents
357 | elif isinstance(messages[-1]["content"], list):
358 | messages[-1]["content"] += image_contents
359 |
360 | return messages
361 |
362 | def __call__(
363 | self, messages, images=None, format=None, workflow=None, node=None, **kwargs
364 | ) -> str:
365 | """
366 | Process messages and return model response using LiteLLM.
367 |
368 | Args:
369 | messages (list): List of message dictionaries
370 | images (list, optional): List of image file paths
371 | format (str, optional): Response format (e.g., 'json')
372 | workflow (optional): The workflow instance for metadata extraction
373 | node (optional): The node instance for metadata extraction
374 | **kwargs: Additional parameters including metadata for Langfuse
375 |
376 | Returns:
377 | str: Model's response text
378 | """
379 | # Process images if provided
380 | if images is not None:
381 | messages = self.process_images(messages, images)
382 |
383 | # Set up response format if needed
384 | response_format = {"type": "json_object"} if format == "json" else None
385 |
386 | # Extract Langfuse metadata only if tracing is enabled
387 | langfuse_metadata = {}
388 | if self.tracing_enabled:
389 | langfuse_metadata = self.collect_langfuse_metadata(
390 | workflow=workflow,
391 | node=node,
392 | **kwargs,
393 | )
394 |
395 | try:
396 | # Use LiteLLM's built-in retry mechanism with Langfuse metadata
397 | response = litellm.completion(
398 | model=self.model_name,
399 | messages=messages,
400 | response_format=response_format,
401 | num_retries=3,
402 | metadata=langfuse_metadata if self.tracing_enabled else {},
403 | **self.model_options,
404 | )
405 |
406 | content = response.choices[0].message.content
407 |
408 | # Validate JSON if requested
409 | if format == "json":
410 | try:
411 | json.loads(content)
412 | except json.JSONDecodeError:
413 | raise ValueError(f"Invalid JSON response from {self.model_name}")
414 |
415 | return content
416 |
417 | except Exception as e:
418 | raise ValueError(
419 | f"Failed to generate response from {self.model_name}. Error: {str(e)}"
420 | )
421 |
--------------------------------------------------------------------------------
/nodeology/state.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
3 |
4 | Copyright 2024. UChicago Argonne, LLC. This software was produced
5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
7 | U.S. Department of Energy. The U.S. Government has rights to use,
8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
11 | modified to produce derivative works, such modified software should
12 | be clearly marked, so as not to confuse it with the version available
13 | from ANL.
14 |
15 | Additionally, redistribution and use in source and binary forms, with
16 | or without modification, are permitted provided that the following
17 | conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright
20 | notice, this list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright
23 | notice, this list of conditions and the following disclaimer in
24 | the documentation and/or other materials provided with the
25 | distribution.
26 |
27 | * Neither the name of UChicago Argonne, LLC, Argonne National
28 | Laboratory, ANL, the U.S. Government, nor the names of its
29 | contributors may be used to endorse or promote products derived
30 | from this software without specific prior written permission.
31 |
32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43 | POSSIBILITY OF SUCH DAMAGE.
44 | """
45 |
46 | ### Initial Author <2024>: Xiangyu Yin
47 |
48 | import json
49 | import logging
50 | import numpy as np
51 | import plotly.graph_objects as go
52 | import plotly.io as pio
53 |
54 | logger = logging.getLogger(__name__)
55 | from typing import TypedDict, List, Dict, Union, Any, get_origin, get_args
56 | from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
57 | import msgpack
58 |
59 | StateBaseT = Union[str, int, float, bool, np.ndarray]
60 |
61 | """
62 | State management module for nodeology.
63 | Handles type definitions, state processing, and state registry management.
64 | """
65 |
66 |
67 | class State(TypedDict):
68 | """
69 | Base state class representing the core state structure.
70 | Contains node information, input/output data, and message history.
71 | """
72 |
73 | current_node_type: str
74 | previous_node_type: str
75 | human_input: str
76 | input: str
77 | output: str
78 | messages: List[dict]
79 |
80 |
81 | def _split_by_top_level_comma(s: str) -> List[str]:
82 | """Helper function to split by comma while respecting brackets"""
83 | parts = []
84 | current = []
85 | bracket_count = 0
86 |
87 | for char in s:
88 | if char == "[":
89 | bracket_count += 1
90 | elif char == "]":
91 | bracket_count -= 1
92 | elif char == "," and bracket_count == 0:
93 | parts.append("".join(current).strip())
94 | current = []
95 | continue
96 | current.append(char)
97 |
98 | if current:
99 | parts.append("".join(current).strip())
100 | return parts
101 |
102 |
103 | def _resolve_state_type(type_str: str):
104 | """
105 | Resolve string representations of types to actual Python types.
106 | """
107 | if not hasattr(_resolve_state_type, "_cache"):
108 | _resolve_state_type._cache = {}
109 |
110 | if type_str in _resolve_state_type._cache:
111 | return _resolve_state_type._cache[type_str]
112 |
113 | try:
114 | # Handle basic types
115 | if type_str in (
116 | "str",
117 | "int",
118 | "float",
119 | "bool",
120 | "dict",
121 | "list",
122 | "bytes",
123 | "tuple",
124 | "ndarray",
125 | ):
126 | if type_str == "ndarray":
127 | return np.ndarray
128 | return eval(type_str)
129 |
130 | if type_str.startswith("List[") and type_str.endswith("]"):
131 | inner_type = type_str[5:-1]
132 | return List[_resolve_state_type(inner_type)]
133 |
134 | elif type_str.startswith("Dict[") and type_str.endswith("]"):
135 | inner_str = type_str[5:-1]
136 | parts = _split_by_top_level_comma(inner_str)
137 | if len(parts) != 2:
138 | raise ValueError(f"Invalid Dict type format: {type_str}")
139 |
140 | key_type = _resolve_state_type(parts[0])
141 | value_type = _resolve_state_type(parts[1])
142 | return Dict[key_type, value_type]
143 |
144 | elif type_str.startswith("Union[") and type_str.endswith("]"):
145 | inner_str = type_str[6:-1]
146 | types = [
147 | _resolve_state_type(t) for t in _split_by_top_level_comma(inner_str)
148 | ]
149 | return Union[tuple(types)]
150 |
151 | else:
152 | raise ValueError(f"Unknown state type: {type_str}")
153 |
154 | except Exception as e:
155 | raise ValueError(f"Failed to resolve type '{type_str}': {str(e)}")
156 |
157 |
158 | def _process_dict_state_def(state_def: Dict) -> tuple:
159 | """
160 | Process a dictionary-format state definition.
161 |
162 | Supports both formats:
163 | - {'name': 'type'} format
164 | - {'name': str, 'type': str} format
165 |
166 | Args:
167 | state_def (Dict): Dictionary containing state definition
168 |
169 | Returns:
170 | tuple: (name, resolved_type)
171 |
172 | Raises:
173 | ValueError: If state definition is missing required fields
174 | """
175 | if len(state_def) == 1:
176 | # Handle {'name': 'type'} format
177 | name, type_str = next(iter(state_def.items()))
178 | else:
179 | # Handle {'name': str, 'type': str} format
180 | name = state_def.get("name")
181 | type_str = state_def.get("type")
182 | if not name or not type_str:
183 | raise ValueError(f"Invalid state definition: {state_def}")
184 |
185 | state_type = _resolve_state_type(type_str)
186 | return (name, state_type)
187 |
188 |
189 | def _process_list_state_def(state_def: List) -> List:
190 | """
191 | Process a list-format state definition.
192 |
193 | Supports two formats:
194 | 1. Single definition: [name, type_str]
195 | 2. Multiple definitions: [[name1, type_str1], [name2, type_str2], ...]
196 |
197 | Args:
198 | state_def (List): List containing state definitions
199 |
200 | Returns:
201 | List[tuple]: List of (name, resolved_type) tuples
202 |
203 | Raises:
204 | ValueError: If state definition format is invalid
205 | """
206 | if len(state_def) == 2 and isinstance(state_def[0], str):
207 | # Single list format [name, type_str]
208 | name, type_str = state_def
209 | state_type = _resolve_state_type(type_str)
210 | return [(name, state_type)]
211 | else:
212 | processed_lists = []
213 | for item in state_def:
214 | if not isinstance(item, list) or len(item) != 2:
215 | raise ValueError(f"Invalid state definition item: {item}")
216 | name, type_str = item
217 | state_type = _resolve_state_type(type_str)
218 | processed_lists.append((name, state_type))
219 | return processed_lists
220 |
221 |
222 | def process_state_definitions(state_defs: List, state_registry: dict):
223 | """
224 | Process state definitions from template format to internal format.
225 |
226 | Supports multiple input formats:
227 | - Dictionary format: {'name': 'type'} or {'name': str, 'type': str}
228 | - List format: [name, type_str] or [[name1, type_str1], ...]
229 | - String format: References to pre-defined states in state_registry
230 |
231 | Args:
232 | state_defs (List): List of state definitions in various formats
233 | state_registry (dict): Registry of pre-defined states
234 |
235 | Returns:
236 | List[tuple]: List of processed (name, type) pairs
237 |
238 | Raises:
239 | ValueError: If state definition format is invalid or state type is unknown
240 | """
241 | processed_state_defs = []
242 |
243 | for state_def in state_defs:
244 | if isinstance(state_def, dict):
245 | processed_state_defs.append(_process_dict_state_def(state_def))
246 | elif isinstance(state_def, list):
247 | processed_state_defs.extend(_process_list_state_def(state_def))
248 | elif isinstance(state_def, str):
249 | if state_def in state_registry:
250 | processed_state_defs.append(state_registry[state_def])
251 | else:
252 | raise ValueError(f"Unknown state type: {state_def}")
253 | else:
254 | raise ValueError(
255 | f"Invalid state definition format: {state_def}. Must be a string, "
256 | "[name, type] list, or {'name': 'type'} dictionary"
257 | )
258 |
259 | return processed_state_defs
260 |
261 |
262 | def _type_from_str(type_obj: type) -> str:
263 | """
264 | Convert a Python type object to a string representation that _resolve_state_type can parse.
265 | """
266 | # Add handling for numpy arrays
267 | if type_obj is np.ndarray:
268 | return "ndarray"
269 |
270 | # Handle basic types
271 | if type_obj in (str, int, float, bool, dict, list, bytes, tuple):
272 | return type_obj.__name__
273 |
274 | # Get the origin type
275 | origin = get_origin(type_obj)
276 | if origin is None:
277 | # More explicit handling of unknown types
278 | logger.warning(f"Unknown type {type_obj}, defaulting to None")
279 | return None
280 |
281 | # Handle List types
282 | if origin is list or origin is List:
283 | args = get_args(type_obj)
284 | if not args:
285 | return "list" # Default to list if no type args
286 | inner_type = _type_from_str(args[0])
287 | if inner_type is None:
288 | return "list"
289 | return f"List[{inner_type}]"
290 |
291 | # Handle Dict types
292 | if origin is dict or origin is Dict:
293 | args = get_args(type_obj)
294 | if not args or len(args) != 2:
295 | return "dict" # Default if no/invalid type args
296 | key_type = _type_from_str(args[0]) # Recursive call for key type
297 | value_type = _type_from_str(args[1]) # Recursive call for value type
298 | if key_type is None or value_type is None:
299 | return "dict"
300 | return f"Dict[{key_type}, {value_type}]"
301 |
302 | # Handle Union types
303 | if origin is Union:
304 | args = get_args(type_obj)
305 | if not args:
306 | return "tuple"
307 | types = [_type_from_str(arg) for arg in args]
308 | if any(t is None for t in types):
309 | return "tuple"
310 | return f"Union[{','.join(types)}]"
311 |
312 | # Default case
313 | return "str"
314 |
315 |
316 | class StateEncoder(json.JSONEncoder):
317 | """Custom JSON encoder for serializing workflow states."""
318 |
319 | def default(self, obj):
320 | try:
321 | if isinstance(obj, np.ndarray):
322 | return {
323 | "__type__": "ndarray",
324 | "data": obj.tolist(),
325 | "dtype": str(obj.dtype),
326 | }
327 | if isinstance(obj, go.Figure):
328 | return {
329 | "__type__": "plotly_figure",
330 | "data": pio.to_json(obj),
331 | }
332 | if hasattr(obj, "to_dict"):
333 | return obj.to_dict()
334 | if isinstance(obj, bytes):
335 | return obj.decode("utf-8")
336 | if isinstance(obj, set):
337 | return list(obj)
338 | if hasattr(obj, "__dict__"):
339 | return obj.__dict__
340 | return super().default(obj)
341 | except TypeError as e:
342 | logger.warning(f"Could not serialize object of type {type(obj)}: {str(e)}")
343 | return str(obj)
344 |
345 |
346 | class CustomSerializer(JsonPlusSerializer):
347 | NDARRAY_EXT_TYPE = 42 # Ensure this code doesn't conflict with other ExtTypes
348 | PLOTLY_FIGURE_EXT_TYPE = 43 # New extension type for Plotly figures
349 |
350 | def _default(self, obj: Any) -> Union[str, Dict[str, Any]]:
351 | if isinstance(obj, np.ndarray):
352 | return {
353 | "lc": 2,
354 | "type": "ndarray",
355 | "data": obj.tolist(),
356 | "dtype": str(obj.dtype),
357 | }
358 | if isinstance(obj, go.Figure):
359 | return {
360 | "lc": 2,
361 | "type": "plotly_figure",
362 | "data": pio.to_json(obj),
363 | }
364 | return super()._default(obj)
365 |
366 | def _reviver(self, value: Dict[str, Any]) -> Any:
367 | if value.get("lc", None) == 2:
368 | if value.get("type", None) == "ndarray":
369 | return np.array(value["data"], dtype=value["dtype"])
370 | elif value.get("type", None) == "plotly_figure":
371 | return pio.from_json(value["data"])
372 | return super()._reviver(value)
373 |
374 | # Override dumps_typed to use instance method _msgpack_enc
375 | def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
376 | if isinstance(obj, bytes):
377 | return "bytes", obj
378 | elif isinstance(obj, bytearray):
379 | return "bytearray", obj
380 | else:
381 | try:
382 | return "msgpack", self._msgpack_enc(obj)
383 | except UnicodeEncodeError:
384 | return "json", self.dumps(obj)
385 |
386 | # Provide instance-level _msgpack_enc
387 | def _msgpack_enc(self, data: Any) -> bytes:
388 | enc = msgpack.Packer(default=self._msgpack_default)
389 | return enc.pack(data)
390 |
391 | # Provide instance-level _msgpack_default
392 | def _msgpack_default(self, obj: Any) -> Any:
393 | if isinstance(obj, np.ndarray):
394 | # Prepare metadata for ndarray
395 | metadata = {
396 | "dtype": str(obj.dtype),
397 | "shape": obj.shape,
398 | }
399 | metadata_packed = msgpack.packb(metadata, use_bin_type=True)
400 | data_packed = obj.tobytes()
401 | combined = metadata_packed + data_packed
402 | return msgpack.ExtType(self.NDARRAY_EXT_TYPE, combined)
403 | elif isinstance(obj, np.number):
404 | # Handle NumPy scalar types
405 | return obj.item()
406 | elif isinstance(obj, go.Figure):
407 | figure_json = pio.to_json(obj)
408 | figure_packed = msgpack.packb(figure_json, use_bin_type=True)
409 | return msgpack.ExtType(self.PLOTLY_FIGURE_EXT_TYPE, figure_packed)
410 |
411 | return super()._msgpack_default(obj)
412 |
413 | # Provide instance-level loads_typed
414 | def loads_typed(self, data: tuple[str, bytes]) -> Any:
415 | type_, data_ = data
416 | if type_ == "bytes":
417 | return data_
418 | elif type_ == "bytearray":
419 | return bytearray(data_)
420 | elif type_ == "json":
421 | return self.loads(data_)
422 | elif type_ == "msgpack":
423 | return msgpack.unpackb(
424 | data_, ext_hook=self._msgpack_ext_hook, strict_map_key=False
425 | )
426 | else:
427 | raise NotImplementedError(f"Unknown serialization type: {type_}")
428 |
429 | # Provide instance-level _msgpack_ext_hook
430 | def _msgpack_ext_hook(self, code: int, data: bytes) -> Any:
431 | if code == self.NDARRAY_EXT_TYPE:
432 | # Unpack metadata
433 | unpacker = msgpack.Unpacker(use_list=False, raw=False)
434 | unpacker.feed(data)
435 | metadata = unpacker.unpack()
436 | buffer_offset = unpacker.tell()
437 | array_data = data[buffer_offset:]
438 | array = np.frombuffer(array_data, dtype=metadata["dtype"])
439 | array = array.reshape(metadata["shape"])
440 | return array
441 | elif code == self.PLOTLY_FIGURE_EXT_TYPE:
442 | figure_json = msgpack.unpackb(
443 | data, strict_map_key=False, ext_hook=self._msgpack_ext_hook
444 | )
445 | return pio.from_json(figure_json)
446 | else:
447 | return super()._msgpack_ext_hook(code, data)
448 |
449 |
450 | def convert_serialized_objects(obj):
451 | """
452 | Convert serialized objects back to their original form.
453 | Currently handles:
454 | - NumPy arrays (serialized as {"__type__": "ndarray", "data": [...], "dtype": "..."})
455 | - Plotly figures (serialized as {"__type__": "plotly_figure", "data": "..."})
456 |
457 | Args:
458 | obj: The object to convert, which may contain serialized objects
459 |
460 | Returns:
461 | The object with any serialized objects converted back to their original form
462 | """
463 | if isinstance(obj, dict):
464 | if "__type__" in obj:
465 | if obj["__type__"] == "ndarray":
466 | return np.array(obj["data"], dtype=obj["dtype"])
467 | elif obj["__type__"] == "plotly_figure":
468 | return pio.from_json(obj["data"])
469 | return {k: convert_serialized_objects(v) for k, v in obj.items()}
470 | elif isinstance(obj, list):
471 | return [convert_serialized_objects(item) for item in obj]
472 | return obj
473 |
474 |
475 | if __name__ == "__main__":
476 | serializer = CustomSerializer()
477 | original_data = {
478 | "array": np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64),
479 | "scalar": np.float32(7.5),
480 | "message": "Test serialization",
481 | "nested": {
482 | "array": np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64),
483 | "scalar": np.float32(7.5),
484 | "list_of_arrays": [
485 | np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64),
486 | np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float64),
487 | ],
488 | },
489 | }
490 |
491 | # Create a more complex figure with multiple traces and customization
492 | fig = go.Figure()
493 |
494 | # Add a scatter plot with markers
495 | fig.add_trace(
496 | go.Scatter(
497 | x=[1, 2, 3, 4, 5],
498 | y=[4, 5.2, 6, 3.2, 8],
499 | mode="markers+lines",
500 | name="Series A",
501 | marker=dict(size=10, color="blue", symbol="circle"),
502 | )
503 | )
504 |
505 | # Add a bar chart
506 | fig.add_trace(
507 | go.Bar(
508 | x=[1, 2, 3, 4, 5], y=[2, 3, 1, 5, 3], name="Series B", marker_color="green"
509 | )
510 | )
511 |
512 | # Add a line plot with different style
513 | fig.add_trace(
514 | go.Scatter(
515 | x=[1, 2, 3, 4, 5],
516 | y=[7, 6, 9, 8, 7],
517 | mode="lines",
518 | name="Series C",
519 | line=dict(width=3, dash="dash", color="red"),
520 | )
521 | )
522 |
523 | # Update layout with title and axis labels
524 | fig.update_layout(
525 | title="Complex Test Figure",
526 | xaxis_title="X Axis",
527 | yaxis_title="Y Axis",
528 | legend_title="Legend",
529 | template="plotly_white",
530 | )
531 | original_data["figure"] = fig
532 |
533 | # Serialize the data
534 | _, serialized = serializer.dumps_typed(original_data)
535 | # Deserialize the data
536 | deserialized_data = serializer.loads_typed(("msgpack", serialized))
537 |
538 | # Assertions
539 | assert isinstance(deserialized_data["array"], np.ndarray)
540 | assert np.array_equal(deserialized_data["array"], original_data["array"])
541 | assert isinstance(deserialized_data["scalar"], float)
542 | assert deserialized_data["scalar"] == float(original_data["scalar"])
543 | assert deserialized_data["message"] == original_data["message"]
544 | assert isinstance(deserialized_data["nested"]["array"], np.ndarray)
545 | assert np.array_equal(
546 | deserialized_data["nested"]["array"], original_data["nested"]["array"]
547 | )
548 | assert isinstance(deserialized_data["nested"]["scalar"], float)
549 | assert deserialized_data["nested"]["scalar"] == float(
550 | original_data["nested"]["scalar"]
551 | )
552 | assert isinstance(deserialized_data["nested"]["list_of_arrays"], list)
553 | assert all(
554 | isinstance(arr, np.ndarray)
555 | for arr in deserialized_data["nested"]["list_of_arrays"]
556 | )
557 | assert all(
558 | np.array_equal(arr, original_arr)
559 | for arr, original_arr in zip(
560 | deserialized_data["nested"]["list_of_arrays"],
561 | original_data["nested"]["list_of_arrays"],
562 | )
563 | )
564 |
565 | assert isinstance(deserialized_data["figure"], go.Figure)
566 | assert len(deserialized_data["figure"].data) == len(fig.data)
567 | for i, trace in enumerate(fig.data):
568 | assert deserialized_data["figure"].data[i].type == trace.type
569 | # Compare x and y data if they exist
570 | if hasattr(trace, "x") and trace.x is not None:
571 | assert np.array_equal(deserialized_data["figure"].data[i].x, trace.x)
572 | if hasattr(trace, "y") and trace.y is not None:
573 | assert np.array_equal(deserialized_data["figure"].data[i].y, trace.y)
574 |
575 | # Compare layout properties
576 | assert deserialized_data["figure"].layout.title.text == fig.layout.title.text
577 |
578 | print("Serialization and deserialization test passed.")
579 |
--------------------------------------------------------------------------------
/nodeology/interface.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
3 |
4 | Copyright 2024. UChicago Argonne, LLC. This software was produced
5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
7 | U.S. Department of Energy. The U.S. Government has rights to use,
8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
11 | modified to produce derivative works, such modified software should
12 | be clearly marked, so as not to confuse it with the version available
13 | from ANL.
14 |
15 | Additionally, redistribution and use in source and binary forms, with
16 | or without modification, are permitted provided that the following
17 | conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright
20 | notice, this list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright
23 | notice, this list of conditions and the following disclaimer in
24 | the documentation and/or other materials provided with the
25 | distribution.
26 |
27 | * Neither the name of UChicago Argonne, LLC, Argonne National
28 | Laboratory, ANL, the U.S. Government, nor the names of its
29 | contributors may be used to endorse or promote products derived
30 | from this software without specific prior written permission.
31 |
32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43 | POSSIBILITY OF SUCH DAMAGE.
44 | """
45 |
46 | ### Initial Author <2025>: Xiangyu Yin
47 |
48 | import os, json, importlib, threading, traceback, contextvars
49 | import logging
50 | import chainlit as cl
51 | from chainlit.cli import run_chainlit
52 | from nodeology.state import StateEncoder, convert_serialized_objects
53 |
54 |
55 | logger = logging.getLogger(__name__)
56 |
57 |
58 | def run_chainlit_for_workflow(workflow, initial_state=None):
59 | """
60 | Called by workflow.run(ui=True). This function:
61 | 1. stores workflow and initial_state in user session data
62 | 2. starts the chainlit server
63 | 3. returns the final state when the workflow completes
64 | """
65 | os.environ["NODEOLOGY_WORKFLOW_CLASS"] = (
66 | workflow.__class__.__module__ + "." + workflow.__class__.__name__
67 | )
68 |
69 | logger.info(f"Starting UI for workflow: {workflow.__class__.__name__}")
70 |
71 | # Save the initialization arguments to ensure they're passed when recreating the workflow
72 | if hasattr(workflow, "_init_kwargs"):
73 | logger.info(
74 | f"Found initialization kwargs: {list(workflow._init_kwargs.keys())}"
75 | )
76 |
77 | # We need to handle non-serializable objects in the kwargs
78 | serializable_kwargs = {}
79 | for key, value in workflow._init_kwargs.items():
80 | # Handle special cases
81 | if key == "state_defs":
82 | # For state_defs, we need to handle it specially
83 | if value is None:
84 | # If None, we'll use the workflow's state_schema
85 | serializable_kwargs[key] = None
86 | logger.info(
87 | f"Serializing {key} as None (will use workflow's state_schema)"
88 | )
89 | elif isinstance(value, type) and hasattr(value, "__annotations__"):
90 | # If it's a TypedDict or similar class with annotations, we'll use its name
91 | # The workflow class will handle recreating it
92 | serializable_kwargs["_state_defs_class"] = (
93 | f"{value.__module__}.{value.__name__}"
94 | )
95 | logger.info(
96 | f"Serializing state_defs as class reference: {serializable_kwargs['_state_defs_class']}"
97 | )
98 | elif isinstance(value, list):
99 | # If it's a list of state definitions, we'll try to serialize it
100 | # This is complex and might not work for all cases
101 | try:
102 | # Convert any TypedDict classes to their module.name string
103 | serialized_list = []
104 | for item in value:
105 | if isinstance(item, type) and hasattr(
106 | item, "__annotations__"
107 | ):
108 | serialized_list.append(
109 | f"{item.__module__}.{item.__name__}"
110 | )
111 | elif isinstance(item, tuple) and len(item) == 2:
112 | # Handle (name, type) tuples
113 | name, type_hint = item
114 | if isinstance(type_hint, type):
115 | serialized_list.append(
116 | [
117 | name,
118 | f"{type_hint.__module__}.{type_hint.__name__}",
119 | ]
120 | )
121 | else:
122 | serialized_list.append([name, str(type_hint)])
123 | elif isinstance(item, dict) and len(item) == 1:
124 | # Handle {"name": type} dictionaries
125 | name, type_hint = next(iter(item.items()))
126 | if isinstance(type_hint, type):
127 | serialized_list.append(
128 | {
129 | name: f"{type_hint.__module__}.{type_hint.__name__}"
130 | }
131 | )
132 | else:
133 | serialized_list.append({name: str(type_hint)})
134 | else:
135 | # Skip items we can't serialize
136 | logger.info(
137 | f"Skipping non-serializable state_def item: {item}"
138 | )
139 |
140 | if serialized_list:
141 | serializable_kwargs["_state_defs_list"] = serialized_list
142 | logger.info(
143 | f"Serializing state_defs as list: {serialized_list}"
144 | )
145 | else:
146 | logger.info("Could not serialize any state_defs items")
147 | except Exception as e:
148 | logger.error(f"Error serializing state_defs list: {str(e)}")
149 | else:
150 | logger.info(
151 | f"Cannot serialize state_defs of type {type(value).__name__}"
152 | )
153 | elif key == "checkpointer":
154 | # For checkpointer, just store "memory" if it's a string or an object
155 | if isinstance(value, str):
156 | serializable_kwargs[key] = value
157 | else:
158 | serializable_kwargs[key] = "memory"
159 | logger.info(f"Serializing checkpointer as: {serializable_kwargs[key]}")
160 | elif isinstance(value, (str, int, float, bool, type(None))):
161 | serializable_kwargs[key] = value
162 | logger.info(
163 | f"Serializing {key} as primitive type: {type(value).__name__}"
164 | )
165 | elif isinstance(value, list) and all(
166 | isinstance(item, (str, int, float, bool, type(None))) for item in value
167 | ):
168 | serializable_kwargs[key] = value
169 | logger.info(f"Serializing {key} as list of primitives")
170 | elif isinstance(value, dict) and all(
171 | isinstance(k, str)
172 | and isinstance(v, (str, int, float, bool, type(None)))
173 | for k, v in value.items()
174 | ):
175 | serializable_kwargs[key] = value
176 | logger.info(f"Serializing {key} as dict of primitives")
177 | else:
178 | logger.info(
179 | f"Skipping non-serializable {key} of type {type(value).__name__}"
180 | )
181 | # Skip other complex objects that can't be easily serialized
182 |
183 | # For client objects, just store their names
184 | if (
185 | "llm_name" in workflow._init_kwargs
186 | and hasattr(workflow, "llm_client")
187 | and hasattr(workflow.llm_client, "model_name")
188 | ):
189 | serializable_kwargs["llm_name"] = workflow.llm_client.model_name
190 | logger.info(
191 | f"Using llm_client.model_name: {workflow.llm_client.model_name}"
192 | )
193 |
194 | if (
195 | "vlm_name" in workflow._init_kwargs
196 | and hasattr(workflow, "vlm_client")
197 | and hasattr(workflow.vlm_client, "model_name")
198 | ):
199 | serializable_kwargs["vlm_name"] = workflow.vlm_client.model_name
200 | logger.info(
201 | f"Using vlm_client.model_name: {workflow.vlm_client.model_name}"
202 | )
203 |
204 | # Store the workflow's state_schema class name if available
205 | if hasattr(workflow, "state_schema") and hasattr(
206 | workflow.state_schema, "__name__"
207 | ):
208 | serializable_kwargs["_state_schema_class"] = (
209 | f"{workflow.state_schema.__module__}.{workflow.state_schema.__name__}"
210 | )
211 | logger.info(
212 | f"Storing state_schema class: {serializable_kwargs['_state_schema_class']}"
213 | )
214 |
215 | os.environ["NODEOLOGY_WORKFLOW_ARGS"] = json.dumps(
216 | serializable_kwargs, cls=StateEncoder
217 | )
218 | logger.info(f"Serialized kwargs: {list(serializable_kwargs.keys())}")
219 | else:
220 | logger.info("No initialization kwargs found on workflow")
221 |
222 | # Serialize any initial state if needed
223 | if initial_state:
224 | # Use StateEncoder to handle NumPy arrays
225 | os.environ["NODEOLOGY_INITIAL_STATE"] = json.dumps(
226 | initial_state, cls=StateEncoder
227 | )
228 | logger.info("Serialized initial state")
229 |
230 | # Create a shared variable to store the final state
231 | os.environ["NODEOLOGY_FINAL_STATE"] = "{}"
232 |
233 | # This file is nodeology/chainlit_interface.py, get its path:
234 | this_file = os.path.abspath(__file__)
235 | # Start with some standard arguments
236 | logger.info("Starting Chainlit server")
237 | run_chainlit(target=this_file)
238 |
239 | # Return the final state from the last session
240 | final_state = {}
241 | if (
242 | "NODEOLOGY_FINAL_STATE" in os.environ
243 | and os.environ["NODEOLOGY_FINAL_STATE"] != "{}"
244 | ):
245 | try:
246 | final_state_json = os.environ["NODEOLOGY_FINAL_STATE"]
247 | final_state_dict = json.loads(final_state_json)
248 | logger.info(
249 | f"Retrieved final state with keys: {list(final_state_dict.keys())}"
250 | )
251 |
252 | # Convert any serialized NumPy arrays back to arrays
253 | final_state = convert_serialized_objects(final_state_dict)
254 | logger.info("Converted any serialized objects in final state")
255 |
256 | except Exception as e:
257 | logger.error(f"Error parsing final state: {str(e)}")
258 |
259 | return final_state
260 |
261 |
262 | @cl.on_chat_start
263 | async def on_chat_start():
264 | """
265 | Called once a new user session is started in the chainlit UI.
266 | We will instantiate a new workflow for this session.
267 | """
268 | try:
269 | # Get the workflow class from environment variable
270 | workflow_class_path = os.environ.get("NODEOLOGY_WORKFLOW_CLASS")
271 | if not workflow_class_path:
272 | await cl.Message(content="No workflow class specified.").send()
273 | return
274 |
275 | logger.info(f"Creating workflow from class: {workflow_class_path}")
276 |
277 | # Import the workflow class dynamically
278 | module_path, class_name = workflow_class_path.rsplit(".", 1)
279 | module = importlib.import_module(module_path)
280 | WorkflowClass = getattr(module, class_name)
281 | logger.info(f"Successfully imported workflow class: {class_name}")
282 |
283 | # Get the saved initialization arguments
284 | workflow_args = {}
285 | state_defs_processed = False
286 |
287 | if "NODEOLOGY_WORKFLOW_ARGS" in os.environ:
288 | try:
289 | serialized_args = json.loads(os.environ["NODEOLOGY_WORKFLOW_ARGS"])
290 | logger.info(f"Loaded serialized args: {list(serialized_args.keys())}")
291 |
292 | # Handle special parameters
293 |
294 | # 1. Handle state_defs
295 | if "_state_defs_class" in serialized_args:
296 | # We have a class reference for state_defs
297 | state_defs_class_path = serialized_args.pop("_state_defs_class")
298 | try:
299 | module_path, class_name = state_defs_class_path.rsplit(".", 1)
300 | module = importlib.import_module(module_path)
301 | state_defs_class = getattr(module, class_name)
302 | workflow_args["state_defs"] = state_defs_class
303 | logger.info(
304 | f"Imported state_defs class: {state_defs_class_path}"
305 | )
306 | state_defs_processed = True
307 | except Exception as e:
308 | logger.error(f"Error importing state_defs class: {str(e)}")
309 | elif "_state_defs_list" in serialized_args:
310 | # We have a list of state definitions
311 | state_defs_list = serialized_args.pop("_state_defs_list")
312 | try:
313 | # Process each item in the list
314 | processed_list = []
315 | for item in state_defs_list:
316 | if isinstance(item, str):
317 | # It's a class reference
318 | try:
319 | module_path, class_name = item.rsplit(".", 1)
320 | module = importlib.import_module(module_path)
321 | class_obj = getattr(module, class_name)
322 | processed_list.append(class_obj)
323 | except Exception as e:
324 | logger.error(
325 | f"Error importing state def class {item}: {str(e)}"
326 | )
327 | elif isinstance(item, list) and len(item) == 2:
328 | # It's a [name, type] tuple
329 | name, type_str = item
330 | if "." in type_str:
331 | # It's a class reference
332 | try:
333 | module_path, class_name = type_str.rsplit(
334 | ".", 1
335 | )
336 | module = importlib.import_module(module_path)
337 | type_obj = getattr(module, class_name)
338 | processed_list.append((name, type_obj))
339 | except Exception as e:
340 | logger.error(
341 | f"Error importing type {type_str}: {str(e)}"
342 | )
343 | # Fall back to string representation
344 | processed_list.append((name, type_str))
345 | else:
346 | # It's a primitive type string
347 | processed_list.append((name, type_str))
348 | elif isinstance(item, dict) and len(item) == 1:
349 | # It's a {name: type} dict
350 | name, type_str = next(iter(item.items()))
351 | if "." in type_str:
352 | # It's a class reference
353 | try:
354 | module_path, class_name = type_str.rsplit(
355 | ".", 1
356 | )
357 | module = importlib.import_module(module_path)
358 | type_obj = getattr(module, class_name)
359 | processed_list.append({name: type_obj})
360 | except Exception as e:
361 | logger.error(
362 | f"Error importing type {type_str}: {str(e)}"
363 | )
364 | # Fall back to string representation
365 | processed_list.append({name: type_str})
366 | else:
367 | # It's a primitive type string
368 | processed_list.append({name: type_str})
369 |
370 | if processed_list:
371 | workflow_args["state_defs"] = processed_list
372 | logger.info(
373 | f"Processed state_defs list with {len(processed_list)} items"
374 | )
375 | state_defs_processed = True
376 | else:
377 | logger.info("No state_defs items could be processed")
378 | except Exception as e:
379 | logger.error(f"Error processing state_defs list: {str(e)}")
380 | elif (
381 | "state_defs" in serialized_args
382 | and serialized_args["state_defs"] is None
383 | ):
384 | # Explicit None value
385 | workflow_args["state_defs"] = None
386 | serialized_args.pop("state_defs")
387 | logger.info("Using None for state_defs")
388 | state_defs_processed = True
389 |
390 | # 2. Handle state_schema if needed
391 | if "_state_schema_class" in serialized_args:
392 | # We have a class reference for state_schema
393 | state_schema_class_path = serialized_args.pop("_state_schema_class")
394 | logger.info(
395 | f"Found state_schema class: {state_schema_class_path} (will be handled by workflow)"
396 | )
397 |
398 | # If we couldn't process state_defs, try to use the state_schema class as a fallback
399 | if not state_defs_processed:
400 | try:
401 | module_path, class_name = state_schema_class_path.rsplit(
402 | ".", 1
403 | )
404 | module = importlib.import_module(module_path)
405 | state_schema_class = getattr(module, class_name)
406 | workflow_args["state_defs"] = state_schema_class
407 | logger.info(
408 | f"Using state_schema class as fallback for state_defs: {state_schema_class_path}"
409 | )
410 | state_defs_processed = True
411 | except Exception as e:
412 | logger.error(
413 | f"Error importing state_schema class as fallback: {str(e)}"
414 | )
415 |
416 | # Add remaining arguments
417 | for key, value in serialized_args.items():
418 | workflow_args[key] = value
419 |
420 | # Convert any serialized NumPy arrays back to arrays
421 | workflow_args = convert_serialized_objects(workflow_args)
422 | logger.info(f"Final workflow args: {list(workflow_args.keys())}")
423 | except Exception as e:
424 | logger.error(f"Error parsing workflow arguments: {str(e)}")
425 | traceback.print_exc()
426 | # Continue with empty args if there's an error
427 |
428 | # If we couldn't process state_defs, check if the workflow class has a state_schema attribute
429 | if not state_defs_processed and hasattr(WorkflowClass, "state_schema"):
430 | logger.info(f"Using workflow class's state_schema attribute as fallback")
431 | # We don't need to set state_defs explicitly, the workflow will use its state_schema
432 |
433 | # Create a new instance of the workflow with the saved arguments
434 | logger.info(
435 | f"Creating workflow instance with args: {list(workflow_args.keys())}"
436 | )
437 | workflow = WorkflowClass(**workflow_args)
438 | logger.info(f"Successfully created workflow instance: {workflow.name}")
439 |
440 | # Check if VLM client is available
441 | if hasattr(workflow, "vlm_client") and workflow.vlm_client is not None:
442 | logger.info(f"VLM client is available")
443 | else:
444 | logger.info("VLM client is not available")
445 |
446 | # Get initial state if available
447 | initial_state = None
448 | initial_state_json = os.environ.get("NODEOLOGY_INITIAL_STATE")
449 | if initial_state_json:
450 | logger.info("Found initial state in environment")
451 | # Parse the JSON and convert any serialized NumPy arrays back to arrays
452 | initial_state_dict = json.loads(initial_state_json)
453 | logger.info(
454 | f"Loaded initial state with keys: {list(initial_state_dict.keys())}"
455 | )
456 |
457 | # Convert any serialized NumPy arrays back to arrays
458 | initial_state = convert_serialized_objects(initial_state_dict)
459 | logger.info("Converted any serialized objects in initial state")
460 |
461 | # Initialize the workflow
462 | if initial_state:
463 | logger.info(
464 | f"Initializing workflow with initial state: {list(initial_state.keys())}"
465 | )
466 | workflow.initialize(initial_state)
467 | else:
468 | logger.info("Initializing workflow with default state")
469 | workflow.initialize()
470 | logger.info("Workflow initialized successfully")
471 |
472 | # Store in user session
473 | cl.user_session.set("workflow", workflow)
474 | logger.info("Stored workflow in user session")
475 |
476 | # Capture the current Chainlit context
477 | parent_ctx = contextvars.copy_context()
478 | logger.info("Captured Chainlit context")
479 |
480 | # Create a function to save the final state when workflow completes
481 | def save_final_state(state):
482 | try:
483 | # Use StateEncoder to handle NumPy arrays and other complex objects
484 | serialized_state = json.dumps(state, cls=StateEncoder)
485 | os.environ["NODEOLOGY_FINAL_STATE"] = serialized_state
486 | logger.info(f"Saved final state with keys: {list(state.keys())}")
487 | except Exception as e:
488 | logger.error(f"Error saving final state: {str(e)}")
489 | traceback.print_exc()
490 |
491 | # Start the workflow in a background thread from within the Chainlit context
492 | def run_workflow_with_polling():
493 | try:
494 | logger.info("Starting workflow execution in background thread")
495 | # Run the workflow inside the captured context
496 | final_state = parent_ctx.run(
497 | lambda: workflow._run(initial_state, ui=True)
498 | )
499 | logger.info("Workflow execution completed")
500 |
501 | # Save the final state
502 | if final_state:
503 | save_final_state(final_state)
504 | logger.info("Final state saved")
505 | except Exception as e:
506 | logger.error(f"Error in workflow execution: {str(e)}")
507 | traceback.print_exc()
508 |
509 | # Start the workflow thread
510 | workflow_thread = threading.Thread(
511 | target=run_workflow_with_polling, daemon=True
512 | )
513 | cl.user_session.set("workflow_thread", workflow_thread)
514 | logger.info("Created workflow thread")
515 | workflow_thread.start()
516 | logger.info("Started workflow thread")
517 |
518 | await cl.Message(
519 | content=f"Welcome to the {workflow.__class__.__name__} via Nodeology!"
520 | ).send()
521 | logger.info("Sent welcome message")
522 | except Exception as e:
523 | logger.error(f"Error in on_chat_start: {str(e)}")
524 | traceback.print_exc()
525 | await cl.Message(content=f"Error initializing workflow: {str(e)}").send()
526 |
--------------------------------------------------------------------------------
/tests/test_node.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
3 |
4 | Copyright 2024. UChicago Argonne, LLC. This software was produced
5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
7 | U.S. Department of Energy. The U.S. Government has rights to use,
8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
11 | modified to produce derivative works, such modified software should
12 | be clearly marked, so as not to confuse it with the version available
13 | from ANL.
14 |
15 | Additionally, redistribution and use in source and binary forms, with
16 | or without modification, are permitted provided that the following
17 | conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright
20 | notice, this list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright
23 | notice, this list of conditions and the following disclaimer in
24 | the documentation and/or other materials provided with the
25 | distribution.
26 |
27 | * Neither the name of UChicago Argonne, LLC, Argonne National
28 | Laboratory, ANL, the U.S. Government, nor the names of its
29 | contributors may be used to endorse or promote products derived
30 | from this software without specific prior written permission.
31 |
32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43 | POSSIBILITY OF SUCH DAMAGE.
44 | """
45 |
46 | ### Initial Author <2024>: Xiangyu Yin
47 |
48 | import os, logging
49 | import pytest
50 | from typing import Any, Dict, List
51 | import numpy as np
52 | from nodeology.client import VLM_Client
53 | from nodeology.node import (
54 | Node,
55 | as_node,
56 | remove_markdown_blocks_formatting,
57 | )
58 | from nodeology.state import State
59 | from nodeology.log import add_logging_level
60 |
61 | add_logging_level("PRINTLOG", logging.INFO + 5)
62 | add_logging_level("LOGONLY", logging.INFO + 1)
63 |
64 |
65 | # Basic Node Tests
66 | class TestBasicNodeFunctionality:
67 | def test_node_creation(self):
68 | """Test basic node creation and configuration"""
69 | node = Node(
70 | node_type="test_node",
71 | prompt_template="Test prompt with {key1} and {key2}",
72 | sink=["output"],
73 | )
74 |
75 | assert node.node_type == "test_node"
76 | assert "key1" in node.required_keys
77 | assert "key2" in node.required_keys
78 | assert node.sink == ["output"]
79 |
80 | def test_node_error_handling(self):
81 | """Test node error handling for missing required keys"""
82 | node = Node(
83 | node_type="test_node", prompt_template="Test {required_key}", sink="output"
84 | )
85 |
86 | state = State()
87 | state["messages"] = []
88 |
89 | with pytest.raises(ValueError, match="Required key 'required_key' not found"):
90 | node(state, None)
91 |
92 | def test_empty_sink_list(self):
93 | """Test node behavior with empty sink list"""
94 | node = Node(node_type="test_node", prompt_template="Test", sink=[])
95 |
96 | state = State()
97 | state["messages"] = []
98 |
99 | result_state = node(state, lambda **k: "response")
100 | assert result_state == state # State should remain unchanged
101 |
102 | def test_state_type_tracking_chain(self):
103 | """Test node type tracking through multiple nodes"""
104 | node1 = Node(node_type="node1", prompt_template="Test", sink="output1")
105 | node2 = Node(node_type="node2", prompt_template="Test", sink="output2")
106 |
107 | state = State()
108 | state["messages"] = []
109 |
110 | state = node1(state, lambda **k: "response1")
111 | assert state["current_node_type"] == "node1"
112 | assert state["previous_node_type"] == ""
113 |
114 | state = node2(state, lambda **k: "response2")
115 | assert state["current_node_type"] == "node2"
116 | assert state["previous_node_type"] == "node1"
117 |
118 | def test_state_preservation(self):
119 | """Test that unrelated state data is preserved"""
120 | node = Node(node_type="test_node", prompt_template="Test", sink="output")
121 |
122 | state = State()
123 | state["messages"] = []
124 | state["preserved_key"] = "preserved_value"
125 | state["input"] = "test_input"
126 |
127 | result_state = node(state, lambda **k: "response")
128 | assert result_state["preserved_key"] == "preserved_value"
129 | assert result_state["input"] == "test_input"
130 |
131 | def test_state_immutability(self):
132 | """Test that original state is not modified"""
133 | node = Node(node_type="test_node", prompt_template="Test", sink="output")
134 |
135 | original_state = State()
136 | original_state["messages"] = []
137 | original_state["value"] = "original"
138 |
139 | state_copy = original_state.copy()
140 | result_state = node(state_copy, lambda **k: "response")
141 |
142 | assert original_state["value"] == "original"
143 | assert original_state != result_state
144 |
145 |
146 | # Execution Tests
147 | class TestNodeExecution:
148 | class MockClient:
149 | def __call__(self, messages, **kwargs):
150 | return "Test response"
151 |
152 | @pytest.fixture
153 | def basic_state(self):
154 | state = State()
155 | state["input"] = "value"
156 | state["messages"] = []
157 | return state
158 |
159 | def test_basic_execution(self, basic_state):
160 | """Test basic node execution"""
161 | node = Node(
162 | node_type="test_node", prompt_template="Test {input}", sink="output"
163 | )
164 |
165 | result_state = node(basic_state, self.MockClient())
166 | assert result_state["output"] == "Test response"
167 |
168 | def test_multiple_sinks(self, basic_state):
169 | """Test node with multiple output sinks"""
170 |
171 | class MultiResponseMock:
172 | def __call__(self, messages, **kwargs):
173 | return ["Response 1", "Response 2"]
174 |
175 | node = Node(
176 | node_type="test_node",
177 | prompt_template="Test {input}",
178 | sink=["output1", "output2"],
179 | )
180 |
181 | result_state = node(basic_state, MultiResponseMock())
182 | assert result_state["output1"] == "Response 1"
183 | assert result_state["output2"] == "Response 2"
184 |
185 | def test_sink_format(self, basic_state):
186 | """Test node with specific sink format"""
187 |
188 | class FormatMock:
189 | def __call__(self, messages, **kwargs):
190 | assert kwargs.get("format") == "json"
191 | return '{"key": "value"}'
192 |
193 | node = Node(
194 | node_type="test_node",
195 | prompt_template="Test {input}",
196 | sink="output",
197 | sink_format="json",
198 | )
199 |
200 | result_state = node(basic_state, FormatMock())
201 | assert result_state["output"] == '{"key": "value"}'
202 |
203 | def test_invalid_sink_format(self):
204 | """Test handling of invalid sink format specification"""
205 | node = Node(
206 | node_type="test_node",
207 | prompt_template="Test",
208 | sink="output",
209 | sink_format="invalid_format",
210 | )
211 |
212 | state = State()
213 | state["messages"] = []
214 |
215 | # The client should handle invalid format
216 | class FormatTestClient:
217 | def __call__(self, messages, **kwargs):
218 | assert kwargs.get("format") == "invalid_format"
219 | return "response"
220 |
221 | result_state = node(state, FormatTestClient())
222 | assert result_state["output"] == "response"
223 |
224 | def test_mismatched_sink_response(self):
225 | """Test handling of mismatched sink and response count"""
226 | node = Node(
227 | node_type="test_node", prompt_template="Test", sink=["output1", "output2"]
228 | )
229 |
230 | state = State()
231 | state["messages"] = []
232 |
233 | with pytest.raises(
234 | ValueError, match="Number of responses .* doesn't match number of sink"
235 | ):
236 | node(state, lambda **k: ["single_response"])
237 |
238 |
239 | # Expression Evaluation Tests
240 | class TestExpressionEvaluation:
241 | @pytest.fixture
242 | def sample_state(self):
243 | class SampleState(State):
244 | items: List[str]
245 | numbers: np.ndarray
246 | text: str
247 | data: Dict[str, Any]
248 | key: str
249 |
250 | state = SampleState()
251 | state["messages"] = []
252 | state["items"] = ["a", "b", "c", "d"]
253 | state["numbers"] = np.array([1, 2, 3, 4, 5])
254 | state["text"] = "Hello World"
255 | state["data"] = {"name": "John", "age": 30}
256 | state["key"] = "name"
257 | return state
258 |
259 | def test_basic_indexing(self, sample_state):
260 | """Test basic list indexing"""
261 | node = Node(
262 | node_type="test_node",
263 | prompt_template="First: {items[0]}, Last: {items[-1]}",
264 | sink="output",
265 | )
266 |
267 | result_state = node(sample_state, lambda **k: "response")
268 | assert "First: a, Last: d" in result_state["messages"][-1]["content"]
269 |
270 | def test_list_slicing(self, sample_state):
271 | """Test list slicing operations"""
272 | node = Node(
273 | node_type="test_node",
274 | prompt_template="Slice: {items[1:3]}, Reverse: {items[::-1]}",
275 | sink="output",
276 | )
277 |
278 | result_state = node(sample_state, lambda **k: "response")
279 | assert (
280 | "Slice: ['b', 'c'], Reverse: ['d', 'c', 'b', 'a']"
281 | in result_state["messages"][-1]["content"]
282 | )
283 |
284 | def test_string_methods(self, sample_state):
285 | """Test string method calls"""
286 | node = Node(
287 | node_type="test_node",
288 | prompt_template="""
289 | Upper: {text.upper()}
290 | Lower: {text.lower()}
291 | Title: {text.title()}
292 | Strip: {text.strip()}
293 | """,
294 | sink="output",
295 | )
296 |
297 | result_state = node(sample_state, lambda **k: "response")
298 | message = result_state["messages"][-1]["content"]
299 | assert "Upper: HELLO WORLD" in message
300 | assert "Lower: hello world" in message
301 | assert "Title: Hello World" in message
302 |
303 | def test_built_in_functions(self, sample_state):
304 | """Test allowed built-in function calls"""
305 | node = Node(
306 | node_type="test_node",
307 | prompt_template="""
308 | Length: {len(items)}
309 | Maximum: {max(numbers)}
310 | Minimum: {min(numbers)}
311 | Sum: {sum(numbers)}
312 | Absolute: {abs(-42)}
313 | """,
314 | sink="output",
315 | )
316 |
317 | result_state = node(sample_state, lambda **k: "response")
318 | message = result_state["messages"][-1]["content"]
319 | assert "Length: 4" in message
320 | assert "Maximum: 5" in message
321 | assert "Minimum: 1" in message
322 | assert "Sum: 15" in message
323 | assert "Absolute: 42" in message
324 |
325 | def test_dict_access(self, sample_state):
326 | """Test dictionary access methods"""
327 | node = Node(
328 | node_type="test_node",
329 | prompt_template="""
330 | Direct key: {data['name']}
331 | Variable key: {data[key]}
332 | """,
333 | sink="output",
334 | )
335 |
336 | result_state = node(sample_state, lambda **k: "response")
337 | message = result_state["messages"][-1]["content"]
338 | assert "Direct key: John" in message
339 | assert "Variable key: John" in message
340 |
341 | def test_type_conversions(self, sample_state):
342 | """Test type conversion functions"""
343 | sample_state["value"] = "42"
344 | node = Node(
345 | node_type="test_node",
346 | prompt_template="""
347 | Integer: {int(value)}
348 | Float: {float(value)}
349 | String: {str(numbers[0])}
350 | """,
351 | sink="output",
352 | )
353 |
354 | result_state = node(sample_state, lambda **k: "response")
355 | message = result_state["messages"][-1]["content"]
356 | assert "Integer: 42" in message
357 | assert "Float: 42.0" in message
358 | assert "String: 1" in message
359 |
360 | def test_invalid_expressions(self, sample_state):
361 | """Test error handling for invalid expressions"""
362 | # Test invalid function
363 | with pytest.raises(ValueError, match="Function not allowed: print"):
364 | node = Node(
365 | node_type="test_node", prompt_template="{print(text)}", sink="output"
366 | )
367 | node(sample_state, lambda **k: "response")
368 |
369 | # Test invalid method
370 | with pytest.raises(ValueError, match="String method not allowed: split"):
371 | node = Node(
372 | node_type="test_node", prompt_template="{text.split()}", sink="output"
373 | )
374 | node(sample_state, lambda **k: "response")
375 |
376 | # Test invalid index
377 | with pytest.raises(ValueError):
378 | node = Node(
379 | node_type="test_node", prompt_template="{items[10]}", sink="output"
380 | )
381 | node(sample_state, lambda **k: "response")
382 |
383 | # Test invalid key
384 | with pytest.raises(ValueError):
385 | node = Node(
386 | node_type="test_node",
387 | prompt_template="{data['invalid_key']}",
388 | sink="output",
389 | )
390 | node(sample_state, lambda **k: "response")
391 |
392 |
393 | # Pre/Post Processing Tests
394 | class TestPrePostProcessing:
395 | @pytest.fixture
396 | def processed_list(self):
397 | return []
398 |
399 | def test_pre_post_processing(self, processed_list):
400 | """Test node with pre and post processing functions"""
401 |
402 | def pre_process(state, client, **kwargs):
403 | processed_list.append("pre")
404 | return state
405 |
406 | def post_process(state, client, **kwargs):
407 | processed_list.append("post")
408 | return state
409 |
410 | node = Node(
411 | node_type="test_node",
412 | prompt_template="Test {input}",
413 | sink="output",
414 | pre_process=pre_process,
415 | post_process=post_process,
416 | )
417 |
418 | state = State()
419 | state["input"] = "value"
420 | state["messages"] = []
421 |
422 | node(state, lambda **k: "response")
423 | assert processed_list == ["pre", "post"]
424 |
425 | def test_none_pre_post_process(self):
426 | """Test node behavior when pre/post process returns None"""
427 |
428 | def pre_process(state, client, **kwargs):
429 | return None
430 |
431 | def post_process(state, client, **kwargs):
432 | return None
433 |
434 | node = Node(
435 | node_type="test_node",
436 | prompt_template="Test {input}",
437 | sink="output",
438 | pre_process=pre_process,
439 | post_process=post_process,
440 | )
441 |
442 | state = State()
443 | state["input"] = "value"
444 | state["messages"] = []
445 |
446 | result_state = node(state, lambda **k: "response")
447 | assert result_state == state
448 |
449 |
450 | # Source Mapping Tests
451 | class TestSourceMapping:
452 | @pytest.fixture
453 | def state_with_mapping(self):
454 | state = State()
455 | state["different_key"] = "mapped value"
456 | state["input_key"] = "value"
457 | state["messages"] = []
458 | return state
459 |
460 | def test_dict_source_mapping(self, state_with_mapping):
461 | """Test node with source key mapping"""
462 | node = Node(
463 | node_type="test_node", prompt_template="Test {value}", sink="output"
464 | )
465 |
466 | result_state = node(
467 | state_with_mapping,
468 | lambda **k: "Test response",
469 | source={"value": "different_key"},
470 | )
471 | assert result_state["output"] == "Test response"
472 |
473 | def test_string_source_mapping(self, state_with_mapping):
474 | """Test node with string source mapping"""
475 | node = Node(
476 | node_type="test_node", prompt_template="Test {source}", sink="output"
477 | )
478 |
479 | result_state = node(
480 | state_with_mapping, lambda **k: "response", source="input_key"
481 | )
482 | assert result_state["output"] == "response"
483 |
484 | def test_invalid_source_mapping(self):
485 | """Test handling of invalid source mapping"""
486 | node = Node(
487 | node_type="test_node", prompt_template="Test {value}", sink="output"
488 | )
489 |
490 | state = State()
491 | state["messages"] = []
492 |
493 | with pytest.raises(ValueError):
494 | node(state, None, source={"value": "nonexistent_key"})
495 |
496 |
497 | # VLM Integration Tests
498 | class TestVLMIntegration:
499 | class MockVLMClient(VLM_Client):
500 | def __init__(self):
501 | super().__init__()
502 |
503 | def process_images(self, messages, images, **kwargs):
504 | # Verify images are valid paths (for testing)
505 | assert all(isinstance(img, str) for img in images)
506 | return messages
507 |
508 | def __call__(self, messages, images=None, **kwargs):
509 | if images is not None:
510 | messages = self.process_images(messages, images)
511 | return "Image description response"
512 |
513 | def test_vlm_execution(self):
514 | """Test node execution with VLM client"""
515 | node = Node(
516 | node_type="test_vlm_node",
517 | prompt_template="Describe this image",
518 | sink="output",
519 | image_keys=["image_path"],
520 | )
521 |
522 | state = State()
523 | state["messages"] = []
524 |
525 | # Create a temporary test image file
526 | test_image_path = "test_image.jpg"
527 | with open(test_image_path, "w") as f:
528 | f.write("dummy image content")
529 |
530 | try:
531 | result_state = node(state, self.MockVLMClient(), image_path=test_image_path)
532 | assert result_state["output"] == "Image description response"
533 | finally:
534 | # Clean up the test image
535 | if os.path.exists(test_image_path):
536 | os.remove(test_image_path)
537 |
538 | def test_vlm_multiple_images(self):
539 | """Test VLM node with multiple image inputs"""
540 | node = Node(
541 | node_type="test_vlm_node",
542 | prompt_template="Describe these images",
543 | sink="output",
544 | image_keys=["image1", "image2"],
545 | )
546 |
547 | state = State()
548 | state["messages"] = []
549 |
550 | # Create temporary test image files
551 | test_images = ["test1.jpg", "test2.jpg"]
552 | for img_path in test_images:
553 | with open(img_path, "w") as f:
554 | f.write(f"dummy image content for {img_path}")
555 |
556 | try:
557 | result_state = node(
558 | state,
559 | self.MockVLMClient(),
560 | image1=test_images[0],
561 | image2=test_images[1],
562 | )
563 | assert result_state["output"] == "Image description response"
564 | finally:
565 | # Clean up test images
566 | for img_path in test_images:
567 | if os.path.exists(img_path):
568 | os.remove(img_path)
569 |
570 | def test_vlm_missing_image(self):
571 | """Test VLM node execution without required image"""
572 | node = Node(
573 | node_type="test_vlm_node",
574 | prompt_template="Describe this image",
575 | sink="output",
576 | image_keys=["image_path"],
577 | )
578 |
579 | state = State()
580 | state["messages"] = []
581 |
582 | with pytest.raises(ValueError, match="At least one image key must be provided"):
583 | node(state, self.MockVLMClient())
584 |
585 | def test_vlm_invalid_image_path(self):
586 | """Test VLM node with invalid image path"""
587 | node = Node(
588 | node_type="test_vlm_node",
589 | prompt_template="Test",
590 | sink="output",
591 | image_keys=["image"],
592 | )
593 |
594 | state = State()
595 | state["messages"] = []
596 |
597 | with pytest.raises(TypeError, match="should be string"):
598 | node(state, self.MockVLMClient(), image=None)
599 |
600 |
601 | # Decorator Tests
602 | class TestDecorators:
603 | def test_as_node_decorator(self):
604 | """Test @as_node decorator functionality"""
605 |
606 | @as_node(["output"])
607 | def test_function(input_value):
608 | return f"Processed {input_value}"
609 |
610 | state = State()
611 | state["input_value"] = "test"
612 | state["messages"] = []
613 |
614 | result_state = test_function(state, None)
615 | assert result_state["output"] == "Processed test"
616 |
617 | def test_custom_function_defaults(self):
618 | """Test node with custom function having default parameters"""
619 |
620 | def custom_func(required_param, optional_param="default"):
621 | return f"{required_param}-{optional_param}"
622 |
623 | node = Node(
624 | node_type="test_node",
625 | prompt_template="",
626 | sink="output",
627 | custom_function=custom_func,
628 | )
629 |
630 | state = State()
631 | state["required_param"] = "value"
632 | state["messages"] = []
633 |
634 | result_state = node(state, None)
635 | assert result_state["output"] == "value-default"
636 |
637 | def test_invalid_custom_function_args(self):
638 | """Test handling of custom function with missing required arguments"""
639 |
640 | def custom_func(required_arg):
641 | return required_arg
642 |
643 | node = Node(
644 | node_type="test_node",
645 | prompt_template="",
646 | sink="output",
647 | custom_function=custom_func,
648 | )
649 |
650 | state = State()
651 | state["messages"] = []
652 |
653 | with pytest.raises(ValueError, match="Required key 'required_arg' not found"):
654 | node(state, None)
655 |
656 | def test_as_node_with_multiple_sinks(self):
657 | """Test @as_node decorator with multiple output sinks"""
658 |
659 | @as_node(["output1", "output2"])
660 | def multi_output_function(value):
661 | return [f"First {value}", f"Second {value}"]
662 |
663 | state = State()
664 | state["value"] = "test"
665 | state["messages"] = []
666 |
667 | result_state = multi_output_function(state, None)
668 | assert result_state["output1"] == "First test"
669 | assert result_state["output2"] == "Second test"
670 |
671 | def test_as_node_with_pre_post_processing(self):
672 | """Test @as_node decorator with pre and post processing"""
673 | processed = []
674 |
675 | def pre_process(state, client, **kwargs):
676 | processed.append("pre")
677 | return state
678 |
679 | def post_process(state, client, **kwargs):
680 | processed.append("post")
681 | return state
682 |
683 | @as_node(["output"], pre_process=pre_process, post_process=post_process)
684 | def test_function(value):
685 | processed.append("main")
686 | return f"Result: {value}"
687 |
688 | state = State()
689 | state["value"] = "test"
690 | state["messages"] = []
691 |
692 | result_state = test_function(state, None)
693 | assert result_state["output"] == "Result: test"
694 | assert processed == ["pre", "main", "post"]
695 |
696 | def test_as_node_with_multiple_parameters(self):
697 | """Test @as_node decorator with function having multiple parameters"""
698 |
699 | @as_node(["output"])
700 | def multi_param_function(param1, param2, param3="default"):
701 | return f"{param1}-{param2}-{param3}"
702 |
703 | state = State()
704 | state["param1"] = "value1"
705 | state["param2"] = "value2"
706 | state["messages"] = []
707 |
708 | result_state = multi_param_function(state, None)
709 | assert result_state["output"] == "value1-value2-default"
710 |
711 | def test_as_node_as_function_flag(self):
712 | """Test @as_node decorator with as_function flag"""
713 |
714 | @as_node(["output"], as_function=True)
715 | def test_function(value):
716 | return f"Processed {value}"
717 |
718 | assert callable(test_function)
719 | assert hasattr(test_function, "node_type")
720 | assert hasattr(test_function, "sink")
721 | assert test_function.node_type == "test_function"
722 | assert test_function.sink == ["output"]
723 |
724 | def test_as_node_error_handling(self):
725 | """Test @as_node decorator error handling for missing parameters"""
726 |
727 | @as_node(["output"])
728 | def error_function(required_param):
729 | return f"Value: {required_param}"
730 |
731 | state = State()
732 | state["messages"] = []
733 |
734 | with pytest.raises(ValueError, match="Required key 'required_param' not found"):
735 | error_function(state, None)
736 |
737 |
738 | # Utility Function Tests
739 | class TestUtilityFunctions:
740 | def test_remove_markdown_blocks(self):
741 | """Test markdown block removal"""
742 | text = "```python\ndef test():\n pass\n```"
743 | result = remove_markdown_blocks_formatting(text)
744 | assert result == "def test():\n pass"
745 |
--------------------------------------------------------------------------------
/nodeology/node.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved.
3 |
4 | Copyright 2024. UChicago Argonne, LLC. This software was produced
5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National
6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the
7 | U.S. Department of Energy. The U.S. Government has rights to use,
8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR
9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR
10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is
11 | modified to produce derivative works, such modified software should
12 | be clearly marked, so as not to confuse it with the version available
13 | from ANL.
14 |
15 | Additionally, redistribution and use in source and binary forms, with
16 | or without modification, are permitted provided that the following
17 | conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright
20 | notice, this list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright
23 | notice, this list of conditions and the following disclaimer in
24 | the documentation and/or other materials provided with the
25 | distribution.
26 |
27 | * Neither the name of UChicago Argonne, LLC, Argonne National
28 | Laboratory, ANL, the U.S. Government, nor the names of its
29 | contributors may be used to endorse or promote products derived
30 | from this software without specific prior written permission.
31 |
32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS
33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago
36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43 | POSSIBILITY OF SUCH DAMAGE.
44 | """
45 |
46 | ### Initial Author <2024>: Xiangyu Yin
47 |
48 | import os
49 | from string import Formatter
50 | from inspect import signature
51 | from typing import Optional, Annotated, List, Union, Dict, Callable, Any
52 | import ast
53 |
54 | from nodeology.state import State
55 | from nodeology.log import log_print_color
56 | from nodeology.client import LLM_Client, VLM_Client
57 |
58 |
59 | def _process_state_with_transforms(
60 | state: State, transforms: Dict[str, Callable], client: LLM_Client, **kwargs
61 | ) -> State:
62 | """Helper function to apply transforms to state values.
63 |
64 | Args:
65 | state: Current state
66 | transforms: Dictionary mapping state keys to transformation functions
67 | client: LLM client (unused but kept for signature compatibility)
68 | """
69 | for key, transform in transforms.items():
70 | if key in state:
71 | try:
72 | state[key] = transform(state[key])
73 | except Exception as e:
74 | raise ValueError(f"Error applying transform to {key}: {str(e)}")
75 | return state
76 |
77 |
78 | class Node:
79 | """Template for creating node functions that process data using LLMs or custom functions.
80 |
81 | A Node represents a processing unit in a workflow that can:
82 | - Execute LLM/VLM queries or custom functions
83 | - Manage state before and after execution
84 | - Handle pre/post processing steps
85 | - Process both text and image inputs
86 | - Format and validate outputs
87 |
88 | Args:
89 | prompt_template (str): Template string for the LLM prompt. Uses Python string formatting
90 | syntax (e.g., "{variable}"). Empty if using custom_function.
91 | node_type (Optional[str]): Unique identifier for the node.
92 | sink (Optional[Union[List[str], str]]): Where to store results in state. Can be:
93 | - Single string key
94 | - List of keys for multiple outputs
95 | - None (results won't be stored)
96 | sink_format (Optional[str]): Format specification for LLM output (e.g., "json", "list").
97 | Used to ensure consistent response structure.
98 | image_keys (Optional[List[str]]): List of keys for image file paths when using VLM.
99 | Must provide at least one image path in kwargs when these are specified.
100 | pre_process (Optional[Union[Callable, Dict[str, Callable]]]): Either a function to run
101 | before execution or a dictionary mapping state keys to transform functions.
102 | post_process (Optional[Union[Callable, Dict[str, Callable]]]): Either a function to run
103 | after execution or a dictionary mapping state keys to transform functions.
104 | sink_transform (Optional[Union[Callable, List[Callable]]]): Transform(s) to apply to
105 | sink value(s). If sink is a string, must be a single callable. If sink is a list,
106 | can be either a single callable (applied to all sinks) or a list of callables.
107 | custom_function (Optional[Callable]): Custom function to execute instead of LLM query.
108 | Function parameters become required keys for node execution.
109 |
110 | Attributes:
111 | required_keys (List[str]): Keys required from state/kwargs for node execution.
112 | Extracted from either prompt_template or custom_function signature.
113 | prompt_history (List[str]): History of prompt templates used by this node.
114 |
115 | Raises:
116 | ValueError: If required keys are missing or response format is invalid
117 | FileNotFoundError: If specified image files don't exist
118 | ValueError: If VLM operations are attempted without proper client
119 |
120 | Example:
121 | ```python
122 | # Create a simple text processing node
123 | node = Node(
124 | node_type="summarizer",
125 | prompt_template="Summarize this text: {text}",
126 | sink="summary"
127 | )
128 |
129 | # Create a node with custom function
130 | def process_data(x, y):
131 | return x + y
132 |
133 | node = Node(
134 | node_type="calculator",
135 | prompt_template="",
136 | sink="result",
137 | custom_function=process_data
138 | )
139 | ```
140 | """
141 |
142 | # Simplified set of allowed functions that return values
143 | ALLOWED_FUNCTIONS = {
144 | "len": len, # Length of sequences
145 | "str": str, # String conversion
146 | "int": int, # Integer conversion
147 | "float": float, # Float conversion
148 | "sum": sum, # Sum of numbers
149 | "max": max, # Maximum value
150 | "min": min, # Minimum value
151 | "abs": abs, # Absolute value
152 | }
153 |
154 | DISALLOWED_FUNCTION_NAMES = [
155 | "eval",
156 | "exec",
157 | "compile",
158 | "open",
159 | "print",
160 | "execfile",
161 | "exit",
162 | "quit",
163 | "help",
164 | "dir",
165 | "globals",
166 | "locals",
167 | "dir",
168 | "type",
169 | "hash",
170 | "repr",
171 | "filter",
172 | "enumerate",
173 | "reversed",
174 | "sorted",
175 | "any",
176 | "all",
177 | ]
178 |
179 | # String methods that return values
180 | ALLOWED_STRING_METHODS = {
181 | "upper": str.upper,
182 | "lower": str.lower,
183 | "strip": str.strip,
184 | "capitalize": str.capitalize,
185 | }
186 |
187 | def __init__(
188 | self,
189 | prompt_template: str,
190 | node_type: Optional[str] = None,
191 | sink: Optional[Union[List[str], str]] = None,
192 | sink_format: Optional[str] = None,
193 | image_keys: Optional[List[str]] = None,
194 | pre_process: Optional[
195 | Union[
196 | Callable[[State, LLM_Client, Any], Optional[State]], Dict[str, Callable]
197 | ]
198 | ] = None,
199 | post_process: Optional[
200 | Union[
201 | Callable[[State, LLM_Client, Any], Optional[State]], Dict[str, Callable]
202 | ]
203 | ] = None,
204 | sink_transform: Optional[Union[Callable, List[Callable]]] = None,
205 | custom_function: Optional[Callable[..., Any]] = None,
206 | use_conversation: Optional[bool] = False,
207 | ):
208 | # Set default node_type based on whether it's prompt or function-based
209 | if node_type is None:
210 | if custom_function:
211 | self.node_type = custom_function.__name__
212 | else:
213 | self.node_type = "prompt"
214 | else:
215 | self.node_type = node_type
216 |
217 | self.prompt_template = prompt_template
218 | self._escaped_sections = [] # Store escaped sections at instance level
219 | self.sink = sink
220 | self.image_keys = image_keys
221 | self.sink_format = sink_format
222 | self.custom_function = custom_function
223 | self.use_conversation = use_conversation
224 |
225 | # Handle pre_process
226 | if isinstance(pre_process, dict):
227 | transforms = pre_process
228 | self.pre_process = (
229 | lambda state, client, **kwargs: _process_state_with_transforms(
230 | state, transforms, client, **kwargs
231 | )
232 | )
233 | else:
234 | self.pre_process = pre_process
235 |
236 | # Handle post_process
237 | if isinstance(post_process, dict):
238 | transforms = post_process
239 | self.post_process = (
240 | lambda state, client, **kwargs: _process_state_with_transforms(
241 | state, transforms, client, **kwargs
242 | )
243 | )
244 | else:
245 | self.post_process = post_process
246 |
247 | # Handle sink_transform
248 | if sink_transform is not None:
249 | if isinstance(sink, str):
250 | if not callable(sink_transform):
251 | raise ValueError(
252 | "sink_transform must be callable when sink is a string"
253 | )
254 | self._sink_transform = sink_transform
255 | elif isinstance(sink, list):
256 | if callable(sink_transform):
257 | # If single transform provided for multiple sinks, apply it to all
258 | self._sink_transform = [sink_transform] * len(sink)
259 | elif len(sink_transform) != len(sink):
260 | raise ValueError("Number of transforms must match number of sinks")
261 | else:
262 | self._sink_transform = sink_transform
263 | else:
264 | raise ValueError("sink must be specified to use sink_transform")
265 | else:
266 | self._sink_transform = None
267 |
268 | # Extract required keys from template or custom function signature
269 | if self.custom_function:
270 | # Get only required keys (those without default values) from function signature
271 | sig = signature(self.custom_function)
272 | self.required_keys = [
273 | param.name
274 | for param in sig.parameters.values()
275 | if param.default is param.empty
276 | and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
277 | and param.name != "self"
278 | ]
279 | else:
280 | # Extract base variable names from expressions, excluding function names and escaped content
281 | self.required_keys = []
282 | # First, temporarily replace escaped content
283 | template = prompt_template
284 |
285 | # Replace {{{ }}} sections first
286 | import re
287 |
288 | triple_brace_pattern = (
289 | r"\{{3}[\s\S]*?\}{3}" # Non-greedy match, including newlines
290 | )
291 | for i, match in enumerate(re.finditer(triple_brace_pattern, template)):
292 | placeholder = f"___ESCAPED_TRIPLE_{i}___"
293 | self._escaped_sections.append((placeholder, match.group(0)))
294 | template = template.replace(match.group(0), placeholder)
295 |
296 | # Then replace {{ }} sections
297 | double_brace_pattern = (
298 | r"\{{2}[\s\S]*?\}{2}" # Non-greedy match, including newlines
299 | )
300 | for i, match in enumerate(re.finditer(double_brace_pattern, template)):
301 | placeholder = f"___ESCAPED_DOUBLE_{i}___"
302 | self._escaped_sections.append((placeholder, match.group(0)))
303 | template = template.replace(match.group(0), placeholder)
304 |
305 | self._template_with_placeholders = template # Store modified template
306 |
307 | # Now parse the template normally
308 | for _, expr, _, _ in Formatter().parse(template):
309 | if expr is not None:
310 | # Parse the expression to identify actual variables
311 | try:
312 | tree = ast.parse(expr, mode="eval")
313 | variables = set()
314 | for node in ast.walk(tree):
315 | if (
316 | isinstance(node, ast.Name)
317 | and node.id not in self.ALLOWED_FUNCTIONS
318 | and node.id not in self.DISALLOWED_FUNCTION_NAMES
319 | ):
320 | variables.add(node.id)
321 | self.required_keys.extend(variables)
322 | except SyntaxError:
323 | # If parsing fails, fall back to basic extraction
324 | base_var = expr.split("[")[0].split(".")[0].split("(")[0]
325 | if (
326 | base_var not in self.ALLOWED_FUNCTIONS
327 | and base_var not in self.DISALLOWED_FUNCTION_NAMES
328 | and base_var not in self.required_keys
329 | ):
330 | self.required_keys.append(base_var)
331 |
332 | # Remove duplicates while preserving order
333 | self.required_keys = list(dict.fromkeys(self.required_keys))
334 |
335 | self._prompt_history = [
336 | prompt_template
337 | ] # Add prompt history as private attribute
338 |
339 | def _eval_expr(self, expr: str, context: dict) -> Any:
340 | """Safely evaluate a Python expression with limited scope."""
341 | try:
342 | # Add allowed functions to the context
343 | eval_context = {
344 | **context,
345 | **self.ALLOWED_FUNCTIONS, # Include built-in functions
346 | }
347 | tree = ast.parse(expr, mode="eval")
348 | return self._eval_node(tree.body, eval_context)
349 | except SyntaxError as e:
350 | raise ValueError(f"Invalid Python syntax in expression: {str(e)}")
351 | except Exception as e:
352 | raise ValueError(f"Invalid expression: {str(e)}")
353 |
354 | def _eval_node(self, node: ast.AST, context: dict) -> Any:
355 | """Recursively evaluate an AST node with security constraints."""
356 | if isinstance(node, ast.Name):
357 | if node.id not in context:
358 | raise ValueError(f"Variable '{node.id}' not found in context")
359 | return context[node.id]
360 |
361 | elif isinstance(node, ast.Constant):
362 | return node.value
363 |
364 | elif isinstance(node, ast.UnaryOp): # Add support for unary operations
365 | if isinstance(node.op, ast.USub): # Handle negative numbers
366 | operand = self._eval_node(node.operand, context)
367 | return -operand
368 | raise ValueError(f"Unsupported unary operator: {type(node.op).__name__}")
369 |
370 | elif isinstance(node, ast.Call):
371 | if isinstance(node.func, ast.Attribute): # Add support for method calls
372 | obj = self._eval_node(node.func.value, context)
373 | method_name = node.func.attr
374 | # List of allowed string methods
375 | allowed_string_methods = ["upper", "lower", "title", "strip"]
376 | if method_name in allowed_string_methods:
377 | method = getattr(obj, method_name)
378 | args = [self._eval_node(arg, context) for arg in node.args]
379 | return method(*args)
380 | raise ValueError(f"String method not allowed: {method_name}")
381 | elif isinstance(node.func, ast.Name):
382 | func_name = node.func.id
383 | if func_name not in self.ALLOWED_FUNCTIONS:
384 | raise ValueError(f"Function not allowed: {func_name}")
385 | func = context[
386 | func_name
387 | ] # Get function from context instead of globals
388 | args = [self._eval_node(arg, context) for arg in node.args]
389 | return func(*args)
390 | raise ValueError("Only simple function calls are allowed")
391 |
392 | elif isinstance(node, ast.Attribute):
393 | # Handle string methods (e.g., text.upper())
394 | if not isinstance(node.value, ast.Name):
395 | raise ValueError("Only simple string methods are allowed")
396 |
397 | obj = self._eval_node(node.value, context)
398 | if not isinstance(obj, str):
399 | raise ValueError("Methods are only allowed on strings")
400 |
401 | method_name = node.attr
402 | if method_name not in self.ALLOWED_STRING_METHODS:
403 | raise ValueError(f"String method not allowed: {method_name}")
404 |
405 | return self.ALLOWED_STRING_METHODS[method_name](obj)
406 |
407 | elif isinstance(node, (ast.List, ast.Tuple)):
408 | return [self._eval_node(elt, context) for elt in node.elts]
409 |
410 | elif isinstance(node, ast.Subscript):
411 | value = self._eval_node(node.value, context)
412 | if isinstance(node.slice, ast.Slice):
413 | lower = (
414 | self._eval_node(node.slice.lower, context)
415 | if node.slice.lower
416 | else None
417 | )
418 | upper = (
419 | self._eval_node(node.slice.upper, context)
420 | if node.slice.upper
421 | else None
422 | )
423 | step = (
424 | self._eval_node(node.slice.step, context)
425 | if node.slice.step
426 | else None
427 | )
428 | return value[slice(lower, upper, step)]
429 | else:
430 | # Handle both numeric indices and string keys
431 | idx = self._eval_node(node.slice, context)
432 | try:
433 | return value[idx]
434 | except (TypeError, KeyError) as e:
435 | raise ValueError(f"Invalid subscript access: {str(e)}")
436 |
437 | elif isinstance(node, ast.Str): # For string literals in subscripts
438 | return node.s
439 |
440 | raise ValueError(f"Unsupported expression type: {type(node).__name__}")
441 |
442 | @property
443 | def func(self):
444 | """Returns the node function without executing it"""
445 |
446 | def node_function(
447 | state: Annotated[State, "The current state"],
448 | client: Annotated[LLM_Client, "The LLM client"],
449 | sink: Optional[Union[List[str], str]] = None,
450 | source: Optional[Dict[str, str]] = None,
451 | **kwargs,
452 | ) -> State:
453 | return self(state, client, sink, source, **kwargs)
454 |
455 | # Attach the attributes to the function
456 | node_function.node_type = self.node_type
457 | node_function.prompt_template = self.prompt_template
458 | node_function.sink = self.sink
459 | node_function.image_keys = self.image_keys
460 | node_function.sink_format = self.sink_format
461 | node_function.pre_process = self.pre_process
462 | node_function.post_process = self.post_process
463 | node_function.required_keys = self.required_keys
464 | return node_function
465 |
466 | def __call__(
467 | self,
468 | state: State,
469 | client: Union[LLM_Client, VLM_Client],
470 | sink: Optional[Union[List[str], str]] = None,
471 | source: Optional[Union[Dict[str, str], str]] = None,
472 | debug: bool = False,
473 | use_conversation: Optional[bool] = None,
474 | **kwargs,
475 | ) -> State:
476 | """Creates and executes a node function from this template.
477 |
478 | Args:
479 | state: Current state object containing variables
480 | client: LLM or VLM client for making API calls
481 | sink: Optional override for where to store results
482 | source: Optional mapping of template keys to state keys
483 | **kwargs: Additional keyword arguments passed to function
484 |
485 | Returns:
486 | Updated state object with results stored in sink keys
487 |
488 | Raises:
489 | ValueError: If required keys are missing or response format is invalid
490 | FileNotFoundError: If specified image files don't exist
491 | """
492 | # Update node type
493 | state["previous_node_type"] = state.get("current_node_type", "")
494 | state["current_node_type"] = self.node_type
495 |
496 | # Pre-processing if defined
497 | if self.pre_process:
498 | pre_process_result = self.pre_process(state, client, **kwargs)
499 | if pre_process_result is None:
500 | return state
501 | state = pre_process_result
502 |
503 | # Get values from state or kwargs
504 | if isinstance(source, str):
505 | source = {"source": source}
506 |
507 | message_values = {}
508 | for key in self.required_keys:
509 | if source and key in source:
510 | source_key = source[key]
511 | if source_key not in state:
512 | raise ValueError(
513 | f"Source mapping key '{source_key}' not found in state"
514 | )
515 | message_values[key] = state[source_key]
516 | elif key in state:
517 | message_values[key] = state[key]
518 | elif key in kwargs:
519 | message_values[key] = kwargs[key]
520 | else:
521 | raise ValueError(f"Required key '{key}' not found in state or kwargs")
522 |
523 | # Execute either custom function or LLM call
524 | if self.custom_function:
525 | # Get default values from function signature
526 | sig = signature(self.custom_function)
527 | default_values = {
528 | k: v.default
529 | for k, v in sig.parameters.items()
530 | if v.default is not v.empty
531 | }
532 | # Update message_values with defaults for missing parameters
533 | for key, default in default_values.items():
534 | if key not in message_values:
535 | message_values[key] = default
536 | if "state" in sig.parameters and "state" not in message_values:
537 | message_values["state"] = state
538 | if "client" in sig.parameters and "client" not in message_values:
539 | message_values["client"] = client
540 | response = self.custom_function(**message_values)
541 | else:
542 | # Create a context with state variables for expression evaluation
543 | eval_context = {**message_values}
544 |
545 | # First fill the template with placeholders
546 | message = self._template_with_placeholders
547 | for _, expr, _, _ in Formatter().parse(self._template_with_placeholders):
548 | if expr is not None:
549 | try:
550 | result = self._eval_expr(expr, eval_context)
551 | message = message.replace(f"{{{expr}}}", str(result))
552 | except Exception as e:
553 | raise ValueError(
554 | f"Error evaluating expression '{expr}': {str(e)}"
555 | )
556 |
557 | # Now restore the escaped sections
558 | for placeholder, original in self._escaped_sections:
559 | message = message.replace(placeholder, original)
560 |
561 | # Record the formatted message
562 | if "messages" not in state:
563 | state["messages"] = []
564 | state["messages"].append({"role": "user", "content": message})
565 |
566 | # Determine if we should use conversation mode
567 | should_use_conversation = (
568 | use_conversation
569 | if use_conversation is not None
570 | else self.use_conversation
571 | )
572 | if should_use_conversation:
573 | assert "conversation" in state and isinstance(
574 | state["conversation"], list
575 | ), "Conversation does not exist in state or is not a list of messages"
576 |
577 | # Prepare messages for client call
578 | if should_use_conversation:
579 | if len(state["conversation"]) == 0 or state["end_conversation"]:
580 | state["conversation"].append({"role": "user", "content": message})
581 | messages = state["conversation"]
582 | else:
583 | messages = [{"role": "user", "content": message}]
584 |
585 | # Handle VLM specific requirements
586 | if self.image_keys:
587 | if not isinstance(client, VLM_Client):
588 | raise ValueError("VLM client required for image keys")
589 |
590 | # Check both state and kwargs for image keys
591 | image_paths = []
592 | for key in self.image_keys:
593 | if key in state:
594 | path = state[key]
595 | elif key in kwargs:
596 | path = kwargs[key]
597 | else:
598 | continue
599 |
600 | if path is None:
601 | raise TypeError(
602 | f"Image path for '{key}' should be string, got None"
603 | )
604 | if not isinstance(path, str):
605 | raise TypeError(
606 | f"Image path for '{key}' should be string, got {type(path)}"
607 | )
608 | image_paths.append(path)
609 |
610 | if not image_paths:
611 | raise ValueError(
612 | "At least one image key must be provided in state or kwargs"
613 | )
614 |
615 | # Verify all paths exist
616 | for path in image_paths:
617 | if not os.path.exists(path):
618 | raise FileNotFoundError(f"Image file not found: {path}")
619 |
620 | response = client(
621 | messages=messages,
622 | images=image_paths,
623 | format=self.sink_format,
624 | workflow=kwargs.get("workflow"),
625 | node=self,
626 | previous_node_type=state["previous_node_type"],
627 | )
628 | else:
629 | response = client(
630 | messages=messages,
631 | format=self.sink_format,
632 | workflow=kwargs.get("workflow"),
633 | node=self,
634 | previous_node_type=state["previous_node_type"],
635 | )
636 |
637 | log_print_color(f"Response: {response}", "white", False)
638 |
639 | # Update state with response
640 | if sink is None:
641 | sink = self.sink
642 |
643 | if sink is None:
644 | log_print_color(
645 | f"Warning: No sink specified for {self.node_type} node", "yellow"
646 | )
647 | return state
648 |
649 | if isinstance(sink, str):
650 | state[sink] = (
651 | remove_markdown_blocks_formatting(response)
652 | if not self.custom_function
653 | else response
654 | )
655 | elif isinstance(sink, list):
656 | if not sink:
657 | log_print_color(
658 | f"Warning: Empty sink list for {self.node_type} node", "yellow"
659 | )
660 | return state
661 |
662 | if len(sink) == 1:
663 | state[sink[0]] = (
664 | remove_markdown_blocks_formatting(response)
665 | if not self.custom_function
666 | else response
667 | )
668 | else:
669 | if not isinstance(response, (list, tuple)):
670 | raise ValueError(
671 | f"Expected multiple responses for multiple sink in {self.node_type} node, but got a single response"
672 | )
673 | if len(response) != len(sink):
674 | raise ValueError(
675 | f"Number of responses ({len(response)}) doesn't match number of sink ({len(sink)}) in {self.node_type} node"
676 | )
677 |
678 | for key, value in zip(sink, response):
679 | state[key] = (
680 | remove_markdown_blocks_formatting(value)
681 | if not self.custom_function
682 | else value
683 | )
684 |
685 | # After storing results but before post_process, apply sink transforms
686 | if self._sink_transform is not None:
687 | current_sink = sink or self.sink
688 | if isinstance(current_sink, str):
689 | state[current_sink] = self._sink_transform(state[current_sink])
690 | else:
691 | for key, transform in zip(current_sink, self._sink_transform):
692 | state[key] = transform(state[key])
693 |
694 | # Post-processing if defined
695 | if self.post_process:
696 | post_process_result = self.post_process(state, client, **kwargs)
697 | if post_process_result is None:
698 | return state
699 | state = post_process_result
700 |
701 | return state
702 |
703 | def __str__(self):
704 | MAX_WIDTH = 80
705 |
706 | # Format prompt with highlighted keys
707 | prompt_lines = self.prompt_template.split("\n")
708 | # First make the whole prompt green
709 | prompt_lines = [f"\033[92m{line}\033[0m" for line in prompt_lines] # Green
710 | # Then highlight the keys in red
711 | for key in self.required_keys:
712 | for i, line in enumerate(prompt_lines):
713 | prompt_lines[i] = line.replace(
714 | f"{{{key}}}",
715 | f"\033[91m{{{key}}}\033[0m\033[92m", # Red keys, return to green after
716 | )
717 |
718 | # Calculate width for horizontal line (min of actual width and MAX_WIDTH)
719 | width = min(max(len(line) for line in prompt_lines), MAX_WIDTH)
720 | double_line = "═" * width
721 | horizontal_line = "─" * width
722 |
723 | # Color formatting for keys in info section
724 | required_keys_colored = [
725 | f"\033[91m{key}\033[0m" for key in self.required_keys
726 | ] # Red
727 | if isinstance(self.sink, str):
728 | sink_colored = [f"\033[94m{self.sink}\033[0m"] # Blue
729 | elif isinstance(self.sink, list):
730 | sink_colored = [f"\033[94m{key}\033[0m" for key in self.sink] # Blue
731 | else:
732 | sink_colored = ["None"]
733 |
734 | # Build the string representation
735 | result = [
736 | double_line,
737 | f"{self.node_type}",
738 | horizontal_line,
739 | *prompt_lines,
740 | horizontal_line,
741 | f"Required keys: {', '.join(required_keys_colored)}",
742 | f"Sink keys: {', '.join(sink_colored)}",
743 | f"Format: {self.sink_format or 'None'}",
744 | f"Image keys: {', '.join(self.image_keys) or 'None'}",
745 | f"Pre-process: {self.pre_process.__name__ if self.pre_process else 'None'}",
746 | f"Post-process: {self.post_process.__name__ if self.post_process else 'None'}",
747 | f"Custom function: {self.custom_function.__name__ if self.custom_function else 'None'}",
748 | ]
749 |
750 | return "\n".join(result)
751 |
752 | @property
753 | def prompt_history(self) -> list[str]:
754 | """Returns the history of prompt templates.
755 |
756 | Returns:
757 | list[str]: List of prompt templates, oldest to newest
758 | """
759 | return self._prompt_history.copy()
760 |
761 |
762 | def as_node(
763 | sink: List[str],
764 | pre_process: Optional[Callable[[State, LLM_Client, Any], Optional[State]]] = None,
765 | post_process: Optional[Callable[[State, LLM_Client, Any], Optional[State]]] = None,
766 | as_function: bool = False,
767 | ):
768 | """Decorator to transform a regular Python function into a Node function.
769 |
770 | This decorator allows you to convert standard Python functions into Node objects
771 | that can be integrated into a nodeology workflow. The decorated function becomes
772 | the custom_function of the Node, with its parameters becoming required keys.
773 |
774 | Args:
775 | sink (List[str]): List of state keys where the function's results will be stored.
776 | The number of sink keys should match the number of return values from the function.
777 | pre_process (Optional[Callable]): Function to run before main execution.
778 | Signature: (state: State, client: LLM_Client, **kwargs) -> Optional[State]
779 | post_process (Optional[Callable]): Function to run after main execution.
780 | Signature: (state: State, client: LLM_Client, **kwargs) -> Optional[State]
781 | as_function (bool): If True, returns a callable node function. If False, returns
782 | the Node object itself. Default is False.
783 |
784 | Returns:
785 | Union[Node, Callable]: Either a Node object or a node function, depending on
786 | the as_function parameter.
787 |
788 | Example:
789 | ```python
790 | # Basic usage
791 | @as_node(sink=["result"])
792 | def multiply(x: int, y: int) -> int:
793 | return x * y
794 |
795 | # With pre and post processing
796 | def log_start(state, client, **kwargs):
797 | print("Starting calculation...")
798 | return state
799 |
800 | def log_result(state, client, **kwargs):
801 | print(f"Result: {state['result']}")
802 | return state
803 |
804 | @as_node(
805 | sink=["result"],
806 | pre_process=log_start,
807 | post_process=log_result
808 | )
809 | def add(x: int, y: int) -> int:
810 | return x + y
811 |
812 | # Multiple return values
813 | @as_node(sink=["mean", "std"])
814 | def calculate_stats(numbers: List[float]) -> Tuple[float, float]:
815 | return np.mean(numbers), np.std(numbers)
816 | ```
817 |
818 | Notes:
819 | - The decorated function's parameters become required keys for node execution
820 | - The function can access the state and client objects by including them
821 | as optional parameters
822 | - The number of sink keys should match the number of return values
823 | - When as_function=True, the decorator returns a callable that can be used
824 | directly in workflows
825 | """
826 |
827 | def decorator(func):
828 | # Create a Node instance with the custom function
829 | node = Node(
830 | prompt_template="", # Empty template since we're using custom function
831 | node_type=func.__name__,
832 | sink=sink,
833 | pre_process=pre_process,
834 | post_process=post_process,
835 | custom_function=func, # Pass the function to Node
836 | )
837 |
838 | # Get only required parameters (those without default values)
839 | sig = signature(func)
840 | node.required_keys = [
841 | param.name
842 | for param in sig.parameters.values()
843 | if param.default is param.empty
844 | ]
845 |
846 | return node.func if as_function else node
847 |
848 | return decorator
849 |
850 |
851 | def remove_markdown_blocks_formatting(text: str) -> str:
852 | """Remove common markdown code block delimiters from text.
853 |
854 | Args:
855 | text: Input text containing markdown code blocks
856 |
857 | Returns:
858 | str: Text with code block delimiters removed
859 | """
860 | lines = text.split("\n")
861 | cleaned_lines = []
862 |
863 | for line in lines:
864 | stripped_line = line.strip()
865 | # Check if line starts with backticks (more robust than exact matches)
866 | if stripped_line.startswith("```"):
867 | continue
868 | else:
869 | cleaned_lines.append(line)
870 |
871 | return "\n".join(cleaned_lines)
872 |
--------------------------------------------------------------------------------