├── go.mod ├── .gitignore ├── LICENSE ├── go.sum ├── README.md ├── pocketflow_test.go └── pocketflow.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/The-Pocket/PocketFlow-Go 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/stretchr/testify v1.8.4 // For assertions in tests 7 | ) 8 | 9 | require ( 10 | github.com/davecgh/go-spew v1.1.1 // indirect 11 | github.com/pmezard/go-difflib v1.0.0 // indirect 12 | gopkg.in/yaml.v3 v3.0.1 // indirect 13 | ) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | *.test 8 | 9 | # Output of 'go build' 10 | # Usually the project name, like 'PocketFlow-Go' if built locally without specifying output 11 | PocketFlow-Go 12 | app 13 | main 14 | bin/ 15 | cmd/ 16 | 17 | # Test binary, build with `go test -c` 18 | *.test 19 | 20 | # Output of the go coverage tool, e.g., coverage.out 21 | *.out 22 | *.prof 23 | 24 | # Dependency directories (remove the comment below if you vendor) 25 | # vendor/ 26 | 27 | # Go workspace files (rarely used now) 28 | # Gopkg.lock 29 | # Gopkg.toml 30 | 31 | # Environment configuration files 32 | .env* 33 | *.env 34 | 35 | # IDE/Editor directories and files 36 | .idea/ 37 | .vscode/ 38 | *.iml 39 | *.ipr 40 | *.iws 41 | *~ 42 | *.swp 43 | *.swo 44 | 45 | # OS generated files 46 | .DS_Store 47 | .DS_Store? 48 | ._* 49 | .Spotlight-V100 50 | .Trashes 51 | ehthumbs.db 52 | Thumbs.db 53 | desktop.ini 54 | 55 | # Log files 56 | *.log 57 | 58 | # Add any other project-specific files or directories to ignore below 59 | # e.g., local data files, temporary build artifacts, etc. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Zachary Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 5 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 6 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 7 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 8 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 9 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 10 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 11 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 12 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 13 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 14 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 15 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 16 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PocketFlow Go 2 | 3 | A minimalist LLM framework concept, ported from Python to Go. 4 | 5 | ## Overview 6 | 7 | PocketFlow Go is a port of the original [Python PocketFlow](https://github.com/The-Pocket/PocketFlow). It provides a lightweight, flexible system for building and executing LLM-based (or other sequential) workflows through a simple node-based architecture using Go interfaces and functions. 8 | 9 | > **Note:** This is an initial synchronous implementation mirroring the Java version. It currently does not support asynchronous operations (goroutines for execution). Community contributors are welcome to help enhance and maintain this project, particularly with adding robust concurrency patterns if desired. 10 | 11 | ## Installation 12 | 13 | Ensure you have Go (1.18 or later recommended, 1.21+ for map cloning functions) installed. 14 | 15 | ```bash 16 | go get github.com/The-Pocket/PocketFlow-Go 17 | ``` 18 | 19 | ## Usage 20 | 21 | Here's a simple example of how to use PocketFlow Go in your application: 22 | 23 | ```go 24 | package main 25 | 26 | import ( 27 | "fmt" 28 | "log" 29 | 30 | pf "github.com/The-Pocket/PocketFlow-Go" // Adjust import path 31 | ) 32 | 33 | // Define node logic using PocketFlow's functional style 34 | 35 | // myStartNode creates a node that starts the workflow. 36 | func myStartNode() pf.BaseNode { 37 | return pf.NewNode(). 38 | SetExec(func(prepResult any, params pf.SharedContext) (any, error) { 39 | log.Println("Starting workflow...") 40 | // Exec result can be used by Post to determine action 41 | return "started_data", nil 42 | }). 43 | SetPost(func(ctx pf.SharedContext, prepResult any, execResult any, params pf.SharedContext) (string, error) { 44 | // Use execResult to decide the next step 45 | log.Printf("Start node finished with data: %v\n", execResult) 46 | ctx["start_result"] = execResult // Optional: Update shared context 47 | return "started", nil // Action name to trigger the next node 48 | }) 49 | } 50 | 51 | // myEndNode creates a node that ends the workflow. 52 | func myEndNode() pf.BaseNode { 53 | return pf.NewNode(). 54 | SetPrep(func(ctx pf.SharedContext, params pf.SharedContext) (any, error) { 55 | // Prep can access the shared context 56 | startData := ctx["start_result"] 57 | prepMsg := fmt.Sprintf("Preparing to end workflow, received: %v", startData) 58 | log.Println(prepMsg) 59 | return prepMsg, nil // Prep result passed to Exec 60 | }). 61 | SetExec(func(prepResult any, params pf.SharedContext) (any, error) { 62 | prepMsg := prepResult.(string) // Assume prep result is string 63 | log.Printf("Ending workflow with: %s\n", prepMsg) 64 | // End nodes often don't need to return data 65 | return nil, nil 66 | }) 67 | // Default Post (returns DefaultAction) is fine here 68 | } 69 | 70 | func main() { 71 | // Create instances of your nodes 72 | startNode := myStartNode() 73 | endNode := myEndNode() 74 | 75 | // Connect the nodes: start -> end (when action is "started") 76 | startNode.Next("started", endNode) 77 | 78 | // Create a flow with the start node 79 | flow := pf.NewFlow(startNode) 80 | 81 | // Create a context and run the flow 82 | context := make(pf.SharedContext) 83 | log.Println("Executing workflow...") 84 | finalAction, err := flow.Run(context) 85 | if err != nil { 86 | log.Fatalf("Workflow failed: %v\n", err) 87 | } 88 | 89 | log.Printf("Workflow completed successfully. Final action: %s\n", finalAction) 90 | log.Printf("Final Context: %v\n", context) 91 | } 92 | 93 | ``` 94 | 95 | ## Development 96 | 97 | ### Building the Project 98 | 99 | ```bash 100 | go build ./... 101 | ``` 102 | 103 | ### Running Tests 104 | 105 | ```bash 106 | go test ./... 107 | ``` 108 | Or with coverage: 109 | ```bash 110 | go test -coverprofile=coverage.out ./... && go tool cover -html=coverage.out 111 | ``` 112 | 113 | ## Contributing 114 | 115 | Contributions are welcome! We're particularly looking for volunteers to: 116 | 117 | 1. Implement asynchronous operation support (e.g., using goroutines, channels, `context.Context`). 118 | 2. Add more comprehensive test coverage, including edge cases and error handling. 119 | 3. Improve documentation and provide more complex examples (e.g., LLM integration stubs). 120 | 4. Refine the API for better Go idiomatic usage if applicable. 121 | 122 | Please feel free to submit pull requests or open issues for discussion. 123 | 124 | ## License 125 | 126 | [MIT License](LICENSE) -------------------------------------------------------------------------------- /pocketflow_test.go: -------------------------------------------------------------------------------- 1 | package pocketflow_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | // "strconv" // Removed unused import 8 | "testing" 9 | "time" // Added missing import 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | 14 | pf "github.com/The-Pocket/PocketFlow-Go" // Assuming module path 15 | ) 16 | 17 | // --- Test Node Implementations using Functional Style --- 18 | 19 | // setNumberNode creates a node that sets a number in the context. 20 | func setNumberNode(number int) pf.BaseNode { 21 | n := pf.NewNode(). 22 | SetExec(func(ctx *pf.PfContext, params map[string]any, prepResult any) (any, error) { 23 | multiplier := 1 24 | if m, ok := params["multiplier"].(int); ok { 25 | multiplier = m 26 | } 27 | return number * multiplier, nil // Exec result is the number 28 | }). 29 | SetPost(func(ctx *pf.PfContext, params map[string]any, prepResult any, execResult any) (string, error) { 30 | num := execResult.(int) // Assume execResult is int 31 | ctx.SetValue("currentValue", num) 32 | if num > 20 { 33 | return "over_20", nil 34 | } 35 | return pf.DefaultAction, nil 36 | }) 37 | return n 38 | } 39 | 40 | // addNumberNode creates a node that adds a number based on context. 41 | func addNumberNode(numberToAdd int) pf.BaseNode { 42 | n := pf.NewNode(). 43 | SetPrep(func(ctx *pf.PfContext, params map[string]any) (any, error) { 44 | // 原代码:current, ok := ctx["currentValue"].(int) 45 | current, ok := ctx.Value("currentValue").(int) 46 | if !ok { 47 | return nil, fmt.Errorf("currentValue not found or not an int in context") 48 | } 49 | return current, nil // Prep result is the current value 50 | }). 51 | SetExec(func(ctx *pf.PfContext, params map[string]any, prepResult any) (any, error) { 52 | current := prepResult.(int) // Assume prepResult is int 53 | return current + numberToAdd, nil 54 | }). 55 | SetPost(func(ctx *pf.PfContext, params map[string]any, prepResult any, execResult any) (string, error) { 56 | num := execResult.(int) // Assume execResult is int 57 | // 原代码:ctx["currentValue"] = num 58 | ctx.SetValue("currentValue", num) 59 | return "added", nil // Action to trigger next node 60 | }) 61 | return n 62 | } 63 | 64 | // resultCaptureNode creates a node that captures the context value into its own params. 65 | func resultCaptureNode() pf.BaseNode { 66 | n := pf.NewNode(). 67 | SetPrep(func(ctx *pf.PfContext, params map[string]any) (any, error) { 68 | // 原代码:val, ok := ctx["currentValue"] 69 | val := ctx.Value("currentValue") 70 | if val == nil { 71 | // Provide a default if not found, mirroring Java test 72 | return -999, nil 73 | } 74 | // Ensure the value is an int before returning 75 | intVal, ok := val.(int) 76 | if !ok { 77 | return -999, fmt.Errorf("currentValue was not an int: %T", val) 78 | } 79 | return intVal, nil 80 | }). 81 | SetExec(func(ctx *pf.PfContext, params map[string]any, prepResult any) (any, error) { 82 | capturedVal := prepResult.(int) 83 | params["capturedValue"] = capturedVal // Store in node's *own* params 84 | return nil, nil // No meaningful exec result needed 85 | }) 86 | // Default Post is sufficient (returns DefaultAction) 87 | return n 88 | } 89 | 90 | // simpleLogNode creates a node for BatchFlow testing. 91 | func simpleLogNode() pf.BaseNode { 92 | n := pf.NewNode(). 93 | SetExec(func(ctx *pf.PfContext, params map[string]any, prepResult any) (any, error) { 94 | multi := params["multiplier"] // Get multiplier from params set by BatchFlow 95 | message := fmt.Sprintf("SimpleLogNode executed with multiplier: %v", multi) 96 | return message, nil 97 | }). 98 | SetPost(func(ctx *pf.PfContext, params map[string]any, prepResult any, execResult any) (string, error) { 99 | message := execResult.(string) 100 | key := fmt.Sprintf("last_message_from_batch_%v", params["multiplier"]) 101 | // 原代码:ctx[key] = message 102 | ctx.SetValue(key, message) 103 | return pf.DefaultAction, nil 104 | }) 105 | return n 106 | } 107 | 108 | // --- Test Methods --- 109 | 110 | func TestSimpleLinearFlow(t *testing.T) { 111 | start := setNumberNode(10) 112 | add := addNumberNode(5) 113 | capture := resultCaptureNode() 114 | 115 | // Connect nodes: start -> add (on default) -> capture (on "added") 116 | start.Next(pf.DefaultAction, add).Next("added", capture) 117 | 118 | flow := pf.NewFlow(start) 119 | sharedContext := pf.WithParam(context.Background(), nil) 120 | 121 | lastAction, err := flow.Run(sharedContext) 122 | require.NoError(t, err) 123 | 124 | // Capture node is the last one, its default post returns "default" 125 | assert.Equal(t, pf.DefaultAction, lastAction) // Flow's post returns last node's action 126 | // 原代码:assert.Equal(t, 15, sharedContext["currentValue"]) 127 | assert.Equal(t, 15, sharedContext.Value("currentValue")) 128 | 129 | // Check the captured value in the capture node's *own* parameters 130 | captureParams := capture.GetParams() 131 | assert.Equal(t, 15, captureParams["capturedValue"]) 132 | } 133 | 134 | func TestBranchingFlow(t *testing.T) { 135 | start := setNumberNode(10) 136 | add := addNumberNode(5) 137 | captureDefault := resultCaptureNode() 138 | captureOver20 := resultCaptureNode() 139 | 140 | // Connections: 141 | // start -> add (on default) -> captureDefault (on "added") 142 | // start -> captureOver20 (on "over_20") 143 | start.Next(pf.DefaultAction, add).Next("added", captureDefault) 144 | start.Next("over_20", captureOver20) 145 | 146 | flow := pf.NewFlow(start) 147 | sharedContext := pf.WithParam(context.Background(), nil) 148 | 149 | // Set parameters on the flow, which will be passed to the start node 150 | flow.SetParams(map[string]any{"multiplier": 3}) 151 | 152 | lastAction, err := flow.Run(sharedContext) 153 | require.NoError(t, err) 154 | 155 | // The flow should take the "over_20" branch to captureOver20, which returns "default" 156 | assert.Equal(t, pf.DefaultAction, lastAction) 157 | // 原代码:assert.Equal(t, 30, sharedContext["currentValue"]) 158 | assert.Equal(t, 30, sharedContext.Value("currentValue")) 159 | 160 | // Check the correct capture node got the value 161 | captureOver20Params := captureOver20.GetParams() 162 | captureDefaultParams := captureDefault.GetParams() 163 | 164 | assert.Equal(t, 30, captureOver20Params["capturedValue"]) 165 | _, existsDefault := captureDefaultParams["capturedValue"] 166 | assert.False(t, existsDefault, "captureDefault should not have captured a value") 167 | // Check default value wasn't accidentally set if GetParams() returns nil map initially 168 | if defaultVal, ok := captureDefaultParams["capturedValue"]; ok { 169 | assert.NotEqual(t, -999, defaultVal, "Default prep value should not be in params") 170 | } 171 | 172 | } 173 | 174 | func TestBatchFlowExecution(t *testing.T) { 175 | batchFlow := pf.NewBatchFlow(simpleLogNode()) // Start node logs based on params 176 | 177 | batchFlow.SetPrepBatch(func(ctx *pf.PfContext, params map[string]any) ([]map[string]any, error) { 178 | // Generate parameter sets for each batch run 179 | return []map[string]any{ 180 | {"multiplier": 2}, 181 | {"multiplier": 4}, 182 | }, nil 183 | }) 184 | 185 | batchFlow.SetPostBatch(func(ctx *pf.PfContext, params map[string]any, batchPrepResult []map[string]any) (string, error) { 186 | // 原代码:ctx["postBatchCalled"] = true 187 | ctx.SetValue("postBatchCalled", true) 188 | assert.Len(t, batchPrepResult, 2, "PostBatch should receive the original prep result") 189 | return "batch_complete", nil 190 | }) 191 | 192 | batchContext := pf.WithParam(context.Background(), nil) 193 | resultAction, err := batchFlow.Run(batchContext) 194 | require.NoError(t, err) 195 | 196 | assert.Equal(t, "batch_complete", resultAction) 197 | assert.True(t, batchContext.Value("postBatchCalled").(bool)) 198 | 199 | // Check that the log messages were stored in the shared context by the simpleLogNode's PostFunc 200 | assert.Equal(t, "SimpleLogNode executed with multiplier: 2", batchContext.Value("last_message_from_batch_2")) 201 | assert.Equal(t, "SimpleLogNode executed with multiplier: 4", batchContext.Value("last_message_from_batch_4")) 202 | } 203 | 204 | // --- Additional Tests --- 205 | 206 | func TestNodeRetrySuccess(t *testing.T) { 207 | execCount := 0 208 | node := pf.NewNode(). 209 | SetRetry(3, 1*time.Millisecond). // Use time.Millisecond 210 | SetExec(func(ctx *pf.PfContext, params map[string]any, prepResult any) (any, error) { 211 | execCount++ 212 | if execCount < 3 { 213 | return nil, fmt.Errorf("temporary failure %d", execCount) 214 | } 215 | return "success", nil // Succeeds on 3rd try 216 | }) 217 | ctx := pf.WithParam(context.Background(), nil) 218 | 219 | _, err := node.Run(ctx) 220 | require.NoError(t, err) 221 | assert.Equal(t, 3, execCount, "Exec should have been called 3 times") 222 | } 223 | 224 | func TestNodeRetryFailureWithFallback(t *testing.T) { 225 | execCount := 0 226 | fallbackCalled := false 227 | node := pf.NewNode(). 228 | SetRetry(2, 1*time.Millisecond). // Use time.Millisecond 229 | SetExec(func(ctx *pf.PfContext, params map[string]any, prepResult any) (any, error) { 230 | execCount++ 231 | return nil, fmt.Errorf("permanent failure %d", execCount) // Always fail 232 | }). 233 | SetFallback(func(ctx *pf.PfContext, params map[string]any, prepResult any, lastErr error) (any, error) { 234 | fallbackCalled = true 235 | assert.ErrorContains(t, lastErr, "permanent failure 2") 236 | return "fallback_success", nil // Fallback succeeds 237 | }) 238 | ctx := pf.WithParam(context.Background(), nil) 239 | 240 | action, err := node.Run(ctx) 241 | require.NoError(t, err) 242 | assert.Equal(t, 2, execCount, "Exec should have been called 2 times") 243 | assert.True(t, fallbackCalled, "Fallback should have been called") 244 | assert.Equal(t, pf.DefaultAction, action) // Default post action 245 | } 246 | 247 | func TestNodeRetryFailureWithoutFallback(t *testing.T) { 248 | execCount := 0 249 | node := pf.NewNode(). 250 | SetRetry(2, 1*time.Millisecond). // Use time.Millisecond 251 | SetExec(func(ctx *pf.PfContext, params map[string]any, prepResult any) (any, error) { 252 | execCount++ 253 | return nil, fmt.Errorf("permanent failure %d", execCount) // Always fail 254 | }) 255 | // No fallback set 256 | ctx := pf.WithParam(context.Background(), nil) 257 | 258 | _, err := node.Run(ctx) 259 | require.Error(t, err) 260 | assert.ErrorContains(t, err, "Exec phase failed") 261 | assert.ErrorContains(t, err, "permanent failure 2") // Check cause 262 | assert.Equal(t, 2, execCount, "Exec should have been called 2 times") 263 | } 264 | 265 | func TestBatchNodeItemRetryAndFallback(t *testing.T) { 266 | itemExecCounts := make(map[string]int) 267 | itemFallbackCalled := make(map[string]bool) 268 | 269 | bnode := pf.NewBatchNode(). 270 | SetRetry(3, 1*time.Millisecond). // Use time.Millisecond - Retries per item 271 | SetPrep(func(ctx *pf.PfContext, params map[string]any) ([]any, error) { 272 | return []any{"ok", "fail_once", "fail_always"}, nil 273 | }). 274 | SetExecItem(func(ctx *pf.PfContext, params map[string]any, item any) (any, error) { 275 | key := item.(string) 276 | itemExecCounts[key]++ 277 | switch key { 278 | case "ok": 279 | return "OK_RES", nil 280 | case "fail_once": 281 | if itemExecCounts[key] < 2 { 282 | return nil, fmt.Errorf("temp fail %s", key) 283 | } 284 | return "FAIL_ONCE_RES", nil // Success on retry 285 | case "fail_always": 286 | return nil, fmt.Errorf("perm fail %s", key) // Always fail 287 | } 288 | return nil, fmt.Errorf("unknown item") 289 | }). 290 | SetItemFallback(func(ctx *pf.PfContext, params map[string]any, item any, lastErr error) (any, error) { 291 | key := item.(string) 292 | if key == "fail_always" { 293 | itemFallbackCalled[key] = true 294 | assert.ErrorContains(t, lastErr, "perm fail fail_always") 295 | return "FAIL_ALWAYS_FALLBACK_RES", nil // Fallback success 296 | } 297 | // Fallback should not be called for others 298 | return nil, fmt.Errorf("unexpected fallback for %s", key) 299 | }). 300 | SetPost(func(ctx *pf.PfContext, params map[string]any, prepResult []any, execResult []any) (string, error) { 301 | // Store results in context for assertion 302 | ctx.SetValue("results", execResult) 303 | return "batch_done", nil 304 | }) 305 | 306 | ctx := pf.WithParam(context.Background(), nil) 307 | //ctx := context.Background() 308 | 309 | action, err := bnode.Run(ctx) 310 | 311 | require.NoError(t, err) 312 | assert.Equal(t, "batch_done", action) 313 | 314 | // Check execution counts 315 | assert.Equal(t, 1, itemExecCounts["ok"]) 316 | assert.Equal(t, 2, itemExecCounts["fail_once"]) 317 | assert.Equal(t, 3, itemExecCounts["fail_always"]) // All retries used 318 | 319 | // Check fallback calls 320 | assert.False(t, itemFallbackCalled["ok"]) 321 | assert.False(t, itemFallbackCalled["fail_once"]) 322 | assert.True(t, itemFallbackCalled["fail_always"]) 323 | 324 | // Check final results passed to Post 325 | results := ctx.Value("results").([]any) 326 | require.Len(t, results, 3) 327 | assert.Equal(t, "OK_RES", results[0]) 328 | assert.Equal(t, "FAIL_ONCE_RES", results[1]) 329 | assert.Equal(t, "FAIL_ALWAYS_FALLBACK_RES", results[2]) 330 | } 331 | -------------------------------------------------------------------------------- /pocketflow.go: -------------------------------------------------------------------------------- 1 | package pocketflow 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "time" 8 | ) 9 | 10 | type PfContext struct { 11 | context.Context 12 | param map[any]any 13 | } 14 | 15 | func (c *PfContext) Value(key any) any { 16 | if c.param == nil { 17 | c.param = make(map[any]any) 18 | } 19 | if v, ok := c.param[key]; ok { 20 | return v 21 | } 22 | return c.Context.Value(key) 23 | } 24 | 25 | func (c *PfContext) SetValue(key any, value any) { 26 | if c.param == nil { 27 | c.param = make(map[any]any) 28 | } 29 | c.param[key] = value 30 | } 31 | 32 | func WithParam(parent context.Context, param map[any]any) *PfContext { 33 | c := PfContext{ 34 | Context: parent, 35 | param: param, 36 | } 37 | return &c 38 | } 39 | 40 | // DefaultAction is the action name used when a node's Post returns an empty string or nil action. 41 | const DefaultAction = "default" 42 | 43 | // PocketFlowError represents an error originating from the PocketFlow library. 44 | type PocketFlowError struct { 45 | Message string 46 | Cause error 47 | } 48 | 49 | func (e *PocketFlowError) Error() string { 50 | if e.Cause != nil { 51 | // Consider adding cause details depending on verbosity needs 52 | return fmt.Sprintf("PocketFlow error: %s (caused by: %v)", e.Message, e.Cause) 53 | } 54 | return fmt.Sprintf("PocketFlow error: %s", e.Message) 55 | } 56 | 57 | // Unwrap allows PocketFlowError to work with errors.Is and errors.As. 58 | func (e *PocketFlowError) Unwrap() error { 59 | return e.Cause 60 | } 61 | 62 | func newPocketFlowError(msg string, cause error) error { 63 | return &PocketFlowError{Message: msg, Cause: cause} 64 | } 65 | 66 | func logWarn(format string, v ...any) { 67 | log.Printf("WARN: PocketFlow - "+format, v...) 68 | } 69 | 70 | // --- Base Node --- 71 | 72 | // BaseNode defines the interface for all nodes in a workflow. 73 | type BaseNode interface { 74 | // Prep prepares input for Exec using the shared context. 75 | // Returns the prepared data (can be nil) and an error. 76 | Prep(ctx *PfContext) (any, error) 77 | 78 | // Exec performs the main work using the result from Prep. 79 | // Returns the execution result (can be nil) and an error. 80 | Exec(ctx *PfContext, prepResult any) (any, error) 81 | 82 | // Post processes results, updates context, and returns the next action string. 83 | // An empty string implies DefaultAction. Returns an error if post-processing fails. 84 | Post(ctx *PfContext, prepResult any, execResult any) (string, error) 85 | 86 | // SetParams sets node-specific parameters. Returns the node for chaining. 87 | SetParams(params map[string]any) BaseNode 88 | 89 | // GetParams returns the node's current parameters. 90 | GetParams() map[string]any 91 | 92 | // Next connects this node to another node for a specific action. Returns the *next* node for chaining. 93 | Next(action string, node BaseNode) BaseNode 94 | 95 | // GetSuccessors returns the map of action->node successors. 96 | GetSuccessors() map[string]BaseNode 97 | 98 | // GetNextNode retrieves the successor node for a given action (or DefaultAction). 99 | GetNextNode(action string) BaseNode 100 | 101 | // Run executes a single node's lifecycle (prep, exec, post). Useful for standalone execution. 102 | // Returns the resulting action and error. 103 | Run(ctx *PfContext) (string, error) 104 | 105 | // InternalRun is used by Flow orchestration to execute the node lifecycle. 106 | // Separated from Run to prevent potential issues if Run is overridden incorrectly. 107 | InternalRun(ctx *PfContext) (string, error) 108 | } 109 | 110 | // --- Common Node Implementation Details --- 111 | 112 | type nodeCore struct { 113 | params map[string]any 114 | successors map[string]BaseNode 115 | } 116 | 117 | func (n *nodeCore) initCore() { 118 | if n.params == nil { 119 | n.params = make(map[string]any) 120 | } 121 | if n.successors == nil { 122 | n.successors = make(map[string]BaseNode) 123 | } 124 | } 125 | 126 | func (n *nodeCore) SetParams(params map[string]any) { 127 | n.initCore() 128 | if params != nil { 129 | // Create a copy to avoid external modification issues 130 | // Replace with manual copy loop for older Go versions: 131 | n.params = make(map[string]any, len(params)) 132 | for k, v := range params { 133 | n.params[k] = v 134 | } 135 | } else { 136 | n.params = make(map[string]any) 137 | } 138 | } 139 | 140 | func (n *nodeCore) GetParams() map[string]any { 141 | n.initCore() 142 | // Return a copy to prevent modification? Or trust user? Let's return direct map for now. 143 | return n.params 144 | } 145 | 146 | func (n *nodeCore) Next(action string, node BaseNode) BaseNode { 147 | n.initCore() 148 | if node == nil { 149 | panic("Successor node cannot be nil") // Panic mirrors Java's NullPointerException 150 | } 151 | if action == "" { 152 | action = DefaultAction 153 | } 154 | if _, exists := n.successors[action]; exists { 155 | logWarn("Overwriting successor for action '%s' in node %T", action, n) // %T gives dynamic type 156 | } 157 | n.successors[action] = node 158 | return node // Return the next node for chaining 159 | } 160 | 161 | func (n *nodeCore) GetSuccessors() map[string]BaseNode { 162 | n.initCore() 163 | return n.successors 164 | } 165 | 166 | func (n *nodeCore) GetNextNode(action string) BaseNode { 167 | n.initCore() 168 | if action == "" { 169 | action = DefaultAction 170 | } 171 | nextNode, exists := n.successors[action] 172 | if !exists && len(n.successors) > 0 { 173 | keys := make([]string, 0, len(n.successors)) 174 | for k := range n.successors { 175 | keys = append(keys, k) 176 | } 177 | logWarn("Flow might end: Action '%s' not found in successors %v of node %T", action, keys, n) 178 | } 179 | return nextNode 180 | } 181 | 182 | // --- Standard Node (with Retry) --- 183 | 184 | // Node implements BaseNode with retry logic. 185 | type Node struct { 186 | nodeCore 187 | MaxRetries int 188 | WaitMilliseconds time.Duration // Use time.Duration for clarity 189 | 190 | // User-defined functions for node logic 191 | PrepFunc func(ctx *PfContext, params map[string]any) (any, error) 192 | ExecFunc func(ctx *PfContext, params map[string]any, prepResult any) (any, error) 193 | PostFunc func(ctx *PfContext, params map[string]any, prepResult any, execResult any) (string, error) 194 | 195 | // Optional fallback function if all retries fail 196 | ExecFallbackFunc func(ctx *PfContext, params map[string]any, prepResult any, lastErr error) (any, error) 197 | } 198 | 199 | // NewNode creates a new Node with default retry settings (1 try, 0 wait). 200 | func NewNode() *Node { 201 | n := &Node{ 202 | MaxRetries: 1, 203 | WaitMilliseconds: 0, 204 | PrepFunc: func(ctx *PfContext, params map[string]any) (any, error) { return nil, nil }, // Default no-op 205 | ExecFunc: func(ctx *PfContext, params map[string]any, prepResult any) (any, error) { return nil, nil }, // Default no-op 206 | PostFunc: func(ctx *PfContext, params map[string]any, prepResult any, execResult any) (string, error) { 207 | return DefaultAction, nil 208 | }, // Default action 209 | 210 | } 211 | n.initCore() 212 | return n 213 | } 214 | 215 | // SetRetry configures retry behaviour. 216 | func (n *Node) SetRetry(maxRetries int, waitMilliseconds time.Duration) *Node { 217 | if maxRetries < 1 { 218 | panic("maxRetries must be at least 1") 219 | } 220 | if waitMilliseconds < 0 { 221 | panic("waitMilliseconds cannot be negative") 222 | } 223 | n.MaxRetries = maxRetries 224 | n.WaitMilliseconds = waitMilliseconds 225 | return n 226 | } 227 | 228 | // SetPrep sets the PrepFunc. 229 | func (n *Node) SetPrep(f func(ctx *PfContext, params map[string]any) (any, error)) *Node { 230 | n.PrepFunc = f 231 | return n 232 | } 233 | 234 | // SetExec sets the ExecFunc. 235 | func (n *Node) SetExec(f func(ctx *PfContext, params map[string]any, prepResult any) (any, error)) *Node { 236 | n.ExecFunc = f 237 | return n 238 | } 239 | 240 | // SetPost sets the PostFunc. 241 | func (n *Node) SetPost(f func(ctx *PfContext, params map[string]any, prepResult any, execResult any) (string, error)) *Node { 242 | n.PostFunc = f 243 | return n 244 | } 245 | 246 | // SetFallback sets the ExecFallbackFunc. 247 | func (n *Node) SetFallback(f func(ctx *PfContext, params map[string]any, prepResult any, lastErr error) (any, error)) *Node { 248 | n.ExecFallbackFunc = f 249 | return n 250 | } 251 | 252 | // --- BaseNode Implementation for Node --- 253 | 254 | func (n *Node) SetParams(params map[string]any) BaseNode { 255 | n.nodeCore.SetParams(params) 256 | return n 257 | } 258 | 259 | func (n *Node) Next(action string, node BaseNode) BaseNode { 260 | return n.nodeCore.Next(action, node) 261 | } 262 | 263 | func (n *Node) Prep(ctx *PfContext) (any, error) { 264 | if n.PrepFunc == nil { 265 | return nil, nil // Default behavior 266 | } 267 | return n.PrepFunc(ctx, n.params) 268 | } 269 | 270 | func (n *Node) Exec(ctx *PfContext, prepResult any) (any, error) { 271 | // This is the public Exec, usually called via InternalRun which handles retry 272 | if n.ExecFunc == nil { 273 | return nil, nil 274 | } 275 | return n.ExecFunc(ctx, n.params, prepResult) 276 | } 277 | 278 | func (n *Node) Post(ctx *PfContext, prepResult any, execResult any) (string, error) { 279 | if n.PostFunc == nil { 280 | return DefaultAction, nil 281 | } 282 | action, err := n.PostFunc(ctx, n.params, prepResult, execResult) 283 | if err == nil && action == "" { 284 | action = DefaultAction 285 | } 286 | return action, err 287 | } 288 | 289 | func (n *Node) Run(ctx *PfContext) (string, error) { 290 | if len(n.successors) > 0 { 291 | logWarn("Node %T has successors, but Run() was called directly. Successors won't be executed by this call. Use Flow.Run() for orchestration.", n) 292 | } 293 | return n.InternalRun(ctx) 294 | } 295 | 296 | func (n *Node) InternalRun(ctx *PfContext) (string, error) { 297 | prepRes, err := n.Prep(ctx) 298 | if err != nil { 299 | return "", newPocketFlowError(fmt.Sprintf("Prep phase failed in %T", n), err) 300 | } 301 | 302 | var execRes any 303 | var lastExecErr error 304 | currentRetry := 0 305 | 306 | for currentRetry = 0; currentRetry < n.MaxRetries; currentRetry++ { 307 | execRes, lastExecErr = n.Exec(ctx, prepRes) // Call the non-retry Exec 308 | if lastExecErr == nil { 309 | break // Success 310 | } 311 | if currentRetry < n.MaxRetries-1 && n.WaitMilliseconds > 0 { 312 | time.Sleep(n.WaitMilliseconds) 313 | } 314 | } 315 | 316 | // If all retries failed 317 | if lastExecErr != nil { 318 | if n.ExecFallbackFunc != nil { 319 | execRes, err = n.ExecFallbackFunc(ctx, n.params, prepRes, lastExecErr) 320 | if err != nil { 321 | // Wrap the fallback error, potentially including the original execution error 322 | return "", newPocketFlowError(fmt.Sprintf("ExecFallback phase failed in %T after %d retries", n, n.MaxRetries), err) 323 | } 324 | lastExecErr = nil // Fallback succeeded, clear the error 325 | } else { 326 | // No fallback, return the last execution error 327 | return "", newPocketFlowError(fmt.Sprintf("Exec phase failed in %T after %d retries", n, n.MaxRetries), lastExecErr) 328 | } 329 | } 330 | 331 | // Post phase 332 | action, err := n.Post(ctx, prepRes, execRes) 333 | if err != nil { 334 | return "", newPocketFlowError(fmt.Sprintf("Post phase failed in %T", n), err) 335 | } 336 | 337 | return action, nil 338 | } 339 | 340 | // --- Batch Node (Processes items individually) --- 341 | 342 | // BatchNode implements BaseNode to process slices of items. 343 | type BatchNode struct { 344 | nodeCore 345 | MaxRetries int 346 | WaitMilliseconds time.Duration 347 | 348 | // User-defined functions 349 | // Prep returns a slice (or error) 350 | PrepFunc func(ctx *PfContext, params map[string]any) ([]any, error) 351 | // ExecItem operates on a single item from the Prep slice 352 | ExecItemFunc func(ctx *PfContext, params map[string]any, item any) (any, error) 353 | // Post receives the original prep slice and the slice of exec results 354 | PostFunc func(ctx *PfContext, params map[string]any, prepResult []any, execResult []any) (string, error) 355 | 356 | // Optional fallback for individual item processing 357 | ExecItemFallbackFunc func(ctx *PfContext, params map[string]any, item any, lastErr error) (any, error) 358 | } 359 | 360 | // NewBatchNode creates a new BatchNode with default settings. 361 | func NewBatchNode() *BatchNode { 362 | bn := &BatchNode{ 363 | MaxRetries: 1, 364 | WaitMilliseconds: 0, 365 | PrepFunc: func(ctx *PfContext, params map[string]any) ([]any, error) { return nil, nil }, 366 | ExecItemFunc: func(ctx *PfContext, params map[string]any, item any) (any, error) { return item, nil }, // Default pass-through 367 | PostFunc: func(ctx *PfContext, params map[string]any, prepResult []any, execResult []any) (string, error) { 368 | return DefaultAction, nil 369 | }, 370 | } 371 | bn.initCore() 372 | return bn 373 | } 374 | 375 | // SetRetry configures retry behaviour. 376 | func (bn *BatchNode) SetRetry(maxRetries int, waitMilliseconds time.Duration) *BatchNode { 377 | if maxRetries < 1 { 378 | panic("maxRetries must be at least 1") 379 | } 380 | if waitMilliseconds < 0 { 381 | panic("waitMilliseconds cannot be negative") 382 | } 383 | bn.MaxRetries = maxRetries 384 | bn.WaitMilliseconds = waitMilliseconds 385 | return bn 386 | } 387 | 388 | // SetPrep sets the PrepFunc. Expects a function returning []any. 389 | func (bn *BatchNode) SetPrep(f func(ctx *PfContext, params map[string]any) ([]any, error)) *BatchNode { 390 | bn.PrepFunc = f 391 | return bn 392 | } 393 | 394 | // SetExecItem sets the ExecItemFunc for processing individual items. 395 | func (bn *BatchNode) SetExecItem(f func(ctx *PfContext, params map[string]any, item any) (any, error)) *BatchNode { 396 | bn.ExecItemFunc = f 397 | return bn 398 | } 399 | 400 | // SetPost sets the PostFunc. Receives []any prep and []any exec results. 401 | func (bn *BatchNode) SetPost(f func(ctx *PfContext, params map[string]any, prepResult []any, execResult []any) (string, error)) *BatchNode { 402 | bn.PostFunc = f 403 | return bn 404 | } 405 | 406 | // SetItemFallback sets the ExecItemFallbackFunc. 407 | func (bn *BatchNode) SetItemFallback(f func(ctx *PfContext, params map[string]any, item any, lastErr error) (any, error)) *BatchNode { 408 | bn.ExecItemFallbackFunc = f 409 | return bn 410 | } 411 | 412 | // --- BaseNode Implementation for BatchNode --- 413 | 414 | func (bn *BatchNode) SetParams(params map[string]any) BaseNode { 415 | bn.nodeCore.SetParams(params) 416 | return bn 417 | } 418 | 419 | func (bn *BatchNode) Next(action string, node BaseNode) BaseNode { 420 | return bn.nodeCore.Next(action, node) 421 | } 422 | 423 | // Prep calls the user-defined PrepFunc. 424 | func (bn *BatchNode) Prep(ctx *PfContext) (any, error) { 425 | if bn.PrepFunc == nil { 426 | return nil, nil 427 | } 428 | // Prep returns the slice directly (as 'any') 429 | return bn.PrepFunc(ctx, bn.params) 430 | } 431 | 432 | // Exec iterates through the prepResult slice, calling ExecItemFunc for each item with retries. 433 | func (bn *BatchNode) Exec(ctx *PfContext, prepResult any) (any, error) { 434 | if prepResult == nil { 435 | return []any{}, nil // Return empty slice if prep was nil 436 | } 437 | 438 | // Type assertion to get the slice from Prep result 439 | items, ok := prepResult.([]any) 440 | if !ok { 441 | return nil, newPocketFlowError(fmt.Sprintf("Prep phase of BatchNode %T did not return []any, got %T", bn, prepResult), nil) 442 | } 443 | 444 | if len(items) == 0 { 445 | return []any{}, nil // Return empty slice for empty input 446 | } 447 | 448 | results := make([]any, len(items)) 449 | var itemResult any 450 | var lastItemErr error 451 | currentRetry := 0 452 | 453 | for i, item := range items { 454 | lastItemErr = nil // Reset error for each item 455 | itemSuccess := false 456 | for currentRetry = 0; currentRetry < bn.MaxRetries; currentRetry++ { 457 | itemResult, lastItemErr = bn.ExecItemFunc(ctx, bn.params, item) 458 | if lastItemErr == nil { 459 | itemSuccess = true 460 | break // Success for this item 461 | } 462 | if currentRetry < bn.MaxRetries-1 && bn.WaitMilliseconds > 0 { 463 | time.Sleep(bn.WaitMilliseconds) 464 | } 465 | } 466 | 467 | // If all retries failed for this item 468 | if !itemSuccess { 469 | if bn.ExecItemFallbackFunc != nil { 470 | fallbackResult, fallbackErr := bn.ExecItemFallbackFunc(ctx, bn.params, item, lastItemErr) 471 | if fallbackErr != nil { 472 | // Fallback failed, return error for the whole batch 473 | return nil, newPocketFlowError(fmt.Sprintf("ExecItemFallback failed for item %d (%v) in %T after %d retries", i, item, bn, bn.MaxRetries), fallbackErr) 474 | } 475 | itemResult = fallbackResult // Use fallback result 476 | lastItemErr = nil // Mark as success via fallback 477 | } else { 478 | // No fallback, fail the whole batch 479 | return nil, newPocketFlowError(fmt.Sprintf("ExecItem failed for item %d (%v) in %T after %d retries", i, item, bn, bn.MaxRetries), lastItemErr) 480 | } 481 | } 482 | results[i] = itemResult 483 | } 484 | 485 | return results, nil // Return the slice of results 486 | } 487 | 488 | // Post calls the user-defined PostFunc. 489 | func (bn *BatchNode) Post(ctx *PfContext, prepResult any, execResult any) (string, error) { 490 | // Type assertions needed as interface methods deal with 'any' 491 | prepSlice, okPrep := prepResult.([]any) 492 | if prepResult != nil && !okPrep { // Allow nil prepResult 493 | return "", newPocketFlowError(fmt.Sprintf("Internal error: prepResult in BatchNode %T Post was not []any (%T)", bn, prepResult), nil) 494 | } 495 | 496 | execSlice, okExec := execResult.([]any) 497 | if execResult != nil && !okExec { // Allow nil execResult (e.g., if prep was empty) 498 | return "", newPocketFlowError(fmt.Sprintf("Internal error: execResult in BatchNode %T Post was not []any (%T)", bn, execResult), nil) 499 | } 500 | // Ensure slices are not nil if they were originally nil/empty, matching Java behaviour somewhat 501 | if prepSlice == nil { 502 | prepSlice = []any{} 503 | } 504 | if execSlice == nil { 505 | execSlice = []any{} 506 | } 507 | 508 | if bn.PostFunc == nil { 509 | return DefaultAction, nil 510 | } 511 | action, err := bn.PostFunc(ctx, bn.params, prepSlice, execSlice) 512 | if err == nil && action == "" { 513 | action = DefaultAction 514 | } 515 | return action, err 516 | } 517 | 518 | func (bn *BatchNode) Run(ctx *PfContext) (string, error) { 519 | if len(bn.successors) > 0 { 520 | logWarn("Node %T has successors, but Run() was called directly. Successors won't be executed by this call. Use Flow.Run() for orchestration.", bn) 521 | } 522 | return bn.InternalRun(ctx) 523 | } 524 | 525 | // InternalRun implements the retry logic at the item level within Exec. 526 | func (bn *BatchNode) InternalRun(ctx *PfContext) (string, error) { 527 | prepRes, err := bn.Prep(ctx) // prepRes should be []any 528 | if err != nil { 529 | return "", newPocketFlowError(fmt.Sprintf("Prep phase failed in %T", bn), err) 530 | } 531 | 532 | // Exec handles its own item-level retry/fallback 533 | execRes, err := bn.Exec(ctx, prepRes) // execRes should be []any 534 | if err != nil { 535 | // Error from Exec already includes context about retries/fallbacks 536 | return "", err // Don't wrap again 537 | } 538 | 539 | // Post phase 540 | action, err := bn.Post(ctx, prepRes, execRes) // Post expects []any, []any 541 | if err != nil { 542 | return "", newPocketFlowError(fmt.Sprintf("Post phase failed in %T", bn), err) 543 | } 544 | 545 | return action, nil 546 | } 547 | 548 | // --- Flow --- 549 | 550 | // Flow orchestrates the execution of connected nodes. 551 | type Flow struct { 552 | nodeCore // Flow itself can have params, though less common for successors here 553 | startNode BaseNode 554 | } 555 | 556 | // NewFlow creates a new Flow, optionally with a starting node. 557 | func NewFlow(startNode BaseNode) *Flow { 558 | f := &Flow{ 559 | startNode: startNode, 560 | } 561 | f.initCore() 562 | return f 563 | } 564 | 565 | // Start sets the initial node for the flow. Returns the start node for chaining setup. 566 | func (f *Flow) Start(node BaseNode) BaseNode { 567 | if node == nil { 568 | panic("Start node cannot be nil") 569 | } 570 | f.startNode = node 571 | return node 572 | } 573 | 574 | // --- BaseNode Implementation for Flow --- 575 | // Most BaseNode methods are less relevant for Flow itself, focused on orchestration. 576 | 577 | func (f *Flow) SetParams(params map[string]any) BaseNode { 578 | f.nodeCore.SetParams(params) 579 | return f 580 | } 581 | 582 | // Next for a Flow doesn't make logical sense in the standard execution model. 583 | func (f *Flow) Next(action string, node BaseNode) BaseNode { 584 | logWarn("Calling Next() on a Flow is unusual. Successors set here are not used by standard Run() orchestration.") 585 | return f.nodeCore.Next(action, node) 586 | } 587 | 588 | // Prep for the Flow itself. Default is no-op. Can be overridden if needed. 589 | func (f *Flow) Prep(ctx *PfContext) (any, error) { 590 | // Typically Flow prep is about setting up the context before orchestration starts 591 | return nil, nil 592 | } 593 | 594 | // Exec for the Flow initiates the orchestration. Should not be called directly by user. 595 | func (f *Flow) Exec(ctx *PfContext, prepResult any) (any, error) { 596 | // This is called internally by InternalRun after Flow's Prep. 597 | // The 'prepResult' here is the result of Flow.Prep, not a node's prep. 598 | // The 'execResult' of a Flow is the final action string from orchestration. 599 | // We need the context here for orchestrate, assume prepResult is the context for simplicity 600 | // although Flow's Prep doesn't *have* to return the context. Let's pass ctx directly. 601 | // This requires changing the call site in InternalRun. 602 | sharedCtx, _ := prepResult.(map[string]any) 603 | finalAction, err := f.orchestrate(ctx, sharedCtx) // Run orchestration with the context 604 | if err != nil { 605 | return "", err // Return error, action is irrelevant if orchestration failed 606 | } 607 | return finalAction, nil // Return the final action as the result 608 | } 609 | 610 | // Post for the Flow runs after orchestration completes. Default returns the final action. 611 | func (f *Flow) Post(ctx *PfContext, prepResult any, execResult any) (string, error) { 612 | // prepResult is from Flow.Prep, execResult is the final action string from Exec/orchestrate. 613 | finalAction, _ := execResult.(string) // Ignore error, default to "" if cast fails 614 | if finalAction == "" { 615 | finalAction = DefaultAction // Or maybe keep it empty? Let's default. 616 | } 617 | return finalAction, nil 618 | } 619 | 620 | // Run starts the flow execution. 621 | func (f *Flow) Run(ctx *PfContext) (string, error) { 622 | // Use InternalRun to perform the standard Flow lifecycle (Prep, Exec(orchestrate), Post) 623 | return f.InternalRun(ctx) 624 | } 625 | 626 | // InternalRun executes the flow's lifecycle: Prep, Orchestrate (via Exec), Post. 627 | func (f *Flow) InternalRun(ctx *PfContext) (string, error) { 628 | // 1. Run Flow's Prep phase 629 | flowPrepResult, err := f.Prep(ctx) 630 | if err != nil { 631 | return "", newPocketFlowError(fmt.Sprintf("Prep phase failed for Flow %T", f), err) 632 | } 633 | 634 | // 2. Run Flow's Exec phase (which triggers orchestration) 635 | // Pass the *original* shared context to Exec, as Exec now expects it. 636 | flowExecResult, err := f.Exec(ctx, flowPrepResult) // Exec calls orchestrate 637 | if err != nil { 638 | // Error likely came from a node within orchestrate 639 | return "", err // Don't wrap again, error should be informative 640 | } 641 | 642 | // 3. Run Flow's Post phase 643 | finalAction, err := f.Post(ctx, flowPrepResult, flowExecResult) 644 | if err != nil { 645 | return "", newPocketFlowError(fmt.Sprintf("Post phase failed for Flow %T", f), err) 646 | } 647 | 648 | return finalAction, nil 649 | } 650 | 651 | // orchestrate executes the node chain starting from startNode. 652 | // initialParams are merged with the flow's own params for the *first* node. 653 | // Returns the last action string and any error encountered. 654 | func (f *Flow) orchestrate(ctx *PfContext, initialParams map[string]any) (string, error) { 655 | if f.startNode == nil { 656 | logWarn("Flow started with no start node.") 657 | return "", nil // No error, just nothing to run 658 | } 659 | 660 | currentNode := f.startNode 661 | lastAction := "" 662 | var err error 663 | 664 | // Prepare initial parameters for the first node run 665 | // Combine Flow's params and any specific initialParams for this run 666 | // Replace with manual copy loop: 667 | combinedParams := make(map[string]any, len(f.params)) 668 | for k, v := range f.params { 669 | combinedParams[k] = v 670 | } 671 | 672 | if initialParams != nil { 673 | // Replace with manual copy loop: 674 | for k, v := range initialParams { 675 | combinedParams[k] = v // Add or overwrite keys from initialParams 676 | } 677 | } 678 | 679 | for currentNode != nil { 680 | // Set the combined params *before* running the node 681 | // Only apply combinedParams on the *first* iteration 682 | if combinedParams != nil { 683 | currentNode.SetParams(combinedParams) 684 | combinedParams = nil // Clear after first use 685 | } else { 686 | // Ensure subsequent nodes get at least the Flow's base params if theirs are unset. 687 | if len(currentNode.GetParams()) == 0 && len(f.params) > 0 { 688 | currentNode.SetParams(f.params) // Give it the flow's base params if it has none 689 | } 690 | } 691 | 692 | // Execute the node's full lifecycle (Prep, Exec, Post) 693 | lastAction, err = currentNode.InternalRun(ctx) 694 | if err != nil { 695 | // Error occurred within the node's execution 696 | return "", err // Return the error immediately 697 | } 698 | 699 | // Get the next node based on the action returned by Post 700 | currentNode = currentNode.GetNextNode(lastAction) 701 | 702 | // Parameter propagation logic for subsequent nodes (revisit if needed) 703 | // The current logic sets params once at the start or uses node's existing/flow base. 704 | } 705 | 706 | // Orchestration finished successfully, return the last action determined 707 | return lastAction, nil 708 | } 709 | 710 | // --- Batch Flow --- 711 | 712 | // BatchFlow runs the entire flow sequence for each parameter set generated by PrepBatch. 713 | type BatchFlow struct { 714 | Flow // Embed Flow to inherit its structure and orchestration logic 715 | 716 | // User-defined functions for batch behavior 717 | PrepBatchFunc func(ctx *PfContext, params map[string]any) ([]map[string]any, error) 718 | PostBatchFunc func(ctx *PfContext, params map[string]any, batchPrepResult []map[string]any) (string, error) 719 | } 720 | 721 | // NewBatchFlow creates a new BatchFlow. 722 | func NewBatchFlow(startNode BaseNode) *BatchFlow { 723 | bf := &BatchFlow{ 724 | Flow: Flow{ // Initialize embedded Flow 725 | startNode: startNode, 726 | }, 727 | // Provide sensible defaults? 728 | PrepBatchFunc: func(ctx *PfContext, params map[string]any) ([]map[string]any, error) { return nil, nil }, 729 | PostBatchFunc: func(ctx *PfContext, params map[string]any, batchPrepResult []map[string]any) (string, error) { 730 | return DefaultAction, nil 731 | }, 732 | } 733 | bf.initCore() // Initialize nodeCore for the BatchFlow itself 734 | bf.Flow.initCore() // Ensure embedded Flow's core is also initialized 735 | return bf 736 | } 737 | 738 | // SetPrepBatch sets the function to generate batch parameters. 739 | func (bf *BatchFlow) SetPrepBatch(f func(ctx *PfContext, params map[string]any) ([]map[string]any, error)) *BatchFlow { 740 | bf.PrepBatchFunc = f 741 | return bf 742 | } 743 | 744 | // SetPostBatch sets the function to run after all batches complete. 745 | func (bf *BatchFlow) SetPostBatch(f func(ctx *PfContext, params map[string]any, batchPrepResult []map[string]any) (string, error)) *BatchFlow { 746 | bf.PostBatchFunc = f 747 | return bf 748 | } 749 | 750 | // --- BaseNode Implementation Overrides for BatchFlow --- 751 | 752 | // Prep for BatchFlow runs its PrepBatchFunc. 753 | func (bf *BatchFlow) Prep(ctx *PfContext) (any, error) { 754 | if bf.PrepBatchFunc == nil { 755 | return nil, nil 756 | } 757 | // Returns []*pfContext 758 | return bf.PrepBatchFunc(ctx, bf.params) 759 | } 760 | 761 | // Exec for BatchFlow runs the orchestration for each batch item. 762 | // The 'prepResult' here is the []*pfContext from BatchFlow.Prep. 763 | func (bf *BatchFlow) Exec(prepResult any) (any, error) { 764 | // We need the original context for the orchestrate calls. 765 | // InternalRun should pass it. For now, let's assume prepResult contains it implicitly 766 | // or redesign how context is passed through BatchFlow's Exec. 767 | // Safest: Assume InternalRun passes the context correctly and prepResult is the list. 768 | // Let's adjust the call site in InternalRun. 769 | 770 | batchParamsList, ok := prepResult.([]*PfContext) 771 | if prepResult != nil && !ok { 772 | return "", newPocketFlowError(fmt.Sprintf("Internal error: prepResult in BatchFlow %T Exec was not []*pfContext (%T)", bf, prepResult), nil) 773 | } 774 | if batchParamsList == nil { 775 | batchParamsList = []*PfContext{} 776 | } 777 | 778 | // We need the actual *pfContext. Where does it come from? 779 | // It should be passed *alongside* the prepResult by InternalRun. 780 | // Let's redefine Exec slightly to accept it, or rely on a field. 781 | // Simpler: Let InternalRun handle context passing to orchestrate directly. 782 | // Exec just needs to return the batchParamsList for Post. 783 | 784 | // The actual orchestration happens in InternalRun using this list. 785 | // This function's role is primarily semantic within the BaseNode interface call chain. 786 | // It returns the data needed for Post. 787 | 788 | return batchParamsList, nil 789 | } 790 | 791 | // Post for BatchFlow runs its PostBatchFunc. 792 | func (bf *BatchFlow) Post(ctx *PfContext, prepResult any, execResult any) (string, error) { 793 | // prepResult is the result of BatchFlow.Prep ([]*pfContext) 794 | // execResult is the result of BatchFlow.Exec (which we defined as the same []*pfContext) 795 | 796 | batchPrepResult, okPrep := prepResult.([]map[string]any) 797 | if prepResult != nil && !okPrep { 798 | return "", newPocketFlowError(fmt.Sprintf("Internal error: prepResult in BatchFlow %T Post was not []*pfContext (%T)", bf, prepResult), nil) 799 | } 800 | if batchPrepResult == nil { 801 | batchPrepResult = []map[string]any{} 802 | } 803 | 804 | if bf.PostBatchFunc == nil { 805 | return DefaultAction, nil 806 | } 807 | 808 | action, err := bf.PostBatchFunc(ctx, bf.params, batchPrepResult) 809 | if err == nil && action == "" { 810 | action = DefaultAction 811 | } 812 | return action, err 813 | } 814 | 815 | // Run starts the BatchFlow execution. 816 | func (bf *BatchFlow) Run(ctx *PfContext) (string, error) { 817 | // Use InternalRun to perform the standard lifecycle (PrepBatch, Exec Batches, PostBatch) 818 | return bf.InternalRun(ctx) 819 | } 820 | 821 | // InternalRun executes the BatchFlow lifecycle: PrepBatch, Exec(orchestrate per batch), PostBatch. 822 | func (bf *BatchFlow) InternalRun(ctx *PfContext) (string, error) { 823 | // 1. Run BatchFlow's Prep phase (PrepBatchFunc) 824 | // Should return []*pfContext 825 | prepBatchResultAny, err := bf.Prep(ctx) 826 | if err != nil { 827 | return "", newPocketFlowError(fmt.Sprintf("PrepBatch phase failed for BatchFlow %T", bf), err) 828 | } 829 | 830 | batchParamsList, ok := prepBatchResultAny.([]map[string]any) 831 | if prepBatchResultAny != nil && !ok { 832 | return "", newPocketFlowError(fmt.Sprintf("Internal error: PrepBatch phase in BatchFlow %T did not return []*pfContext (%T)", bf, prepBatchResultAny), nil) 833 | } 834 | if batchParamsList == nil { 835 | batchParamsList = []map[string]any{} 836 | } 837 | 838 | // 2. Run the orchestration for each item in batchParamsList 839 | for i, batchParams := range batchParamsList { 840 | // Run the embedded Flow's orchestration logic for each parameter set. 841 | // Pass the *original* shared context and current batchParams. 842 | _, err := bf.Flow.orchestrate(ctx, batchParams) 843 | if err != nil { 844 | // If one batch run fails, fail the whole BatchFlow execution 845 | return "", newPocketFlowError(fmt.Sprintf("Orchestration failed for batch item %d in %T", i, bf), err) 846 | } 847 | // Result (lastAction) of individual orchestrate runs is ignored here; side effects matter. 848 | } 849 | 850 | // 3. Run BatchFlow's Post phase (PostBatchFunc) 851 | // The result of the "Exec" phase semantically is the list itself. 852 | execResult := batchParamsList 853 | finalAction, err := bf.Post(ctx, prepBatchResultAny, execResult) 854 | if err != nil { 855 | return "", newPocketFlowError(fmt.Sprintf("PostBatch phase failed for BatchFlow %T", bf), err) 856 | } 857 | 858 | return finalAction, nil 859 | } 860 | --------------------------------------------------------------------------------