├── LICENSE ├── CMakeLists.txt ├── .gitignore ├── README.md ├── main.cpp └── pocketflow.h /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Zachary Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) # Need at least 3.10 for C++17 target property 2 | 3 | project(PocketFlowCpp VERSION 1.0.0 LANGUAGES CXX) 4 | 5 | # Require C++17 for std::any, std::optional 6 | set(CMAKE_CXX_STANDARD 17) 7 | set(CMAKE_CXX_STANDARD_REQUIRED True) 8 | set(CMAKE_CXX_EXTENSIONS OFF) # Prefer standard C++ 9 | 10 | # --- Library Definition (Optional, if you want to build a static/shared lib) --- 11 | # add_library(pocketflow PocketFlow.h) # Header-only doesn't need this directly 12 | # target_include_directories(pocketflow PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 13 | # target_compile_features(pocketflow PUBLIC cxx_std_17) # Ensure C++17 14 | 15 | # --- Executable Example --- 16 | add_executable(pocketflow_example main.cpp) 17 | 18 | # Link executable to library if it were built separately: 19 | # target_link_libraries(pocketflow_example PRIVATE pocketflow) 20 | 21 | # If header-only, just need to ensure includes are found and C++17 is used 22 | target_include_directories(pocketflow_example PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 23 | target_compile_features(pocketflow_example PRIVATE cxx_std_17) # Ensure C++17 for the executable 24 | 25 | # --- Testing (Example using GoogleTest - requires GTest setup) --- 26 | # enable_testing() 27 | # find_package(GTest REQUIRED) 28 | # add_executable(pocketflow_tests PocketFlow_test.cpp) # Your test file 29 | # target_link_libraries(pocketflow_tests PRIVATE GTest::gtest GTest::gtest_main) 30 | # target_include_directories(pocketflow_tests PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 31 | # target_compile_features(pocketflow_tests PRIVATE cxx_std_17) 32 | # include(GoogleTest) 33 | # gtest_discover_tests(pocketflow_tests) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # CMake generated files and directories 2 | /[Bb]uild/ 3 | /build-*/ 4 | /out/ 5 | /install/ 6 | CMakeCache.txt 7 | CMakeFiles/ 8 | CMakeScripts/ 9 | cmake_install.cmake 10 | CTestTestfile.cmake 11 | CTestConfig.cmake 12 | install_manifest.txt 13 | compile_commands.json 14 | Makefile 15 | *.make 16 | 17 | # Object files, libraries, executables 18 | *.o 19 | *.obj 20 | *.lo 21 | *.la 22 | *.a 23 | *.lib 24 | *.so 25 | *.so.* 26 | *.dylib 27 | *.dll 28 | *.exe 29 | *.out 30 | *.app 31 | *.i # Preprocessed C files 32 | *.s # Assembly files 33 | 34 | # Precompiled Headers 35 | *.gch 36 | *.pch 37 | 38 | # Dependency files 39 | *.d 40 | 41 | # ========================================== 42 | # IDE / Editor specific files 43 | # ========================================== 44 | 45 | # Visual Studio Code 46 | .vscode/ 47 | *.code-workspace 48 | 49 | # Visual Studio 50 | .vs/ 51 | *.sln.docstates 52 | *.suo 53 | *.user 54 | *.vcxproj.user 55 | *.vcxproj.filters 56 | ipch/ 57 | obj/ 58 | bin/ 59 | [Dd]ebug/ 60 | [Rr]elease/ 61 | x64/ 62 | Win32/ 63 | ARM/ 64 | ARM64/ 65 | 66 | # CLion 67 | .idea/ 68 | cmake-build-*/ 69 | 70 | # Qt Creator 71 | *.pro.user 72 | *.pro.user.* 73 | *.qbs.user 74 | *.qbs.user.* 75 | 76 | # Xcode 77 | *.xcodeproj/ 78 | *.xcworkspace/ 79 | xcuserdata/ 80 | *.xccheckout 81 | 82 | # Eclipse CDT 83 | .cproject 84 | .project 85 | .settings/ 86 | 87 | # NetBeans 88 | nbproject/ 89 | 90 | # Sublime Text 91 | *.sublime-project 92 | *.sublime-workspace 93 | 94 | # Vim / Emacs / TextMate / etc. 95 | *.swp 96 | *~ 97 | .*.swp 98 | .*.swo 99 | Session.vim 100 | .DS_Store 101 | *.tm_properties 102 | *.bak 103 | *.orig 104 | 105 | # ========================================== 106 | # Operating System specific files 107 | # ========================================== 108 | 109 | # macOS 110 | .DS_Store 111 | .AppleDouble 112 | .LSOverride 113 | ._* 114 | .Spotlight-V100 115 | .Trashes 116 | Network Trash Folder 117 | Temporary Items 118 | Icon? 119 | 120 | # Windows 121 | Thumbs.db 122 | ehthumbs.db 123 | Desktop.ini 124 | $RECYCLE.BIN/ 125 | *.lnk 126 | 127 | # Linux 128 | *~ 129 | .directory 130 | 131 | # ========================================== 132 | # Testing / Coverage / Logging 133 | # ========================================== 134 | Testing/ 135 | coverage/ 136 | *.gcno 137 | *.gcda 138 | *.gcov 139 | *.lcov 140 | *.profraw 141 | *.log 142 | junit*.xml # if generating JUnit test reports 143 | 144 | # ========================================== 145 | # Dependency Management (add if using) 146 | # ========================================== 147 | # Conan 148 | # conanbuildinfo.* 149 | # conaninfo.txt 150 | # _build/ 151 | 152 | # vcpkg 153 | # vcpkg_installed/ 154 | 155 | # ========================================== 156 | # Project Specific (adjust as needed) 157 | # ========================================== 158 | # Example executable name from main.cpp 159 | /pocketflow_example 160 | /pocketflow_example.exe 161 | 162 | # Any local configuration files 163 | # config.local 164 | 165 | # Any generated documentation you don't commit 166 | # docs/html 167 | # docs/latex 168 | 169 | # Package files (if built locally) 170 | # *.tar.gz 171 | # *.zip 172 | # *.deb 173 | # *.rpm 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PocketFlow C++ 2 | 3 | A minimalist LLM framework, ported from Python to C++. 4 | 5 | ## Overview 6 | 7 | PocketFlow C++ is a port of the original [Python PocketFlow](https://github.com/The-Pocket/PocketFlow). It provides a lightweight, flexible system for building and executing workflows through a simple node-based architecture using modern C++. 8 | 9 | > **Note:** This is an initial implementation ported from the Python version. It currently only supports synchronous operations. Community contributions are welcome to enhance and maintain this project, particularly in adding asynchronous capabilities and comprehensive testing. 10 | 11 | ## Features 12 | 13 | * **Node-Based Architecture:** Define workflows by connecting distinct processing units (nodes). 14 | * **Type-Safe (within C++ limits):** Uses C++ templates for node input/output types. `std::any` is used for flexible context and parameters. 15 | * **Synchronous Execution:** Simple, predictable execution flow (async is a future goal). 16 | * **Context Propagation:** Share data between nodes using a `Context` map (`std::map`). 17 | * **Configurable Nodes:** Pass parameters to nodes using a `Params` map (also `std::map`). 18 | * **Retry Logic:** Built-in optional retry mechanism for `Node` operations. 19 | * **Batch Processing:** Includes `BatchNode` and `BatchFlow` for processing lists of items or parameter sets. 20 | * **Header-Only:** The core library is provided in `PocketFlow.h` for easy integration. 21 | 22 | ## Requirements 23 | 24 | * **C++17 Compliant Compiler:** Required for `std::any` and `std::optional`. (e.g., GCC 7+, Clang 5+, MSVC 19.14+) 25 | * **CMake:** Version 3.10 or higher (for C++17 support). 26 | 27 | ## Building 28 | 29 | The library itself is header-only (`PocketFlow.h`). To build the example provided (`main.cpp`): 30 | 31 | 1. Ensure you have CMake and a C++17 compiler installed. 32 | 2. Create a build directory: 33 | ```bash 34 | mkdir build 35 | cd build 36 | ``` 37 | 3. Run CMake to configure the project: 38 | ```bash 39 | cmake .. 40 | ``` 41 | 4. Build the executable: 42 | ```bash 43 | cmake --build . 44 | # Or use make, ninja, etc. depending on your generator 45 | # make 46 | ``` 47 | 5. The example executable (e.g., `pocketflow_example`) will be inside the `build` directory. 48 | ```bash 49 | ./pocketflow_example 50 | ``` 51 | 52 | ## Usage 53 | 54 | Here's a simple example demonstrating how to define and run a workflow: 55 | 56 | ```cpp 57 | #include "pocketflow.h" // Include the library header 58 | #include 59 | #include 60 | #include 61 | #include 62 | #include // For std::make_shared 63 | 64 | // Use the namespace 65 | using namespace pocketflow; 66 | 67 | // --- Define Custom Nodes --- 68 | 69 | // Start Node: Takes no input (nullptr_t), returns a string action 70 | class MyStartNode : public Node { 71 | public: 72 | std::string exec(std::nullptr_t /*prepResult*/) override { 73 | std::cout << "Starting workflow..." << std::endl; 74 | return "started"; // This string determines the next node 75 | } 76 | // Optional: override post to return the action explicitly 77 | std::optional post(Context& ctx, const std::nullptr_t& p, const std::string& e) override { 78 | return e; // Return exec result as the action 79 | } 80 | }; 81 | 82 | // End Node: Takes a string from prep, returns nothing (nullptr_t) 83 | class MyEndNode : public Node { 84 | public: 85 | // Optional: Prepare input for exec, potentially using context 86 | std::string prep(Context& ctx) override { 87 | return "Preparing to end workflow"; 88 | } 89 | 90 | // Execute the node's main logic 91 | std::nullptr_t exec(std::string prepResult) override { 92 | std::cout << "Ending workflow with: " << prepResult << std::endl; 93 | return nullptr; // Return value for void exec 94 | } 95 | // Optional: post can modify context, default returns no action (ends flow here) 96 | }; 97 | 98 | 99 | int main() { 100 | // --- Create Node Instances (use std::shared_ptr) --- 101 | auto startNode = std::make_shared(); 102 | auto endNode = std::make_shared(); 103 | 104 | // --- Connect Nodes --- 105 | // When startNode returns the action "started", execute endNode next. 106 | startNode->next(endNode, "started"); 107 | 108 | // --- Create and Configure Flow --- 109 | Flow flow(startNode); // Initialize the flow with the starting node 110 | 111 | // --- Prepare Context and Run --- 112 | Context sharedContext; // Map to share data between nodes 113 | std::cout << "Executing workflow..." << std::endl; 114 | flow.run(sharedContext); // Execute the workflow 115 | std::cout << "Workflow completed successfully." << std::endl; 116 | 117 | return 0; 118 | } 119 | ``` 120 | 121 | ## Core Concepts 122 | 123 | * **`IBaseNode` / `BaseNode`:** The fundamental building block. 124 | * `P`: The type returned by the `prep` method (input to `exec`). Use `std::nullptr_t` for no input. 125 | * `E`: The type returned by the `exec` method (input to `post`). Use `std::nullptr_t` for no return value. 126 | * `prep(Context&)`: Prepare data needed for `exec`. Can use/modify the shared context. 127 | * `exec(P)`: Execute the core logic of the node. 128 | * `post(Context&, P, E)`: Process results after `exec`. Can use/modify context. Returns `std::optional` which is the *action* determining the next node. `std::nullopt` triggers the default transition. 129 | * `next(node, action)`: Connects this node to the `node` when the `action` string is returned by `post`. `next(node)` connects via the default action. 130 | * **`Node`:** A `BaseNode` with added retry logic (`maxRetries`, `waitMillis`, `execFallback`). 131 | * **`BatchNode`:** A `Node` that processes a `std::vector` and produces a `std::vector`, handling retries per item via `execItem` and `execItemFallback`. 132 | * **`Flow`:** Orchestrates the execution of connected nodes starting from a designated `startNode`. 133 | * **`BatchFlow`:** A `Flow` that runs its entire sequence for multiple parameter sets generated by `prepBatch`. 134 | * **`Context` (`std::map`):** A shared data store passed through the workflow, allowing nodes to communicate indirectly. Requires careful type casting (`std::any_cast`). 135 | * **`Params` (`std::map`):** Configuration parameters passed to a node instance, typically set before execution or by a `BatchFlow`. 136 | 137 | ## C++ Specifics (vs. Java/Python) 138 | 139 | * **Memory Management:** Uses `std::shared_ptr` for managing node object lifetimes within the workflow graph. 140 | * **Type Erasure:** `std::any` is used for `Context` and `Params`, requiring explicit `std::any_cast` and careful handling of potential `std::bad_any_cast` exceptions. 141 | * **`void` Types:** `std::nullptr_t` is used as a placeholder template argument for `P` or `E` when a node doesn't logically take input (`prep` returns nothing) or produce output (`exec` returns nothing). 142 | * **Actions:** Node transitions are determined by `std::optional` returned from `post`. `std::nullopt` signifies the default transition (if one is defined using `next(node)`). 143 | * **Header-Only:** Simplifies integration – just include `PocketFlow.h`. 144 | 145 | ## Development 146 | 147 | ### Building the Example/Tests 148 | 149 | Use the CMake instructions provided in the "Building" section. 150 | 151 | ### Running Tests 152 | 153 | (Currently, no automated test suite like JUnit is included. This is a key area for contribution!) 154 | You would typically integrate a testing framework like GoogleTest: 155 | 1. Set up GoogleTest in your project. 156 | 2. Write test cases in a separate `.cpp` file (e.g., `PocketFlow_test.cpp`). 157 | 3. Configure CMake (as shown commented out in the example `CMakeLists.txt`) to build and run the tests. 158 | 159 | ```bash 160 | # Example CMake commands after GoogleTest setup 161 | cd build 162 | cmake .. 163 | cmake --build . 164 | ctest # Run tests 165 | ``` 166 | 167 | ## Contributing 168 | 169 | Contributions are highly welcome! We are particularly looking for help with: 170 | 171 | 1. **Asynchronous Operations:** Implementing non-blocking node execution (e.g., using `std::async`, `std::thread`, futures, or an external async library). 172 | 2. **Testing:** Adding comprehensive unit and integration tests using a framework like GoogleTest. 173 | 3. **Documentation:** Improving explanations, adding more examples, and documenting edge cases. 174 | 4. **Error Handling:** Refining exception types and context propagation for errors. 175 | 5. **Examples:** Providing more practical examples, potentially related to LLM interactions. 176 | 177 | Please feel free to submit pull requests or open issues for discussion. 178 | 179 | ## License 180 | 181 | [MIT License](LICENSE) -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include "pocketflow.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include // For std::make_shared 7 | 8 | // Use the namespace 9 | using namespace pocketflow; 10 | 11 | // --- Test Node Implementations --- 12 | 13 | // Define a custom start node 14 | // Template arguments: P=void (no prep input), E=std::string (exec returns string) 15 | class MyStartNode : public Node { 16 | public: 17 | // Override exec (pure virtual in BaseNode) 18 | std::string exec(std::nullptr_t /*prepResult*/) override { 19 | std::cout << "Starting workflow..." << std::endl; 20 | return "started"; 21 | } 22 | // Override post (optional) 23 | std::optional post(Context& ctx, const std::nullptr_t& p, const std::string& e) override { 24 | // Return the exec result directly as the action 25 | return e; 26 | } 27 | }; 28 | 29 | // Define a custom end node 30 | // Template arguments: P=std::string (prep returns string), E=void (exec returns nothing) 31 | class MyEndNode : public Node { 32 | public: 33 | // Override prep (optional) 34 | std::string prep(Context& ctx) override { 35 | // Example: Read something from context if needed 36 | // return std::any_cast(ctx.at("some_key")); 37 | return "Preparing to end workflow"; 38 | } 39 | 40 | // Override exec 41 | std::nullptr_t exec(std::string prepResult) override { 42 | std::cout << "Ending workflow with: " << prepResult << std::endl; 43 | // Since E is void (represented by nullptr_t), we don't return anything meaningful 44 | return nullptr; 45 | } 46 | 47 | // Override post (optional, default returns nullopt/default action) 48 | std::optional post(Context& ctx, const std::string& p, const std::nullptr_t& e) override { 49 | ctx["end_node_prep_result"] = p; // Example: Store something in context 50 | return std::nullopt; // No further action needed 51 | } 52 | }; 53 | 54 | 55 | // --- Example Test Nodes mirroring Java Test --- 56 | 57 | // P=nullptr_t, E=int 58 | class SetNumberNode : public Node { 59 | int number; 60 | public: 61 | SetNumberNode(int num) : number(num) {} 62 | 63 | int exec(std::nullptr_t) override { 64 | int multiplier = getParamOrDefault("multiplier", 1); 65 | return number * multiplier; 66 | } 67 | 68 | std::optional post(Context& ctx, const std::nullptr_t&, const int& e) override { 69 | ctx["currentValue"] = e; // Store result in context 70 | return e > 20 ? std::make_optional("over_20") : std::nullopt; // Branching action 71 | } 72 | }; 73 | 74 | // P=int, E=int 75 | class AddNumberNode : public Node { 76 | int numberToAdd; 77 | public: 78 | AddNumberNode(int num) : numberToAdd(num) {} 79 | 80 | int prep(Context& ctx) override { 81 | // Get value from context, throw if not found or wrong type 82 | try { 83 | return std::any_cast(ctx.at("currentValue")); 84 | } catch (const std::out_of_range& oor) { 85 | throw PocketFlowException("Context missing 'currentValue' for AddNumberNode"); 86 | } catch (const std::bad_any_cast& bac) { 87 | throw PocketFlowException("'currentValue' in context is not an int for AddNumberNode"); 88 | } 89 | } 90 | 91 | int exec(int currentValue) override { 92 | return currentValue + numberToAdd; 93 | } 94 | 95 | std::optional post(Context& ctx, const int&, const int& e) override { 96 | ctx["currentValue"] = e; // Update context 97 | return "added"; // Fixed action 98 | } 99 | }; 100 | 101 | // P=int, E=nullptr_t 102 | class ResultCaptureNode : public Node { 103 | public: 104 | int capturedValue = -999; // Store result locally for testing 105 | 106 | int prep(Context& ctx) override { 107 | // Get value from context, provide default if missing 108 | auto it = ctx.find("currentValue"); 109 | if (it != ctx.end()) { 110 | try { 111 | return std::any_cast(it->second); 112 | } catch(const std::bad_any_cast& ) { 113 | // Handle error or return default 114 | } 115 | } 116 | return -999; // Default if not found or bad cast 117 | } 118 | 119 | std::nullptr_t exec(int prepResult) override { 120 | capturedValue = prepResult; // Capture the value 121 | // Also store in params map like the Java example 122 | this->params["capturedValue"] = prepResult; 123 | return nullptr; 124 | } 125 | // No post needed, default action (nullopt) is fine 126 | }; 127 | 128 | 129 | int main() { 130 | // --- Simple Workflow Example --- 131 | std::cout << "--- Running Simple Workflow ---" << std::endl; 132 | auto startNode = std::make_shared(); 133 | auto endNode = std::make_shared(); 134 | 135 | // Connect the nodes: startNode transitions to endNode on the "started" action 136 | startNode->next(endNode, "started"); 137 | 138 | // Create a flow with the start node 139 | Flow flow(startNode); 140 | 141 | // Create a context and run the flow 142 | Context context; 143 | std::cout << "Executing workflow..." << std::endl; 144 | flow.run(context); // Returns the final action, ignored here 145 | std::cout << "Workflow completed successfully." << std::endl; 146 | // Check context if endNode modified it 147 | if (context.count("end_node_prep_result")) { 148 | std::cout << "End node stored in context: " 149 | << std::any_cast(context["end_node_prep_result"]) 150 | << std::endl; 151 | } 152 | std::cout << std::endl; 153 | 154 | 155 | // --- Linear Flow Test Example (like Java test) --- 156 | std::cout << "--- Running Linear Test Workflow ---" << std::endl; 157 | auto setNum = std::make_shared(10); 158 | auto addNum = std::make_shared(5); 159 | auto capture = std::make_shared(); 160 | 161 | setNum->next(addNum) // Default action connects to addNum 162 | ->next(capture, "added"); // addNum's "added" action connects to capture 163 | 164 | Flow linearFlow(setNum); 165 | Context linearContext; 166 | linearFlow.run(linearContext); 167 | 168 | // Assertions (manual checks here) 169 | std::cout << "Linear Test: Final Context 'currentValue': " 170 | << (linearContext.count("currentValue") ? std::any_cast(linearContext["currentValue"]) : -1) 171 | << " (Expected: 15)" << std::endl; 172 | std::cout << "Linear Test: Captured Value in Node: " 173 | << capture->capturedValue 174 | << " (Expected: 15)" << std::endl; 175 | // Check params map in capture node 176 | if (capture->getParams().count("capturedValue")) { 177 | std::cout << "Linear Test: Captured Value in Params: " 178 | << std::any_cast(capture->getParams().at("capturedValue")) 179 | << " (Expected: 15)" << std::endl; 180 | } 181 | std::cout << std::endl; 182 | 183 | 184 | // --- Branching Flow Test Example --- 185 | std::cout << "--- Running Branching Test Workflow ---" << std::endl; 186 | auto setNumBranch = std::make_shared(10); // Multiplier will make it > 20 187 | auto addNumBranch = std::make_shared(5); 188 | auto captureDefault = std::make_shared(); 189 | auto captureOver20 = std::make_shared(); 190 | 191 | // Connections 192 | setNumBranch->next(addNumBranch); // Default action 193 | setNumBranch->next(captureOver20, "over_20"); // "over_20" action 194 | addNumBranch->next(captureDefault, "added"); // "added" action from addNum 195 | 196 | Flow branchingFlow(setNumBranch); 197 | Context branchingContext; 198 | // Set params for the flow (which get passed to the first node) 199 | branchingFlow.setParams({{"multiplier", 3}}); // Make initial value 30 (> 20) 200 | 201 | branchingFlow.run(branchingContext); 202 | 203 | // Assertions (manual checks) 204 | std::cout << "Branching Test: Final Context 'currentValue': " 205 | << (branchingContext.count("currentValue") ? std::any_cast(branchingContext["currentValue"]) : -1) 206 | << " (Expected: 30)" << std::endl; // From SetNumberNode's post 207 | std::cout << "Branching Test: Default Capture Node Value: " 208 | << captureDefault->capturedValue 209 | << " (Expected: -999 - not executed)" << std::endl; 210 | std::cout << "Branching Test: Over_20 Capture Node Value: " 211 | << captureOver20->capturedValue 212 | << " (Expected: 30)" << std::endl; 213 | std::cout << std::endl; 214 | 215 | 216 | return 0; 217 | } -------------------------------------------------------------------------------- /pocketflow.h: -------------------------------------------------------------------------------- 1 | #ifndef POCKETFLOW_H 2 | #define POCKETFLOW_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include // For std::shared_ptr 9 | #include // For type erasure (like Java Object) - C++17 10 | #include // For optional return values (like Java null) - C++17 11 | #include 12 | #include 13 | #include 14 | #include // For std::function if needed (not strictly used here yet) 15 | #include // For std::move 16 | 17 | namespace pocketflow { 18 | 19 | // --- Type Definitions --- 20 | using Context = std::map; 21 | using Params = std::map; 22 | 23 | // --- Constants --- 24 | // Use std::nullopt to represent the default action instead of a magic string 25 | // static const std::string DEFAULT_ACTION = "default"; // Optional: if you prefer explicit string 26 | 27 | // --- Utility Functions --- 28 | inline void logWarn(const std::string& message) { 29 | std::cerr << "WARN: PocketFlow - " << message << std::endl; 30 | } 31 | 32 | // --- Custom Exception --- 33 | class PocketFlowException : public std::runtime_error { 34 | public: 35 | PocketFlowException(const std::string& message) : std::runtime_error(message) {} 36 | PocketFlowException(const std::string& message, const std::exception& cause) 37 | : std::runtime_error(message + " (Caused by: " + cause.what() + ")") {} // Simple cause handling 38 | }; 39 | 40 | 41 | // --- Forward Declarations --- 42 | class IBaseNode; // Non-templated base interface 43 | 44 | // --- Base Node Interface (Non-Templated) --- 45 | // Needed to store heterogeneous node types in successors map 46 | class IBaseNode { 47 | public: 48 | virtual ~IBaseNode() = default; // IMPORTANT: Virtual destructor 49 | 50 | virtual void setParamsInternal(const Params& params) = 0; 51 | virtual std::optional internalRun(Context& sharedContext) = 0; 52 | virtual std::shared_ptr getNextNode(const std::optional& action) const = 0; 53 | virtual bool hasSuccessors() const = 0; 54 | virtual const std::string& getClassName() const = 0; // For logging 55 | 56 | // Simplified next chaining accepting any IBaseNode. More type-safe versions 57 | // can be added in the templated BaseNode. 58 | virtual std::shared_ptr next(std::shared_ptr node, const std::string& action) = 0; 59 | virtual std::shared_ptr next(std::shared_ptr node) = 0; // Default action 60 | 61 | // Allow getting params (e.g., for result capture in tests) 62 | virtual const Params& getParams() const = 0; 63 | }; 64 | 65 | 66 | // --- Base Node Template --- 67 | template 68 | class BaseNode : public IBaseNode { 69 | protected: 70 | Params params; 71 | std::map> successors; 72 | std::string className = typeid(*this).name(); // Store class name approximation 73 | 74 | public: 75 | virtual ~BaseNode() override = default; 76 | 77 | // --- Configuration --- 78 | // Returns *this reference to allow chaining like Java, but less common in C++ 79 | BaseNode& setParams(const Params& newParams) { 80 | params = newParams; // Creates a copy 81 | return *this; 82 | } 83 | 84 | // Override from IBaseNode 85 | void setParamsInternal(const Params& newParams) override { 86 | params = newParams; 87 | } 88 | 89 | const Params& getParams() const override { 90 | return params; 91 | } 92 | 93 | // --- Chaining --- 94 | // Templated next for potential type checking at connection time if needed, 95 | // but primarily delegates to the IBaseNode interface version. 96 | template 97 | std::shared_ptr> next(std::shared_ptr> node, const std::string& action) { 98 | if (!node) { 99 | throw std::invalid_argument("Successor node cannot be null"); 100 | } 101 | if (successors.count(action)) { 102 | logWarn("Overwriting successor for action '" + action + "' in node " + getClassName()); 103 | } 104 | successors[action] = node; // Implicit cast to shared_ptr 105 | return node; 106 | } 107 | 108 | template 109 | std::shared_ptr> next(std::shared_ptr> node) { 110 | // Use "" or a specific constant internally to represent the default action key 111 | return next(node, ""); // Empty string as internal key for default 112 | } 113 | 114 | // IBaseNode implementation for next (needed for polymorphism) 115 | std::shared_ptr next(std::shared_ptr node, const std::string& action) override { 116 | if (!node) { 117 | throw std::invalid_argument("Successor node cannot be null"); 118 | } 119 | if (successors.count(action)) { 120 | logWarn("Overwriting successor for action '" + action + "' in node " + getClassName()); 121 | } 122 | successors[action] = node; 123 | return node; // Return the base interface pointer 124 | } 125 | 126 | std::shared_ptr next(std::shared_ptr node) override { 127 | return next(node, ""); // Empty string for default 128 | } 129 | 130 | 131 | // --- Core Methods (to be implemented by subclasses) --- 132 | virtual P prep(Context& sharedContext) { 133 | // Default implementation returns default-constructed P 134 | // Handle void case explicitly if needed, though default works ok. 135 | if constexpr (std::is_same_v) { 136 | return; // Or handle differently if void needs special logic 137 | } else { 138 | return P{}; 139 | } 140 | } 141 | 142 | virtual E exec(P prepResult) = 0; // Pure virtual 143 | 144 | virtual std::optional post(Context& sharedContext, const P& prepResult, const E& execResult) { 145 | // Default implementation returns nullopt (default action) 146 | return std::nullopt; 147 | } 148 | 149 | // --- Internal Execution Logic --- 150 | protected: 151 | // This internal method allows Node to override execution with retries 152 | virtual E internalExec(P prepResult) { 153 | return exec(std::move(prepResult)); // Use move if P is movable 154 | } 155 | 156 | public: 157 | // IBaseNode implementation 158 | std::optional internalRun(Context& sharedContext) override { 159 | P prepRes = prep(sharedContext); 160 | E execRes = internalExec(std::move(prepRes)); // Use move if P is movable 161 | // Need to handle void return type E potentially 162 | if constexpr (std::is_same_v) { 163 | return post(sharedContext, prepRes, {}); // Pass dummy value for void E 164 | } else { 165 | return post(sharedContext, prepRes, execRes); 166 | } 167 | } 168 | 169 | 170 | // --- Standalone Run --- 171 | // Note: Return type matches internalRun now. The Java version returned String. 172 | // Returning optional seems more consistent here. 173 | std::optional run(Context& sharedContext) { 174 | if (!successors.empty()) { 175 | logWarn("Node " + getClassName() + " has successors, but run() was called. Successors won't be executed. Use Flow."); 176 | } 177 | return internalRun(sharedContext); 178 | } 179 | 180 | // --- Successor Retrieval (IBaseNode implementation) --- 181 | std::shared_ptr getNextNode(const std::optional& action) const override { 182 | std::string actionKey = action.value_or(""); // Use "" for default action key 183 | auto it = successors.find(actionKey); 184 | if (it != successors.end()) { 185 | return it->second; 186 | } else { 187 | if (!successors.empty()) { 188 | std::string requestedAction = action.has_value() ? "'" + action.value() + "'" : "default"; 189 | std::string availableActions; 190 | for(const auto& pair : successors) { 191 | availableActions += "'" + (pair.first.empty() ? "" : pair.first) + "' "; 192 | } 193 | logWarn("Flow might end: Action " + requestedAction + " not found in successors [" 194 | + availableActions + "] of node " + getClassName()); 195 | } 196 | return nullptr; // No successor found 197 | } 198 | } 199 | 200 | bool hasSuccessors() const override { 201 | return !successors.empty(); 202 | } 203 | 204 | const std::string& getClassName() const override { 205 | // Return a potentially mangled name. Provide a way to set a clean name if needed. 206 | // For now, use the stored approximation. 207 | // A better way: Add a virtual getName() method overridden by each concrete node. 208 | return className; 209 | } 210 | 211 | protected: 212 | // Helper to safely get from map with default 213 | template 214 | T getParamOrDefault(const std::string& key, T defaultValue) const { 215 | auto it = params.find(key); 216 | if (it != params.end()) { 217 | try { 218 | return std::any_cast(it->second); 219 | } catch (const std::bad_any_cast& e) { 220 | // Log or handle cast error - return default for now 221 | logWarn("Bad any_cast for param '" + key + "' in node " + getClassName() + ". Expected different type."); 222 | return defaultValue; 223 | } 224 | } 225 | return defaultValue; 226 | } 227 | }; 228 | 229 | 230 | // --- Synchronous Node with Retries --- 231 | template 232 | class Node : public BaseNode { 233 | protected: 234 | int maxRetries; 235 | long long waitMillis; // Use long long for milliseconds 236 | int currentRetry = 0; 237 | 238 | public: 239 | Node(int retries = 1, long long waitMilliseconds = 0) 240 | : maxRetries(retries), waitMillis(waitMilliseconds) { 241 | if (maxRetries < 1) throw std::invalid_argument("maxRetries must be at least 1"); 242 | if (waitMillis < 0) throw std::invalid_argument("waitMillis cannot be negative"); 243 | } 244 | 245 | virtual ~Node() override = default; 246 | 247 | // Fallback method to be overridden if needed 248 | virtual E execFallback(P prepResult, const std::exception& lastException) { 249 | // Default behavior is to re-throw the last exception 250 | throw PocketFlowException("Node execution failed after " + std::to_string(maxRetries) + " retries, and fallback was not implemented or also failed.", lastException); 251 | } 252 | 253 | protected: 254 | // Override internalExec to add retry logic 255 | E internalExec(P prepResult) override { 256 | std::unique_ptr lastExceptionPtr; // Store last exception 257 | 258 | for (currentRetry = 0; currentRetry < maxRetries; ++currentRetry) { 259 | try { 260 | // Need to copy or move prepResult carefully if exec might modify it 261 | // Assuming exec takes by value or const ref for simplicity here 262 | // If P is expensive to copy, consider passing by ref and ensuring exec handles it. 263 | return this->exec(prepResult); // Call the user-defined exec 264 | } catch (const std::exception& e) { 265 | // Using unique_ptr to manage exception polymorphism if needed, 266 | // but storing a copy of the base std::exception might suffice. 267 | // Let's store the last exception message simply for now. 268 | // A better approach might involve exception_ptr. 269 | // lastException = e; // Direct copy loses polymorphic type 270 | 271 | // Store the exception *type* to rethrow properly or use std::exception_ptr 272 | try { throw; } // Rethrow to capture current exception 273 | catch (const std::exception& current_e) { 274 | // Store the exception to be used in fallback 275 | lastExceptionPtr = std::make_unique(current_e.what()); // Store message 276 | } 277 | 278 | 279 | if (currentRetry < maxRetries - 1 && waitMillis > 0) { 280 | try { 281 | std::this_thread::sleep_for(std::chrono::milliseconds(waitMillis)); 282 | } catch (...) { 283 | // Handle potential exceptions during sleep? Unlikely but possible. 284 | throw PocketFlowException("Thread interrupted during retry wait", std::runtime_error("sleep interrupted")); 285 | } 286 | } 287 | } catch (...) { // Catch non-std exceptions if necessary 288 | lastExceptionPtr = std::make_unique("Non-standard exception caught during exec"); 289 | if (currentRetry < maxRetries - 1 && waitMillis > 0) { 290 | std::this_thread::sleep_for(std::chrono::milliseconds(waitMillis)); 291 | } 292 | } 293 | } 294 | 295 | // If loop finishes, all retries failed 296 | try { 297 | if (!lastExceptionPtr) { 298 | throw PocketFlowException("Execution failed after retries, but no exception was captured."); 299 | } 300 | // Call fallback, passing a reference to the stored exception approximation 301 | return execFallback(std::move(prepResult), *lastExceptionPtr); 302 | } catch (const std::exception& fallbackException) { 303 | // If fallback fails, throw appropriate exception 304 | throw PocketFlowException("Fallback execution failed after main exec retries failed.", fallbackException); 305 | } catch (...) { 306 | throw PocketFlowException("Fallback execution failed with non-standard exception.", std::runtime_error("Unknown fallback error")); 307 | } 308 | } 309 | }; 310 | 311 | 312 | // --- Synchronous Batch Node --- 313 | template 314 | class BatchNode : public Node, std::vector> { 315 | public: 316 | BatchNode(int retries = 1, long long waitMilliseconds = 0) 317 | : Node, std::vector>(retries, waitMilliseconds) {} 318 | 319 | virtual ~BatchNode() override = default; 320 | 321 | // --- Methods for subclasses to implement --- 322 | virtual OUT_ITEM execItem(const IN_ITEM& item) = 0; // Process a single item 323 | 324 | virtual OUT_ITEM execItemFallback(const IN_ITEM& item, const std::exception& lastException) { 325 | // Default fallback re-throws 326 | throw PocketFlowException("Batch item execution failed after retries, and fallback was not implemented or also failed.", lastException); 327 | } 328 | 329 | // --- Base class methods that MUST NOT be overridden by user --- 330 | // Make exec final to prevent accidental override. User should implement execItem. 331 | std::vector exec(std::vector prepResult) final { 332 | // This method is conceptually hidden by internalExec override below. 333 | // We override internalExec directly. 334 | throw std::logic_error("BatchNode::exec should not be called directly."); 335 | } 336 | 337 | // Fallback for the whole batch (rarely needed if item fallback exists) 338 | std::vector execFallback(std::vector prepResult, const std::exception& lastException) final { 339 | // This fallback applies if the *looping* itself fails, not individual items. 340 | throw PocketFlowException("BatchNode internal execution loop failed.", lastException); 341 | } 342 | 343 | 344 | protected: 345 | // Override internalExec for batch processing logic 346 | std::vector internalExec(std::vector batchPrepResult) override { 347 | if (batchPrepResult.empty()) { 348 | return {}; 349 | } 350 | 351 | std::vector results; 352 | results.reserve(batchPrepResult.size()); 353 | 354 | for (const auto& item : batchPrepResult) { 355 | std::unique_ptr lastItemExceptionPtr; 356 | bool itemSuccess = false; 357 | OUT_ITEM itemResult{}; // Default construct 358 | 359 | for (this->currentRetry = 0; this->currentRetry < this->maxRetries; ++this->currentRetry) { 360 | try { 361 | itemResult = execItem(item); // Call user implementation 362 | itemSuccess = true; 363 | break; // Success, exit retry loop for this item 364 | } catch (const std::exception& e) { 365 | try { throw; } catch(const std::exception& current_e) { 366 | lastItemExceptionPtr = std::make_unique(current_e.what()); 367 | } 368 | if (this->currentRetry < this->maxRetries - 1 && this->waitMillis > 0) { 369 | std::this_thread::sleep_for(std::chrono::milliseconds(this->waitMillis)); 370 | } 371 | } catch (...) { 372 | lastItemExceptionPtr = std::make_unique("Non-standard exception during execItem"); 373 | if (this->currentRetry < this->maxRetries - 1 && this->waitMillis > 0) { 374 | std::this_thread::sleep_for(std::chrono::milliseconds(this->waitMillis)); 375 | } 376 | } 377 | } // End retry loop for item 378 | 379 | if (!itemSuccess) { 380 | try { 381 | if (!lastItemExceptionPtr) { 382 | throw PocketFlowException("Item execution failed without exception for item."); // Add item info if possible 383 | } 384 | itemResult = execItemFallback(item, *lastItemExceptionPtr); // Call user fallback 385 | } catch (const std::exception& fallbackEx) { 386 | throw PocketFlowException("Item fallback execution failed.", fallbackEx); // Add item info if possible 387 | } catch (...) { 388 | throw PocketFlowException("Item fallback failed with non-standard exception.", std::runtime_error("Unknown item fallback error")); 389 | } 390 | } 391 | results.push_back(std::move(itemResult)); // Move if possible 392 | } // End loop over items 393 | 394 | return results; 395 | } 396 | }; 397 | 398 | 399 | // --- Flow Orchestrator --- 400 | // Inherits from BaseNode with dummy types for consistency, but overrides run logic. 401 | // Using std::nullptr_t for unused P type. 402 | class Flow : public BaseNode> { 403 | protected: 404 | std::shared_ptr startNode = nullptr; 405 | 406 | public: 407 | Flow() = default; 408 | explicit Flow(std::shared_ptr start) { this->start(std::move(start)); } 409 | virtual ~Flow() override = default; 410 | 411 | template 412 | std::shared_ptr> start(std::shared_ptr> node) { 413 | if (!node) { 414 | throw std::invalid_argument("Start node cannot be null"); 415 | } 416 | startNode = node; // Implicit cast to shared_ptr 417 | return node; 418 | } 419 | // Overload for IBaseNode pointer directly 420 | std::shared_ptr start(std::shared_ptr node) { 421 | if (!node) { 422 | throw std::invalid_argument("Start node cannot be null"); 423 | } 424 | startNode = node; 425 | return node; 426 | } 427 | 428 | // Prevent direct calls to Flow's exec - logic is in orchestrate/internalRun 429 | std::optional exec(std::nullptr_t /*prepResult*/) final override { 430 | throw std::logic_error("Flow::exec() is internal and should not be called directly. Use run()."); 431 | } 432 | 433 | protected: 434 | // The core orchestration logic 435 | virtual std::optional orchestrate(Context& sharedContext, const Params& initialParams) { 436 | if (!startNode) { 437 | logWarn("Flow started with no start node."); 438 | return std::nullopt; 439 | } 440 | 441 | std::shared_ptr currentNode = startNode; 442 | std::optional lastAction = std::nullopt; 443 | 444 | // Combine flow's base params with initial params for this run 445 | Params currentRunParams = this->params; // Start with Flow's own params 446 | currentRunParams.insert(initialParams.begin(), initialParams.end()); // Add/overwrite with initialParams 447 | 448 | 449 | while (currentNode != nullptr) { 450 | currentNode->setParamsInternal(currentRunParams); // Set params for the current node 451 | lastAction = currentNode->internalRun(sharedContext); // Execute the node 452 | currentNode = currentNode->getNextNode(lastAction); // Find the next node based on action 453 | } 454 | 455 | return lastAction; // Return the action that led to termination (or nullopt if last node had no action) 456 | } 457 | 458 | // Override BaseNode's internal run 459 | std::optional internalRun(Context& sharedContext) override { 460 | // Flow's prep is usually no-op unless overridden 461 | [[maybe_unused]] std::nullptr_t prepRes = prep(sharedContext); // Call prep, ignore result 462 | 463 | // Orchestrate starting with empty initial params (can be overridden by BatchFlow) 464 | std::optional orchRes = orchestrate(sharedContext, {}); 465 | 466 | // Flow's post processes the *result* of the orchestration 467 | return post(sharedContext, nullptr, orchRes); 468 | } 469 | 470 | public: 471 | // Override post for Flow. Default returns the final action from orchestration. 472 | std::optional post(Context& sharedContext, const std::nullptr_t& /*prepResult*/, const std::optional& execResult) override { 473 | return execResult; // Simply pass through the last action 474 | } 475 | 476 | // run() method inherited from BaseNode calls internalRun() correctly. 477 | }; 478 | 479 | 480 | // --- Batch Flow --- 481 | class BatchFlow : public Flow { 482 | public: 483 | BatchFlow() = default; 484 | explicit BatchFlow(std::shared_ptr start) : Flow(std::move(start)) {} 485 | virtual ~BatchFlow() override = default; 486 | 487 | // --- Methods for subclasses to implement --- 488 | virtual std::vector prepBatch(Context& sharedContext) = 0; 489 | 490 | // Post method after all batches have run 491 | virtual std::optional postBatch(Context& sharedContext, const std::vector& batchPrepResult) = 0; 492 | 493 | protected: 494 | // Override internalRun to handle batch execution 495 | std::optional internalRun(Context& sharedContext) override { 496 | std::vector batchParamsList = prepBatch(sharedContext); 497 | 498 | if (batchParamsList.empty()) { 499 | logWarn("BatchFlow prepBatch returned empty list."); 500 | // Still call postBatch even if empty 501 | } 502 | 503 | for (const auto& batchParams : batchParamsList) { 504 | // Run the orchestration for each parameter set. 505 | // Result of individual orchestrations is ignored here; focus is on side effects. 506 | orchestrate(sharedContext, batchParams); 507 | } 508 | 509 | // After all batches, call postBatch 510 | return postBatch(sharedContext, batchParamsList); 511 | } 512 | 513 | public: 514 | // Prevent calling the regular post method directly for BatchFlow 515 | std::optional post(Context& /*sharedContext*/, const std::nullptr_t& /*prepResult*/, const std::optional& /*execResult*/) final override { 516 | throw std::logic_error("Use postBatch for BatchFlow, not post."); 517 | } 518 | }; 519 | 520 | 521 | } // namespace pocketflow 522 | 523 | #endif // POCKETFLOW_H --------------------------------------------------------------------------------