├── examples ├── chainlit.md ├── public │ ├── favicon.png │ ├── logo_light.png │ ├── hide-watermark.js │ ├── theme.json │ ├── elements │ │ └── DataDisplay.jsx │ └── logo_light.svg ├── particle_trajectory_analysis.png ├── README.md ├── particle_trajectory_analysis.yaml ├── .chainlit │ └── config.toml ├── writing_improvement.py └── trajectory_analysis.py ├── assets └── logo.jpg ├── pyproject.toml ├── requirements.txt ├── MANIFEST.in ├── CITATION.cff ├── setup.py ├── LICENSE ├── nodeology ├── __init__.py ├── log.py ├── client.py ├── state.py ├── interface.py └── node.py ├── CONTRIBUTING.md ├── .gitignore ├── README.md └── tests ├── test_state.py └── test_node.py /examples/chainlit.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyin-anl/Nodeology/HEAD/assets/logo.jpg -------------------------------------------------------------------------------- /examples/public/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyin-anl/Nodeology/HEAD/examples/public/favicon.png -------------------------------------------------------------------------------- /examples/public/logo_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyin-anl/Nodeology/HEAD/examples/public/logo_light.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /examples/particle_trajectory_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyin-anl/Nodeology/HEAD/examples/particle_trajectory_analysis.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | pyyaml 3 | typing-extensions 4 | numpy 5 | plotly 6 | kaleido 7 | langgraph<=0.2.45 8 | litellm>=1.0.0 9 | langfuse>=2.0.0 10 | chainlit>=2.0.0 11 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt 4 | include CITATION.cff 5 | include CONTRIBUTING.md 6 | recursive-include examples * 7 | recursive-include tests * 8 | recursive-include nodeology * -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Xiangyu" 5 | given-names: "Yin" 6 | orcid: "https://orcid.org/0000-0003-2868-1728" 7 | title: "Nodeology" 8 | version: 0.0.1 9 | date-released: 2024-11-20 10 | url: "https://github.com/xyin-anl/Nodeology" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | with open("requirements.txt", "r", encoding="utf-8") as fh: 7 | requirements = [ 8 | line.strip() for line in fh if line.strip() and not line.startswith("#") 9 | ] 10 | 11 | setup( 12 | name="nodeology", 13 | version="0.0.2", 14 | author="Xiangyu Yin", 15 | author_email="xyin@anl.gov", 16 | description="Foundation AI-Enhanced Scientific Workflow", 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | url="https://github.com/xyin-anl/nodeology", 20 | packages=find_packages(), 21 | include_package_data=True, 22 | package_data={ 23 | "": ["*.md", "*.cff", "LICENSE"], 24 | "nodeology": ["examples/*", "tests/*"], 25 | }, 26 | classifiers=[ 27 | "Intended Audience :: Science/Research", 28 | "Programming Language :: Python :: 3", 29 | ], 30 | python_requires=">=3.10", 31 | install_requires=requirements, 32 | ) 33 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Nodeology Examples 2 | 3 | This directory contains example applications built with Nodeology, demonstrating features of the framework. 4 | 5 | ## Prerequisites 6 | 7 | Before running the examples, ensure you have `nodeology` installed 8 | 9 | ```bash 10 | pip install nodeology 11 | ``` 12 | 13 | ## Directory Structure 14 | 15 | - `writing_improvement.py` - Text analysis and improvement workflow 16 | - `trajectory_analysis.py` - Particle trajectory simulation and visualization 17 | - `public/` - Static assets for the examples (needed for `nodeology` UI elements) 18 | - `.chainlit/` - Chainlit configuration files (needed for `nodeology` UI settings) 19 | 20 | ## Available Examples 21 | 22 | ### 1. Writing Improvement (`writing_improvement.py`) 23 | 24 | An interactive application that helps users improve their writing through analysis and suggestions. This example demonstrates: 25 | 26 | - State management with Nodeology 27 | - Interactive user input handling 28 | - Text analysis workflow 29 | - Chainlit UI integration 30 | 31 | To run this example: 32 | 33 | ```bash 34 | cd examples 35 | python writing_improvement.py 36 | ``` 37 | 38 | ### 2. Particle Trajectory Analysis (`trajectory_analysis.py`) 39 | 40 | A scientific application that simulates and visualizes particle trajectories under electromagnetic fields. This example showcases: 41 | 42 | - Complex scientific calculations 43 | - Interactive parameter input 44 | - Data visualization 45 | - State management for scientific workflows 46 | - Advanced Chainlit UI features 47 | 48 | To run this example: 49 | 50 | ```bash 51 | cd examples 52 | python trajectory_analysis.py 53 | ``` 54 | 55 | ## Usage Tips 56 | 57 | 1. Each example will open in your default web browser when launched 58 | 2. Follow the interactive prompts in the Chainlit UI 59 | 3. You can modify parameters and experiment with different inputs 60 | 4. Use the chat interface to interact with the applications 61 | 62 | ## License 63 | 64 | These examples are provided under the same license as the main Nodeology project. See the license headers in individual files for details. 65 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 2 | 3 | Copyright 2024. UChicago Argonne, LLC. This software was produced 4 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 5 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 6 | U.S. Department of Energy. The U.S. Government has rights to use, 7 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 8 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 9 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 10 | modified to produce derivative works, such modified software should 11 | be clearly marked, so as not to confuse it with the version available 12 | from ANL. 13 | 14 | Additionally, redistribution and use in source and binary forms, with 15 | or without modification, are permitted provided that the following 16 | conditions are met: 17 | 18 | * Redistributions of source code must retain the above copyright 19 | notice, this list of conditions and the following disclaimer. 20 | 21 | * Redistributions in binary form must reproduce the above copyright 22 | notice, this list of conditions and the following disclaimer in 23 | the documentation and/or other materials provided with the 24 | distribution. 25 | 26 | * Neither the name of UChicago Argonne, LLC, Argonne National 27 | Laboratory, ANL, the U.S. Government, nor the names of its 28 | contributors may be used to endorse or promote products derived 29 | from this software without specific prior written permission. 30 | 31 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 32 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 33 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 34 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 35 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 36 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 37 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 38 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 39 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 40 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 41 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 42 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /nodeology/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 3 | 4 | Copyright 2024. UChicago Argonne, LLC. This software was produced 5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 7 | U.S. Department of Energy. The U.S. Government has rights to use, 8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 11 | modified to produce derivative works, such modified software should 12 | be clearly marked, so as not to confuse it with the version available 13 | from ANL. 14 | 15 | Additionally, redistribution and use in source and binary forms, with 16 | or without modification, are permitted provided that the following 17 | conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in 24 | the documentation and/or other materials provided with the 25 | distribution. 26 | 27 | * Neither the name of UChicago Argonne, LLC, Argonne National 28 | Laboratory, ANL, the U.S. Government, nor the names of its 29 | contributors may be used to endorse or promote products derived 30 | from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 43 | POSSIBILITY OF SUCH DAMAGE. 44 | """ 45 | 46 | ### Initial Author <2024>: Xiangyu Yin 47 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Nodeology 2 | 3 | Thank you for your interest in contributing to Nodeology! This document provides guidelines and information about contributing to the project. 4 | 5 | ## Ways to Contribute 6 | 7 | ### Code Contributions 8 | 9 | 1. **Core Framework** 10 | 11 | - Bug fixes and improvements 12 | - Performance optimizations 13 | - New features 14 | - Test coverage 15 | 16 | 2. **Pre-built Components** 17 | - New node types 18 | - State definitions 19 | - Workflow templates 20 | - Domain-specific tools 21 | 22 | ### Documentation 23 | 24 | - API documentation 25 | - Usage examples 26 | - Tutorials 27 | - Best practices 28 | 29 | ### Research Collaborations 30 | 31 | 1. **Workflow Patterns** 32 | 33 | - Novel automation patterns 34 | - Optimization strategies 35 | - Human-AI interaction interfaces 36 | - Error handling approaches 37 | 38 | 2. **Scientific Integration** 39 | 40 | - Domain-specific applications 41 | - Instrument interfaces 42 | - Data processing pipelines 43 | - Analysis tools 44 | 45 | 3. **Evaluation Methods** 46 | 47 | - Benchmark development 48 | - Performance metrics 49 | - Reliability assessment 50 | - Comparison frameworks 51 | 52 | 4. **AI Integration** 53 | - Prompt optimization 54 | - Model evaluation 55 | - Hybrid processing 56 | - Knowledge integration 57 | 58 | ## Getting Started 59 | 60 | 1. **Development Setup** 61 | 62 | ```bash 63 | # Clone repository 64 | git clone https://github.com/xyin-anl/nodeology.git 65 | cd nodeology 66 | 67 | # Create virtual environment using venv or conda 68 | python -m venv venv 69 | source venv/bin/activate 70 | 71 | # Install dependencies 72 | pip install -r requirements.txt 73 | 74 | # Install pytest 75 | pip install pytest 76 | 77 | # Run tests 78 | pytest tests/ 79 | ``` 80 | 81 | ## Guidelines 82 | 83 | 1. **Code Contributions** 84 | 85 | - Fork repository 86 | - Create feature branch 87 | - Submit pull request 88 | - Use `black` formatter 89 | - Write clear commit messages 90 | - Add unit tests 91 | - Update/create documentation if possible 92 | - Include example usage if possible 93 | 94 | 2. **Documentation** 95 | 96 | - Clear, concise writing 97 | - Practical examples 98 | - Proper formatting 99 | - Complete coverage 100 | 101 | 3. **Community** 102 | - Be respectful 103 | - Provide constructive feedback 104 | - Help new users 105 | - Share knowledge 106 | -------------------------------------------------------------------------------- /examples/particle_trajectory_analysis.yaml: -------------------------------------------------------------------------------- 1 | name: TrajectoryWorkflow_03_13_2025_20_06_45 2 | state_defs: 3 | - current_node_type: str 4 | - previous_node_type: str 5 | - human_input: str 6 | - input: str 7 | - output: str 8 | - messages: List[dict] 9 | - mass: float 10 | - charge: float 11 | - initial_velocity: ndarray 12 | - E_field: ndarray 13 | - B_field: ndarray 14 | - confirm_parameters: bool 15 | - parameters_updater_output: str 16 | - positions: List[ndarray] 17 | - trajectory_plot: str 18 | - trajectory_plot_path: str 19 | - analysis_result: dict 20 | - continue_simulation: bool 21 | nodes: 22 | display_parameters: 23 | type: display_parameters 24 | next: ask_confirm_parameters 25 | ask_confirm_parameters: 26 | type: ask_confirm_parameters 27 | sink: confirm_parameters 28 | next: 29 | condition: confirm_parameters 30 | then: calculate_trajectory 31 | otherwise: ask_parameters_input 32 | ask_parameters_input: 33 | type: ask_parameters_input 34 | sink: human_input 35 | next: update_parameters 36 | update_parameters: 37 | type: prompt 38 | template: 'Update the parameters based on the user''s input. Current parameters: 39 | mass: {mass} charge: {charge} initial_velocity: {initial_velocity} E_field: 40 | {E_field} B_field: {B_field} User input: {human_input} Please return the updated 41 | parameters in JSON format. {{ "mass": float, "charge": float, "initial_velocity": 42 | list[float], "E_field": list[float], "B_field": list[float] }}' 43 | sink: parameters_updater_output 44 | next: display_parameters 45 | calculate_trajectory: 46 | type: calculate_trajectory 47 | sink: positions 48 | next: plot_trajectory 49 | plot_trajectory: 50 | type: plot_trajectory 51 | sink: [trajectory_plot, trajectory_plot_path] 52 | next: analyze_trajectory 53 | analyze_trajectory: 54 | type: prompt 55 | template: 'Analyze this particle trajectory plot. Please determine: 1. The type 56 | of motion (linear, circular, helical, or chaotic) 2. Key physical features (radius, 57 | period, pitch angle if applicable) 3. Explanation of the motion 4. Anomalies 58 | in the motion Output in JSON format: {{ "trajectory_type": "type_name", "key_features": 59 | { "feature1": value, "feature2": value }, "explanation": "detailed explanation", 60 | "anomalies": "anomaly description" }}' 61 | sink: analysis_result 62 | image_keys: trajectory_plot_path 63 | next: ask_continue_simulation 64 | ask_continue_simulation: 65 | type: ask_continue_simulation 66 | sink: continue_simulation 67 | next: 68 | condition: continue_simulation 69 | then: display_parameters 70 | otherwise: END 71 | entry_point: display_parameters 72 | llm: gemini/gemini-2.0-flash 73 | vlm: gemini/gemini-2.0-flash 74 | exit_commands: [stop workflow, quit workflow, terminate workflow] 75 | -------------------------------------------------------------------------------- /examples/public/hide-watermark.js: -------------------------------------------------------------------------------- 1 | function hideWatermark() { 2 | // Try multiple selector approaches 3 | const selectors = [ 4 | "#chainlit-copilot", 5 | ".cl-copilot-container", 6 | "[data-testid='copilot-container']", 7 | // Add any other potential selectors 8 | ]; 9 | 10 | for (const selector of selectors) { 11 | const elements = document.querySelectorAll(selector); 12 | 13 | elements.forEach(element => { 14 | // Try to access shadow DOM if it exists 15 | if (element.shadowRoot) { 16 | const watermarks = element.shadowRoot.querySelectorAll("a.watermark, .watermark, [class*='watermark']"); 17 | watermarks.forEach(watermark => { 18 | watermark.style.display = "none"; 19 | watermark.style.visibility = "hidden"; 20 | watermark.remove(); // Try to remove it completely 21 | }); 22 | } 23 | 24 | // Also check for watermarks in the regular DOM 25 | const directWatermarks = element.querySelectorAll("a.watermark, .watermark, [class*='watermark']"); 26 | directWatermarks.forEach(watermark => { 27 | watermark.style.display = "none"; 28 | watermark.style.visibility = "hidden"; 29 | watermark.remove(); // Try to remove it completely 30 | }); 31 | }); 32 | } 33 | 34 | // Add CSS to hide watermarks globally 35 | const style = document.createElement('style'); 36 | style.textContent = ` 37 | a.watermark, .watermark, [class*='watermark'] { 38 | display: none !important; 39 | visibility: hidden !important; 40 | opacity: 0 !important; 41 | pointer-events: none !important; 42 | } 43 | `; 44 | document.head.appendChild(style); 45 | } 46 | 47 | // More aggressive approach with mutation observer for the entire document 48 | function setupGlobalObserver() { 49 | const observer = new MutationObserver((mutations) => { 50 | let shouldCheck = false; 51 | 52 | for (const mutation of mutations) { 53 | if (mutation.addedNodes.length > 0) { 54 | shouldCheck = true; 55 | break; 56 | } 57 | } 58 | 59 | if (shouldCheck) { 60 | hideWatermark(); 61 | } 62 | }); 63 | 64 | observer.observe(document.body, { 65 | childList: true, 66 | subtree: true 67 | }); 68 | } 69 | 70 | // Run on page load 71 | document.addEventListener("DOMContentLoaded", function() { 72 | // Try immediately 73 | hideWatermark(); 74 | 75 | // Setup global observer 76 | setupGlobalObserver(); 77 | 78 | // Try again after delays to catch late-loading elements 79 | setTimeout(hideWatermark, 1000); 80 | setTimeout(hideWatermark, 3000); 81 | 82 | // Periodically check 83 | setInterval(hideWatermark, 5000); 84 | }); 85 | 86 | // Also run the script immediately in case the DOM is already loaded 87 | if (document.readyState === "complete" || document.readyState === "interactive") { 88 | hideWatermark(); 89 | setTimeout(setupGlobalObserver, 0); 90 | } 91 | -------------------------------------------------------------------------------- /examples/public/theme.json: -------------------------------------------------------------------------------- 1 | { 2 | "custom_fonts": [], 3 | "variables": { 4 | "light": { 5 | "--font-sans": "'Inter', sans-serif", 6 | "--font-mono": "source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace", 7 | "--background": "0 0% 100%", 8 | "--foreground": "0 0% 5%", 9 | "--card": "0 0% 100%", 10 | "--card-foreground": "0 0% 5%", 11 | "--popover": "0 0% 100%", 12 | "--popover-foreground": "0 0% 5%", 13 | "--primary": "150 40% 60%", 14 | "--primary-foreground": "0 0% 100%", 15 | "--secondary": "150 30% 90%", 16 | "--secondary-foreground": "150 30% 20%", 17 | "--muted": "0 0% 90%", 18 | "--muted-foreground": "150 15% 30%", 19 | "--accent": "0 0% 95%", 20 | "--accent-foreground": "150 30% 20%", 21 | "--destructive": "0 84.2% 60.2%", 22 | "--destructive-foreground": "210 40% 98%", 23 | "--border": "150 30% 75%", 24 | "--input": "150 30% 75%", 25 | "--ring": "150 40% 60%", 26 | "--radius": "0.75rem", 27 | "--sidebar-background": "0 0% 98%", 28 | "--sidebar-foreground": "240 5.3% 26.1%", 29 | "--sidebar-primary": "150 40% 60%", 30 | "--sidebar-primary-foreground": "0 0% 98%", 31 | "--sidebar-accent": "240 4.8% 95.9%", 32 | "--sidebar-accent-foreground": "240 5.9% 10%", 33 | "--sidebar-border": "220 13% 91%", 34 | "--sidebar-ring": "217.2 91.2% 59.8%" 35 | }, 36 | "dark": { 37 | "--font-sans": "'Inter', sans-serif", 38 | "--font-mono": "source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace", 39 | "--background": "0 0% 13%", 40 | "--foreground": "0 0% 93%", 41 | "--card": "0 0% 18%", 42 | "--card-foreground": "210 40% 98%", 43 | "--popover": "0 0% 18%", 44 | "--popover-foreground": "210 40% 98%", 45 | "--primary": "150 45% 50%", 46 | "--primary-foreground": "0 0% 100%", 47 | "--secondary": "150 35% 25%", 48 | "--secondary-foreground": "0 0% 98%", 49 | "--muted": "150 15% 30%", 50 | "--muted-foreground": "150 10% 80%", 51 | "--accent": "150 40% 40%", 52 | "--accent-foreground": "0 0% 98%", 53 | "--destructive": "0 62.8% 30.6%", 54 | "--destructive-foreground": "210 40% 98%", 55 | "--border": "150 30% 40%", 56 | "--input": "150 30% 40%", 57 | "--ring": "150 45% 50%", 58 | "--sidebar-background": "0 0% 9%", 59 | "--sidebar-foreground": "240 4.8% 95.9%", 60 | "--sidebar-primary": "150 45% 50%", 61 | "--sidebar-primary-foreground": "0 0% 100%", 62 | "--sidebar-accent": "150 25% 20%", 63 | "--sidebar-accent-foreground": "240 4.8% 95.9%", 64 | "--sidebar-border": "240 3.7% 15.9%", 65 | "--sidebar-ring": "217.2 91.2% 59.8%" 66 | } 67 | } 68 | } -------------------------------------------------------------------------------- /examples/.chainlit/config.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | # Whether to enable telemetry (default: true). No personal data is collected. 3 | enable_telemetry = false 4 | 5 | # List of environment variables to be provided by each user to use the app. 6 | user_env = [] 7 | 8 | # Duration (in seconds) during which the session is saved when the connection is lost 9 | session_timeout = 3600 10 | 11 | # Duration (in seconds) of the user session expiry 12 | user_session_timeout = 1296000 # 15 days 13 | 14 | # Enable third parties caching (e.g LangChain cache) 15 | cache = false 16 | 17 | # Authorized origins 18 | allow_origins = ["*"] 19 | 20 | [features] 21 | # Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript) 22 | unsafe_allow_html = false 23 | 24 | # Process and display mathematical expressions. This can clash with "$" characters in messages. 25 | latex = false 26 | 27 | # Automatically tag threads with the current chat profile (if a chat profile is used) 28 | auto_tag_thread = true 29 | 30 | # Allow users to edit their own messages 31 | edit_message = true 32 | 33 | # Authorize users to spontaneously upload files with messages 34 | [features.spontaneous_file_upload] 35 | enabled = true 36 | # Define accepted file types using MIME types 37 | # Examples: 38 | # 1. For specific file types: 39 | # accept = ["image/jpeg", "image/png", "application/pdf"] 40 | # 2. For all files of certain type: 41 | # accept = ["image/*", "audio/*", "video/*"] 42 | # 3. For specific file extensions: 43 | # accept = { "application/octet-stream" = [".xyz", ".pdb"] } 44 | # Note: Using "*/*" is not recommended as it may cause browser warnings 45 | accept = ["*/*"] 46 | max_files = 20 47 | max_size_mb = 500 48 | 49 | [features.audio] 50 | # Sample rate of the audio 51 | sample_rate = 24000 52 | 53 | [UI] 54 | # Name of the assistant. 55 | name = "Assistant" 56 | 57 | # default_theme = "light" 58 | 59 | # layout = "wide" 60 | 61 | # Description of the assistant. This is used for HTML tags. 62 | # description = "" 63 | 64 | # Chain of Thought (CoT) display mode. Can be "hidden", "tool_call" or "full". 65 | cot = "full" 66 | 67 | # Specify a CSS file that can be used to customize the user interface. 68 | # The CSS file can be served from the public directory or via an external link. 69 | # custom_css = "/public/test.css" 70 | 71 | # Specify a Javascript file that can be used to customize the user interface. 72 | # The Javascript file can be served from the public directory. 73 | custom_js = "/public/hide-watermark.js" 74 | 75 | # Specify a custom meta image url. 76 | # custom_meta_image_url = "https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png" 77 | 78 | # Specify a custom build directory for the frontend. 79 | # This can be used to customize the frontend code. 80 | # Be careful: If this is a relative path, it should not start with a slash. 81 | # custom_build = "./public/build" 82 | 83 | # Specify optional one or more custom links in the header. 84 | # [[UI.header_links]] 85 | # name = "Nodeology" 86 | # icon_url = "https://avatars.githubusercontent.com/u/128686189?s=200&v=4" 87 | # url = "https://github.com/xyin-anl/Nodeology" 88 | 89 | [meta] 90 | generated_by = "2.2.1" 91 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # Apple 165 | .DS_Store 166 | 167 | # Nodeology 168 | artifacts/ 169 | logs/ 170 | 171 | # PyPI 172 | dist/ 173 | *.egg-info/ 174 | .pypirc 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > [!IMPORTANT] 2 | > This package is actively in development, and breaking changes may occur. 3 | 4 |
5 | Nodeology Logo 6 |

7 |
8 | 9 | ## 🤖 Foundation AI-Enhanced Scientific Workflow 10 | 11 | Foundation AI holds enormous potential for scientific research, especially in analyzing unstructured data, automating complex reasoning tasks, and simplifying human-computer interactions. However, integrating foundation AI models like LLMs and VLMs into scientific workflows poses challenges: handling diverse data types beyond text and images, managing model inaccuracies (hallucinations), and adapting general-purpose models to highly specialized scientific contexts. 12 | 13 | `nodeology` addresses these challenges by combining the strengths of foundation AI with traditional scientific methods and expert oversight. Built on `langgraph`'s state machine framework, it simplifies creating robust, AI-driven workflows through an intuitive, accessible interface. Originally developed at Argonne National Lab, the framework enables researchers—especially those without extensive programming experience—to quickly design and deploy full-stack AI workflows simply using prompt templates and existing functions as reusable nodes. 14 | 15 | Key features include: 16 | 17 | - Easy creation of AI-integrated workflows without complex syntax 18 | - Flexible and composable node architecture for various tasks 19 | - Seamless human-in-the-loop interactions for expert oversight 20 | - Portable workflow templates for collaboration and reproducibility 21 | - Quickly spin up simple chatbots for immediate AI interaction 22 | - Built-in tracing and telemetry for workflow monitoring and optimization 23 | 24 | ## 🚀 Getting Started 25 | 26 | ### Install the package 27 | 28 | To use the latest development version: 29 | 30 | ```bash 31 | pip install git+https://github.com/xyin-anl/Nodeology.git 32 | ``` 33 | 34 | To use the latest release version: 35 | 36 | ```bash 37 | pip install nodeology 38 | ``` 39 | 40 | ### Access foundation models 41 | 42 | Nodeology supports various cloud-based/local foundation models via [LiteLLM](https://docs.litellm.ai/docs/), see [provider list](https://docs.litellm.ai/docs/providers). Most of cloud-based models usage requires setting up API key. For example: 43 | 44 | ```bash 45 | # For OpenAI models 46 | export OPENAI_API_KEY='your-api-key' 47 | 48 | # For Anthropic models 49 | export ANTHROPIC_API_KEY='your-api-key' 50 | 51 | # For Gemini models 52 | export GEMINI_API_KEY='your-api-key' 53 | 54 | # For Together AI hosted open weight models 55 | export TOGETHER_API_KEY='your-api-key' 56 | ``` 57 | 58 | > **💡 Tip:** The field of foundation models is evolving rapidly with new and improved models emerging frequently. As of **February 2025**, we recommend the following models based on their strengths: 59 | > 60 | > - **gpt-4o**: Excellent for broad general knowledge, writing tasks, and conversational interactions 61 | > - **o3-mini**: Good balance of math, coding, and reasoning capabilities at a lower price point 62 | > - **anthropic/claude-3.7**: Strong performance in general knowledge, math, science, and coding with well-constrained outputs 63 | > - **gemini/gemini-2.0-flash**: Effective for general knowledge tasks with a large context window for processing substantial information 64 | > - **together_ai/deepseek-ai/DeepSeek-R1**: Exceptional reasoning, math, science, and coding capabilities with transparent thinking processes 65 | 66 | **For Argonne Users:** if you are within Argonne network, you will have access to OpenAI's models through Argonne's ARGO inference service and ALCF's open weights model inference service for free. Please check this [link](https://gist.github.com/xyin-anl/0cc744a7862e153414857b15fe31b239) to see how to use them 67 | 68 | ### Langfuse Tracing (Optional) 69 | 70 | Nodeology supports [Langfuse](https://langfuse.com/) for observability and tracing of LLM/VLM calls. To use Langfuse: 71 | 72 | 1. Set up a Langfuse account and get your API keys 73 | 2. Configure Langfuse with your keys: 74 | 75 | ```bash 76 | # Set environment variables 77 | export LANGFUSE_PUBLIC_KEY='your-public-key' 78 | export LANGFUSE_SECRET_KEY='your-secret-key' 79 | export LANGFUSE_HOST='https://cloud.langfuse.com' # Or your self-hosted URL 80 | ``` 81 | 82 | Or configure programmatically: 83 | 84 | ```python 85 | from nodeology.client import configure_langfuse 86 | 87 | configure_langfuse( 88 | public_key='your-public-key', 89 | secret_key='your-secret-key', 90 | host='https://cloud.langfuse.com' # Optional 91 | ) 92 | ``` 93 | 94 | ### Chainlit Interface (Optional) 95 | 96 | Nodeology supports [Chainlit](https://docs.chainlit.io/get-started/overview) for creating chat-based user interfaces. To use this feature, simply set `ui=True` when running your workflow: 97 | 98 | ```python 99 | # Create your workflow 100 | workflow = MyWorkflow() 101 | 102 | # Run with UI enabled 103 | workflow.run(ui=True) 104 | ``` 105 | 106 | This will automatically launch a Chainlit server with a chat interface for interacting with your workflow. The interface preserves your workflow's state and configuration, allowing users to interact with it through a user-friendly chat interface. 107 | 108 | When the Chainlit server starts, you can access the interface through your web browser at `http://localhost:8000` by default. 109 | 110 | ## 🧪 Illustrating Examples 111 | 112 | ### [Writing Improvement](https://github.com/xyin-anl/Nodeology/examples/writing_improvement.py) 113 | 114 |
115 | 116 | 117 | 118 |
119 | 120 | ### [Trajectory Analysis](https://github.com/xyin-anl/Nodeology/examples/trajectory_analysis.py) 121 | 122 |
123 | 124 | 125 | 126 |
127 | 128 | ## 🔬 Scientific Applications 129 | 130 | - [PEAR: Ptychography automation framework](https://arxiv.org/abs/2410.09034) 131 | - [AutoScriptCopilot: TEM experiment control](https://github.com/xyin-anl/AutoScriptCopilot) 132 | 133 | ## 👥 Contributing & Collaboration 134 | 135 | We welcome comments, feedbacks, bugs report, code contributions and research collaborations. Please refer to CONTRIBUTING.md 136 | 137 | If you find `nodeology` useful and may inspire your research, please use the **Cite this repository** function 138 | -------------------------------------------------------------------------------- /examples/writing_improvement.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 3 | 4 | Copyright 2024. UChicago Argonne, LLC. This software was produced 5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 7 | U.S. Department of Energy. The U.S. Government has rights to use, 8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 11 | modified to produce derivative works, such modified software should 12 | be clearly marked, so as not to confuse it with the version available 13 | from ANL. 14 | 15 | Additionally, redistribution and use in source and binary forms, with 16 | or without modification, are permitted provided that the following 17 | conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in 24 | the documentation and/or other materials provided with the 25 | distribution. 26 | 27 | * Neither the name of UChicago Argonne, LLC, Argonne National 28 | Laboratory, ANL, the U.S. Government, nor the names of its 29 | contributors may be used to endorse or promote products derived 30 | from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 43 | POSSIBILITY OF SUCH DAMAGE. 44 | """ 45 | 46 | ### Initial Author <2025>: Xiangyu Yin 47 | 48 | import json 49 | from nodeology.state import State 50 | from nodeology.node import Node, as_node 51 | from nodeology.workflow import Workflow 52 | 53 | import chainlit as cl 54 | from chainlit import Message, AskActionMessage, run_sync 55 | from langgraph.graph import END 56 | 57 | 58 | # 1. Define your state 59 | class TextAnalysisState(State): 60 | analysis: dict # Analysis results 61 | text: str # Enhanced text 62 | continue_improving: bool # Whether to continue improving 63 | 64 | 65 | # 2. Create nodes 66 | @as_node(sink="text") 67 | def parse_human_input(human_input: str): 68 | return human_input 69 | 70 | 71 | analyze_text = Node( 72 | prompt_template="""Text to analyze: {text} 73 | 74 | Analyze the above text for: 75 | - Clarity (1-10) 76 | - Grammar (1-10) 77 | - Style (1-10) 78 | - Suggestions for improvement 79 | 80 | Output as JSON: 81 | {{ 82 | "clarity_score": int, 83 | "grammar_score": int, 84 | "style_score": int, 85 | "suggestions": str 86 | }} 87 | """, 88 | sink="analysis", 89 | sink_format="json", 90 | ) 91 | 92 | 93 | def report_analysis(state, client, **kwargs): 94 | analysis = json.loads(state["analysis"]) 95 | run_sync( 96 | Message( 97 | content="Below is the analysis of the text:", 98 | elements=[cl.CustomElement(name="DataDisplay", props={"data": analysis})], 99 | ).send() 100 | ) 101 | return state 102 | 103 | 104 | analyze_text.post_process = report_analysis 105 | 106 | improve_text = Node( 107 | prompt_template="""Text to improve: {text} 108 | 109 | Analysis: {analysis} 110 | 111 | Rewrite the text incorporating the suggestions while maintaining the original meaning. 112 | Focus on clarity, grammar, and style improvements. Return the improved text only.""", 113 | sink="text", 114 | ) 115 | 116 | 117 | def report_improvement(state, client, **kwargs): 118 | text_md = f"{state['text']}" 119 | run_sync( 120 | Message( 121 | content="Below is the improved text:", elements=[cl.Text(content=text_md)] 122 | ).send() 123 | ) 124 | return state 125 | 126 | 127 | improve_text.post_process = report_improvement 128 | 129 | 130 | @as_node(sink="continue_improving") 131 | def ask_continue_improve(): 132 | res = run_sync( 133 | AskActionMessage( 134 | content="Would you like to further improve the text?", 135 | timeout=300, 136 | actions=[ 137 | cl.Action( 138 | name="continue", 139 | payload={"value": "continue"}, 140 | label="Continue Improving", 141 | ), 142 | cl.Action( 143 | name="finish", 144 | payload={"value": "finish"}, 145 | label="Finish", 146 | ), 147 | ], 148 | ).send() 149 | ) 150 | 151 | # Return the user's choice 152 | if res and res.get("payload").get("value") == "continue": 153 | return True 154 | else: 155 | return False 156 | 157 | 158 | # 3. Create workflow 159 | class TextEnhancementWorkflow(Workflow): 160 | state_schema = TextAnalysisState 161 | 162 | def create_workflow(self): 163 | # Add nodes 164 | self.add_node("parse_human_input", parse_human_input) 165 | self.add_node("analyze", analyze_text) 166 | self.add_node("improve", improve_text) 167 | self.add_node("ask_continue", ask_continue_improve) 168 | 169 | # Connect nodes 170 | self.add_flow("parse_human_input", "analyze") 171 | self.add_flow("analyze", "improve") 172 | self.add_flow("improve", "ask_continue") 173 | 174 | # Add conditional flow based on user's choice 175 | self.add_conditional_flow( 176 | "ask_continue", 177 | "continue_improving", 178 | "analyze", 179 | END, 180 | ) 181 | 182 | # Set entry point 183 | self.set_entry("parse_human_input") 184 | 185 | # Compile workflow 186 | self.compile( 187 | interrupt_before=["parse_human_input"], 188 | interrupt_before_phrases={ 189 | "parse_human_input": "Please enter the text to analyze." 190 | }, 191 | ) 192 | 193 | 194 | # 4. Run workflow 195 | workflow = TextEnhancementWorkflow( 196 | llm_name="gemini/gemini-2.0-flash", save_artifacts=True 197 | ) 198 | 199 | if __name__ == "__main__": 200 | result = workflow.run(ui=True) 201 | -------------------------------------------------------------------------------- /examples/public/elements/DataDisplay.jsx: -------------------------------------------------------------------------------- 1 | import { Card, CardHeader, CardTitle, CardContent } from "@/components/ui/card" 2 | import { Badge } from "@/components/ui/badge" 3 | import { Separator } from "@/components/ui/separator" 4 | import { ScrollArea } from "@/components/ui/scroll-area" 5 | import { Button } from "@/components/ui/button" 6 | import { Copy, ChevronDown, ChevronRight } from "lucide-react" 7 | import { useState } from "react" 8 | 9 | export default function DataDisplay() { 10 | // Data is passed via props 11 | const data = props.data || {}; 12 | const title = props.title || "Data"; 13 | const badge = props.badge || null; 14 | const maxHeight = props.maxHeight || "300px"; 15 | const showScrollArea = props.showScrollArea !== false; 16 | const collapsible = props.collapsible !== false; 17 | const theme = props.theme || "default"; // default, compact, or expanded 18 | 19 | // State for collapsible sections 20 | const [expandedSections, setExpandedSections] = useState({}); 21 | 22 | // Toggle section expansion 23 | const toggleSection = (key) => { 24 | setExpandedSections(prev => ({ 25 | ...prev, 26 | [key]: !prev[key] 27 | })); 28 | }; 29 | 30 | // Copy value to clipboard 31 | const copyToClipboard = (value) => { 32 | let textToCopy; 33 | 34 | if (Array.isArray(value)) { 35 | // Format arrays properly for copying 36 | textToCopy = JSON.stringify(value); 37 | } else if (typeof value === 'object' && value !== null) { 38 | textToCopy = JSON.stringify(value, null, 2); 39 | } else { 40 | textToCopy = String(value); 41 | } 42 | 43 | navigator.clipboard.writeText(textToCopy); 44 | // Could use sonner for toast notification here 45 | }; 46 | 47 | // Helper to format values with scientific notation for very small/large numbers 48 | const formatValue = (value) => { 49 | if (value === undefined || value === null) return "N/A"; 50 | 51 | if (Array.isArray(value) || (typeof value === 'object' && value.length)) { 52 | return formatArray(value); 53 | } 54 | 55 | if (typeof value === 'object' && value !== null) { 56 | return null; // Handled separately in the render 57 | } 58 | 59 | if (typeof value === 'number') { 60 | // Use scientific notation for very small or large numbers 61 | if (Math.abs(value) < 0.001 || Math.abs(value) > 10000) { 62 | return value.toExponential(4); 63 | } 64 | // Format with appropriate precision 65 | return Number.isInteger(value) ? value.toString() : value.toFixed(4).replace(/\.?0+$/, ''); 66 | } 67 | 68 | if (typeof value === 'boolean') { 69 | return value ? "true" : "false"; 70 | } 71 | 72 | return String(value); 73 | }; 74 | 75 | // Helper function to format arrays nicely 76 | const formatArray = (arr) => { 77 | if (!arr) return "N/A"; 78 | if (typeof arr === 'string') return arr; 79 | 80 | try { 81 | // Handle both array-like objects and actual arrays 82 | if (arr.length > 10) { 83 | const displayed = Array.from(arr).slice(0, 10); 84 | return `[${displayed.map(val => 85 | typeof val === 'object' ? '{...}' : 86 | typeof val === 'number' ? formatValue(val) : 87 | JSON.stringify(val) 88 | ).join(", ")}, ... ${arr.length - 10} more]`; 89 | } 90 | 91 | return `[${Array.from(arr).map(val => 92 | typeof val === 'object' ? '{...}' : 93 | typeof val === 'number' ? formatValue(val) : 94 | JSON.stringify(val) 95 | ).join(", ")}]`; 96 | } catch (e) { 97 | return String(arr); 98 | } 99 | }; 100 | 101 | // Helper function to render nested objects 102 | const renderNestedObject = (obj, path = "", level = 0) => { 103 | if (!obj || typeof obj !== 'object') return {String(obj || "N/A")}; 104 | 105 | if (Array.isArray(obj)) { 106 | return ( 107 |
108 | {formatArray(obj)} 109 | 117 |
118 | ); 119 | } 120 | 121 | const isExpanded = expandedSections[path] !== false; // Default to expanded 122 | 123 | return ( 124 |
0 ? "pl-4" : ""}`}> 125 | {Object.entries(obj).map(([key, value], index) => { 126 | const currentPath = path ? `${path}.${key}` : key; 127 | const isObject = typeof value === 'object' && value !== null; 128 | 129 | return ( 130 |
131 | {index > 0 && level === 0 && } 132 | 133 |
134 |
135 | {isObject && collapsible && ( 136 | 147 | )} 148 |
{key}
149 |
150 |
151 | 152 |
0 ? "border-l-2 border-gray-200 dark:border-gray-700 pl-2" : ""} 155 | ${isObject && expandedSections[currentPath] === false ? "hidden" : ""} 156 | `}> 157 | {isObject ? 158 | renderNestedObject(value, currentPath, level + 1) : 159 |
160 |
{formatValue(value)}
161 | 169 |
170 | } 171 |
172 |
173 | ); 174 | })} 175 |
176 | ); 177 | }; 178 | 179 | // Apply theme styles 180 | const getThemeStyles = () => { 181 | switch (theme) { 182 | case 'compact': 183 | return "text-xs"; 184 | case 'expanded': 185 | return "text-base"; 186 | default: 187 | return "text-sm"; 188 | } 189 | }; 190 | 191 | return ( 192 | 193 | 194 |
195 | 196 | {title} 197 | 198 |
199 | {badge && ( 200 | {badge} 201 | )} 202 | 211 |
212 |
213 |
214 | 215 | {Object.keys(data).length === 0 ? ( 216 |
No data available
217 | ) : ( 218 | showScrollArea ? ( 219 | 220 | {renderNestedObject(data)} 221 | 222 | ) : ( 223 | renderNestedObject(data) 224 | ) 225 | )} 226 |
227 |
228 | ) 229 | } -------------------------------------------------------------------------------- /nodeology/log.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 3 | 4 | Copyright 2024. UChicago Argonne, LLC. This software was produced 5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 7 | U.S. Department of Energy. The U.S. Government has rights to use, 8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 11 | modified to produce derivative works, such modified software should 12 | be clearly marked, so as not to confuse it with the version available 13 | from ANL. 14 | 15 | Additionally, redistribution and use in source and binary forms, with 16 | or without modification, are permitted provided that the following 17 | conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in 24 | the documentation and/or other materials provided with the 25 | distribution. 26 | 27 | * Neither the name of UChicago Argonne, LLC, Argonne National 28 | Laboratory, ANL, the U.S. Government, nor the names of its 29 | contributors may be used to endorse or promote products derived 30 | from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 43 | POSSIBILITY OF SUCH DAMAGE. 44 | """ 45 | 46 | ### Initial Author <2024>: Xiangyu Yin 47 | 48 | import os, logging, logging.config 49 | import sys 50 | 51 | logger = logging.getLogger(__name__) 52 | 53 | 54 | def setup_logging(log_dir, log_name, debug_mode=False, base_dir=None): 55 | """Configure the logging system with console and/or file handlers. 56 | 57 | Args: 58 | log_dir (str): Directory where log files will be stored 59 | log_name (str): Name of the log file (without extension) 60 | debug_mode (bool): If True, only console logging with debug level is enabled 61 | base_dir (str, optional): Base directory to prepend to log_dir 62 | """ 63 | # Get root logger 64 | root_logger = logging.getLogger() 65 | 66 | # Remove any existing handlers first 67 | for handler in root_logger.handlers[:]: 68 | handler.close() # Properly close handlers 69 | root_logger.removeHandler(handler) 70 | 71 | # Create handlers 72 | console_handler = logging.StreamHandler(sys.stdout) 73 | 74 | # Set the root logger level to DEBUG to capture all messages 75 | root_logger.setLevel(logging.DEBUG) 76 | 77 | # Use base_dir if provided, otherwise use log_dir directly 78 | full_log_dir = os.path.join(base_dir, log_dir) if base_dir else log_dir 79 | 80 | if not os.path.exists(full_log_dir): 81 | os.makedirs(full_log_dir) 82 | 83 | log_file_path = f"{full_log_dir}/{log_name}.log" 84 | 85 | if os.path.isfile(log_file_path): 86 | log_print_color( 87 | f"WARNING: {log_file_path} already exists and will be overwritten.", 88 | "red", 89 | ) 90 | 91 | # Create file handler for both modes 92 | file_handler = logging.FileHandler(log_file_path, "w") 93 | 94 | if debug_mode: 95 | # Debug mode configuration 96 | # Console shows DEBUG and above 97 | console_handler.setLevel(logging.DEBUG) 98 | console_format = logging.Formatter( 99 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s" 100 | ) 101 | console_handler.setFormatter(console_format) 102 | 103 | # File handler captures everything 104 | file_handler.setLevel(logging.DEBUG) 105 | file_format = logging.Formatter( 106 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s" 107 | ) 108 | file_handler.setFormatter(file_format) 109 | else: 110 | # Production mode configuration 111 | # Console only shows PRINTLOG and WARNING+ messages 112 | console_handler.setLevel(logging.PRINTLOG) 113 | console_format = logging.Formatter("%(message)s") 114 | console_handler.setFormatter(console_format) 115 | 116 | # File handler captures everything (DEBUG and above) 117 | file_handler.setLevel(logging.DEBUG) 118 | file_format = logging.Formatter( 119 | "%(asctime)s - %(message)s", datefmt="%Y%m%d-%H:%M:%S" 120 | ) 121 | file_handler.setFormatter(file_format) 122 | 123 | # Add handlers to root logger 124 | root_logger.addHandler(console_handler) 125 | root_logger.addHandler(file_handler) 126 | 127 | # Configure third-party loggers to be less verbose 128 | # This prevents them from cluttering the console 129 | for logger_name in logging.root.manager.loggerDict: 130 | if logger_name != __name__ and not logger_name.startswith("nodeology"): 131 | third_party_logger = logging.getLogger(logger_name) 132 | if debug_mode: 133 | # In debug mode, third-party loggers show WARNING and above 134 | third_party_logger.setLevel(logging.WARNING) 135 | else: 136 | # In production mode, third-party loggers show ERROR and above 137 | third_party_logger.setLevel(logging.ERROR) 138 | 139 | # Store handlers in logger for later cleanup 140 | root_logger.handlers_to_close = root_logger.handlers[:] 141 | 142 | 143 | def cleanup_logging(): 144 | """Properly clean up logging handlers to prevent resource leaks.""" 145 | root_logger = logging.getLogger() 146 | 147 | # Close and remove any existing handlers 148 | if hasattr(root_logger, "handlers_to_close"): 149 | for handler in root_logger.handlers_to_close: 150 | try: 151 | handler.close() 152 | except: 153 | pass # Ignore errors during cleanup 154 | if handler in root_logger.handlers: 155 | root_logger.removeHandler(handler) 156 | root_logger.handlers_to_close = [] 157 | 158 | 159 | # https://stackoverflow.com/questions/2183233/how-to-add-a-custom-loglevel-to-pythons-logging-facility/35804945#35804945 160 | def add_logging_level(levelName, levelNum, methodName=None): 161 | """Add a new logging level to the logging module. 162 | 163 | Args: 164 | levelName (str): Name of the new level (e.g., 'TRACE') 165 | levelNum (int): Numeric value for the level 166 | methodName (str, optional): Method name to add. Defaults to levelName.lower() 167 | 168 | Raises: 169 | AttributeError: If levelName or methodName already exists 170 | """ 171 | if not methodName: 172 | methodName = levelName.lower() 173 | 174 | if hasattr(logging, levelName): 175 | raise AttributeError("{} already defined in logging module".format(levelName)) 176 | if hasattr(logging, methodName): 177 | raise AttributeError("{} already defined in logging module".format(methodName)) 178 | if hasattr(logging.getLoggerClass(), methodName): 179 | raise AttributeError("{} already defined in logger class".format(methodName)) 180 | 181 | # This method was inspired by the answers to Stack Overflow post 182 | # http://stackoverflow.com/q/2183233/2988730, especially 183 | # http://stackoverflow.com/a/13638084/2988730 184 | def logForLevel(self, message, *args, **kwargs): 185 | if self.isEnabledFor(levelNum): 186 | self._log(levelNum, message, args, **kwargs) 187 | 188 | def logToRoot(message, *args, **kwargs): 189 | logging.log(levelNum, message, *args, **kwargs) 190 | 191 | logging.addLevelName(levelNum, levelName) 192 | setattr(logging, levelName, levelNum) 193 | setattr(logging.getLoggerClass(), methodName, logForLevel) 194 | setattr(logging, methodName, logToRoot) 195 | 196 | 197 | def log_print_color(text, color="", print_to_console=True): 198 | """Print colored text to console and log it to file. 199 | 200 | Args: 201 | text (str): Text to print and log 202 | color (str): Color name ('green', 'red', 'blue', 'yellow', or '' for white) 203 | print_to_console (bool): If True, print the text to console 204 | """ 205 | # Define color codes as constants at the top of the function 206 | COLOR_CODES = { 207 | "green": "\033[92m", 208 | "red": "\033[91m", 209 | "blue": "\033[94m", 210 | "yellow": "\033[93m", 211 | "": "\033[97m", # default white 212 | } 213 | 214 | # Get color code from dictionary, defaulting to white 215 | ansi_code = COLOR_CODES.get(color, COLOR_CODES[""]) 216 | 217 | if print_to_console: 218 | print(ansi_code + text + "\033[0m") 219 | logger.logonly(text) 220 | -------------------------------------------------------------------------------- /tests/test_state.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 3 | 4 | Copyright 2024. UChicago Argonne, LLC. This software was produced 5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 7 | U.S. Department of Energy. The U.S. Government has rights to use, 8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 11 | modified to produce derivative works, such modified software should 12 | be clearly marked, so as not to confuse it with the version available 13 | from ANL. 14 | 15 | Additionally, redistribution and use in source and binary forms, with 16 | or without modification, are permitted provided that the following 17 | conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in 24 | the documentation and/or other materials provided with the 25 | distribution. 26 | 27 | * Neither the name of UChicago Argonne, LLC, Argonne National 28 | Laboratory, ANL, the U.S. Government, nor the names of its 29 | contributors may be used to endorse or promote products derived 30 | from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 43 | POSSIBILITY OF SUCH DAMAGE. 44 | """ 45 | 46 | ### Initial Author <2024>: Xiangyu Yin 47 | 48 | import numpy as np 49 | from typing import List, Dict, Union 50 | import pytest 51 | from nodeology.state import ( 52 | process_state_definitions, 53 | _resolve_state_type, 54 | _type_from_str, 55 | ) 56 | 57 | 58 | class TestTypeResolution: 59 | """Tests for basic type resolution functionality""" 60 | 61 | def test_primitive_types(self): 62 | """Test resolution of primitive types""" 63 | assert _resolve_state_type("str") == str 64 | assert _resolve_state_type("int") == int 65 | assert _resolve_state_type("float") == float 66 | assert _resolve_state_type("bool") == bool 67 | assert _resolve_state_type("ndarray") == np.ndarray 68 | 69 | def test_list_types(self): 70 | """Test resolution of List types""" 71 | assert _resolve_state_type("List[str]") == List[str] 72 | assert _resolve_state_type("List[int]") == List[int] 73 | assert _resolve_state_type("List[bool]") == List[bool] 74 | 75 | def test_dict_types(self): 76 | """Test resolution of Dict types""" 77 | assert _resolve_state_type("Dict[str, int]") == Dict[str, int] 78 | assert _resolve_state_type("Dict[str, List[int]]") == Dict[str, List[int]] 79 | assert ( 80 | _resolve_state_type("Dict[str, Dict[str, bool]]") 81 | == Dict[str, Dict[str, bool]] 82 | ) 83 | 84 | def test_nested_types(self): 85 | """Test resolution of deeply nested types""" 86 | complex_type = "Dict[str, List[Dict[str, Union[int, str]]]]" 87 | expected = Dict[str, List[Dict[str, Union[int, str]]]] 88 | assert _resolve_state_type(complex_type) == expected 89 | 90 | def test_numpy_composite_types(self): 91 | """Test resolution of composite types involving numpy arrays""" 92 | assert _resolve_state_type("List[ndarray]") == List[np.ndarray] 93 | assert _resolve_state_type("Dict[str, ndarray]") == Dict[str, np.ndarray] 94 | assert ( 95 | _resolve_state_type("Dict[str, List[ndarray]]") 96 | == Dict[str, List[np.ndarray]] 97 | ) 98 | assert _resolve_state_type("Union[ndarray, int]") == Union[np.ndarray, int] 99 | 100 | def test_type_conversion_symmetry(self): 101 | """Test that type conversion is symmetrical""" 102 | test_cases = [ 103 | str, 104 | int, 105 | float, 106 | bool, 107 | List[str], 108 | List[int], 109 | Dict[str, int], 110 | Dict[str, List[str]], 111 | Union[str, int], 112 | Union[str, List[int]], 113 | np.ndarray, 114 | List[np.ndarray], 115 | Dict[str, np.ndarray], 116 | Dict[str, List[np.ndarray]], 117 | Union[np.ndarray, int], 118 | Union[List[np.ndarray], Dict[str, np.ndarray]], 119 | ] 120 | 121 | for type_obj in test_cases: 122 | # Convert type to string 123 | type_str = _type_from_str(type_obj) 124 | # Convert string back to type 125 | resolved_type = _resolve_state_type(type_str) 126 | # Verify they're equivalent 127 | assert str(resolved_type) == str( 128 | type_obj 129 | ), f"Type conversion failed for {type_obj}" 130 | 131 | 132 | class TestErrorHandling: 133 | """Tests for error handling in type resolution""" 134 | 135 | def test_invalid_type_names(self): 136 | """Test handling of invalid type names""" 137 | with pytest.raises(ValueError, match="Unknown state type"): 138 | _resolve_state_type("InvalidType") 139 | 140 | with pytest.raises(ValueError, match="Unknown state type"): 141 | _resolve_state_type("List[InvalidType]") 142 | 143 | def test_malformed_type_strings(self): 144 | """Test handling of malformed type strings""" 145 | invalid_types = [ 146 | "List[str", # Missing closing bracket 147 | "Dict[str]", # Missing value type 148 | "Dict[str,]", # Empty value type 149 | "Union[]", # Empty union 150 | "List[]", # Empty list type 151 | "[str]", # Invalid format 152 | ] 153 | 154 | for invalid_type in invalid_types: 155 | with pytest.raises(ValueError): 156 | _resolve_state_type(invalid_type) 157 | 158 | def test_invalid_dict_formats(self): 159 | """Test handling of invalid dictionary formats""" 160 | with pytest.raises(ValueError): 161 | _resolve_state_type("Dict") 162 | 163 | with pytest.raises(ValueError): 164 | _resolve_state_type("Dict[str, int, bool]") 165 | 166 | 167 | class TestStateDefinitionProcessing: 168 | """Tests for state definition processing""" 169 | 170 | def test_dict_state_definition(self): 171 | """Test processing of dictionary state definitions""" 172 | state_def = {"name": "test_field", "type": "str"} 173 | result = process_state_definitions([state_def], {}) 174 | assert result == [("test_field", str)] 175 | 176 | def test_custom_type_processing(self): 177 | """Test processing with custom types in registry""" 178 | 179 | class CustomType: 180 | pass 181 | 182 | registry = {"CustomType": CustomType} 183 | 184 | # Test direct custom type reference 185 | assert process_state_definitions(["CustomType"], registry) == [CustomType] 186 | 187 | # Test custom type in dictionary definition 188 | state_def = { 189 | "name": "custom_field", 190 | "type": "str", 191 | } # Can't use CustomType directly in type string 192 | result = process_state_definitions([state_def], registry) 193 | assert result == [("custom_field", str)] 194 | 195 | def test_list_state_definition(self): 196 | """Test processing of list format state definitions""" 197 | # Single list definition 198 | state_def = ["test_field", "List[int]"] 199 | result = process_state_definitions([state_def], {}) 200 | assert result == [("test_field", List[int])] 201 | 202 | # Multiple list definitions 203 | state_defs = [["field1", "str"], ["field2", "List[int]"]] 204 | result = process_state_definitions(state_defs, {}) 205 | assert result == [("field1", str), ("field2", List[int])] 206 | 207 | def test_process_mixed_definitions(self): 208 | """Test processing mixed format state definitions""" 209 | state_defs = [ 210 | {"name": "field1", "type": "str"}, 211 | ["field2", "List[int]"], 212 | {"name": "field3", "type": "Dict[str, bool]"}, 213 | ] 214 | result = process_state_definitions(state_defs, {}) 215 | assert result == [ 216 | ("field1", str), 217 | ("field2", List[int]), 218 | ("field3", Dict[str, bool]), 219 | ] 220 | 221 | def test_numpy_array_state_definition(self): 222 | """Test processing of numpy array state definitions""" 223 | # Test direct ndarray type 224 | state_def = {"name": "array_field", "type": "ndarray"} 225 | result = process_state_definitions([state_def], {}) 226 | assert result == [("array_field", np.ndarray)] 227 | 228 | # Test in list format 229 | state_def = ["array_field2", "ndarray"] 230 | result = process_state_definitions([state_def], {}) 231 | assert result == [("array_field2", np.ndarray)] 232 | 233 | # Test in mixed definitions 234 | state_defs = [ 235 | {"name": "field1", "type": "str"}, 236 | ["array_field", "ndarray"], 237 | {"name": "field3", "type": "Dict[str, bool]"}, 238 | ] 239 | result = process_state_definitions(state_defs, {}) 240 | assert result == [ 241 | ("field1", str), 242 | ("array_field", np.ndarray), 243 | ("field3", Dict[str, bool]), 244 | ] 245 | 246 | def test_numpy_composite_state_definition(self): 247 | """Test processing of composite state definitions with numpy arrays""" 248 | state_defs = [ 249 | {"name": "array_list", "type": "List[ndarray]"}, 250 | {"name": "array_dict", "type": "Dict[str, ndarray]"}, 251 | ["nested_arrays", "Dict[str, List[ndarray]]"], 252 | {"name": "mixed_type", "type": "Union[ndarray, int]"}, 253 | ] 254 | 255 | result = process_state_definitions(state_defs, {}) 256 | assert result == [ 257 | ("array_list", List[np.ndarray]), 258 | ("array_dict", Dict[str, np.ndarray]), 259 | ("nested_arrays", Dict[str, List[np.ndarray]]), 260 | ("mixed_type", Union[np.ndarray, int]), 261 | ] 262 | -------------------------------------------------------------------------------- /examples/public/logo_light.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | LangGraph Nodeology 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /examples/trajectory_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 3 | 4 | Copyright 2024. UChicago Argonne, LLC. This software was produced 5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 7 | U.S. Department of Energy. The U.S. Government has rights to use, 8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 11 | modified to produce derivative works, such modified software should 12 | be clearly marked, so as not to confuse it with the version available 13 | from ANL. 14 | 15 | Additionally, redistribution and use in source and binary forms, with 16 | or without modification, are permitted provided that the following 17 | conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in 24 | the documentation and/or other materials provided with the 25 | distribution. 26 | 27 | * Neither the name of UChicago Argonne, LLC, Argonne National 28 | Laboratory, ANL, the U.S. Government, nor the names of its 29 | contributors may be used to endorse or promote products derived 30 | from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 43 | POSSIBILITY OF SUCH DAMAGE. 44 | """ 45 | 46 | ### Initial Author <2025>: Xiangyu Yin 47 | 48 | import json 49 | import tempfile 50 | import numpy as np 51 | from scipy.integrate import solve_ivp 52 | from typing import List, Dict 53 | from langgraph.graph import END 54 | import chainlit as cl 55 | from chainlit import Message, AskUserMessage, AskActionMessage, run_sync 56 | from nodeology.state import State 57 | from nodeology.node import Node, as_node 58 | from nodeology.workflow import Workflow 59 | import plotly.graph_objects as go 60 | 61 | 62 | class TrajectoryState(State): 63 | """State for particle trajectory analysis workflow""" 64 | 65 | # Parameters 66 | mass: float # Particle mass (kg) 67 | charge: float # Particle charge (C) 68 | initial_velocity: np.ndarray # Initial velocity vector [vx, vy, vz] 69 | E_field: np.ndarray # Electric field vector [Ex, Ey, Ez] 70 | B_field: np.ndarray # Magnetic field vector [Bx, By, Bz] 71 | 72 | # Confirm parameters 73 | confirm_parameters: bool 74 | 75 | # Parameters updater 76 | parameters_updater_output: str 77 | 78 | # Calculation results 79 | positions: List[np.ndarray] # Position vectors at each time point 80 | 81 | # Image 82 | trajectory_plot: str 83 | trajectory_plot_path: str 84 | 85 | # Analysis results 86 | analysis_result: Dict 87 | 88 | # Continue simulation 89 | continue_simulation: bool 90 | 91 | 92 | @as_node(sink=[]) 93 | def display_parameters( 94 | mass: float, 95 | charge: float, 96 | initial_velocity: np.ndarray, 97 | E_field: np.ndarray, 98 | B_field: np.ndarray, 99 | ): 100 | # Create a dictionary of parameters for the custom element 101 | parameters = { 102 | "Mass (kg)": mass, 103 | "Charge (C)": charge, 104 | "Initial Velocity (m/s)": initial_velocity.tolist(), 105 | "Electric Field (N/C)": E_field.tolist(), 106 | "Magnetic Field (T)": B_field.tolist(), 107 | } 108 | 109 | # Use the custom element to display parameters 110 | run_sync( 111 | Message( 112 | content="Below are the current simulation parameters:", 113 | elements=[ 114 | cl.CustomElement( 115 | name="DataDisplay", 116 | props={ 117 | "data": parameters, 118 | "title": "Particle Parameters", 119 | "badge": "Configured", 120 | "showScrollArea": False, 121 | }, 122 | ) 123 | ], 124 | ).send() 125 | ) 126 | return 127 | 128 | 129 | @as_node(sink="confirm_parameters") 130 | def ask_confirm_parameters(): 131 | res = run_sync( 132 | AskActionMessage( 133 | content="Are you happy with the parameters?", 134 | timeout=300, 135 | actions=[ 136 | cl.Action( 137 | name="yes", 138 | payload={"value": "yes"}, 139 | label="Yes", 140 | ), 141 | cl.Action( 142 | name="no", 143 | payload={"value": "no"}, 144 | label="No", 145 | ), 146 | ], 147 | ).send() 148 | ) 149 | if res and res.get("payload").get("value") == "yes": 150 | return True 151 | else: 152 | return False 153 | 154 | 155 | @as_node(sink=["human_input"]) 156 | def ask_parameters_input(): 157 | human_input = run_sync( 158 | AskUserMessage( 159 | content="Please let me know how you want to change any of the parameters :)", 160 | timeout=300, 161 | ).send() 162 | )["output"] 163 | return human_input 164 | 165 | 166 | parameters_updater = Node( 167 | node_type="parameters_updater", 168 | prompt_template="""Update the parameters based on the user's input. 169 | 170 | Current parameters: 171 | mass: {mass} 172 | charge: {charge} 173 | initial_velocity: {initial_velocity} 174 | E_field: {E_field} 175 | B_field: {B_field} 176 | 177 | User input: 178 | {human_input} 179 | 180 | Please return the updated parameters in JSON format. 181 | {{ 182 | "mass": float, 183 | "charge": float, 184 | "initial_velocity": list[float], 185 | "E_field": list[float], 186 | "B_field": list[float] 187 | }} 188 | """, 189 | sink="parameters_updater_output", 190 | sink_format="json", 191 | ) 192 | 193 | 194 | def parameters_updater_transform(state, client, **kwargs): 195 | params_dict = json.loads(state["parameters_updater_output"]) 196 | state["mass"] = params_dict["mass"] 197 | state["charge"] = params_dict["charge"] 198 | state["initial_velocity"] = np.array(params_dict["initial_velocity"]) 199 | state["E_field"] = np.array(params_dict["E_field"]) 200 | state["B_field"] = np.array(params_dict["B_field"]) 201 | return state 202 | 203 | 204 | parameters_updater.post_process = parameters_updater_transform 205 | 206 | 207 | @as_node(sink=["positions"]) 208 | def calculate_trajectory( 209 | mass: float, 210 | charge: float, 211 | initial_velocity: np.ndarray, 212 | E_field: np.ndarray, 213 | B_field: np.ndarray, 214 | ) -> List[np.ndarray]: 215 | """Calculate particle trajectory under Lorentz force with automatic time steps""" 216 | B_magnitude = np.linalg.norm(B_field) 217 | if B_magnitude == 0 or charge == 0: 218 | # Handle the case where B=0 or charge=0 (no magnetic force) 219 | cyclotron_period = 1e-6 # Arbitrary time scale 220 | else: 221 | cyclotron_frequency = np.abs(charge) * B_magnitude / mass 222 | cyclotron_period = 2 * np.pi / cyclotron_frequency 223 | 224 | # Determine total simulation time and time steps 225 | num_periods = 5 # Simulate over 5 cyclotron periods 226 | num_points_per_period = 100 # At least 100 points per period 227 | total_time = num_periods * cyclotron_period 228 | total_points = int(num_periods * num_points_per_period) 229 | time_points = np.linspace(0, total_time, total_points) 230 | 231 | def lorentz_force(t, state): 232 | """Compute acceleration from Lorentz force""" 233 | vel = state[3:] 234 | force = charge * (E_field + np.cross(vel, B_field)) 235 | acc = force / mass 236 | return np.concatenate([vel, acc]) 237 | 238 | # Initial state vector [x, y, z, vx, vy, vz] 239 | initial_position = np.array([0.0, 0.0, 0.0]) 240 | initial_state = np.concatenate([initial_position, initial_velocity]) 241 | 242 | # Solve equations of motion 243 | solution = solve_ivp( 244 | lorentz_force, 245 | (time_points[0], time_points[-1]), 246 | initial_state, 247 | t_eval=time_points, 248 | method="RK45", 249 | rtol=1e-8, 250 | ) 251 | 252 | if not solution.success: 253 | return [np.zeros(3) for _ in range(len(time_points))] 254 | 255 | return [solution.y[:3, i] for i in range(len(time_points))] 256 | 257 | 258 | @as_node(sink=["trajectory_plot", "trajectory_plot_path"]) 259 | def plot_trajectory(positions: List[np.ndarray]) -> str: 260 | """Plot 3D particle trajectory and save to temp file 261 | 262 | Returns: 263 | tuple: (Plotly figure object, path to saved plot image) 264 | """ 265 | positions = np.array(positions) 266 | 267 | # Create a Plotly 3D scatter plot 268 | fig = go.Figure( 269 | data=[ 270 | go.Scatter3d( 271 | x=positions[:, 0], 272 | y=positions[:, 1], 273 | z=positions[:, 2], 274 | mode="lines", 275 | line=dict(width=4, color="green"), 276 | ) 277 | ] 278 | ) 279 | 280 | # Update layout 281 | fig.update_layout( 282 | scene=dict(xaxis_title="X (m)", yaxis_title="Y (m)", zaxis_title="Z (m)"), 283 | ) 284 | 285 | # Save to temp file 286 | temp_path = tempfile.mktemp(suffix=".png") 287 | fig.write_image(temp_path) 288 | 289 | run_sync( 290 | Message( 291 | content="Below is the trajectory plot of the particle:", 292 | elements=[cl.Plotly(figure=fig)], 293 | ).send() 294 | ) 295 | 296 | return fig, temp_path 297 | 298 | 299 | trajectory_analyzer = Node( 300 | node_type="trajectory_analyzer", 301 | prompt_template="""Analyze this particle trajectory plot. 302 | 303 | Please determine: 304 | 1. The type of motion (linear, circular, helical, or chaotic) 305 | 2. Key physical features (radius, period, pitch angle if applicable) 306 | 3. Explanation of the motion 307 | 4. Anomalies in the motion 308 | Output in JSON format: 309 | {{ 310 | "trajectory_type": "type_name", 311 | "key_features": { 312 | "feature1": value, 313 | "feature2": value 314 | }, 315 | "explanation": "detailed explanation", 316 | "anomalies": "anomaly description" 317 | }}""", 318 | sink="analysis_result", 319 | sink_format="json", 320 | image_keys=["trajectory_plot_path"], 321 | ) 322 | 323 | 324 | def display_trajectory_analyzer_result(state, client, **kwargs): 325 | state["analysis_result"] = json.loads(state["analysis_result"]) 326 | 327 | # Use the custom element to display analysis results 328 | run_sync( 329 | Message( 330 | content="Below is the analysis of the particle trajectory:", 331 | elements=[ 332 | cl.CustomElement( 333 | name="DataDisplay", 334 | props={ 335 | "data": state["analysis_result"], 336 | "title": "Trajectory Analysis", 337 | "badge": state["analysis_result"].get( 338 | "trajectory_type", "Unknown" 339 | ), 340 | "maxHeight": "400px", 341 | }, 342 | ) 343 | ], 344 | ).send() 345 | ) 346 | return state 347 | 348 | 349 | trajectory_analyzer.post_process = display_trajectory_analyzer_result 350 | 351 | 352 | @as_node(sink="continue_simulation") 353 | def ask_continue_simulation(): 354 | res = run_sync( 355 | AskActionMessage( 356 | content="Would you like to continue the simulation?", 357 | timeout=300, 358 | actions=[ 359 | cl.Action( 360 | name="continue", 361 | payload={"value": "continue"}, 362 | label="Continue Simulation", 363 | ), 364 | cl.Action( 365 | name="finish", 366 | payload={"value": "finish"}, 367 | label="Finish", 368 | ), 369 | ], 370 | ).send() 371 | ) 372 | 373 | # Return the user's choice 374 | if res and res.get("payload").get("value") == "continue": 375 | return True 376 | else: 377 | return False 378 | 379 | 380 | class TrajectoryWorkflow(Workflow): 381 | """Workflow for particle trajectory analysis""" 382 | 383 | def create_workflow(self): 384 | """Define the workflow graph structure""" 385 | # Add nodes 386 | self.add_node("display_parameters", display_parameters) 387 | self.add_node("ask_confirm_parameters", ask_confirm_parameters) 388 | self.add_node("ask_parameters_input", ask_parameters_input) 389 | self.add_node("update_parameters", parameters_updater) 390 | self.add_node("calculate_trajectory", calculate_trajectory) 391 | self.add_node("plot_trajectory", plot_trajectory) 392 | self.add_node("analyze_trajectory", trajectory_analyzer) 393 | self.add_node("ask_continue_simulation", ask_continue_simulation) 394 | 395 | self.add_flow("display_parameters", "ask_confirm_parameters") 396 | self.add_conditional_flow( 397 | "ask_confirm_parameters", 398 | "confirm_parameters", 399 | then="calculate_trajectory", 400 | otherwise="ask_parameters_input", 401 | ) 402 | self.add_flow("ask_parameters_input", "update_parameters") 403 | self.add_flow("update_parameters", "display_parameters") 404 | self.add_flow("calculate_trajectory", "plot_trajectory") 405 | self.add_flow("plot_trajectory", "analyze_trajectory") 406 | self.add_flow("analyze_trajectory", "ask_continue_simulation") 407 | self.add_conditional_flow( 408 | "ask_continue_simulation", 409 | "continue_simulation", 410 | then="display_parameters", 411 | otherwise=END, 412 | ) 413 | 414 | # Set entry point 415 | self.set_entry("display_parameters") 416 | 417 | # Compile workflow 418 | self.compile() 419 | 420 | 421 | if __name__ == "__main__": 422 | workflow = TrajectoryWorkflow( 423 | state_defs=TrajectoryState, 424 | llm_name="gemini/gemini-2.0-flash", 425 | vlm_name="gemini/gemini-2.0-flash", 426 | debug_mode=False, 427 | ) 428 | 429 | # # Export workflow to YAML file 430 | # workflow.to_yaml("particle_trajectory_analysis.yaml") 431 | 432 | # # Print workflow graph 433 | # workflow.graph.get_graph().draw_mermaid_png( 434 | # output_file_path="particle_trajectory_analysis.png" 435 | # ) 436 | 437 | initial_state = { 438 | "mass": 9.1093837015e-31, # electron mass in kg 439 | "charge": -1.602176634e-19, # electron charge in C 440 | "initial_velocity": np.array([1e6, 1e6, 1e6]), # 1e6 m/s in each direction 441 | "E_field": np.array([5e6, 1e6, 5e6]), # 1e6 N/C in y-direction 442 | "B_field": np.array( 443 | [0.0, 0.0, 50000.0] 444 | ), # deliberately typo to be caught by validation 445 | } 446 | 447 | result = workflow.run(init_values=initial_state, ui=True) 448 | -------------------------------------------------------------------------------- /nodeology/client.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 3 | 4 | Copyright 2024. UChicago Argonne, LLC. This software was produced 5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 7 | U.S. Department of Energy. The U.S. Government has rights to use, 8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 11 | modified to produce derivative works, such modified software should 12 | be clearly marked, so as not to confuse it with the version available 13 | from ANL. 14 | 15 | Additionally, redistribution and use in source and binary forms, with 16 | or without modification, are permitted provided that the following 17 | conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in 24 | the documentation and/or other materials provided with the 25 | distribution. 26 | 27 | * Neither the name of UChicago Argonne, LLC, Argonne National 28 | Laboratory, ANL, the U.S. Government, nor the names of its 29 | contributors may be used to endorse or promote products derived 30 | from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 43 | POSSIBILITY OF SUCH DAMAGE. 44 | """ 45 | 46 | ### Initial Author <2024>: Xiangyu Yin 47 | 48 | import os, base64, json, getpass 49 | from abc import ABC, abstractmethod 50 | import litellm 51 | from datetime import datetime 52 | 53 | 54 | def get_client(model_name, **kwargs): 55 | """ 56 | Factory function to create appropriate client based on model name. 57 | 58 | Handles three scenarios: 59 | 1. Just model name (e.g., "gpt-4o") - Let LiteLLM figure out the provider 60 | 2. Model name with provider keyword (e.g., model="gpt-4o", provider="openai") 61 | 3. Provider/name convention (e.g., "openai/gpt-4o") 62 | 63 | Args: 64 | model_name (str): Name of the model to use 65 | **kwargs: Additional arguments including optional 'provider' 66 | 67 | Returns: 68 | LLM_Client or VLM_Client: Appropriate client instance for the requested model 69 | """ 70 | # Handle special clients first 71 | if model_name == "mock": 72 | return Mock_LLM_Client(**kwargs) 73 | elif model_name == "mock_vlm": 74 | return Mock_VLM_Client(**kwargs) 75 | 76 | # Get provider from kwargs if specified (Scenario 2) 77 | provider = kwargs.pop("provider", None) 78 | 79 | # Get tracing_enabled from kwargs 80 | tracing_enabled = kwargs.pop("tracing_enabled", False) 81 | 82 | # Handle provider/model format (Scenario 3) 83 | if "/" in model_name and provider is None: 84 | provider, model_name = model_name.split("/", 1) 85 | 86 | # Create LiteLLM client - for Scenario 1, provider will be None 87 | try: 88 | return LiteLLM_Client( 89 | model_name, provider=provider, tracing_enabled=tracing_enabled, **kwargs 90 | ) 91 | except Exception as e: 92 | raise ValueError(f"Error creating client for model {model_name}: {e}") 93 | 94 | 95 | def configure_langfuse(public_key=None, secret_key=None, host=None, enabled=True): 96 | """ 97 | Configure Langfuse for observability. 98 | 99 | Args: 100 | public_key (str, optional): Langfuse public key. Defaults to LANGFUSE_PUBLIC_KEY env var. 101 | secret_key (str, optional): Langfuse secret key. Defaults to LANGFUSE_SECRET_KEY env var. 102 | host (str, optional): Langfuse host URL. Defaults to LANGFUSE_HOST env var or https://cloud.langfuse.com. 103 | enabled (bool, optional): Whether to enable Langfuse tracing. Defaults to True. 104 | """ 105 | if not enabled: 106 | litellm.success_callback = [] 107 | litellm.failure_callback = [] 108 | return 109 | 110 | # Set environment variables if provided 111 | if public_key: 112 | os.environ["LANGFUSE_PUBLIC_KEY"] = public_key 113 | if secret_key: 114 | os.environ["LANGFUSE_SECRET_KEY"] = secret_key 115 | if host: 116 | os.environ["LANGFUSE_HOST"] = host 117 | 118 | litellm.success_callback = ["langfuse"] 119 | litellm.failure_callback = ["langfuse"] 120 | 121 | 122 | class LLM_Client(ABC): 123 | """Base abstract class for Language Model clients.""" 124 | 125 | def __init__(self) -> None: 126 | pass 127 | 128 | @abstractmethod 129 | def __call__(self, messages, **kwargs) -> str: 130 | """ 131 | Process messages and return model response. 132 | 133 | Args: 134 | messages (list): List of message dictionaries with 'role' and 'content' 135 | **kwargs: Additional model-specific parameters 136 | 137 | Returns: 138 | str: Model's response text 139 | """ 140 | pass 141 | 142 | 143 | class VLM_Client(LLM_Client): 144 | """Base abstract class for Vision Language Model clients.""" 145 | 146 | def __init__(self) -> None: 147 | super().__init__() 148 | 149 | @abstractmethod 150 | def process_images(self, messages, images, **kwargs) -> list: 151 | """ 152 | Process and format images for the model. 153 | 154 | Args: 155 | messages (list): List of message dictionaries 156 | images (list): List of image file paths 157 | **kwargs: Additional processing parameters 158 | 159 | Returns: 160 | list: Updated messages with processed images 161 | """ 162 | pass 163 | 164 | 165 | class Mock_LLM_Client(LLM_Client): 166 | def __init__(self, response=None, **kwargs) -> None: 167 | super().__init__() 168 | self.response = response 169 | self.model_name = "mock" 170 | 171 | def __call__(self, messages, **kwargs) -> str: 172 | response = ( 173 | "\n".join([msg["role"] + ": " + msg["content"] for msg in messages]) 174 | if self.response is None 175 | else self.response 176 | ) 177 | return response 178 | 179 | 180 | class Mock_VLM_Client(VLM_Client): 181 | def __init__(self, response=None, **kwargs) -> None: 182 | super().__init__() 183 | self.response = response 184 | self.model_name = "mock_vlm" 185 | 186 | def __call__(self, messages, images=None, **kwargs) -> str: 187 | if images is not None: 188 | messages = self.process_images(messages, images) 189 | if self.response is None: 190 | message_parts = [] 191 | for msg in messages: 192 | content = msg["content"] 193 | if isinstance(content, str): 194 | message_parts.append(f"{msg['role']}: {content}") 195 | else: # content is already a list of text/image objects 196 | parts = [] 197 | for item in content: 198 | if item["type"] == "text": 199 | parts.append(item["text"]) 200 | elif item["type"] == "image": 201 | parts.append(f"[Image: {item['image_url']['url']}]") 202 | message_parts.append(f"{msg['role']}: {' '.join(parts)}") 203 | return "\n".join(message_parts) 204 | return self.response 205 | 206 | def process_images(self, messages, images, **kwargs) -> list: 207 | # Make a copy to avoid modifying the original 208 | messages = messages.copy() 209 | 210 | # Simply append a placeholder for each image 211 | for img in images: 212 | if isinstance(messages[-1]["content"], str): 213 | messages[-1]["content"] = [ 214 | {"type": "text", "text": messages[-1]["content"]}, 215 | {"type": "image", "image_url": {"url": f"mock_processed_{img}"}}, 216 | ] 217 | elif isinstance(messages[-1]["content"], list): 218 | messages[-1]["content"].append( 219 | {"type": "image", "image_url": {"url": f"mock_processed_{img}"}} 220 | ) 221 | return messages 222 | 223 | 224 | class LiteLLM_Client(VLM_Client): 225 | """ 226 | Unified client for all LLM/VLM providers using LiteLLM. 227 | Supports both text and image inputs across multiple providers. 228 | """ 229 | 230 | def __init__( 231 | self, 232 | model_name, 233 | provider=None, 234 | model_options=None, 235 | api_key=None, 236 | tracing_enabled=False, 237 | ) -> None: 238 | """ 239 | Initialize LiteLLM client. 240 | 241 | Args: 242 | model_name (str): Name of the model to use 243 | provider (str, optional): Provider name (openai, anthropic, etc.) 244 | model_options (dict): Model parameters like temperature and top_p 245 | api_key (str, optional): API key for the specified provider 246 | tracing_enabled (bool, optional): Whether to enable Langfuse tracing. Defaults to False. 247 | """ 248 | super().__init__() 249 | self.model_options = model_options if model_options else {} 250 | self.tracing_enabled = tracing_enabled 251 | 252 | # Set API key if provided 253 | if api_key and provider: 254 | os.environ[f"{provider.upper()}_API_KEY"] = api_key 255 | 256 | # Construct the model name for LiteLLM based on whether provider is specified 257 | # If provider is None, LiteLLM will infer the provider from the model name 258 | self.model_name = f"{provider}/{model_name}" if provider else model_name 259 | 260 | def collect_langfuse_metadata( 261 | self, 262 | workflow=None, 263 | node=None, 264 | **kwargs, 265 | ): 266 | """ 267 | Collect metadata for Langfuse tracing from workflow and node information. 268 | 269 | Args: 270 | workflow: The workflow instance (optional) 271 | node: The node instance (optional) 272 | **kwargs: Additional metadata to include 273 | 274 | Returns: 275 | dict: Metadata dictionary formatted for Langfuse 276 | """ 277 | metadata = {} 278 | 279 | timestamp = datetime.now().strftime("%Y%m%d") 280 | user_id = getpass.getuser() 281 | session_id_str = f"{user_id}-{timestamp}" 282 | 283 | metadata["trace_metadata"] = {} 284 | 285 | # Extract workflow metadata if available 286 | if workflow: 287 | # Use workflow class name as generation name 288 | metadata["generation_name"] = workflow.__class__.__name__ 289 | session_id_str += f"-{workflow.__class__.__name__}" 290 | 291 | # Create a generation ID based on workflow name and timestamp 292 | metadata["generation_id"] = f"gen-{workflow.name}-{timestamp}" 293 | 294 | # Add user ID if available 295 | if hasattr(workflow, "user_name"): 296 | metadata["trace_user_id"] = workflow.user_name 297 | else: 298 | metadata["trace_user_id"] = user_id 299 | 300 | # Extract node metadata if available 301 | if node: 302 | # Use node type as trace name 303 | metadata["trace_name"] = node.node_type 304 | session_id_str += f"-{node.node_type}" 305 | 306 | # Add node metadata to trace metadata 307 | metadata["trace_metadata"].update( 308 | { 309 | "required_keys": node.required_keys, 310 | "sink": node.sink, 311 | "sink_format": node.sink_format, 312 | "image_keys": node.image_keys, 313 | "use_conversation": node.use_conversation, 314 | "prompt_template": node.prompt_template, 315 | } 316 | ) 317 | 318 | # Add session ID based on timestamp 319 | metadata["session_id"] = f"session-{session_id_str}" 320 | 321 | # Add any additional metadata from kwargs 322 | metadata["trace_metadata"].update(kwargs) 323 | 324 | return metadata 325 | 326 | def process_images(self, messages, images): 327 | """ 328 | Process and format images for the model using LiteLLM's format. 329 | 330 | Args: 331 | messages (list): List of message dictionaries 332 | images (list): List of image file paths 333 | 334 | Returns: 335 | list: Updated messages with processed images 336 | """ 337 | # Make a copy to avoid modifying the original 338 | messages = messages.copy() 339 | 340 | # Convert images to base64 341 | image_contents = [] 342 | for img in images: 343 | with open(img, "rb") as image_file: 344 | base64_image = base64.b64encode(image_file.read()).decode("utf-8") 345 | image_contents.append( 346 | { 347 | "type": "image_url", 348 | "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, 349 | } 350 | ) 351 | 352 | # Add images to the last message 353 | if isinstance(messages[-1]["content"], str): 354 | messages[-1]["content"] = [ 355 | {"type": "text", "text": messages[-1]["content"]} 356 | ] + image_contents 357 | elif isinstance(messages[-1]["content"], list): 358 | messages[-1]["content"] += image_contents 359 | 360 | return messages 361 | 362 | def __call__( 363 | self, messages, images=None, format=None, workflow=None, node=None, **kwargs 364 | ) -> str: 365 | """ 366 | Process messages and return model response using LiteLLM. 367 | 368 | Args: 369 | messages (list): List of message dictionaries 370 | images (list, optional): List of image file paths 371 | format (str, optional): Response format (e.g., 'json') 372 | workflow (optional): The workflow instance for metadata extraction 373 | node (optional): The node instance for metadata extraction 374 | **kwargs: Additional parameters including metadata for Langfuse 375 | 376 | Returns: 377 | str: Model's response text 378 | """ 379 | # Process images if provided 380 | if images is not None: 381 | messages = self.process_images(messages, images) 382 | 383 | # Set up response format if needed 384 | response_format = {"type": "json_object"} if format == "json" else None 385 | 386 | # Extract Langfuse metadata only if tracing is enabled 387 | langfuse_metadata = {} 388 | if self.tracing_enabled: 389 | langfuse_metadata = self.collect_langfuse_metadata( 390 | workflow=workflow, 391 | node=node, 392 | **kwargs, 393 | ) 394 | 395 | try: 396 | # Use LiteLLM's built-in retry mechanism with Langfuse metadata 397 | response = litellm.completion( 398 | model=self.model_name, 399 | messages=messages, 400 | response_format=response_format, 401 | num_retries=3, 402 | metadata=langfuse_metadata if self.tracing_enabled else {}, 403 | **self.model_options, 404 | ) 405 | 406 | content = response.choices[0].message.content 407 | 408 | # Validate JSON if requested 409 | if format == "json": 410 | try: 411 | json.loads(content) 412 | except json.JSONDecodeError: 413 | raise ValueError(f"Invalid JSON response from {self.model_name}") 414 | 415 | return content 416 | 417 | except Exception as e: 418 | raise ValueError( 419 | f"Failed to generate response from {self.model_name}. Error: {str(e)}" 420 | ) 421 | -------------------------------------------------------------------------------- /nodeology/state.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 3 | 4 | Copyright 2024. UChicago Argonne, LLC. This software was produced 5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 7 | U.S. Department of Energy. The U.S. Government has rights to use, 8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 11 | modified to produce derivative works, such modified software should 12 | be clearly marked, so as not to confuse it with the version available 13 | from ANL. 14 | 15 | Additionally, redistribution and use in source and binary forms, with 16 | or without modification, are permitted provided that the following 17 | conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in 24 | the documentation and/or other materials provided with the 25 | distribution. 26 | 27 | * Neither the name of UChicago Argonne, LLC, Argonne National 28 | Laboratory, ANL, the U.S. Government, nor the names of its 29 | contributors may be used to endorse or promote products derived 30 | from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 43 | POSSIBILITY OF SUCH DAMAGE. 44 | """ 45 | 46 | ### Initial Author <2024>: Xiangyu Yin 47 | 48 | import json 49 | import logging 50 | import numpy as np 51 | import plotly.graph_objects as go 52 | import plotly.io as pio 53 | 54 | logger = logging.getLogger(__name__) 55 | from typing import TypedDict, List, Dict, Union, Any, get_origin, get_args 56 | from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer 57 | import msgpack 58 | 59 | StateBaseT = Union[str, int, float, bool, np.ndarray] 60 | 61 | """ 62 | State management module for nodeology. 63 | Handles type definitions, state processing, and state registry management. 64 | """ 65 | 66 | 67 | class State(TypedDict): 68 | """ 69 | Base state class representing the core state structure. 70 | Contains node information, input/output data, and message history. 71 | """ 72 | 73 | current_node_type: str 74 | previous_node_type: str 75 | human_input: str 76 | input: str 77 | output: str 78 | messages: List[dict] 79 | 80 | 81 | def _split_by_top_level_comma(s: str) -> List[str]: 82 | """Helper function to split by comma while respecting brackets""" 83 | parts = [] 84 | current = [] 85 | bracket_count = 0 86 | 87 | for char in s: 88 | if char == "[": 89 | bracket_count += 1 90 | elif char == "]": 91 | bracket_count -= 1 92 | elif char == "," and bracket_count == 0: 93 | parts.append("".join(current).strip()) 94 | current = [] 95 | continue 96 | current.append(char) 97 | 98 | if current: 99 | parts.append("".join(current).strip()) 100 | return parts 101 | 102 | 103 | def _resolve_state_type(type_str: str): 104 | """ 105 | Resolve string representations of types to actual Python types. 106 | """ 107 | if not hasattr(_resolve_state_type, "_cache"): 108 | _resolve_state_type._cache = {} 109 | 110 | if type_str in _resolve_state_type._cache: 111 | return _resolve_state_type._cache[type_str] 112 | 113 | try: 114 | # Handle basic types 115 | if type_str in ( 116 | "str", 117 | "int", 118 | "float", 119 | "bool", 120 | "dict", 121 | "list", 122 | "bytes", 123 | "tuple", 124 | "ndarray", 125 | ): 126 | if type_str == "ndarray": 127 | return np.ndarray 128 | return eval(type_str) 129 | 130 | if type_str.startswith("List[") and type_str.endswith("]"): 131 | inner_type = type_str[5:-1] 132 | return List[_resolve_state_type(inner_type)] 133 | 134 | elif type_str.startswith("Dict[") and type_str.endswith("]"): 135 | inner_str = type_str[5:-1] 136 | parts = _split_by_top_level_comma(inner_str) 137 | if len(parts) != 2: 138 | raise ValueError(f"Invalid Dict type format: {type_str}") 139 | 140 | key_type = _resolve_state_type(parts[0]) 141 | value_type = _resolve_state_type(parts[1]) 142 | return Dict[key_type, value_type] 143 | 144 | elif type_str.startswith("Union[") and type_str.endswith("]"): 145 | inner_str = type_str[6:-1] 146 | types = [ 147 | _resolve_state_type(t) for t in _split_by_top_level_comma(inner_str) 148 | ] 149 | return Union[tuple(types)] 150 | 151 | else: 152 | raise ValueError(f"Unknown state type: {type_str}") 153 | 154 | except Exception as e: 155 | raise ValueError(f"Failed to resolve type '{type_str}': {str(e)}") 156 | 157 | 158 | def _process_dict_state_def(state_def: Dict) -> tuple: 159 | """ 160 | Process a dictionary-format state definition. 161 | 162 | Supports both formats: 163 | - {'name': 'type'} format 164 | - {'name': str, 'type': str} format 165 | 166 | Args: 167 | state_def (Dict): Dictionary containing state definition 168 | 169 | Returns: 170 | tuple: (name, resolved_type) 171 | 172 | Raises: 173 | ValueError: If state definition is missing required fields 174 | """ 175 | if len(state_def) == 1: 176 | # Handle {'name': 'type'} format 177 | name, type_str = next(iter(state_def.items())) 178 | else: 179 | # Handle {'name': str, 'type': str} format 180 | name = state_def.get("name") 181 | type_str = state_def.get("type") 182 | if not name or not type_str: 183 | raise ValueError(f"Invalid state definition: {state_def}") 184 | 185 | state_type = _resolve_state_type(type_str) 186 | return (name, state_type) 187 | 188 | 189 | def _process_list_state_def(state_def: List) -> List: 190 | """ 191 | Process a list-format state definition. 192 | 193 | Supports two formats: 194 | 1. Single definition: [name, type_str] 195 | 2. Multiple definitions: [[name1, type_str1], [name2, type_str2], ...] 196 | 197 | Args: 198 | state_def (List): List containing state definitions 199 | 200 | Returns: 201 | List[tuple]: List of (name, resolved_type) tuples 202 | 203 | Raises: 204 | ValueError: If state definition format is invalid 205 | """ 206 | if len(state_def) == 2 and isinstance(state_def[0], str): 207 | # Single list format [name, type_str] 208 | name, type_str = state_def 209 | state_type = _resolve_state_type(type_str) 210 | return [(name, state_type)] 211 | else: 212 | processed_lists = [] 213 | for item in state_def: 214 | if not isinstance(item, list) or len(item) != 2: 215 | raise ValueError(f"Invalid state definition item: {item}") 216 | name, type_str = item 217 | state_type = _resolve_state_type(type_str) 218 | processed_lists.append((name, state_type)) 219 | return processed_lists 220 | 221 | 222 | def process_state_definitions(state_defs: List, state_registry: dict): 223 | """ 224 | Process state definitions from template format to internal format. 225 | 226 | Supports multiple input formats: 227 | - Dictionary format: {'name': 'type'} or {'name': str, 'type': str} 228 | - List format: [name, type_str] or [[name1, type_str1], ...] 229 | - String format: References to pre-defined states in state_registry 230 | 231 | Args: 232 | state_defs (List): List of state definitions in various formats 233 | state_registry (dict): Registry of pre-defined states 234 | 235 | Returns: 236 | List[tuple]: List of processed (name, type) pairs 237 | 238 | Raises: 239 | ValueError: If state definition format is invalid or state type is unknown 240 | """ 241 | processed_state_defs = [] 242 | 243 | for state_def in state_defs: 244 | if isinstance(state_def, dict): 245 | processed_state_defs.append(_process_dict_state_def(state_def)) 246 | elif isinstance(state_def, list): 247 | processed_state_defs.extend(_process_list_state_def(state_def)) 248 | elif isinstance(state_def, str): 249 | if state_def in state_registry: 250 | processed_state_defs.append(state_registry[state_def]) 251 | else: 252 | raise ValueError(f"Unknown state type: {state_def}") 253 | else: 254 | raise ValueError( 255 | f"Invalid state definition format: {state_def}. Must be a string, " 256 | "[name, type] list, or {'name': 'type'} dictionary" 257 | ) 258 | 259 | return processed_state_defs 260 | 261 | 262 | def _type_from_str(type_obj: type) -> str: 263 | """ 264 | Convert a Python type object to a string representation that _resolve_state_type can parse. 265 | """ 266 | # Add handling for numpy arrays 267 | if type_obj is np.ndarray: 268 | return "ndarray" 269 | 270 | # Handle basic types 271 | if type_obj in (str, int, float, bool, dict, list, bytes, tuple): 272 | return type_obj.__name__ 273 | 274 | # Get the origin type 275 | origin = get_origin(type_obj) 276 | if origin is None: 277 | # More explicit handling of unknown types 278 | logger.warning(f"Unknown type {type_obj}, defaulting to None") 279 | return None 280 | 281 | # Handle List types 282 | if origin is list or origin is List: 283 | args = get_args(type_obj) 284 | if not args: 285 | return "list" # Default to list if no type args 286 | inner_type = _type_from_str(args[0]) 287 | if inner_type is None: 288 | return "list" 289 | return f"List[{inner_type}]" 290 | 291 | # Handle Dict types 292 | if origin is dict or origin is Dict: 293 | args = get_args(type_obj) 294 | if not args or len(args) != 2: 295 | return "dict" # Default if no/invalid type args 296 | key_type = _type_from_str(args[0]) # Recursive call for key type 297 | value_type = _type_from_str(args[1]) # Recursive call for value type 298 | if key_type is None or value_type is None: 299 | return "dict" 300 | return f"Dict[{key_type}, {value_type}]" 301 | 302 | # Handle Union types 303 | if origin is Union: 304 | args = get_args(type_obj) 305 | if not args: 306 | return "tuple" 307 | types = [_type_from_str(arg) for arg in args] 308 | if any(t is None for t in types): 309 | return "tuple" 310 | return f"Union[{','.join(types)}]" 311 | 312 | # Default case 313 | return "str" 314 | 315 | 316 | class StateEncoder(json.JSONEncoder): 317 | """Custom JSON encoder for serializing workflow states.""" 318 | 319 | def default(self, obj): 320 | try: 321 | if isinstance(obj, np.ndarray): 322 | return { 323 | "__type__": "ndarray", 324 | "data": obj.tolist(), 325 | "dtype": str(obj.dtype), 326 | } 327 | if isinstance(obj, go.Figure): 328 | return { 329 | "__type__": "plotly_figure", 330 | "data": pio.to_json(obj), 331 | } 332 | if hasattr(obj, "to_dict"): 333 | return obj.to_dict() 334 | if isinstance(obj, bytes): 335 | return obj.decode("utf-8") 336 | if isinstance(obj, set): 337 | return list(obj) 338 | if hasattr(obj, "__dict__"): 339 | return obj.__dict__ 340 | return super().default(obj) 341 | except TypeError as e: 342 | logger.warning(f"Could not serialize object of type {type(obj)}: {str(e)}") 343 | return str(obj) 344 | 345 | 346 | class CustomSerializer(JsonPlusSerializer): 347 | NDARRAY_EXT_TYPE = 42 # Ensure this code doesn't conflict with other ExtTypes 348 | PLOTLY_FIGURE_EXT_TYPE = 43 # New extension type for Plotly figures 349 | 350 | def _default(self, obj: Any) -> Union[str, Dict[str, Any]]: 351 | if isinstance(obj, np.ndarray): 352 | return { 353 | "lc": 2, 354 | "type": "ndarray", 355 | "data": obj.tolist(), 356 | "dtype": str(obj.dtype), 357 | } 358 | if isinstance(obj, go.Figure): 359 | return { 360 | "lc": 2, 361 | "type": "plotly_figure", 362 | "data": pio.to_json(obj), 363 | } 364 | return super()._default(obj) 365 | 366 | def _reviver(self, value: Dict[str, Any]) -> Any: 367 | if value.get("lc", None) == 2: 368 | if value.get("type", None) == "ndarray": 369 | return np.array(value["data"], dtype=value["dtype"]) 370 | elif value.get("type", None) == "plotly_figure": 371 | return pio.from_json(value["data"]) 372 | return super()._reviver(value) 373 | 374 | # Override dumps_typed to use instance method _msgpack_enc 375 | def dumps_typed(self, obj: Any) -> tuple[str, bytes]: 376 | if isinstance(obj, bytes): 377 | return "bytes", obj 378 | elif isinstance(obj, bytearray): 379 | return "bytearray", obj 380 | else: 381 | try: 382 | return "msgpack", self._msgpack_enc(obj) 383 | except UnicodeEncodeError: 384 | return "json", self.dumps(obj) 385 | 386 | # Provide instance-level _msgpack_enc 387 | def _msgpack_enc(self, data: Any) -> bytes: 388 | enc = msgpack.Packer(default=self._msgpack_default) 389 | return enc.pack(data) 390 | 391 | # Provide instance-level _msgpack_default 392 | def _msgpack_default(self, obj: Any) -> Any: 393 | if isinstance(obj, np.ndarray): 394 | # Prepare metadata for ndarray 395 | metadata = { 396 | "dtype": str(obj.dtype), 397 | "shape": obj.shape, 398 | } 399 | metadata_packed = msgpack.packb(metadata, use_bin_type=True) 400 | data_packed = obj.tobytes() 401 | combined = metadata_packed + data_packed 402 | return msgpack.ExtType(self.NDARRAY_EXT_TYPE, combined) 403 | elif isinstance(obj, np.number): 404 | # Handle NumPy scalar types 405 | return obj.item() 406 | elif isinstance(obj, go.Figure): 407 | figure_json = pio.to_json(obj) 408 | figure_packed = msgpack.packb(figure_json, use_bin_type=True) 409 | return msgpack.ExtType(self.PLOTLY_FIGURE_EXT_TYPE, figure_packed) 410 | 411 | return super()._msgpack_default(obj) 412 | 413 | # Provide instance-level loads_typed 414 | def loads_typed(self, data: tuple[str, bytes]) -> Any: 415 | type_, data_ = data 416 | if type_ == "bytes": 417 | return data_ 418 | elif type_ == "bytearray": 419 | return bytearray(data_) 420 | elif type_ == "json": 421 | return self.loads(data_) 422 | elif type_ == "msgpack": 423 | return msgpack.unpackb( 424 | data_, ext_hook=self._msgpack_ext_hook, strict_map_key=False 425 | ) 426 | else: 427 | raise NotImplementedError(f"Unknown serialization type: {type_}") 428 | 429 | # Provide instance-level _msgpack_ext_hook 430 | def _msgpack_ext_hook(self, code: int, data: bytes) -> Any: 431 | if code == self.NDARRAY_EXT_TYPE: 432 | # Unpack metadata 433 | unpacker = msgpack.Unpacker(use_list=False, raw=False) 434 | unpacker.feed(data) 435 | metadata = unpacker.unpack() 436 | buffer_offset = unpacker.tell() 437 | array_data = data[buffer_offset:] 438 | array = np.frombuffer(array_data, dtype=metadata["dtype"]) 439 | array = array.reshape(metadata["shape"]) 440 | return array 441 | elif code == self.PLOTLY_FIGURE_EXT_TYPE: 442 | figure_json = msgpack.unpackb( 443 | data, strict_map_key=False, ext_hook=self._msgpack_ext_hook 444 | ) 445 | return pio.from_json(figure_json) 446 | else: 447 | return super()._msgpack_ext_hook(code, data) 448 | 449 | 450 | def convert_serialized_objects(obj): 451 | """ 452 | Convert serialized objects back to their original form. 453 | Currently handles: 454 | - NumPy arrays (serialized as {"__type__": "ndarray", "data": [...], "dtype": "..."}) 455 | - Plotly figures (serialized as {"__type__": "plotly_figure", "data": "..."}) 456 | 457 | Args: 458 | obj: The object to convert, which may contain serialized objects 459 | 460 | Returns: 461 | The object with any serialized objects converted back to their original form 462 | """ 463 | if isinstance(obj, dict): 464 | if "__type__" in obj: 465 | if obj["__type__"] == "ndarray": 466 | return np.array(obj["data"], dtype=obj["dtype"]) 467 | elif obj["__type__"] == "plotly_figure": 468 | return pio.from_json(obj["data"]) 469 | return {k: convert_serialized_objects(v) for k, v in obj.items()} 470 | elif isinstance(obj, list): 471 | return [convert_serialized_objects(item) for item in obj] 472 | return obj 473 | 474 | 475 | if __name__ == "__main__": 476 | serializer = CustomSerializer() 477 | original_data = { 478 | "array": np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64), 479 | "scalar": np.float32(7.5), 480 | "message": "Test serialization", 481 | "nested": { 482 | "array": np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64), 483 | "scalar": np.float32(7.5), 484 | "list_of_arrays": [ 485 | np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64), 486 | np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float64), 487 | ], 488 | }, 489 | } 490 | 491 | # Create a more complex figure with multiple traces and customization 492 | fig = go.Figure() 493 | 494 | # Add a scatter plot with markers 495 | fig.add_trace( 496 | go.Scatter( 497 | x=[1, 2, 3, 4, 5], 498 | y=[4, 5.2, 6, 3.2, 8], 499 | mode="markers+lines", 500 | name="Series A", 501 | marker=dict(size=10, color="blue", symbol="circle"), 502 | ) 503 | ) 504 | 505 | # Add a bar chart 506 | fig.add_trace( 507 | go.Bar( 508 | x=[1, 2, 3, 4, 5], y=[2, 3, 1, 5, 3], name="Series B", marker_color="green" 509 | ) 510 | ) 511 | 512 | # Add a line plot with different style 513 | fig.add_trace( 514 | go.Scatter( 515 | x=[1, 2, 3, 4, 5], 516 | y=[7, 6, 9, 8, 7], 517 | mode="lines", 518 | name="Series C", 519 | line=dict(width=3, dash="dash", color="red"), 520 | ) 521 | ) 522 | 523 | # Update layout with title and axis labels 524 | fig.update_layout( 525 | title="Complex Test Figure", 526 | xaxis_title="X Axis", 527 | yaxis_title="Y Axis", 528 | legend_title="Legend", 529 | template="plotly_white", 530 | ) 531 | original_data["figure"] = fig 532 | 533 | # Serialize the data 534 | _, serialized = serializer.dumps_typed(original_data) 535 | # Deserialize the data 536 | deserialized_data = serializer.loads_typed(("msgpack", serialized)) 537 | 538 | # Assertions 539 | assert isinstance(deserialized_data["array"], np.ndarray) 540 | assert np.array_equal(deserialized_data["array"], original_data["array"]) 541 | assert isinstance(deserialized_data["scalar"], float) 542 | assert deserialized_data["scalar"] == float(original_data["scalar"]) 543 | assert deserialized_data["message"] == original_data["message"] 544 | assert isinstance(deserialized_data["nested"]["array"], np.ndarray) 545 | assert np.array_equal( 546 | deserialized_data["nested"]["array"], original_data["nested"]["array"] 547 | ) 548 | assert isinstance(deserialized_data["nested"]["scalar"], float) 549 | assert deserialized_data["nested"]["scalar"] == float( 550 | original_data["nested"]["scalar"] 551 | ) 552 | assert isinstance(deserialized_data["nested"]["list_of_arrays"], list) 553 | assert all( 554 | isinstance(arr, np.ndarray) 555 | for arr in deserialized_data["nested"]["list_of_arrays"] 556 | ) 557 | assert all( 558 | np.array_equal(arr, original_arr) 559 | for arr, original_arr in zip( 560 | deserialized_data["nested"]["list_of_arrays"], 561 | original_data["nested"]["list_of_arrays"], 562 | ) 563 | ) 564 | 565 | assert isinstance(deserialized_data["figure"], go.Figure) 566 | assert len(deserialized_data["figure"].data) == len(fig.data) 567 | for i, trace in enumerate(fig.data): 568 | assert deserialized_data["figure"].data[i].type == trace.type 569 | # Compare x and y data if they exist 570 | if hasattr(trace, "x") and trace.x is not None: 571 | assert np.array_equal(deserialized_data["figure"].data[i].x, trace.x) 572 | if hasattr(trace, "y") and trace.y is not None: 573 | assert np.array_equal(deserialized_data["figure"].data[i].y, trace.y) 574 | 575 | # Compare layout properties 576 | assert deserialized_data["figure"].layout.title.text == fig.layout.title.text 577 | 578 | print("Serialization and deserialization test passed.") 579 | -------------------------------------------------------------------------------- /nodeology/interface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 3 | 4 | Copyright 2024. UChicago Argonne, LLC. This software was produced 5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 7 | U.S. Department of Energy. The U.S. Government has rights to use, 8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 11 | modified to produce derivative works, such modified software should 12 | be clearly marked, so as not to confuse it with the version available 13 | from ANL. 14 | 15 | Additionally, redistribution and use in source and binary forms, with 16 | or without modification, are permitted provided that the following 17 | conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in 24 | the documentation and/or other materials provided with the 25 | distribution. 26 | 27 | * Neither the name of UChicago Argonne, LLC, Argonne National 28 | Laboratory, ANL, the U.S. Government, nor the names of its 29 | contributors may be used to endorse or promote products derived 30 | from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 43 | POSSIBILITY OF SUCH DAMAGE. 44 | """ 45 | 46 | ### Initial Author <2025>: Xiangyu Yin 47 | 48 | import os, json, importlib, threading, traceback, contextvars 49 | import logging 50 | import chainlit as cl 51 | from chainlit.cli import run_chainlit 52 | from nodeology.state import StateEncoder, convert_serialized_objects 53 | 54 | 55 | logger = logging.getLogger(__name__) 56 | 57 | 58 | def run_chainlit_for_workflow(workflow, initial_state=None): 59 | """ 60 | Called by workflow.run(ui=True). This function: 61 | 1. stores workflow and initial_state in user session data 62 | 2. starts the chainlit server 63 | 3. returns the final state when the workflow completes 64 | """ 65 | os.environ["NODEOLOGY_WORKFLOW_CLASS"] = ( 66 | workflow.__class__.__module__ + "." + workflow.__class__.__name__ 67 | ) 68 | 69 | logger.info(f"Starting UI for workflow: {workflow.__class__.__name__}") 70 | 71 | # Save the initialization arguments to ensure they're passed when recreating the workflow 72 | if hasattr(workflow, "_init_kwargs"): 73 | logger.info( 74 | f"Found initialization kwargs: {list(workflow._init_kwargs.keys())}" 75 | ) 76 | 77 | # We need to handle non-serializable objects in the kwargs 78 | serializable_kwargs = {} 79 | for key, value in workflow._init_kwargs.items(): 80 | # Handle special cases 81 | if key == "state_defs": 82 | # For state_defs, we need to handle it specially 83 | if value is None: 84 | # If None, we'll use the workflow's state_schema 85 | serializable_kwargs[key] = None 86 | logger.info( 87 | f"Serializing {key} as None (will use workflow's state_schema)" 88 | ) 89 | elif isinstance(value, type) and hasattr(value, "__annotations__"): 90 | # If it's a TypedDict or similar class with annotations, we'll use its name 91 | # The workflow class will handle recreating it 92 | serializable_kwargs["_state_defs_class"] = ( 93 | f"{value.__module__}.{value.__name__}" 94 | ) 95 | logger.info( 96 | f"Serializing state_defs as class reference: {serializable_kwargs['_state_defs_class']}" 97 | ) 98 | elif isinstance(value, list): 99 | # If it's a list of state definitions, we'll try to serialize it 100 | # This is complex and might not work for all cases 101 | try: 102 | # Convert any TypedDict classes to their module.name string 103 | serialized_list = [] 104 | for item in value: 105 | if isinstance(item, type) and hasattr( 106 | item, "__annotations__" 107 | ): 108 | serialized_list.append( 109 | f"{item.__module__}.{item.__name__}" 110 | ) 111 | elif isinstance(item, tuple) and len(item) == 2: 112 | # Handle (name, type) tuples 113 | name, type_hint = item 114 | if isinstance(type_hint, type): 115 | serialized_list.append( 116 | [ 117 | name, 118 | f"{type_hint.__module__}.{type_hint.__name__}", 119 | ] 120 | ) 121 | else: 122 | serialized_list.append([name, str(type_hint)]) 123 | elif isinstance(item, dict) and len(item) == 1: 124 | # Handle {"name": type} dictionaries 125 | name, type_hint = next(iter(item.items())) 126 | if isinstance(type_hint, type): 127 | serialized_list.append( 128 | { 129 | name: f"{type_hint.__module__}.{type_hint.__name__}" 130 | } 131 | ) 132 | else: 133 | serialized_list.append({name: str(type_hint)}) 134 | else: 135 | # Skip items we can't serialize 136 | logger.info( 137 | f"Skipping non-serializable state_def item: {item}" 138 | ) 139 | 140 | if serialized_list: 141 | serializable_kwargs["_state_defs_list"] = serialized_list 142 | logger.info( 143 | f"Serializing state_defs as list: {serialized_list}" 144 | ) 145 | else: 146 | logger.info("Could not serialize any state_defs items") 147 | except Exception as e: 148 | logger.error(f"Error serializing state_defs list: {str(e)}") 149 | else: 150 | logger.info( 151 | f"Cannot serialize state_defs of type {type(value).__name__}" 152 | ) 153 | elif key == "checkpointer": 154 | # For checkpointer, just store "memory" if it's a string or an object 155 | if isinstance(value, str): 156 | serializable_kwargs[key] = value 157 | else: 158 | serializable_kwargs[key] = "memory" 159 | logger.info(f"Serializing checkpointer as: {serializable_kwargs[key]}") 160 | elif isinstance(value, (str, int, float, bool, type(None))): 161 | serializable_kwargs[key] = value 162 | logger.info( 163 | f"Serializing {key} as primitive type: {type(value).__name__}" 164 | ) 165 | elif isinstance(value, list) and all( 166 | isinstance(item, (str, int, float, bool, type(None))) for item in value 167 | ): 168 | serializable_kwargs[key] = value 169 | logger.info(f"Serializing {key} as list of primitives") 170 | elif isinstance(value, dict) and all( 171 | isinstance(k, str) 172 | and isinstance(v, (str, int, float, bool, type(None))) 173 | for k, v in value.items() 174 | ): 175 | serializable_kwargs[key] = value 176 | logger.info(f"Serializing {key} as dict of primitives") 177 | else: 178 | logger.info( 179 | f"Skipping non-serializable {key} of type {type(value).__name__}" 180 | ) 181 | # Skip other complex objects that can't be easily serialized 182 | 183 | # For client objects, just store their names 184 | if ( 185 | "llm_name" in workflow._init_kwargs 186 | and hasattr(workflow, "llm_client") 187 | and hasattr(workflow.llm_client, "model_name") 188 | ): 189 | serializable_kwargs["llm_name"] = workflow.llm_client.model_name 190 | logger.info( 191 | f"Using llm_client.model_name: {workflow.llm_client.model_name}" 192 | ) 193 | 194 | if ( 195 | "vlm_name" in workflow._init_kwargs 196 | and hasattr(workflow, "vlm_client") 197 | and hasattr(workflow.vlm_client, "model_name") 198 | ): 199 | serializable_kwargs["vlm_name"] = workflow.vlm_client.model_name 200 | logger.info( 201 | f"Using vlm_client.model_name: {workflow.vlm_client.model_name}" 202 | ) 203 | 204 | # Store the workflow's state_schema class name if available 205 | if hasattr(workflow, "state_schema") and hasattr( 206 | workflow.state_schema, "__name__" 207 | ): 208 | serializable_kwargs["_state_schema_class"] = ( 209 | f"{workflow.state_schema.__module__}.{workflow.state_schema.__name__}" 210 | ) 211 | logger.info( 212 | f"Storing state_schema class: {serializable_kwargs['_state_schema_class']}" 213 | ) 214 | 215 | os.environ["NODEOLOGY_WORKFLOW_ARGS"] = json.dumps( 216 | serializable_kwargs, cls=StateEncoder 217 | ) 218 | logger.info(f"Serialized kwargs: {list(serializable_kwargs.keys())}") 219 | else: 220 | logger.info("No initialization kwargs found on workflow") 221 | 222 | # Serialize any initial state if needed 223 | if initial_state: 224 | # Use StateEncoder to handle NumPy arrays 225 | os.environ["NODEOLOGY_INITIAL_STATE"] = json.dumps( 226 | initial_state, cls=StateEncoder 227 | ) 228 | logger.info("Serialized initial state") 229 | 230 | # Create a shared variable to store the final state 231 | os.environ["NODEOLOGY_FINAL_STATE"] = "{}" 232 | 233 | # This file is nodeology/chainlit_interface.py, get its path: 234 | this_file = os.path.abspath(__file__) 235 | # Start with some standard arguments 236 | logger.info("Starting Chainlit server") 237 | run_chainlit(target=this_file) 238 | 239 | # Return the final state from the last session 240 | final_state = {} 241 | if ( 242 | "NODEOLOGY_FINAL_STATE" in os.environ 243 | and os.environ["NODEOLOGY_FINAL_STATE"] != "{}" 244 | ): 245 | try: 246 | final_state_json = os.environ["NODEOLOGY_FINAL_STATE"] 247 | final_state_dict = json.loads(final_state_json) 248 | logger.info( 249 | f"Retrieved final state with keys: {list(final_state_dict.keys())}" 250 | ) 251 | 252 | # Convert any serialized NumPy arrays back to arrays 253 | final_state = convert_serialized_objects(final_state_dict) 254 | logger.info("Converted any serialized objects in final state") 255 | 256 | except Exception as e: 257 | logger.error(f"Error parsing final state: {str(e)}") 258 | 259 | return final_state 260 | 261 | 262 | @cl.on_chat_start 263 | async def on_chat_start(): 264 | """ 265 | Called once a new user session is started in the chainlit UI. 266 | We will instantiate a new workflow for this session. 267 | """ 268 | try: 269 | # Get the workflow class from environment variable 270 | workflow_class_path = os.environ.get("NODEOLOGY_WORKFLOW_CLASS") 271 | if not workflow_class_path: 272 | await cl.Message(content="No workflow class specified.").send() 273 | return 274 | 275 | logger.info(f"Creating workflow from class: {workflow_class_path}") 276 | 277 | # Import the workflow class dynamically 278 | module_path, class_name = workflow_class_path.rsplit(".", 1) 279 | module = importlib.import_module(module_path) 280 | WorkflowClass = getattr(module, class_name) 281 | logger.info(f"Successfully imported workflow class: {class_name}") 282 | 283 | # Get the saved initialization arguments 284 | workflow_args = {} 285 | state_defs_processed = False 286 | 287 | if "NODEOLOGY_WORKFLOW_ARGS" in os.environ: 288 | try: 289 | serialized_args = json.loads(os.environ["NODEOLOGY_WORKFLOW_ARGS"]) 290 | logger.info(f"Loaded serialized args: {list(serialized_args.keys())}") 291 | 292 | # Handle special parameters 293 | 294 | # 1. Handle state_defs 295 | if "_state_defs_class" in serialized_args: 296 | # We have a class reference for state_defs 297 | state_defs_class_path = serialized_args.pop("_state_defs_class") 298 | try: 299 | module_path, class_name = state_defs_class_path.rsplit(".", 1) 300 | module = importlib.import_module(module_path) 301 | state_defs_class = getattr(module, class_name) 302 | workflow_args["state_defs"] = state_defs_class 303 | logger.info( 304 | f"Imported state_defs class: {state_defs_class_path}" 305 | ) 306 | state_defs_processed = True 307 | except Exception as e: 308 | logger.error(f"Error importing state_defs class: {str(e)}") 309 | elif "_state_defs_list" in serialized_args: 310 | # We have a list of state definitions 311 | state_defs_list = serialized_args.pop("_state_defs_list") 312 | try: 313 | # Process each item in the list 314 | processed_list = [] 315 | for item in state_defs_list: 316 | if isinstance(item, str): 317 | # It's a class reference 318 | try: 319 | module_path, class_name = item.rsplit(".", 1) 320 | module = importlib.import_module(module_path) 321 | class_obj = getattr(module, class_name) 322 | processed_list.append(class_obj) 323 | except Exception as e: 324 | logger.error( 325 | f"Error importing state def class {item}: {str(e)}" 326 | ) 327 | elif isinstance(item, list) and len(item) == 2: 328 | # It's a [name, type] tuple 329 | name, type_str = item 330 | if "." in type_str: 331 | # It's a class reference 332 | try: 333 | module_path, class_name = type_str.rsplit( 334 | ".", 1 335 | ) 336 | module = importlib.import_module(module_path) 337 | type_obj = getattr(module, class_name) 338 | processed_list.append((name, type_obj)) 339 | except Exception as e: 340 | logger.error( 341 | f"Error importing type {type_str}: {str(e)}" 342 | ) 343 | # Fall back to string representation 344 | processed_list.append((name, type_str)) 345 | else: 346 | # It's a primitive type string 347 | processed_list.append((name, type_str)) 348 | elif isinstance(item, dict) and len(item) == 1: 349 | # It's a {name: type} dict 350 | name, type_str = next(iter(item.items())) 351 | if "." in type_str: 352 | # It's a class reference 353 | try: 354 | module_path, class_name = type_str.rsplit( 355 | ".", 1 356 | ) 357 | module = importlib.import_module(module_path) 358 | type_obj = getattr(module, class_name) 359 | processed_list.append({name: type_obj}) 360 | except Exception as e: 361 | logger.error( 362 | f"Error importing type {type_str}: {str(e)}" 363 | ) 364 | # Fall back to string representation 365 | processed_list.append({name: type_str}) 366 | else: 367 | # It's a primitive type string 368 | processed_list.append({name: type_str}) 369 | 370 | if processed_list: 371 | workflow_args["state_defs"] = processed_list 372 | logger.info( 373 | f"Processed state_defs list with {len(processed_list)} items" 374 | ) 375 | state_defs_processed = True 376 | else: 377 | logger.info("No state_defs items could be processed") 378 | except Exception as e: 379 | logger.error(f"Error processing state_defs list: {str(e)}") 380 | elif ( 381 | "state_defs" in serialized_args 382 | and serialized_args["state_defs"] is None 383 | ): 384 | # Explicit None value 385 | workflow_args["state_defs"] = None 386 | serialized_args.pop("state_defs") 387 | logger.info("Using None for state_defs") 388 | state_defs_processed = True 389 | 390 | # 2. Handle state_schema if needed 391 | if "_state_schema_class" in serialized_args: 392 | # We have a class reference for state_schema 393 | state_schema_class_path = serialized_args.pop("_state_schema_class") 394 | logger.info( 395 | f"Found state_schema class: {state_schema_class_path} (will be handled by workflow)" 396 | ) 397 | 398 | # If we couldn't process state_defs, try to use the state_schema class as a fallback 399 | if not state_defs_processed: 400 | try: 401 | module_path, class_name = state_schema_class_path.rsplit( 402 | ".", 1 403 | ) 404 | module = importlib.import_module(module_path) 405 | state_schema_class = getattr(module, class_name) 406 | workflow_args["state_defs"] = state_schema_class 407 | logger.info( 408 | f"Using state_schema class as fallback for state_defs: {state_schema_class_path}" 409 | ) 410 | state_defs_processed = True 411 | except Exception as e: 412 | logger.error( 413 | f"Error importing state_schema class as fallback: {str(e)}" 414 | ) 415 | 416 | # Add remaining arguments 417 | for key, value in serialized_args.items(): 418 | workflow_args[key] = value 419 | 420 | # Convert any serialized NumPy arrays back to arrays 421 | workflow_args = convert_serialized_objects(workflow_args) 422 | logger.info(f"Final workflow args: {list(workflow_args.keys())}") 423 | except Exception as e: 424 | logger.error(f"Error parsing workflow arguments: {str(e)}") 425 | traceback.print_exc() 426 | # Continue with empty args if there's an error 427 | 428 | # If we couldn't process state_defs, check if the workflow class has a state_schema attribute 429 | if not state_defs_processed and hasattr(WorkflowClass, "state_schema"): 430 | logger.info(f"Using workflow class's state_schema attribute as fallback") 431 | # We don't need to set state_defs explicitly, the workflow will use its state_schema 432 | 433 | # Create a new instance of the workflow with the saved arguments 434 | logger.info( 435 | f"Creating workflow instance with args: {list(workflow_args.keys())}" 436 | ) 437 | workflow = WorkflowClass(**workflow_args) 438 | logger.info(f"Successfully created workflow instance: {workflow.name}") 439 | 440 | # Check if VLM client is available 441 | if hasattr(workflow, "vlm_client") and workflow.vlm_client is not None: 442 | logger.info(f"VLM client is available") 443 | else: 444 | logger.info("VLM client is not available") 445 | 446 | # Get initial state if available 447 | initial_state = None 448 | initial_state_json = os.environ.get("NODEOLOGY_INITIAL_STATE") 449 | if initial_state_json: 450 | logger.info("Found initial state in environment") 451 | # Parse the JSON and convert any serialized NumPy arrays back to arrays 452 | initial_state_dict = json.loads(initial_state_json) 453 | logger.info( 454 | f"Loaded initial state with keys: {list(initial_state_dict.keys())}" 455 | ) 456 | 457 | # Convert any serialized NumPy arrays back to arrays 458 | initial_state = convert_serialized_objects(initial_state_dict) 459 | logger.info("Converted any serialized objects in initial state") 460 | 461 | # Initialize the workflow 462 | if initial_state: 463 | logger.info( 464 | f"Initializing workflow with initial state: {list(initial_state.keys())}" 465 | ) 466 | workflow.initialize(initial_state) 467 | else: 468 | logger.info("Initializing workflow with default state") 469 | workflow.initialize() 470 | logger.info("Workflow initialized successfully") 471 | 472 | # Store in user session 473 | cl.user_session.set("workflow", workflow) 474 | logger.info("Stored workflow in user session") 475 | 476 | # Capture the current Chainlit context 477 | parent_ctx = contextvars.copy_context() 478 | logger.info("Captured Chainlit context") 479 | 480 | # Create a function to save the final state when workflow completes 481 | def save_final_state(state): 482 | try: 483 | # Use StateEncoder to handle NumPy arrays and other complex objects 484 | serialized_state = json.dumps(state, cls=StateEncoder) 485 | os.environ["NODEOLOGY_FINAL_STATE"] = serialized_state 486 | logger.info(f"Saved final state with keys: {list(state.keys())}") 487 | except Exception as e: 488 | logger.error(f"Error saving final state: {str(e)}") 489 | traceback.print_exc() 490 | 491 | # Start the workflow in a background thread from within the Chainlit context 492 | def run_workflow_with_polling(): 493 | try: 494 | logger.info("Starting workflow execution in background thread") 495 | # Run the workflow inside the captured context 496 | final_state = parent_ctx.run( 497 | lambda: workflow._run(initial_state, ui=True) 498 | ) 499 | logger.info("Workflow execution completed") 500 | 501 | # Save the final state 502 | if final_state: 503 | save_final_state(final_state) 504 | logger.info("Final state saved") 505 | except Exception as e: 506 | logger.error(f"Error in workflow execution: {str(e)}") 507 | traceback.print_exc() 508 | 509 | # Start the workflow thread 510 | workflow_thread = threading.Thread( 511 | target=run_workflow_with_polling, daemon=True 512 | ) 513 | cl.user_session.set("workflow_thread", workflow_thread) 514 | logger.info("Created workflow thread") 515 | workflow_thread.start() 516 | logger.info("Started workflow thread") 517 | 518 | await cl.Message( 519 | content=f"Welcome to the {workflow.__class__.__name__} via Nodeology!" 520 | ).send() 521 | logger.info("Sent welcome message") 522 | except Exception as e: 523 | logger.error(f"Error in on_chat_start: {str(e)}") 524 | traceback.print_exc() 525 | await cl.Message(content=f"Error initializing workflow: {str(e)}").send() 526 | -------------------------------------------------------------------------------- /tests/test_node.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 3 | 4 | Copyright 2024. UChicago Argonne, LLC. This software was produced 5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 7 | U.S. Department of Energy. The U.S. Government has rights to use, 8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 11 | modified to produce derivative works, such modified software should 12 | be clearly marked, so as not to confuse it with the version available 13 | from ANL. 14 | 15 | Additionally, redistribution and use in source and binary forms, with 16 | or without modification, are permitted provided that the following 17 | conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in 24 | the documentation and/or other materials provided with the 25 | distribution. 26 | 27 | * Neither the name of UChicago Argonne, LLC, Argonne National 28 | Laboratory, ANL, the U.S. Government, nor the names of its 29 | contributors may be used to endorse or promote products derived 30 | from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 43 | POSSIBILITY OF SUCH DAMAGE. 44 | """ 45 | 46 | ### Initial Author <2024>: Xiangyu Yin 47 | 48 | import os, logging 49 | import pytest 50 | from typing import Any, Dict, List 51 | import numpy as np 52 | from nodeology.client import VLM_Client 53 | from nodeology.node import ( 54 | Node, 55 | as_node, 56 | remove_markdown_blocks_formatting, 57 | ) 58 | from nodeology.state import State 59 | from nodeology.log import add_logging_level 60 | 61 | add_logging_level("PRINTLOG", logging.INFO + 5) 62 | add_logging_level("LOGONLY", logging.INFO + 1) 63 | 64 | 65 | # Basic Node Tests 66 | class TestBasicNodeFunctionality: 67 | def test_node_creation(self): 68 | """Test basic node creation and configuration""" 69 | node = Node( 70 | node_type="test_node", 71 | prompt_template="Test prompt with {key1} and {key2}", 72 | sink=["output"], 73 | ) 74 | 75 | assert node.node_type == "test_node" 76 | assert "key1" in node.required_keys 77 | assert "key2" in node.required_keys 78 | assert node.sink == ["output"] 79 | 80 | def test_node_error_handling(self): 81 | """Test node error handling for missing required keys""" 82 | node = Node( 83 | node_type="test_node", prompt_template="Test {required_key}", sink="output" 84 | ) 85 | 86 | state = State() 87 | state["messages"] = [] 88 | 89 | with pytest.raises(ValueError, match="Required key 'required_key' not found"): 90 | node(state, None) 91 | 92 | def test_empty_sink_list(self): 93 | """Test node behavior with empty sink list""" 94 | node = Node(node_type="test_node", prompt_template="Test", sink=[]) 95 | 96 | state = State() 97 | state["messages"] = [] 98 | 99 | result_state = node(state, lambda **k: "response") 100 | assert result_state == state # State should remain unchanged 101 | 102 | def test_state_type_tracking_chain(self): 103 | """Test node type tracking through multiple nodes""" 104 | node1 = Node(node_type="node1", prompt_template="Test", sink="output1") 105 | node2 = Node(node_type="node2", prompt_template="Test", sink="output2") 106 | 107 | state = State() 108 | state["messages"] = [] 109 | 110 | state = node1(state, lambda **k: "response1") 111 | assert state["current_node_type"] == "node1" 112 | assert state["previous_node_type"] == "" 113 | 114 | state = node2(state, lambda **k: "response2") 115 | assert state["current_node_type"] == "node2" 116 | assert state["previous_node_type"] == "node1" 117 | 118 | def test_state_preservation(self): 119 | """Test that unrelated state data is preserved""" 120 | node = Node(node_type="test_node", prompt_template="Test", sink="output") 121 | 122 | state = State() 123 | state["messages"] = [] 124 | state["preserved_key"] = "preserved_value" 125 | state["input"] = "test_input" 126 | 127 | result_state = node(state, lambda **k: "response") 128 | assert result_state["preserved_key"] == "preserved_value" 129 | assert result_state["input"] == "test_input" 130 | 131 | def test_state_immutability(self): 132 | """Test that original state is not modified""" 133 | node = Node(node_type="test_node", prompt_template="Test", sink="output") 134 | 135 | original_state = State() 136 | original_state["messages"] = [] 137 | original_state["value"] = "original" 138 | 139 | state_copy = original_state.copy() 140 | result_state = node(state_copy, lambda **k: "response") 141 | 142 | assert original_state["value"] == "original" 143 | assert original_state != result_state 144 | 145 | 146 | # Execution Tests 147 | class TestNodeExecution: 148 | class MockClient: 149 | def __call__(self, messages, **kwargs): 150 | return "Test response" 151 | 152 | @pytest.fixture 153 | def basic_state(self): 154 | state = State() 155 | state["input"] = "value" 156 | state["messages"] = [] 157 | return state 158 | 159 | def test_basic_execution(self, basic_state): 160 | """Test basic node execution""" 161 | node = Node( 162 | node_type="test_node", prompt_template="Test {input}", sink="output" 163 | ) 164 | 165 | result_state = node(basic_state, self.MockClient()) 166 | assert result_state["output"] == "Test response" 167 | 168 | def test_multiple_sinks(self, basic_state): 169 | """Test node with multiple output sinks""" 170 | 171 | class MultiResponseMock: 172 | def __call__(self, messages, **kwargs): 173 | return ["Response 1", "Response 2"] 174 | 175 | node = Node( 176 | node_type="test_node", 177 | prompt_template="Test {input}", 178 | sink=["output1", "output2"], 179 | ) 180 | 181 | result_state = node(basic_state, MultiResponseMock()) 182 | assert result_state["output1"] == "Response 1" 183 | assert result_state["output2"] == "Response 2" 184 | 185 | def test_sink_format(self, basic_state): 186 | """Test node with specific sink format""" 187 | 188 | class FormatMock: 189 | def __call__(self, messages, **kwargs): 190 | assert kwargs.get("format") == "json" 191 | return '{"key": "value"}' 192 | 193 | node = Node( 194 | node_type="test_node", 195 | prompt_template="Test {input}", 196 | sink="output", 197 | sink_format="json", 198 | ) 199 | 200 | result_state = node(basic_state, FormatMock()) 201 | assert result_state["output"] == '{"key": "value"}' 202 | 203 | def test_invalid_sink_format(self): 204 | """Test handling of invalid sink format specification""" 205 | node = Node( 206 | node_type="test_node", 207 | prompt_template="Test", 208 | sink="output", 209 | sink_format="invalid_format", 210 | ) 211 | 212 | state = State() 213 | state["messages"] = [] 214 | 215 | # The client should handle invalid format 216 | class FormatTestClient: 217 | def __call__(self, messages, **kwargs): 218 | assert kwargs.get("format") == "invalid_format" 219 | return "response" 220 | 221 | result_state = node(state, FormatTestClient()) 222 | assert result_state["output"] == "response" 223 | 224 | def test_mismatched_sink_response(self): 225 | """Test handling of mismatched sink and response count""" 226 | node = Node( 227 | node_type="test_node", prompt_template="Test", sink=["output1", "output2"] 228 | ) 229 | 230 | state = State() 231 | state["messages"] = [] 232 | 233 | with pytest.raises( 234 | ValueError, match="Number of responses .* doesn't match number of sink" 235 | ): 236 | node(state, lambda **k: ["single_response"]) 237 | 238 | 239 | # Expression Evaluation Tests 240 | class TestExpressionEvaluation: 241 | @pytest.fixture 242 | def sample_state(self): 243 | class SampleState(State): 244 | items: List[str] 245 | numbers: np.ndarray 246 | text: str 247 | data: Dict[str, Any] 248 | key: str 249 | 250 | state = SampleState() 251 | state["messages"] = [] 252 | state["items"] = ["a", "b", "c", "d"] 253 | state["numbers"] = np.array([1, 2, 3, 4, 5]) 254 | state["text"] = "Hello World" 255 | state["data"] = {"name": "John", "age": 30} 256 | state["key"] = "name" 257 | return state 258 | 259 | def test_basic_indexing(self, sample_state): 260 | """Test basic list indexing""" 261 | node = Node( 262 | node_type="test_node", 263 | prompt_template="First: {items[0]}, Last: {items[-1]}", 264 | sink="output", 265 | ) 266 | 267 | result_state = node(sample_state, lambda **k: "response") 268 | assert "First: a, Last: d" in result_state["messages"][-1]["content"] 269 | 270 | def test_list_slicing(self, sample_state): 271 | """Test list slicing operations""" 272 | node = Node( 273 | node_type="test_node", 274 | prompt_template="Slice: {items[1:3]}, Reverse: {items[::-1]}", 275 | sink="output", 276 | ) 277 | 278 | result_state = node(sample_state, lambda **k: "response") 279 | assert ( 280 | "Slice: ['b', 'c'], Reverse: ['d', 'c', 'b', 'a']" 281 | in result_state["messages"][-1]["content"] 282 | ) 283 | 284 | def test_string_methods(self, sample_state): 285 | """Test string method calls""" 286 | node = Node( 287 | node_type="test_node", 288 | prompt_template=""" 289 | Upper: {text.upper()} 290 | Lower: {text.lower()} 291 | Title: {text.title()} 292 | Strip: {text.strip()} 293 | """, 294 | sink="output", 295 | ) 296 | 297 | result_state = node(sample_state, lambda **k: "response") 298 | message = result_state["messages"][-1]["content"] 299 | assert "Upper: HELLO WORLD" in message 300 | assert "Lower: hello world" in message 301 | assert "Title: Hello World" in message 302 | 303 | def test_built_in_functions(self, sample_state): 304 | """Test allowed built-in function calls""" 305 | node = Node( 306 | node_type="test_node", 307 | prompt_template=""" 308 | Length: {len(items)} 309 | Maximum: {max(numbers)} 310 | Minimum: {min(numbers)} 311 | Sum: {sum(numbers)} 312 | Absolute: {abs(-42)} 313 | """, 314 | sink="output", 315 | ) 316 | 317 | result_state = node(sample_state, lambda **k: "response") 318 | message = result_state["messages"][-1]["content"] 319 | assert "Length: 4" in message 320 | assert "Maximum: 5" in message 321 | assert "Minimum: 1" in message 322 | assert "Sum: 15" in message 323 | assert "Absolute: 42" in message 324 | 325 | def test_dict_access(self, sample_state): 326 | """Test dictionary access methods""" 327 | node = Node( 328 | node_type="test_node", 329 | prompt_template=""" 330 | Direct key: {data['name']} 331 | Variable key: {data[key]} 332 | """, 333 | sink="output", 334 | ) 335 | 336 | result_state = node(sample_state, lambda **k: "response") 337 | message = result_state["messages"][-1]["content"] 338 | assert "Direct key: John" in message 339 | assert "Variable key: John" in message 340 | 341 | def test_type_conversions(self, sample_state): 342 | """Test type conversion functions""" 343 | sample_state["value"] = "42" 344 | node = Node( 345 | node_type="test_node", 346 | prompt_template=""" 347 | Integer: {int(value)} 348 | Float: {float(value)} 349 | String: {str(numbers[0])} 350 | """, 351 | sink="output", 352 | ) 353 | 354 | result_state = node(sample_state, lambda **k: "response") 355 | message = result_state["messages"][-1]["content"] 356 | assert "Integer: 42" in message 357 | assert "Float: 42.0" in message 358 | assert "String: 1" in message 359 | 360 | def test_invalid_expressions(self, sample_state): 361 | """Test error handling for invalid expressions""" 362 | # Test invalid function 363 | with pytest.raises(ValueError, match="Function not allowed: print"): 364 | node = Node( 365 | node_type="test_node", prompt_template="{print(text)}", sink="output" 366 | ) 367 | node(sample_state, lambda **k: "response") 368 | 369 | # Test invalid method 370 | with pytest.raises(ValueError, match="String method not allowed: split"): 371 | node = Node( 372 | node_type="test_node", prompt_template="{text.split()}", sink="output" 373 | ) 374 | node(sample_state, lambda **k: "response") 375 | 376 | # Test invalid index 377 | with pytest.raises(ValueError): 378 | node = Node( 379 | node_type="test_node", prompt_template="{items[10]}", sink="output" 380 | ) 381 | node(sample_state, lambda **k: "response") 382 | 383 | # Test invalid key 384 | with pytest.raises(ValueError): 385 | node = Node( 386 | node_type="test_node", 387 | prompt_template="{data['invalid_key']}", 388 | sink="output", 389 | ) 390 | node(sample_state, lambda **k: "response") 391 | 392 | 393 | # Pre/Post Processing Tests 394 | class TestPrePostProcessing: 395 | @pytest.fixture 396 | def processed_list(self): 397 | return [] 398 | 399 | def test_pre_post_processing(self, processed_list): 400 | """Test node with pre and post processing functions""" 401 | 402 | def pre_process(state, client, **kwargs): 403 | processed_list.append("pre") 404 | return state 405 | 406 | def post_process(state, client, **kwargs): 407 | processed_list.append("post") 408 | return state 409 | 410 | node = Node( 411 | node_type="test_node", 412 | prompt_template="Test {input}", 413 | sink="output", 414 | pre_process=pre_process, 415 | post_process=post_process, 416 | ) 417 | 418 | state = State() 419 | state["input"] = "value" 420 | state["messages"] = [] 421 | 422 | node(state, lambda **k: "response") 423 | assert processed_list == ["pre", "post"] 424 | 425 | def test_none_pre_post_process(self): 426 | """Test node behavior when pre/post process returns None""" 427 | 428 | def pre_process(state, client, **kwargs): 429 | return None 430 | 431 | def post_process(state, client, **kwargs): 432 | return None 433 | 434 | node = Node( 435 | node_type="test_node", 436 | prompt_template="Test {input}", 437 | sink="output", 438 | pre_process=pre_process, 439 | post_process=post_process, 440 | ) 441 | 442 | state = State() 443 | state["input"] = "value" 444 | state["messages"] = [] 445 | 446 | result_state = node(state, lambda **k: "response") 447 | assert result_state == state 448 | 449 | 450 | # Source Mapping Tests 451 | class TestSourceMapping: 452 | @pytest.fixture 453 | def state_with_mapping(self): 454 | state = State() 455 | state["different_key"] = "mapped value" 456 | state["input_key"] = "value" 457 | state["messages"] = [] 458 | return state 459 | 460 | def test_dict_source_mapping(self, state_with_mapping): 461 | """Test node with source key mapping""" 462 | node = Node( 463 | node_type="test_node", prompt_template="Test {value}", sink="output" 464 | ) 465 | 466 | result_state = node( 467 | state_with_mapping, 468 | lambda **k: "Test response", 469 | source={"value": "different_key"}, 470 | ) 471 | assert result_state["output"] == "Test response" 472 | 473 | def test_string_source_mapping(self, state_with_mapping): 474 | """Test node with string source mapping""" 475 | node = Node( 476 | node_type="test_node", prompt_template="Test {source}", sink="output" 477 | ) 478 | 479 | result_state = node( 480 | state_with_mapping, lambda **k: "response", source="input_key" 481 | ) 482 | assert result_state["output"] == "response" 483 | 484 | def test_invalid_source_mapping(self): 485 | """Test handling of invalid source mapping""" 486 | node = Node( 487 | node_type="test_node", prompt_template="Test {value}", sink="output" 488 | ) 489 | 490 | state = State() 491 | state["messages"] = [] 492 | 493 | with pytest.raises(ValueError): 494 | node(state, None, source={"value": "nonexistent_key"}) 495 | 496 | 497 | # VLM Integration Tests 498 | class TestVLMIntegration: 499 | class MockVLMClient(VLM_Client): 500 | def __init__(self): 501 | super().__init__() 502 | 503 | def process_images(self, messages, images, **kwargs): 504 | # Verify images are valid paths (for testing) 505 | assert all(isinstance(img, str) for img in images) 506 | return messages 507 | 508 | def __call__(self, messages, images=None, **kwargs): 509 | if images is not None: 510 | messages = self.process_images(messages, images) 511 | return "Image description response" 512 | 513 | def test_vlm_execution(self): 514 | """Test node execution with VLM client""" 515 | node = Node( 516 | node_type="test_vlm_node", 517 | prompt_template="Describe this image", 518 | sink="output", 519 | image_keys=["image_path"], 520 | ) 521 | 522 | state = State() 523 | state["messages"] = [] 524 | 525 | # Create a temporary test image file 526 | test_image_path = "test_image.jpg" 527 | with open(test_image_path, "w") as f: 528 | f.write("dummy image content") 529 | 530 | try: 531 | result_state = node(state, self.MockVLMClient(), image_path=test_image_path) 532 | assert result_state["output"] == "Image description response" 533 | finally: 534 | # Clean up the test image 535 | if os.path.exists(test_image_path): 536 | os.remove(test_image_path) 537 | 538 | def test_vlm_multiple_images(self): 539 | """Test VLM node with multiple image inputs""" 540 | node = Node( 541 | node_type="test_vlm_node", 542 | prompt_template="Describe these images", 543 | sink="output", 544 | image_keys=["image1", "image2"], 545 | ) 546 | 547 | state = State() 548 | state["messages"] = [] 549 | 550 | # Create temporary test image files 551 | test_images = ["test1.jpg", "test2.jpg"] 552 | for img_path in test_images: 553 | with open(img_path, "w") as f: 554 | f.write(f"dummy image content for {img_path}") 555 | 556 | try: 557 | result_state = node( 558 | state, 559 | self.MockVLMClient(), 560 | image1=test_images[0], 561 | image2=test_images[1], 562 | ) 563 | assert result_state["output"] == "Image description response" 564 | finally: 565 | # Clean up test images 566 | for img_path in test_images: 567 | if os.path.exists(img_path): 568 | os.remove(img_path) 569 | 570 | def test_vlm_missing_image(self): 571 | """Test VLM node execution without required image""" 572 | node = Node( 573 | node_type="test_vlm_node", 574 | prompt_template="Describe this image", 575 | sink="output", 576 | image_keys=["image_path"], 577 | ) 578 | 579 | state = State() 580 | state["messages"] = [] 581 | 582 | with pytest.raises(ValueError, match="At least one image key must be provided"): 583 | node(state, self.MockVLMClient()) 584 | 585 | def test_vlm_invalid_image_path(self): 586 | """Test VLM node with invalid image path""" 587 | node = Node( 588 | node_type="test_vlm_node", 589 | prompt_template="Test", 590 | sink="output", 591 | image_keys=["image"], 592 | ) 593 | 594 | state = State() 595 | state["messages"] = [] 596 | 597 | with pytest.raises(TypeError, match="should be string"): 598 | node(state, self.MockVLMClient(), image=None) 599 | 600 | 601 | # Decorator Tests 602 | class TestDecorators: 603 | def test_as_node_decorator(self): 604 | """Test @as_node decorator functionality""" 605 | 606 | @as_node(["output"]) 607 | def test_function(input_value): 608 | return f"Processed {input_value}" 609 | 610 | state = State() 611 | state["input_value"] = "test" 612 | state["messages"] = [] 613 | 614 | result_state = test_function(state, None) 615 | assert result_state["output"] == "Processed test" 616 | 617 | def test_custom_function_defaults(self): 618 | """Test node with custom function having default parameters""" 619 | 620 | def custom_func(required_param, optional_param="default"): 621 | return f"{required_param}-{optional_param}" 622 | 623 | node = Node( 624 | node_type="test_node", 625 | prompt_template="", 626 | sink="output", 627 | custom_function=custom_func, 628 | ) 629 | 630 | state = State() 631 | state["required_param"] = "value" 632 | state["messages"] = [] 633 | 634 | result_state = node(state, None) 635 | assert result_state["output"] == "value-default" 636 | 637 | def test_invalid_custom_function_args(self): 638 | """Test handling of custom function with missing required arguments""" 639 | 640 | def custom_func(required_arg): 641 | return required_arg 642 | 643 | node = Node( 644 | node_type="test_node", 645 | prompt_template="", 646 | sink="output", 647 | custom_function=custom_func, 648 | ) 649 | 650 | state = State() 651 | state["messages"] = [] 652 | 653 | with pytest.raises(ValueError, match="Required key 'required_arg' not found"): 654 | node(state, None) 655 | 656 | def test_as_node_with_multiple_sinks(self): 657 | """Test @as_node decorator with multiple output sinks""" 658 | 659 | @as_node(["output1", "output2"]) 660 | def multi_output_function(value): 661 | return [f"First {value}", f"Second {value}"] 662 | 663 | state = State() 664 | state["value"] = "test" 665 | state["messages"] = [] 666 | 667 | result_state = multi_output_function(state, None) 668 | assert result_state["output1"] == "First test" 669 | assert result_state["output2"] == "Second test" 670 | 671 | def test_as_node_with_pre_post_processing(self): 672 | """Test @as_node decorator with pre and post processing""" 673 | processed = [] 674 | 675 | def pre_process(state, client, **kwargs): 676 | processed.append("pre") 677 | return state 678 | 679 | def post_process(state, client, **kwargs): 680 | processed.append("post") 681 | return state 682 | 683 | @as_node(["output"], pre_process=pre_process, post_process=post_process) 684 | def test_function(value): 685 | processed.append("main") 686 | return f"Result: {value}" 687 | 688 | state = State() 689 | state["value"] = "test" 690 | state["messages"] = [] 691 | 692 | result_state = test_function(state, None) 693 | assert result_state["output"] == "Result: test" 694 | assert processed == ["pre", "main", "post"] 695 | 696 | def test_as_node_with_multiple_parameters(self): 697 | """Test @as_node decorator with function having multiple parameters""" 698 | 699 | @as_node(["output"]) 700 | def multi_param_function(param1, param2, param3="default"): 701 | return f"{param1}-{param2}-{param3}" 702 | 703 | state = State() 704 | state["param1"] = "value1" 705 | state["param2"] = "value2" 706 | state["messages"] = [] 707 | 708 | result_state = multi_param_function(state, None) 709 | assert result_state["output"] == "value1-value2-default" 710 | 711 | def test_as_node_as_function_flag(self): 712 | """Test @as_node decorator with as_function flag""" 713 | 714 | @as_node(["output"], as_function=True) 715 | def test_function(value): 716 | return f"Processed {value}" 717 | 718 | assert callable(test_function) 719 | assert hasattr(test_function, "node_type") 720 | assert hasattr(test_function, "sink") 721 | assert test_function.node_type == "test_function" 722 | assert test_function.sink == ["output"] 723 | 724 | def test_as_node_error_handling(self): 725 | """Test @as_node decorator error handling for missing parameters""" 726 | 727 | @as_node(["output"]) 728 | def error_function(required_param): 729 | return f"Value: {required_param}" 730 | 731 | state = State() 732 | state["messages"] = [] 733 | 734 | with pytest.raises(ValueError, match="Required key 'required_param' not found"): 735 | error_function(state, None) 736 | 737 | 738 | # Utility Function Tests 739 | class TestUtilityFunctions: 740 | def test_remove_markdown_blocks(self): 741 | """Test markdown block removal""" 742 | text = "```python\ndef test():\n pass\n```" 743 | result = remove_markdown_blocks_formatting(text) 744 | assert result == "def test():\n pass" 745 | -------------------------------------------------------------------------------- /nodeology/node.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024, UChicago Argonne, LLC. All rights reserved. 3 | 4 | Copyright 2024. UChicago Argonne, LLC. This software was produced 5 | under U.S. Government contract DE-AC02-06CH11357 for Argonne National 6 | Laboratory (ANL), which is operated by UChicago Argonne, LLC for the 7 | U.S. Department of Energy. The U.S. Government has rights to use, 8 | reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR 9 | UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR 10 | ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is 11 | modified to produce derivative works, such modified software should 12 | be clearly marked, so as not to confuse it with the version available 13 | from ANL. 14 | 15 | Additionally, redistribution and use in source and binary forms, with 16 | or without modification, are permitted provided that the following 17 | conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright 20 | notice, this list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright 23 | notice, this list of conditions and the following disclaimer in 24 | the documentation and/or other materials provided with the 25 | distribution. 26 | 27 | * Neither the name of UChicago Argonne, LLC, Argonne National 28 | Laboratory, ANL, the U.S. Government, nor the names of its 29 | contributors may be used to endorse or promote products derived 30 | from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS 33 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 34 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 35 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago 36 | Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 37 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 38 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 41 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 42 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 43 | POSSIBILITY OF SUCH DAMAGE. 44 | """ 45 | 46 | ### Initial Author <2024>: Xiangyu Yin 47 | 48 | import os 49 | from string import Formatter 50 | from inspect import signature 51 | from typing import Optional, Annotated, List, Union, Dict, Callable, Any 52 | import ast 53 | 54 | from nodeology.state import State 55 | from nodeology.log import log_print_color 56 | from nodeology.client import LLM_Client, VLM_Client 57 | 58 | 59 | def _process_state_with_transforms( 60 | state: State, transforms: Dict[str, Callable], client: LLM_Client, **kwargs 61 | ) -> State: 62 | """Helper function to apply transforms to state values. 63 | 64 | Args: 65 | state: Current state 66 | transforms: Dictionary mapping state keys to transformation functions 67 | client: LLM client (unused but kept for signature compatibility) 68 | """ 69 | for key, transform in transforms.items(): 70 | if key in state: 71 | try: 72 | state[key] = transform(state[key]) 73 | except Exception as e: 74 | raise ValueError(f"Error applying transform to {key}: {str(e)}") 75 | return state 76 | 77 | 78 | class Node: 79 | """Template for creating node functions that process data using LLMs or custom functions. 80 | 81 | A Node represents a processing unit in a workflow that can: 82 | - Execute LLM/VLM queries or custom functions 83 | - Manage state before and after execution 84 | - Handle pre/post processing steps 85 | - Process both text and image inputs 86 | - Format and validate outputs 87 | 88 | Args: 89 | prompt_template (str): Template string for the LLM prompt. Uses Python string formatting 90 | syntax (e.g., "{variable}"). Empty if using custom_function. 91 | node_type (Optional[str]): Unique identifier for the node. 92 | sink (Optional[Union[List[str], str]]): Where to store results in state. Can be: 93 | - Single string key 94 | - List of keys for multiple outputs 95 | - None (results won't be stored) 96 | sink_format (Optional[str]): Format specification for LLM output (e.g., "json", "list"). 97 | Used to ensure consistent response structure. 98 | image_keys (Optional[List[str]]): List of keys for image file paths when using VLM. 99 | Must provide at least one image path in kwargs when these are specified. 100 | pre_process (Optional[Union[Callable, Dict[str, Callable]]]): Either a function to run 101 | before execution or a dictionary mapping state keys to transform functions. 102 | post_process (Optional[Union[Callable, Dict[str, Callable]]]): Either a function to run 103 | after execution or a dictionary mapping state keys to transform functions. 104 | sink_transform (Optional[Union[Callable, List[Callable]]]): Transform(s) to apply to 105 | sink value(s). If sink is a string, must be a single callable. If sink is a list, 106 | can be either a single callable (applied to all sinks) or a list of callables. 107 | custom_function (Optional[Callable]): Custom function to execute instead of LLM query. 108 | Function parameters become required keys for node execution. 109 | 110 | Attributes: 111 | required_keys (List[str]): Keys required from state/kwargs for node execution. 112 | Extracted from either prompt_template or custom_function signature. 113 | prompt_history (List[str]): History of prompt templates used by this node. 114 | 115 | Raises: 116 | ValueError: If required keys are missing or response format is invalid 117 | FileNotFoundError: If specified image files don't exist 118 | ValueError: If VLM operations are attempted without proper client 119 | 120 | Example: 121 | ```python 122 | # Create a simple text processing node 123 | node = Node( 124 | node_type="summarizer", 125 | prompt_template="Summarize this text: {text}", 126 | sink="summary" 127 | ) 128 | 129 | # Create a node with custom function 130 | def process_data(x, y): 131 | return x + y 132 | 133 | node = Node( 134 | node_type="calculator", 135 | prompt_template="", 136 | sink="result", 137 | custom_function=process_data 138 | ) 139 | ``` 140 | """ 141 | 142 | # Simplified set of allowed functions that return values 143 | ALLOWED_FUNCTIONS = { 144 | "len": len, # Length of sequences 145 | "str": str, # String conversion 146 | "int": int, # Integer conversion 147 | "float": float, # Float conversion 148 | "sum": sum, # Sum of numbers 149 | "max": max, # Maximum value 150 | "min": min, # Minimum value 151 | "abs": abs, # Absolute value 152 | } 153 | 154 | DISALLOWED_FUNCTION_NAMES = [ 155 | "eval", 156 | "exec", 157 | "compile", 158 | "open", 159 | "print", 160 | "execfile", 161 | "exit", 162 | "quit", 163 | "help", 164 | "dir", 165 | "globals", 166 | "locals", 167 | "dir", 168 | "type", 169 | "hash", 170 | "repr", 171 | "filter", 172 | "enumerate", 173 | "reversed", 174 | "sorted", 175 | "any", 176 | "all", 177 | ] 178 | 179 | # String methods that return values 180 | ALLOWED_STRING_METHODS = { 181 | "upper": str.upper, 182 | "lower": str.lower, 183 | "strip": str.strip, 184 | "capitalize": str.capitalize, 185 | } 186 | 187 | def __init__( 188 | self, 189 | prompt_template: str, 190 | node_type: Optional[str] = None, 191 | sink: Optional[Union[List[str], str]] = None, 192 | sink_format: Optional[str] = None, 193 | image_keys: Optional[List[str]] = None, 194 | pre_process: Optional[ 195 | Union[ 196 | Callable[[State, LLM_Client, Any], Optional[State]], Dict[str, Callable] 197 | ] 198 | ] = None, 199 | post_process: Optional[ 200 | Union[ 201 | Callable[[State, LLM_Client, Any], Optional[State]], Dict[str, Callable] 202 | ] 203 | ] = None, 204 | sink_transform: Optional[Union[Callable, List[Callable]]] = None, 205 | custom_function: Optional[Callable[..., Any]] = None, 206 | use_conversation: Optional[bool] = False, 207 | ): 208 | # Set default node_type based on whether it's prompt or function-based 209 | if node_type is None: 210 | if custom_function: 211 | self.node_type = custom_function.__name__ 212 | else: 213 | self.node_type = "prompt" 214 | else: 215 | self.node_type = node_type 216 | 217 | self.prompt_template = prompt_template 218 | self._escaped_sections = [] # Store escaped sections at instance level 219 | self.sink = sink 220 | self.image_keys = image_keys 221 | self.sink_format = sink_format 222 | self.custom_function = custom_function 223 | self.use_conversation = use_conversation 224 | 225 | # Handle pre_process 226 | if isinstance(pre_process, dict): 227 | transforms = pre_process 228 | self.pre_process = ( 229 | lambda state, client, **kwargs: _process_state_with_transforms( 230 | state, transforms, client, **kwargs 231 | ) 232 | ) 233 | else: 234 | self.pre_process = pre_process 235 | 236 | # Handle post_process 237 | if isinstance(post_process, dict): 238 | transforms = post_process 239 | self.post_process = ( 240 | lambda state, client, **kwargs: _process_state_with_transforms( 241 | state, transforms, client, **kwargs 242 | ) 243 | ) 244 | else: 245 | self.post_process = post_process 246 | 247 | # Handle sink_transform 248 | if sink_transform is not None: 249 | if isinstance(sink, str): 250 | if not callable(sink_transform): 251 | raise ValueError( 252 | "sink_transform must be callable when sink is a string" 253 | ) 254 | self._sink_transform = sink_transform 255 | elif isinstance(sink, list): 256 | if callable(sink_transform): 257 | # If single transform provided for multiple sinks, apply it to all 258 | self._sink_transform = [sink_transform] * len(sink) 259 | elif len(sink_transform) != len(sink): 260 | raise ValueError("Number of transforms must match number of sinks") 261 | else: 262 | self._sink_transform = sink_transform 263 | else: 264 | raise ValueError("sink must be specified to use sink_transform") 265 | else: 266 | self._sink_transform = None 267 | 268 | # Extract required keys from template or custom function signature 269 | if self.custom_function: 270 | # Get only required keys (those without default values) from function signature 271 | sig = signature(self.custom_function) 272 | self.required_keys = [ 273 | param.name 274 | for param in sig.parameters.values() 275 | if param.default is param.empty 276 | and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) 277 | and param.name != "self" 278 | ] 279 | else: 280 | # Extract base variable names from expressions, excluding function names and escaped content 281 | self.required_keys = [] 282 | # First, temporarily replace escaped content 283 | template = prompt_template 284 | 285 | # Replace {{{ }}} sections first 286 | import re 287 | 288 | triple_brace_pattern = ( 289 | r"\{{3}[\s\S]*?\}{3}" # Non-greedy match, including newlines 290 | ) 291 | for i, match in enumerate(re.finditer(triple_brace_pattern, template)): 292 | placeholder = f"___ESCAPED_TRIPLE_{i}___" 293 | self._escaped_sections.append((placeholder, match.group(0))) 294 | template = template.replace(match.group(0), placeholder) 295 | 296 | # Then replace {{ }} sections 297 | double_brace_pattern = ( 298 | r"\{{2}[\s\S]*?\}{2}" # Non-greedy match, including newlines 299 | ) 300 | for i, match in enumerate(re.finditer(double_brace_pattern, template)): 301 | placeholder = f"___ESCAPED_DOUBLE_{i}___" 302 | self._escaped_sections.append((placeholder, match.group(0))) 303 | template = template.replace(match.group(0), placeholder) 304 | 305 | self._template_with_placeholders = template # Store modified template 306 | 307 | # Now parse the template normally 308 | for _, expr, _, _ in Formatter().parse(template): 309 | if expr is not None: 310 | # Parse the expression to identify actual variables 311 | try: 312 | tree = ast.parse(expr, mode="eval") 313 | variables = set() 314 | for node in ast.walk(tree): 315 | if ( 316 | isinstance(node, ast.Name) 317 | and node.id not in self.ALLOWED_FUNCTIONS 318 | and node.id not in self.DISALLOWED_FUNCTION_NAMES 319 | ): 320 | variables.add(node.id) 321 | self.required_keys.extend(variables) 322 | except SyntaxError: 323 | # If parsing fails, fall back to basic extraction 324 | base_var = expr.split("[")[0].split(".")[0].split("(")[0] 325 | if ( 326 | base_var not in self.ALLOWED_FUNCTIONS 327 | and base_var not in self.DISALLOWED_FUNCTION_NAMES 328 | and base_var not in self.required_keys 329 | ): 330 | self.required_keys.append(base_var) 331 | 332 | # Remove duplicates while preserving order 333 | self.required_keys = list(dict.fromkeys(self.required_keys)) 334 | 335 | self._prompt_history = [ 336 | prompt_template 337 | ] # Add prompt history as private attribute 338 | 339 | def _eval_expr(self, expr: str, context: dict) -> Any: 340 | """Safely evaluate a Python expression with limited scope.""" 341 | try: 342 | # Add allowed functions to the context 343 | eval_context = { 344 | **context, 345 | **self.ALLOWED_FUNCTIONS, # Include built-in functions 346 | } 347 | tree = ast.parse(expr, mode="eval") 348 | return self._eval_node(tree.body, eval_context) 349 | except SyntaxError as e: 350 | raise ValueError(f"Invalid Python syntax in expression: {str(e)}") 351 | except Exception as e: 352 | raise ValueError(f"Invalid expression: {str(e)}") 353 | 354 | def _eval_node(self, node: ast.AST, context: dict) -> Any: 355 | """Recursively evaluate an AST node with security constraints.""" 356 | if isinstance(node, ast.Name): 357 | if node.id not in context: 358 | raise ValueError(f"Variable '{node.id}' not found in context") 359 | return context[node.id] 360 | 361 | elif isinstance(node, ast.Constant): 362 | return node.value 363 | 364 | elif isinstance(node, ast.UnaryOp): # Add support for unary operations 365 | if isinstance(node.op, ast.USub): # Handle negative numbers 366 | operand = self._eval_node(node.operand, context) 367 | return -operand 368 | raise ValueError(f"Unsupported unary operator: {type(node.op).__name__}") 369 | 370 | elif isinstance(node, ast.Call): 371 | if isinstance(node.func, ast.Attribute): # Add support for method calls 372 | obj = self._eval_node(node.func.value, context) 373 | method_name = node.func.attr 374 | # List of allowed string methods 375 | allowed_string_methods = ["upper", "lower", "title", "strip"] 376 | if method_name in allowed_string_methods: 377 | method = getattr(obj, method_name) 378 | args = [self._eval_node(arg, context) for arg in node.args] 379 | return method(*args) 380 | raise ValueError(f"String method not allowed: {method_name}") 381 | elif isinstance(node.func, ast.Name): 382 | func_name = node.func.id 383 | if func_name not in self.ALLOWED_FUNCTIONS: 384 | raise ValueError(f"Function not allowed: {func_name}") 385 | func = context[ 386 | func_name 387 | ] # Get function from context instead of globals 388 | args = [self._eval_node(arg, context) for arg in node.args] 389 | return func(*args) 390 | raise ValueError("Only simple function calls are allowed") 391 | 392 | elif isinstance(node, ast.Attribute): 393 | # Handle string methods (e.g., text.upper()) 394 | if not isinstance(node.value, ast.Name): 395 | raise ValueError("Only simple string methods are allowed") 396 | 397 | obj = self._eval_node(node.value, context) 398 | if not isinstance(obj, str): 399 | raise ValueError("Methods are only allowed on strings") 400 | 401 | method_name = node.attr 402 | if method_name not in self.ALLOWED_STRING_METHODS: 403 | raise ValueError(f"String method not allowed: {method_name}") 404 | 405 | return self.ALLOWED_STRING_METHODS[method_name](obj) 406 | 407 | elif isinstance(node, (ast.List, ast.Tuple)): 408 | return [self._eval_node(elt, context) for elt in node.elts] 409 | 410 | elif isinstance(node, ast.Subscript): 411 | value = self._eval_node(node.value, context) 412 | if isinstance(node.slice, ast.Slice): 413 | lower = ( 414 | self._eval_node(node.slice.lower, context) 415 | if node.slice.lower 416 | else None 417 | ) 418 | upper = ( 419 | self._eval_node(node.slice.upper, context) 420 | if node.slice.upper 421 | else None 422 | ) 423 | step = ( 424 | self._eval_node(node.slice.step, context) 425 | if node.slice.step 426 | else None 427 | ) 428 | return value[slice(lower, upper, step)] 429 | else: 430 | # Handle both numeric indices and string keys 431 | idx = self._eval_node(node.slice, context) 432 | try: 433 | return value[idx] 434 | except (TypeError, KeyError) as e: 435 | raise ValueError(f"Invalid subscript access: {str(e)}") 436 | 437 | elif isinstance(node, ast.Str): # For string literals in subscripts 438 | return node.s 439 | 440 | raise ValueError(f"Unsupported expression type: {type(node).__name__}") 441 | 442 | @property 443 | def func(self): 444 | """Returns the node function without executing it""" 445 | 446 | def node_function( 447 | state: Annotated[State, "The current state"], 448 | client: Annotated[LLM_Client, "The LLM client"], 449 | sink: Optional[Union[List[str], str]] = None, 450 | source: Optional[Dict[str, str]] = None, 451 | **kwargs, 452 | ) -> State: 453 | return self(state, client, sink, source, **kwargs) 454 | 455 | # Attach the attributes to the function 456 | node_function.node_type = self.node_type 457 | node_function.prompt_template = self.prompt_template 458 | node_function.sink = self.sink 459 | node_function.image_keys = self.image_keys 460 | node_function.sink_format = self.sink_format 461 | node_function.pre_process = self.pre_process 462 | node_function.post_process = self.post_process 463 | node_function.required_keys = self.required_keys 464 | return node_function 465 | 466 | def __call__( 467 | self, 468 | state: State, 469 | client: Union[LLM_Client, VLM_Client], 470 | sink: Optional[Union[List[str], str]] = None, 471 | source: Optional[Union[Dict[str, str], str]] = None, 472 | debug: bool = False, 473 | use_conversation: Optional[bool] = None, 474 | **kwargs, 475 | ) -> State: 476 | """Creates and executes a node function from this template. 477 | 478 | Args: 479 | state: Current state object containing variables 480 | client: LLM or VLM client for making API calls 481 | sink: Optional override for where to store results 482 | source: Optional mapping of template keys to state keys 483 | **kwargs: Additional keyword arguments passed to function 484 | 485 | Returns: 486 | Updated state object with results stored in sink keys 487 | 488 | Raises: 489 | ValueError: If required keys are missing or response format is invalid 490 | FileNotFoundError: If specified image files don't exist 491 | """ 492 | # Update node type 493 | state["previous_node_type"] = state.get("current_node_type", "") 494 | state["current_node_type"] = self.node_type 495 | 496 | # Pre-processing if defined 497 | if self.pre_process: 498 | pre_process_result = self.pre_process(state, client, **kwargs) 499 | if pre_process_result is None: 500 | return state 501 | state = pre_process_result 502 | 503 | # Get values from state or kwargs 504 | if isinstance(source, str): 505 | source = {"source": source} 506 | 507 | message_values = {} 508 | for key in self.required_keys: 509 | if source and key in source: 510 | source_key = source[key] 511 | if source_key not in state: 512 | raise ValueError( 513 | f"Source mapping key '{source_key}' not found in state" 514 | ) 515 | message_values[key] = state[source_key] 516 | elif key in state: 517 | message_values[key] = state[key] 518 | elif key in kwargs: 519 | message_values[key] = kwargs[key] 520 | else: 521 | raise ValueError(f"Required key '{key}' not found in state or kwargs") 522 | 523 | # Execute either custom function or LLM call 524 | if self.custom_function: 525 | # Get default values from function signature 526 | sig = signature(self.custom_function) 527 | default_values = { 528 | k: v.default 529 | for k, v in sig.parameters.items() 530 | if v.default is not v.empty 531 | } 532 | # Update message_values with defaults for missing parameters 533 | for key, default in default_values.items(): 534 | if key not in message_values: 535 | message_values[key] = default 536 | if "state" in sig.parameters and "state" not in message_values: 537 | message_values["state"] = state 538 | if "client" in sig.parameters and "client" not in message_values: 539 | message_values["client"] = client 540 | response = self.custom_function(**message_values) 541 | else: 542 | # Create a context with state variables for expression evaluation 543 | eval_context = {**message_values} 544 | 545 | # First fill the template with placeholders 546 | message = self._template_with_placeholders 547 | for _, expr, _, _ in Formatter().parse(self._template_with_placeholders): 548 | if expr is not None: 549 | try: 550 | result = self._eval_expr(expr, eval_context) 551 | message = message.replace(f"{{{expr}}}", str(result)) 552 | except Exception as e: 553 | raise ValueError( 554 | f"Error evaluating expression '{expr}': {str(e)}" 555 | ) 556 | 557 | # Now restore the escaped sections 558 | for placeholder, original in self._escaped_sections: 559 | message = message.replace(placeholder, original) 560 | 561 | # Record the formatted message 562 | if "messages" not in state: 563 | state["messages"] = [] 564 | state["messages"].append({"role": "user", "content": message}) 565 | 566 | # Determine if we should use conversation mode 567 | should_use_conversation = ( 568 | use_conversation 569 | if use_conversation is not None 570 | else self.use_conversation 571 | ) 572 | if should_use_conversation: 573 | assert "conversation" in state and isinstance( 574 | state["conversation"], list 575 | ), "Conversation does not exist in state or is not a list of messages" 576 | 577 | # Prepare messages for client call 578 | if should_use_conversation: 579 | if len(state["conversation"]) == 0 or state["end_conversation"]: 580 | state["conversation"].append({"role": "user", "content": message}) 581 | messages = state["conversation"] 582 | else: 583 | messages = [{"role": "user", "content": message}] 584 | 585 | # Handle VLM specific requirements 586 | if self.image_keys: 587 | if not isinstance(client, VLM_Client): 588 | raise ValueError("VLM client required for image keys") 589 | 590 | # Check both state and kwargs for image keys 591 | image_paths = [] 592 | for key in self.image_keys: 593 | if key in state: 594 | path = state[key] 595 | elif key in kwargs: 596 | path = kwargs[key] 597 | else: 598 | continue 599 | 600 | if path is None: 601 | raise TypeError( 602 | f"Image path for '{key}' should be string, got None" 603 | ) 604 | if not isinstance(path, str): 605 | raise TypeError( 606 | f"Image path for '{key}' should be string, got {type(path)}" 607 | ) 608 | image_paths.append(path) 609 | 610 | if not image_paths: 611 | raise ValueError( 612 | "At least one image key must be provided in state or kwargs" 613 | ) 614 | 615 | # Verify all paths exist 616 | for path in image_paths: 617 | if not os.path.exists(path): 618 | raise FileNotFoundError(f"Image file not found: {path}") 619 | 620 | response = client( 621 | messages=messages, 622 | images=image_paths, 623 | format=self.sink_format, 624 | workflow=kwargs.get("workflow"), 625 | node=self, 626 | previous_node_type=state["previous_node_type"], 627 | ) 628 | else: 629 | response = client( 630 | messages=messages, 631 | format=self.sink_format, 632 | workflow=kwargs.get("workflow"), 633 | node=self, 634 | previous_node_type=state["previous_node_type"], 635 | ) 636 | 637 | log_print_color(f"Response: {response}", "white", False) 638 | 639 | # Update state with response 640 | if sink is None: 641 | sink = self.sink 642 | 643 | if sink is None: 644 | log_print_color( 645 | f"Warning: No sink specified for {self.node_type} node", "yellow" 646 | ) 647 | return state 648 | 649 | if isinstance(sink, str): 650 | state[sink] = ( 651 | remove_markdown_blocks_formatting(response) 652 | if not self.custom_function 653 | else response 654 | ) 655 | elif isinstance(sink, list): 656 | if not sink: 657 | log_print_color( 658 | f"Warning: Empty sink list for {self.node_type} node", "yellow" 659 | ) 660 | return state 661 | 662 | if len(sink) == 1: 663 | state[sink[0]] = ( 664 | remove_markdown_blocks_formatting(response) 665 | if not self.custom_function 666 | else response 667 | ) 668 | else: 669 | if not isinstance(response, (list, tuple)): 670 | raise ValueError( 671 | f"Expected multiple responses for multiple sink in {self.node_type} node, but got a single response" 672 | ) 673 | if len(response) != len(sink): 674 | raise ValueError( 675 | f"Number of responses ({len(response)}) doesn't match number of sink ({len(sink)}) in {self.node_type} node" 676 | ) 677 | 678 | for key, value in zip(sink, response): 679 | state[key] = ( 680 | remove_markdown_blocks_formatting(value) 681 | if not self.custom_function 682 | else value 683 | ) 684 | 685 | # After storing results but before post_process, apply sink transforms 686 | if self._sink_transform is not None: 687 | current_sink = sink or self.sink 688 | if isinstance(current_sink, str): 689 | state[current_sink] = self._sink_transform(state[current_sink]) 690 | else: 691 | for key, transform in zip(current_sink, self._sink_transform): 692 | state[key] = transform(state[key]) 693 | 694 | # Post-processing if defined 695 | if self.post_process: 696 | post_process_result = self.post_process(state, client, **kwargs) 697 | if post_process_result is None: 698 | return state 699 | state = post_process_result 700 | 701 | return state 702 | 703 | def __str__(self): 704 | MAX_WIDTH = 80 705 | 706 | # Format prompt with highlighted keys 707 | prompt_lines = self.prompt_template.split("\n") 708 | # First make the whole prompt green 709 | prompt_lines = [f"\033[92m{line}\033[0m" for line in prompt_lines] # Green 710 | # Then highlight the keys in red 711 | for key in self.required_keys: 712 | for i, line in enumerate(prompt_lines): 713 | prompt_lines[i] = line.replace( 714 | f"{{{key}}}", 715 | f"\033[91m{{{key}}}\033[0m\033[92m", # Red keys, return to green after 716 | ) 717 | 718 | # Calculate width for horizontal line (min of actual width and MAX_WIDTH) 719 | width = min(max(len(line) for line in prompt_lines), MAX_WIDTH) 720 | double_line = "═" * width 721 | horizontal_line = "─" * width 722 | 723 | # Color formatting for keys in info section 724 | required_keys_colored = [ 725 | f"\033[91m{key}\033[0m" for key in self.required_keys 726 | ] # Red 727 | if isinstance(self.sink, str): 728 | sink_colored = [f"\033[94m{self.sink}\033[0m"] # Blue 729 | elif isinstance(self.sink, list): 730 | sink_colored = [f"\033[94m{key}\033[0m" for key in self.sink] # Blue 731 | else: 732 | sink_colored = ["None"] 733 | 734 | # Build the string representation 735 | result = [ 736 | double_line, 737 | f"{self.node_type}", 738 | horizontal_line, 739 | *prompt_lines, 740 | horizontal_line, 741 | f"Required keys: {', '.join(required_keys_colored)}", 742 | f"Sink keys: {', '.join(sink_colored)}", 743 | f"Format: {self.sink_format or 'None'}", 744 | f"Image keys: {', '.join(self.image_keys) or 'None'}", 745 | f"Pre-process: {self.pre_process.__name__ if self.pre_process else 'None'}", 746 | f"Post-process: {self.post_process.__name__ if self.post_process else 'None'}", 747 | f"Custom function: {self.custom_function.__name__ if self.custom_function else 'None'}", 748 | ] 749 | 750 | return "\n".join(result) 751 | 752 | @property 753 | def prompt_history(self) -> list[str]: 754 | """Returns the history of prompt templates. 755 | 756 | Returns: 757 | list[str]: List of prompt templates, oldest to newest 758 | """ 759 | return self._prompt_history.copy() 760 | 761 | 762 | def as_node( 763 | sink: List[str], 764 | pre_process: Optional[Callable[[State, LLM_Client, Any], Optional[State]]] = None, 765 | post_process: Optional[Callable[[State, LLM_Client, Any], Optional[State]]] = None, 766 | as_function: bool = False, 767 | ): 768 | """Decorator to transform a regular Python function into a Node function. 769 | 770 | This decorator allows you to convert standard Python functions into Node objects 771 | that can be integrated into a nodeology workflow. The decorated function becomes 772 | the custom_function of the Node, with its parameters becoming required keys. 773 | 774 | Args: 775 | sink (List[str]): List of state keys where the function's results will be stored. 776 | The number of sink keys should match the number of return values from the function. 777 | pre_process (Optional[Callable]): Function to run before main execution. 778 | Signature: (state: State, client: LLM_Client, **kwargs) -> Optional[State] 779 | post_process (Optional[Callable]): Function to run after main execution. 780 | Signature: (state: State, client: LLM_Client, **kwargs) -> Optional[State] 781 | as_function (bool): If True, returns a callable node function. If False, returns 782 | the Node object itself. Default is False. 783 | 784 | Returns: 785 | Union[Node, Callable]: Either a Node object or a node function, depending on 786 | the as_function parameter. 787 | 788 | Example: 789 | ```python 790 | # Basic usage 791 | @as_node(sink=["result"]) 792 | def multiply(x: int, y: int) -> int: 793 | return x * y 794 | 795 | # With pre and post processing 796 | def log_start(state, client, **kwargs): 797 | print("Starting calculation...") 798 | return state 799 | 800 | def log_result(state, client, **kwargs): 801 | print(f"Result: {state['result']}") 802 | return state 803 | 804 | @as_node( 805 | sink=["result"], 806 | pre_process=log_start, 807 | post_process=log_result 808 | ) 809 | def add(x: int, y: int) -> int: 810 | return x + y 811 | 812 | # Multiple return values 813 | @as_node(sink=["mean", "std"]) 814 | def calculate_stats(numbers: List[float]) -> Tuple[float, float]: 815 | return np.mean(numbers), np.std(numbers) 816 | ``` 817 | 818 | Notes: 819 | - The decorated function's parameters become required keys for node execution 820 | - The function can access the state and client objects by including them 821 | as optional parameters 822 | - The number of sink keys should match the number of return values 823 | - When as_function=True, the decorator returns a callable that can be used 824 | directly in workflows 825 | """ 826 | 827 | def decorator(func): 828 | # Create a Node instance with the custom function 829 | node = Node( 830 | prompt_template="", # Empty template since we're using custom function 831 | node_type=func.__name__, 832 | sink=sink, 833 | pre_process=pre_process, 834 | post_process=post_process, 835 | custom_function=func, # Pass the function to Node 836 | ) 837 | 838 | # Get only required parameters (those without default values) 839 | sig = signature(func) 840 | node.required_keys = [ 841 | param.name 842 | for param in sig.parameters.values() 843 | if param.default is param.empty 844 | ] 845 | 846 | return node.func if as_function else node 847 | 848 | return decorator 849 | 850 | 851 | def remove_markdown_blocks_formatting(text: str) -> str: 852 | """Remove common markdown code block delimiters from text. 853 | 854 | Args: 855 | text: Input text containing markdown code blocks 856 | 857 | Returns: 858 | str: Text with code block delimiters removed 859 | """ 860 | lines = text.split("\n") 861 | cleaned_lines = [] 862 | 863 | for line in lines: 864 | stripped_line = line.strip() 865 | # Check if line starts with backticks (more robust than exact matches) 866 | if stripped_line.startswith("```"): 867 | continue 868 | else: 869 | cleaned_lines.append(line) 870 | 871 | return "\n".join(cleaned_lines) 872 | --------------------------------------------------------------------------------