├── .codecov.yml ├── .github ├── dependabot.yml └── workflows │ ├── doc.yml │ ├── go.yml │ ├── lint.yml │ ├── release.yml │ └── tag.yml ├── .golangci.yml ├── LICENSE ├── README.md ├── doc.go ├── examples ├── agents │ ├── main.go │ └── util.go ├── gsm8k │ └── main.go ├── hotpotqa │ ├── main.go │ └── monitor.go ├── others │ ├── html │ │ └── main.go │ ├── mcp │ │ └── main.go │ └── mipro │ │ └── main.go └── utils │ └── setup.go ├── go.mod ├── go.sum ├── internal └── testutil │ └── mocks.go └── pkg ├── agents ├── README.md ├── agent.go ├── agent_test.go ├── common.go ├── common_test.go ├── memory.go ├── memory │ ├── buffered_memory.go │ ├── buffered_memory_test.go │ ├── sqlite_memory.go │ └── sqlite_memory_test.go ├── memory_test.go ├── orchestrator.go ├── orchestrator_test.go └── workflows │ ├── chain.go │ ├── chain_test.go │ ├── errors.go │ ├── errors_test.go │ ├── parallel.go │ ├── parallel_test.go │ ├── router.go │ ├── router_test.go │ ├── step.go │ ├── step_test.go │ ├── workflow.go │ └── workflow_test.go ├── core ├── config.go ├── decorators.go ├── decorators_test.go ├── execution_context.go ├── execution_context_test.go ├── factory.go ├── llm.go ├── llm_test.go ├── module.go ├── module_test.go ├── optimizer.go ├── optimizer_test.go ├── program.go ├── program_test.go ├── signature.go ├── signature_test.go ├── state.go ├── state_test.go └── tool.go ├── datasets ├── dataset.go ├── dataset_test.go ├── gsm8k.go └── hotpot_qa.go ├── errors ├── errors.go └── errors_test.go ├── llms ├── anthrophic.go ├── anthrophic_test.go ├── factory.go ├── factory_test.go ├── gemini.go ├── gemini_test.go ├── llamacpp.go ├── llamacpp_test.go ├── ollama.go └── ollama_test.go ├── logging ├── log_entry.go ├── log_entry_test.go ├── logger.go ├── logger_test.go ├── outputs.go ├── outputs_test.go ├── severity.go └── severity_test.go ├── metrics ├── accuracy.go └── accuracy_test.go ├── modules ├── chain_of_thought.go ├── chain_of_thought_test.go ├── module.go ├── predict.go ├── predict_test.go ├── react.go └── react_test.go ├── optimizers ├── bootstrap_fewshot.go ├── bootstrap_fewshot_test.go ├── copro.go ├── copro_test.go ├── mipro.go ├── mipro_test.go ├── tpe_optimizer.go └── tpe_optimizer_test.go ├── tools ├── func.go ├── func_test.go ├── mcp.go ├── mcp_test.go ├── registry.go ├── registry_test.go ├── tool.go ├── tool_test.go ├── util.go └── util_test.go └── utils ├── util.go └── util_test.go /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | # Define the overall project coverage status 6 | target: auto 7 | threshold: 1% 8 | paths: 9 | - "pkg/" 10 | patch: 11 | default: 12 | # Define the patch coverage status (for pull requests) 13 | target: auto 14 | threshold: 1% 15 | paths: 16 | - "pkg/" 17 | 18 | # Ignore coverage for files and pkg 19 | ignore: 20 | - "pkg/datasets/gsm8k.go" 21 | - "pkg/datasets/hotpot_qa.go" 22 | - "examples/**" 23 | - "cmd/**" 24 | - "internal/**" 25 | - "test/**" 26 | - "*.md" 27 | - "*.yml" 28 | - "*.yaml" 29 | - "Makefile" 30 | - "LICENSE" 31 | 32 | # Configure Codecov to only comment on pull requests if coverage decreases 33 | comment: 34 | require_changes: false 35 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - package-ecosystem: gomod 5 | directory: "/" 6 | schedule: 7 | interval: weekly 8 | ignore: 9 | - dependency-name: "*" 10 | update-types: 11 | - version-update:semver-minor 12 | - version-update:semver-major 13 | - package-ecosystem: "github-actions" 14 | directory: "/" 15 | schedule: 16 | interval: weekly 17 | -------------------------------------------------------------------------------- /.github/workflows/doc.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | on: 3 | push: 4 | branches: [main] 5 | 6 | permissions: 7 | contents: read 8 | pages: write 9 | id-token: write 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Go 17 | uses: actions/setup-go@v5 18 | with: 19 | go-version-file: "go.mod" 20 | - name: Check documentation syntax 21 | run: go doc . > /dev/null 22 | - name: Run golangci-lint 23 | uses: golangci/golangci-lint-action@v7 24 | with: 25 | args: --timeout=5m --enable=godot 26 | - name: Generate documentation 27 | run: | 28 | mkdir -p _site 29 | go doc -all . > _site/index.html 30 | - name: Upload artifact 31 | uses: actions/upload-pages-artifact@v3 32 | 33 | deploy: 34 | needs: build 35 | runs-on: ubuntu-latest 36 | environment: 37 | name: github-pages 38 | url: ${{ steps.deployment.outputs.page_url }} 39 | steps: 40 | - name: Deploy to GitHub Pages 41 | id: deployment 42 | uses: actions/deploy-pages@v4 43 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Unit and Integration Tests 2 | 3 | on: [push, pull_request] 4 | 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | unit-tests: 10 | name: Run Unit Tests 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [macos-latest] 15 | runs-on: ${{ matrix.os }} 16 | 17 | steps: 18 | - name: Check out code 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Go 22 | uses: actions/setup-go@v5 23 | with: 24 | go-version-file: "go.mod" 25 | 26 | - name: Download dependencies 27 | run: go mod download 28 | 29 | - name: Run unit tests 30 | run: go test -tags skip -race -v ./pkg/... -coverprofile ./coverage.txt 31 | 32 | - name: Upload coverage to Codecov 33 | uses: codecov/codecov-action@v5.4.2 34 | with: 35 | token: ${{ secrets.CODECOV_TOKEN }} 36 | files: ./coverage.txt 37 | flags: pkg 38 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: 3 | push: 4 | paths: 5 | - "**.go" 6 | - go.mod 7 | - go.sum 8 | pull_request: 9 | paths: 10 | - "**.go" 11 | - go.mod 12 | - go.sum 13 | 14 | env: 15 | GO111MODULE: on 16 | 17 | jobs: 18 | golangci-lint: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: actions/setup-go@v5 23 | with: 24 | go-version-file: "go.mod" 25 | 26 | - uses: golangci/golangci-lint-action@v7 27 | with: 28 | version: latest 29 | args: --verbose --timeout=5m 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | tags: 5 | - "v[0-9]+.[0-9]+.[0-9]+" 6 | repository_dispatch: 7 | types: [trigger-release] 8 | permissions: 9 | contents: write 10 | packages: write 11 | jobs: 12 | goreleaser: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | - name: Set up Go 20 | uses: actions/setup-go@v5 21 | with: 22 | go-version-file: "go.mod" 23 | - name: Get the version 24 | id: get_version 25 | run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT 26 | if: github.event_name == 'push' 27 | - name: Set version from dispatch 28 | if: github.event_name == 'repository_dispatch' 29 | run: echo "VERSION=${{ github.event.client_payload.tag }}" >> $GITHUB_OUTPUT 30 | - name: Run GoReleaser 31 | uses: goreleaser/goreleaser-action@v6 32 | with: 33 | distribution: goreleaser 34 | version: latest 35 | args: release --clean 36 | env: 37 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 38 | -------------------------------------------------------------------------------- /.github/workflows/tag.yml: -------------------------------------------------------------------------------- 1 | name: Automated Semantic Versioning 2 | 3 | on: 4 | workflow_dispatch: # Allows manual trigger from GitHub UI 5 | schedule: 6 | - cron: "0 0 * * *" # Run daily at midnight UTC 7 | push: 8 | branches: 9 | - main # Adjust this as necessary 10 | paths: 11 | - "**.go" 12 | 13 | jobs: 14 | tag: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v4 19 | with: 20 | fetch-depth: 0 # Fetch all history for all tags 21 | 22 | - name: Fetch latest tags 23 | run: git fetch --tags origin 24 | 25 | - name: Determine New Tag 26 | id: newtag 27 | run: | 28 | # Get the latest tag 29 | LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0") 30 | echo "Latest tag: $LATEST_TAG" 31 | 32 | # Extract major, minor, patch 33 | IFS='.' read -r MAJOR MINOR PATCH <<< "${LATEST_TAG#v}" 34 | 35 | # Check commit messages since last tag 36 | COMMITS=$(git log $LATEST_TAG..HEAD --pretty=format:"%s") 37 | if echo "$COMMITS" | grep -qE "^BREAKING CHANGE"; then 38 | NEW_MAJOR=$((MAJOR + 1)) 39 | NEW_MINOR=0 40 | NEW_PATCH=0 41 | elif echo "$COMMITS" | grep -qE "^feat:"; then 42 | NEW_MAJOR=$MAJOR 43 | NEW_MINOR=$((MINOR + 1)) 44 | NEW_PATCH=0 45 | elif echo "$COMMITS" | grep -qE "^fix:"; then 46 | NEW_MAJOR=$MAJOR 47 | NEW_MINOR=$MINOR 48 | NEW_PATCH=$((PATCH + 1)) 49 | else 50 | echo "No version bump needed" 51 | exit 0 52 | fi 53 | 54 | NEW_TAG="v$NEW_MAJOR.$NEW_MINOR.$NEW_PATCH" 55 | echo "New tag: $NEW_TAG" 56 | echo "NEW_TAG=$NEW_TAG" >> $GITHUB_ENV 57 | 58 | - name: Check if tag exists 59 | id: check_tag 60 | if: env.NEW_TAG != '' 61 | run: | 62 | if git rev-parse ${{ env.NEW_TAG }} >/dev/null 2>&1; then 63 | echo "Tag ${{ env.NEW_TAG }} already exists" 64 | echo "TAG_EXISTS=true" >> $GITHUB_ENV 65 | else 66 | echo "TAG_EXISTS=false" >> $GITHUB_ENV 67 | fi 68 | 69 | - name: Create and Push New Tag 70 | if: env.NEW_TAG != '' && env.TAG_EXISTS == 'false' 71 | run: | 72 | git config user.name github-actions 73 | git config user.email github-actions@github.com 74 | git tag -a ${{ env.NEW_TAG }} -m "Release ${{ env.NEW_TAG }}" 75 | git push origin ${{ env.NEW_TAG }} 76 | echo "TAG_CREATED=true" >> $GITHUB_ENV 77 | 78 | - name: Trigger Release Workflow 79 | if: env.TAG_CREATED == 'true' 80 | uses: peter-evans/repository-dispatch@v3 81 | with: 82 | token: ${{ secrets.GITHUB_TOKEN }} 83 | event-type: trigger-release 84 | client-payload: '{"tag": "${{ env.NEW_TAG }}"}' 85 | 86 | - name: Notify on Skipped Tag 87 | if: env.NEW_TAG == '' || env.TAG_EXISTS == 'true' 88 | run: | 89 | if [ -z "${{ env.NEW_TAG }}" ]; then 90 | echo "No new tag was created as no version bump was needed." 91 | else 92 | echo "Tag ${{ env.NEW_TAG }} already exists. Skipping tag creation." 93 | fi 94 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | linters: 3 | default: none 4 | enable: 5 | - errcheck 6 | - godot 7 | - govet 8 | - ineffassign 9 | - staticcheck 10 | - unused 11 | exclusions: 12 | generated: lax 13 | presets: 14 | - comments 15 | - common-false-positives 16 | - legacy 17 | - std-error-handling 18 | paths: 19 | - third_party$ 20 | - builtin$ 21 | - examples$ 22 | formatters: 23 | exclusions: 24 | generated: lax 25 | paths: 26 | - third_party$ 27 | - builtin$ 28 | - examples$ 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2024] [Xiao Cui] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package dspy is a Go implementation of the DSPy framework for using language models 2 | // to solve complex tasks through composable steps and prompting techniques. 3 | // 4 | // DSPy-Go provides a collection of modules, optimizers, and tools for building 5 | // reliable LLM-powered applications. It focuses on making it easy to: 6 | // - Break down complex tasks into modular steps 7 | // - Optimize prompts and chain-of-thought reasoning 8 | // - Build flexible agent-based systems 9 | // - Handle common LLM interaction patterns 10 | // - Evaluate and improve system performance 11 | // 12 | // Key Components: 13 | // 14 | // - Core: Fundamental abstractions like Module, Signature, LLM and Program 15 | // for defining and executing LLM-based workflows. 16 | // 17 | // - Modules: Building blocks for composing LLM workflows: 18 | // * Predict: Basic prediction module for simple LLM interactions 19 | // * ChainOfThought: Implements step-by-step reasoning with rationale tracking 20 | // * ReAct: Implements Reasoning and Acting with tool integration 21 | // 22 | // - Optimizers: Tools for improving prompt effectiveness: 23 | // * BootstrapFewShot: Automatically selects high-quality examples for few-shot learning 24 | // * MIPRO: Multi-step interactive prompt optimization 25 | // * Copro: Collaborative prompt optimization 26 | // 27 | // - Agents: Advanced patterns for building sophisticated AI systems: 28 | // * Memory: Different memory implementations for tracking conversation history 29 | // * Tools: Integration with external tools and APIs 30 | // * Workflows: 31 | // - Chain: Sequential execution of steps 32 | // - Parallel: Concurrent execution with controlled parallelism 33 | // - Router: Dynamic routing based on classification 34 | // * Orchestrator: Flexible task decomposition and execution 35 | // 36 | // - Integration with multiple LLM providers: 37 | // * Anthropic Claude 38 | // * Google Gemini 39 | // * Ollama 40 | // * LlamaCPP 41 | // 42 | // Simple Example: 43 | // 44 | // import ( 45 | // "context" 46 | // "fmt" 47 | // "log" 48 | // 49 | // "github.com/XiaoConstantine/dspy-go/pkg/core" 50 | // "github.com/XiaoConstantine/dspy-go/pkg/llms" 51 | // "github.com/XiaoConstantine/dspy-go/pkg/modules" 52 | // ) 53 | // 54 | // func main() { 55 | // // Configure the default LLM 56 | // llms.EnsureFactory() 57 | // err := config.ConfigureDefaultLLM("your-api-key", core.ModelAnthropicSonnet) 58 | // if err != nil { 59 | // log.Fatalf("Failed to configure LLM: %v", err) 60 | // } 61 | // 62 | // // Create a signature for question answering 63 | // signature := core.NewSignature( 64 | // []core.InputField{{Field: core.Field{Name: "question"}}}, 65 | // []core.OutputField{{Field: core.Field{Name: "answer"}}}, 66 | // ) 67 | // 68 | // // Create a ChainOfThought module 69 | // cot := modules.NewChainOfThought(signature) 70 | // 71 | // // Create a program 72 | // program := core.NewProgram( 73 | // map[string]core.Module{"cot": cot}, 74 | // func(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { 75 | // return cot.Process(ctx, inputs) 76 | // }, 77 | // ) 78 | // 79 | // // Execute the program 80 | // result, err := program.Execute(context.Background(), map[string]interface{}{ 81 | // "question": "What is the capital of France?", 82 | // }) 83 | // if err != nil { 84 | // log.Fatalf("Error executing program: %v", err) 85 | // } 86 | // 87 | // fmt.Printf("Answer: %s\n", result["answer"]) 88 | // } 89 | // 90 | // Advanced Features: 91 | // 92 | // - Tracing and Logging: Detailed tracing and structured logging for debugging and optimization 93 | // Execution context is tracked and passed through the pipeline for debugging and analysis. 94 | // 95 | // - Error Handling: Comprehensive error management with custom error types and centralized handling 96 | // 97 | // - Metric-Based Optimization: Improve module performance based on custom evaluation metrics 98 | // 99 | // - Custom Tool Integration: Extend ReAct modules with domain-specific tools 100 | // 101 | // - Workflow Retry Logic: Resilient execution with configurable retry mechanisms and backoff strategies 102 | // 103 | // - Streaming Support: Process LLM outputs incrementally as they're generated 104 | // 105 | // - Data Storage: Integration with various storage backends for persistence of examples and results 106 | // 107 | // - Arrow Support: Integration with Apache Arrow for efficient data handling and processing 108 | // 109 | // Working with Workflows: 110 | // 111 | // // Chain workflow example 112 | // workflow := workflows.NewChainWorkflow(store) 113 | // workflow.AddStep(&workflows.Step{ 114 | // ID: "step1", 115 | // Module: modules.NewPredict(signature1), 116 | // }) 117 | // workflow.AddStep(&workflows.Step{ 118 | // ID: "step2", 119 | // Module: modules.NewPredict(signature2), 120 | // // Configurable retry logic 121 | // RetryConfig: &workflows.RetryConfig{ 122 | // MaxAttempts: 3, 123 | // BackoffMultiplier: 2.0, 124 | // }, 125 | // // Conditional execution 126 | // Condition: func(state map[string]interface{}) bool { 127 | // return someCondition(state) 128 | // }, 129 | // }) 130 | // 131 | // For more examples and detailed documentation, visit: 132 | // https://github.com/XiaoConstantine/dspy-go 133 | // 134 | // DSPy-Go is released under the MIT License. 135 | package dspy 136 | -------------------------------------------------------------------------------- /examples/gsm8k/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "log" 7 | 8 | "github.com/XiaoConstantine/dspy-go/pkg/core" 9 | "github.com/XiaoConstantine/dspy-go/pkg/datasets" 10 | "github.com/XiaoConstantine/dspy-go/pkg/llms" 11 | "github.com/XiaoConstantine/dspy-go/pkg/logging" 12 | "github.com/XiaoConstantine/dspy-go/pkg/modules" 13 | "github.com/XiaoConstantine/dspy-go/pkg/optimizers" 14 | ) 15 | 16 | func RunGSM8KExample(apiKey string) { 17 | output := logging.NewConsoleOutput(true, logging.WithColor(true)) 18 | 19 | logger := logging.NewLogger(logging.Config{ 20 | Severity: logging.INFO, 21 | Outputs: []logging.Output{output}, 22 | }) 23 | logging.SetLogger(logger) 24 | 25 | ctx := core.WithExecutionState(context.Background()) 26 | // Setup LLM 27 | llms.EnsureFactory() 28 | err := core.ConfigureDefaultLLM(apiKey, core.ModelGoogleGeminiFlash) 29 | if err != nil { 30 | logger.Fatalf(ctx, "Failed to setup llm") 31 | } 32 | 33 | // Load GSM8K dataset 34 | examples, err := datasets.LoadGSM8K() 35 | if err != nil { 36 | log.Fatalf("Failed to load GSM8K dataset: %v", err) 37 | } 38 | 39 | // Create signature for ChainOfThought 40 | signature := core.NewSignature( 41 | []core.InputField{{Field: core.Field{Name: "question"}}}, 42 | []core.OutputField{{Field: core.NewField("answer")}}, 43 | ) 44 | 45 | // Create ChainOfThought module 46 | cot := modules.NewChainOfThought(signature) 47 | 48 | // Create program 49 | program := core.NewProgram(map[string]core.Module{"cot": cot}, func(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { 50 | return cot.Process(ctx, inputs, core.WithGenerateOptions( 51 | core.WithTemperature(0.7), 52 | core.WithMaxTokens(8192), 53 | )) 54 | }) 55 | 56 | // Create optimizer 57 | optimizer := optimizers.NewBootstrapFewShot(func(example, prediction map[string]interface{}, ctx context.Context) bool { 58 | return example["answer"] == prediction["answer"] 59 | }, 5) 60 | 61 | // Prepare training set 62 | trainset := make([]map[string]interface{}, len(examples[:10])) 63 | for i, ex := range examples[:10] { 64 | trainset[i] = map[string]interface{}{ 65 | "question": ex.Question, 66 | "answer": ex.Answer, 67 | } 68 | } 69 | 70 | // Compile the program 71 | compiledProgram, err := optimizer.Compile(ctx, program, program, trainset) 72 | if err != nil { 73 | logger.Fatalf(ctx, "Failed to compile program: %v", err) 74 | } 75 | 76 | // Test the compiled program 77 | for _, ex := range examples[10:15] { 78 | result, err := compiledProgram.Execute(ctx, map[string]interface{}{"question": ex.Question}) 79 | if err != nil { 80 | log.Printf("Error executing program: %v", err) 81 | continue 82 | } 83 | 84 | logger.Info(ctx, "Question: %s\n", ex.Question) 85 | logger.Info(ctx, "Predicted Answer: %s\n", result["answer"]) 86 | logger.Info(ctx, "Actual Answer: %s\n\n", ex.Answer) 87 | } 88 | } 89 | 90 | func main() { 91 | apiKey := flag.String("api-key", "", "Anthropic API Key") 92 | flag.Parse() 93 | 94 | RunGSM8KExample(*apiKey) 95 | } 96 | -------------------------------------------------------------------------------- /examples/hotpotqa/monitor.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "runtime" 6 | "sync/atomic" 7 | "time" 8 | 9 | "github.com/XiaoConstantine/dspy-go/pkg/logging" 10 | ) 11 | 12 | // GoroutineStats tracks goroutine statistics. 13 | type GoroutineStats struct { 14 | Current int32 15 | Peak int32 16 | Started int64 17 | Completed int64 18 | } 19 | 20 | // GoroutineMonitor tracks goroutine usage. 21 | type GoroutineMonitor struct { 22 | stats GoroutineStats 23 | logger *logging.Logger 24 | ctx context.Context 25 | cancel context.CancelFunc 26 | interval time.Duration 27 | } 28 | 29 | // NewGoroutineMonitor creates a new monitor with the given logging interval. 30 | func NewGoroutineMonitor(interval time.Duration, ctx context.Context) *GoroutineMonitor { 31 | ctx, cancel := context.WithCancel(ctx) 32 | return &GoroutineMonitor{ 33 | stats: GoroutineStats{}, 34 | logger: logging.GetLogger(), 35 | ctx: ctx, 36 | cancel: cancel, 37 | interval: interval, 38 | } 39 | } 40 | 41 | // Start begins monitoring goroutine usage. 42 | func (m *GoroutineMonitor) Start() { 43 | go func() { 44 | ticker := time.NewTicker(m.interval) 45 | defer ticker.Stop() 46 | 47 | for { 48 | select { 49 | case <-m.ctx.Done(): 50 | return 51 | case <-ticker.C: 52 | current := runtime.NumGoroutine() 53 | 54 | // Update peak if current count is higher 55 | for { 56 | peak := atomic.LoadInt32(&m.stats.Peak) 57 | if int32(current) <= peak { 58 | break 59 | } 60 | if atomic.CompareAndSwapInt32(&m.stats.Peak, peak, int32(current)) { 61 | break 62 | } 63 | } 64 | 65 | m.logger.Info(m.ctx, "Goroutine Stats - Current: %d, Peak: %d, Started: %d, Completed: %d", 66 | current, 67 | atomic.LoadInt32(&m.stats.Peak), 68 | atomic.LoadInt64(&m.stats.Started), 69 | atomic.LoadInt64(&m.stats.Completed)) 70 | } 71 | } 72 | }() 73 | } 74 | 75 | // Stop terminates the monitoring. 76 | func (m *GoroutineMonitor) Stop() { 77 | m.cancel() 78 | } 79 | 80 | // TrackGoroutine increments the counter when a goroutine starts. 81 | func (m *GoroutineMonitor) TrackGoroutine() { 82 | atomic.AddInt64(&m.stats.Started, 1) 83 | } 84 | 85 | // ReleaseGoroutine decrements the counter when a goroutine completes. 86 | func (m *GoroutineMonitor) ReleaseGoroutine() { 87 | atomic.AddInt64(&m.stats.Completed, 1) 88 | } 89 | 90 | // GetStats returns a copy of the current statistics. 91 | func (m *GoroutineMonitor) GetStats() GoroutineStats { 92 | return GoroutineStats{ 93 | Current: int32(runtime.NumGoroutine()), 94 | Peak: atomic.LoadInt32(&m.stats.Peak), 95 | Started: atomic.LoadInt64(&m.stats.Started), 96 | Completed: atomic.LoadInt64(&m.stats.Completed), 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /examples/others/mcp/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "time" 9 | 10 | "github.com/XiaoConstantine/dspy-go/pkg/core" 11 | "github.com/XiaoConstantine/dspy-go/pkg/llms" 12 | dspyLogging "github.com/XiaoConstantine/dspy-go/pkg/logging" 13 | "github.com/XiaoConstantine/dspy-go/pkg/modules" 14 | "github.com/XiaoConstantine/dspy-go/pkg/tools" 15 | mcpLogging "github.com/XiaoConstantine/mcp-go/pkg/logging" 16 | ) 17 | 18 | // LoggerAdapter adapts dspy-go logger to mcp-go logger interface. 19 | type LoggerAdapter struct { 20 | dspyLogger *dspyLogging.Logger 21 | ctx context.Context 22 | } 23 | 24 | func NewLoggerAdapter(dspyLogger *dspyLogging.Logger) mcpLogging.Logger { 25 | return &LoggerAdapter{ 26 | dspyLogger: dspyLogger, 27 | ctx: context.Background(), 28 | } 29 | } 30 | 31 | // Debug implements mcp-go/pkg/logging.Logger interface. 32 | func (a *LoggerAdapter) Debug(msg string, args ...interface{}) { 33 | a.dspyLogger.Debug(a.ctx, msg, args...) 34 | } 35 | 36 | // Info implements mcp-go/pkg/logging.Logger interface. 37 | func (a *LoggerAdapter) Info(msg string, args ...interface{}) { 38 | a.dspyLogger.Info(a.ctx, msg, args...) 39 | } 40 | 41 | // Warn implements mcp-go/pkg/logging.Logger interface. 42 | func (a *LoggerAdapter) Warn(msg string, args ...interface{}) { 43 | a.dspyLogger.Warn(a.ctx, msg, args...) 44 | } 45 | 46 | // Error implements mcp-go/pkg/logging.Logger interface. 47 | func (a *LoggerAdapter) Error(msg string, args ...interface{}) { 48 | a.dspyLogger.Error(a.ctx, msg, args...) 49 | } 50 | 51 | func main() { 52 | // Setup logging 53 | ctx := core.WithExecutionState(context.Background()) 54 | output := dspyLogging.NewConsoleOutput(true, dspyLogging.WithColor(true)) 55 | 56 | logger := dspyLogging.NewLogger(dspyLogging.Config{ 57 | Severity: dspyLogging.INFO, 58 | Outputs: []dspyLogging.Output{output}, 59 | }) 60 | dspyLogging.SetLogger(logger) 61 | 62 | loggerAdapter := NewLoggerAdapter(logger) 63 | // 1. Start MCP server as a subprocess (e.g., Git MCP server) 64 | cmd := exec.Command("./git-mcp-server") 65 | 66 | // Set up stdio for communication 67 | serverIn, err := cmd.StdinPipe() 68 | if err != nil { 69 | logger.Fatal(ctx, fmt.Sprintf("Failed to create stdin pipe: %v", err)) 70 | } 71 | 72 | serverOut, err := cmd.StdoutPipe() 73 | if err != nil { 74 | logger.Fatal(ctx, fmt.Sprintf("Failed to create stdout pipe: %v", err)) 75 | } 76 | 77 | cmd.Stderr = os.Stderr 78 | 79 | // Start the server 80 | if err := cmd.Start(); err != nil { 81 | logger.Fatal(ctx, fmt.Sprintf("Failed to start server: %v", err)) 82 | } 83 | 84 | // Give the server a moment to initialize 85 | time.Sleep(1 * time.Second) 86 | 87 | // 2. Create MCP client 88 | mcpClient, err := tools.NewMCPClientFromStdio( 89 | serverOut, 90 | serverIn, 91 | tools.MCPClientOptions{ 92 | ClientName: "react-example", 93 | ClientVersion: "0.1.0", 94 | Logger: loggerAdapter, 95 | }, 96 | ) 97 | if err != nil { 98 | logger.Fatal(ctx, fmt.Sprintf("Failed to create MCP client: %v", err)) 99 | } 100 | 101 | // 3. Create tool registry and register MCP tools 102 | registry := tools.NewInMemoryToolRegistry() 103 | err = tools.RegisterMCPTools(registry, mcpClient) 104 | if err != nil { 105 | logger.Fatal(ctx, fmt.Sprintf("Failed to register MCP tools: %v", err)) 106 | } 107 | 108 | // 4. Create and configure ReAct module 109 | signature := core.NewSignature( 110 | []core.InputField{{Field: core.Field{Name: "query"}}}, 111 | []core.OutputField{ 112 | {Field: core.NewField("answer")}, 113 | }, 114 | ).WithInstruction(`Answer the query about Git repositories by using the available Git tools. 115 | Identify which Git tool is appropriate for the question and use it with the correct arguments. 116 | Always use the git_ prefixed tools like git_blame, git_log, git_status, etc. 117 | `) 118 | 119 | maxIters := 5 120 | reactModule := modules.NewReAct(signature, registry, maxIters) 121 | llms.EnsureFactory() 122 | 123 | // 5. Set up LLM 124 | // This assumes you've set the API_KEY environment variable 125 | // Configure the default LLM 126 | err = core.ConfigureDefaultLLM("", core.ModelGoogleGeminiFlash) 127 | if err != nil { 128 | logger.Fatal(ctx, fmt.Sprintf("Failed to configure LLM: %v", err)) 129 | } 130 | 131 | // Set the LLM for the ReAct module 132 | reactModule.SetLLM(core.GetDefaultLLM()) 133 | 134 | // 6. Execute query with ReAct 135 | result, err := reactModule.Process(ctx, map[string]interface{}{ 136 | "query": "Show me the details of latest 5 commit", 137 | }) 138 | 139 | if err != nil { 140 | logger.Error(ctx, "Error executing ReAct: %v", err) 141 | return 142 | } 143 | 144 | // 7. Print the result 145 | logger.Info(ctx, "\nReAct Result: %v", result) 146 | logger.Info(ctx, "Thought: %v\n", result["thought"]) 147 | logger.Info(ctx, "Action: %v\n", result["action"]) 148 | logger.Info(ctx, "Observation: %v\n", result["observation"]) 149 | logger.Info(ctx, "Answer: %v\n", result["answer"]) 150 | 151 | // 8. Clean up 152 | logger.Info(ctx, "\nShutting down...") 153 | if err := cmd.Process.Signal(os.Interrupt); err != nil { 154 | logger.Error(ctx, "Failed to send interrupt signal: %v", err) 155 | if err := cmd.Process.Kill(); err != nil { 156 | logger.Error(ctx, "Failed to kill process: %v", err) 157 | } 158 | } 159 | 160 | if err := cmd.Wait(); err != nil { 161 | logger.Error(ctx, "Server exited with error: %v", err) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /examples/utils/setup.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/XiaoConstantine/dspy-go/pkg/core" 7 | ) 8 | 9 | func SetupLLM(apiKey string, modelID core.ModelID) { 10 | err := core.ConfigureDefaultLLM(apiKey, modelID) 11 | if err != nil { 12 | log.Fatalf("Failed to configure default LLM: %v", err) 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/XiaoConstantine/dspy-go 2 | 3 | go 1.24.1 4 | 5 | require ( 6 | github.com/XiaoConstantine/anthropic-go v0.0.8 7 | github.com/XiaoConstantine/mcp-go v0.2.1 8 | github.com/apache/arrow/go/v13 v13.0.0 9 | github.com/mattn/go-sqlite3 v1.14.28 10 | github.com/sourcegraph/conc v0.3.0 11 | github.com/stretchr/testify v1.10.0 12 | ) 13 | 14 | replace google.golang.org/genproto => google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 15 | 16 | require ( 17 | github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect 18 | github.com/andybalholm/brotli v1.0.4 // indirect 19 | github.com/apache/thrift v0.16.0 // indirect 20 | github.com/davecgh/go-spew v1.1.1 // indirect 21 | github.com/goccy/go-json v0.10.2 // indirect 22 | github.com/golang/snappy v0.0.4 // indirect 23 | github.com/google/flatbuffers v24.3.25+incompatible // indirect 24 | github.com/klauspost/asmfmt v1.3.2 // indirect 25 | github.com/klauspost/compress v1.15.15 // indirect 26 | github.com/klauspost/cpuid/v2 v2.2.7 // indirect 27 | github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect 28 | github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect 29 | github.com/pierrec/lz4/v4 v4.1.17 // indirect 30 | github.com/pmezard/go-difflib v1.0.0 // indirect 31 | github.com/stretchr/objx v0.5.2 // indirect 32 | github.com/zeebo/xxh3 v1.0.2 // indirect 33 | go.uber.org/atomic v1.7.0 // indirect 34 | go.uber.org/multierr v1.9.0 // indirect 35 | golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect 36 | golang.org/x/mod v0.17.0 // indirect 37 | golang.org/x/net v0.36.0 // indirect 38 | golang.org/x/sync v0.11.0 // indirect 39 | golang.org/x/sys v0.30.0 // indirect 40 | golang.org/x/text v0.22.0 // indirect 41 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect 42 | golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect 43 | gonum.org/v1/gonum v0.15.0 // indirect 44 | google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 // indirect 45 | google.golang.org/grpc v1.71.0 // indirect 46 | google.golang.org/protobuf v1.36.5 // indirect 47 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 48 | gopkg.in/yaml.v3 v3.0.1 // indirect 49 | ) 50 | -------------------------------------------------------------------------------- /pkg/agents/README.md: -------------------------------------------------------------------------------- 1 | 2 | How Orchestrator works high level 3 | --------------------------------- 4 | 5 | ```mermaid 6 | sequenceDiagram 7 | participant Client 8 | participant Orchestrator 9 | participant Analyzer 10 | participant TaskParser 11 | participant PlanCreator 12 | participant Processor 13 | 14 | Client->>Orchestrator: Process(task, context) 15 | activate Orchestrator 16 | 17 | Note over Orchestrator,Analyzer: Phase 1: Task Analysis 18 | Orchestrator->>Analyzer: Analyze task breakdown 19 | activate Analyzer 20 | Analyzer-->>Orchestrator: Raw analysis output (XML format) 21 | deactivate Analyzer 22 | 23 | Note over Orchestrator,TaskParser: Phase 2: Task Parsing 24 | Orchestrator->>TaskParser: Parse(analyzerOutput) 25 | activate TaskParser 26 | TaskParser-->>Orchestrator: Structured Task objects 27 | deactivate TaskParser 28 | 29 | Note over Orchestrator,PlanCreator: Phase 3: Plan Creation 30 | Orchestrator->>PlanCreator: CreatePlan(tasks) 31 | activate PlanCreator 32 | PlanCreator-->>Orchestrator: Execution phases 33 | deactivate PlanCreator 34 | 35 | Note over Orchestrator,Processor: Phase 4: Execution 36 | loop For each phase 37 | loop For each task in phase (parallel) 38 | Orchestrator->>Processor: Process(task, context) 39 | activate Processor 40 | Processor-->>Orchestrator: Task result 41 | deactivate Processor 42 | end 43 | end 44 | 45 | Orchestrator-->>Client: OrchestratorResult 46 | deactivate Orchestrator 47 | ``` 48 | 49 | Task dependency resolution 50 | -------------------------- 51 | 52 | ```mermaid 53 | 54 | graph TD 55 | subgraph Task Structure 56 | A[Task] --> B[ID] 57 | A --> C[Type] 58 | A --> D[ProcessorType] 59 | A --> E[Dependencies] 60 | A --> F[Priority] 61 | A --> G[Metadata] 62 | end 63 | 64 | subgraph Plan Creation 65 | H[Input Tasks] --> I[Build Dependency Graph] 66 | I --> J[Detect Cycles] 67 | J --> K[Create Phases] 68 | K --> L[Sort by Priority] 69 | L --> M[Apply Max Concurrent] 70 | end 71 | 72 | subgraph Execution 73 | N[Phase Execution] --> O[Parallel Task Pool] 74 | O --> P[Process Task 1] 75 | O --> Q[Process Task 2] 76 | O --> R[Process Task N] 77 | P --> S[Collect Results] 78 | Q --> S 79 | R --> S 80 | end 81 | ``` 82 | 83 | 84 | Error handling and retry flow explained 85 | --------------------------------------- 86 | 87 | ```mermaid 88 | stateDiagram-v2 89 | [*] --> TaskReceived 90 | TaskReceived --> Analyzing 91 | 92 | state Analyzing { 93 | [*] --> AttemptAnalysis 94 | AttemptAnalysis --> AnalysisSuccess 95 | AttemptAnalysis --> AnalysisFailure 96 | AnalysisFailure --> RetryAnalysis: Retry < MaxAttempts 97 | RetryAnalysis --> AttemptAnalysis 98 | AnalysisFailure --> AnalysisFailed: Retry >= MaxAttempts 99 | } 100 | 101 | state Execution { 102 | [*] --> ExecuteTask 103 | ExecuteTask --> TaskSuccess 104 | ExecuteTask --> TaskFailure 105 | TaskFailure --> RetryTask: Retry < MaxAttempts 106 | RetryTask --> ExecuteTask 107 | TaskFailure --> TaskFailed: Retry >= MaxAttempts 108 | } 109 | 110 | Analyzing --> Execution: Analysis Success 111 | Analyzing --> [*]: Analysis Failed 112 | Execution --> [*]: All Tasks Complete/Failed 113 | ``` 114 | -------------------------------------------------------------------------------- /pkg/agents/agent.go: -------------------------------------------------------------------------------- 1 | package agents 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/XiaoConstantine/dspy-go/pkg/core" 7 | ) 8 | 9 | type Agent interface { 10 | // Execute runs the agent's task with given input and returns output 11 | Execute(ctx context.Context, input map[string]interface{}) (map[string]interface{}, error) 12 | 13 | // GetCapabilities returns the tools/capabilities available to this agent 14 | GetCapabilities() []core.Tool 15 | 16 | // GetMemory returns the agent's memory store 17 | GetMemory() Memory 18 | } 19 | -------------------------------------------------------------------------------- /pkg/agents/agent_test.go: -------------------------------------------------------------------------------- 1 | package agents 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | testutil "github.com/XiaoConstantine/dspy-go/internal/testutil" 8 | "github.com/XiaoConstantine/dspy-go/pkg/core" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/mock" 11 | ) 12 | 13 | // MockTool implements Tool interface for testing. 14 | type MockTool struct { 15 | mock.Mock 16 | } 17 | 18 | func (m *MockTool) Name() string { 19 | args := m.Called() 20 | return args.String(0) 21 | } 22 | 23 | func (m *MockTool) Description() string { 24 | args := m.Called() 25 | return args.String(0) 26 | } 27 | 28 | func (m *MockTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) { 29 | args := m.Called(ctx, params) 30 | return args.Get(0), args.Error(1) 31 | } 32 | 33 | func (m *MockTool) ValidateParams(params map[string]interface{}) error { 34 | args := m.Called(params) 35 | return args.Error(0) 36 | } 37 | 38 | // MockAgent implements Agent interface for testing. 39 | type MockAgent struct { 40 | mock.Mock 41 | } 42 | 43 | func (m *MockAgent) Execute(ctx context.Context, input map[string]interface{}) (map[string]interface{}, error) { 44 | args := m.Called(ctx, input) 45 | return args.Get(0).(map[string]interface{}), args.Error(1) 46 | } 47 | 48 | func (m *MockAgent) GetCapabilities() []core.Tool { 49 | args := m.Called() 50 | return args.Get(0).([]core.Tool) 51 | } 52 | 53 | func (m *MockAgent) GetMemory() Memory { 54 | args := m.Called() 55 | return args.Get(0).(Memory) 56 | } 57 | 58 | func TestAgentInterface(t *testing.T) { 59 | t.Run("MockAgent Implementation", func(t *testing.T) { 60 | mockAgent := new(MockAgent) 61 | mockTool := testutil.NewMockTool("test_tool") 62 | mockMemory := NewInMemoryStore() 63 | 64 | // Setup expectations 65 | expectedOutput := map[string]interface{}{ 66 | "result": "success", 67 | } 68 | mockAgent.On("Execute", mock.Anything, mock.Anything).Return(expectedOutput, nil) 69 | mockAgent.On("GetCapabilities").Return([]core.Tool{mockTool}) 70 | mockAgent.On("GetMemory").Return(mockMemory) 71 | 72 | // Test Execute 73 | ctx := context.Background() 74 | input := map[string]interface{}{ 75 | "test": "input", 76 | } 77 | output, err := mockAgent.Execute(ctx, input) 78 | assert.NoError(t, err) 79 | assert.Equal(t, expectedOutput, output) 80 | 81 | // Test GetCapabilities 82 | capabilities := mockAgent.GetCapabilities() 83 | assert.Len(t, capabilities, 1) 84 | assert.Equal(t, mockTool, capabilities[0]) 85 | 86 | // Test GetMemory 87 | memory := mockAgent.GetMemory() 88 | assert.Equal(t, mockMemory, memory) 89 | 90 | // Verify all expectations were met 91 | mockAgent.AssertExpectations(t) 92 | }) 93 | 94 | t.Run("MockTool Implementation", func(t *testing.T) { 95 | mockTool := new(MockTool) 96 | 97 | // Setup expectations 98 | mockTool.On("Name").Return("TestTool") 99 | mockTool.On("Description").Return("Test tool description") 100 | mockTool.On("Execute", mock.Anything, mock.Anything).Return("result", nil) 101 | mockTool.On("ValidateParams", mock.Anything).Return(nil) 102 | 103 | // Test Name 104 | name := mockTool.Name() 105 | assert.Equal(t, "TestTool", name) 106 | 107 | // Test Description 108 | desc := mockTool.Description() 109 | assert.Equal(t, "Test tool description", desc) 110 | 111 | // Test Execute 112 | ctx := context.Background() 113 | params := map[string]interface{}{ 114 | "param": "value", 115 | } 116 | result, err := mockTool.Execute(ctx, params) 117 | assert.NoError(t, err) 118 | assert.Equal(t, "result", result) 119 | 120 | // Test ValidateParams 121 | err = mockTool.ValidateParams(params) 122 | assert.NoError(t, err) 123 | 124 | // Verify all expectations were met 125 | mockTool.AssertExpectations(t) 126 | }) 127 | } 128 | -------------------------------------------------------------------------------- /pkg/agents/memory.go: -------------------------------------------------------------------------------- 1 | package agents 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/XiaoConstantine/dspy-go/pkg/errors" 8 | ) 9 | 10 | // Memory provides storage capabilities for agents. 11 | type Memory interface { 12 | // Store saves a value with a given key 13 | Store(key string, value interface{}) error 14 | 15 | // Retrieve gets a value by key 16 | Retrieve(key string) (interface{}, error) 17 | 18 | // List returns all stored keys 19 | List() ([]string, error) 20 | 21 | // Clear removes all stored values 22 | Clear() error 23 | } 24 | 25 | // Simple in-memory implementation. 26 | type InMemoryStore struct { 27 | data map[string]interface{} 28 | mu sync.RWMutex 29 | } 30 | 31 | func NewInMemoryStore() *InMemoryStore { 32 | return &InMemoryStore{ 33 | data: make(map[string]interface{}), 34 | } 35 | } 36 | 37 | func (s *InMemoryStore) Store(key string, value interface{}) error { 38 | s.mu.Lock() 39 | defer s.mu.Unlock() 40 | s.data[key] = value 41 | return nil 42 | } 43 | 44 | func (s *InMemoryStore) Retrieve(key string) (interface{}, error) { 45 | s.mu.RLock() 46 | defer s.mu.RUnlock() 47 | 48 | value, exists := s.data[key] 49 | if !exists { 50 | return nil, errors.WithFields( 51 | errors.New(errors.ResourceNotFound, "key not found in memory store"), 52 | errors.Fields{ 53 | "key": key, 54 | "access_time": time.Now().UTC(), 55 | }) 56 | } 57 | return value, nil 58 | } 59 | 60 | func (s *InMemoryStore) List() ([]string, error) { 61 | s.mu.RLock() 62 | defer s.mu.RUnlock() 63 | 64 | keys := make([]string, 0, len(s.data)) 65 | for k := range s.data { 66 | keys = append(keys, k) 67 | } 68 | 69 | return keys, nil 70 | } 71 | 72 | func (s *InMemoryStore) Clear() error { 73 | s.mu.Lock() 74 | defer s.mu.Unlock() 75 | 76 | // Create a new map rather than ranging and deleting 77 | // This is more efficient for clearing everything 78 | s.data = make(map[string]interface{}) 79 | return nil 80 | } 81 | -------------------------------------------------------------------------------- /pkg/agents/memory/buffered_memory.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | goerrors "errors" 7 | "sync" 8 | 9 | "github.com/XiaoConstantine/dspy-go/pkg/agents" 10 | "github.com/XiaoConstantine/dspy-go/pkg/errors" 11 | ) 12 | 13 | // defaultHistoryKey is the default key used to store history in the underlying store. 14 | const defaultHistoryKey = "conversation_log" 15 | 16 | // Message represents a single entry in the conversation history. 17 | type Message struct { 18 | Role string `json:"role"` 19 | Content string `json:"content"` 20 | } 21 | 22 | // BufferedMemory provides an in-memory store that keeps only the last N messages. 23 | // It wraps an underlying agents.Memory implementation. 24 | type BufferedMemory struct { 25 | store agents.Memory 26 | maxSize int 27 | historyKey string // Keep internal field, but set to default 28 | mu sync.RWMutex 29 | } 30 | 31 | // NewBufferedMemory creates a new BufferedMemory instance with a default history key. 32 | // It initializes an InMemoryStore as the underlying storage. 33 | func NewBufferedMemory(maxSize int) *BufferedMemory { 34 | if maxSize <= 0 { 35 | maxSize = 1 // Ensure maxSize is at least 1 36 | } 37 | // Use the existing NewInMemoryStore for the underlying storage 38 | underlyingStore := agents.NewInMemoryStore() 39 | return &BufferedMemory{ 40 | store: underlyingStore, 41 | maxSize: maxSize, 42 | historyKey: defaultHistoryKey, // Use the default key here 43 | } 44 | } 45 | 46 | // Add appends a new message to the history, ensuring the buffer size limit is maintained. 47 | func (m *BufferedMemory) Add(ctx context.Context, role string, content string) error { 48 | m.mu.Lock() 49 | defer m.mu.Unlock() 50 | 51 | history, err := m.getHistoryInternal(ctx) 52 | if err != nil { 53 | // Check if the error is specifically ResourceNotFound using errors.As 54 | var dspyErr *errors.Error 55 | if goerrors.As(err, &dspyErr) && dspyErr.Code() == errors.ResourceNotFound { 56 | // It's a ResourceNotFound error, initialize history 57 | history = make([]Message, 0) 58 | } else { 59 | // It's some other error, wrap and return it 60 | return errors.Wrap(err, errors.Unknown, "failed to retrieve history for adding") 61 | } 62 | } 63 | 64 | // Append the new message 65 | history = append(history, Message{Role: role, Content: content}) 66 | 67 | // Enforce maxSize limit 68 | if len(history) > m.maxSize { 69 | startIndex := len(history) - m.maxSize 70 | history = history[startIndex:] 71 | } 72 | 73 | // Marshal and store the updated history 74 | historyBytes, err := json.Marshal(history) 75 | if err != nil { 76 | // Use Unknown or a more appropriate general code if SerializationFailed doesn't exist 77 | return errors.Wrap(err, errors.Unknown, "failed to marshal history") // Changed to Unknown 78 | } 79 | 80 | // Use m.historyKey which is now set to the default 81 | return m.store.Store(m.historyKey, historyBytes) 82 | } 83 | 84 | // Get retrieves the conversation history. 85 | func (m *BufferedMemory) Get(ctx context.Context) ([]Message, error) { 86 | m.mu.RLock() 87 | defer m.mu.RUnlock() 88 | // Use m.historyKey which is now set to the default 89 | return m.getHistoryInternal(ctx) 90 | } 91 | 92 | // getHistoryInternal retrieves and unmarshals the history from the store. 93 | // This internal version doesn't lock, assuming the caller handles locking. 94 | func (m *BufferedMemory) getHistoryInternal(ctx context.Context) ([]Message, error) { 95 | // Use m.historyKey which is now set to the default 96 | value, err := m.store.Retrieve(m.historyKey) 97 | if err != nil { 98 | // Check if the error is specifically ResourceNotFound using errors.As 99 | var dspyErr *errors.Error 100 | if goerrors.As(err, &dspyErr) && dspyErr.Code() == errors.ResourceNotFound { 101 | // Key not found means empty history, return empty slice and nil error 102 | return make([]Message, 0), nil 103 | } else { 104 | // It's some other error retrieving from the store 105 | return nil, err 106 | } 107 | } 108 | 109 | // If we got here, the key was found. Proceed with type assertion/unmarshalling 110 | historyBytes, ok := value.([]byte) 111 | if !ok { 112 | // Attempt conversion if it was stored as string initially (less likely with Store) 113 | if historyString, okStr := value.(string); okStr { 114 | historyBytes = []byte(historyString) 115 | } else { 116 | // Use InvalidResponse or Unknown if TypeAssertionFailed doesn't exist 117 | return nil, errors.New(errors.InvalidResponse, "stored history is not []byte or string") // Changed to InvalidResponse 118 | } 119 | } 120 | 121 | if len(historyBytes) == 0 { 122 | return make([]Message, 0), nil // Return empty slice if stored value is empty bytes 123 | } 124 | 125 | var history []Message 126 | if err := json.Unmarshal(historyBytes, &history); err != nil { 127 | // Use InvalidResponse or Unknown if DeserializationFailed doesn't exist 128 | return nil, errors.Wrap(err, errors.InvalidResponse, "failed to unmarshal history") // Changed to InvalidResponse 129 | } 130 | 131 | return history, nil 132 | } 133 | 134 | // Clear removes the conversation history from the store. 135 | func (m *BufferedMemory) Clear(ctx context.Context) error { 136 | m.mu.Lock() 137 | defer m.mu.Unlock() 138 | // Use m.historyKey which is now set to the default 139 | emptyHistory := make([]Message, 0) 140 | historyBytes, _ := json.Marshal(emptyHistory) // Error handling omitted for brevity 141 | return m.store.Store(m.historyKey, historyBytes) 142 | } 143 | -------------------------------------------------------------------------------- /pkg/agents/memory/sqlite_memory_test.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "math/rand" 7 | "os" 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestSQLiteStore(t *testing.T) { 17 | // Create an in-memory database for testing 18 | store, err := NewSQLiteStore(":memory:") 19 | require.NoError(t, err) 20 | defer store.Close() 21 | 22 | t.Run("Basic Store and Retrieve", func(t *testing.T) { 23 | testData := map[string]interface{}{ 24 | "string": "test value", 25 | "number": 42, 26 | "bool": true, 27 | "array": []string{"a", "b", "c"}, 28 | "map": map[string]int{"one": 1, "two": 2}, 29 | } 30 | 31 | for key, value := range testData { 32 | err := store.Store(key, value) 33 | assert.NoError(t, err) 34 | 35 | retrieved, err := store.Retrieve(key) 36 | assert.NoError(t, err) 37 | assert.Equal(t, value, retrieved) 38 | } 39 | }) 40 | 41 | t.Run("List Keys", func(t *testing.T) { 42 | // Clear existing data 43 | err := store.Clear() 44 | require.NoError(t, err) 45 | 46 | // Store some test data 47 | testKeys := []string{"key1", "key2", "key3"} 48 | for _, key := range testKeys { 49 | err := store.Store(key, "value") 50 | require.NoError(t, err) 51 | } 52 | 53 | // List keys 54 | keys, err := store.List() 55 | assert.NoError(t, err) 56 | assert.ElementsMatch(t, testKeys, keys) 57 | }) 58 | 59 | t.Run("Clear Store", func(t *testing.T) { 60 | // Store some data 61 | err := store.Store("test", "value") 62 | require.NoError(t, err) 63 | 64 | // Clear the store 65 | err = store.Clear() 66 | assert.NoError(t, err) 67 | 68 | // Verify store is empty 69 | keys, err := store.List() 70 | assert.NoError(t, err) 71 | assert.Empty(t, keys) 72 | }) 73 | 74 | t.Run("Non-existent Key", func(t *testing.T) { 75 | _, err := store.Retrieve("nonexistent") 76 | assert.Error(t, err) 77 | }) 78 | 79 | t.Run("Concurrent Access", func(t *testing.T) { 80 | dbPath := "/tmp/test_ttl.db" 81 | os.Remove(dbPath) // Clean up before test 82 | 83 | store, err := NewSQLiteStore(dbPath) 84 | require.NoError(t, err) 85 | defer func() { 86 | store.Close() 87 | os.Remove(dbPath) 88 | }() 89 | const numGoroutines = 10 90 | done := make(chan bool, numGoroutines) 91 | var wg sync.WaitGroup 92 | 93 | // Clear any existing data 94 | err = store.Clear() 95 | require.NoError(t, err) 96 | 97 | wg.Add(numGoroutines) 98 | for i := 0; i < numGoroutines; i++ { 99 | go func(n int) { 100 | defer wg.Done() 101 | 102 | key := fmt.Sprintf("concurrent_key_%d", n) 103 | err := store.Store(key, n) 104 | if !assert.NoError(t, err, "Store failed for key: %s", key) { 105 | done <- false 106 | return 107 | } 108 | 109 | // Small delay to increase chance of concurrent access 110 | time.Sleep(time.Millisecond * time.Duration(rand.Intn(10))) 111 | 112 | retrieved, err := store.Retrieve(key) 113 | if !assert.NoError(t, err, "Retrieve failed for key: %s", key) { 114 | done <- false 115 | return 116 | } 117 | 118 | if !assert.Equal(t, n, retrieved, "Value mismatch for key: %s", key) { 119 | done <- false 120 | return 121 | } 122 | 123 | done <- true 124 | }(i) 125 | } 126 | 127 | // Wait for all goroutines to finish 128 | wg.Wait() 129 | close(done) 130 | 131 | // Check if any goroutine failed 132 | for success := range done { 133 | assert.True(t, success, "One or more goroutines failed") 134 | } 135 | }) 136 | 137 | t.Run("TTL Storage", func(t *testing.T) { 138 | ctx := context.Background() 139 | 140 | // Store with short TTL 141 | err := store.StoreWithTTL(ctx, "ttl_key", "ttl_value", 1*time.Second) 142 | require.NoError(t, err) 143 | 144 | // Verify value exists 145 | value, err := store.Retrieve("ttl_key") 146 | assert.NoError(t, err) 147 | assert.Equal(t, "ttl_value", value) 148 | t.Logf("Current time (UTC): %s\n", time.Now().UTC().Format(time.RFC3339)) 149 | // Wait for TTL to expire 150 | time.Sleep(3 * time.Second) 151 | t.Logf("Current time (UTC): %s\n", time.Now().UTC().Format(time.RFC3339)) 152 | 153 | cleaned, err := store.CleanExpired(ctx) 154 | assert.NoError(t, err) 155 | assert.Equal(t, int64(1), cleaned, "Expected one entry to be cleaned") 156 | // Verify value is gone 157 | _, err = store.Retrieve("ttl_key") 158 | assert.Error(t, err) 159 | }) 160 | 161 | t.Run("Invalid JSON", func(t *testing.T) { 162 | // Try to store a channel (cannot be marshaled to JSON) 163 | ch := make(chan bool) 164 | err := store.Store("invalid", ch) 165 | assert.Error(t, err) 166 | }) 167 | 168 | t.Run("Database Connection", func(t *testing.T) { 169 | // Test invalid database path 170 | _, err := NewSQLiteStore("/root/forbidden/db.sqlite") 171 | assert.Error(t, err) 172 | 173 | // Test closing database 174 | tempStore, err := NewSQLiteStore(":memory:") 175 | require.NoError(t, err) 176 | assert.NoError(t, tempStore.Close()) 177 | 178 | // Try operations after closing 179 | err = tempStore.Store("key", "value") 180 | assert.Error(t, err) 181 | }) 182 | } 183 | -------------------------------------------------------------------------------- /pkg/agents/memory_test.go: -------------------------------------------------------------------------------- 1 | package agents 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestInMemoryStore(t *testing.T) { 13 | t.Run("Basic Operations", func(t *testing.T) { 14 | store := NewInMemoryStore() 15 | 16 | // Test Store 17 | err := store.Store("key1", "value1") 18 | require.NoError(t, err) 19 | 20 | // Test Retrieve 21 | value, err := store.Retrieve("key1") 22 | require.NoError(t, err) 23 | assert.Equal(t, "value1", value) 24 | 25 | // Test non-existent key 26 | _, err = store.Retrieve("nonexistent") 27 | assert.Error(t, err) 28 | assert.Contains(t, err.Error(), "not found") 29 | 30 | // Test List 31 | keys, err := store.List() 32 | require.NoError(t, err) 33 | assert.Contains(t, keys, "key1") 34 | 35 | // Test Clear 36 | err = store.Clear() 37 | require.NoError(t, err) 38 | keys, err = store.List() 39 | require.NoError(t, err) 40 | assert.Empty(t, keys) 41 | }) 42 | 43 | t.Run("Concurrent Operations", func(t *testing.T) { 44 | store := NewInMemoryStore() 45 | var wg sync.WaitGroup 46 | iterations := 100 47 | 48 | // Concurrent writes 49 | for i := 0; i < iterations; i++ { 50 | wg.Add(1) 51 | go func(i int) { 52 | defer wg.Done() 53 | err := store.Store(fmt.Sprintf("key%d", i), i) 54 | assert.NoError(t, err) 55 | }(i) 56 | } 57 | 58 | // Concurrent reads 59 | for i := 0; i < iterations; i++ { 60 | wg.Add(1) 61 | go func(i int) { 62 | defer wg.Done() 63 | _, err := store.Retrieve(fmt.Sprintf("key%d", i)) 64 | // Error is acceptable since we're reading concurrently with writes 65 | if err != nil { 66 | assert.Contains(t, err.Error(), "not found") 67 | } 68 | }(i) 69 | } 70 | 71 | wg.Wait() 72 | 73 | // Verify final state 74 | keys, err := store.List() 75 | require.NoError(t, err) 76 | assert.Len(t, keys, iterations) 77 | }) 78 | 79 | t.Run("Store Different Types", func(t *testing.T) { 80 | store := NewInMemoryStore() 81 | 82 | testCases := []struct { 83 | key string 84 | value interface{} 85 | }{ 86 | {"string", "test"}, 87 | {"int", 42}, 88 | {"float", 3.14}, 89 | {"bool", true}, 90 | {"slice", []string{"a", "b", "c"}}, 91 | {"map", map[string]int{"a": 1, "b": 2}}, 92 | {"struct", struct{ Name string }{"test"}}, 93 | } 94 | 95 | for _, tc := range testCases { 96 | t.Run(tc.key, func(t *testing.T) { 97 | err := store.Store(tc.key, tc.value) 98 | require.NoError(t, err) 99 | 100 | retrieved, err := store.Retrieve(tc.key) 101 | require.NoError(t, err) 102 | assert.Equal(t, tc.value, retrieved) 103 | }) 104 | } 105 | }) 106 | } 107 | -------------------------------------------------------------------------------- /pkg/agents/workflows/chain.go: -------------------------------------------------------------------------------- 1 | package workflows 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/XiaoConstantine/dspy-go/pkg/agents" 8 | "github.com/XiaoConstantine/dspy-go/pkg/core" 9 | "github.com/XiaoConstantine/dspy-go/pkg/errors" 10 | ) 11 | 12 | // ChainWorkflow executes steps in a linear sequence, where each step's output 13 | // can be used as input for subsequent steps. 14 | type ChainWorkflow struct { 15 | *BaseWorkflow 16 | } 17 | 18 | func NewChainWorkflow(memory agents.Memory) *ChainWorkflow { 19 | return &ChainWorkflow{ 20 | BaseWorkflow: NewBaseWorkflow(memory), 21 | } 22 | } 23 | 24 | // Execute runs steps sequentially, passing state from one step to the next. 25 | func (w *ChainWorkflow) Execute(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { 26 | // Initialize workflow state with input values 27 | state := make(map[string]interface{}) 28 | for k, v := range inputs { 29 | state[k] = v 30 | } 31 | 32 | totalSteps := len(w.steps) 33 | // Execute steps in sequence 34 | for i, step := range w.steps { 35 | 36 | stepCtx, stepSpan := core.StartSpan(ctx, fmt.Sprintf("ChainStep_%d", i)) 37 | 38 | stepSpan.WithAnnotation("chain_step", map[string]interface{}{ 39 | "name": step.ID, 40 | "index": i, 41 | "total": totalSteps, 42 | }) 43 | 44 | signature := step.Module.GetSignature() 45 | 46 | // Create subset of state containing only the fields this step needs 47 | stepInputs := make(map[string]interface{}) 48 | for _, field := range signature.Inputs { 49 | if val, ok := state[field.Name]; ok { 50 | stepInputs[field.Name] = val 51 | } 52 | } 53 | 54 | // Execute the step 55 | result, err := step.Execute(stepCtx, stepInputs) 56 | 57 | core.EndSpan(stepCtx) 58 | if err != nil { 59 | return nil, errors.WithFields( 60 | errors.Wrap(err, errors.StepExecutionFailed, "step execution failed"), 61 | errors.Fields{ 62 | "step_id": step.ID, 63 | "step": i + 1, 64 | "inputs": stepInputs, 65 | "total": totalSteps, 66 | }) 67 | } 68 | 69 | // Update state with step outputs 70 | for k, v := range result.Outputs { 71 | state[k] = v 72 | } 73 | } 74 | 75 | return state, nil 76 | } 77 | -------------------------------------------------------------------------------- /pkg/agents/workflows/chain_test.go: -------------------------------------------------------------------------------- 1 | package workflows 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | 8 | "github.com/XiaoConstantine/dspy-go/pkg/core" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/mock" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestChainWorkflow(t *testing.T) { 15 | t.Run("Execute steps in sequence", func(t *testing.T) { 16 | memory := new(MockMemory) 17 | workflow := NewChainWorkflow(memory) 18 | 19 | // Create mock modules with expected behavior 20 | module1 := new(MockModule) 21 | module2 := new(MockModule) 22 | 23 | // Setup mock responses 24 | module1.On("GetSignature").Return(core.Signature{ 25 | Inputs: []core.InputField{{Field: core.Field{Name: "input1"}}}, 26 | Outputs: []core.OutputField{{Field: core.Field{Name: "output1"}}}, 27 | }) 28 | module1.On("Process", mock.Anything, mock.Anything, mock.Anything).Return( 29 | map[string]any{"output1": "intermediate"}, nil, 30 | ) 31 | 32 | module2.On("GetSignature").Return(core.Signature{ 33 | Inputs: []core.InputField{{Field: core.Field{Name: "output1"}}}, 34 | Outputs: []core.OutputField{{Field: core.Field{Name: "final"}}}, 35 | }) 36 | module2.On("Process", mock.Anything, mock.Anything, mock.Anything).Return( 37 | map[string]any{"final": "result"}, nil, 38 | ) 39 | 40 | // Add steps to workflow 41 | err := workflow.AddStep(&Step{ID: "step1", Module: module1}) 42 | require.NoError(t, err, "Failed to add step1") 43 | 44 | err = workflow.AddStep(&Step{ID: "step2", Module: module2}) 45 | require.NoError(t, err, "Failed to add step2") 46 | // Execute workflow 47 | ctx := context.Background() 48 | result, err := workflow.Execute(ctx, map[string]interface{}{ 49 | "input1": "initial", 50 | }) 51 | 52 | assert.NoError(t, err) 53 | assert.Equal(t, "result", result["final"]) 54 | 55 | // Verify mocks 56 | module1.AssertExpectations(t) 57 | module2.AssertExpectations(t) 58 | }) 59 | 60 | t.Run("Step failure", func(t *testing.T) { 61 | memory := new(MockMemory) 62 | workflow := NewChainWorkflow(memory) 63 | 64 | module := new(MockModule) 65 | module.On("GetSignature").Return(core.Signature{ 66 | Inputs: []core.InputField{{Field: core.Field{Name: "input"}}}, 67 | Outputs: []core.OutputField{{Field: core.Field{Name: "output"}}}, 68 | }) 69 | module.On("Process", mock.Anything, mock.Anything).Return( 70 | map[string]any{}, errors.New("step failed"), 71 | ) 72 | 73 | err := workflow.AddStep(&Step{ID: "step1", Module: module}) 74 | require.NoError(t, err, "Failed to add step1") 75 | ctx := context.Background() 76 | _, err = workflow.Execute(ctx, map[string]interface{}{ 77 | "input": "value", 78 | }) 79 | 80 | assert.Error(t, err) 81 | assert.Contains(t, err.Error(), "step failed") 82 | module.AssertExpectations(t) 83 | }) 84 | } 85 | -------------------------------------------------------------------------------- /pkg/agents/workflows/errors.go: -------------------------------------------------------------------------------- 1 | package workflows 2 | 3 | import "github.com/XiaoConstantine/dspy-go/pkg/errors" 4 | 5 | var ( 6 | // ErrStepConditionFailed indicates a step's condition check failed. 7 | ErrStepConditionFailed = errors.New(errors.InvalidWorkflowState, "step condition check failed") 8 | 9 | // ErrStepNotFound indicates a referenced step doesn't exist in workflow. 10 | ErrStepNotFound = errors.New(errors.ResourceNotFound, "step not found in workflow") 11 | 12 | // ErrInvalidInput indicates missing or invalid input parameters. 13 | ErrInvalidInput = errors.New(errors.InvalidInput, "invalid input parameters") 14 | 15 | // ErrDuplicateStepID indicates attempt to add step with existing ID. 16 | ErrDuplicateStepID = errors.New(errors.ValidationFailed, "duplicate step ID") 17 | 18 | // ErrCyclicDependency indicates circular dependencies between steps. 19 | ErrCyclicDependency = errors.New(errors.WorkflowExecutionFailed, "cyclic dependency detected in workflow") 20 | ) 21 | 22 | func WrapWorkflowError(err error, fields map[string]interface{}) error { 23 | if err == nil { 24 | return nil 25 | } 26 | return errors.WithFields(err, fields) 27 | } 28 | -------------------------------------------------------------------------------- /pkg/agents/workflows/errors_test.go: -------------------------------------------------------------------------------- 1 | package workflows 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestErrors(t *testing.T) { 12 | // Test error variables are properly defined 13 | t.Run("Error definitions", func(t *testing.T) { 14 | assert.NotNil(t, ErrStepConditionFailed) 15 | assert.NotNil(t, ErrStepNotFound) 16 | assert.NotNil(t, ErrInvalidInput) 17 | assert.NotNil(t, ErrDuplicateStepID) 18 | assert.NotNil(t, ErrCyclicDependency) 19 | }) 20 | 21 | // Test error messages are as expected 22 | t.Run("Error messages", func(t *testing.T) { 23 | assert.Equal(t, "step condition check failed", ErrStepConditionFailed.Error()) 24 | assert.Equal(t, "step not found in workflow", ErrStepNotFound.Error()) 25 | assert.Equal(t, "invalid input parameters", ErrInvalidInput.Error()) 26 | assert.Equal(t, "duplicate step ID", ErrDuplicateStepID.Error()) 27 | assert.Equal(t, "cyclic dependency detected in workflow", ErrCyclicDependency.Error()) 28 | }) 29 | 30 | // Test errors can be used with errors.Is 31 | t.Run("Error comparison", func(t *testing.T) { 32 | err := ErrStepConditionFailed 33 | assert.True(t, errors.Is(err, ErrStepConditionFailed)) 34 | assert.False(t, errors.Is(err, ErrStepNotFound)) 35 | }) 36 | 37 | // Test error wrapping 38 | t.Run("Error wrapping", func(t *testing.T) { 39 | wrapped := fmt.Errorf("failed to execute step: %w", ErrStepConditionFailed) 40 | assert.True(t, errors.Is(wrapped, ErrStepConditionFailed)) 41 | assert.Contains(t, wrapped.Error(), "failed to execute step") 42 | }) 43 | } 44 | -------------------------------------------------------------------------------- /pkg/agents/workflows/parallel.go: -------------------------------------------------------------------------------- 1 | package workflows 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | 8 | "github.com/XiaoConstantine/dspy-go/pkg/agents" 9 | ) 10 | 11 | // ParallelWorkflow executes multiple steps concurrently. 12 | type ParallelWorkflow struct { 13 | *BaseWorkflow 14 | // Maximum number of concurrent steps 15 | maxConcurrent int 16 | } 17 | 18 | func NewParallelWorkflow(memory agents.Memory, maxConcurrent int) *ParallelWorkflow { 19 | return &ParallelWorkflow{ 20 | BaseWorkflow: NewBaseWorkflow(memory), 21 | maxConcurrent: maxConcurrent, 22 | } 23 | } 24 | 25 | func (w *ParallelWorkflow) Execute(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { 26 | state := make(map[string]interface{}) 27 | for k, v := range inputs { 28 | state[k] = v 29 | } 30 | 31 | // Create channel for collecting results 32 | results := make(chan *StepResult, len(w.steps)) 33 | errors := make(chan error, len(w.steps)) 34 | 35 | // Create semaphore to limit concurrency 36 | sem := make(chan struct{}, w.maxConcurrent) 37 | 38 | // Launch goroutine for each step 39 | var wg sync.WaitGroup 40 | for _, step := range w.steps { 41 | wg.Add(1) 42 | go func(s *Step) { 43 | defer wg.Done() 44 | 45 | // Acquire semaphore 46 | sem <- struct{}{} 47 | defer func() { <-sem }() 48 | 49 | // Prepare inputs for this step 50 | stepInputs := make(map[string]interface{}) 51 | signature := step.Module.GetSignature() 52 | 53 | for _, field := range signature.Inputs { 54 | if val, ok := inputs[field.Name]; ok { 55 | stepInputs[field.Name] = val 56 | } 57 | } 58 | 59 | // Execute step 60 | result, err := s.Execute(ctx, stepInputs) 61 | if err != nil { 62 | errors <- fmt.Errorf("step %s failed: %w", s.ID, err) 63 | return 64 | } 65 | results <- result 66 | }(step) 67 | } 68 | 69 | // Wait for all steps to complete 70 | go func() { 71 | wg.Wait() 72 | close(results) 73 | close(errors) 74 | }() 75 | 76 | // Collect results and errors 77 | var errs []error 78 | for err := range errors { 79 | errs = append(errs, err) 80 | } 81 | if len(errs) > 0 { 82 | return nil, fmt.Errorf("parallel execution failed with %d errors: %v", len(errs), errs) 83 | } 84 | 85 | // Merge results into final state 86 | for result := range results { 87 | for k, v := range result.Outputs { 88 | state[k] = v 89 | } 90 | } 91 | 92 | return state, nil 93 | } 94 | -------------------------------------------------------------------------------- /pkg/agents/workflows/parallel_test.go: -------------------------------------------------------------------------------- 1 | package workflows 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/XiaoConstantine/dspy-go/pkg/core" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/mock" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestParallelWorkflow(t *testing.T) { 17 | t.Run("Execute steps in parallel", func(t *testing.T) { 18 | memory := new(MockMemory) 19 | workflow := NewParallelWorkflow(memory, 2) 20 | 21 | var wg sync.WaitGroup 22 | procCount := 0 23 | var mu sync.Mutex 24 | 25 | // Create mock modules that track concurrent execution 26 | createModule := func(id string, delay time.Duration) *MockModule { 27 | module := new(MockModule) 28 | module.On("GetSignature").Return(core.Signature{ 29 | Inputs: []core.InputField{{Field: core.Field{Name: "input"}}}, 30 | Outputs: []core.OutputField{{Field: core.Field{Name: id}}}, 31 | }) 32 | module.On("Process", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { 33 | mu.Lock() 34 | procCount++ 35 | current := procCount 36 | mu.Unlock() 37 | 38 | // Ensure we don't exceed max concurrent processes 39 | assert.LessOrEqual(t, current, 2) 40 | time.Sleep(delay) 41 | 42 | mu.Lock() 43 | procCount-- 44 | mu.Unlock() 45 | wg.Done() 46 | }).Return(map[string]any{id: "done"}, nil) 47 | return module 48 | } 49 | 50 | // Add three steps with different delays 51 | wg.Add(3) 52 | err := workflow.AddStep(&Step{ID: "step1", Module: createModule("output1", 100*time.Millisecond)}) 53 | require.NoError(t, err, "Failed to add step1") 54 | 55 | err = workflow.AddStep(&Step{ID: "step2", Module: createModule("output2", 50*time.Millisecond)}) 56 | require.NoError(t, err, "Failed to add step2") 57 | 58 | err = workflow.AddStep(&Step{ID: "step3", Module: createModule("output3", 75*time.Millisecond)}) 59 | require.NoError(t, err, "Failed to add step3") 60 | 61 | // Execute workflow 62 | ctx := context.Background() 63 | result, err := workflow.Execute(ctx, map[string]interface{}{ 64 | "input": "value", 65 | }) 66 | 67 | // Wait for all goroutines to complete 68 | wg.Wait() 69 | 70 | assert.NoError(t, err) 71 | assert.Equal(t, "done", result["output1"]) 72 | assert.Equal(t, "done", result["output2"]) 73 | assert.Equal(t, "done", result["output3"]) 74 | }) 75 | 76 | t.Run("Handle step failure", func(t *testing.T) { 77 | memory := new(MockMemory) 78 | workflow := NewParallelWorkflow(memory, 2) 79 | 80 | failingModule := new(MockModule) 81 | failingModule.On("GetSignature").Return(core.Signature{ 82 | Inputs: []core.InputField{{Field: core.Field{Name: "input"}}}, 83 | Outputs: []core.OutputField{{Field: core.Field{Name: "output"}}}, 84 | }) 85 | failingModule.On("Process", mock.Anything, mock.Anything).Return( 86 | map[string]any{}, errors.New("step failed"), 87 | ) 88 | 89 | err := workflow.AddStep(&Step{ID: "step1", Module: failingModule}) 90 | require.NoError(t, err, "Failed to add step1") 91 | 92 | ctx := context.Background() 93 | _, err = workflow.Execute(ctx, map[string]interface{}{ 94 | "input": "value", 95 | }) 96 | 97 | assert.Error(t, err) 98 | assert.Contains(t, err.Error(), "step failed") 99 | failingModule.AssertExpectations(t) 100 | }) 101 | } 102 | -------------------------------------------------------------------------------- /pkg/agents/workflows/router.go: -------------------------------------------------------------------------------- 1 | package workflows 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/XiaoConstantine/dspy-go/pkg/agents" 8 | 9 | "github.com/XiaoConstantine/dspy-go/pkg/errors" 10 | ) 11 | 12 | // RouterWorkflow directs inputs to different processing paths based on 13 | // a classification step. 14 | type RouterWorkflow struct { 15 | *BaseWorkflow 16 | // The step that determines which path to take 17 | classifierStep *Step 18 | // Maps classification outputs to step sequences 19 | routes map[string][]*Step 20 | } 21 | 22 | func NewRouterWorkflow(memory agents.Memory, classifierStep *Step) *RouterWorkflow { 23 | return &RouterWorkflow{ 24 | BaseWorkflow: NewBaseWorkflow(memory), 25 | classifierStep: classifierStep, 26 | routes: make(map[string][]*Step), 27 | } 28 | } 29 | 30 | // AddRoute associates a classification value with a sequence of steps. 31 | func (w *RouterWorkflow) AddRoute(classification string, steps []*Step) error { 32 | // Validate steps exist in workflow 33 | for _, step := range steps { 34 | if _, exists := w.stepIndex[step.ID]; !exists { 35 | return fmt.Errorf("step %s not found in workflow", step.ID) 36 | } 37 | } 38 | w.routes[classification] = steps 39 | return nil 40 | } 41 | 42 | func (w *RouterWorkflow) Execute(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { 43 | // Initialize state 44 | state := make(map[string]interface{}) 45 | for k, v := range inputs { 46 | state[k] = v 47 | } 48 | 49 | // Execute classifier step 50 | result, err := w.classifierStep.Execute(ctx, inputs) 51 | if err != nil { 52 | return nil, errors.WithFields( 53 | errors.Wrap(err, errors.WorkflowExecutionFailed, "classifier step failed"), 54 | errors.Fields{ 55 | "step_id": w.classifierStep.ID, 56 | "inputs": inputs, 57 | }) 58 | } 59 | 60 | // Get classification from result 61 | classification, ok := result.Outputs["classification"].(string) 62 | if !ok { 63 | return nil, errors.WithFields( 64 | errors.New(errors.InvalidResponse, "classifier did not return a string classification"), 65 | errors.Fields{ 66 | "actual_type": fmt.Sprintf("%T", result.Outputs["classification"]), 67 | }) 68 | } 69 | 70 | // Get route for this classification 71 | route, exists := w.routes[classification] 72 | if !exists { 73 | return nil, errors.WithFields( 74 | errors.New(errors.ResourceNotFound, "no route defined for classification"), 75 | errors.Fields{ 76 | "classification": classification, 77 | }) 78 | 79 | } 80 | 81 | // Execute steps in the selected route 82 | for _, step := range route { 83 | signature := step.Module.GetSignature() 84 | 85 | stepInputs := make(map[string]interface{}) 86 | for _, field := range signature.Inputs { 87 | if val, ok := state[field.Name]; ok { 88 | stepInputs[field.Name] = val 89 | } 90 | } 91 | 92 | result, err := step.Execute(ctx, stepInputs) 93 | if err != nil { 94 | return nil, errors.WithFields( 95 | errors.Wrap(err, errors.StepExecutionFailed, "step execution failed"), 96 | errors.Fields{ 97 | "step_id": step.ID, 98 | "inputs": stepInputs, 99 | }) 100 | } 101 | 102 | for k, v := range result.Outputs { 103 | state[k] = v 104 | } 105 | } 106 | 107 | return state, nil 108 | } 109 | -------------------------------------------------------------------------------- /pkg/agents/workflows/step_test.go: -------------------------------------------------------------------------------- 1 | package workflows 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "sync/atomic" 7 | "testing" 8 | "time" 9 | 10 | "github.com/XiaoConstantine/dspy-go/pkg/core" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/mock" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestStep(t *testing.T) { 17 | // Helper function to create a basic step with mocks 18 | setupStep := func() (*Step, *MockModule) { 19 | module := new(MockModule) 20 | module.On("GetSignature").Return(core.Signature{ 21 | Inputs: []core.InputField{{Field: core.Field{Name: "input"}}}, 22 | Outputs: []core.OutputField{{Field: core.Field{Name: "output"}}}, 23 | }) 24 | 25 | return &Step{ 26 | ID: "test_step", 27 | Module: module, 28 | }, module 29 | } 30 | 31 | t.Run("Basic execution", func(t *testing.T) { 32 | step, module := setupStep() 33 | 34 | // Set up expected behavior 35 | module.On("Process", mock.Anything, mock.Anything).Return( 36 | map[string]any{"output": "success"}, nil, 37 | ) 38 | 39 | // Execute step 40 | ctx := context.Background() 41 | result, err := step.Execute(ctx, map[string]interface{}{ 42 | "input": "test", 43 | }) 44 | 45 | // Verify results 46 | require.NoError(t, err) 47 | assert.Equal(t, "success", result.Outputs["output"]) 48 | assert.Equal(t, step.ID, result.StepID) 49 | module.AssertExpectations(t) 50 | }) 51 | 52 | t.Run("Input validation", func(t *testing.T) { 53 | step, module := setupStep() 54 | 55 | ctx := context.Background() 56 | _, err := step.Execute(ctx, map[string]interface{}{ 57 | "wrong_input": "test", // Missing required 'input' field 58 | }) 59 | 60 | assert.Error(t, err) 61 | assert.Contains(t, err.Error(), "input validation failed") 62 | module.AssertNotCalled(t, "Process") 63 | }) 64 | 65 | t.Run("Output validation", func(t *testing.T) { 66 | step, module := setupStep() 67 | 68 | module.On("Process", mock.Anything, mock.Anything).Return( 69 | map[string]any{"wrong_output": "value"}, nil, // Missing required 'output' field 70 | ) 71 | 72 | ctx := context.Background() 73 | _, err := step.Execute(ctx, map[string]interface{}{ 74 | "input": "test", 75 | }) 76 | 77 | assert.Error(t, err) 78 | assert.Contains(t, err.Error(), "output validation failed") 79 | module.AssertExpectations(t) 80 | }) 81 | 82 | t.Run("Condition check", func(t *testing.T) { 83 | step, module := setupStep() 84 | 85 | // Add condition that always fails 86 | step.Condition = func(state map[string]interface{}) bool { 87 | return false 88 | } 89 | 90 | ctx := context.Background() 91 | _, err := step.Execute(ctx, map[string]interface{}{ 92 | "input": "test", 93 | }) 94 | 95 | assert.ErrorIs(t, err, ErrStepConditionFailed) 96 | module.AssertNotCalled(t, "Process") 97 | }) 98 | 99 | t.Run("Retry logic", func(t *testing.T) { 100 | step, module := setupStep() 101 | 102 | // Configure retry 103 | step.RetryConfig = &RetryConfig{ 104 | MaxAttempts: 3, 105 | BackoffMultiplier: 1.5, 106 | } 107 | var attempts int32 108 | 109 | module.On("Process", mock.Anything, mock.Anything). 110 | Run(func(args mock.Arguments) { 111 | current := atomic.AddInt32(&attempts, 1) 112 | t.Logf("Attempt #%d", current) 113 | }). 114 | Return(make(map[string]interface{}), assert.AnError). 115 | Times(2) 116 | 117 | // Third call will succeed 118 | module.On("Process", mock.Anything, mock.Anything). 119 | Run(func(args mock.Arguments) { 120 | atomic.AddInt32(&attempts, 1) 121 | t.Logf("Final successful attempt") 122 | }). 123 | Return(map[string]interface{}{"output": "success"}, nil). 124 | Once() 125 | 126 | ctx := context.Background() 127 | result, err := step.Execute(ctx, map[string]interface{}{ 128 | "input": "test", 129 | }) 130 | 131 | // Verify results 132 | require.NoError(t, err) 133 | assert.Equal(t, "success", result.Outputs["output"]) 134 | assert.Equal(t, int32(3), atomic.LoadInt32(&attempts), 135 | "Should have attempted exactly 3 times") 136 | 137 | module.AssertExpectations(t) 138 | }) 139 | t.Run("Retry backoff timing", func(t *testing.T) { 140 | step, module := setupStep() 141 | 142 | // Configure retry with very small intervals for testing 143 | step.RetryConfig = &RetryConfig{ 144 | MaxAttempts: 2, 145 | BackoffMultiplier: 2.0, 146 | } 147 | 148 | // Record attempt times in a thread-safe way 149 | var mu sync.Mutex 150 | attemptTimes := make([]time.Time, 0, 2) 151 | 152 | module.On("Process", mock.Anything, mock.Anything). 153 | Run(func(args mock.Arguments) { 154 | mu.Lock() 155 | attemptTimes = append(attemptTimes, time.Now()) 156 | mu.Unlock() 157 | }). 158 | Return(make(map[string]interface{}), assert.AnError). 159 | Times(2) 160 | 161 | ctx := context.Background() 162 | _, err := step.Execute(ctx, map[string]interface{}{ 163 | "input": "test", 164 | }) 165 | 166 | // Should fail after max attempts 167 | require.Error(t, err) 168 | 169 | mu.Lock() 170 | recordedAttempts := attemptTimes 171 | mu.Unlock() 172 | 173 | // Verify we got the expected number of attempts 174 | assert.Equal(t, 2, len(recordedAttempts), 175 | "Should have recorded exactly 2 attempts") 176 | 177 | // Check intervals between attempts 178 | if len(recordedAttempts) >= 2 { 179 | firstInterval := recordedAttempts[1].Sub(recordedAttempts[0]) 180 | t.Logf("Interval between retries: %v", firstInterval) 181 | // Just verify the ordering is preserved and delays are positive 182 | assert.True(t, firstInterval > 0, 183 | "Time should progress forward between attempts") 184 | } 185 | 186 | module.AssertExpectations(t) 187 | }) 188 | 189 | t.Run("Context cancellation", func(t *testing.T) { 190 | step, module := setupStep() 191 | 192 | // Create cancellable context 193 | ctx, cancel := context.WithCancel(context.Background()) 194 | 195 | // Mock module to wait before responding 196 | module.On("Process", mock.Anything, mock.Anything). 197 | Run(func(args mock.Arguments) { 198 | time.Sleep(100 * time.Millisecond) 199 | }). 200 | Return(map[string]any{"output": "success"}, nil) 201 | 202 | // Cancel context shortly after starting 203 | go func() { 204 | time.Sleep(50 * time.Millisecond) 205 | cancel() 206 | }() 207 | 208 | _, err := step.Execute(ctx, map[string]interface{}{ 209 | "input": "test", 210 | }) 211 | 212 | assert.ErrorIs(t, err, context.Canceled) 213 | module.AssertExpectations(t) 214 | }) 215 | 216 | t.Run("Next steps propagation", func(t *testing.T) { 217 | step, module := setupStep() 218 | 219 | // Configure next steps 220 | step.NextSteps = []string{"step2", "step3"} 221 | 222 | module.On("Process", mock.Anything, mock.Anything).Return( 223 | map[string]any{"output": "success"}, nil, 224 | ) 225 | 226 | ctx := context.Background() 227 | result, err := step.Execute(ctx, map[string]interface{}{ 228 | "input": "test", 229 | }) 230 | 231 | require.NoError(t, err) 232 | assert.Equal(t, step.NextSteps, result.NextSteps) 233 | module.AssertExpectations(t) 234 | }) 235 | } 236 | -------------------------------------------------------------------------------- /pkg/agents/workflows/workflow.go: -------------------------------------------------------------------------------- 1 | package workflows 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/XiaoConstantine/dspy-go/pkg/agents" 8 | ) 9 | 10 | // Workflow represents a sequence of steps that accomplish a task. 11 | type Workflow interface { 12 | // Execute runs the workflow with the provided inputs 13 | Execute(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) 14 | 15 | // GetSteps returns all steps in this workflow 16 | GetSteps() []*Step 17 | 18 | // AddStep adds a new step to the workflow 19 | AddStep(step *Step) error 20 | } 21 | 22 | // BaseWorkflow provides common workflow functionality. 23 | type BaseWorkflow struct { 24 | // steps stores all steps in the workflow 25 | steps []*Step 26 | 27 | // stepIndex provides quick lookup of steps by ID 28 | stepIndex map[string]*Step 29 | 30 | // memory provides persistence between workflow runs 31 | memory agents.Memory 32 | } 33 | 34 | func NewBaseWorkflow(memory agents.Memory) *BaseWorkflow { 35 | return &BaseWorkflow{ 36 | steps: make([]*Step, 0), 37 | stepIndex: make(map[string]*Step), 38 | memory: memory, 39 | } 40 | } 41 | 42 | func (w *BaseWorkflow) AddStep(step *Step) error { 43 | // Validate step ID is unique 44 | if _, exists := w.stepIndex[step.ID]; exists { 45 | return fmt.Errorf("step with ID %s already exists", step.ID) 46 | } 47 | 48 | // Add step to workflow 49 | w.steps = append(w.steps, step) 50 | w.stepIndex[step.ID] = step 51 | 52 | return nil 53 | } 54 | 55 | func (w *BaseWorkflow) GetSteps() []*Step { 56 | return w.steps 57 | } 58 | 59 | // ValidateWorkflow checks if the workflow structure is valid. 60 | func (w *BaseWorkflow) ValidateWorkflow() error { 61 | // Check for cycles in step dependencies 62 | visited := make(map[string]bool) 63 | path := make(map[string]bool) 64 | 65 | var checkCycle func(stepID string) error 66 | checkCycle = func(stepID string) error { 67 | visited[stepID] = true 68 | path[stepID] = true 69 | 70 | step := w.stepIndex[stepID] 71 | for _, nextID := range step.NextSteps { 72 | if !visited[nextID] { 73 | if err := checkCycle(nextID); err != nil { 74 | return err 75 | } 76 | } else if path[nextID] { 77 | return fmt.Errorf("cycle detected in workflow") 78 | } 79 | } 80 | 81 | path[stepID] = false 82 | return nil 83 | } 84 | 85 | for _, step := range w.steps { 86 | if !visited[step.ID] { 87 | if err := checkCycle(step.ID); err != nil { 88 | return err 89 | } 90 | } 91 | } 92 | 93 | return nil 94 | } 95 | -------------------------------------------------------------------------------- /pkg/agents/workflows/workflow_test.go: -------------------------------------------------------------------------------- 1 | package workflows 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/XiaoConstantine/dspy-go/pkg/core" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/mock" 10 | ) 11 | 12 | // MockMemory implements agents.Memory interface for testing. 13 | type MockMemory struct { 14 | mock.Mock 15 | } 16 | 17 | func (m *MockMemory) Store(key string, value interface{}) error { 18 | args := m.Called(key, value) 19 | return args.Error(0) 20 | } 21 | 22 | func (m *MockMemory) Retrieve(key string) (interface{}, error) { 23 | args := m.Called(key) 24 | return args.Get(0), args.Error(1) 25 | } 26 | 27 | func (m *MockMemory) List() ([]string, error) { 28 | args := m.Called() 29 | return args.Get(0).([]string), args.Error(1) 30 | } 31 | 32 | func (m *MockMemory) Clear() error { 33 | args := m.Called() 34 | return args.Error(0) 35 | } 36 | 37 | // MockModule implements core.Module interface for testing. 38 | type MockModule struct { 39 | mock.Mock 40 | } 41 | 42 | func (m *MockModule) Process(ctx context.Context, inputs map[string]any, opts ...core.Option) (map[string]any, error) { 43 | args := m.Called(ctx, inputs) 44 | // Handle nil return case properly 45 | if args.Get(0) == nil { 46 | return make(map[string]any), args.Error(1) 47 | } 48 | return args.Get(0).(map[string]any), args.Error(1) 49 | } 50 | 51 | func (m *MockModule) GetSignature() core.Signature { 52 | args := m.Called() 53 | return args.Get(0).(core.Signature) 54 | } 55 | 56 | func (m *MockModule) SetSignature(signature core.Signature) { 57 | m.Called(signature) 58 | } 59 | 60 | func (m *MockModule) SetLLM(llm core.LLM) { 61 | m.Called(llm) 62 | } 63 | 64 | func (m *MockModule) Clone() core.Module { 65 | args := m.Called() 66 | return args.Get(0).(core.Module) 67 | } 68 | 69 | func TestBaseWorkflow(t *testing.T) { 70 | t.Run("AddStep success", func(t *testing.T) { 71 | memory := new(MockMemory) 72 | workflow := NewBaseWorkflow(memory) 73 | 74 | step := &Step{ 75 | ID: "test_step", 76 | Module: new(MockModule), 77 | } 78 | 79 | err := workflow.AddStep(step) 80 | assert.NoError(t, err) 81 | assert.Len(t, workflow.GetSteps(), 1) 82 | assert.Equal(t, step, workflow.stepIndex["test_step"]) 83 | }) 84 | 85 | t.Run("AddStep duplicate ID", func(t *testing.T) { 86 | memory := new(MockMemory) 87 | workflow := NewBaseWorkflow(memory) 88 | 89 | step := &Step{ 90 | ID: "test_step", 91 | Module: new(MockModule), 92 | } 93 | 94 | _ = workflow.AddStep(step) 95 | err := workflow.AddStep(step) 96 | assert.Error(t, err) 97 | assert.Contains(t, err.Error(), "already exists") 98 | }) 99 | 100 | t.Run("ValidateWorkflow success", func(t *testing.T) { 101 | memory := new(MockMemory) 102 | workflow := NewBaseWorkflow(memory) 103 | 104 | step1 := &Step{ 105 | ID: "step1", 106 | Module: new(MockModule), 107 | NextSteps: []string{"step2"}, 108 | } 109 | 110 | step2 := &Step{ 111 | ID: "step2", 112 | Module: new(MockModule), 113 | } 114 | 115 | _ = workflow.AddStep(step1) 116 | _ = workflow.AddStep(step2) 117 | 118 | err := workflow.ValidateWorkflow() 119 | assert.NoError(t, err) 120 | }) 121 | 122 | t.Run("ValidateWorkflow cyclic dependency", func(t *testing.T) { 123 | memory := new(MockMemory) 124 | workflow := NewBaseWorkflow(memory) 125 | 126 | step1 := &Step{ 127 | ID: "step1", 128 | Module: new(MockModule), 129 | NextSteps: []string{"step2"}, 130 | } 131 | 132 | step2 := &Step{ 133 | ID: "step2", 134 | Module: new(MockModule), 135 | NextSteps: []string{"step1"}, 136 | } 137 | 138 | _ = workflow.AddStep(step1) 139 | _ = workflow.AddStep(step2) 140 | 141 | err := workflow.ValidateWorkflow() 142 | assert.Error(t, err) 143 | assert.Contains(t, err.Error(), "cycle detected") 144 | }) 145 | } 146 | -------------------------------------------------------------------------------- /pkg/core/config.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | type Config struct { 8 | DefaultLLM LLM 9 | TeacherLLM LLM 10 | ConcurrencyLevel int 11 | } 12 | 13 | var GlobalConfig = &Config{ 14 | // default concurrency 1 15 | ConcurrencyLevel: 1, 16 | } 17 | 18 | // ConfigureDefaultLLM sets up the default LLM to be used across the package. 19 | func ConfigureDefaultLLM(apiKey string, modelID ModelID) error { 20 | llmInstance, err := DefaultFactory.CreateLLM(apiKey, modelID) 21 | if err != nil { 22 | return fmt.Errorf("failed to configure default LLM: %w", err) 23 | } 24 | GlobalConfig.DefaultLLM = llmInstance 25 | return nil 26 | } 27 | 28 | // ConfigureTeacherLLM sets up the teacher LLM. 29 | func ConfigureTeacherLLM(apiKey string, modelID ModelID) error { 30 | llmInstance, err := DefaultFactory.CreateLLM(apiKey, modelID) 31 | if err != nil { 32 | return fmt.Errorf("failed to configure teacher LLM: %w", err) 33 | } 34 | GlobalConfig.TeacherLLM = llmInstance 35 | return nil 36 | } 37 | 38 | // GetDefaultLLM returns the default LLM. 39 | func GetDefaultLLM() LLM { 40 | return GlobalConfig.DefaultLLM 41 | } 42 | 43 | // GetTeacherLLM returns the teacher LLM. 44 | func GetTeacherLLM() LLM { 45 | return GlobalConfig.TeacherLLM 46 | } 47 | 48 | func SetConcurrencyOptions(level int) { 49 | if level > 0 { 50 | GlobalConfig.ConcurrencyLevel = level 51 | } else { 52 | GlobalConfig.ConcurrencyLevel = 1 // Reset to default value for invalid inputs 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /pkg/core/decorators.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import "context" 4 | 5 | // BaseDecorator provides common functionality for all LLM decorators. 6 | type BaseDecorator struct { 7 | LLM 8 | } 9 | 10 | // ModelContextDecorator adds model context tracking. 11 | type ModelContextDecorator struct { 12 | BaseDecorator 13 | } 14 | 15 | func NewModelContextDecorator(base LLM) *ModelContextDecorator { 16 | return &ModelContextDecorator{ 17 | BaseDecorator: BaseDecorator{LLM: base}, 18 | } 19 | } 20 | 21 | func (d *BaseDecorator) Unwrap() LLM { 22 | return d.LLM 23 | } 24 | 25 | func (d *ModelContextDecorator) Generate(ctx context.Context, prompt string, options ...GenerateOption) (*LLMResponse, error) { 26 | if state := GetExecutionState(ctx); state != nil { 27 | state.WithModelID(d.ModelID()) 28 | } 29 | return d.LLM.Generate(ctx, prompt, options...) 30 | } 31 | 32 | // Helper function to compose multiple decorators. 33 | func Chain(base LLM, decorators ...func(LLM) LLM) LLM { 34 | result := base 35 | for _, d := range decorators { 36 | result = d(result) 37 | } 38 | return result 39 | } 40 | -------------------------------------------------------------------------------- /pkg/core/execution_context.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "encoding/hex" 7 | "fmt" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | ) 12 | 13 | // ExecutionState holds the mutable state for an execution context. 14 | type ExecutionState struct { 15 | mu sync.RWMutex 16 | 17 | // Execution metadata 18 | traceID string 19 | spans []*Span 20 | activeSpan *Span 21 | 22 | // LLM-specific state 23 | modelID string 24 | tokenUsage *TokenUsage 25 | 26 | // Custom annotations 27 | annotations map[string]interface{} 28 | } 29 | 30 | // Span represents a single operation within the execution. 31 | type Span struct { 32 | ID string 33 | ParentID string 34 | Operation string 35 | StartTime time.Time 36 | EndTime time.Time 37 | Error error 38 | Annotations map[string]interface{} 39 | } 40 | 41 | // TokenUsage tracks token consumption. 42 | type TokenUsage struct { 43 | PromptTokens int 44 | CompletionTokens int 45 | TotalTokens int 46 | Cost float64 47 | } 48 | 49 | type spanIDGenerator struct { 50 | // counter ensures uniqueness even with identical timestamps 51 | counter uint64 52 | // lastTimestamp helps detect time backwards movement 53 | lastTimestamp int64 54 | } 55 | 56 | // ExecutionContextKey is the type for context keys specific to dspy-go. 57 | type ExecutionContextKey struct { 58 | name string 59 | } 60 | 61 | var ( 62 | stateKey = &ExecutionContextKey{"dspy-state"} 63 | defaultGenerator = &spanIDGenerator{} 64 | ) 65 | 66 | // WithExecutionState creates a new context with dspy-go execution state. 67 | func WithExecutionState(ctx context.Context) context.Context { 68 | if GetExecutionState(ctx) != nil { 69 | return ctx // State already exists 70 | } 71 | return context.WithValue(ctx, stateKey, &ExecutionState{ 72 | traceID: generateTraceID(), 73 | annotations: make(map[string]interface{}), 74 | spans: make([]*Span, 0), 75 | }) 76 | } 77 | 78 | // GetExecutionState retrieves the execution state from a context. 79 | func GetExecutionState(ctx context.Context) *ExecutionState { 80 | if state, ok := ctx.Value(stateKey).(*ExecutionState); ok { 81 | return state 82 | } 83 | return nil 84 | } 85 | 86 | // StartSpan begins a new operation span. 87 | func StartSpan(ctx context.Context, operation string) (context.Context, *Span) { 88 | state := GetExecutionState(ctx) 89 | if state == nil { 90 | ctx = WithExecutionState(ctx) 91 | state = GetExecutionState(ctx) 92 | } 93 | 94 | state.mu.Lock() 95 | defer state.mu.Unlock() 96 | 97 | span := &Span{ 98 | ID: generateSpanID(), // Implementation needed 99 | Operation: operation, 100 | StartTime: time.Now(), 101 | Annotations: make(map[string]interface{}), 102 | } 103 | 104 | if state.activeSpan != nil { 105 | span.ParentID = state.activeSpan.ID 106 | } 107 | 108 | state.spans = append(state.spans, span) 109 | state.activeSpan = span 110 | 111 | return ctx, span 112 | } 113 | 114 | // EndSpan completes the current span. 115 | func EndSpan(ctx context.Context) { 116 | if state := GetExecutionState(ctx); state != nil { 117 | state.mu.Lock() 118 | defer state.mu.Unlock() 119 | 120 | if state.activeSpan != nil { 121 | state.activeSpan.EndTime = time.Now() 122 | state.activeSpan = nil 123 | } 124 | } 125 | } 126 | 127 | // State modification methods. 128 | func (s *ExecutionState) WithModelID(modelID string) { 129 | s.mu.Lock() 130 | defer s.mu.Unlock() 131 | s.modelID = modelID 132 | } 133 | 134 | func (s *ExecutionState) WithTokenUsage(usage *TokenUsage) { 135 | s.mu.Lock() 136 | defer s.mu.Unlock() 137 | s.tokenUsage = usage 138 | } 139 | 140 | // State access methods. 141 | func (s *ExecutionState) GetModelID() string { 142 | s.mu.RLock() 143 | defer s.mu.RUnlock() 144 | return s.modelID 145 | } 146 | 147 | func (s *ExecutionState) GetTokenUsage() *TokenUsage { 148 | s.mu.RLock() 149 | defer s.mu.RUnlock() 150 | return s.tokenUsage 151 | } 152 | 153 | // Span methods. 154 | func (s *Span) WithError(err error) { 155 | s.Error = err 156 | } 157 | 158 | func (s *Span) WithAnnotation(key string, value interface{}) { 159 | s.Annotations[key] = value 160 | } 161 | 162 | // Helper method to collect all spans. 163 | func CollectSpans(ctx context.Context) []*Span { 164 | if state := GetExecutionState(ctx); state != nil { 165 | state.mu.RLock() 166 | defer state.mu.RUnlock() 167 | 168 | spans := make([]*Span, len(state.spans)) 169 | copy(spans, state.spans) 170 | return spans 171 | } 172 | return nil 173 | } 174 | 175 | // generateSpanID creates a new unique span identifier. 176 | // The format is: 8 bytes total 177 | // - 4 bytes: timestamp (seconds since epoch) 178 | // - 2 bytes: counter 179 | // - 2 bytes: random data 180 | // This provides a good balance of: 181 | // - Temporal ordering (timestamp component) 182 | // - Uniqueness guarantee (counter component) 183 | // - Collision resistance (random component) 184 | // 185 | // Example: 186 | // 63f51a2a01ab9c8d 187 | // │ │ └─┴─ Random component (2 bytes) 188 | // │ └─┴─ Counter (2 bytes) 189 | // └─┴─┴─┴─ Timestamp (4 bytes). 190 | func generateSpanID() string { 191 | // Get current timestamp 192 | now := time.Now().Unix() 193 | 194 | // Increment counter atomically 195 | counter := atomic.AddUint64(&defaultGenerator.counter, 1) 196 | 197 | // Create buffer for our ID components 198 | id := make([]byte, 8) 199 | 200 | // Add timestamp (4 bytes) 201 | id[0] = byte(now >> 24) 202 | id[1] = byte(now >> 16) 203 | id[2] = byte(now >> 8) 204 | id[3] = byte(now) 205 | 206 | // Add counter (2 bytes) 207 | id[4] = byte(counter >> 8) 208 | id[5] = byte(counter) 209 | 210 | // Add random component (2 bytes) 211 | if _, err := rand.Read(id[6:]); err != nil { 212 | // Fallback to using more counter bits if random fails 213 | id[6] = byte(counter >> 16) 214 | id[7] = byte(counter >> 24) 215 | } 216 | 217 | // Return hex-encoded string 218 | return hex.EncodeToString(id) 219 | } 220 | 221 | // For testing and debugging. 222 | func resetSpanIDGenerator() { 223 | atomic.StoreUint64(&defaultGenerator.counter, 0) 224 | defaultGenerator.lastTimestamp = 0 225 | } 226 | 227 | func (s *ExecutionState) GetCurrentSpan() *Span { 228 | s.mu.RLock() 229 | defer s.mu.RUnlock() 230 | return s.activeSpan 231 | } 232 | 233 | func generateTraceID() string { 234 | // Generate 16 random bytes for trace ID 235 | b := make([]byte, 16) 236 | if _, err := rand.Read(b); err != nil { 237 | // Fallback to timestamp-based ID if random generation fails 238 | now := time.Now().UnixNano() 239 | return fmt.Sprintf("trace-%d", now) 240 | } 241 | 242 | // Format as hex string 243 | return hex.EncodeToString(b) 244 | } 245 | 246 | func (s *ExecutionState) GetTraceID() string { 247 | s.mu.RLock() 248 | defer s.mu.RUnlock() 249 | return s.traceID 250 | } 251 | -------------------------------------------------------------------------------- /pkg/core/execution_context_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "encoding/hex" 5 | "sync" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestGenerateSpanID(t *testing.T) { 11 | // Reset generator state 12 | resetSpanIDGenerator() 13 | 14 | // Test basic functionality 15 | t.Run("Basic Generation", func(t *testing.T) { 16 | id := generateSpanID() 17 | 18 | // Verify length (16 characters for 8 bytes hex-encoded) 19 | if len(id) != 16 { 20 | t.Errorf("Expected ID length of 16, got %d", len(id)) 21 | } 22 | 23 | // Verify it's valid hex 24 | _, err := hex.DecodeString(id) 25 | if err != nil { 26 | t.Errorf("Invalid hex string: %v", err) 27 | } 28 | }) 29 | 30 | // Test uniqueness 31 | t.Run("Uniqueness", func(t *testing.T) { 32 | const iterations = 10000 33 | ids := make(map[string]bool) 34 | 35 | for i := 0; i < iterations; i++ { 36 | id := generateSpanID() 37 | if ids[id] { 38 | t.Errorf("Duplicate ID generated: %s", id) 39 | } 40 | ids[id] = true 41 | } 42 | }) 43 | 44 | // Test concurrent generation 45 | t.Run("Concurrent Generation", func(t *testing.T) { 46 | const goroutines = 10 47 | const idsPerRoutine = 1000 48 | 49 | var wg sync.WaitGroup 50 | ids := make(chan string, goroutines*idsPerRoutine) 51 | 52 | for i := 0; i < goroutines; i++ { 53 | wg.Add(1) 54 | go func() { 55 | defer wg.Done() 56 | for j := 0; j < idsPerRoutine; j++ { 57 | ids <- generateSpanID() 58 | } 59 | }() 60 | } 61 | 62 | wg.Wait() 63 | close(ids) 64 | 65 | // Check for duplicates 66 | seen := make(map[string]bool) 67 | for id := range ids { 68 | if seen[id] { 69 | t.Errorf("Duplicate ID generated in concurrent test: %s", id) 70 | } 71 | seen[id] = true 72 | } 73 | }) 74 | 75 | // Test timestamp component 76 | t.Run("Timestamp Component", func(t *testing.T) { 77 | // Generate two IDs with a time gap 78 | id1 := generateSpanID() 79 | time.Sleep(2 * time.Second) 80 | id2 := generateSpanID() 81 | 82 | // Convert hex to bytes 83 | bytes1, _ := hex.DecodeString(id1) 84 | bytes2, _ := hex.DecodeString(id2) 85 | 86 | // Extract timestamps (first 4 bytes) 87 | timestamp1 := uint32(bytes1[0])<<24 | uint32(bytes1[1])<<16 | uint32(bytes1[2])<<8 | uint32(bytes1[3]) 88 | timestamp2 := uint32(bytes2[0])<<24 | uint32(bytes2[1])<<16 | uint32(bytes2[2])<<8 | uint32(bytes2[3]) 89 | 90 | if timestamp2 <= timestamp1 { 91 | t.Errorf("Second timestamp not greater than first: %d <= %d", timestamp2, timestamp1) 92 | } 93 | }) 94 | } 95 | 96 | func BenchmarkGenerateSpanID(b *testing.B) { 97 | for i := 0; i < b.N; i++ { 98 | generateSpanID() 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /pkg/core/factory.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | // LLMFactory defines a simple interface for creating LLM instances. 4 | // This maintains compatibility with existing code while allowing for configuration. 5 | type LLMFactory interface { 6 | // CreateLLM creates a new LLM instance. It uses the global configuration 7 | // from core.GlobalConfig for client settings. 8 | CreateLLM(apiKey string, modelID ModelID) (LLM, error) 9 | } 10 | 11 | // DefaultFactory is the global factory instance used by the configuration system. 12 | var DefaultFactory LLMFactory 13 | -------------------------------------------------------------------------------- /pkg/core/llm_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | // TestGenerateOptions tests the GenerateOptions and related functions. 9 | func TestGenerateOptions(t *testing.T) { 10 | opts := &GenerateOptions{} 11 | 12 | WithMaxTokens(100)(opts) 13 | if opts.MaxTokens != 100 { 14 | t.Errorf("Expected MaxTokens 100, got %d", opts.MaxTokens) 15 | } 16 | 17 | WithTemperature(0.7)(opts) 18 | if opts.Temperature != 0.7 { 19 | t.Errorf("Expected Temperature 0.7, got %f", opts.Temperature) 20 | } 21 | 22 | WithTopP(0.9)(opts) 23 | if opts.TopP != 0.9 { 24 | t.Errorf("Expected TopP 0.9, got %f", opts.TopP) 25 | } 26 | 27 | WithPresencePenalty(1.0)(opts) 28 | if opts.PresencePenalty != 1.0 { 29 | t.Errorf("Expected PresencePenalty 1.0, got %f", opts.PresencePenalty) 30 | } 31 | 32 | WithFrequencyPenalty(1.5)(opts) 33 | if opts.FrequencyPenalty != 1.5 { 34 | t.Errorf("Expected FrequencyPenalty 1.5, got %f", opts.FrequencyPenalty) 35 | } 36 | 37 | WithStopSequences("stop1", "stop2")(opts) 38 | if len(opts.Stop) != 2 || opts.Stop[0] != "stop1" || opts.Stop[1] != "stop2" { 39 | t.Errorf("Expected Stop sequences [stop1 stop2], got %v", opts.Stop) 40 | } 41 | } 42 | 43 | // TestMockLLM tests the MockLLM implementation. 44 | func TestMockLLM(t *testing.T) { 45 | llm := &MockLLM{} 46 | 47 | response, err := llm.Generate(context.Background(), "test prompt") 48 | if err != nil { 49 | t.Errorf("Unexpected error: %v", err) 50 | } 51 | if response.Content != "mock response" { 52 | t.Errorf("Expected 'mock response', got '%s'", response.Content) 53 | } 54 | 55 | jsonResponse, err := llm.GenerateWithJSON(context.Background(), "test prompt") 56 | if err != nil { 57 | t.Errorf("Unexpected error: %v", err) 58 | } 59 | if jsonResponse["response"] != "mock response" { 60 | t.Errorf("Expected {'response': 'mock response'}, got %v", jsonResponse) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /pkg/core/module.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | ) 7 | 8 | // Module represents a basic unit of computation in DSPy. 9 | type Module interface { 10 | // Process executes the module's logic 11 | Process(ctx context.Context, inputs map[string]any, opts ...Option) (map[string]any, error) 12 | 13 | // GetSignature returns the module's input and output signature 14 | GetSignature() Signature 15 | 16 | SetSignature(signature Signature) 17 | 18 | // SetLLM sets the language model for the module 19 | SetLLM(llm LLM) 20 | 21 | // Clone creates a deep copy of the module 22 | Clone() Module 23 | } 24 | 25 | type Option func(*ModuleOptions) 26 | 27 | type StreamHandler func(chunk StreamChunk) error 28 | 29 | // WithStreamHandler returns an option to enable streaming. 30 | func WithStreamHandler(handler StreamHandler) Option { 31 | return func(o *ModuleOptions) { 32 | o.StreamHandler = handler 33 | } 34 | } 35 | 36 | // ModuleOptions holds configuration that can be passed to modules. 37 | type ModuleOptions struct { 38 | // LLM generation options 39 | GenerateOptions []GenerateOption 40 | 41 | StreamHandler StreamHandler 42 | } 43 | 44 | // WithGenerateOptions adds LLM generation options. 45 | func WithGenerateOptions(opts ...GenerateOption) Option { 46 | return func(o *ModuleOptions) { 47 | o.GenerateOptions = append(o.GenerateOptions, opts...) 48 | } 49 | } 50 | 51 | // Clone creates a copy of ModuleOptions. 52 | func (o *ModuleOptions) Clone() *ModuleOptions { 53 | if o == nil { 54 | return nil 55 | } 56 | return &ModuleOptions{ 57 | GenerateOptions: append([]GenerateOption{}, o.GenerateOptions...), 58 | StreamHandler: o.StreamHandler, // Copy the StreamHandler reference 59 | } 60 | } 61 | 62 | // MergeWith merges this options with other options, with other taking precedence. 63 | func (o *ModuleOptions) MergeWith(other *ModuleOptions) *ModuleOptions { 64 | if other == nil { 65 | return o.Clone() 66 | } 67 | merged := o.Clone() 68 | if merged == nil { 69 | merged = &ModuleOptions{} 70 | } 71 | merged.GenerateOptions = append(merged.GenerateOptions, other.GenerateOptions...) 72 | // If other has a StreamHandler, it takes precedence 73 | if other.StreamHandler != nil { 74 | merged.StreamHandler = other.StreamHandler 75 | } 76 | return merged 77 | } 78 | 79 | func WithOptions(opts ...Option) Option { 80 | return func(o *ModuleOptions) { 81 | for _, opt := range opts { 82 | opt(o) 83 | } 84 | } 85 | } 86 | 87 | // BaseModule provides a basic implementation of the Module interface. 88 | type BaseModule struct { 89 | Signature Signature 90 | LLM LLM 91 | } 92 | 93 | // GetSignature returns the module's signature. 94 | func (bm *BaseModule) GetSignature() Signature { 95 | return bm.Signature 96 | } 97 | 98 | // SetLLM sets the language model for the module. 99 | func (bm *BaseModule) SetLLM(llm LLM) { 100 | bm.LLM = llm 101 | } 102 | 103 | func (bm *BaseModule) SetSignature(signature Signature) { 104 | bm.Signature = signature 105 | } 106 | 107 | // Process is a placeholder implementation and should be overridden by specific modules. 108 | func (bm *BaseModule) Process(ctx context.Context, inputs map[string]any, opts ...Option) (map[string]any, error) { 109 | return nil, errors.New("Process method not implemented") 110 | } 111 | 112 | // Clone creates a deep copy of the BaseModule. 113 | func (bm *BaseModule) Clone() Module { 114 | return &BaseModule{ 115 | Signature: bm.Signature, 116 | LLM: bm.LLM, // Note: This is a shallow copy of the LLM 117 | } 118 | } 119 | 120 | // NewModule creates a new base module with the given signature. 121 | func NewModule(signature Signature) *BaseModule { 122 | return &BaseModule{ 123 | Signature: signature, 124 | } 125 | } 126 | 127 | // ValidateInputs checks if the provided inputs match the module's input signature. 128 | func (bm *BaseModule) ValidateInputs(inputs map[string]any) error { 129 | for _, field := range bm.Signature.Inputs { 130 | if _, ok := inputs[field.Name]; !ok { 131 | return errors.New("missing required input: " + field.Name) 132 | } 133 | } 134 | return nil 135 | } 136 | 137 | // FormatOutputs ensures that the output map contains all fields specified in the output signature. 138 | func (bm *BaseModule) FormatOutputs(outputs map[string]any) map[string]any { 139 | formattedOutputs := make(map[string]any) 140 | for _, field := range bm.Signature.Outputs { 141 | if value, ok := outputs[field.Name]; ok { 142 | formattedOutputs[field.Name] = value 143 | } else { 144 | formattedOutputs[field.Name] = nil 145 | } 146 | } 147 | return formattedOutputs 148 | } 149 | 150 | // Composable is an interface for modules that can be composed with other modules. 151 | type Composable interface { 152 | Module 153 | Compose(next Module) Module 154 | GetSubModules() []Module 155 | SetSubModules([]Module) 156 | } 157 | 158 | // ModuleChain represents a chain of modules. 159 | type ModuleChain struct { 160 | BaseModule 161 | Modules []Module 162 | } 163 | 164 | // NewModuleChain creates a new module chain. 165 | func NewModuleChain(modules ...Module) *ModuleChain { 166 | // Compute the combined signature 167 | var inputs []InputField 168 | var outputs []OutputField 169 | for i, m := range modules { 170 | sig := m.GetSignature() 171 | if i == 0 { 172 | inputs = sig.Inputs 173 | } 174 | if i == len(modules)-1 { 175 | outputs = sig.Outputs 176 | } 177 | } 178 | 179 | return &ModuleChain{ 180 | BaseModule: BaseModule{ 181 | Signature: Signature{Inputs: inputs, Outputs: outputs}, 182 | }, 183 | Modules: modules, 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /pkg/core/module_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | // // MockLLM is a mock implementation of the LLM interface for testing. 11 | type MockLLM struct{} 12 | 13 | func (m *MockLLM) Generate(ctx context.Context, prompt string, options ...GenerateOption) (*LLMResponse, error) { 14 | return &LLMResponse{Content: "mock response"}, nil 15 | } 16 | 17 | func (m *MockLLM) GenerateWithJSON(ctx context.Context, prompt string, options ...GenerateOption) (map[string]interface{}, error) { 18 | return map[string]interface{}{"response": "mock response"}, nil 19 | } 20 | 21 | func (m *MockLLM) GenerateWithFunctions(ctx context.Context, prompt string, functions []map[string]interface{}, options ...GenerateOption) (map[string]interface{}, error) { 22 | return nil, nil 23 | } 24 | 25 | func (m *MockLLM) CreateEmbedding(ctx context.Context, input string, options ...EmbeddingOption) (*EmbeddingResult, error) { 26 | return &EmbeddingResult{ 27 | // Using float32 for the vector as embeddings are typically floating point numbers 28 | Vector: []float32{0.1, 0.2, 0.3}, 29 | // Include token count to simulate real embedding behavior 30 | TokenCount: len(strings.Fields(input)), 31 | // Add metadata to simulate real response 32 | Metadata: map[string]interface{}{}, 33 | }, nil 34 | } 35 | 36 | func (m *MockLLM) CreateEmbeddings(ctx context.Context, inputs []string, options ...EmbeddingOption) (*BatchEmbeddingResult, error) { 37 | opts := NewEmbeddingOptions() 38 | for _, opt := range options { 39 | opt(opts) 40 | } 41 | 42 | // Create mock results for each input 43 | embeddings := make([]EmbeddingResult, len(inputs)) 44 | for i, input := range inputs { 45 | embeddings[i] = EmbeddingResult{ 46 | // Each embedding gets slightly different values to simulate real behavior 47 | Vector: []float32{0.1 * float32(i+1), 0.2 * float32(i+1), 0.3 * float32(i+1)}, 48 | TokenCount: len(strings.Fields(input)), 49 | Metadata: map[string]interface{}{ 50 | "model": opts.Model, 51 | "input_length": len(input), 52 | "batch_index": i, 53 | }, 54 | } 55 | } 56 | 57 | // Return the batch result 58 | return &BatchEmbeddingResult{ 59 | Embeddings: embeddings, 60 | Error: nil, 61 | ErrorIndex: -1, // -1 indicates no error 62 | }, nil 63 | } 64 | 65 | func (m *MockLLM) StreamGenerate(ctx context.Context, prompt string, opts ...GenerateOption) (*StreamResponse, error) { 66 | return nil, nil 67 | } 68 | 69 | func (m *MockLLM) ProviderName() string { 70 | return "mock" 71 | } 72 | 73 | func (m *MockLLM) ModelID() string { 74 | return "mock" 75 | } 76 | 77 | func (m *MockLLM) Capabilities() []Capability { 78 | return []Capability{} 79 | } 80 | 81 | // // TestBaseModule tests the BaseModule struct and its methods. 82 | func TestBaseModule(t *testing.T) { 83 | sig := NewSignature( 84 | []InputField{{Field: Field{Name: "input"}}}, 85 | []OutputField{{Field: Field{Name: "output"}}}, 86 | ) 87 | bm := NewModule(sig) 88 | 89 | if !reflect.DeepEqual(bm.GetSignature(), sig) { 90 | t.Error("GetSignature did not return the correct signature") 91 | } 92 | 93 | mockLLM := &MockLLM{} 94 | bm.SetLLM(mockLLM) 95 | if bm.LLM != mockLLM { 96 | t.Error("SetLLM did not set the LLM correctly") 97 | } 98 | 99 | _, err := bm.Process(context.Background(), map[string]any{"input": "test"}) 100 | if err == nil || err.Error() != "Process method not implemented" { 101 | t.Error("Expected 'Process method not implemented' error") 102 | } 103 | 104 | clone := bm.Clone() 105 | if !reflect.DeepEqual(clone.GetSignature(), bm.GetSignature()) { 106 | t.Error("Cloned module does not have the same signature") 107 | } 108 | } 109 | 110 | // TestModuleChain tests the ModuleChain struct and its methods. 111 | func TestModuleChain(t *testing.T) { 112 | module1 := NewModule(NewSignature( 113 | []InputField{{Field: Field{Name: "input1"}}}, 114 | []OutputField{{Field: Field{Name: "output1"}}}, 115 | )) 116 | module2 := NewModule(NewSignature( 117 | []InputField{{Field: Field{Name: "input2"}}}, 118 | []OutputField{{Field: Field{Name: "output2"}}}, 119 | )) 120 | 121 | chain := NewModuleChain(module1, module2) 122 | 123 | if len(chain.Modules) != 2 { 124 | t.Errorf("Expected 2 modules in chain, got %d", len(chain.Modules)) 125 | } 126 | 127 | sig := chain.GetSignature() 128 | if len(sig.Inputs) != 1 || sig.Inputs[0].Name != "input1" { 129 | t.Error("Chain signature inputs are incorrect") 130 | } 131 | if len(sig.Outputs) != 1 || sig.Outputs[0].Name != "output2" { 132 | t.Error("Chain signature outputs are incorrect") 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /pkg/core/optimizer.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | ) 7 | 8 | // Optimizer represents an interface for optimizing DSPy programs. 9 | type Optimizer interface { 10 | // Compile optimizes a given program using the provided dataset and metric 11 | Compile(ctx context.Context, program Program, dataset Dataset, metric Metric) (Program, error) 12 | } 13 | 14 | // Metric is a function type that evaluates the performance of a program. 15 | type Metric func(expected, actual map[string]interface{}) float64 16 | 17 | // Dataset represents a collection of examples for training/evaluation. 18 | type Dataset interface { 19 | // Next returns the next example in the dataset 20 | Next() (Example, bool) 21 | // Reset resets the dataset iterator 22 | Reset() 23 | } 24 | 25 | // Example represents a single training/evaluation example. 26 | type Example struct { 27 | Inputs map[string]interface{} 28 | Outputs map[string]interface{} 29 | } 30 | 31 | // BaseOptimizer provides a basic implementation of the Optimizer interface. 32 | type BaseOptimizer struct { 33 | Name string 34 | } 35 | 36 | // Compile is a placeholder implementation and should be overridden by specific optimizer implementations. 37 | func (bo *BaseOptimizer) Compile(ctx context.Context, program Program, dataset Dataset, metric Metric) (Program, error) { 38 | return Program{}, errors.New("Compile method not implemented") 39 | } 40 | 41 | // OptimizerFactory is a function type for creating Optimizer instances. 42 | type OptimizerFactory func() (Optimizer, error) 43 | 44 | // OptimizerRegistry maintains a registry of available Optimizer implementations. 45 | type OptimizerRegistry struct { 46 | factories map[string]OptimizerFactory 47 | } 48 | 49 | // NewOptimizerRegistry creates a new OptimizerRegistry. 50 | func NewOptimizerRegistry() *OptimizerRegistry { 51 | return &OptimizerRegistry{ 52 | factories: make(map[string]OptimizerFactory), 53 | } 54 | } 55 | 56 | // Register adds a new Optimizer factory to the registry. 57 | func (r *OptimizerRegistry) Register(name string, factory OptimizerFactory) { 58 | r.factories[name] = factory 59 | } 60 | 61 | // Create instantiates a new Optimizer based on the given name. 62 | func (r *OptimizerRegistry) Create(name string) (Optimizer, error) { 63 | factory, exists := r.factories[name] 64 | if !exists { 65 | return nil, errors.New("unknown Optimizer type: " + name) 66 | } 67 | return factory() 68 | } 69 | 70 | // CompileOptions represents options for the compilation process. 71 | type CompileOptions struct { 72 | MaxTrials int 73 | Teacher *Program 74 | } 75 | 76 | // WithMaxTrials sets the maximum number of trials for optimization. 77 | func WithMaxTrials(n int) func(*CompileOptions) { 78 | return func(o *CompileOptions) { 79 | o.MaxTrials = n 80 | } 81 | } 82 | 83 | // WithTeacher sets a teacher program for optimization. 84 | func WithTeacher(teacher *Program) func(*CompileOptions) { 85 | return func(o *CompileOptions) { 86 | o.Teacher = teacher 87 | } 88 | } 89 | 90 | // BootstrapFewShot implements a basic few-shot learning optimizer. 91 | type BootstrapFewShot struct { 92 | BaseOptimizer 93 | MaxExamples int 94 | } 95 | 96 | // NewBootstrapFewShot creates a new BootstrapFewShot optimizer. 97 | func NewBootstrapFewShot(maxExamples int) *BootstrapFewShot { 98 | return &BootstrapFewShot{ 99 | BaseOptimizer: BaseOptimizer{Name: "BootstrapFewShot"}, 100 | MaxExamples: maxExamples, 101 | } 102 | } 103 | 104 | // Compile implements the optimization logic for BootstrapFewShot. 105 | func (bfs *BootstrapFewShot) Compile(ctx context.Context, program Program, dataset Dataset, metric Metric) (Program, error) { 106 | // Implementation of bootstrap few-shot learning 107 | // This is a placeholder and should be implemented based on the DSPy paper's description 108 | return program, nil 109 | } 110 | 111 | type ProgressReporter interface { 112 | Report(stage string, processed, total int) 113 | } 114 | -------------------------------------------------------------------------------- /pkg/core/optimizer_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | // TestOptimizerRegistry tests the OptimizerRegistry. 9 | func TestOptimizerRegistry(t *testing.T) { 10 | registry := NewOptimizerRegistry() 11 | 12 | // Test registering an Optimizer 13 | registry.Register("test", func() (Optimizer, error) { 14 | return &MockOptimizer{}, nil 15 | }) 16 | 17 | // Test creating a registered Optimizer 18 | optimizer, err := registry.Create("test") 19 | if err != nil { 20 | t.Errorf("Unexpected error creating Optimizer: %v", err) 21 | } 22 | if _, ok := optimizer.(*MockOptimizer); !ok { 23 | t.Error("Created Optimizer is not of expected type") 24 | } 25 | 26 | // Test creating an unregistered Optimizer 27 | _, err = registry.Create("nonexistent") 28 | if err == nil { 29 | t.Error("Expected error when creating unregistered Optimizer, got nil") 30 | } 31 | } 32 | 33 | // TestCompileOptions tests the CompileOptions and related functions. 34 | func TestCompileOptions(t *testing.T) { 35 | opts := &CompileOptions{} 36 | 37 | WithMaxTrials(10)(opts) 38 | if opts.MaxTrials != 10 { 39 | t.Errorf("Expected MaxTrials 10, got %d", opts.MaxTrials) 40 | } 41 | 42 | teacherProgram := &Program{ 43 | Modules: map[string]Module{ 44 | "test": NewModule(NewSignature( 45 | []InputField{{Field: Field{Name: "input"}}}, 46 | []OutputField{{Field: Field{Name: "output"}}}, 47 | )), 48 | }, 49 | Forward: func(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { 50 | return inputs, nil 51 | }, 52 | } 53 | 54 | WithTeacher(teacherProgram)(opts) 55 | if opts.Teacher == nil { 56 | t.Error("Expected Teacher program to be set") 57 | } else { 58 | if len(opts.Teacher.Modules) != 1 { 59 | t.Errorf("Expected 1 module in Teacher program, got %d", len(opts.Teacher.Modules)) 60 | } 61 | if opts.Teacher.Forward == nil { 62 | t.Error("Expected Forward function to be set in Teacher program") 63 | } 64 | } 65 | } 66 | 67 | // TestBootstrapFewShot tests the BootstrapFewShot optimizer. 68 | func TestBootstrapFewShot(t *testing.T) { 69 | optimizer := NewBootstrapFewShot(5) 70 | 71 | if optimizer.MaxExamples != 5 { 72 | t.Errorf("Expected MaxExamples 5, got %d", optimizer.MaxExamples) 73 | } 74 | 75 | // Create a simple program for testing 76 | program := NewProgram(map[string]Module{ 77 | "test": NewModule(NewSignature( 78 | []InputField{{Field: Field{Name: "input"}}}, 79 | []OutputField{{Field: Field{Name: "output"}}}, 80 | )), 81 | }, nil) 82 | 83 | // Create a simple dataset for testing 84 | dataset := &MockDataset{} 85 | 86 | // Create a simple metric for testing 87 | metric := func(expected, actual map[string]interface{}) float64 { 88 | return 1.0 // Always return 1.0 for this test 89 | } 90 | 91 | optimizedProgram, err := optimizer.Compile(context.Background(), program, dataset, metric) 92 | if err != nil { 93 | t.Errorf("Unexpected error: %v", err) 94 | } 95 | 96 | if len(optimizedProgram.Modules) != 1 { 97 | t.Errorf("Expected 1 module in optimized program, got %d", len(optimizedProgram.Modules)) 98 | } 99 | } 100 | 101 | // MockOptimizer is a mock implementation of the Optimizer interface for testing. 102 | type MockOptimizer struct{} 103 | 104 | func (m *MockOptimizer) Compile(ctx context.Context, program Program, dataset Dataset, metric Metric) (Program, error) { 105 | return program, nil 106 | } 107 | 108 | // MockDataset is a mock implementation of the Dataset interface for testing. 109 | type MockDataset struct { 110 | data []Example 111 | index int 112 | } 113 | 114 | func (m *MockDataset) Next() (Example, bool) { 115 | if m.index >= len(m.data) { 116 | return Example{}, false 117 | } 118 | example := m.data[m.index] 119 | m.index++ 120 | return example, true 121 | } 122 | 123 | func (m *MockDataset) Reset() { 124 | m.index = 0 125 | } 126 | -------------------------------------------------------------------------------- /pkg/core/program.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "reflect" 7 | "sort" 8 | ) 9 | 10 | // Program represents a complete DSPy pipeline or workflow. 11 | type Program struct { 12 | Modules map[string]Module 13 | Forward func(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) 14 | } 15 | 16 | // NewProgram creates a new Program with the given modules and forward function. 17 | func NewProgram(modules map[string]Module, forward func(context.Context, map[string]interface{}) (map[string]interface{}, error)) Program { 18 | return Program{ 19 | Modules: modules, 20 | Forward: forward, 21 | } 22 | } 23 | 24 | // Execute runs the program with the given inputs. 25 | func (p Program) Execute(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { 26 | if p.Forward == nil { 27 | return nil, errors.New("forward function is not defined") 28 | } 29 | // Ensure we have execution state 30 | if GetExecutionState(ctx) == nil { 31 | ctx = WithExecutionState(ctx) 32 | } 33 | 34 | ctx, span := StartSpan(ctx, "Program") 35 | defer EndSpan(ctx) 36 | 37 | span.WithAnnotation("inputs", inputs) 38 | outputs, err := p.Forward(ctx, inputs) 39 | if err != nil { 40 | span.WithError(err) 41 | return nil, err 42 | } 43 | 44 | span.WithAnnotation("outputs", outputs) 45 | return outputs, nil 46 | } 47 | 48 | // GetSignature returns the overall signature of the program 49 | // This would need to be defined based on the Forward function's expected inputs and outputs. 50 | func (p Program) GetSignature() Signature { 51 | var inputs []InputField 52 | var outputs []OutputField 53 | 54 | // Since modules are in a map, we can't rely on order. 55 | // We'll use the first module we find for inputs and outputs. 56 | for _, module := range p.Modules { 57 | sig := module.GetSignature() 58 | inputs = sig.Inputs 59 | outputs = sig.Outputs 60 | break 61 | } 62 | return NewSignature(inputs, outputs) 63 | } 64 | 65 | // Clone creates a deep copy of the Program. 66 | func (p Program) Clone() Program { 67 | modulesCopy := make(map[string]Module) 68 | for name, module := range p.Modules { 69 | modulesCopy[name] = module.Clone() 70 | } 71 | 72 | return Program{ 73 | Modules: modulesCopy, 74 | Forward: p.Forward, // Note: We're copying the pointer to the forward function 75 | } 76 | } 77 | 78 | // Equal checks if two Programs are equivalent. 79 | func (p Program) Equal(other Program) bool { 80 | if p.Forward == nil && other.Forward != nil || p.Forward != nil && other.Forward == nil { 81 | return false 82 | } 83 | if len(p.Modules) != len(other.Modules) { 84 | return false 85 | } 86 | for name, module := range p.Modules { 87 | otherModule, exists := other.Modules[name] 88 | if !exists { 89 | return false 90 | } 91 | if !reflect.DeepEqual(module.GetSignature(), otherModule.GetSignature()) { 92 | return false 93 | } 94 | } 95 | return true 96 | } 97 | 98 | // AddModule adds a new module to the Program. 99 | func (p *Program) AddModule(name string, module Module) { 100 | p.Modules[name] = module 101 | } 102 | 103 | // SetForward sets the forward function for the Program. 104 | func (p *Program) SetForward(forward func(context.Context, map[string]interface{}) (map[string]interface{}, error)) { 105 | p.Forward = forward 106 | } 107 | 108 | func (p *Program) GetModules() []Module { 109 | moduleNames := make([]string, 0, len(p.Modules)) 110 | for name := range p.Modules { 111 | moduleNames = append(moduleNames, name) 112 | } 113 | sort.Strings(moduleNames) 114 | 115 | // Build ordered module slice 116 | modules := make([]Module, len(moduleNames)) 117 | for i, name := range moduleNames { 118 | modules[i] = p.Modules[name] 119 | } 120 | return modules 121 | 122 | } 123 | 124 | func (p *Program) Predictors() []Module { 125 | modules := make([]Module, 0, len(p.Modules)) 126 | for _, module := range p.Modules { 127 | modules = append(modules, module) 128 | } 129 | return modules 130 | } 131 | -------------------------------------------------------------------------------- /pkg/core/signature.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // Field represents a single field in a signature. 9 | type Field struct { 10 | Name string 11 | Description string 12 | Prefix string 13 | } 14 | 15 | // NewField creates a new Field with smart defaults. 16 | func NewField(name string, opts ...FieldOption) Field { 17 | // Start with sensible defaults 18 | f := Field{ 19 | Name: name, 20 | Prefix: name + ":", // Default prefix is the field name with colon 21 | } 22 | 23 | // Apply any custom options 24 | for _, opt := range opts { 25 | opt(&f) 26 | } 27 | 28 | return f 29 | } 30 | 31 | // FieldOption allows customization of Field creation. 32 | type FieldOption func(*Field) 33 | 34 | // WithDescription sets a custom description. 35 | func WithDescription(desc string) FieldOption { 36 | return func(f *Field) { 37 | f.Description = desc 38 | } 39 | } 40 | 41 | // WithCustomPrefix overrides the default prefix. 42 | func WithCustomPrefix(prefix string) FieldOption { 43 | return func(f *Field) { 44 | f.Prefix = prefix 45 | } 46 | } 47 | 48 | // WithNoPrefix removes the prefix entirely. 49 | func WithNoPrefix() FieldOption { 50 | return func(f *Field) { 51 | f.Prefix = "" 52 | } 53 | } 54 | 55 | // InputField represents an input field. 56 | type InputField struct { 57 | Field 58 | } 59 | 60 | // OutputField represents an output field. 61 | type OutputField struct { 62 | Field 63 | } 64 | 65 | // Signature represents the input and output specification of a module. 66 | type Signature struct { 67 | Inputs []InputField 68 | Outputs []OutputField 69 | Instruction string 70 | } 71 | 72 | // NewSignature creates a new Signature with the given inputs and outputs. 73 | func NewSignature(inputs []InputField, outputs []OutputField) Signature { 74 | return Signature{ 75 | Inputs: inputs, 76 | Outputs: outputs, 77 | } 78 | } 79 | 80 | // WithInstruction adds an instruction to the Signature. 81 | func (s Signature) WithInstruction(instruction string) Signature { 82 | s.Instruction = instruction 83 | return s 84 | } 85 | 86 | // String returns a string representation of the Signature. 87 | func (s Signature) String() string { 88 | var sb strings.Builder 89 | sb.WriteString("Inputs:\n") 90 | for _, input := range s.Inputs { 91 | sb.WriteString(fmt.Sprintf(" - %s (%s)\n", input.Name, input.Description)) 92 | } 93 | sb.WriteString("Outputs:\n") 94 | for _, output := range s.Outputs { 95 | sb.WriteString(fmt.Sprintf(" - %s (%s)\n", output.Name, output.Description)) 96 | } 97 | if s.Instruction != "" { 98 | sb.WriteString(fmt.Sprintf("Instruction: %s\n", s.Instruction)) 99 | } 100 | return sb.String() 101 | } 102 | 103 | // ParseSignature parses a signature string into a Signature struct. 104 | func ParseSignature(signatureStr string) (Signature, error) { 105 | parts := strings.Split(signatureStr, "->") 106 | if len(parts) != 2 { 107 | return Signature{}, fmt.Errorf("invalid signature format: %s", signatureStr) 108 | } 109 | 110 | inputs := parseInputFields(strings.TrimSpace(parts[0])) 111 | outputs := parseOutputFields(strings.TrimSpace(parts[1])) 112 | 113 | return NewSignature(inputs, outputs), nil 114 | } 115 | 116 | func parseInputFields(fieldsStr string) []InputField { 117 | fieldStrs := strings.Split(fieldsStr, ",") 118 | fields := make([]InputField, len(fieldStrs)) 119 | for i, fieldStr := range fieldStrs { 120 | fieldStr = strings.TrimSpace(fieldStr) 121 | fields[i] = InputField{Field: Field{Name: fieldStr}} 122 | } 123 | return fields 124 | } 125 | 126 | func parseOutputFields(fieldsStr string) []OutputField { 127 | fieldStrs := strings.Split(fieldsStr, ",") 128 | fields := make([]OutputField, len(fieldStrs)) 129 | for i, fieldStr := range fieldStrs { 130 | fieldStr = strings.TrimSpace(fieldStr) 131 | fields[i] = OutputField{Field: Field{Name: fieldStr}} 132 | } 133 | return fields 134 | } 135 | 136 | // ShorthandNotation creates a Signature from a shorthand notation string. 137 | func ShorthandNotation(notation string) (Signature, error) { 138 | return ParseSignature(notation) 139 | } 140 | -------------------------------------------------------------------------------- /pkg/core/signature_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestField(t *testing.T) { 10 | t.Run("NewField with defaults", func(t *testing.T) { 11 | field := NewField("test") 12 | assert.Equal(t, "test", field.Name) 13 | assert.Equal(t, "test:", field.Prefix) 14 | assert.Empty(t, field.Description) 15 | }) 16 | 17 | t.Run("NewField with options", func(t *testing.T) { 18 | field := NewField("test", 19 | WithDescription("test description"), 20 | WithCustomPrefix("custom:"), 21 | ) 22 | assert.Equal(t, "test", field.Name) 23 | assert.Equal(t, "custom:", field.Prefix) 24 | assert.Equal(t, "test description", field.Description) 25 | }) 26 | 27 | t.Run("NewField with no prefix", func(t *testing.T) { 28 | field := NewField("test", WithNoPrefix()) 29 | assert.Equal(t, "test", field.Name) 30 | assert.Empty(t, field.Prefix) 31 | }) 32 | } 33 | func TestSignature(t *testing.T) { 34 | t.Run("NewSignature", func(t *testing.T) { 35 | inputs := []InputField{ 36 | {Field: Field{Name: "input1"}}, 37 | {Field: Field{Name: "input2"}}, 38 | } 39 | outputs := []OutputField{ 40 | {Field: Field{Name: "output1"}}, 41 | {Field: Field{Name: "output2"}}, 42 | } 43 | 44 | sig := NewSignature(inputs, outputs) 45 | assert.Equal(t, inputs, sig.Inputs) 46 | assert.Equal(t, outputs, sig.Outputs) 47 | assert.Empty(t, sig.Instruction) 48 | }) 49 | 50 | t.Run("WithInstruction", func(t *testing.T) { 51 | sig := NewSignature(nil, nil) 52 | sigWithInst := sig.WithInstruction("test instruction") 53 | assert.Equal(t, "test instruction", sigWithInst.Instruction) 54 | }) 55 | 56 | t.Run("String representation", func(t *testing.T) { 57 | sig := NewSignature( 58 | []InputField{{Field: Field{Name: "input", Description: "input desc"}}}, 59 | []OutputField{{Field: Field{Name: "output", Description: "output desc"}}}, 60 | ).WithInstruction("test instruction") 61 | 62 | str := sig.String() 63 | assert.Contains(t, str, "Inputs:") 64 | assert.Contains(t, str, "input (input desc)") 65 | assert.Contains(t, str, "Outputs:") 66 | assert.Contains(t, str, "output (output desc)") 67 | assert.Contains(t, str, "Instruction: test instruction") 68 | }) 69 | } 70 | 71 | func TestSignatureParser(t *testing.T) { 72 | t.Run("ParseSignature valid", func(t *testing.T) { 73 | signatureStr := "input1, input2 -> output1, output2" 74 | sig, err := ParseSignature(signatureStr) 75 | assert.NoError(t, err) 76 | assert.Len(t, sig.Inputs, 2) 77 | assert.Len(t, sig.Outputs, 2) 78 | assert.Equal(t, "input1", sig.Inputs[0].Name) 79 | assert.Equal(t, "output1", sig.Outputs[0].Name) 80 | }) 81 | 82 | t.Run("ParseSignature invalid", func(t *testing.T) { 83 | invalidStr := "invalid signature format" 84 | _, err := ParseSignature(invalidStr) 85 | assert.Error(t, err) 86 | assert.Contains(t, err.Error(), "invalid signature format") 87 | }) 88 | 89 | t.Run("ShorthandNotation", func(t *testing.T) { 90 | sig, err := ShorthandNotation("in1, in2 -> out1, out2") 91 | assert.NoError(t, err) 92 | assert.Len(t, sig.Inputs, 2) 93 | assert.Len(t, sig.Outputs, 2) 94 | }) 95 | 96 | t.Run("ParseSignature with whitespace", func(t *testing.T) { 97 | signatureStr := " input1 , input2 -> output1 , output2 " 98 | sig, err := ParseSignature(signatureStr) 99 | assert.NoError(t, err) 100 | assert.Len(t, sig.Inputs, 2) 101 | assert.Len(t, sig.Outputs, 2) 102 | assert.Equal(t, "input1", sig.Inputs[0].Name) 103 | assert.Equal(t, "output1", sig.Outputs[0].Name) 104 | }) 105 | } 106 | 107 | func TestHelperFunctions(t *testing.T) { 108 | t.Run("parseInputFields", func(t *testing.T) { 109 | fields := parseInputFields("field1, field2") 110 | assert.Len(t, fields, 2) 111 | assert.Equal(t, "field1", fields[0].Name) 112 | assert.Equal(t, "field2", fields[1].Name) 113 | }) 114 | 115 | t.Run("parseOutputFields", func(t *testing.T) { 116 | fields := parseOutputFields("field1, field2") 117 | assert.Len(t, fields, 2) 118 | assert.Equal(t, "field1", fields[0].Name) 119 | assert.Equal(t, "field2", fields[1].Name) 120 | }) 121 | } 122 | -------------------------------------------------------------------------------- /pkg/core/tool.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/XiaoConstantine/mcp-go/pkg/model" 7 | ) 8 | 9 | // ToolMetadata contains information about a tool's capabilities and requirements. 10 | type ToolMetadata struct { 11 | Name string // Unique identifier for the tool 12 | Description string // Human-readable description 13 | InputSchema models.InputSchema // Rich schema from MCP-Go - now consistent with Tool interface 14 | OutputSchema map[string]string // Keep this field for backward compatibility 15 | Capabilities []string // List of supported capabilities 16 | ContextNeeded []string // Required context keys 17 | Version string // Tool version for compatibility 18 | } 19 | 20 | // Tool represents a capability that can be used by both agents and modules. 21 | type Tool interface { 22 | // Name returns the tool's identifier 23 | Name() string 24 | 25 | // Description returns human-readable explanation of the tool's purpose 26 | Description() string 27 | 28 | // Metadata returns the tool's metadata 29 | Metadata() *ToolMetadata 30 | 31 | // CanHandle checks if the tool can handle a specific action/intent 32 | CanHandle(ctx context.Context, intent string) bool 33 | 34 | // Execute runs the tool with provided parameters 35 | Execute(ctx context.Context, params map[string]interface{}) (ToolResult, error) 36 | 37 | // Validate checks if the parameters match the expected schema 38 | Validate(params map[string]interface{}) error 39 | 40 | // InputSchema returns the expected parameter structure 41 | InputSchema() models.InputSchema 42 | } 43 | 44 | // ToolResult wraps tool execution results with metadata. 45 | type ToolResult struct { 46 | Data interface{} // The actual result data 47 | Metadata map[string]interface{} // Execution metadata (timing, resources used, etc) 48 | Annotations map[string]interface{} // Additional context for result interpretation 49 | } 50 | 51 | // ToolRegistry manages available tools. 52 | type ToolRegistry interface { 53 | Register(tool Tool) error 54 | Get(name string) (Tool, error) 55 | List() []Tool 56 | Match(intent string) []Tool 57 | } 58 | -------------------------------------------------------------------------------- /pkg/datasets/dataset.go: -------------------------------------------------------------------------------- 1 | package datasets 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "os" 8 | "path/filepath" 9 | ) 10 | 11 | var ( 12 | gsm8kDatasetURL = "https://huggingface.co/datasets/openai/gsm8k/resolve/main/main/test-00000-of-00001.parquet" 13 | hotPotQADatasetURL = "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json" 14 | ) 15 | 16 | // For testing purposes. 17 | func setTestURLs(gsm8k, hotpotqa string) { 18 | gsm8kDatasetURL = gsm8k 19 | hotPotQADatasetURL = hotpotqa 20 | } 21 | 22 | func EnsureDataset(datasetName string) (string, error) { 23 | homeDir, err := os.UserHomeDir() 24 | if err != nil { 25 | return "", fmt.Errorf("failed to get user home directory: %w", err) 26 | } 27 | var suffix string 28 | switch datasetName { 29 | case "gsm8k": 30 | suffix = ".parquet" 31 | case "hotpotqa": 32 | suffix = ".json" 33 | default: 34 | suffix = ".parquet" 35 | } 36 | datasetDir := filepath.Join(homeDir, ".dspy-go", "datasets") 37 | if err := os.MkdirAll(datasetDir, 0755); err != nil { 38 | return "", fmt.Errorf("failed to create dataset directory: %w", err) 39 | } 40 | 41 | datasetPath := filepath.Join(datasetDir, datasetName+suffix) 42 | 43 | if _, err := os.Stat(datasetPath); os.IsNotExist(err) { 44 | fmt.Printf("Dataset %s not found locally. Downloading from Hugging Face...\n", datasetName) 45 | if err := downloadDataset(datasetName, datasetPath); err != nil { 46 | return "", fmt.Errorf("failed to download dataset: %w", err) 47 | } 48 | } 49 | 50 | return datasetPath, nil 51 | } 52 | 53 | func downloadDataset(datasetName, datasetPath string) error { 54 | var url string 55 | switch datasetName { 56 | case "gsm8k": 57 | url = gsm8kDatasetURL 58 | case "hotpotqa": 59 | url = hotPotQADatasetURL 60 | default: 61 | return fmt.Errorf("unknown dataset: %s", datasetName) 62 | } 63 | 64 | resp, err := http.Get(url) 65 | if err != nil { 66 | return fmt.Errorf("failed to download dataset: %w", err) 67 | } 68 | defer resp.Body.Close() 69 | // Check for non-200 status codes 70 | if resp.StatusCode != http.StatusOK { 71 | return fmt.Errorf("server returned non-200 status code: %d", resp.StatusCode) 72 | } 73 | out, err := os.Create(datasetPath) 74 | if err != nil { 75 | return fmt.Errorf("failed to create dataset file: %w", err) 76 | } 77 | defer out.Close() 78 | 79 | _, err = io.Copy(out, resp.Body) 80 | if err != nil { 81 | return fmt.Errorf("failed to save dataset: %w", err) 82 | } 83 | 84 | return nil 85 | } 86 | -------------------------------------------------------------------------------- /pkg/datasets/dataset_test.go: -------------------------------------------------------------------------------- 1 | package datasets 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "os" 8 | "path/filepath" 9 | "testing" 10 | ) 11 | 12 | func TestEnsureDataset(t *testing.T) { 13 | // Setup 14 | homeDir, _ := os.UserHomeDir() 15 | datasetDir := filepath.Join(homeDir, ".dspy-go", "datasets") 16 | 17 | tests := []struct { 18 | name string 19 | datasetName string 20 | expectedSuffix string 21 | setupFunc func() 22 | cleanupFunc func() 23 | }{ 24 | { 25 | name: "GSM8K dataset - not existing", 26 | datasetName: "gsm8k", 27 | expectedSuffix: ".parquet", 28 | setupFunc: func() { 29 | os.RemoveAll(datasetDir) 30 | }, 31 | cleanupFunc: func() { 32 | os.RemoveAll(datasetDir) 33 | }, 34 | }, 35 | { 36 | name: "HotPotQA dataset - not existing", 37 | datasetName: "hotpotqa", 38 | expectedSuffix: ".json", 39 | setupFunc: func() { 40 | os.RemoveAll(datasetDir) 41 | }, 42 | cleanupFunc: func() { 43 | os.RemoveAll(datasetDir) 44 | }, 45 | }, 46 | { 47 | name: "Unknown dataset", 48 | datasetName: "unknown", 49 | expectedSuffix: ".parquet", 50 | setupFunc: func() {}, 51 | cleanupFunc: func() {}, 52 | }, 53 | { 54 | name: "Existing dataset", 55 | datasetName: "existing", 56 | expectedSuffix: ".parquet", 57 | setupFunc: func() { 58 | if err := os.MkdirAll(datasetDir, 0755); err != nil { 59 | return 60 | } 61 | if err := os.WriteFile(filepath.Join(datasetDir, "existing.parquet"), []byte("test"), 0644); err != nil { 62 | return 63 | } 64 | }, 65 | cleanupFunc: func() { 66 | os.RemoveAll(datasetDir) 67 | }, 68 | }, 69 | } 70 | 71 | for _, tt := range tests { 72 | t.Run(tt.name, func(t *testing.T) { 73 | tt.setupFunc() 74 | defer tt.cleanupFunc() 75 | 76 | // Mock HTTP server 77 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 | w.WriteHeader(http.StatusOK) 79 | if _, err := w.Write([]byte("mock dataset content")); err != nil { 80 | return 81 | } 82 | })) 83 | defer server.Close() 84 | setTestURLs(server.URL, server.URL) 85 | 86 | path, err := EnsureDataset(tt.datasetName) 87 | 88 | if tt.datasetName == "unknown" { 89 | if err == nil { 90 | t.Errorf("Expected error for unknown dataset, got nil") 91 | } 92 | } else { 93 | if err != nil { 94 | t.Errorf("Unexpected error: %v", err) 95 | } 96 | 97 | expectedPath := filepath.Join(datasetDir, tt.datasetName+tt.expectedSuffix) 98 | if path != expectedPath { 99 | t.Errorf("Expected path %s, got %s", expectedPath, path) 100 | } 101 | 102 | if _, err := os.Stat(path); os.IsNotExist(err) { 103 | t.Errorf("Dataset file not created") 104 | } 105 | } 106 | }) 107 | } 108 | } 109 | 110 | func TestDownloadDataset(t *testing.T) { 111 | // Setup 112 | tempDir, err := os.MkdirTemp("", "dataset-test") 113 | if err != nil { 114 | t.Fatalf("Failed to create temp dir: %v", err) 115 | } 116 | defer os.RemoveAll(tempDir) 117 | 118 | tests := []struct { 119 | name string 120 | datasetName string 121 | setupServer func() *httptest.Server 122 | expectError bool 123 | }{ 124 | { 125 | name: "Successful download - GSM8K", 126 | datasetName: "gsm8k", 127 | setupServer: func() *httptest.Server { 128 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 129 | w.WriteHeader(http.StatusOK) 130 | if _, err := w.Write([]byte("mock gsm8k content")); err != nil { 131 | return 132 | } 133 | })) 134 | }, 135 | expectError: false, 136 | }, 137 | { 138 | name: "Successful download - HotPotQA", 139 | datasetName: "hotpotqa", 140 | setupServer: func() *httptest.Server { 141 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 142 | w.WriteHeader(http.StatusOK) 143 | if _, err := w.Write([]byte("mock hotpotqa content")); err != nil { 144 | return 145 | } 146 | })) 147 | }, 148 | expectError: false, 149 | }, 150 | { 151 | name: "Unknown dataset", 152 | datasetName: "unknown", 153 | setupServer: func() *httptest.Server { 154 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 155 | w.WriteHeader(http.StatusOK) 156 | })) 157 | }, 158 | expectError: true, 159 | }, 160 | { 161 | name: "Server error", 162 | datasetName: "gsm8k", 163 | setupServer: func() *httptest.Server { 164 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 165 | w.WriteHeader(http.StatusInternalServerError) 166 | })) 167 | }, 168 | expectError: true, 169 | }, 170 | } 171 | 172 | for _, tt := range tests { 173 | t.Run(tt.name, func(t *testing.T) { 174 | server := tt.setupServer() 175 | defer server.Close() 176 | 177 | setTestURLs(server.URL, server.URL) 178 | 179 | datasetPath := filepath.Join(tempDir, tt.datasetName+".dataset") 180 | err := downloadDataset(tt.datasetName, datasetPath) 181 | 182 | if tt.expectError { 183 | if err == nil { 184 | t.Errorf("Expected error, got nil") 185 | } 186 | } else { 187 | if err != nil { 188 | t.Errorf("Unexpected error: %v", err) 189 | } 190 | 191 | content, err := os.ReadFile(datasetPath) 192 | if err != nil { 193 | t.Errorf("Failed to read downloaded file: %v", err) 194 | } 195 | 196 | expectedContent := fmt.Sprintf("mock %s content", tt.datasetName) 197 | if string(content) != expectedContent { 198 | t.Errorf("Expected content %s, got %s", expectedContent, string(content)) 199 | } 200 | } 201 | }) 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /pkg/datasets/gsm8k.go: -------------------------------------------------------------------------------- 1 | //go:build !skip 2 | // +build !skip 3 | 4 | package datasets 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "log" 10 | 11 | "github.com/apache/arrow/go/v13/arrow/array" 12 | "github.com/apache/arrow/go/v13/arrow/memory" 13 | "github.com/apache/arrow/go/v13/parquet/file" 14 | "github.com/apache/arrow/go/v13/parquet/pqarrow" 15 | ) 16 | 17 | type GSM8KExample struct { 18 | Question string `json:"question"` 19 | Answer string `json:"answer"` 20 | } 21 | 22 | func LoadGSM8K() ([]GSM8KExample, error) { 23 | datasetPath, err := EnsureDataset("gsm8k") 24 | if err != nil { 25 | return nil, err 26 | } 27 | // Open the Parquet file 28 | reader, err := file.OpenParquetFile(datasetPath, false) 29 | if err != nil { 30 | log.Fatalf("Error opening Parquet file: %v", err) 31 | } 32 | defer reader.Close() 33 | arrowReader, err := pqarrow.NewFileReader(reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | // Get the schema 39 | schema, _ := arrowReader.Schema() 40 | fmt.Println(schema) 41 | // Find question and answer fields 42 | // Find question and answer field indices 43 | questionIndices := schema.FieldIndices("question") 44 | answerIndices := schema.FieldIndices("answer") 45 | if len(questionIndices) == 0 || len(answerIndices) == 0 { 46 | log.Fatalf("Required columns 'question' and 'answer' not found in the schema") 47 | } 48 | questionIndex := questionIndices[0] 49 | answerIndex := answerIndices[0] 50 | fmt.Printf("Question index: %d, Answer index: %d\n", questionIndex, answerIndex) 51 | 52 | // Prepare a slice to hold all examples 53 | // Read the entire table 54 | table, err := arrowReader.ReadTable(context.Background()) 55 | if err != nil { 56 | log.Fatalf("Error reading table: %v", err) 57 | } 58 | defer table.Release() 59 | 60 | fmt.Printf("Table number of columns: %d\n", table.NumCols()) 61 | fmt.Printf("Table number of rows: %d\n", table.NumRows()) 62 | // Get question and answer columns 63 | questionCol := table.Column(questionIndex) 64 | answerCol := table.Column(answerIndex) 65 | 66 | // Prepare a slice to hold all examples 67 | examples := make([]GSM8KExample, table.NumRows()) 68 | 69 | // Create GSM8KExample structs 70 | for i := 0; i < int(table.NumRows()); i++ { 71 | questionChunk := questionCol.Data().Chunk(0) 72 | answerChunk := answerCol.Data().Chunk(0) 73 | 74 | questionValue := questionChunk.(*array.String).Value(i) 75 | answerValue := answerChunk.(*array.String).Value(i) 76 | examples[i] = GSM8KExample{ 77 | Question: questionValue, 78 | Answer: answerValue, 79 | } 80 | } 81 | 82 | fmt.Printf("Total examples read: %d\n", len(examples)) 83 | return examples, nil 84 | } 85 | -------------------------------------------------------------------------------- /pkg/datasets/hotpot_qa.go: -------------------------------------------------------------------------------- 1 | //go:build !skip 2 | // +build !skip 3 | 4 | package datasets 5 | 6 | import ( 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "os" 11 | ) 12 | 13 | type HotPotQAExample struct { 14 | ID string `json:"_id"` 15 | SupportingFacts [][]interface{} `json:"supporting_facts"` 16 | Context [][]interface{} `json:"context"` 17 | Question string `json:"question"` 18 | Answer string `json:"answer"` 19 | Type string `json:"type"` 20 | Level string `json:"level"` 21 | } 22 | 23 | func LoadHotpotQA() ([]HotPotQAExample, error) { 24 | datasetPath, err := EnsureDataset("hotpotqa") 25 | if err != nil { 26 | return nil, err 27 | } 28 | file, err := os.Open(datasetPath) 29 | if err != nil { 30 | return nil, err 31 | } 32 | byteValue, err := io.ReadAll(file) 33 | if err != nil { 34 | fmt.Println("Error reading file:", err) 35 | return nil, err 36 | } 37 | var examples []HotPotQAExample 38 | err = json.Unmarshal(byteValue, &examples) 39 | if err != nil { 40 | fmt.Println("Failed to load HotPotQA dataset:", err) 41 | return nil, err 42 | } 43 | return examples, nil 44 | } 45 | -------------------------------------------------------------------------------- /pkg/errors/errors.go: -------------------------------------------------------------------------------- 1 | package errors 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // ErrorCode defines known error types in the system. 9 | type ErrorCode int 10 | 11 | const ( 12 | // Core error codes. 13 | Unknown ErrorCode = iota 14 | InvalidInput 15 | ValidationFailed 16 | ResourceNotFound 17 | Timeout 18 | RateLimitExceeded 19 | Canceled 20 | ResourceExhausted // For scenarios where a resource limit is reached 21 | // LLM specific errors. 22 | LLMGenerationFailed 23 | TokenLimitExceeded 24 | InvalidResponse 25 | 26 | // Workflow errors. 27 | WorkflowExecutionFailed 28 | StepExecutionFailed 29 | InvalidWorkflowState 30 | ) 31 | 32 | // Error represents a structured error with context. 33 | type Error struct { 34 | code ErrorCode // Type of error 35 | message string // Human-readable message 36 | original error // Original/wrapped error 37 | fields Fields // Additional context 38 | } 39 | 40 | // Fields carries structured data about the error. 41 | type Fields map[string]interface{} 42 | 43 | func (e *Error) Error() string { 44 | var b strings.Builder 45 | b.WriteString(e.message) 46 | 47 | if e.original != nil { 48 | b.WriteString(": ") 49 | b.WriteString(e.original.Error()) 50 | } 51 | 52 | if len(e.fields) > 0 { 53 | b.WriteString(" [") 54 | for k, v := range e.fields { 55 | fmt.Fprintf(&b, "%s=%v ", k, v) 56 | } 57 | b.WriteString("]") 58 | } 59 | 60 | return strings.TrimSpace(b.String()) 61 | } 62 | 63 | func (e *Error) Unwrap() error { 64 | return e.original 65 | } 66 | 67 | func (e *Error) Code() ErrorCode { 68 | return e.code 69 | } 70 | 71 | // New creates a new error with a code and message. 72 | func New(code ErrorCode, message string) error { 73 | return &Error{ 74 | code: code, 75 | message: message, 76 | } 77 | } 78 | 79 | // Wrap wraps an existing error with additional context. 80 | func Wrap(err error, code ErrorCode, message string) error { 81 | if err == nil { 82 | return nil 83 | } 84 | return &Error{ 85 | code: code, 86 | message: message, 87 | original: err, 88 | } 89 | } 90 | 91 | // WithFields adds structured context to an error. 92 | func WithFields(err error, fields Fields) error { 93 | if err == nil { 94 | return nil 95 | } 96 | 97 | // If it's already our error type, add fields 98 | if e, ok := err.(*Error); ok { 99 | newFields := make(Fields) 100 | for k, v := range e.fields { 101 | newFields[k] = v 102 | } 103 | for k, v := range fields { 104 | newFields[k] = v 105 | } 106 | 107 | return &Error{ 108 | code: e.code, 109 | message: e.message, 110 | original: e.original, 111 | fields: newFields, 112 | } 113 | } 114 | 115 | // Otherwise, create new error 116 | return &Error{ 117 | code: Unknown, 118 | message: err.Error(), 119 | original: err, 120 | fields: fields, 121 | } 122 | } 123 | 124 | // Is implements error matching. 125 | func (e *Error) Is(target error) bool { 126 | t, ok := target.(*Error) 127 | if !ok { 128 | return false 129 | } 130 | return e.code == t.code 131 | } 132 | 133 | // As implements error type casting for errors.As. 134 | func (e *Error) As(target interface{}) bool { 135 | // Check if target is a pointer to *Error 136 | errorPtr, ok := target.(**Error) 137 | if !ok { 138 | return false 139 | } 140 | // Set the target pointer to our error 141 | *errorPtr = e 142 | return true 143 | } 144 | 145 | func (e *Error) Fields() Fields { 146 | if e.fields == nil { 147 | return Fields{} 148 | } 149 | // Create a copy of the fields map 150 | fields := make(Fields, len(e.fields)) 151 | for k, v := range e.fields { 152 | fields[k] = v 153 | } 154 | return fields 155 | } 156 | -------------------------------------------------------------------------------- /pkg/llms/factory.go: -------------------------------------------------------------------------------- 1 | package llms 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "sync" 7 | 8 | "github.com/XiaoConstantine/anthropic-go/anthropic" 9 | "github.com/XiaoConstantine/dspy-go/pkg/core" 10 | ) 11 | 12 | type DefaultLLMFactory struct{} 13 | 14 | var ( 15 | defaultFactory *DefaultLLMFactory 16 | defaultFactoryOnce sync.Once 17 | ) 18 | 19 | // resetFactoryForTesting resets the factory for testing purposes 20 | // This should only be called from tests. 21 | func resetFactoryForTesting() { 22 | defaultFactory = nil 23 | defaultFactoryOnce = sync.Once{} 24 | core.DefaultFactory = nil 25 | } 26 | 27 | func ensureFactory() { 28 | defaultFactoryOnce.Do(func() { 29 | defaultFactory = &DefaultLLMFactory{} 30 | core.DefaultFactory = defaultFactory 31 | }) 32 | } 33 | 34 | // NewLLM creates a new LLM instance based on the provided model ID. 35 | func NewLLM(apiKey string, modelID core.ModelID) (core.LLM, error) { 36 | 37 | ensureFactory() 38 | var llm core.LLM 39 | var err error 40 | switch { 41 | case modelID == core.ModelAnthropicHaiku || modelID == core.ModelAnthropicSonnet || modelID == core.ModelAnthropicOpus: 42 | llm, err = NewAnthropicLLM(apiKey, anthropic.ModelID(modelID)) 43 | case modelID == core.ModelGoogleGeminiFlash || modelID == core.ModelGoogleGeminiPro || modelID == core.ModelGoogleGeminiFlashThinking || modelID == core.ModelGoogleGeminiFlashLite: 44 | llm, err = NewGeminiLLM(apiKey, modelID) 45 | case strings.HasPrefix(string(modelID), "ollama:"): 46 | parts := strings.SplitN(string(modelID), ":", 2) 47 | if len(parts) != 2 || parts[1] == "" { 48 | return nil, fmt.Errorf("invalid Ollama model ID format. Use 'ollama:'") 49 | } 50 | llm, err = NewOllamaLLM("http://localhost:11434", parts[1]) 51 | case strings.HasPrefix(string(modelID), "llamacpp:"): 52 | return NewLlamacppLLM("http://localhost:8080") 53 | default: 54 | return nil, fmt.Errorf("unsupported model ID: %s", modelID) 55 | } 56 | if err != nil { 57 | return nil, err 58 | } 59 | return core.Chain(llm, 60 | func(l core.LLM) core.LLM { return core.NewModelContextDecorator(l) }, 61 | ), nil 62 | } 63 | 64 | // Implement the LLMFactory interface. 65 | func (f *DefaultLLMFactory) CreateLLM(apiKey string, modelID core.ModelID) (core.LLM, error) { 66 | return NewLLM(apiKey, modelID) 67 | } 68 | 69 | func init() { 70 | ensureFactory() 71 | } 72 | 73 | func EnsureFactory() { 74 | ensureFactory() 75 | } 76 | -------------------------------------------------------------------------------- /pkg/logging/log_entry.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/XiaoConstantine/dspy-go/pkg/core" 7 | ) 8 | 9 | // contextKey is a custom type for context keys to avoid collisions. 10 | type contextKey string 11 | 12 | const ( 13 | // ModelIDKey is used to store/retrieve ModelID from context. 14 | ModelIDKey contextKey = "model_id" 15 | 16 | // TokenInfoKey is used to store/retrieve token usage information. 17 | TokenInfoKey contextKey = "token_info" 18 | ) 19 | 20 | // LogEntry represents a structured log record with fields particularly relevant to LLM operations. 21 | type LogEntry struct { 22 | // Standard fields 23 | Time int64 24 | Severity Severity 25 | Message string 26 | File string 27 | Line int 28 | Function string 29 | TraceID string // Added trace ID field 30 | 31 | // LLM-specific fields 32 | ModelID string // The LLM model being used 33 | TokenInfo *core.TokenInfo // Token usage information 34 | Latency int64 // Operation duration in milliseconds 35 | Cost float64 // Operation cost in dollars 36 | 37 | // General structured data 38 | Fields map[string]interface{} 39 | } 40 | 41 | // WithModelID adds a ModelID to the context. 42 | func WithModelID(ctx context.Context, modelID core.ModelID) context.Context { 43 | return context.WithValue(ctx, ModelIDKey, modelID) 44 | } 45 | 46 | // GetModelID retrieves ModelID from context. 47 | func GetModelID(ctx context.Context) (core.ModelID, bool) { 48 | if v := ctx.Value(ModelIDKey); v != nil { 49 | if mid, ok := v.(core.ModelID); ok { 50 | return mid, true 51 | } 52 | } 53 | return "", false 54 | } 55 | 56 | // WithTokenInfo adds TokenInfo to the context. 57 | func WithTokenInfo(ctx context.Context, info *core.TokenInfo) context.Context { 58 | return context.WithValue(ctx, TokenInfoKey, info) 59 | } 60 | 61 | // GetTokenInfo retrieves TokenInfo from context. 62 | func GetTokenInfo(ctx context.Context) (*core.TokenInfo, bool) { 63 | if v := ctx.Value(TokenInfoKey); v != nil { 64 | if ti, ok := v.(*core.TokenInfo); ok { 65 | return ti, true 66 | } 67 | } 68 | return nil, false 69 | } 70 | -------------------------------------------------------------------------------- /pkg/logging/log_entry_test.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/XiaoConstantine/dspy-go/pkg/core" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestContextValues(t *testing.T) { 12 | ctx := context.Background() 13 | 14 | // Test ModelID 15 | modelID := core.ModelID("test-model") 16 | ctxWithModel := WithModelID(ctx, modelID) 17 | retrievedModelID, ok := GetModelID(ctxWithModel) 18 | assert.True(t, ok) 19 | assert.Equal(t, modelID, retrievedModelID) 20 | 21 | // Test TokenInfo 22 | tokenInfo := &core.TokenInfo{ 23 | PromptTokens: 100, 24 | CompletionTokens: 50, 25 | TotalTokens: 150, 26 | } 27 | ctxWithToken := WithTokenInfo(ctx, tokenInfo) 28 | retrievedTokenInfo, ok := GetTokenInfo(ctxWithToken) 29 | assert.True(t, ok) 30 | assert.Equal(t, tokenInfo, retrievedTokenInfo) 31 | 32 | // Test invalid context values 33 | _, ok = GetModelID(ctx) 34 | assert.False(t, ok) 35 | _, ok = GetTokenInfo(ctx) 36 | assert.False(t, ok) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/logging/logger.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "runtime" 9 | "sync" 10 | "time" 11 | 12 | "github.com/XiaoConstantine/dspy-go/pkg/core" 13 | ) 14 | 15 | var ( 16 | defaultLogger *Logger 17 | mu sync.RWMutex 18 | ) 19 | 20 | var osExit = os.Exit 21 | 22 | // Logger provides the core logging functionality. 23 | type Logger struct { 24 | mu sync.Mutex 25 | severity Severity 26 | outputs []Output 27 | sampleRate uint32 // For high-frequency event sampling 28 | fields map[string]interface{} // Default fields for all logs 29 | } 30 | 31 | // Output interface allows for different logging destinations. 32 | type Output interface { 33 | Write(LogEntry) error 34 | Sync() error 35 | Close() error 36 | } 37 | 38 | // Config allows flexible logger configuration. 39 | type Config struct { 40 | Severity Severity 41 | Outputs []Output 42 | SampleRate uint32 43 | DefaultFields map[string]interface{} 44 | } 45 | 46 | // NewLogger creates a new logger with the given configuration. 47 | func NewLogger(cfg Config) *Logger { 48 | return &Logger{ 49 | severity: cfg.Severity, 50 | outputs: cfg.Outputs, 51 | sampleRate: cfg.SampleRate, 52 | fields: cfg.DefaultFields, 53 | } 54 | } 55 | 56 | // logf is the core logging function that handles all severity levels. 57 | func (l *Logger) logf(ctx context.Context, s Severity, format string, args ...interface{}) { 58 | // Early severity check for performance 59 | if s < l.severity { 60 | return 61 | } 62 | 63 | // Get caller information 64 | pc, file, line, _ := runtime.Caller(2) 65 | fn := runtime.FuncForPC(pc).Name() 66 | 67 | // Create base entry 68 | entry := LogEntry{ 69 | Time: time.Now().UnixNano(), 70 | Severity: s, 71 | Message: fmt.Sprintf(format, args...), 72 | File: filepath.Base(file), 73 | Line: line, 74 | Function: filepath.Base(fn), 75 | Fields: make(map[string]interface{}), 76 | } 77 | 78 | // Add context values if present 79 | if ctx != nil { 80 | if modelID, ok := GetModelID(ctx); ok { 81 | entry.ModelID = string(modelID) 82 | } 83 | 84 | if tokenInfo, ok := GetTokenInfo(ctx); ok { 85 | entry.TokenInfo = tokenInfo 86 | } 87 | } 88 | if state := core.GetExecutionState(ctx); state != nil { 89 | entry.TraceID = state.GetTraceID() 90 | } 91 | 92 | // Add default fields 93 | for k, v := range l.fields { 94 | if _, exists := entry.Fields[k]; !exists { 95 | entry.Fields[k] = v 96 | } 97 | } 98 | // Add execution context information if available 99 | if state := core.GetExecutionState(ctx); state != nil { 100 | entry.Fields["model_id"] = state.GetModelID() 101 | if usage := state.GetTokenUsage(); usage != nil { 102 | entry.Fields["token_usage"] = usage 103 | } 104 | entry.Fields["spans"] = core.CollectSpans(ctx) 105 | } 106 | 107 | // Write to all outputs 108 | l.mu.Lock() 109 | defer l.mu.Unlock() 110 | 111 | for _, out := range l.outputs { 112 | if err := out.Write(entry); err != nil { 113 | fmt.Fprintf(os.Stderr, "failed to write log entry: %v\n", err) 114 | } 115 | } 116 | } 117 | 118 | // LLM-specific logging methods. 119 | func (l *Logger) PromptCompletion(ctx context.Context, prompt, completion string, tokenInfo *core.TokenInfo) { 120 | if l.severity > DEBUG { 121 | return 122 | } 123 | 124 | l.Debug(ctx, "LLM Interaction: prompt: %s, completion: %v, token_info: %v", 125 | prompt, 126 | completion, 127 | tokenInfo, 128 | ) 129 | } 130 | 131 | // Regular severity-based logging methods. 132 | func (l *Logger) Debug(ctx context.Context, format string, args ...interface{}) { 133 | l.logf(ctx, DEBUG, format, args...) 134 | } 135 | 136 | func (l *Logger) Info(ctx context.Context, format string, args ...interface{}) { 137 | l.logf(ctx, INFO, format, args...) 138 | } 139 | 140 | func (l *Logger) Warn(ctx context.Context, format string, args ...interface{}) { 141 | l.logf(ctx, WARN, format, args...) 142 | } 143 | 144 | func (l *Logger) Error(ctx context.Context, format string, args ...interface{}) { 145 | l.logf(ctx, ERROR, format, args...) 146 | } 147 | 148 | func (l *Logger) Fatal(ctx context.Context, msg string) { 149 | l.logf(ctx, FATAL, "%s", msg) 150 | 151 | // Ensure all logs are written 152 | for _, out := range l.outputs { 153 | _ = out.Sync() 154 | _ = out.Close() 155 | } 156 | 157 | // Exit the program with status code 1 158 | osExit(1) 159 | } 160 | 161 | func (l *Logger) Fatalf(ctx context.Context, format string, args ...interface{}) { 162 | l.logf(ctx, FATAL, format, args...) 163 | 164 | // Ensure all logs are written 165 | for _, out := range l.outputs { 166 | _ = out.Sync() 167 | _ = out.Close() 168 | } 169 | 170 | // Exit the program with status code 1 171 | osExit(1) 172 | } 173 | 174 | // GetLogger returns the global logger instance. 175 | func GetLogger() *Logger { 176 | // First try reading without a write lock 177 | mu.RLock() 178 | if l := defaultLogger; l != nil { 179 | mu.RUnlock() 180 | return l 181 | } 182 | mu.RUnlock() 183 | 184 | // If no logger exists, create one with write lock 185 | mu.Lock() 186 | defer mu.Unlock() 187 | 188 | if defaultLogger == nil { 189 | // Create default logger with reasonable defaults 190 | defaultLogger = NewLogger(Config{ 191 | Severity: INFO, 192 | Outputs: []Output{ 193 | NewConsoleOutput(false), 194 | }, 195 | }) 196 | } 197 | 198 | return defaultLogger 199 | } 200 | 201 | // SetLogger allows setting a custom configured logger as the global instance. 202 | func SetLogger(l *Logger) { 203 | mu.Lock() 204 | defaultLogger = l 205 | mu.Unlock() 206 | } 207 | -------------------------------------------------------------------------------- /pkg/logging/severity.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | // Severity represents log levels with clear mapping to different stages of LLM operations. 4 | type Severity int32 5 | 6 | const ( 7 | DEBUG Severity = iota 8 | INFO 9 | WARN 10 | ERROR 11 | FATAL 12 | ) 13 | 14 | // String provides human-readable severity levels. 15 | func (s Severity) String() string { 16 | return [...]string{"DEBUG", "INFO", "WARN", "ERROR", "FATAL"}[s] 17 | } 18 | -------------------------------------------------------------------------------- /pkg/logging/severity_test.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestSeverityString(t *testing.T) { 10 | tests := []struct { 11 | severity Severity 12 | expected string 13 | }{ 14 | {DEBUG, "DEBUG"}, 15 | {INFO, "INFO"}, 16 | {WARN, "WARN"}, 17 | {ERROR, "ERROR"}, 18 | {FATAL, "FATAL"}, 19 | } 20 | 21 | for _, tt := range tests { 22 | t.Run(tt.expected, func(t *testing.T) { 23 | assert.Equal(t, tt.expected, tt.severity.String()) 24 | }) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /pkg/metrics/accuracy.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | ) 7 | 8 | // ExactMatch checks if the predicted answer exactly matches the expected answer for all fields. 9 | func ExactMatch(expected, actual map[string]interface{}) float64 { 10 | for key, expectedValue := range expected { 11 | if actualValue, ok := actual[key]; !ok || !reflect.DeepEqual(expectedValue, actualValue) { 12 | return 0.0 13 | } 14 | } 15 | return 1.0 16 | } 17 | 18 | // AnyMatch checks if any of the predicted answers match the expected answer for all fields. 19 | func AnyMatch(expected, actual map[string]interface{}) float64 { 20 | for key, expectedValue := range expected { 21 | actualValue, ok := actual[key] 22 | if !ok { 23 | return 0.0 24 | } 25 | 26 | if reflect.TypeOf(actualValue).Kind() == reflect.Slice { 27 | found := false 28 | slice := reflect.ValueOf(actualValue) 29 | for i := 0; i < slice.Len(); i++ { 30 | if reflect.DeepEqual(expectedValue, slice.Index(i).Interface()) { 31 | found = true 32 | break 33 | } 34 | } 35 | if !found { 36 | return 0.0 37 | } 38 | } else if !reflect.DeepEqual(expectedValue, actualValue) { 39 | return 0.0 40 | } 41 | } 42 | return 1.0 43 | } 44 | 45 | // F1Score calculates the F1 score between the expected and actual answers. 46 | func F1Score(expected, actual map[string]interface{}) float64 { 47 | var totalF1 float64 48 | var count int 49 | 50 | for key, expectedValue := range expected { 51 | actualValue, ok := actual[key] 52 | if !ok { 53 | continue 54 | } 55 | 56 | expectedStr, expectedOk := expectedValue.(string) 57 | actualStr, actualOk := actualValue.(string) 58 | if !expectedOk || !actualOk { 59 | continue 60 | } 61 | 62 | expectedTokens := tokenize(expectedStr) 63 | actualTokens := tokenize(actualStr) 64 | 65 | if len(expectedTokens) == 0 && len(actualTokens) == 0 { 66 | totalF1 += 1.0 67 | count++ 68 | continue 69 | } 70 | 71 | if len(expectedTokens) == 0 || len(actualTokens) == 0 { 72 | count++ 73 | continue 74 | } 75 | 76 | intersection := intersection(expectedTokens, actualTokens) 77 | precision := float64(len(intersection)) / float64(len(actualTokens)) 78 | recall := float64(len(intersection)) / float64(len(expectedTokens)) 79 | 80 | if precision+recall > 0 { 81 | f1 := 2 * precision * recall / (precision + recall) 82 | totalF1 += f1 83 | count++ 84 | } else { 85 | count++ 86 | } 87 | } 88 | 89 | if count == 0 { 90 | return 0.0 91 | } 92 | 93 | return totalF1 / float64(count) 94 | } 95 | 96 | // Helper functions. 97 | func tokenize(s string) []string { 98 | return strings.Fields(s) 99 | } 100 | 101 | func intersection(a, b []string) []string { 102 | set := make(map[string]bool) 103 | for _, item := range a { 104 | set[item] = true 105 | } 106 | 107 | var result []string 108 | for _, item := range b { 109 | if set[item] { 110 | result = append(result, item) 111 | delete(set, item) 112 | } 113 | } 114 | return result 115 | } 116 | 117 | // MetricFunc is a type alias for metric functions. 118 | type MetricFunc func(expected, actual map[string]interface{}) float64 119 | 120 | // Accuracy is a struct that can be used to create customizable accuracy metrics. 121 | type Accuracy struct { 122 | MetricFunc MetricFunc 123 | } 124 | 125 | // NewAccuracy creates a new Accuracy metric with the specified metric function. 126 | func NewAccuracy(metricFunc MetricFunc) *Accuracy { 127 | return &Accuracy{MetricFunc: metricFunc} 128 | } 129 | 130 | // Evaluate applies the metric function to the expected and actual outputs. 131 | func (a *Accuracy) Evaluate(expected, actual map[string]interface{}) float64 { 132 | return a.MetricFunc(expected, actual) 133 | } 134 | -------------------------------------------------------------------------------- /pkg/metrics/accuracy_test.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestExactMatch(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | expected map[string]interface{} 13 | actual map[string]interface{} 14 | want float64 15 | }{ 16 | { 17 | name: "Exact match", 18 | expected: map[string]interface{}{"answer": "hello"}, 19 | actual: map[string]interface{}{"answer": "hello"}, 20 | want: 1.0, 21 | }, 22 | { 23 | name: "No match", 24 | expected: map[string]interface{}{"answer": "hello"}, 25 | actual: map[string]interface{}{"answer": "world"}, 26 | want: 0.0, 27 | }, 28 | { 29 | name: "Multiple fields match", 30 | expected: map[string]interface{}{"answer": "hello", "confidence": 0.9}, 31 | actual: map[string]interface{}{"answer": "hello", "confidence": 0.9}, 32 | want: 1.0, 33 | }, 34 | { 35 | name: "Multiple fields, partial match", 36 | expected: map[string]interface{}{"answer": "hello", "confidence": 0.9}, 37 | actual: map[string]interface{}{"answer": "hello", "confidence": 0.8}, 38 | want: 0.0, 39 | }, 40 | { 41 | name: "Missing field in actual", 42 | expected: map[string]interface{}{"answer": "hello", "confidence": 0.9}, 43 | actual: map[string]interface{}{"answer": "hello"}, 44 | want: 0.0, 45 | }, 46 | } 47 | 48 | for _, tt := range tests { 49 | t.Run(tt.name, func(t *testing.T) { 50 | got := ExactMatch(tt.expected, tt.actual) 51 | assert.Equal(t, tt.want, got) 52 | }) 53 | } 54 | } 55 | 56 | func TestAnyMatch(t *testing.T) { 57 | tests := []struct { 58 | name string 59 | expected map[string]interface{} 60 | actual map[string]interface{} 61 | want float64 62 | }{ 63 | { 64 | name: "Single value match", 65 | expected: map[string]interface{}{"answer": "hello"}, 66 | actual: map[string]interface{}{"answer": "hello"}, 67 | want: 1.0, 68 | }, 69 | { 70 | name: "Single value no match", 71 | expected: map[string]interface{}{"answer": "hello"}, 72 | actual: map[string]interface{}{"answer": "world"}, 73 | want: 0.0, 74 | }, 75 | { 76 | name: "Slice match", 77 | expected: map[string]interface{}{"answer": "hello"}, 78 | actual: map[string]interface{}{"answer": []interface{}{"world", "hello", "foo"}}, 79 | want: 1.0, 80 | }, 81 | { 82 | name: "Slice no match", 83 | expected: map[string]interface{}{"answer": "hello"}, 84 | actual: map[string]interface{}{"answer": []interface{}{"world", "foo", "bar"}}, 85 | want: 0.0, 86 | }, 87 | { 88 | name: "Multiple fields, all match", 89 | expected: map[string]interface{}{"answer": "hello", "confidence": 0.9}, 90 | actual: map[string]interface{}{"answer": []interface{}{"world", "hello"}, "confidence": 0.9}, 91 | want: 1.0, 92 | }, 93 | { 94 | name: "Multiple fields, partial match", 95 | expected: map[string]interface{}{"answer": "hello", "confidence": 0.9}, 96 | actual: map[string]interface{}{"answer": []interface{}{"world", "hello"}, "confidence": 0.8}, 97 | want: 0.0, 98 | }, 99 | } 100 | 101 | for _, tt := range tests { 102 | t.Run(tt.name, func(t *testing.T) { 103 | got := AnyMatch(tt.expected, tt.actual) 104 | assert.Equal(t, tt.want, got) 105 | }) 106 | } 107 | } 108 | 109 | func TestF1Score(t *testing.T) { 110 | tests := []struct { 111 | name string 112 | expected map[string]interface{} 113 | actual map[string]interface{} 114 | want float64 115 | }{ 116 | { 117 | name: "Perfect match", 118 | expected: map[string]interface{}{"answer": "the quick brown fox"}, 119 | actual: map[string]interface{}{"answer": "the quick brown fox"}, 120 | want: 1.0, 121 | }, 122 | { 123 | name: "No match", 124 | expected: map[string]interface{}{"answer": "the quick brown fox"}, 125 | actual: map[string]interface{}{"answer": "a lazy dog"}, 126 | want: 0.0, 127 | }, 128 | { 129 | name: "Partial match", 130 | expected: map[string]interface{}{"answer": "the quick brown fox"}, 131 | actual: map[string]interface{}{"answer": "the quick fox jumps"}, 132 | want: 0.75, // (2 * 3/4 * 3/4) / (3/4 + 3/4) 133 | }, 134 | { 135 | name: "Multiple fields", 136 | expected: map[string]interface{}{"answer1": "the quick brown fox", "answer2": "jumps over the lazy dog"}, 137 | actual: map[string]interface{}{"answer1": "the quick fox", "answer2": "jumps over the dog"}, 138 | want: 0.8730158730158731, // Average of 0.75 and 0.9960317460317461 139 | }, 140 | // { 141 | // name: "Non-string fields", 142 | // expected: map[string]interface{}{"answer": "the quick brown fox", "confidence": 0.9}, 143 | // actual: map[string]interface{}{"answer": "the quick fox", "confidence": 0.8}, 144 | // want: 0.75, // Only considers the string field 145 | // }, 146 | { 147 | name: "Empty string", 148 | expected: map[string]interface{}{"answer": ""}, 149 | actual: map[string]interface{}{"answer": ""}, 150 | want: 1.0, // Both empty strings should be considered a perfect match 151 | }, 152 | { 153 | name: "One empty string", 154 | expected: map[string]interface{}{"answer": "the quick brown fox"}, 155 | actual: map[string]interface{}{"answer": ""}, 156 | want: 0.0, 157 | }, 158 | { 159 | name: "All non-string fields", 160 | expected: map[string]interface{}{"confidence": 0.9, "score": 5}, 161 | actual: map[string]interface{}{"confidence": 0.8, "score": 4}, 162 | want: 0.0, // No string fields to compare 163 | }, 164 | } 165 | 166 | for _, tt := range tests { 167 | t.Run(tt.name, func(t *testing.T) { 168 | got := F1Score(tt.expected, tt.actual) 169 | assert.InDelta(t, tt.want, got, 0.0001) // Using InDelta for float comparison 170 | }) 171 | } 172 | } 173 | func TestAccuracy(t *testing.T) { 174 | expected := map[string]interface{}{"answer": "hello world"} 175 | actual := map[string]interface{}{"answer": "hello world"} 176 | 177 | exactMatchAccuracy := NewAccuracy(ExactMatch) 178 | assert.Equal(t, 1.0, exactMatchAccuracy.Evaluate(expected, actual)) 179 | 180 | f1ScoreAccuracy := NewAccuracy(F1Score) 181 | assert.Equal(t, 1.0, f1ScoreAccuracy.Evaluate(expected, actual)) 182 | 183 | customMetric := func(expected, actual map[string]interface{}) float64 { 184 | return 0.5 // Always return 0.5 for testing purposes 185 | } 186 | customAccuracy := NewAccuracy(customMetric) 187 | assert.Equal(t, 0.5, customAccuracy.Evaluate(expected, actual)) 188 | } 189 | -------------------------------------------------------------------------------- /pkg/modules/chain_of_thought.go: -------------------------------------------------------------------------------- 1 | package modules 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/XiaoConstantine/dspy-go/pkg/core" 7 | ) 8 | 9 | type ChainOfThought struct { 10 | Predict *Predict 11 | } 12 | 13 | var ( 14 | _ core.Module = (*ChainOfThought)(nil) 15 | _ core.Composable = (*ChainOfThought)(nil) 16 | ) 17 | 18 | func NewChainOfThought(signature core.Signature) *ChainOfThought { 19 | modifiedSignature := appendRationaleField(signature) 20 | return &ChainOfThought{ 21 | Predict: NewPredict(modifiedSignature), 22 | } 23 | } 24 | 25 | // WithDefaultOptions sets default options by configuring the underlying Predict module. 26 | func (c *ChainOfThought) WithDefaultOptions(opts ...core.Option) *ChainOfThought { 27 | // Simply delegate to the Predict module's WithDefaultOptions 28 | c.Predict.WithDefaultOptions(opts...) 29 | return c 30 | } 31 | 32 | func (c *ChainOfThought) Process(ctx context.Context, inputs map[string]any, opts ...core.Option) (map[string]any, error) { 33 | ctx, span := core.StartSpan(ctx, "ChainOfThought") 34 | defer core.EndSpan(ctx) 35 | 36 | span.WithAnnotation("inputs", inputs) 37 | outputs, err := c.Predict.Process(ctx, inputs, opts...) 38 | if err != nil { 39 | span.WithError(err) 40 | return nil, err 41 | } 42 | span.WithAnnotation("outputs", outputs) 43 | 44 | return outputs, nil 45 | } 46 | 47 | func (c *ChainOfThought) GetSignature() core.Signature { 48 | return c.Predict.GetSignature() 49 | } 50 | 51 | // SetSignature implements the core.Module interface. 52 | func (c *ChainOfThought) SetSignature(signature core.Signature) { 53 | modifiedSignature := appendRationaleField(signature) 54 | c.Predict.SetSignature(modifiedSignature) 55 | } 56 | 57 | func (c *ChainOfThought) SetLLM(llm core.LLM) { 58 | c.Predict.SetLLM(llm) 59 | } 60 | 61 | func (c *ChainOfThought) Clone() core.Module { 62 | return &ChainOfThought{ 63 | Predict: c.Predict.Clone().(*Predict), 64 | } 65 | } 66 | func (c *ChainOfThought) Compose(next core.Module) core.Module { 67 | return &core.ModuleChain{ 68 | Modules: []core.Module{c, next}, 69 | } 70 | } 71 | 72 | func (c *ChainOfThought) GetSubModules() []core.Module { 73 | return []core.Module{c.Predict} 74 | } 75 | 76 | func (c *ChainOfThought) SetSubModules(modules []core.Module) { 77 | if len(modules) > 0 { 78 | if predict, ok := modules[0].(*Predict); ok { 79 | c.Predict = predict 80 | } 81 | } 82 | } 83 | func appendRationaleField(signature core.Signature) core.Signature { 84 | newSignature := signature 85 | rationaleField := core.OutputField{ 86 | Field: core.NewField("rationale", 87 | core.WithDescription("Step-by-step reasoning process"), 88 | ), 89 | } 90 | newSignature.Outputs = append([]core.OutputField{rationaleField}, newSignature.Outputs...) 91 | 92 | return newSignature 93 | } 94 | -------------------------------------------------------------------------------- /pkg/modules/module.go: -------------------------------------------------------------------------------- 1 | package modules 2 | 3 | type Module interface { 4 | Forward(inputs map[string]interface{}) (Predict, error) 5 | } 6 | -------------------------------------------------------------------------------- /pkg/modules/predict_test.go: -------------------------------------------------------------------------------- 1 | package modules 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/XiaoConstantine/dspy-go/internal/testutil" 8 | "github.com/XiaoConstantine/dspy-go/pkg/core" 9 | "github.com/XiaoConstantine/dspy-go/pkg/errors" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/mock" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestPredict(t *testing.T) { 16 | // Create a mock LLM 17 | mockLLM := new(testutil.MockLLM) 18 | 19 | // Set up the expected behavior 20 | mockLLM.On("Generate", mock.Anything, mock.Anything, mock.Anything).Return(&core.LLMResponse{ 21 | Content: `answer: 22 | 42 23 | `, 24 | }, nil) 25 | 26 | // Create a Predict module 27 | signature := core.NewSignature( 28 | []core.InputField{{Field: core.Field{Name: "question"}}}, 29 | []core.OutputField{{Field: core.NewField("answer")}}, 30 | ) 31 | predict := NewPredict(signature) 32 | predict.SetLLM(mockLLM) 33 | 34 | // Test the Process method 35 | ctx := context.Background() 36 | ctx = core.WithExecutionState(ctx) 37 | 38 | inputs := map[string]any{"question": "What is the meaning of life?"} 39 | outputs, err := predict.Process(ctx, inputs) 40 | 41 | // Assert the results 42 | assert.NoError(t, err) 43 | assert.Equal(t, "42", outputs["answer"]) 44 | 45 | // Verify that the mock was called as expected 46 | mockLLM.AssertExpectations(t) 47 | // Verify traces 48 | spans := core.CollectSpans(ctx) 49 | require.Len(t, spans, 1) 50 | span := spans[0] 51 | 52 | inputsMap, _ := span.Annotations["inputs"].(map[string]interface{}) 53 | question, _ := inputsMap["question"].(string) 54 | 55 | outputsMap, _ := span.Annotations["outputs"].(map[string]interface{}) 56 | answer, _ := outputsMap["answer"].(string) 57 | 58 | assert.Contains(t, question, "What is the meaning of life?") 59 | assert.Contains(t, answer, "4") 60 | } 61 | 62 | func TestPredict_WithLLMError(t *testing.T) { 63 | // Create a mock LLM 64 | mockLLM := new(testutil.MockLLM) 65 | 66 | // Set up the expected behavior with an error 67 | expectedErr := errors.New(errors.LLMGenerationFailed, "LLM service unavailable") 68 | mockLLM.On("Generate", mock.Anything, mock.Anything, mock.Anything).Return((*core.LLMResponse)(nil), expectedErr) 69 | 70 | // Create a Predict module 71 | signature := core.NewSignature( 72 | []core.InputField{{Field: core.Field{Name: "question"}}}, 73 | []core.OutputField{{Field: core.NewField("answer")}}, 74 | ) 75 | predict := NewPredict(signature) 76 | predict.SetLLM(mockLLM) 77 | 78 | // Test the Process method with an error 79 | ctx := context.Background() 80 | ctx = core.WithExecutionState(ctx) 81 | 82 | inputs := map[string]any{"question": "What is the meaning of life?"} 83 | outputs, err := predict.Process(ctx, inputs) 84 | 85 | // Assert the results 86 | assert.Error(t, err) 87 | assert.Nil(t, outputs) 88 | assert.Contains(t, err.Error(), "failed to generate prediction") 89 | 90 | // Verify that the mock was called as expected 91 | mockLLM.AssertExpectations(t) 92 | } 93 | 94 | func TestPredict_WithMissingInput(t *testing.T) { 95 | // Create a mock LLM 96 | mockLLM := new(testutil.MockLLM) 97 | 98 | // Create a Predict module 99 | signature := core.NewSignature( 100 | []core.InputField{{Field: core.Field{Name: "question"}}}, 101 | []core.OutputField{{Field: core.NewField("answer")}}, 102 | ) 103 | predict := NewPredict(signature) 104 | predict.SetLLM(mockLLM) 105 | 106 | // Test the Process method with missing input 107 | ctx := context.Background() 108 | ctx = core.WithExecutionState(ctx) 109 | 110 | // Empty inputs map will cause validation to fail 111 | inputs := map[string]any{} 112 | outputs, err := predict.Process(ctx, inputs) 113 | 114 | // Assert the results 115 | assert.Error(t, err) 116 | assert.Nil(t, outputs) 117 | assert.Contains(t, err.Error(), "input validation failed") 118 | } 119 | 120 | func TestPredict_WithGenerateOptions(t *testing.T) { 121 | // Create a mock LLM that can capture the generate options 122 | mockLLM := new(testutil.MockLLM) 123 | 124 | var capturedOpts []core.GenerateOption 125 | 126 | // Set up the expected behavior 127 | mockLLM.On("Generate", mock.Anything, mock.Anything, mock.MatchedBy(func(opts []core.GenerateOption) bool { 128 | capturedOpts = opts 129 | return true 130 | })).Return(&core.LLMResponse{ 131 | Content: "answer: Test response", 132 | }, nil) 133 | 134 | // Create a Predict module with default options 135 | signature := core.NewSignature( 136 | []core.InputField{{Field: core.Field{Name: "question"}}}, 137 | []core.OutputField{{Field: core.NewField("answer")}}, 138 | ) 139 | predict := NewPredict(signature) 140 | predict.SetLLM(mockLLM) 141 | 142 | // Add default options 143 | predict.WithDefaultOptions( 144 | core.WithGenerateOptions( 145 | core.WithTemperature(0.8), 146 | core.WithMaxTokens(1000), 147 | ), 148 | ) 149 | 150 | // Call with additional process-specific options 151 | ctx := context.Background() 152 | inputs := map[string]any{"question": "Test question"} 153 | _, err := predict.Process(ctx, inputs, 154 | core.WithGenerateOptions( 155 | core.WithTemperature(0.5), // Override temperature 156 | ), 157 | ) 158 | 159 | // Verify results 160 | assert.NoError(t, err) 161 | assert.NotEmpty(t, capturedOpts) 162 | 163 | // We can't directly test the options since they're opaque functions, 164 | // but we can verify the mock was called with some options 165 | mockLLM.AssertExpectations(t) 166 | } 167 | 168 | func TestPredict_WithStreamHandler(t *testing.T) { 169 | // Create a mock LLM 170 | mockLLM := new(testutil.MockLLM) 171 | 172 | // Setup streaming 173 | streamConfig := &testutil.MockStreamConfig{ 174 | Content: "answer: Streaming response", 175 | ChunkSize: 5, 176 | TokenCounts: &core.TokenInfo{ 177 | PromptTokens: 10, 178 | }, 179 | } 180 | 181 | // Set up the mock behavior for streaming 182 | mockLLM.On("StreamGenerate", mock.Anything, mock.Anything, mock.Anything).Return(streamConfig, nil) 183 | 184 | // Create a Predict module 185 | signature := core.NewSignature( 186 | []core.InputField{{Field: core.Field{Name: "question"}}}, 187 | []core.OutputField{{Field: core.NewField("answer")}}, 188 | ) 189 | predict := NewPredict(signature) 190 | predict.SetLLM(mockLLM) 191 | 192 | // Create a handler to collect chunks 193 | var chunks []string 194 | handler := func(chunk core.StreamChunk) error { 195 | if !chunk.Done && chunk.Error == nil { 196 | chunks = append(chunks, chunk.Content) 197 | } 198 | return nil 199 | } 200 | 201 | // Process with streaming 202 | ctx := context.Background() 203 | inputs := map[string]any{"question": "Stream test"} 204 | outputs, err := predict.Process(ctx, inputs, core.WithStreamHandler(handler)) 205 | 206 | // Verify results 207 | assert.NoError(t, err) 208 | assert.NotNil(t, outputs) 209 | assert.Greater(t, len(chunks), 0, "Should have received some chunks") 210 | 211 | // Verify the mock was called with streaming 212 | mockLLM.AssertExpectations(t) 213 | } 214 | -------------------------------------------------------------------------------- /pkg/optimizers/bootstrap_fewshot.go: -------------------------------------------------------------------------------- 1 | package optimizers 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "sync" 8 | "sync/atomic" 9 | 10 | "github.com/XiaoConstantine/dspy-go/pkg/core" 11 | "github.com/XiaoConstantine/dspy-go/pkg/errors" 12 | "github.com/XiaoConstantine/dspy-go/pkg/modules" 13 | "github.com/sourcegraph/conc/pool" 14 | ) 15 | 16 | type BootstrapFewShot struct { 17 | Metric func(example map[string]interface{}, prediction map[string]interface{}, ctx context.Context) bool 18 | MaxBootstrapped int 19 | } 20 | 21 | func NewBootstrapFewShot(metric func(example map[string]interface{}, prediction map[string]interface{}, ctx context.Context) bool, maxBootstrapped int) *BootstrapFewShot { 22 | return &BootstrapFewShot{ 23 | Metric: metric, 24 | MaxBootstrapped: maxBootstrapped, 25 | } 26 | } 27 | 28 | func (b *BootstrapFewShot) Compile(ctx context.Context, student, teacher core.Program, trainset []map[string]interface{}) (core.Program, error) { 29 | compiledStudent := student.Clone() 30 | teacherLLM := core.GetTeacherLLM() 31 | if teacherLLM == nil { 32 | teacherLLM = core.GetDefaultLLM() 33 | } 34 | if ctx == nil { 35 | ctx = context.Background() 36 | } 37 | if core.GetExecutionState(ctx) == nil { 38 | ctx = core.WithExecutionState(ctx) 39 | } 40 | 41 | ctx = core.WithExecutionState(ctx) 42 | ctx, span := core.StartSpan(ctx, "Compilation") 43 | 44 | defer core.EndSpan(ctx) 45 | 46 | var ( 47 | resultsMu sync.Mutex 48 | results []struct { 49 | demo core.Example 50 | ctx context.Context 51 | } 52 | processed int32 53 | errCh = make(chan error, 1) 54 | ) 55 | examplesNeeded := b.MaxBootstrapped 56 | if examplesNeeded > len(trainset) { 57 | examplesNeeded = len(trainset) 58 | } 59 | 60 | p := pool.New().WithMaxGoroutines(core.GlobalConfig.ConcurrencyLevel) 61 | 62 | for i := 0; i < examplesNeeded; i++ { 63 | if b.enoughBootstrappedDemos(compiledStudent) { 64 | log.Println("Enough bootstrapped demos, breaking loop") 65 | break 66 | } 67 | 68 | ex := trainset[i] 69 | p.Go(func() { 70 | exampleCtx, exampleSpan := core.StartSpan(ctx, "Example") 71 | defer core.EndSpan(exampleCtx) 72 | 73 | exampleSpan.WithAnnotation("Example", ex) 74 | prediction, err := b.predictWithTeacher(ctx, teacher, teacherLLM, ex) 75 | if err != nil { 76 | exampleSpan.WithError(err) 77 | select { 78 | case errCh <- err: 79 | default: 80 | } 81 | return 82 | } 83 | exampleSpan.WithAnnotation("prediction", prediction) 84 | 85 | if b.Metric(ex, prediction, exampleCtx) { 86 | resultsMu.Lock() 87 | results = append(results, struct { 88 | demo core.Example 89 | ctx context.Context 90 | }{ 91 | demo: core.Example{ 92 | Inputs: ex, 93 | Outputs: prediction, 94 | }, 95 | ctx: exampleCtx, 96 | }) 97 | resultsMu.Unlock() 98 | } 99 | 100 | atomic.AddInt32(&processed, 1) 101 | }) 102 | } 103 | 104 | p.Wait() 105 | 106 | select { 107 | case err := <-errCh: 108 | span.WithError(err) 109 | return compiledStudent, fmt.Errorf("error during compilation: %w", err) 110 | default: 111 | } 112 | 113 | for _, result := range results { 114 | if err := b.addDemonstrations(compiledStudent, result.demo, result.ctx); err != nil { 115 | span.WithError(err) 116 | return compiledStudent, fmt.Errorf("error adding demonstrations: %w", err) 117 | } 118 | if b.enoughBootstrappedDemos(compiledStudent) { 119 | break 120 | } 121 | } 122 | 123 | span.WithAnnotation("compiledStudent", compiledStudent) 124 | return compiledStudent, nil 125 | } 126 | 127 | func (b *BootstrapFewShot) predictWithTeacher(ctx context.Context, teacher core.Program, teacherLLM core.LLM, example map[string]interface{}) (map[string]interface{}, error) { 128 | // Clone the teacher program and set its LLM to the teacher LLM 129 | teacherClone := teacher.Clone() 130 | for _, module := range teacherClone.Modules { 131 | if predictor, ok := module.(interface{ SetLLM(core.LLM) }); ok { 132 | predictor.SetLLM(teacherLLM) 133 | } 134 | } 135 | if ctx == nil { 136 | ctx = context.Background() 137 | } 138 | 139 | ctx = core.WithExecutionState(ctx) 140 | ctx, span := core.StartSpan(ctx, "TeacherPrediction") 141 | defer core.EndSpan(ctx) 142 | 143 | span.WithAnnotation("Example", example) 144 | outputs, err := teacherClone.Execute(ctx, example) 145 | if err != nil { 146 | span.WithError(err) 147 | return nil, err 148 | } 149 | 150 | span.WithAnnotation("outputs", outputs) 151 | return outputs, nil 152 | 153 | } 154 | 155 | func (b *BootstrapFewShot) enoughBootstrappedDemos(program core.Program) bool { 156 | for _, module := range program.Modules { 157 | if predictor, ok := module.(interface{ GetDemos() []core.Example }); ok { 158 | if len(predictor.GetDemos()) < b.MaxBootstrapped { 159 | return false 160 | } 161 | } 162 | } 163 | return true 164 | } 165 | 166 | func (b *BootstrapFewShot) addDemonstrations(program core.Program, demo core.Example, ctx context.Context) error { 167 | if ctx == nil { 168 | return errors.New(errors.InvalidInput, "cannot add demonstrations: context is nil") 169 | } 170 | 171 | ctx, span := core.StartSpan(ctx, "AddDemonstrations") 172 | defer core.EndSpan(ctx) 173 | span.WithAnnotation("demo_inputs", demo.Inputs) 174 | span.WithAnnotation("demo_outputs", demo.Outputs) 175 | for moduleName, module := range program.Modules { 176 | predictor, ok := module.(*modules.Predict) 177 | if !ok { 178 | continue 179 | } 180 | 181 | currentDemos := predictor.GetDemos() 182 | 183 | if len(currentDemos) < b.MaxBootstrapped { 184 | 185 | newDemos := append(currentDemos, demo) 186 | predictor.SetDemos(newDemos) 187 | span.WithAnnotation("added_to_module", moduleName) 188 | span.WithAnnotation("total_demos", len(newDemos)) 189 | return nil 190 | } else { 191 | span.WithAnnotation("skipped_module", moduleName) 192 | span.WithAnnotation("reason", "max_demos_reached") 193 | return errors.WithFields( 194 | errors.New(errors.ResourceExhausted, fmt.Sprintf("max demonstrations reached for module %s", moduleName)), 195 | errors.Fields{ 196 | "module": moduleName, 197 | "max_demos": b.MaxBootstrapped, 198 | "current_demos": len(currentDemos), 199 | }) 200 | } 201 | } 202 | return errors.New(errors.ResourceNotFound, "no suitable module found for adding demonstrations") 203 | } 204 | -------------------------------------------------------------------------------- /pkg/optimizers/bootstrap_fewshot_test.go: -------------------------------------------------------------------------------- 1 | package optimizers 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/XiaoConstantine/dspy-go/internal/testutil" 8 | "github.com/XiaoConstantine/dspy-go/pkg/core" 9 | "github.com/XiaoConstantine/dspy-go/pkg/modules" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/mock" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func init() { 16 | mockLLM := new(testutil.MockLLM) 17 | 18 | mockLLM.On("Generate", mock.Anything, mock.Anything, mock.Anything).Return(&core.LLMResponse{Content: `answer: 19 | Paris`}, nil) 20 | mockLLM.On("GenerateWithJSON", mock.Anything, mock.Anything, mock.Anything).Return(map[string]interface{}{"answer": "Paris"}, nil) 21 | 22 | core.GlobalConfig.DefaultLLM = mockLLM 23 | core.GlobalConfig.TeacherLLM = mockLLM 24 | core.GlobalConfig.ConcurrencyLevel = 1 25 | } 26 | 27 | func createProgram() core.Program { 28 | predict := modules.NewPredict(core.NewSignature( 29 | []core.InputField{{Field: core.Field{Name: "question"}}}, 30 | []core.OutputField{{Field: core.NewField("answer")}}, 31 | )) 32 | 33 | forwardFunc := func(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) { 34 | 35 | ctx, span := core.StartSpan(ctx, "Forward") 36 | defer core.EndSpan(ctx) 37 | span.WithAnnotation("inputs", inputs) 38 | outputs, err := predict.Process(ctx, inputs) 39 | if err != nil { 40 | span.WithError(err) 41 | return nil, err 42 | } 43 | span.WithAnnotation("outputs", outputs) 44 | return outputs, nil 45 | } 46 | 47 | return core.NewProgram(map[string]core.Module{"predict": predict}, forwardFunc) 48 | } 49 | 50 | func TestBootstrapFewShot(t *testing.T) { 51 | student := createProgram() 52 | teacher := createProgram() 53 | // Create training set 54 | trainset := []map[string]interface{}{ 55 | {"question": "What is the capital of France?"}, 56 | {"question": "What is the capital of Germany?"}, 57 | {"question": "What is the capital of Italy?"}, 58 | } 59 | 60 | // Define metric function 61 | metric := func(example, prediction map[string]interface{}, ctx context.Context) bool { 62 | return true // Always return true for this test 63 | } 64 | 65 | // Create BootstrapFewShot optimizer 66 | maxBootstrapped := 2 67 | optimizer := NewBootstrapFewShot(metric, maxBootstrapped) 68 | 69 | ctx := core.WithExecutionState(context.Background()) 70 | 71 | // Compile the program 72 | optimizedProgram, _ := optimizer.Compile(ctx, student, teacher, trainset) 73 | 74 | // Check if the optimized program has the correct number of demonstrations 75 | optimizedPredict, ok := optimizedProgram.Modules["predict"].(*modules.Predict) 76 | assert.True(t, ok) 77 | assert.Equal(t, maxBootstrapped, len(optimizedPredict.Demos)) 78 | 79 | // Check if the demonstrations are correct 80 | for _, demo := range optimizedPredict.Demos { 81 | assert.Contains(t, demo.Inputs, "question") 82 | assert.Contains(t, demo.Outputs, "answer") 83 | assert.Equal(t, "Paris", demo.Outputs["answer"]) 84 | } 85 | // Verify the trace structure 86 | spans := core.CollectSpans(ctx) 87 | require.NotEmpty(t, spans, "Expected spans to be recorded") 88 | rootSpan := spans[0] 89 | assert.Equal(t, "Compilation", rootSpan.Operation, "Expected Compilation as root span") 90 | 91 | // Find Example spans (should be direct children of Compilation) 92 | var exampleSpans []*core.Span 93 | for _, span := range spans { 94 | if span.Operation == "Example" { 95 | exampleSpans = append(exampleSpans, span) 96 | } 97 | } 98 | assert.Equal(t, maxBootstrapped, len(exampleSpans), 99 | "Expected number of Example spans to match maxBootstrapped") 100 | 101 | // Verify span structure and content 102 | var compilationSpan *core.Span 103 | for _, span := range spans { 104 | if span.Operation == "Compilation" { 105 | compilationSpan = span 106 | } 107 | } 108 | 109 | // Verify compilation span 110 | require.NotNil(t, compilationSpan, "Expected to find compilation span") 111 | assert.NotZero(t, compilationSpan.StartTime) 112 | assert.Nil(t, compilationSpan.Error) 113 | 114 | } 115 | 116 | func TestBootstrapFewShotEdgeCases(t *testing.T) { 117 | 118 | trainset := []map[string]interface{}{ 119 | {"question": "Q1"}, 120 | {"question": "Q2"}, 121 | {"question": "Q3"}, 122 | } 123 | 124 | t.Run("MaxBootstrapped Zero", func(t *testing.T) { 125 | optimizer := NewBootstrapFewShot(func(_, _ map[string]interface{}, _ context.Context) bool { return true }, 0) 126 | ctx := context.Background() 127 | 128 | optimized, err := optimizer.Compile(ctx, createProgram(), createProgram(), trainset) 129 | assert.NoError(t, err) 130 | assert.Equal(t, 0, len(optimized.Modules["predict"].(*modules.Predict).Demos)) 131 | }) 132 | 133 | t.Run("MaxBootstrapped Large", func(t *testing.T) { 134 | optimizer := NewBootstrapFewShot(func(_, _ map[string]interface{}, _ context.Context) bool { 135 | return true 136 | }, 100) 137 | ctx := context.Background() 138 | 139 | optimized, err := optimizer.Compile(ctx, createProgram(), createProgram(), trainset) 140 | if err != nil { 141 | t.Fatalf("Compilation failed: %v", err) 142 | } 143 | demoCount := len(optimized.Modules["predict"].(*modules.Predict).Demos) 144 | assert.Equal(t, len(trainset), demoCount) 145 | }) 146 | 147 | t.Run("Metric Rejects All", func(t *testing.T) { 148 | optimizer := NewBootstrapFewShot(func(_, _ map[string]interface{}, _ context.Context) bool { return false }, 2) 149 | ctx := context.Background() 150 | optimized, _ := optimizer.Compile(ctx, createProgram(), createProgram(), trainset) 151 | assert.Equal(t, 0, len(optimized.Modules["predict"].(*modules.Predict).Demos)) 152 | }) 153 | } 154 | -------------------------------------------------------------------------------- /pkg/optimizers/copro.go: -------------------------------------------------------------------------------- 1 | package optimizers 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/XiaoConstantine/dspy-go/pkg/core" 8 | "github.com/XiaoConstantine/dspy-go/pkg/modules" 9 | ) 10 | 11 | type Copro struct { 12 | Metric func(example, prediction map[string]interface{}, ctx context.Context) bool 13 | MaxBootstrapped int 14 | SubOptimizer core.Optimizer 15 | } 16 | 17 | func NewCopro(metric func(example, prediction map[string]interface{}, ctx context.Context) bool, maxBootstrapped int, subOptimizer core.Optimizer) *Copro { 18 | return &Copro{ 19 | Metric: metric, 20 | MaxBootstrapped: maxBootstrapped, 21 | SubOptimizer: subOptimizer, 22 | } 23 | } 24 | func (c *Copro) Compile(ctx context.Context, program core.Program, dataset core.Dataset, metric core.Metric) (core.Program, error) { 25 | compiledProgram := program.Clone() 26 | // Ensure execution state exists 27 | if core.GetExecutionState(ctx) == nil { 28 | ctx = core.WithExecutionState(ctx) 29 | } 30 | 31 | ctx, compilationSpan := core.StartSpan(ctx, "CoproCompilation") 32 | defer core.EndSpan(ctx) 33 | 34 | wrappedMetric := func(expected, actual map[string]interface{}) float64 { 35 | metricCtx, metricSpan := core.StartSpan(ctx, "MetricEvaluation") 36 | defer core.EndSpan(metricCtx) 37 | 38 | metricSpan.WithAnnotation("expected", expected) 39 | metricSpan.WithAnnotation("actual", actual) 40 | 41 | // Use the context-based metric 42 | if c.Metric(expected, actual, metricCtx) { 43 | metricSpan.WithAnnotation("result", 1.0) 44 | return 1.0 45 | } 46 | 47 | metricSpan.WithAnnotation("result", 0.0) 48 | 49 | return 0.0 50 | } 51 | for moduleName, module := range compiledProgram.Modules { 52 | moduleCtx, moduleSpan := core.StartSpan(ctx, fmt.Sprintf("Module_%s", moduleName)) 53 | 54 | compiledModule, err := c.compileModule(ctx, module, dataset, wrappedMetric) 55 | if err != nil { 56 | moduleSpan.WithError(err) 57 | core.EndSpan(moduleCtx) 58 | compilationSpan.WithError(err) 59 | 60 | return compiledProgram, fmt.Errorf("error compiling module %s: %w", moduleName, err) 61 | } 62 | 63 | compiledProgram.Modules[moduleName] = compiledModule 64 | moduleSpan.WithAnnotation("compiledModule", compiledModule) 65 | core.EndSpan(moduleCtx) 66 | 67 | } 68 | 69 | compilationSpan.WithAnnotation("compiledProgram", compiledProgram) 70 | 71 | return compiledProgram, nil 72 | } 73 | 74 | func (c *Copro) compileModule(ctx context.Context, module core.Module, dataset core.Dataset, metric core.Metric) (core.Module, error) { 75 | switch m := module.(type) { 76 | case *modules.Predict: 77 | // Create a temporary Program with just this Predict module 78 | tempProgram := core.NewProgram(map[string]core.Module{"predict": m}, nil) 79 | // Compile using the SubOptimizer 80 | compiledProgram, err := c.SubOptimizer.Compile(ctx, tempProgram, dataset, metric) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | // Extract the optimized Predict module from the compiled Program 86 | optimizedPredict, ok := compiledProgram.Modules["predict"] 87 | if !ok { 88 | return nil, fmt.Errorf("compiled program does not contain 'predict' module") 89 | } 90 | 91 | return optimizedPredict, nil 92 | 93 | case core.Composable: 94 | subModules := m.GetSubModules() 95 | for i, subModule := range subModules { 96 | compiledSubModule, err := c.compileModule(ctx, subModule, dataset, metric) 97 | if err != nil { 98 | return nil, err 99 | } 100 | subModules[i] = compiledSubModule 101 | } 102 | m.SetSubModules(subModules) 103 | return m, nil 104 | 105 | default: 106 | // For non-Predict, non-Composable modules, return as-is 107 | 108 | return m, nil 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /pkg/optimizers/copro_test.go: -------------------------------------------------------------------------------- 1 | package optimizers 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/XiaoConstantine/dspy-go/internal/testutil" 8 | "github.com/XiaoConstantine/dspy-go/pkg/core" 9 | "github.com/XiaoConstantine/dspy-go/pkg/modules" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/mock" 12 | ) 13 | 14 | // MockModule is a mock implementation of core.Module. 15 | type MockModule struct { 16 | mock.Mock 17 | } 18 | 19 | func (m *MockModule) Process(ctx context.Context, inputs map[string]any, opts ...core.Option) (map[string]any, error) { 20 | args := m.Called(ctx, inputs) 21 | return args.Get(0).(map[string]any), args.Error(1) 22 | } 23 | 24 | func (m *MockModule) GetSignature() core.Signature { 25 | args := m.Called() 26 | return args.Get(0).(core.Signature) 27 | } 28 | 29 | func (m *MockModule) SetLLM(llm core.LLM) { 30 | m.Called(llm) 31 | } 32 | 33 | func (m *MockModule) Clone() core.Module { 34 | args := m.Called() 35 | return args.Get(0).(core.Module) 36 | } 37 | 38 | func (m *MockModule) SetSignature(signature core.Signature) { 39 | m.Called(signature) 40 | } 41 | 42 | // MockOptimizer is a mock implementation of core.Optimizer. 43 | type MockOptimizer struct { 44 | mock.Mock 45 | } 46 | 47 | func (m *MockOptimizer) Compile(ctx context.Context, program core.Program, dataset core.Dataset, metric core.Metric) (core.Program, error) { 48 | args := m.Called(ctx, program, dataset, metric) 49 | return args.Get(0).(core.Program), args.Error(1) 50 | } 51 | 52 | func TestCoproCompile(t *testing.T) { 53 | // Create mock objects 54 | mockModule := new(MockModule) 55 | mockSubOptimizer := new(MockOptimizer) 56 | mockDataset := new(testutil.MockDataset) 57 | 58 | // Create a test program 59 | testProgram := core.Program{ 60 | Modules: map[string]core.Module{ 61 | "test": mockModule, 62 | }, 63 | } 64 | 65 | // Create a Copro instance 66 | copro := NewCopro( 67 | func(example, prediction map[string]interface{}, ctx context.Context) bool { return true }, 68 | 5, 69 | mockSubOptimizer, 70 | ) 71 | 72 | // Set up expectations 73 | mockModule.On("Clone").Return(mockModule) 74 | // We no longer expect Compile to be called on non-Predict modules 75 | // mockSubOptimizer.On("Compile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(testProgram, nil) 76 | 77 | // Create a context with trace manager 78 | ctx := core.WithExecutionState(context.Background()) 79 | 80 | // Call Compile 81 | compiledProgram, err := copro.Compile(ctx, testProgram, mockDataset, nil) 82 | 83 | // Assert expectations 84 | assert.NoError(t, err) 85 | assert.NotNil(t, compiledProgram) 86 | assert.Equal(t, 1, len(compiledProgram.Modules)) 87 | assert.Contains(t, compiledProgram.Modules, "test") 88 | assert.Equal(t, mockModule, compiledProgram.Modules["test"]) // The module should be unchanged 89 | 90 | mockModule.AssertExpectations(t) 91 | mockSubOptimizer.AssertExpectations(t) 92 | } 93 | 94 | func TestCoproCompileWithPredict(t *testing.T) { 95 | // Create mock objects 96 | mockPredict := modules.NewPredict(core.Signature{}) 97 | mockSubOptimizer := new(MockOptimizer) 98 | mockDataset := new(testutil.MockDataset) 99 | 100 | // Create a test program 101 | testProgram := core.Program{ 102 | Modules: map[string]core.Module{ 103 | "predict": mockPredict, 104 | }, 105 | } 106 | 107 | // Create a Copro instance 108 | copro := NewCopro( 109 | func(example, prediction map[string]interface{}, ctx context.Context) bool { return true }, 110 | 5, 111 | mockSubOptimizer, 112 | ) 113 | 114 | // Set up expectations 115 | mockSubOptimizer.On("Compile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(testProgram, nil) 116 | 117 | // Create a context with trace manager 118 | ctx := core.WithExecutionState(context.Background()) 119 | 120 | // Call Compile 121 | compiledProgram, err := copro.Compile(ctx, testProgram, mockDataset, nil) 122 | 123 | // Assert expectations 124 | assert.NoError(t, err) 125 | assert.NotNil(t, compiledProgram) 126 | assert.Equal(t, 1, len(compiledProgram.Modules)) 127 | assert.Contains(t, compiledProgram.Modules, "predict") 128 | 129 | mockSubOptimizer.AssertExpectations(t) 130 | } 131 | 132 | func TestCoproCompileError(t *testing.T) { 133 | // Create mock objects 134 | mockSubOptimizer := new(MockOptimizer) 135 | mockDataset := new(testutil.MockDataset) 136 | 137 | // Create a test program 138 | testProgram := core.Program{ 139 | Modules: map[string]core.Module{ 140 | "test": &modules.Predict{}, // Use a real Predict module instead of mockModule 141 | }, 142 | } 143 | 144 | // Create a Copro instance 145 | copro := NewCopro( 146 | func(example, prediction map[string]interface{}, ctx context.Context) bool { return true }, 147 | 5, 148 | mockSubOptimizer, 149 | ) 150 | 151 | // Set up expectations 152 | mockSubOptimizer.On("Compile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.Program{}, assert.AnError) 153 | 154 | // Create a context with trace manager 155 | ctx := core.WithExecutionState(context.Background()) 156 | 157 | // Call Compile 158 | _, err := copro.Compile(ctx, testProgram, mockDataset, nil) 159 | 160 | // Assert expectations 161 | assert.Error(t, err) 162 | assert.Contains(t, err.Error(), "error compiling module test") 163 | 164 | mockSubOptimizer.AssertExpectations(t) 165 | } 166 | -------------------------------------------------------------------------------- /pkg/tools/func.go: -------------------------------------------------------------------------------- 1 | // pkg/tools/func.go 2 | package tools 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | 8 | "github.com/XiaoConstantine/dspy-go/pkg/core" 9 | models "github.com/XiaoConstantine/mcp-go/pkg/model" 10 | ) 11 | 12 | // ToolFunc represents a function that can be called as a tool. 13 | type ToolFunc func(ctx context.Context, args map[string]interface{}) (*models.CallToolResult, error) 14 | 15 | // FuncTool wraps a Go function as a Tool implementation. 16 | type FuncTool struct { 17 | name string 18 | description string 19 | schema models.InputSchema 20 | fn ToolFunc 21 | metadata *core.ToolMetadata 22 | matchCutoff float64 23 | } 24 | 25 | // NewFuncTool creates a new function-based tool. 26 | func NewFuncTool(name, description string, schema models.InputSchema, fn ToolFunc) *FuncTool { 27 | // Extract capabilities from description 28 | capabilities := extractCapabilities(description) 29 | 30 | // Create the metadata with the full schema - no conversion needed! 31 | metadata := &core.ToolMetadata{ 32 | Name: name, 33 | Description: description, 34 | InputSchema: schema, // Use the schema directly - no conversion required 35 | Capabilities: capabilities, 36 | Version: "1.0.0", 37 | } 38 | 39 | return &FuncTool{ 40 | name: name, 41 | description: description, 42 | schema: schema, 43 | fn: fn, 44 | metadata: metadata, 45 | matchCutoff: 0.3, 46 | } 47 | } 48 | 49 | // Name returns the tool's identifier. 50 | func (t *FuncTool) Name() string { 51 | return t.name 52 | } 53 | 54 | // Description returns human-readable explanation of the tool. 55 | func (t *FuncTool) Description() string { 56 | return t.description 57 | } 58 | 59 | // InputSchema returns the expected parameter structure. 60 | func (t *FuncTool) InputSchema() models.InputSchema { 61 | return t.schema 62 | } 63 | 64 | // Metadata returns the tool's metadata for intent matching. 65 | func (t *FuncTool) Metadata() *core.ToolMetadata { 66 | return t.metadata 67 | } 68 | 69 | // CanHandle checks if the tool can handle a specific action/intent. 70 | func (t *FuncTool) CanHandle(ctx context.Context, intent string) bool { 71 | score := calculateToolMatchScore(t.metadata, intent) 72 | return score >= t.matchCutoff 73 | } 74 | 75 | // Call executes the wrapped function with the provided arguments. 76 | func (t *FuncTool) Call(ctx context.Context, args map[string]interface{}) (*models.CallToolResult, error) { 77 | return t.fn(ctx, args) 78 | } 79 | 80 | // Execute runs the tool with provided parameters and adapts the result to the core interface. 81 | func (t *FuncTool) Execute(ctx context.Context, params map[string]interface{}) (core.ToolResult, error) { 82 | result, err := t.Call(ctx, params) 83 | if err != nil { 84 | return core.ToolResult{}, err 85 | } 86 | 87 | // Convert CallToolResult to core.ToolResult 88 | toolResult := core.ToolResult{ 89 | Data: extractContentText(result.Content), 90 | Metadata: map[string]interface{}{"isError": result.IsError}, 91 | Annotations: map[string]interface{}{}, 92 | } 93 | 94 | return toolResult, nil 95 | } 96 | 97 | // Validate checks if the parameters match the expected schema. 98 | func (t *FuncTool) Validate(params map[string]interface{}) error { 99 | // Use the full InputSchema for validation 100 | for name, param := range t.schema.Properties { 101 | if param.Required { 102 | if _, exists := params[name]; !exists { 103 | return fmt.Errorf("missing required parameter: %s", name) 104 | } 105 | 106 | // We could add type checking here based on param.Type 107 | // For example, check if numbers are actually numbers, etc. 108 | } 109 | } 110 | 111 | return nil 112 | } 113 | 114 | // Type returns the tool type. 115 | func (t *FuncTool) Type() ToolType { 116 | return ToolTypeFunc 117 | } 118 | -------------------------------------------------------------------------------- /pkg/tools/mcp.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | 9 | "github.com/XiaoConstantine/dspy-go/pkg/core" 10 | "github.com/XiaoConstantine/dspy-go/pkg/logging" 11 | "github.com/XiaoConstantine/mcp-go/pkg/client" 12 | models "github.com/XiaoConstantine/mcp-go/pkg/model" 13 | ) 14 | 15 | // MCPTool represents a tool that delegates to an MCP server. 16 | type MCPTool struct { 17 | name string 18 | description string 19 | schema models.InputSchema 20 | client *client.Client 21 | toolName string 22 | metadata *core.ToolMetadata 23 | matchCutoff float64 24 | } 25 | 26 | // NewMCPTool creates a new MCP-based tool. 27 | func NewMCPTool(name, description string, schema models.InputSchema, 28 | client *client.Client, toolName string) *MCPTool { 29 | 30 | // Extract capabilities from description 31 | capabilities := extractCapabilities(description) 32 | 33 | // Create the metadata with the full schema - no conversion needed! 34 | metadata := &core.ToolMetadata{ 35 | Name: name, 36 | Description: description, 37 | InputSchema: schema, // Use the schema directly - no conversion required 38 | Capabilities: capabilities, 39 | Version: "1.0.0", 40 | } 41 | 42 | return &MCPTool{ 43 | name: name, 44 | description: description, 45 | schema: schema, 46 | client: client, 47 | toolName: toolName, 48 | metadata: metadata, 49 | matchCutoff: 0.3, 50 | } 51 | } 52 | 53 | // Name returns the tool's identifier. 54 | func (t *MCPTool) Name() string { 55 | return t.name 56 | } 57 | 58 | // Description returns human-readable explanation of the tool. 59 | func (t *MCPTool) Description() string { 60 | return t.description 61 | } 62 | 63 | // InputSchema returns the expected parameter structure. 64 | func (t *MCPTool) InputSchema() models.InputSchema { 65 | return t.schema 66 | } 67 | 68 | // Metadata returns the tool's metadata for intent matching. 69 | func (t *MCPTool) Metadata() *core.ToolMetadata { 70 | return t.metadata 71 | } 72 | 73 | // CanHandle checks if the tool can handle a specific action/intent. 74 | func (t *MCPTool) CanHandle(ctx context.Context, intent string) bool { 75 | score := calculateToolMatchScore(t.metadata, intent) 76 | return score >= t.matchCutoff 77 | } 78 | 79 | // Call forwards the call to the MCP server and returns the result. 80 | func (t *MCPTool) Call(ctx context.Context, args map[string]interface{}) (*models.CallToolResult, error) { 81 | return t.client.CallTool(ctx, t.toolName, args) 82 | } 83 | 84 | // Execute runs the tool with provided parameters and adapts the result to the core interface. 85 | func (t *MCPTool) Execute(ctx context.Context, params map[string]interface{}) (core.ToolResult, error) { 86 | 87 | convertedParams := convertMCPParams(ctx, t.schema, params) // Call the helper function 88 | result, err := t.Call(ctx, convertedParams) 89 | if err != nil { 90 | return core.ToolResult{}, err 91 | } 92 | 93 | // Convert MCP call result to core.ToolResult 94 | toolResult := core.ToolResult{ 95 | Data: extractContentText(result.Content), 96 | Metadata: map[string]interface{}{"isError": result.IsError}, 97 | Annotations: map[string]interface{}{}, 98 | } 99 | 100 | return toolResult, nil 101 | } 102 | 103 | // Validate checks if the parameters match the expected schema. 104 | func (t *MCPTool) Validate(params map[string]interface{}) error { 105 | // Use the full InputSchema for validation 106 | for name, param := range t.schema.Properties { 107 | if param.Required { 108 | if _, exists := params[name]; !exists { 109 | return fmt.Errorf("missing required parameter: %s", name) 110 | } 111 | } 112 | } 113 | 114 | return nil 115 | } 116 | 117 | // Type returns the tool type. 118 | func (t *MCPTool) Type() ToolType { 119 | return ToolTypeMCP 120 | } 121 | 122 | // convertMCPParams attempts to convert parameter values based on the provided MCP schema. 123 | // It prioritizes converting strings to numbers/integers if the schema specifies. 124 | func convertMCPParams(ctx context.Context, schema models.InputSchema, params map[string]interface{}) map[string]interface{} { 125 | logger := logging.GetLogger() 126 | convertedParams := make(map[string]interface{}) 127 | 128 | for key, value := range params { 129 | convertedParams[key] = value // Default: keep original value 130 | 131 | prop, schemaHasKey := schema.Properties[key] 132 | if !schemaHasKey { 133 | continue // No schema info for this key, keep original value 134 | } 135 | 136 | expectedType := strings.ToLower(prop.Type) 137 | currentValueStr, isString := value.(string) 138 | 139 | if isString { 140 | var conversionErr error 141 | // Attempt conversion if schema expects a number/integer and we have a string 142 | switch expectedType { 143 | case "number", "float": 144 | if floatVal, err := strconv.ParseFloat(currentValueStr, 64); err == nil { 145 | convertedParams[key] = floatVal 146 | // logger.Debug(ctx, "Converted MCP param '%s' from string to float64", key) 147 | } else { 148 | conversionErr = err 149 | } 150 | case "integer": 151 | if intVal, err := strconv.Atoi(currentValueStr); err == nil { 152 | convertedParams[key] = intVal 153 | logger.Debug(ctx, "Converted MCP param '%s' from string to int", key) 154 | } else { 155 | conversionErr = err 156 | } 157 | } 158 | 159 | if conversionErr != nil { 160 | logger.Warn(ctx, "Failed to convert param '%s' ('%s') to expected type '%s': %v. Using original string.", key, currentValueStr, prop.Type, conversionErr) 161 | // Keep original string value if conversion fails 162 | } 163 | } 164 | // Handle cases where the input is already a number but schema expects string? (Less common) 165 | } 166 | return convertedParams 167 | } 168 | -------------------------------------------------------------------------------- /pkg/tools/registry.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "strings" 5 | "sync" 6 | 7 | "github.com/XiaoConstantine/dspy-go/pkg/core" 8 | "github.com/XiaoConstantine/dspy-go/pkg/errors" 9 | ) 10 | 11 | // InMemoryToolRegistry provides a basic in-memory implementation of the ToolRegistry interface. 12 | type InMemoryToolRegistry struct { 13 | mu sync.RWMutex 14 | tools map[string]core.Tool 15 | } 16 | 17 | // NewInMemoryToolRegistry creates a new, empty InMemoryToolRegistry. 18 | func NewInMemoryToolRegistry() *InMemoryToolRegistry { 19 | return &InMemoryToolRegistry{ 20 | tools: make(map[string]core.Tool), 21 | } 22 | } 23 | 24 | // Register adds a tool to the registry. 25 | // It returns an error if a tool with the same name already exists. 26 | func (r *InMemoryToolRegistry) Register(tool core.Tool) error { 27 | r.mu.Lock() 28 | defer r.mu.Unlock() 29 | 30 | if tool == nil { 31 | return errors.New(errors.InvalidInput, "cannot register a nil tool") 32 | } 33 | 34 | name := tool.Name() 35 | if _, exists := r.tools[name]; exists { 36 | return errors.WithFields(errors.New(errors.InvalidInput, "tool already registered"), errors.Fields{ 37 | "tool_name": name, 38 | }) 39 | } 40 | 41 | r.tools[name] = tool 42 | return nil 43 | } 44 | 45 | // Get retrieves a tool by its name. 46 | // It returns an error if the tool is not found. 47 | func (r *InMemoryToolRegistry) Get(name string) (core.Tool, error) { 48 | r.mu.RLock() 49 | defer r.mu.RUnlock() 50 | 51 | tool, exists := r.tools[name] 52 | if !exists { 53 | return nil, errors.WithFields(errors.New(errors.ResourceNotFound, "tool not found"), errors.Fields{ 54 | "tool_name": name, 55 | }) 56 | } 57 | return tool, nil 58 | } 59 | 60 | // List returns a slice of all registered tools. 61 | // The order is not guaranteed. 62 | func (r *InMemoryToolRegistry) List() []core.Tool { 63 | r.mu.RLock() 64 | defer r.mu.RUnlock() 65 | 66 | list := make([]core.Tool, 0, len(r.tools)) 67 | for _, tool := range r.tools { 68 | list = append(list, tool) 69 | } 70 | return list 71 | } 72 | 73 | // Match finds tools that might match a given intent string. 74 | // This basic implementation checks if the intent contains the tool name (case-insensitive). 75 | // More sophisticated matching (e.g., using descriptions or CanHandle) could be added. 76 | func (r *InMemoryToolRegistry) Match(intent string) []core.Tool { 77 | r.mu.RLock() 78 | defer r.mu.RUnlock() 79 | 80 | var matches []core.Tool 81 | lowerIntent := strings.ToLower(intent) 82 | 83 | for name, tool := range r.tools { 84 | // Simple substring match on name 85 | if strings.Contains(lowerIntent, strings.ToLower(name)) { 86 | matches = append(matches, tool) 87 | continue 88 | } 89 | } 90 | return matches 91 | } 92 | 93 | // Ensure InMemoryToolRegistry implements the interface. 94 | var _ core.ToolRegistry = (*InMemoryToolRegistry)(nil) 95 | -------------------------------------------------------------------------------- /pkg/tools/registry_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | testutil "github.com/XiaoConstantine/dspy-go/internal/testutil" 8 | pkgErrors "github.com/XiaoConstantine/dspy-go/pkg/errors" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestNewInMemoryRegistry(t *testing.T) { 14 | registry := NewInMemoryToolRegistry() // Use the new constructor 15 | require.NotNil(t, registry, "Expected non-nil registry") 16 | assert.NotNil(t, registry.tools, "Expected tools map to be initialized") 17 | assert.Empty(t, registry.tools, "Expected empty tools map") 18 | } 19 | 20 | func TestInMemoryRegister(t *testing.T) { 21 | registry := NewInMemoryToolRegistry() 22 | 23 | // Create a mock core tool using the constructor 24 | mockTool := testutil.NewMockCoreTool("test-tool", "A test tool", nil) 25 | 26 | // Test successful registration 27 | err := registry.Register(mockTool) 28 | assert.NoError(t, err, "Unexpected error during registration") 29 | 30 | // Verify registration 31 | retrievedTool, err := registry.Get("test-tool") 32 | assert.NoError(t, err, "Failed to get registered tool") 33 | assert.Equal(t, mockTool, retrievedTool, "Retrieved tool does not match registered tool") 34 | 35 | // Test duplicate registration 36 | err = registry.Register(mockTool) 37 | assert.Error(t, err, "Expected error for duplicate registration, got nil") 38 | if err != nil { 39 | // Check if it's the expected error type using type assertion 40 | e, ok := err.(*pkgErrors.Error) 41 | assert.True(t, ok, "Error should be of type *pkgErrors.Error") 42 | if ok { 43 | assert.Equal(t, pkgErrors.InvalidInput, e.Code(), "Expected InvalidInput error code") 44 | } 45 | } 46 | 47 | // Test registering nil tool 48 | err = registry.Register(nil) 49 | assert.Error(t, err, "Expected error for registering nil tool") 50 | if err != nil { 51 | e, ok := err.(*pkgErrors.Error) 52 | assert.True(t, ok, "Error should be of type *pkgErrors.Error") 53 | if ok { 54 | assert.Equal(t, pkgErrors.InvalidInput, e.Code(), "Expected InvalidInput error code") 55 | } 56 | } 57 | } 58 | 59 | func TestInMemoryGet(t *testing.T) { 60 | registry := NewInMemoryToolRegistry() 61 | 62 | // Create and register a mock core tool using the constructor 63 | mockTool := testutil.NewMockCoreTool("test-tool", "", nil) 64 | err := registry.Register(mockTool) 65 | require.NoError(t, err) 66 | 67 | // Test successful retrieval 68 | tool, err := registry.Get("test-tool") 69 | assert.NoError(t, err, "Unexpected error getting existing tool") 70 | require.NotNil(t, tool, "Expected non-nil tool") 71 | assert.Equal(t, "test-tool", tool.Name(), "Expected name 'test-tool'") 72 | 73 | // Test non-existent tool 74 | tool, err = registry.Get("non-existent-tool") 75 | assert.Error(t, err, "Expected error for non-existent tool, got nil") 76 | assert.Nil(t, tool, "Expected nil tool for non-existent name") 77 | if err != nil { 78 | e, ok := err.(*pkgErrors.Error) 79 | assert.True(t, ok, "Error should be of type *pkgErrors.Error") 80 | if ok { 81 | assert.Equal(t, pkgErrors.ResourceNotFound, e.Code(), "Expected ResourceNotFound error code") 82 | } 83 | } 84 | } 85 | 86 | func TestInMemoryList(t *testing.T) { 87 | registry := NewInMemoryToolRegistry() 88 | 89 | // Check empty list 90 | tools := registry.List() 91 | assert.Empty(t, tools, "Expected empty list for new registry") 92 | 93 | // Add some tools using the constructor 94 | mockTool1 := testutil.NewMockCoreTool("tool1", "Tool 1", nil) 95 | mockTool2 := testutil.NewMockCoreTool("tool2", "Tool 2", nil) 96 | 97 | err := registry.Register(mockTool1) 98 | require.NoError(t, err) 99 | err = registry.Register(mockTool2) 100 | require.NoError(t, err) 101 | 102 | // Check list with tools 103 | tools = registry.List() 104 | assert.Len(t, tools, 2, "Expected 2 tools in the list") 105 | 106 | // Check that all tools are in the list (order doesn't matter) 107 | foundNames := make(map[string]bool) 108 | for _, tool := range tools { 109 | foundNames[tool.Name()] = true 110 | } 111 | assert.True(t, foundNames["tool1"], "Expected tool1 in the list") 112 | assert.True(t, foundNames["tool2"], "Expected tool2 in the list") 113 | } 114 | 115 | func TestInMemoryMatch(t *testing.T) { 116 | registry := NewInMemoryToolRegistry() 117 | 118 | // Add some tools using the constructor 119 | mockTool1 := testutil.NewMockCoreTool("ReadFile", "Reads a file", nil) 120 | mockTool2 := testutil.NewMockCoreTool("WriteFile", "Writes a file", nil) 121 | mockTool3 := testutil.NewMockCoreTool("ListFiles", "Lists directory contents", nil) 122 | 123 | err := registry.Register(mockTool1) 124 | require.NoError(t, err) 125 | err = registry.Register(mockTool2) 126 | require.NoError(t, err) 127 | err = registry.Register(mockTool3) 128 | require.NoError(t, err) 129 | 130 | tests := []struct { 131 | intent string 132 | expectedLen int 133 | expectedName string // Only checks first match if len=1 134 | }{ 135 | {"read the input file", 0, ""}, // "read the input file" does NOT contain "readfile" 136 | {"I want to WRITE a file", 0, ""}, // "i want to write a file" does NOT contain "writefile" 137 | {"list the files in the directory", 0, ""}, // does NOT contain "listfiles" 138 | {"something about files", 0, ""}, // does NOT contain "readfile", "writefile", or "listfiles" 139 | {"delete everything", 0, ""}, // No matching tool names 140 | {"READFILE", 1, "ReadFile"}, // "readfile" contains "readfile" 141 | {"use WriteFile command", 1, "WriteFile"}, // "use writefile command" contains "writefile" 142 | {"ListFilesPlease", 1, "ListFiles"}, // "listfilesplease" contains "listfiles" 143 | {"no match", 0, ""}, 144 | } 145 | 146 | for _, tt := range tests { 147 | t.Run(fmt.Sprintf("Intent_%s", tt.intent), func(t *testing.T) { 148 | matches := registry.Match(tt.intent) 149 | assert.Len(t, matches, tt.expectedLen, "Incorrect number of matches for intent: '%s'", tt.intent) 150 | if tt.expectedLen == 1 && len(matches) == 1 { 151 | assert.Equal(t, tt.expectedName, matches[0].Name(), "Incorrect tool matched for intent: '%s'", tt.intent) 152 | } 153 | }) 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /pkg/tools/tool.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "encoding/xml" 6 | 7 | models "github.com/XiaoConstantine/mcp-go/pkg/model" 8 | ) 9 | 10 | // Tool defines a callable tool interface that abstracts both local functions 11 | // and remote MCP tools. This provides a unified way to interact with tools 12 | // regardless of their implementation details. 13 | type Tool interface { 14 | // Name returns the tool's identifier 15 | Name() string 16 | 17 | // Description returns human-readable explanation of the tool's purpose 18 | Description() string 19 | 20 | // InputSchema returns the expected parameter structure 21 | InputSchema() models.InputSchema 22 | 23 | // Call executes the tool with the provided arguments 24 | Call(ctx context.Context, args map[string]interface{}) (*models.CallToolResult, error) 25 | } 26 | 27 | // ToolType represents the source/type of a tool. 28 | type ToolType string 29 | 30 | const ( 31 | // ToolTypeFunc represents a tool backed by a local Go function. 32 | ToolTypeFunc ToolType = "function" 33 | 34 | // ToolTypeMCP represents a tool backed by an MCP server. 35 | ToolTypeMCP ToolType = "mcp" 36 | ) 37 | 38 | type XMLArgument struct { 39 | XMLName xml.Name `xml:"arg"` 40 | Key string `xml:"key,attr"` 41 | Value string `xml:",chardata"` // Store raw value as string for now 42 | } 43 | 44 | type XMLAction struct { 45 | XMLName xml.Name `xml:"action"` 46 | ToolName string `xml:"tool_name,omitempty"` 47 | Arguments []XMLArgument `xml:"arguments>arg,omitempty"` 48 | 49 | Content string `xml:",chardata"` 50 | } 51 | 52 | // Helper to convert XML arguments to map[string]interface{} 53 | // Note: This currently stores all values as strings. More sophisticated type 54 | // inference or checking could be added later if needed based on tool schemas. 55 | func (xa *XMLAction) GetArgumentsMap() map[string]interface{} { 56 | argsMap := make(map[string]interface{}) 57 | if xa == nil { 58 | return argsMap 59 | } 60 | // If it's a finish action or other simple action, there may be no arguments 61 | if len(xa.Arguments) == 0 { 62 | return argsMap 63 | } 64 | for _, arg := range xa.Arguments { 65 | argsMap[arg.Key] = arg.Value 66 | } 67 | return argsMap 68 | } 69 | -------------------------------------------------------------------------------- /pkg/tools/tool_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestToolTypeConstants(t *testing.T) { 8 | // Test that the tool type constants are defined correctly 9 | if ToolTypeFunc != "function" { 10 | t.Errorf("Expected ToolTypeFunc to be 'function', got %s", ToolTypeFunc) 11 | } 12 | 13 | if ToolTypeMCP != "mcp" { 14 | t.Errorf("Expected ToolTypeMCP to be 'mcp', got %s", ToolTypeMCP) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /pkg/utils/util.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/XiaoConstantine/dspy-go/pkg/errors" 7 | ) 8 | 9 | // ParseJSONResponse attempts to parse a string response as JSON. 10 | func ParseJSONResponse(response string) (map[string]interface{}, error) { 11 | var result map[string]interface{} 12 | err := json.Unmarshal([]byte(response), &result) 13 | if err != nil { 14 | return nil, errors.WithFields( 15 | errors.Wrap(err, errors.InvalidResponse, "failed to parse JSON"), 16 | errors.Fields{ 17 | "error_type": "json_parse_error", 18 | "data_preview": truncateString(response, 100), 19 | "data_length": len(response), 20 | }) 21 | } 22 | return result, nil 23 | } 24 | 25 | func truncateString(s string, maxLen int) string { 26 | if len(s) <= maxLen { 27 | return s 28 | } 29 | return s[:maxLen] + "..." 30 | } 31 | 32 | // Max returns the maximum of two integers. 33 | func Max(a, b int) int { 34 | if a > b { 35 | return a 36 | } 37 | return b 38 | } 39 | 40 | // CloneParams creates a deep copy of a parameter map. 41 | func CloneParams(params map[string]interface{}) map[string]interface{} { 42 | clone := make(map[string]interface{}) 43 | for k, v := range params { 44 | clone[k] = v 45 | } 46 | return clone 47 | } 48 | --------------------------------------------------------------------------------