├── .covignore ├── examples ├── policer-http │ ├── requirements.txt │ └── policer.py ├── policer-rego │ ├── simple.rego │ ├── basic-auth.rego │ └── police.rego └── gen-test-tokens.sh ├── assets └── imgs │ ├── otel.png │ └── mb-arch-overview.png ├── pkgs ├── policer │ ├── api │ │ ├── error.go │ │ ├── response.go │ │ └── request.go │ ├── interface.go │ └── internal │ │ ├── rego │ │ ├── utils.go │ │ └── policer.go │ │ └── http │ │ └── policer.go ├── backend │ ├── client │ │ ├── cmd_others.go │ │ ├── cmd_darwin.go │ │ ├── cmd_linux.go │ │ ├── interface.go │ │ ├── server.go │ │ ├── server_test.go │ │ ├── options.go │ │ ├── options_test.go │ │ ├── stdio.go │ │ ├── stdio_test.go │ │ ├── sse.go │ │ ├── stream_test.go │ │ └── stream.go │ ├── interface.go │ ├── carrier.go │ ├── helpers.go │ ├── carrier_test.go │ ├── options_test.go │ ├── helpers_test.go │ ├── options.go │ └── ws_test.go ├── internal │ ├── sanitize │ │ ├── data.go │ │ └── data_test.go │ └── cors │ │ ├── cors.go │ │ └── cors_test.go ├── info │ └── info.go ├── memconn │ ├── addr.go │ ├── listener.go │ └── memconn.go ├── mcp │ ├── notification.go │ ├── error.go │ ├── types.go │ ├── message.go │ └── id.go ├── frontend │ ├── hash.go │ ├── interface.go │ ├── info.go │ ├── internal │ │ └── session │ │ │ ├── manager.go │ │ │ ├── manager_test.go │ │ │ └── session.go │ ├── connect.go │ ├── options.go │ └── stdio.go ├── auth │ ├── auth_test.go │ └── auth.go ├── oauth │ ├── keyring.go │ ├── oauth.go │ ├── data.go │ └── dance.go ├── scan │ ├── sbom.go │ ├── scan.go │ └── sbom_test.go └── metrics │ └── manager.go ├── .gitignore ├── main.go ├── ROADMAP.md ├── .github └── workflows │ ├── cov.yaml │ └── build.yaml ├── cli ├── internal │ └── cmd │ │ ├── completion.go │ │ ├── root.go │ │ ├── backend.go │ │ ├── flags.go │ │ ├── scan.go │ │ ├── frontend.go │ │ └── aio.go ├── main.go └── init.go ├── Makefile ├── .goreleaser.yml ├── README.md └── go.mod /.covignore: -------------------------------------------------------------------------------- 1 | cli/* 2 | -------------------------------------------------------------------------------- /examples/policer-http/requirements.txt: -------------------------------------------------------------------------------- 1 | PyJWT 2 | -------------------------------------------------------------------------------- /assets/imgs/otel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acuvity/minibridge/HEAD/assets/imgs/otel.png -------------------------------------------------------------------------------- /assets/imgs/mb-arch-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acuvity/minibridge/HEAD/assets/imgs/mb-arch-overview.png -------------------------------------------------------------------------------- /pkgs/policer/api/error.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import "errors" 4 | 5 | var ErrBlocked = errors.New("request blocked") 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | minibridge 2 | dist 3 | *.exe 4 | *.pem 5 | *.token 6 | *.sbom 7 | unit_coverage.out 8 | go.work 9 | go.work.sum 10 | -------------------------------------------------------------------------------- /examples/policer-rego/simple.rego: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import rego.v1 4 | 5 | # This is most basic policy we can have It will allow everything by default. 6 | default allow := true 7 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | 6 | "go.acuvity.ai/minibridge/cli" 7 | ) 8 | 9 | func main() { 10 | 11 | ctx, cancel := context.WithCancel(context.Background()) 12 | defer cancel() 13 | 14 | cli.Main(ctx) 15 | } 16 | -------------------------------------------------------------------------------- /pkgs/backend/client/cmd_others.go: -------------------------------------------------------------------------------- 1 | //go:build !darwin && !linux 2 | 3 | package client 4 | 5 | import ( 6 | "os/exec" 7 | "syscall" 8 | ) 9 | 10 | func setCaps(cmd *exec.Cmd, _ string, _ *creds) { 11 | cmd.SysProcAttr = &syscall.SysProcAttr{} 12 | } 13 | -------------------------------------------------------------------------------- /pkgs/internal/sanitize/data.go: -------------------------------------------------------------------------------- 1 | package sanitize 2 | 3 | import "bytes" 4 | 5 | // Data sanitizes the data for internal 6 | // transport. It removed all trailing '\n' and `\r` 7 | func Data(data []byte) []byte { 8 | 9 | return bytes.TrimRight(data, "\n\r") 10 | } 11 | -------------------------------------------------------------------------------- /pkgs/info/info.go: -------------------------------------------------------------------------------- 1 | package info 2 | 3 | type Info struct { 4 | OAuthAuthorize bool `json:"oauthAuthorize"` 5 | OAuthRegister bool `json:"oauthRegister"` 6 | OAuthToken bool `json:"oauthToken"` 7 | OAuthMetadata bool `json:"oauthMetadata"` 8 | Type string `json:"type"` 9 | Server string `json:"server"` 10 | } 11 | -------------------------------------------------------------------------------- /pkgs/backend/interface.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // A Backend is the interface of object that can 8 | // act as a minibridge Backend. 9 | type Backend interface { 10 | 11 | // Sarts starts the backend. It will run in background 12 | // until the given context is done. 13 | Start(context.Context) error 14 | } 15 | -------------------------------------------------------------------------------- /pkgs/memconn/addr.go: -------------------------------------------------------------------------------- 1 | package memconn 2 | 3 | import "net" 4 | 5 | // Addr is the address of a memlistener. 6 | type Addr struct { 7 | name string 8 | } 9 | 10 | var _ net.Addr = (*Addr)(nil) 11 | 12 | // Network implements net.Addr. Returns "memory." 13 | func (Addr) Network() string { return "memory" } 14 | 15 | // String implements net.Addr. Returns "memory." 16 | func (a Addr) String() string { return a.name } 17 | -------------------------------------------------------------------------------- /pkgs/mcp/notification.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | type Notification struct { 4 | JSONRPC string `json:"jsonrpc"` 5 | Method string `json:"method,omitempty"` 6 | Params map[string]any `json:"params,omitempty"` 7 | } 8 | 9 | // NewNotification returns a new notification. 10 | func NewNotification(name string) Notification { 11 | return Notification{ 12 | JSONRPC: "2.0", 13 | Method: name, 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /pkgs/mcp/error.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | // An Error represents an inline MPC error. 4 | type Error struct { 5 | Code int `json:"code"` 6 | Message string `json:"message,omitempty"` 7 | Data any `json:"data,omitempty"` 8 | } 9 | 10 | // NewError returns an *MCPError with code 500 11 | // and the given error 12 | func NewError(err error) *Error { 13 | return &Error{ 14 | Code: 500, 15 | Message: err.Error(), 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /pkgs/frontend/hash.go: -------------------------------------------------------------------------------- 1 | package frontend 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spaolacci/murmur3" 7 | "go.acuvity.ai/minibridge/pkgs/auth" 8 | ) 9 | 10 | func hash(v string) uint64 { 11 | return murmur3.Sum64([]byte(v)) & 0x7FFFFFFFFFFFFFFF // #nosec G115 12 | } 13 | 14 | func hashCreds(auth *auth.Auth, authHeaders []string) uint64 { 15 | if auth != nil { 16 | return hash(fmt.Sprintf("%s-%v", auth.Encode(), authHeaders)) 17 | } 18 | return hash(fmt.Sprintf("%v", authHeaders)) 19 | } 20 | -------------------------------------------------------------------------------- /pkgs/backend/client/cmd_darwin.go: -------------------------------------------------------------------------------- 1 | //go:build darwin 2 | 3 | package client 4 | 5 | import ( 6 | "os/exec" 7 | "syscall" 8 | ) 9 | 10 | func setCaps(cmd *exec.Cmd, chroot string, creds *creds) { 11 | 12 | var screds *syscall.Credential 13 | if creds != nil { 14 | screds = &syscall.Credential{ 15 | Uid: creds.Uid, 16 | Gid: creds.Gid, 17 | Groups: creds.Groups, 18 | } 19 | } 20 | 21 | cmd.SysProcAttr = &syscall.SysProcAttr{ 22 | Chroot: chroot, 23 | Credential: screds, 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /pkgs/backend/client/cmd_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | 3 | package client 4 | 5 | import ( 6 | "os/exec" 7 | "syscall" 8 | ) 9 | 10 | func setCaps(cmd *exec.Cmd, chroot string, creds *creds) { 11 | 12 | var screds *syscall.Credential 13 | if creds != nil { 14 | screds = &syscall.Credential{ 15 | Uid: creds.Uid, 16 | Gid: creds.Gid, 17 | Groups: creds.Groups, 18 | } 19 | } 20 | 21 | cmd.SysProcAttr = &syscall.SysProcAttr{ 22 | Pdeathsig: syscall.SIGKILL, 23 | Chroot: chroot, 24 | Credential: screds, 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /examples/gen-test-tokens.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if ! which jwt >/dev/null; then 4 | echo "you must install jwt tool from https://github.com/mike-engel/jwt-cli" 5 | exit 1 6 | fi 7 | 8 | echo "token for Alice" 9 | jwt encode --secret secret --iss pki.example.com -P email=alice@example.com --aud minibridge 10 | 11 | echo 12 | echo "token for Bob" 13 | jwt encode --secret secret --iss pki.example.com -P email=bob@example.com --aud minibridge 14 | 15 | echo 16 | echo "Token for Eve" 17 | jwt encode --secret "not-secret" --iss pki.example.com -P email=eve@example.com --aud minibridge 18 | -------------------------------------------------------------------------------- /examples/policer-rego/basic-auth.rego: -------------------------------------------------------------------------------- 1 | # This policy enforces a deny‐all default and, if the BASIC_AUTH env var is set, 2 | # compares it to input.agent.password passed by the client as basic Authentication headers 3 | # The request with the reason “invalid credentials” when they don’t match. 4 | # To set BASIC_AUTH run export REGO_POLICY_RUNTIME_BASIC_AUTH="s3cr3t" 5 | package main 6 | 7 | import rego.v1 8 | 9 | secret := opa.runtime().env.BASIC_AUTH 10 | 11 | reasons contains "access denied" if { 12 | secret != "" 13 | input.agent.password != secret 14 | } 15 | 16 | default allow := false 17 | 18 | allow if { 19 | count(reasons) == 0 20 | } 21 | -------------------------------------------------------------------------------- /pkgs/backend/client/interface.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "go.acuvity.ai/minibridge/pkgs/auth" 8 | ) 9 | 10 | type cfg struct { 11 | auth *auth.Auth 12 | } 13 | 14 | type Option func(*cfg) 15 | 16 | func OptionAuth(a *auth.Auth) Option { 17 | return func(c *cfg) { 18 | c.auth = a 19 | } 20 | } 21 | 22 | // A Client is the interface of object that can 23 | // act as a minibridge mcp Client. 24 | type Client interface { 25 | Start(context.Context, ...Option) (*MCPStream, error) 26 | Type() string 27 | Server() string 28 | } 29 | 30 | type RemoteClient interface { 31 | HTTPClient() *http.Client 32 | BaseURL() string 33 | Client 34 | } 35 | -------------------------------------------------------------------------------- /pkgs/backend/client/server.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | ) 7 | 8 | // A MCPServer contains the information needed 9 | // to launch an MCP Server. 10 | type MCPServer struct { 11 | Command string 12 | Args []string 13 | Env []string 14 | } 15 | 16 | // NewMCPServer returtns a new MCPServer. Returns an error is the given cmd path 17 | // does not exist. 18 | func NewMCPServer(path string, args ...string) (MCPServer, error) { 19 | p, err := exec.LookPath(path) 20 | if err != nil { 21 | return MCPServer{}, fmt.Errorf("unable to find server binary: %w", err) 22 | } 23 | return MCPServer{ 24 | Command: p, 25 | Args: args, 26 | }, nil 27 | } 28 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | # Roadmap 2 | 3 | ## In Progress 4 | 5 | - [ ] Unit tests 6 | - [ ] Wiki with tutorials on specific parts 7 | - [ ] Support for 2025-03-26 (pr #21) 8 | 9 | ## Todo 10 | 11 | - [ ] Advanced sandboxing when not running in containers (firejail/bubblewarp) 12 | - [ ] Support for shared MCP server (when using a Policer) 13 | - [ ] Add MTLS policer 14 | - [ ] Add A3S Policer 15 | - [ ] Add DScope Policer 16 | 17 | ## Done 18 | 19 | - [x] Transport user information over the websocket channel 20 | - [x] Support for user extraction to pass to the policer 21 | - [x] Plug in prometheus metrics 22 | - [x] Opentelemetry 23 | - [x] Optimize communications between front/back in aio mode (use memconn) 24 | -------------------------------------------------------------------------------- /pkgs/backend/client/server_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/smartystreets/goconvey/convey" 7 | ) 8 | 9 | func TestMCPServer(t *testing.T) { 10 | 11 | Convey("calling NewMCPServer on existing bin should work", t, func() { 12 | srv, err := NewMCPServer("echo", "hello") 13 | So(err, ShouldBeNil) 14 | So(srv.Command, ShouldEndWith, "/bin/echo") 15 | So(srv.Args, ShouldResemble, []string{"hello"}) 16 | }) 17 | 18 | Convey("calling NewMCPServer on non exiting bin should work", t, func() { 19 | _, err := NewMCPServer("not-echo", "hello") 20 | So(err, ShouldNotBeNil) 21 | So(err.Error(), ShouldEqual, `unable to find server binary: exec: "not-echo": executable file not found in $PATH`) 22 | }) 23 | } 24 | -------------------------------------------------------------------------------- /pkgs/auth/auth_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/smartystreets/goconvey/convey" 7 | ) 8 | 9 | func TestAuth(t *testing.T) { 10 | 11 | Convey("Basic auth should work", t, func() { 12 | auth := NewBasicAuth("user", "pass") 13 | So(auth.Type(), ShouldEqual, "Basic") 14 | So(auth.User(), ShouldEqual, "user") 15 | So(auth.Password(), ShouldEqual, "pass") 16 | So(auth.Encode(), ShouldEqual, "Basic dXNlcjpwYXNz") 17 | }) 18 | 19 | Convey("Bearer auth should work", t, func() { 20 | auth := NewBearerAuth("token") 21 | So(auth.Type(), ShouldEqual, "Bearer") 22 | So(auth.User(), ShouldEqual, "Bearer") 23 | So(auth.Password(), ShouldEqual, "token") 24 | So(auth.Encode(), ShouldEqual, "Bearer token") 25 | }) 26 | } 27 | -------------------------------------------------------------------------------- /pkgs/frontend/interface.go: -------------------------------------------------------------------------------- 1 | package frontend 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "go.acuvity.ai/minibridge/pkgs/auth" 8 | "go.acuvity.ai/minibridge/pkgs/info" 9 | ) 10 | 11 | // A Frontend is the interface of object that can 12 | // act as a minibridge Frontend. 13 | type Frontend interface { 14 | 15 | // Start starts the frontend. It will run in background until 16 | // the given context is done. 17 | Start(context.Context, *auth.Auth) error 18 | 19 | // BackendURL returns the backend URL the frontend connects to. 20 | BackendURL() string 21 | 22 | // HTTPClient returns a client that can be used to communicate 23 | // with the backend. 24 | HTTPClient() *http.Client 25 | 26 | // BackendInfo queries the backend for information. 27 | BackendInfo() (info.Info, error) 28 | } 29 | -------------------------------------------------------------------------------- /.github/workflows/cov.yaml: -------------------------------------------------------------------------------- 1 | name: cov 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["build-go"] 6 | types: 7 | - completed 8 | 9 | env: 10 | GO111MODULE: on 11 | GOPRIVATE: github.com/acuvity,go.acuvity.ai 12 | GOPROXY: https://proxy.golang.org,direct 13 | GOTOKEN: ${{ secrets.GO_PRIVATE_REPO_PAT }} 14 | 15 | jobs: 16 | cov: 17 | runs-on: ubuntu-latest 18 | permissions: write-all 19 | steps: 20 | - name: setup 21 | run: | 22 | git config --global url."https://acuvity:${GOTOKEN}@github.com/acuvity".insteadOf "https://github.com/acuvity" 23 | - uses: acuvity/cov@1.0.2 24 | with: 25 | cov_mode: send-status 26 | workflow_run_id: ${{github.event.workflow_run.id}} 27 | workflow_head_sha: ${{github.event.workflow_run.head_sha}} 28 | -------------------------------------------------------------------------------- /pkgs/policer/interface.go: -------------------------------------------------------------------------------- 1 | package policer 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | 7 | "go.acuvity.ai/minibridge/pkgs/auth" 8 | "go.acuvity.ai/minibridge/pkgs/mcp" 9 | "go.acuvity.ai/minibridge/pkgs/policer/api" 10 | "go.acuvity.ai/minibridge/pkgs/policer/internal/http" 11 | "go.acuvity.ai/minibridge/pkgs/policer/internal/rego" 12 | ) 13 | 14 | // A Policer is the interface of objects that can police request. 15 | type Policer interface { 16 | Police(context.Context, api.Request) (*mcp.Message, error) 17 | Type() string 18 | } 19 | 20 | // NewRego returns a new rego based Policer. 21 | func NewRego(policy string) (Policer, error) { 22 | return rego.New(policy) 23 | } 24 | 25 | // NewHTTP returns a new HTTP based Policer 26 | func NewHTTP(endpoint string, auth *auth.Auth, tlsConfig *tls.Config) Policer { 27 | return http.New(endpoint, auth, tlsConfig) 28 | } 29 | -------------------------------------------------------------------------------- /pkgs/policer/api/response.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import "go.acuvity.ai/minibridge/pkgs/mcp" 4 | 5 | // GenericDenyReason is the generic reason returned if non is provided. 6 | const GenericDenyReason = "You are not allowed to perform this operation" 7 | 8 | // A Response is returned by the Policer. 9 | type Response struct { 10 | 11 | // If true, the request is allowed. Otherwise 12 | // it will be rejected for the reasons set in 13 | // Messages. 14 | Allow bool `json:"allow"` 15 | 16 | // Reasons contains reasons for denying the request. 17 | // If no message is given, and Allow is false, a generic 18 | // forbidden message will be used. 19 | Reasons []string `json:"reasons,omitempty"` 20 | 21 | // If non-zero, replace the request MCP call with 22 | // this one. This allows Policers to modify the content 23 | // of an MCP call. 24 | MCP *mcp.Message `json:"mcp,omitempty"` 25 | } 26 | -------------------------------------------------------------------------------- /cli/internal/cmd/completion.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | "os" 6 | ) 7 | 8 | var Completion = &cobra.Command{ 9 | Use: "completion [bash|zsh|fish|powershell]", 10 | Short: "Generate completion script", 11 | DisableFlagsInUseLine: true, 12 | ValidArgs: []string{"bash", "zsh", "fish", "powershell"}, 13 | Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs), 14 | RunE: func(cmd *cobra.Command, args []string) error { 15 | switch args[0] { 16 | case "bash": 17 | return cmd.Root().GenBashCompletion(os.Stdout) 18 | case "zsh": 19 | return cmd.Root().GenZshCompletion(os.Stdout) 20 | case "fish": 21 | return cmd.Root().GenFishCompletion(os.Stdout, true) 22 | case "powershell": 23 | return cmd.Root().GenPowerShellCompletionWithDesc(os.Stdout) 24 | } 25 | 26 | return nil 27 | }, 28 | } 29 | -------------------------------------------------------------------------------- /pkgs/internal/cors/cors.go: -------------------------------------------------------------------------------- 1 | package cors 2 | 3 | import ( 4 | "net/http" 5 | 6 | "go.acuvity.ai/bahamut" 7 | ) 8 | 9 | // HandleCORS handles CORS for browsers. 10 | func HandleCORS(w http.ResponseWriter, req *http.Request, corsPolicy *bahamut.CORSPolicy) bool { 11 | 12 | w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") 13 | w.Header().Set("X-Frame-Options", "DENY") 14 | w.Header().Set("X-Content-Type-Options", "nosniff") 15 | w.Header().Set("X-Xss-Protection", "1; mode=block") 16 | w.Header().Set("Cache-Control", "private, no-transform") 17 | 18 | if corsPolicy == nil { 19 | return true 20 | } 21 | 22 | w.Header().Del("Origin") 23 | 24 | if req.Method == http.MethodOptions { 25 | corsPolicy.Inject(w.Header(), req.Header.Get("Origin"), true) 26 | w.WriteHeader(http.StatusNoContent) 27 | return false 28 | } 29 | 30 | corsPolicy.Inject(w.Header(), req.Header.Get("Origin"), false) 31 | 32 | return true 33 | } 34 | -------------------------------------------------------------------------------- /pkgs/frontend/info.go: -------------------------------------------------------------------------------- 1 | package frontend 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | 8 | "go.acuvity.ai/elemental" 9 | "go.acuvity.ai/minibridge/pkgs/info" 10 | ) 11 | 12 | func getBackendInfo(mfrontend Frontend) (info.Info, error) { 13 | 14 | inf := info.Info{} 15 | cl := mfrontend.HTTPClient() 16 | 17 | resp, err := cl.Get(fmt.Sprintf("%s/_info", mfrontend.BackendURL())) 18 | if err != nil { 19 | return inf, fmt.Errorf("unable to make backend info request: %w", err) 20 | } 21 | defer func() { _ = resp.Body.Close() }() 22 | 23 | if resp.StatusCode != http.StatusOK { 24 | return inf, fmt.Errorf("invalid backend info response status: %s", resp.Status) 25 | } 26 | 27 | data, err := io.ReadAll(resp.Body) 28 | if err != nil { 29 | return inf, fmt.Errorf("unable to read backend info response body: %w", err) 30 | } 31 | 32 | if err := elemental.Decode(elemental.EncodingTypeJSON, data, &inf); err != nil { 33 | return inf, fmt.Errorf("unable to decode backend info response body: %w", err) 34 | } 35 | 36 | return inf, nil 37 | } 38 | -------------------------------------------------------------------------------- /pkgs/oauth/keyring.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/zalando/go-keyring" 7 | "go.acuvity.ai/elemental" 8 | ) 9 | 10 | func TokenToKeyring(server string, t Credentials) error { 11 | 12 | data, err := elemental.Encode(elemental.EncodingTypeJSON, t) 13 | if err != nil { 14 | return fmt.Errorf("unable to encode token data: %w", err) 15 | } 16 | 17 | if err := keyring.Set("minibridge", server, string(data)); err != nil { 18 | return fmt.Errorf("unable to store token into keyring: %w", err) 19 | } 20 | 21 | return nil 22 | } 23 | 24 | func TokenFromKeyring(server string) (Credentials, error) { 25 | 26 | data, err := keyring.Get("minibridge", server) 27 | if err != nil { 28 | return Credentials{}, fmt.Errorf("unable to retrieve token from keyring: %w", err) 29 | } 30 | 31 | t := Credentials{} 32 | if err := elemental.Decode(elemental.EncodingTypeJSON, []byte(data), &t); err != nil { 33 | return Credentials{}, fmt.Errorf("unable to decode token data from keychain: %w", err) 34 | } 35 | 36 | return t, nil 37 | } 38 | -------------------------------------------------------------------------------- /pkgs/backend/carrier.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "go.acuvity.ai/minibridge/pkgs/mcp" 5 | "go.opentelemetry.io/otel/propagation" 6 | ) 7 | 8 | var _ propagation.TextMapCarrier = metaCarrier{} 9 | 10 | type metaCarrier struct { 11 | meta map[string]string 12 | } 13 | 14 | func newMCPMetaCarrier(call mcp.Message) metaCarrier { 15 | 16 | meta := map[string]string{} 17 | 18 | if call.Params != nil { 19 | if pmeta, ok := call.Params["_meta"].(map[string]any); ok { 20 | for k, v := range pmeta { 21 | if s, ok := v.(string); ok { 22 | meta[k] = s 23 | } 24 | } 25 | } 26 | } 27 | 28 | return metaCarrier{ 29 | meta: meta, 30 | } 31 | } 32 | 33 | func (c metaCarrier) Get(key string) string { 34 | 35 | v, ok := c.meta[key] 36 | if !ok { 37 | return "" 38 | } 39 | 40 | return v 41 | } 42 | 43 | func (c metaCarrier) Set(key string, value string) { 44 | c.meta[key] = value 45 | } 46 | 47 | func (c metaCarrier) Keys() []string { 48 | 49 | out := make([]string, 0, len(c.meta)) 50 | for k := range c.meta { 51 | out = append(out, k) 52 | } 53 | 54 | return out 55 | } 56 | -------------------------------------------------------------------------------- /cli/main.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log/slog" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | 11 | "github.com/spf13/cobra" 12 | "go.acuvity.ai/minibridge/cli/internal/cmd" 13 | ) 14 | 15 | // Main is the main run for the cli. 16 | func Main(ctx context.Context) { 17 | 18 | cobra.OnInitialize(initCobra) 19 | 20 | ctx, cancel := context.WithCancel(ctx) 21 | defer cancel() 22 | 23 | installSIGINTHandler(cancel) 24 | 25 | if err := cmd.Root.ExecuteContext(ctx); err != nil { 26 | if _, ok := slog.Default().Handler().(*slog.JSONHandler); ok { 27 | slog.Error("Minibridge exited with error", "err", err) 28 | } else { 29 | fmt.Fprintf(os.Stderr, "error: %s\n", err.Error()) 30 | } 31 | os.Exit(1) 32 | } 33 | } 34 | 35 | func installSIGINTHandler(cancel context.CancelFunc) { 36 | 37 | sigs := []os.Signal{syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGABRT} 38 | signalCh := make(chan os.Signal, 1) 39 | signal.Reset(sigs...) 40 | signal.Notify(signalCh, sigs...) 41 | 42 | go func() { 43 | <-signalCh 44 | cancel() 45 | }() 46 | } 47 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: build-go 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | 8 | defaults: 9 | run: 10 | shell: bash 11 | 12 | env: 13 | GOPRIVATE: "github.com/acuvity,go.acuvity.ai" 14 | GOPROXY: "https://proxy.golang.org,direct" 15 | GOTOKEN: ${{ secrets.GO_PRIVATE_REPO_PAT }} 16 | 17 | 18 | jobs: 19 | build: 20 | runs-on: ubuntu-latest 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | go: 25 | - "1.24" 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - uses: actions/setup-go@v5 30 | with: 31 | go-version: ${{ matrix.go }} 32 | cache: true 33 | 34 | - name: setup 35 | run: | 36 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest 37 | go install golang.org/x/vuln/cmd/govulncheck@latest 38 | go install github.com/securego/gosec/v2/cmd/gosec@master 39 | 40 | - name: build 41 | run: | 42 | git config --global url."https://acuvity:${GOTOKEN}@github.com/acuvity".insteadOf "https://github.com/acuvity" 43 | make 44 | 45 | - uses: acuvity/cov@1.0.2 46 | with: 47 | main_branch: main 48 | cov_file: unit_coverage.out 49 | cov_threshold: "0" 50 | cov_mode: coverage 51 | -------------------------------------------------------------------------------- /pkgs/oauth/oauth.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log/slog" 7 | "net/http" 8 | "strings" 9 | ) 10 | 11 | func Forward(baseURL string, client *http.Client, w http.ResponseWriter, req *http.Request, path string) func() { 12 | 13 | u := fmt.Sprintf("%s/%s", baseURL, strings.TrimLeft(path, "/")) 14 | 15 | slog.Debug("OAuth: Forwarding OAuth call", "method", req.Method, "target", u) 16 | 17 | breq, err := http.NewRequestWithContext(req.Context(), req.Method, u, req.Body) 18 | if err != nil { 19 | http.Error(w, fmt.Sprintf("unable to make request: %s", err), http.StatusInternalServerError) 20 | return func() {} 21 | } 22 | 23 | breq.URL.RawQuery = req.URL.Query().Encode() 24 | breq.Header = req.Header.Clone() 25 | 26 | resp, err := client.Do(breq) // nolint: bodyclose 27 | if err != nil { 28 | http.Error(w, err.Error(), http.StatusInternalServerError) 29 | return func() {} 30 | } 31 | 32 | for k, vs := range resp.Header { 33 | 34 | for _, v := range vs { 35 | 36 | if k == "Origin" || strings.HasPrefix(k, "Access-Control-") { 37 | continue 38 | } 39 | 40 | w.Header().Add(k, v) 41 | } 42 | } 43 | 44 | w.WriteHeader(resp.StatusCode) 45 | 46 | if resp.Body != nil { 47 | _, _ = io.Copy(w, resp.Body) 48 | return func() { _ = resp.Body.Close() } 49 | } 50 | 51 | return func() {} 52 | } 53 | -------------------------------------------------------------------------------- /pkgs/mcp/types.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | type Tools []Tool 4 | type Tool struct { 5 | Name string `json:"name"` 6 | Description string `json:"description,omitempty"` 7 | InputSchema map[string]any `json:"inputSchema,omitempty"` 8 | } 9 | 10 | type Prompts []*Prompt 11 | type Prompt struct { 12 | Name string `json:"name,omitempty"` 13 | Description string `json:"description,omitempty"` 14 | Arguments PromptArguments `json:"arguments,omitempty"` 15 | } 16 | 17 | type PromptArguments []PromptArgument 18 | type PromptArgument struct { 19 | Name string `json:"name,omitempty"` 20 | Description string `json:"description,omitempty"` 21 | Required bool `json:"required,omitempty"` 22 | } 23 | 24 | type Resources []Resource 25 | type Resource struct { 26 | Name string `json:"name,omitempty"` 27 | Description string `json:"description,omitempty"` 28 | MimeType string `json:"mimeType,omitempty"` 29 | Text string `json:"text,omitempty"` 30 | Blob string `json:"blob,omitempty"` 31 | URI string `json:"uri,omitempty"` 32 | } 33 | 34 | type ResourceTemplates []ResourceTemplate 35 | type ResourceTemplate struct { 36 | Name string `json:"name,omitempty"` 37 | Description string `json:"description,omitempty"` 38 | URITemplate string `json:"uriTemplate,omitempty"` 39 | } 40 | -------------------------------------------------------------------------------- /pkgs/policer/internal/rego/utils.go: -------------------------------------------------------------------------------- 1 | package rego 2 | 3 | import ( 4 | "fmt" 5 | "log/slog" 6 | 7 | "github.com/open-policy-agent/opa/v1/ast" 8 | "github.com/open-policy-agent/opa/v1/topdown/print" 9 | ) 10 | 11 | type printer struct{} 12 | 13 | func (p printer) Print(ctx print.Context, s string) error { 14 | slog.Info(fmt.Sprintf("Rego Print: %s", s), "row", ctx.Location.Row) 15 | return nil 16 | } 17 | 18 | func precompile(policy string, name string, modules ...*ast.Module) (*ast.Compiler, error) { 19 | 20 | name = name + ".rego" 21 | 22 | compiler := ast.NewCompiler().WithEnablePrintStatements(true) 23 | module, err := prepareModule("main", policy) 24 | if err != nil { 25 | return nil, err 26 | } 27 | 28 | allModules := map[string]*ast.Module{ 29 | name: module, 30 | } 31 | for _, m := range modules { 32 | allModules[m.Package.String()+".rego"] = m 33 | } 34 | 35 | compiler.Compile(allModules) 36 | 37 | if compiler.Failed() { 38 | return nil, fmt.Errorf("unable compile rego module: %w", compiler.Errors) 39 | } 40 | 41 | return compiler, nil 42 | } 43 | 44 | func prepareModule(name string, policy string) (*ast.Module, error) { 45 | 46 | caps := ast.CapabilitiesForThisVersion() 47 | 48 | module, err := ast.ParseModuleWithOpts( 49 | name, 50 | policy, 51 | ast.ParserOptions{ 52 | Capabilities: caps, 53 | }, 54 | ) 55 | if err != nil { 56 | return nil, fmt.Errorf("unable to parse rego module: %w", err) 57 | } 58 | 59 | return module, nil 60 | } 61 | -------------------------------------------------------------------------------- /cli/internal/cmd/root.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/spf13/cobra" 8 | "github.com/spf13/viper" 9 | "go.acuvity.ai/a3s/pkgs/bootstrap" 10 | "go.acuvity.ai/a3s/pkgs/conf" 11 | "go.acuvity.ai/a3s/pkgs/version" 12 | ) 13 | 14 | func init() { 15 | 16 | initSharedFlagSet() 17 | 18 | Root.PersistentFlags().String("log-level", "info", "sets the log level.") 19 | Root.PersistentFlags().String("log-format", "console", "sets the log format.") 20 | Root.PersistentFlags().Bool("version", false, "print version and exit.") 21 | 22 | Root.AddCommand( 23 | Backend, 24 | Frontend, 25 | AIO, 26 | Completion, 27 | Scan, 28 | ) 29 | } 30 | 31 | var Root = &cobra.Command{ 32 | Use: "minibridge", 33 | Short: "Secure your MCP Servers", 34 | SilenceUsage: true, 35 | SilenceErrors: true, 36 | TraverseChildren: true, 37 | PersistentPreRunE: func(cmd *cobra.Command, args []string) error { 38 | 39 | if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil { 40 | return err 41 | } 42 | 43 | if err := viper.BindPFlags(cmd.Flags()); err != nil { 44 | return err 45 | } 46 | 47 | bootstrap.ConfigureLogger("minibridge", conf.LoggingConf{ 48 | LogLevel: viper.GetString("log-level"), 49 | LogFormat: viper.GetString("log-format"), 50 | }) 51 | 52 | return nil 53 | }, 54 | RunE: func(cmd *cobra.Command, args []string) error { 55 | if viper.GetBool("version") { 56 | fmt.Println(version.Short()) 57 | os.Exit(0) 58 | } 59 | return cmd.Usage() 60 | }, 61 | } 62 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | MAKEFLAGS += --warn-undefined-variables 2 | SHELL := /bin/bash -o pipefail 3 | GIT_SHA=$(shell git rev-parse --short HEAD) 4 | GIT_BRANCH=$(shell git rev-parse --abbrev-ref HEAD) 5 | GIT_TAG=$(shell git describe --tags --abbrev=0 --match='v[0-9]*.[0-9]*.[0-9]*' 2> /dev/null | sed 's/^.//') 6 | BUILD_DATE=$(shell date) 7 | VERSION_PKG="go.acuvity.ai/a3s/pkgs/version" 8 | LDFLAGS = -ldflags="-w -s -X '$(VERSION_PKG).GitSha=$(GIT_SHA)' -X '$(VERSION_PKG).GitBranch=$(GIT_BRANCH)' -X '$(VERSION_PKG).GitTag=$(GIT_TAG)' -X '$(VERSION_PKG).BuildDate=$(BUILD_DATE)'" 9 | 10 | export GO111MODULE = on 11 | 12 | default: lint test build vuln sec 13 | 14 | lint: 15 | golangci-lint run \ 16 | --timeout=5m \ 17 | --disable=govet \ 18 | --enable=errcheck \ 19 | --enable=ineffassign \ 20 | --enable=unused \ 21 | --enable=unconvert \ 22 | --enable=misspell \ 23 | --enable=prealloc \ 24 | --enable=nakedret \ 25 | --enable=unparam \ 26 | --enable=nilerr \ 27 | --enable=bodyclose \ 28 | --enable=errorlint \ 29 | ./... 30 | test: 31 | go test ./... -vet off -race -cover -covermode=atomic -coverprofile=unit_coverage.out 32 | 33 | sec: 34 | gosec -quiet ./... 35 | 36 | vuln: 37 | govulncheck ./... 38 | 39 | build: 40 | go build $(LDFLAGS) -trimpath . 41 | 42 | remod: 43 | go get go.acuvity.ai/tg@master 44 | go get go.acuvity.ai/wsc@master 45 | go get go.acuvity.ai/regolithe@master 46 | go get go.acuvity.ai/bahamut@master 47 | go get go.acuvity.ai/elemental@master 48 | go get go.acuvity.ai/manipulate@master 49 | go get go.acuvity.ai/a3s@master 50 | go mod tidy 51 | -------------------------------------------------------------------------------- /pkgs/frontend/internal/session/manager.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import "sync" 4 | 5 | // A Manager manages sessions and keep 6 | // track of them. 7 | type Manager struct { 8 | sessions map[string]*Session 9 | sync.RWMutex 10 | } 11 | 12 | // NewManager returns a new *session.Manager. 13 | func NewManager() *Manager { 14 | return &Manager{ 15 | sessions: map[string]*Session{}, 16 | } 17 | } 18 | 19 | // Acquire acquires and returns the session with the given sid. 20 | // It returns nil if not session with that sid is found. 21 | // In addition, if ch is not nil, it will be registered as a read hook. 22 | func (p *Manager) Acquire(sid string, ch chan []byte) *Session { 23 | 24 | p.RLock() 25 | defer p.RUnlock() 26 | 27 | s := p.sessions[sid] 28 | if s != nil { 29 | s.acquire() 30 | if ch != nil { 31 | s.register(ch) 32 | } 33 | } 34 | 35 | return s 36 | } 37 | 38 | // Release sessions releases an acquired session. 39 | // if the session is not acquired by anything, the ws connection 40 | // will be closed, and the session deleted. 41 | // In addition, if ch is not nil, it will be unregstered as a read hook. 42 | func (p *Manager) Release(sid string, ch chan []byte) { 43 | 44 | p.Lock() 45 | defer p.Unlock() 46 | 47 | s := p.sessions[sid] 48 | if s == nil { 49 | return 50 | } 51 | 52 | if ch != nil { 53 | s.unregister(ch) 54 | } 55 | 56 | if closed := s.release(); closed { 57 | delete(p.sessions, sid) 58 | } 59 | } 60 | 61 | func (p *Manager) Register(s *Session) { 62 | p.Lock() 63 | defer p.Unlock() 64 | 65 | p.sessions[s.ID()] = s 66 | } 67 | -------------------------------------------------------------------------------- /pkgs/oauth/data.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | var successBody = ` 4 | 5 | 6 | 7 | 8 | 9 | Authentication Successful 10 | 37 | 38 | 39 |
40 | 44 |

Authentication Successful

45 |

You have successfully authenticated minibridge with %s.

46 |

You can now close this window.

47 |
48 | 49 | 50 | ` 51 | -------------------------------------------------------------------------------- /pkgs/backend/helpers.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "encoding/base64" 5 | "log/slog" 6 | "net/http" 7 | "strings" 8 | 9 | "go.acuvity.ai/elemental" 10 | "go.acuvity.ai/minibridge/pkgs/auth" 11 | "go.acuvity.ai/minibridge/pkgs/mcp" 12 | "go.opentelemetry.io/otel/codes" 13 | "go.opentelemetry.io/otel/trace" 14 | ) 15 | 16 | func makeMCPError(ID any, err error) []byte { 17 | 18 | mpcerr := mcp.Message{ 19 | JSONRPC: "2.0", 20 | ID: ID, 21 | Error: &mcp.Error{ 22 | Code: 451, 23 | Message: err.Error(), 24 | }, 25 | } 26 | 27 | data, err := elemental.Encode(elemental.EncodingTypeJSON, mpcerr) 28 | if err != nil { 29 | panic(err) 30 | } 31 | 32 | slog.Debug("Injecting MCP error", "err", string(data)) 33 | 34 | return data 35 | } 36 | 37 | func parseBasicAuth(authString string) (a *auth.Auth, ok bool) { 38 | 39 | const prefix = "Basic " 40 | 41 | if len(authString) < len(prefix) || !strings.EqualFold(authString[:len(prefix)], prefix) { 42 | 43 | parts := strings.SplitN(authString, " ", 2) 44 | if len(parts) == 2 { 45 | a = auth.NewBearerAuth(parts[1]) 46 | return a, true 47 | } 48 | return nil, false 49 | } 50 | 51 | c, err := base64.StdEncoding.DecodeString(authString[len(prefix):]) 52 | if err != nil { 53 | return nil, false 54 | } 55 | 56 | cs := string(c) 57 | 58 | user, password, ok := strings.Cut(cs, ":") 59 | if !ok { 60 | return nil, false 61 | } 62 | 63 | return auth.NewBasicAuth(user, password), true 64 | } 65 | 66 | func hErr(w http.ResponseWriter, message string, code int, span trace.Span) { 67 | http.Error(w, message, code) 68 | span.SetStatus(codes.Error, message) 69 | } 70 | -------------------------------------------------------------------------------- /pkgs/backend/carrier_test.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/smartystreets/goconvey/convey" 7 | "go.acuvity.ai/minibridge/pkgs/mcp" 8 | ) 9 | 10 | func TestCarrier(t *testing.T) { 11 | 12 | Convey("MCPCarrier from call with no params", t, func() { 13 | msg := newMCPMetaCarrier(mcp.NewMessage(1)) 14 | So(len(msg.meta), ShouldEqual, 0) 15 | So(msg.Get("a"), ShouldBeEmpty) 16 | msg.Set("a", "1") 17 | So(msg.Get("a"), ShouldEqual, "1") 18 | So(msg.Keys(), ShouldResemble, []string{"a"}) 19 | }) 20 | 21 | Convey("MCPCarrier from call with params but no _meta", t, func() { 22 | msg := mcp.NewMessage(1) 23 | msg.Params = map[string]any{} 24 | c := newMCPMetaCarrier(msg) 25 | So(len(c.meta), ShouldEqual, 0) 26 | }) 27 | 28 | Convey("MCPCarrier from call with params with _meta with wrong type", t, func() { 29 | msg := mcp.NewMessage(1) 30 | msg.Params = map[string]any{"_meta": "oh no"} 31 | c := newMCPMetaCarrier(msg) 32 | So(len(c.meta), ShouldEqual, 0) 33 | }) 34 | 35 | Convey("MCPCarrier from call with params with wrong _meta value type", t, func() { 36 | msg := mcp.NewMessage(1) 37 | msg.Params = map[string]any{"_meta": map[string]any{"a": 42}} 38 | c := newMCPMetaCarrier(msg) 39 | So(len(c.meta), ShouldEqual, 0) 40 | }) 41 | 42 | Convey("MCPCarrier from call with valid _meta", t, func() { 43 | msg := mcp.NewMessage(1) 44 | msg.Params = map[string]any{"_meta": map[string]any{"a": "42"}} 45 | c := newMCPMetaCarrier(msg) 46 | So(len(c.meta), ShouldEqual, 1) 47 | So(c.Get("a"), ShouldEqual, "42") 48 | So(c.Keys(), ShouldResemble, []string{"a"}) 49 | So(msg.Params["_meta"], ShouldNotBeNil) 50 | }) 51 | 52 | } 53 | -------------------------------------------------------------------------------- /.goreleaser.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | snapshot: 3 | version_template: "v{{ .Tag }}-next" 4 | changelog: 5 | sort: asc 6 | filters: 7 | exclude: 8 | - "^docs:" 9 | - "^test:" 10 | - "^examples:" 11 | builds: 12 | - id: minibridge 13 | binary: minibridge 14 | goos: 15 | - linux 16 | - darwin 17 | - windows 18 | goarch: 19 | - amd64 20 | - arm64 21 | env: 22 | - CGO_ENABLED=0 23 | ldflags: 24 | - -w -s -X 'go.acuvity.ai/a3s/pkgs/version.GitSha={{.Commit}}' -X 'go.acuvity.ai/a3s/pkgs/version.GitBranch=main' -X 'go.acuvity.ai/a3s/pkgs/version.GitTag={{.Version}}' -X 'go.acuvity.ai/a3s/pkgs/version.BuildDate={{.Date}}' 25 | 26 | archives: 27 | - id: minibridge 28 | formats: ["zip"] 29 | builds: 30 | - minibridge 31 | 32 | signs: 33 | - artifacts: checksum 34 | args: 35 | [ 36 | "-u", 37 | "0C3214A61024881F5CA1F5F056EDB08A11DCE325", 38 | "--output", 39 | "${signature}", 40 | "--detach-sign", 41 | "${artifact}", 42 | ] 43 | 44 | brews: 45 | - name: minibridge 46 | homepage: "https://github.com/acuvity/minibridge" 47 | description: Minibridge securely connects Agents to MCP servers, exposing them to the internet while enabling optional integration with remote or local Policers for authentication, analysis, and transformation. 48 | license: "Apache" 49 | repository: 50 | owner: acuvity 51 | name: homebrew-tap 52 | commit_author: 53 | name: goreleaserbot 54 | email: goreleaser@acuvity.ai 55 | directory: Formula 56 | install: | 57 | bin.install "minibridge" 58 | test: | 59 | system "#{bin}/minibridge", "--version" 60 | -------------------------------------------------------------------------------- /pkgs/backend/client/options.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | ) 7 | 8 | type creds struct { 9 | Uid uint32 10 | Gid uint32 11 | Groups []uint32 12 | } 13 | 14 | type stdioCfg struct { 15 | useTempDir bool 16 | creds *creds 17 | } 18 | 19 | func newStdioCfg() stdioCfg { 20 | return stdioCfg{} 21 | } 22 | 23 | // An StdioOption can be passed to the Client. 24 | type StdioOption func(*stdioCfg) 25 | 26 | // OptStdioUseTempDir defines if the the client should 27 | // run the command into it's own working dir. If false, 28 | // the command will run in minibridge current cwd 29 | func OptStdioUseTempDir(use bool) StdioOption { 30 | return func(c *stdioCfg) { 31 | c.useTempDir = use 32 | } 33 | } 34 | 35 | // OptStdioCredentials sets the uid and gid to run the command as. 36 | func OptStdioCredentials(uid int, gid int, groups []int) StdioOption { 37 | return func(c *stdioCfg) { 38 | 39 | grps := make([]uint32, 0, len(groups)) 40 | 41 | for i, g := range groups { 42 | 43 | if g < 0 { 44 | continue 45 | } 46 | 47 | if g > math.MaxUint32 { 48 | panic(fmt.Sprintf("invalid group %d. overflows", i)) 49 | } 50 | 51 | grps = append(grps, uint32(g)) // #nosec: G115 52 | } 53 | 54 | if len(grps) == 0 && uid < 0 && gid < 0 { 55 | return 56 | } 57 | 58 | c.creds = &creds{Groups: grps} 59 | 60 | if uid > -1 { 61 | 62 | if uid > math.MaxUint32 { 63 | panic("invalid uid. overflows") 64 | } 65 | 66 | c.creds.Uid = uint32(uid) // #nosec: G115 67 | } 68 | 69 | if gid > -1 { 70 | 71 | if gid > math.MaxUint32 { 72 | panic("invalid gid. overflows") 73 | } 74 | 75 | c.creds.Gid = uint32(gid) // #nosec: G115 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /pkgs/mcp/message.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | type ProtocolVersion string 4 | 5 | var ( 6 | ProtocolVersion20250326 ProtocolVersion = "2025-03-26" 7 | ProtocolVersion20241105 ProtocolVersion = "2024-11-05" 8 | ) 9 | 10 | // Message represents the inline MPC request. 11 | type Message struct { 12 | JSONRPC string `json:"jsonrpc"` 13 | ID any `json:"id,omitempty,omitzero"` 14 | Method string `json:"method,omitempty"` 15 | Params map[string]any `json:"params,omitempty"` 16 | Result map[string]any `json:"result,omitempty"` 17 | Error *Error `json:"error,omitempty"` 18 | } 19 | 20 | // NewMessage returns a MCPCall initialized with the given id. 21 | // To initialize a call without ID set, use an empty string. 22 | func NewMessage[T int | string](id T) Message { 23 | c := Message{ 24 | JSONRPC: "2.0", 25 | } 26 | 27 | var zero T 28 | if id != zero { 29 | c.ID = id 30 | } 31 | 32 | return c 33 | } 34 | 35 | // IDString returns the call ID as a string 36 | // whatever is the original type. 37 | func (c *Message) IDString() string { 38 | 39 | if c.ID == nil { 40 | return "" 41 | } 42 | 43 | return normalizeID(c.ID) 44 | } 45 | 46 | // NewInitMessage makes a new init call using the given protocol version. 47 | func NewInitMessage(proto ProtocolVersion) Message { 48 | return Message{ 49 | JSONRPC: "2.0", 50 | ID: 0, 51 | Method: "initialize", 52 | Params: map[string]any{ 53 | "protocolVersion": proto, 54 | "capabilities": map[string]any{ 55 | "sampling": map[string]any{}, 56 | "roots": map[string]any{ 57 | "listChanged": true, 58 | }, 59 | }, 60 | "clientInfo": map[string]any{ 61 | "name": "minibridge", 62 | "version": "1.0", 63 | }, 64 | }, 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /pkgs/internal/sanitize/data_test.go: -------------------------------------------------------------------------------- 1 | package sanitize 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestSanitize(t *testing.T) { 9 | type args struct { 10 | data []byte 11 | } 12 | tests := []struct { 13 | name string 14 | args func(t *testing.T) args 15 | 16 | want1 []byte 17 | }{ 18 | 19 | { 20 | "empty", 21 | func(*testing.T) args { 22 | return args{ 23 | data: []byte(""), 24 | } 25 | }, 26 | []byte(""), 27 | }, 28 | { 29 | "nil", 30 | func(*testing.T) args { 31 | return args{ 32 | data: nil, 33 | } 34 | }, 35 | nil, 36 | }, 37 | { 38 | "no suffix", 39 | func(*testing.T) args { 40 | return args{ 41 | data: []byte("hello world"), 42 | } 43 | }, 44 | []byte("hello world"), 45 | }, 46 | { 47 | "with \n", 48 | func(*testing.T) args { 49 | return args{ 50 | data: []byte("hello world\n\n\n\n\n"), 51 | } 52 | }, 53 | []byte("hello world"), 54 | }, 55 | { 56 | "with \r", 57 | func(*testing.T) args { 58 | return args{ 59 | data: []byte("hello world\r\r\r\r"), 60 | } 61 | }, 62 | []byte("hello world"), 63 | }, 64 | { 65 | "with \n\r", 66 | func(*testing.T) args { 67 | return args{ 68 | data: []byte("hello world\r\n"), 69 | } 70 | }, 71 | []byte("hello world"), 72 | }, 73 | { 74 | "with \n\r at start", 75 | func(*testing.T) args { 76 | return args{ 77 | data: []byte("\r\nhello world\r\n"), 78 | } 79 | }, 80 | []byte("\r\nhello world"), 81 | }, 82 | } 83 | 84 | for _, tt := range tests { 85 | t.Run(tt.name, func(t *testing.T) { 86 | tArgs := tt.args(t) 87 | 88 | got1 := Data(tArgs.data) 89 | 90 | if !reflect.DeepEqual(got1, tt.want1) { 91 | t.Errorf("Sanitize got1 = %v, want1: %v", got1, tt.want1) 92 | } 93 | }) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /examples/policer-rego/police.rego: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import rego.v1 4 | 5 | # the call is allowed is we don't have any reason to deny it. 6 | allow if { 7 | count(reasons) == 0 8 | } 9 | 10 | # verifies the claims from the JWT of the agent. 11 | claims := x if { 12 | [verified, _, x] := io.jwt.decode_verify( 13 | input.agent.password, 14 | { 15 | "secret": "secret", 16 | "iss": "pki.example.com", 17 | "aud": "minibridge", 18 | }, 19 | ) 20 | verified == true 21 | } 22 | 23 | # if we have no claims, we add a deny reason. 24 | reasons contains msg if { 25 | not claims 26 | msg := "You must provide a valid token" 27 | } 28 | 29 | # if the call is a tools/call and the name 30 | # is printEnv and the user is not Alice, 31 | # we add a deny reason. 32 | reasons contains msg if { 33 | input.mcp.method == "tools/call" 34 | input.mcp.params.name == "printEnv" 35 | claims.email != "alice@example.com" 36 | msg := "only alice can run printEnv" 37 | } 38 | 39 | # if the call is a tools/call and the name 40 | # is longRunningOperation and the user is Bob, 41 | # we add a deny reason. 42 | reasons contains msg if { 43 | input.mcp.method == "tools/call" 44 | claims.email == "bob@example.com" 45 | input.mcp.params.name == "longRunningOperation" 46 | msg := "bob cannot run longRunningOperation" 47 | } 48 | 49 | # if the call is a tools/call and the request is from Bob, we remove the 50 | # longRunningOperation from the response. If Bob still tries to call that tool, 51 | # it will be denied by the rule above. This allows the agent to not loose time 52 | # trying a tool that will be denied anyways 53 | mcp := x if { 54 | input.mcp.result.tools 55 | claims.email == "bob@example.com" 56 | 57 | x := json.patch(input.mcp, [{ 58 | "op": "replace", 59 | "path": "/result/tools", 60 | "value": [x | x := input.mcp.result.tools[_]; x.name != "longRunningOperation"], 61 | }]) 62 | } 63 | -------------------------------------------------------------------------------- /pkgs/backend/client/options_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | . "github.com/smartystreets/goconvey/convey" 8 | ) 9 | 10 | func TestOptions(t *testing.T) { 11 | 12 | Convey("OptUseTempDir should work", t, func() { 13 | cfg := newStdioCfg() 14 | OptStdioUseTempDir(true)(&cfg) 15 | So(cfg.useTempDir, ShouldBeTrue) 16 | }) 17 | 18 | Convey("OptCredentials should work", t, func() { 19 | cfg := newStdioCfg() 20 | OptStdioCredentials(1000, 1001, []int{2001, 2002})(&cfg) 21 | So(cfg.creds.Uid, ShouldEqual, 1000) 22 | So(cfg.creds.Gid, ShouldEqual, 1001) 23 | So(cfg.creds.Groups, ShouldResemble, []uint32{2001, 2002}) 24 | }) 25 | 26 | Convey("OptCredentials with -1 should work", t, func() { 27 | cfg := newStdioCfg() 28 | OptStdioCredentials(-1, -1, []int{-1})(&cfg) 29 | So(cfg.creds, ShouldBeNil) 30 | }) 31 | 32 | Convey("OptCredentials with only gid -1 should work", t, func() { 33 | cfg := newStdioCfg() 34 | OptStdioCredentials(100, -1, []int{-1})(&cfg) 35 | So(cfg.creds.Uid, ShouldEqual, 100) 36 | So(cfg.creds.Gid, ShouldEqual, 0) 37 | So(cfg.creds.Groups, ShouldResemble, []uint32{}) 38 | }) 39 | 40 | Convey("OptCredentials with uint32 overflow on uid should fail", t, func() { 41 | cfg := newStdioCfg() 42 | So(func() { OptStdioCredentials(math.MaxInt64, 1001, []int{2001, 2002})(&cfg) }, ShouldPanicWith, "invalid uid. overflows") 43 | }) 44 | 45 | Convey("OptCredentials with uint32 overflow on uid should fail", t, func() { 46 | cfg := newStdioCfg() 47 | So(func() { OptStdioCredentials(1, math.MaxInt64, []int{2001, 2002})(&cfg) }, ShouldPanicWith, "invalid gid. overflows") 48 | }) 49 | 50 | Convey("OptCredentials with uint32 overflow on groups should fail", t, func() { 51 | cfg := newStdioCfg() 52 | So(func() { OptStdioCredentials(1, 1, []int{2001, math.MaxInt64})(&cfg) }, ShouldPanicWith, "invalid group 1. overflows") 53 | }) 54 | } 55 | -------------------------------------------------------------------------------- /cli/init.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "errors" 5 | "log/slog" 6 | "os" 7 | "strings" 8 | 9 | "github.com/adrg/xdg" 10 | "github.com/spf13/viper" 11 | ) 12 | 13 | var ( 14 | cfgFile string 15 | cfgName string 16 | ) 17 | 18 | func initCobra() { 19 | 20 | viper.SetEnvPrefix("minibridge") 21 | viper.AutomaticEnv() 22 | viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) 23 | 24 | dataFolder, err := xdg.DataFile("minibridge") 25 | if err != nil { 26 | slog.Error("failed to retrieve xdg data folder: %w", err) 27 | os.Exit(1) 28 | } 29 | 30 | configFolder, err := xdg.ConfigFile("minibridge") 31 | if err != nil { 32 | slog.Error("failed to retrieve xdg data folder: %w", err) 33 | os.Exit(1) 34 | } 35 | 36 | slog.Debug("Folders configured", "config", configFolder, "data", dataFolder) 37 | 38 | if cfgFile == "" { 39 | cfgFile = os.Getenv("MINIBRIDGE_CONFIG") 40 | } 41 | 42 | if cfgFile != "" { 43 | if _, err := os.Stat(cfgFile); os.IsNotExist(err) { 44 | slog.Error("Config file does not exist", err) 45 | os.Exit(1) 46 | } 47 | 48 | viper.SetConfigType("yaml") 49 | viper.SetConfigFile(cfgFile) 50 | 51 | if err = viper.ReadInConfig(); err != nil { 52 | slog.Error("Unable to read config", 53 | "path", cfgFile, 54 | err, 55 | ) 56 | os.Exit(1) 57 | } 58 | 59 | slog.Debug("Using config file", "path", cfgFile) 60 | return 61 | } 62 | 63 | viper.AddConfigPath(configFolder) 64 | viper.AddConfigPath("/usr/local/etc/minibridge") 65 | viper.AddConfigPath("/etc/minibridge") 66 | 67 | if cfgName == "" { 68 | cfgName = os.Getenv("MINIBRIDGE_CONFIG_NAME") 69 | } 70 | 71 | if cfgName == "" { 72 | cfgName = "default" 73 | } 74 | 75 | viper.SetConfigName(cfgName) 76 | 77 | if err = viper.ReadInConfig(); err != nil { 78 | if !errors.As(err, &viper.ConfigFileNotFoundError{}) { 79 | slog.Error("Unable to read config", err) 80 | os.Exit(1) 81 | } 82 | } 83 | 84 | slog.Debug("Using config name", "name", cfgName) 85 | } 86 | -------------------------------------------------------------------------------- /pkgs/policer/api/request.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "time" 5 | 6 | "go.acuvity.ai/minibridge/pkgs/mcp" 7 | ) 8 | 9 | // CallType type of request to the policer. 10 | type CallType string 11 | 12 | // Various values of RequestType 13 | var ( 14 | CallTypeRequest CallType = "request" 15 | CallTypeResponse CallType = "response" 16 | ) 17 | 18 | // SpanContext contains information about the OTEL span 19 | // related to a Request. 20 | type SpanContext struct { 21 | TraceID string `json:"traceID" ` 22 | ParentSpanID string `json:"parentSpanID,omitempty"` 23 | End time.Time `json:"end"` 24 | ID string `json:"ID"` 25 | Name string `json:"name"` 26 | Start time.Time `json:"start"` 27 | } 28 | 29 | func (c SpanContext) IsValid() bool { 30 | return c.TraceID != "" && c.ID != "" 31 | } 32 | 33 | // A Request represents the data sent to the Policer 34 | type Request struct { 35 | 36 | // Type of the request. Request will be set for request from the agent 37 | // and Response will be set for replies from the MPC server. 38 | Type CallType `json:"type"` 39 | 40 | // MPC embeds the full MPC call, either request or response, 41 | // based on the Type. 42 | MCP mcp.Message `json:"mcp,omitzero"` 43 | 44 | // Agent contains callers information. 45 | Agent Agent `json:"agent,omitzero"` 46 | 47 | // SpanContext contains info about the eventual OTEL span 48 | // for the request. There are advanced use cases where you 49 | // want correlation between a Request and the OTEL traces 50 | // associated. 51 | SpanContext SpanContext `json:"spanContext,omitzero"` 52 | } 53 | 54 | // Agent contains information about the caller of the request. 55 | type Agent struct { 56 | 57 | // User contains the user from the Auth header. 58 | User string `json:"user"` 59 | 60 | // Password contains the password from the Auth header. 61 | Password string `json:"password"` 62 | 63 | // RemoteAddr contains the agent's RemoteAddr, as seen by minibridge. 64 | RemoteAddr string `json:"remoteAddr"` 65 | 66 | // User Agent contains the user agent field of the agent. 67 | UserAgent string `json:"userAgent,omitempty"` 68 | } 69 | -------------------------------------------------------------------------------- /pkgs/scan/sbom.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "go.acuvity.ai/elemental" 8 | ) 9 | 10 | // Exclusions contains the resources we 11 | // want to exclude from scan 12 | type Exclusions struct { 13 | Prompts bool 14 | Resources bool 15 | Tools bool 16 | } 17 | 18 | // SBOM contains a list of hashes for hashable 19 | // resources. 20 | type SBOM struct { 21 | Tools Hashes `json:"tools,omitzero"` 22 | Prompts Hashes `json:"prompts,omitzero"` 23 | } 24 | 25 | func LoadSBOM(path string) (sbom SBOM, err error) { 26 | 27 | data, err := os.ReadFile(path) // #nosec: G304 28 | if err != nil { 29 | return sbom, fmt.Errorf("unable to load sbom file at '%s': %w", path, err) 30 | } 31 | 32 | if err := elemental.Decode(elemental.EncodingTypeJSON, data, &sbom); err != nil { 33 | return sbom, fmt.Errorf("unable to decode content of sbom file: %w", err) 34 | } 35 | 36 | return sbom, nil 37 | } 38 | 39 | // Hashes are a list of Hash. 40 | type Hashes []Hash 41 | 42 | // Matches return nil if both receiver and o 43 | // match, meaning len are identical, and all hashes 44 | // on h match hashes on o. 45 | func (h Hashes) Matches(o Hashes) error { 46 | return cmpH(h, o) 47 | } 48 | 49 | // Map converts the Hashes into a map[string]Hash keyed by the Hash Name. 50 | func (l Hashes) Map() map[string]Hash { 51 | 52 | out := make(map[string]Hash, len(l)) 53 | 54 | for _, h := range l { 55 | out[h.Name] = h 56 | } 57 | 58 | return out 59 | } 60 | 61 | // A Hash represent the hash of an item with it's name 62 | // and potential parameters. 63 | type Hash struct { 64 | Name string `json:"name"` 65 | Hash string `json:"hash"` 66 | Params Hashes `json:"params,omitzero"` 67 | } 68 | 69 | func cmpH(a Hashes, b Hashes) error { 70 | 71 | if len(b) > len(a) { 72 | return fmt.Errorf("invalid len. left: %d right: %d", len(a), len(b)) 73 | } 74 | 75 | am := a.Map() 76 | bm := b.Map() 77 | 78 | for name, h := range bm { 79 | 80 | o, ok := am[name] 81 | if !ok { 82 | return fmt.Errorf("'%s': missing", name) 83 | } 84 | 85 | if h.Hash != o.Hash { 86 | return fmt.Errorf("'%s': hash mismatch", name) 87 | } 88 | 89 | if len(o.Params) > 0 { 90 | 91 | if err := cmpH(o.Params, h.Params); err != nil { 92 | return fmt.Errorf("'%s': invalid param: %w", name, err) 93 | } 94 | } 95 | } 96 | 97 | return nil 98 | } 99 | -------------------------------------------------------------------------------- /pkgs/memconn/listener.go: -------------------------------------------------------------------------------- 1 | package memconn 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | ) 8 | 9 | // Listener is an in-memory net.Listener. Call DialContext to create a new 10 | // connection. 11 | type Listener struct { 12 | pending chan *Conn 13 | closed chan struct{} 14 | } 15 | 16 | var _ net.Listener = (*Listener)(nil) 17 | 18 | // NewListener creates a new in-memory Listener. 19 | func NewListener() *Listener { 20 | return &Listener{ 21 | pending: make(chan *Conn), 22 | closed: make(chan struct{}), 23 | } 24 | } 25 | 26 | // Accept waits for and returns the next connection to l. Connections to l are 27 | // established by calling l.DialContext. 28 | // 29 | // The returned net.Conn is the server side of the connection. 30 | func (l *Listener) Accept() (net.Conn, error) { 31 | select { 32 | case peer := <-l.pending: 33 | local := newConn() 34 | peer.Attach(local) 35 | local.Attach(peer) 36 | return local, nil 37 | 38 | case <-l.closed: 39 | return nil, fmt.Errorf("Listener closed") 40 | } 41 | } 42 | 43 | // Close closes l. Any blocked Accept operations will immediately be unblocked 44 | // and return errors. Already Accepted connections are not closed. 45 | func (l *Listener) Close() error { 46 | select { 47 | default: 48 | close(l.closed) 49 | return nil 50 | case <-l.closed: 51 | return fmt.Errorf("already closed") 52 | } 53 | } 54 | 55 | // Addr returns l's address. This will always be a fake "memory" 56 | // address. 57 | func (l *Listener) Addr() net.Addr { 58 | return Addr{} 59 | } 60 | 61 | // DialContext creates a new connection to l. DialContext will block until the 62 | // connection is accepted through a blocked l.Accept call or until ctx is 63 | // canceled. 64 | // 65 | // Note that unlike other Dial methods in different packages, there is no 66 | // address to supply because the remote side of the connection is always the 67 | // in-memory listener. 68 | func (l *Listener) DialContext(ctx context.Context, clientName string) (net.Conn, error) { 69 | local := newConn() 70 | local.name = clientName 71 | 72 | select { 73 | case l.pending <- local: 74 | // Wait for our peer to be connected. 75 | if err := local.WaitPeer(ctx); err != nil { 76 | return nil, err 77 | } 78 | return local, nil 79 | case <-l.closed: 80 | return nil, fmt.Errorf("server closed") 81 | case <-ctx.Done(): 82 | return nil, ctx.Err() 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /pkgs/auth/auth.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | 7 | "go.acuvity.ai/minibridge/pkgs/oauth" 8 | ) 9 | 10 | // AuthScheme represents the various auth schemes. 11 | type AuthScheme int 12 | 13 | // Supported version of auth schemes. 14 | const ( 15 | AuthSchemeBasic AuthScheme = iota 16 | AuthSchemeBearer 17 | AuthSchemeOAuth 18 | ) 19 | 20 | // Auth holds user credentials. 21 | type Auth struct { 22 | mode AuthScheme 23 | user string 24 | password string 25 | 26 | oauthCreds oauth.Credentials 27 | } 28 | 29 | // NewBasicAuth returns a new Basic Auth. 30 | func NewBasicAuth(user string, password string) *Auth { 31 | return &Auth{ 32 | mode: AuthSchemeBasic, 33 | user: user, 34 | password: password, 35 | } 36 | } 37 | 38 | // NewBearerAuth returns a new Bearer auth. 39 | // User() will be set to "Bearer" and Password() ] 40 | // will hold the token. 41 | func NewBearerAuth(token string) *Auth { 42 | return &Auth{ 43 | mode: AuthSchemeBearer, 44 | user: "Bearer", 45 | password: token, 46 | } 47 | } 48 | 49 | // NewOAuthAuth returns a new OAuth auth. 50 | // User() will be set to "Bearer" and Password() ] 51 | // will hold the access token. 52 | func NewOAuthAuth(creds oauth.Credentials) *Auth { 53 | return &Auth{ 54 | mode: AuthSchemeOAuth, 55 | user: "Bearer", 56 | password: creds.AccessToken, 57 | oauthCreds: creds, 58 | } 59 | } 60 | 61 | // Type returns the current type of Auth as a string. 62 | func (a *Auth) Type() string { 63 | switch a.mode { 64 | case AuthSchemeBasic: 65 | return "Basic" 66 | case AuthSchemeBearer: 67 | return "Bearer" 68 | case AuthSchemeOAuth: 69 | return "OAuth" 70 | default: 71 | panic("unknown auth mode") 72 | } 73 | } 74 | 75 | // Encode encode the Auth to transmit on the wire. 76 | func (a *Auth) Encode() string { 77 | switch a.mode { 78 | case AuthSchemeBasic: 79 | return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(fmt.Appendf([]byte{}, "%s:%s", a.user, a.password))) 80 | case AuthSchemeBearer, AuthSchemeOAuth: 81 | return fmt.Sprintf("Bearer %s", a.password) 82 | default: 83 | panic("unknown auth mode") 84 | } 85 | } 86 | 87 | // User returns the user. 88 | func (a *Auth) User() string { 89 | return a.user 90 | } 91 | 92 | // Password returns the password. 93 | func (a *Auth) Password() string { 94 | return a.password 95 | } 96 | -------------------------------------------------------------------------------- /pkgs/backend/options_test.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "testing" 7 | 8 | . "github.com/smartystreets/goconvey/convey" 9 | "go.acuvity.ai/bahamut" 10 | "go.acuvity.ai/minibridge/pkgs/mcp" 11 | "go.acuvity.ai/minibridge/pkgs/metrics" 12 | "go.acuvity.ai/minibridge/pkgs/policer/api" 13 | "go.acuvity.ai/minibridge/pkgs/scan" 14 | "go.opentelemetry.io/otel/trace/noop" 15 | ) 16 | 17 | type fakePolicer struct { 18 | } 19 | 20 | func (f fakePolicer) Police(context.Context, api.Request) (*mcp.Message, error) { 21 | return nil, nil 22 | } 23 | 24 | func (f fakePolicer) Type() string { return "fake" } 25 | 26 | func TestOptions(t *testing.T) { 27 | 28 | Convey("OptPolicer should work", t, func() { 29 | cfg := newWSCfg() 30 | f := fakePolicer{} 31 | OptPolicer(f)(&cfg) 32 | So(cfg.policer, ShouldEqual, f) 33 | }) 34 | 35 | Convey("OptDumpStderrOnError", t, func() { 36 | cfg := newWSCfg() 37 | OptDumpStderrOnError(true)(&cfg) 38 | So(cfg.dumpStderr, ShouldBeTrue) 39 | }) 40 | 41 | Convey("OPtCORSPolicy should work", t, func() { 42 | cfg := newWSCfg() 43 | p := &bahamut.CORSPolicy{} 44 | OptCORSPolicy(p)(&cfg) 45 | So(cfg.corsPolicy, ShouldEqual, p) 46 | }) 47 | 48 | Convey("OptSBOM should work", t, func() { 49 | cfg := newWSCfg() 50 | s := scan.SBOM{} 51 | OptSBOM(s)(&cfg) 52 | So(cfg.sbom, ShouldEqual, s) 53 | }) 54 | 55 | Convey("OptMetricsManager should work", t, func() { 56 | cfg := newWSCfg() 57 | mm := &metrics.Manager{} 58 | OptMetricsManager(mm)(&cfg) 59 | So(cfg.metricsManager, ShouldEqual, mm) 60 | }) 61 | 62 | Convey("OptTracer should work", t, func() { 63 | cfg := newWSCfg() 64 | t := noop.NewTracerProvider().Tracer("test") 65 | OptTracer(t)(&cfg) 66 | So(cfg.tracer, ShouldEqual, t) 67 | }) 68 | 69 | Convey("OptTracer with nil should work", t, func() { 70 | cfg := newWSCfg() 71 | OptTracer(nil)(&cfg) 72 | So(cfg.tracer, ShouldHaveSameTypeAs, noop.NewTracerProvider().Tracer("test")) 73 | }) 74 | 75 | Convey("OptListener should work", t, func() { 76 | cfg := newWSCfg() 77 | listener := tls.NewListener(nil, nil) 78 | OptListener(listener)(&cfg) 79 | So(cfg.listener, ShouldEqual, listener) 80 | }) 81 | 82 | Convey("OptPolicerEnforce should work", t, func() { 83 | cfg := newWSCfg() 84 | So(cfg.policerEnforced, ShouldBeTrue) 85 | OptPolicerEnforce(false)(&cfg) 86 | So(cfg.policerEnforced, ShouldBeFalse) 87 | }) 88 | } 89 | -------------------------------------------------------------------------------- /pkgs/mcp/id.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "math" 7 | "reflect" 8 | ) 9 | 10 | // RelatedIDs returns true if the two given 11 | // ID as any are equal. 12 | // Slow clap on spec that basically says 13 | // MUST SHOULD BE AN INT. Or a string. whatever... 14 | func RelatedIDs(a any, b any) bool { 15 | 16 | if sa, ok := a.(string); ok { 17 | sb, ok := b.(string) 18 | return ok && sa == sb 19 | } 20 | 21 | if isNumeric(a) && isNumeric(b) { 22 | ia, oka := extractInt64(a) 23 | ib, okb := extractInt64(b) 24 | return oka && okb && ia == ib 25 | } 26 | 27 | return reflect.TypeOf(a) == reflect.TypeOf(b) && reflect.DeepEqual(a, b) 28 | } 29 | 30 | func normalizeID(id any) string { 31 | switch v := id.(type) { 32 | case string: 33 | return v 34 | case int: 35 | return fmt.Sprintf("%d", v) 36 | case int64: 37 | return fmt.Sprintf("%d", v) 38 | case uint64: 39 | return fmt.Sprintf("%d", v) 40 | case float64: 41 | return fmt.Sprintf("%.0f", v) 42 | case json.Number: 43 | return v.String() 44 | default: 45 | return fmt.Sprintf("%v", v) 46 | } 47 | } 48 | 49 | func isNumeric(v any) bool { 50 | switch v.(type) { 51 | case int, int8, int16, int32, int64, 52 | uint, uint8, uint16, uint32, uint64, 53 | float32, float64, json.Number: 54 | return true 55 | default: 56 | return false 57 | } 58 | } 59 | 60 | func extractInt64(v any) (int64, bool) { 61 | switch val := v.(type) { 62 | case int: 63 | return int64(val), true 64 | case int64: 65 | return val, true 66 | case int32: 67 | return int64(val), true 68 | case int16: 69 | return int64(val), true 70 | case int8: 71 | return int64(val), true 72 | case uint: 73 | if val <= math.MaxInt64 { 74 | return int64(val), true 75 | } 76 | return 0, false 77 | case uint64: 78 | if val <= math.MaxInt64 { 79 | return int64(val), true 80 | } 81 | return 0, false 82 | case uint32: 83 | return int64(val), true 84 | case float64: 85 | if val == float64(int64(val)) { 86 | return int64(val), true 87 | } 88 | return 0, false 89 | case float32: 90 | f := float64(val) 91 | if f == float64(int64(f)) { 92 | return int64(f), true 93 | } 94 | return 0, false 95 | case json.Number: 96 | if i, err := val.Int64(); err == nil { 97 | return i, true 98 | } 99 | return 0, false 100 | default: 101 | return 0, false 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /pkgs/policer/internal/http/policer.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/tls" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "strings" 11 | 12 | "go.acuvity.ai/elemental" 13 | "go.acuvity.ai/minibridge/pkgs/auth" 14 | "go.acuvity.ai/minibridge/pkgs/mcp" 15 | "go.acuvity.ai/minibridge/pkgs/policer/api" 16 | ) 17 | 18 | type Policer struct { 19 | endpoint string 20 | auth auth.Auth 21 | client *http.Client 22 | } 23 | 24 | // New returns a new HTTP based Policer. 25 | func New(endpoint string, auth *auth.Auth, tlsConfig *tls.Config) *Policer { 26 | 27 | return &Policer{ 28 | endpoint: endpoint, 29 | auth: *auth, 30 | client: &http.Client{ 31 | Transport: &http.Transport{ 32 | TLSClientConfig: tlsConfig, 33 | }, 34 | }, 35 | } 36 | } 37 | 38 | func (p *Policer) Type() string { return "http" } 39 | 40 | func (p *Policer) Police(ctx context.Context, preq api.Request) (*mcp.Message, error) { 41 | 42 | body, err := elemental.Encode(elemental.EncodingTypeJSON, preq) 43 | if err != nil { 44 | return nil, fmt.Errorf("unable to encode scan request: %w", err) 45 | } 46 | 47 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.endpoint, bytes.NewBuffer(body)) 48 | if err != nil { 49 | return nil, fmt.Errorf("unable to create new http request: %w", err) 50 | } 51 | 52 | req.Header.Add("Accept", "application/json") 53 | req.Header.Add("Content-Type", "application/json") 54 | req.Header.Add("Authorization", p.auth.Encode()) 55 | 56 | resp, err := p.client.Do(req) 57 | if err != nil { 58 | return nil, fmt.Errorf("unable to send request: %w", err) 59 | } 60 | 61 | if resp.StatusCode == http.StatusNoContent { 62 | return nil, nil 63 | } 64 | 65 | rbody, err := io.ReadAll(resp.Body) 66 | if err != nil { 67 | return nil, fmt.Errorf("unable to read response body: %w", err) 68 | } 69 | defer func() { _ = resp.Body.Close() }() 70 | 71 | if resp.StatusCode != http.StatusOK { 72 | return nil, fmt.Errorf("invalid response from policer `%s`: %s", string(rbody), resp.Status) 73 | } 74 | 75 | sresp := api.Response{} 76 | if err := elemental.Decode(elemental.EncodingTypeJSON, rbody, &sresp); err != nil { 77 | return nil, fmt.Errorf("unable to decode response body: %w", err) 78 | } 79 | 80 | if sresp.MCP != nil && preq.MCP.ID != nil { 81 | sresp.MCP.ID = preq.MCP.ID 82 | } 83 | 84 | if sresp.Allow { 85 | return sresp.MCP, nil 86 | } 87 | 88 | if len(sresp.Reasons) == 0 { 89 | sresp.Reasons = []string{api.GenericDenyReason} 90 | } 91 | 92 | return nil, fmt.Errorf("%w: %s", api.ErrBlocked, strings.Join(sresp.Reasons, ", ")) 93 | } 94 | -------------------------------------------------------------------------------- /pkgs/backend/helpers_test.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | 8 | "go.acuvity.ai/minibridge/pkgs/auth" 9 | ) 10 | 11 | func Test_makeMCPError(t *testing.T) { 12 | type args struct { 13 | ID any 14 | err error 15 | } 16 | tests := []struct { 17 | name string 18 | args func(t *testing.T) args 19 | 20 | want1 []byte 21 | }{ 22 | { 23 | "basic", 24 | func(*testing.T) args { 25 | return args{ 26 | ID: 42, 27 | err: fmt.Errorf("oh noes!"), 28 | } 29 | }, 30 | []byte(`{"error":{"code":451,"message":"oh noes!"},"id":42,"jsonrpc":"2.0"}`), 31 | }, 32 | } 33 | 34 | for _, tt := range tests { 35 | t.Run(tt.name, func(t *testing.T) { 36 | tArgs := tt.args(t) 37 | 38 | got1 := makeMCPError(tArgs.ID, tArgs.err) 39 | 40 | if !reflect.DeepEqual(got1, tt.want1) { 41 | t.Errorf("makeMCPError got1 = %v, want1: %v", string(got1), string(tt.want1)) 42 | } 43 | }) 44 | } 45 | } 46 | 47 | func Test_parseBasicAuth(t *testing.T) { 48 | type args struct { 49 | authString string 50 | } 51 | tests := []struct { 52 | name string 53 | args func(t *testing.T) args 54 | 55 | want1 *auth.Auth 56 | want2 bool 57 | }{ 58 | { 59 | "empty header", 60 | func(t *testing.T) args { 61 | return args{ 62 | authString: "", 63 | } 64 | }, 65 | nil, 66 | false, 67 | }, 68 | { 69 | "bearer", 70 | func(t *testing.T) args { 71 | return args{ 72 | authString: "Bearer token", 73 | } 74 | }, 75 | auth.NewBearerAuth("token"), 76 | true, 77 | }, 78 | { 79 | "basic", 80 | func(t *testing.T) args { 81 | return args{ 82 | authString: "Basic dXNlcjpwYXNz", 83 | } 84 | }, 85 | auth.NewBasicAuth("user", "pass"), 86 | true, 87 | }, 88 | { 89 | "invalid basic b64", 90 | func(t *testing.T) args { 91 | return args{ 92 | authString: "Basic not-b64", 93 | } 94 | }, 95 | nil, 96 | false, 97 | }, 98 | { 99 | "invalid basic decoded", 100 | func(t *testing.T) args { 101 | return args{ 102 | authString: "Basic aGVsbG8=", 103 | } 104 | }, 105 | nil, 106 | false, 107 | }, 108 | } 109 | 110 | for _, tt := range tests { 111 | t.Run(tt.name, func(t *testing.T) { 112 | tArgs := tt.args(t) 113 | 114 | got1, got2 := parseBasicAuth(tArgs.authString) 115 | 116 | if !reflect.DeepEqual(got1, tt.want1) { 117 | t.Errorf("parseBasicAuth got1 = %v, want1: %v", got1, tt.want1) 118 | } 119 | 120 | if !reflect.DeepEqual(got2, tt.want2) { 121 | t.Errorf("parseBasicAuth got2 = %v, want2: %v", got2, tt.want2) 122 | } 123 | }) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /cli/internal/cmd/backend.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "log/slog" 6 | 7 | "github.com/spf13/cobra" 8 | "github.com/spf13/pflag" 9 | "github.com/spf13/viper" 10 | "go.acuvity.ai/minibridge/pkgs/backend" 11 | ) 12 | 13 | var fBackend = pflag.NewFlagSet("backend", pflag.ExitOnError) 14 | 15 | func init() { 16 | 17 | initSharedFlagSet() 18 | 19 | fBackend.StringP("listen", "l", ":8000", "listen address of the bridge for incoming websocket connections.") 20 | 21 | Backend.Flags().AddFlagSet(fBackend) 22 | Backend.Flags().AddFlagSet(fPolicer) 23 | Backend.Flags().AddFlagSet(fTLSServer) 24 | Backend.Flags().AddFlagSet(fHealth) 25 | Backend.Flags().AddFlagSet(fProfiler) 26 | Backend.Flags().AddFlagSet(fCORS) 27 | Backend.Flags().AddFlagSet(fSBOM) 28 | Backend.Flags().AddFlagSet(fMCP) 29 | } 30 | 31 | // Backend is the cobra command to run the server. 32 | var Backend = &cobra.Command{ 33 | Use: "backend [flags] -- command [args...]", 34 | Short: "Start a minibridge backend to expose an MCP server", 35 | SilenceUsage: true, 36 | SilenceErrors: true, 37 | TraverseChildren: true, 38 | Args: cobra.MinimumNArgs(1), 39 | 40 | RunE: func(cmd *cobra.Command, args []string) error { 41 | 42 | listen := viper.GetString("listen") 43 | 44 | if listen == "" { 45 | return fmt.Errorf("--listen must be set") 46 | } 47 | 48 | backendTLSConfig, err := tlsConfigFromFlags(fTLSServer) 49 | if err != nil { 50 | return err 51 | } 52 | 53 | policer, penforce, err := makePolicer() 54 | if err != nil { 55 | return fmt.Errorf("unable to make policer: %w", err) 56 | } 57 | 58 | sbom, err := makeSBOM() 59 | if err != nil { 60 | return fmt.Errorf("unable to make hashes: %w", err) 61 | } 62 | 63 | tracer, err := makeTracer(cmd.Context(), "backend") 64 | if err != nil { 65 | return fmt.Errorf("unable to configure tracer: %w", err) 66 | } 67 | 68 | corsPolicy := makeCORSPolicy() 69 | 70 | mcpClient, err := makeMCPClient(args, true) 71 | if err != nil { 72 | return fmt.Errorf("unable to create MCP client: %w", err) 73 | } 74 | 75 | mm := startHealthServer(cmd.Context()) 76 | 77 | slog.Info("Minibridge backend configured", 78 | "server-tls", backendTLSConfig != nil, 79 | "server-mtls", mtlsMode(backendTLSConfig), 80 | "listen", listen, 81 | ) 82 | 83 | proxy := backend.NewWebSocket(listen, backendTLSConfig, mcpClient, 84 | backend.OptPolicer(policer), 85 | backend.OptPolicerEnforce(penforce), 86 | backend.OptDumpStderrOnError(viper.GetString("log-format") != "json"), 87 | backend.OptCORSPolicy(corsPolicy), 88 | backend.OptSBOM(sbom), 89 | backend.OptMetricsManager(mm), 90 | backend.OptTracer(tracer), 91 | ) 92 | 93 | return proxy.Start(cmd.Context()) 94 | }, 95 | } 96 | -------------------------------------------------------------------------------- /pkgs/backend/options.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "net" 5 | 6 | "go.acuvity.ai/bahamut" 7 | "go.acuvity.ai/minibridge/pkgs/metrics" 8 | "go.acuvity.ai/minibridge/pkgs/policer" 9 | "go.acuvity.ai/minibridge/pkgs/scan" 10 | "go.opentelemetry.io/otel/trace" 11 | "go.opentelemetry.io/otel/trace/noop" 12 | ) 13 | 14 | type wsCfg struct { 15 | corsPolicy *bahamut.CORSPolicy 16 | dumpStderr bool 17 | listener net.Listener 18 | metricsManager *metrics.Manager 19 | policer policer.Policer 20 | policerEnforced bool 21 | sbom scan.SBOM 22 | tracer trace.Tracer 23 | } 24 | 25 | func newWSCfg() wsCfg { 26 | return wsCfg{ 27 | tracer: noop.NewTracerProvider().Tracer("noop"), 28 | policerEnforced: true, 29 | } 30 | } 31 | 32 | // Option are options that can be given to NewStdio(). 33 | type Option func(*wsCfg) 34 | 35 | // OptPolicer sets the Policer to forward the traffic to. 36 | func OptPolicer(policer policer.Policer) Option { 37 | return func(cfg *wsCfg) { 38 | cfg.policer = policer 39 | } 40 | } 41 | 42 | // OptPolicerEnforce sets the Policer decision should be enforced 43 | // or just logged. The default is true if a policer is set 44 | func OptPolicerEnforce(enforced bool) Option { 45 | return func(cfg *wsCfg) { 46 | cfg.policerEnforced = enforced 47 | } 48 | } 49 | 50 | // OptDumpStderrOnError controls whether the WS server should 51 | // dump the stderr of the MCP server as is, or in a log. 52 | func OptDumpStderrOnError(dump bool) Option { 53 | return func(cfg *wsCfg) { 54 | cfg.dumpStderr = dump 55 | } 56 | } 57 | 58 | // OptCORSPolicy sets the bahamut.CORSPolicy to use for 59 | // connection originating from a webrowser. 60 | func OptCORSPolicy(policy *bahamut.CORSPolicy) Option { 61 | return func(cfg *wsCfg) { 62 | cfg.corsPolicy = policy 63 | } 64 | } 65 | 66 | // OptSBOM sets a the utils.SBOM to use to verify 67 | // server integrity. 68 | func OptSBOM(sbom scan.SBOM) Option { 69 | return func(cfg *wsCfg) { 70 | cfg.sbom = sbom 71 | } 72 | } 73 | 74 | // OptMetricsManager sets the metric manager to use to collect 75 | // prometheus metrics. 76 | func OptMetricsManager(m *metrics.Manager) Option { 77 | return func(cfg *wsCfg) { 78 | cfg.metricsManager = m 79 | } 80 | } 81 | 82 | // OptTracer sets the otel trace.Tracer to use to trace requests 83 | func OptTracer(tracer trace.Tracer) Option { 84 | return func(cfg *wsCfg) { 85 | if tracer == nil { 86 | tracer = noop.NewTracerProvider().Tracer("noop") 87 | } 88 | cfg.tracer = tracer 89 | } 90 | } 91 | 92 | // OptListener sets the listener to use for the server. 93 | // by defaut, it will use a classic listener. 94 | func OptListener(listener net.Listener) Option { 95 | return func(cfg *wsCfg) { 96 | cfg.listener = listener 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /examples/policer-http/policer.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | """ 3 | This is an example implementation for a policer 4 | """ 5 | 6 | import json 7 | import jwt 8 | from termcolor import colored as c 9 | from flask import Flask, request, Response 10 | 11 | app = Flask(__name__) 12 | 13 | FORBIDDEN_TOOLS = { 14 | "*": ["printEnv"], 15 | "bob@example.com": ["longRunningOperation"], 16 | "alice@example.com": [], 17 | } 18 | 19 | 20 | @app.route("/police", methods=["POST"]) 21 | def police(): 22 | """handles /police""" 23 | 24 | req = request.get_json() 25 | agent = req["agent"] 26 | mcp = req["mcp"] 27 | 28 | print() 29 | print("---") 30 | print(c(f"Type: {req['type']}", "green" if req["type"] == "request" else "yellow")) 31 | print(c(f"Agent: {agent['userAgent']} {agent['remoteAddr']}", "blue")) 32 | print() 33 | print(c(f"{request.headers}", "dark_grey")) 34 | print(json.dumps(req, sort_keys=True, indent=4)) 35 | 36 | # Check the agent token. Deny if not valid. 37 | # This example is meant to work with JWT issued 38 | # with the gen-test-tokens.sh located a the parent folder. 39 | try: 40 | claims = jwt.decode( 41 | agent["password"], 42 | "secret", 43 | algorithms=["HS256"], 44 | issuer="pki.example.com", 45 | audience="minibridge", 46 | ) 47 | except Exception as e: 48 | print(c(f"DENIED: invalid token: {e}", "red")) 49 | return json.dumps({"allow": False, "reasons": [f"{e}"]}) 50 | 51 | # This is an example of blanket policing. We deny access 52 | # to the tool/calls declared in FORBIDDEN_TOOLS 53 | if req["type"] == "request": 54 | if ( 55 | mcp["method"] == "tools/call" 56 | and "name" in mcp["params"] 57 | and ( 58 | mcp["params"]["name"] in FORBIDDEN_TOOLS["*"] 59 | or mcp["params"]["name"] in FORBIDDEN_TOOLS[claims["email"]] 60 | ) 61 | ): 62 | dmsg = f"forbidden method call {mcp['params']['name']} {mcp['method']}" 63 | print(c(f"DENIED: {dmsg}", "red")) 64 | return json.dumps({"allow": False, "reasons": [dmsg]}) 65 | 66 | # This is an example of redaction: If the 67 | # user is Bob, then we remove the tool named `longRunningOperation`. 68 | # from the response. 69 | if ( 70 | req["type"] == "response" 71 | and "result" in req["mcp"] 72 | and claims["email"] == "bob@example.com" 73 | ): 74 | result = req["mcp"]["result"] 75 | if "tools" in result: 76 | result["tools"] = [ 77 | cell 78 | for cell in result["tools"] 79 | if cell["name"] != "longRunningOperation" 80 | ] 81 | return Response( 82 | status=200, response=json.dumps({"allow": True, "mcp": mcp}) 83 | ) 84 | 85 | # otherwise we allow everything 86 | return Response(status=204) 87 | 88 | 89 | app.run(port=5000) 90 | -------------------------------------------------------------------------------- /pkgs/internal/cors/cors_test.go: -------------------------------------------------------------------------------- 1 | package cors 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | . "github.com/smartystreets/goconvey/convey" 9 | "go.acuvity.ai/bahamut" 10 | ) 11 | 12 | func TestThing(t *testing.T) { 13 | 14 | Convey("Given I have a response writer, req and corsPolicy", t, func() { 15 | w := httptest.NewRecorder() 16 | req := &http.Request{} 17 | pol := &bahamut.CORSPolicy{ 18 | AllowOrigin: "https://coucou.test", 19 | AllowCredentials: true, 20 | MaxAge: 1500, 21 | AllowHeaders: []string{ 22 | "Authorization", 23 | }, 24 | AllowMethods: []string{ 25 | "GET", 26 | "POST", 27 | "OPTIONS", 28 | }, 29 | } 30 | 31 | Convey("Calling HandleCors should work on OPTIONS", func() { 32 | 33 | req.Method = http.MethodOptions 34 | 35 | shouldCont := HandleCORS(w, req, pol) 36 | So(shouldCont, ShouldBeFalse) 37 | So(w.Result().Header, ShouldResemble, http.Header{ 38 | "Access-Control-Allow-Credentials": {"true"}, 39 | "Access-Control-Allow-Headers": {"Authorization"}, 40 | "Access-Control-Allow-Methods": {"GET, POST, OPTIONS"}, 41 | "Access-Control-Allow-Origin": {"https://coucou.test"}, 42 | "Access-Control-Expose-Headers": {""}, 43 | "Access-Control-Max-Age": {"1500"}, 44 | "Cache-Control": {"private, no-transform"}, 45 | "Strict-Transport-Security": {"max-age=31536000; includeSubDomains; preload"}, 46 | "X-Content-Type-Options": {"nosniff"}, 47 | "X-Frame-Options": {"DENY"}, 48 | "X-Xss-Protection": {"1; mode=block"}, 49 | }) 50 | }) 51 | 52 | Convey("Calling HandleCors should work on non OPTIONS", func() { 53 | 54 | req.Method = http.MethodPost 55 | 56 | shouldCont := HandleCORS(w, req, pol) 57 | So(shouldCont, ShouldBeTrue) 58 | So(w.Result().Header, ShouldResemble, http.Header{ 59 | "Access-Control-Allow-Credentials": {"true"}, 60 | "Access-Control-Allow-Origin": {"https://coucou.test"}, 61 | "Access-Control-Expose-Headers": {""}, 62 | "Cache-Control": {"private, no-transform"}, 63 | "Strict-Transport-Security": {"max-age=31536000; includeSubDomains; preload"}, 64 | "X-Content-Type-Options": {"nosniff"}, 65 | "X-Frame-Options": {"DENY"}, 66 | "X-Xss-Protection": {"1; mode=block"}, 67 | }) 68 | }) 69 | 70 | Convey("Calling HandleCors should work with no policy", func() { 71 | 72 | req.Method = http.MethodPost 73 | 74 | shouldCont := HandleCORS(w, req, nil) 75 | So(shouldCont, ShouldBeTrue) 76 | So(w.Result().Header, ShouldResemble, http.Header{ 77 | "Cache-Control": {"private, no-transform"}, 78 | "Strict-Transport-Security": {"max-age=31536000; includeSubDomains; preload"}, 79 | "X-Content-Type-Options": {"nosniff"}, 80 | "X-Frame-Options": {"DENY"}, 81 | "X-Xss-Protection": {"1; mode=block"}, 82 | }) 83 | }) 84 | }) 85 | } 86 | -------------------------------------------------------------------------------- /pkgs/frontend/connect.go: -------------------------------------------------------------------------------- 1 | package frontend 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log/slog" 10 | "net" 11 | "net/http" 12 | "strings" 13 | 14 | "go.acuvity.ai/minibridge/pkgs/auth" 15 | "go.acuvity.ai/wsc" 16 | "go.opentelemetry.io/otel" 17 | "go.opentelemetry.io/otel/propagation" 18 | ) 19 | 20 | var ErrAuthRequired = errors.New("authorization required") 21 | 22 | // AgentInfo holds information about the agent 23 | // who sent an MCPCall. 24 | type AgentInfo struct { 25 | Auth *auth.Auth 26 | AuthHeaders []string 27 | UserAgent string 28 | RemoteAddr string 29 | } 30 | 31 | // Connect is a low level function to connect to the backend's websocket 32 | func Connect( 33 | ctx context.Context, 34 | dialer func(ctx context.Context, network, addr string) (net.Conn, error), 35 | backendURL string, 36 | tlsConfig *tls.Config, 37 | info AgentInfo, 38 | ) (wsc.Websocket, error) { 39 | 40 | slog.Debug("New websocket connection", 41 | "url", backendURL, 42 | "using-auth", info.Auth != nil, 43 | "using-headers", len(info.AuthHeaders) > 0, 44 | "tls", strings.HasPrefix(backendURL, "wss://"), 45 | "tls-config", tlsConfig != nil, 46 | ) 47 | 48 | if dialer == nil && (info.Auth != nil || len(info.AuthHeaders) > 0) && tlsConfig == nil { 49 | slog.Warn("Security: connecting to a websocket with crendentials sent over the network in clear-text. Refused. Credentials have been stripped. Request will proceed and will likely fail.") 50 | } 51 | 52 | wsconfig := wsc.Config{ 53 | WriteChanSize: 64, 54 | ReadChanSize: 16, 55 | TLSConfig: tlsConfig, 56 | NetDialContextFunc: dialer, 57 | } 58 | 59 | wsconfig.Headers = http.Header{ 60 | "X-Forwarded-UA": {info.UserAgent}, 61 | "X-Forwarded-For": {info.RemoteAddr}, 62 | } 63 | 64 | otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(wsconfig.Headers)) 65 | 66 | if tlsConfig != nil || dialer != nil { 67 | if info.Auth != nil { 68 | wsconfig.Headers["Authorization"] = []string{info.Auth.Encode()} 69 | } else if len(info.AuthHeaders) > 0 { 70 | wsconfig.Headers["Authorization"] = info.AuthHeaders 71 | } 72 | } 73 | 74 | session, resp, err := wsc.Connect(ctx, backendURL, wsconfig) 75 | if err != nil { 76 | 77 | var data []byte 78 | var code int 79 | status := "" 80 | 81 | if resp != nil { 82 | 83 | if resp.StatusCode == http.StatusUnauthorized { 84 | return nil, ErrAuthRequired 85 | } 86 | 87 | data, _ = io.ReadAll(resp.Body) 88 | _ = resp.Body.Close() 89 | 90 | code = resp.StatusCode 91 | status = resp.Status 92 | } 93 | 94 | slog.Error("WS connection failed", "code", code, "status", status, "data", strings.TrimSpace(string(data)), err) 95 | 96 | return nil, fmt.Errorf("unable to connect to the websocket. code: %d, status: %s: %w", code, status, err) 97 | } 98 | 99 | defer func() { _ = resp.Body.Close() }() 100 | 101 | if resp.StatusCode != http.StatusSwitchingProtocols { 102 | return nil, fmt.Errorf("invalid response from other end of the tunnel (must be 101): %s", resp.Status) 103 | } 104 | 105 | return session, nil 106 | } 107 | -------------------------------------------------------------------------------- /pkgs/frontend/internal/session/manager_test.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | . "github.com/smartystreets/goconvey/convey" 8 | "go.acuvity.ai/wsc" 9 | ) 10 | 11 | func TestManager(t *testing.T) { 12 | 13 | Convey("Manager should work", t, func() { 14 | 15 | ws1 := wsc.NewMockWebsocket(t.Context()) 16 | ws2 := wsc.NewMockWebsocket(t.Context()) 17 | 18 | m := NewManager() 19 | s1 := New(ws1, 1, "1") 20 | s2 := newSession(ws2, 2, "2", 2*time.Second) 21 | 22 | So(s1.nextDeadline, ShouldEqual, defaultDeadlineDuration) 23 | So(s2.nextDeadline, ShouldEqual, 2*time.Second) 24 | So(s1.ValidateHash(1), ShouldBeTrue) 25 | So(s2.ValidateHash(2), ShouldBeTrue) 26 | 27 | m.Register(s1) 28 | m.Register(s2) 29 | So(len(m.sessions), ShouldEqual, 2) 30 | So(m.sessions["1"].getCount(), ShouldEqual, 1) 31 | So(m.sessions["2"].getCount(), ShouldEqual, 1) 32 | 33 | // We acquire 1 34 | m.Acquire("1", nil) 35 | So(len(m.sessions), ShouldEqual, 2) 36 | So(m.sessions["1"].getCount(), ShouldEqual, 2) 37 | So(m.sessions["2"].getCount(), ShouldEqual, 1) 38 | 39 | go func() { s1.Write([]byte("coucou")) }() 40 | So(string(<-ws1.LastWrite()), ShouldEqual, "coucou") 41 | 42 | // We releease 1 43 | m.Release("1", nil) 44 | So(len(m.sessions), ShouldEqual, 2) 45 | So(m.sessions["1"].getCount(), ShouldEqual, 1) 46 | So(m.sessions["2"].getCount(), ShouldEqual, 1) 47 | 48 | // We release 1 again, it should be removed 49 | m.Release("1", nil) 50 | So(len(m.sessions), ShouldEqual, 1) 51 | So(m.sessions["1"], ShouldBeNil) 52 | So(m.sessions["2"].getCount(), ShouldEqual, 1) 53 | 54 | // we over release, it should be noop 55 | m.Release("1", nil) 56 | So(len(m.sessions), ShouldEqual, 1) 57 | So(m.sessions["1"], ShouldBeNil) 58 | So(m.sessions["2"].getCount(), ShouldEqual, 1) 59 | 60 | // We simulate a message when there is no hook 61 | // this should be noop and should not break the rest of 62 | // the test 63 | ws2.NextRead([]byte("nobody there")) 64 | <-time.After(time.Second) 65 | 66 | // We acquire 2 with a chan 67 | ch1 := make(chan []byte) 68 | ch2 := make(chan []byte) 69 | m.Acquire("2", ch1) 70 | m.Acquire("2", ch2) 71 | So(len(m.sessions), ShouldEqual, 1) 72 | So(m.sessions["2"].getCount(), ShouldEqual, 3) 73 | So(len(m.sessions["2"].getHooks()), ShouldEqual, 2) 74 | 75 | // we simulate a message from the ws 76 | ws2.NextRead([]byte("coucou")) 77 | 78 | // we should get it on ch1 79 | var data []byte 80 | select { 81 | case data = <-ch1: 82 | case <-time.After(time.Second): 83 | } 84 | So(string(data), ShouldEqual, "coucou") 85 | 86 | // we should get it on ch2 87 | select { 88 | case data = <-ch2: 89 | case <-time.After(time.Second): 90 | } 91 | So(string(data), ShouldEqual, "coucou") 92 | 93 | // Now we release 2 twice 94 | m.Release("2", ch1) 95 | m.Release("2", ch2) 96 | So(len(m.sessions), ShouldEqual, 1) 97 | So(m.sessions["2"].getCount(), ShouldEqual, 1) 98 | So(len(m.sessions["2"].getHooks()), ShouldEqual, 0) 99 | 100 | select { 101 | case <-s2.Done(): 102 | m.Release("2", nil) 103 | case <-time.After(5 * time.Second): 104 | } 105 | 106 | So(len(m.sessions), ShouldEqual, 0) 107 | }) 108 | } 109 | -------------------------------------------------------------------------------- /cli/internal/cmd/flags.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/pflag" 5 | ) 6 | 7 | var ( 8 | fTLSClient = pflag.NewFlagSet("tlsclient", pflag.ExitOnError) 9 | fTLSServer = pflag.NewFlagSet("tlsserver", pflag.ExitOnError) 10 | fProfiler = pflag.NewFlagSet("profile", pflag.ExitOnError) 11 | fHealth = pflag.NewFlagSet("health", pflag.ExitOnError) 12 | fPolicer = pflag.NewFlagSet("policer", pflag.ExitOnError) 13 | fCORS = pflag.NewFlagSet("cors", pflag.ExitOnError) 14 | fAgentAuth = pflag.NewFlagSet("agentauth", pflag.ExitOnError) 15 | fSBOM = pflag.NewFlagSet("sbom", pflag.ExitOnError) 16 | fMCP = pflag.NewFlagSet("mcp", pflag.ExitOnError) 17 | 18 | initialized = false 19 | ) 20 | 21 | func initSharedFlagSet() { 22 | 23 | if initialized { 24 | return 25 | } 26 | 27 | initialized = true 28 | 29 | fTLSServer.StringP("tls-server-cert", "c", "", "path to the server certificate for incoming HTTPS connections.") 30 | fTLSServer.StringP("tls-server-key", "k", "", "path to the key for the server certificate.") 31 | fTLSServer.StringP("tls-server-key-pass", "p", "", "passphrase for the server certificate key.") 32 | fTLSServer.String("tls-server-client-ca", "", "path to a CA to require and validate incoming client certificates.") 33 | 34 | fTLSClient.StringP("tls-client-cert", "C", "", "path to the client certificate to authenticate against the minibridge backend.") 35 | fTLSClient.StringP("tls-client-key", "K", "", "path to the key for the client certificate.") 36 | fTLSClient.StringP("tls-client-key-pass", "P", "", "passphrase for the client certificate key.") 37 | fTLSClient.String("tls-client-backend-ca", "", "path to a CA to validate the minibridge backend server certificates.") 38 | fTLSClient.Bool("tls-client-insecure-skip-verify", false, "skip backend's server certificates validation. INSECURE.") 39 | 40 | fHealth.String("health-listen", "", "if set, start health server on that address.") 41 | 42 | fPolicer.StringP("policer-type", "P", "", "type of policer to use. 'rego' or 'http'.") 43 | fPolicer.Bool("policer-enforce", true, "enforce policy or only log verdict.") 44 | fPolicer.String("policer-rego-policy", "", "path to a rego policy file for the rego policer.") 45 | fPolicer.String("policer-http-url", "", "URL of the HTTP policer to POST agent policing requests.") 46 | fPolicer.String("policer-http-bearer-token", "", "token to use to authenticate against the HTTP policer using Bearer scheme.") 47 | fPolicer.String("policer-http-basic-user", "", "user to use to authenticate against the HTTP policer using Basic scheme.") 48 | fPolicer.String("policer-http-basic-pass", "", "password to use to authenticate against the HTTP policer using Basic scheme.") 49 | fPolicer.String("policer-http-ca", "", "path to a CA to validate the policer server certificates.") 50 | fPolicer.Bool("policer-http-insecure-skip-verify", false, "skip policer's server certificates validation. INSECURE.") 51 | 52 | fCORS.String("cors-origin", "*", "sets the valid HTTP Origin for CORS responses.") 53 | 54 | fAgentAuth.StringP("agent-token", "t", "", "JWT token to pass to the minibridge backend for agent identification.") 55 | fAgentAuth.String("agent-user", "", "User to send to the backend as Basic auth.") 56 | fAgentAuth.String("agent-pass", "", "Password to send to the backend as Basic auth.") 57 | fAgentAuth.Bool("oauth-disabled", false, "If set, skip trying to perform the oauth dance.") 58 | 59 | fSBOM.String("sbom", "", "path to a sbom file (generated by minibridge scan sbom) to ensure server integrity.") 60 | 61 | fMCP.Int("mcp-uid", -1, "if greater than -1, use as UID to run the MCP server command.") 62 | fMCP.Int("mcp-gid", -1, "if greater than -1, use as GID to run the MCP server command.") 63 | fMCP.IntSlice("mcp-groups", nil, "additional GIDs to to run the MCP server command.") 64 | fMCP.Bool("mcp-use-tempdir", false, "if set, create a new temp execution dir for each MCP server instance.") 65 | fMCP.String("mcp-tls-ca", "", "when using SSE, path to a CA to valide MCP server server certificates.") 66 | fMCP.Bool("mcp-tls-insecure-skip-verify", false, "skip MCP server certificates validation. INSECURE.") 67 | } 68 | -------------------------------------------------------------------------------- /pkgs/backend/client/stdio.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "fmt" 7 | "io" 8 | "log/slog" 9 | "os" 10 | "os/exec" 11 | "strings" 12 | 13 | "go.acuvity.ai/minibridge/pkgs/internal/sanitize" 14 | ) 15 | 16 | var _ Client = (*stdioClient)(nil) 17 | 18 | type stdioClient struct { 19 | srv MCPServer 20 | cfg stdioCfg 21 | } 22 | 23 | // NewStdio returns a Client communicating through stdio. 24 | func NewStdio(srv MCPServer, options ...StdioOption) Client { 25 | 26 | cfg := newStdioCfg() 27 | for _, o := range options { 28 | o(&cfg) 29 | } 30 | 31 | return &stdioClient{ 32 | srv: srv, 33 | cfg: cfg, 34 | } 35 | } 36 | 37 | func (c *stdioClient) Type() string { 38 | return "stdio" 39 | } 40 | 41 | func (c *stdioClient) Server() string { return c.srv.Command } 42 | 43 | func (c *stdioClient) Start(ctx context.Context, _ ...Option) (pipe *MCPStream, err error) { 44 | 45 | dir, err := os.Getwd() 46 | if err != nil { 47 | return nil, fmt.Errorf("unable to get current directory: %w", err) 48 | } 49 | 50 | if c.cfg.useTempDir { 51 | dir, err = os.MkdirTemp(os.TempDir(), "minibridge-") 52 | if err != nil { 53 | return nil, fmt.Errorf("unable to create tempdir: %w", err) 54 | } 55 | } 56 | 57 | cmd := exec.CommandContext(ctx, c.srv.Command, c.srv.Args...) // #nosec: G204 58 | cmd.Env = append(os.Environ(), c.srv.Env...) 59 | for i, e := range cmd.Env { 60 | cmd.Env[i] = strings.ReplaceAll(e, "_MINIBRIDGE_PREFIX_", dir) 61 | } 62 | 63 | cmd.Dir = dir 64 | cmd.Cancel = func() error { 65 | if err := cmd.Process.Signal(os.Interrupt); err != nil { 66 | return err 67 | } 68 | 69 | if c.cfg.useTempDir { 70 | return os.RemoveAll(dir) 71 | } 72 | 73 | return nil 74 | } 75 | 76 | setCaps(cmd, "", c.cfg.creds) 77 | 78 | slog.Debug("Client: starting command", 79 | "path", cmd.Path, 80 | "dir", cmd.Dir, 81 | "creds", c.cfg.creds, 82 | ) 83 | 84 | stdin, err := cmd.StdinPipe() 85 | if err != nil { 86 | return nil, fmt.Errorf("unable to create stdin pipe: %w", err) 87 | } 88 | 89 | stdout, err := cmd.StdoutPipe() 90 | if err != nil { 91 | return nil, fmt.Errorf("unable to create stdout pipe: %w", err) 92 | } 93 | 94 | stderr, err := cmd.StderrPipe() 95 | if err != nil { 96 | return nil, fmt.Errorf("unable to create stderr pipe: %w", err) 97 | } 98 | 99 | stream := NewMCPStream(ctx) 100 | 101 | go c.readRequests(ctx, stdin, stream.stdin) 102 | go c.readResponses(ctx, stdout, stream.stdout) 103 | go c.readErrors(ctx, stderr, stream.stderr) 104 | 105 | if err := cmd.Start(); err != nil { 106 | return nil, fmt.Errorf("unable to start command: %w", err) 107 | } 108 | 109 | go func() { stream.exit <- cmd.Wait() }() 110 | 111 | return stream, nil 112 | } 113 | 114 | func (c *stdioClient) readRequests(ctx context.Context, stdin io.WriteCloser, ch chan []byte) { 115 | 116 | for { 117 | select { 118 | case <-ctx.Done(): 119 | return 120 | case data := <-ch: 121 | if _, err := stdin.Write(append(sanitize.Data(data), '\n')); err != nil { 122 | slog.Error("Unable to write data to stdin", "err", err) 123 | return 124 | } 125 | } 126 | } 127 | } 128 | 129 | func (c *stdioClient) readResponses(ctx context.Context, stdout io.ReadCloser, ch chan []byte) { 130 | 131 | bstdout := bufio.NewReader(stdout) 132 | for { 133 | data, err := bstdout.ReadBytes('\n') 134 | if err != nil { 135 | if err != io.EOF { 136 | slog.Error("Unable to read response from stdout", "err", err) 137 | } 138 | return 139 | } 140 | select { 141 | case ch <- sanitize.Data(data): 142 | case <-ctx.Done(): 143 | return 144 | } 145 | } 146 | } 147 | 148 | func (c *stdioClient) readErrors(ctx context.Context, stderr io.ReadCloser, ch chan []byte) { 149 | 150 | bstderr := bufio.NewReader(stderr) 151 | for { 152 | data, err := bstderr.ReadBytes('\n') 153 | if err != nil { 154 | if err != io.EOF { 155 | slog.Error("Unable to read error response from stderr", "err", err) 156 | } 157 | return 158 | } 159 | select { 160 | case ch <- sanitize.Data(data): 161 | case <-ctx.Done(): 162 | return 163 | } 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /pkgs/policer/internal/rego/policer.go: -------------------------------------------------------------------------------- 1 | package rego 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "strings" 8 | "time" 9 | 10 | "github.com/go-viper/mapstructure/v2" 11 | "github.com/open-policy-agent/opa/v1/ast" 12 | "github.com/open-policy-agent/opa/v1/rego" 13 | "go.acuvity.ai/minibridge/pkgs/mcp" 14 | "go.acuvity.ai/minibridge/pkgs/policer/api" 15 | ) 16 | 17 | type Policer struct { 18 | queryAllow rego.PreparedEvalQuery 19 | queryReasons rego.PreparedEvalQuery 20 | queryMCP rego.PreparedEvalQuery 21 | } 22 | 23 | const RegoRuntimeEnvPrefix = "REGO_POLICY_RUNTIME_" 24 | 25 | // New returns a new Rego based Policer. 26 | func New(policy string) (*Policer, error) { 27 | 28 | comp, err := precompile(policy, "default") 29 | if err != nil { 30 | return nil, fmt.Errorf("unable to compile rego policy: %w", err) 31 | } 32 | 33 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 34 | defer cancel() 35 | 36 | rTerm := makeRegoRuntimeTerm() 37 | 38 | queryAllow, err := rego.New(rego.Compiler(comp), rego.Query("data.main.allow"), rego.Runtime(rTerm)).PrepareForEval(ctx) 39 | if err != nil { 40 | return nil, fmt.Errorf("unable to prepare rego deny query: %w", err) 41 | } 42 | 43 | queryReasons, err := rego.New(rego.Compiler(comp), rego.Query("reasons := data.main.reasons"), rego.Runtime(rTerm)).PrepareForEval(ctx) 44 | if err != nil { 45 | return nil, fmt.Errorf("unable to prepare rego deny query: %w", err) 46 | } 47 | 48 | queryMCP, err := rego.New(rego.Compiler(comp), rego.Query("mcp := data.main.mcp"), rego.Runtime(rTerm)).PrepareForEval(ctx) 49 | if err != nil { 50 | return nil, fmt.Errorf("unable to prepare rego mcp query: %w", err) 51 | } 52 | 53 | return &Policer{ 54 | queryAllow: queryAllow, 55 | queryReasons: queryReasons, 56 | queryMCP: queryMCP, 57 | }, nil 58 | } 59 | 60 | func (p *Policer) Type() string { return "rego" } 61 | 62 | func (p *Policer) Police(ctx context.Context, preq api.Request) (*mcp.Message, error) { 63 | 64 | res, err := p.queryAllow.Eval(ctx, rego.EvalInput(preq), rego.EvalPrintHook(printer{})) 65 | if err != nil { 66 | return nil, fmt.Errorf("unable to eval allow query: %w", err) 67 | } 68 | 69 | if !res.Allowed() { 70 | 71 | res, err = p.queryReasons.Eval(ctx, rego.EvalInput(preq), rego.EvalPrintHook(printer{})) 72 | if err != nil { 73 | return nil, fmt.Errorf("unable to eval reasons query: %w", err) 74 | } 75 | 76 | reasons := []string{api.GenericDenyReason} 77 | 78 | if len(res) > 0 { 79 | bindings := res[0].Bindings 80 | breasons, _ := bindings["reasons"].([]any) 81 | 82 | if len(breasons) > 0 { 83 | reasons = make([]string, len(breasons)) 84 | for i, v := range breasons { 85 | reasons[i], _ = v.(string) 86 | } 87 | } 88 | } 89 | 90 | return nil, fmt.Errorf("%w: %s", api.ErrBlocked, strings.Join(reasons, ", ")) 91 | } 92 | 93 | res, err = p.queryMCP.Eval(ctx, rego.EvalInput(preq), rego.EvalPrintHook(printer{})) 94 | if err != nil { 95 | return nil, fmt.Errorf("unable to eval mcp query: %w", err) 96 | } 97 | 98 | if len(res) == 0 { 99 | return nil, nil 100 | } 101 | 102 | bindings := res[0].Bindings 103 | 104 | newmcp, ok := bindings["mcp"].(map[string]any) 105 | if !ok { 106 | return nil, fmt.Errorf("invalid binding: mcp must be an map[string]any, got %T", bindings["mcp"]) 107 | } 108 | 109 | mcall := &mcp.Message{} 110 | if err := mapstructure.Decode(newmcp, mcall); err != nil { 111 | return nil, fmt.Errorf("unable to decode rego mcp into valid MCP call: %w", err) 112 | } 113 | 114 | mcall.ID = preq.MCP.ID 115 | 116 | return mcall, nil 117 | } 118 | 119 | // makeRegoRuntimeTerm create a rego ast Term 120 | // to expose prefixed env var to the rego runtime. 121 | func makeRegoRuntimeTerm() *ast.Term { 122 | 123 | env := ast.NewObject() 124 | 125 | for _, s := range os.Environ() { 126 | parts := strings.SplitN(s, "=", 2) 127 | if !strings.HasPrefix(parts[0], RegoRuntimeEnvPrefix) { 128 | continue 129 | } 130 | if len(parts) == 2 { 131 | env.Insert(ast.StringTerm(strings.ReplaceAll(parts[0], RegoRuntimeEnvPrefix, "")), ast.StringTerm(parts[1])) 132 | } 133 | } 134 | 135 | obj := ast.NewObject() 136 | obj.Insert(ast.StringTerm("env"), ast.NewTerm(env)) 137 | 138 | return ast.NewTerm(obj) 139 | } 140 | -------------------------------------------------------------------------------- /cli/internal/cmd/scan.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "os" 8 | "time" 9 | 10 | "github.com/spf13/cobra" 11 | "github.com/spf13/pflag" 12 | "github.com/spf13/viper" 13 | "go.acuvity.ai/minibridge/pkgs/backend/client" 14 | "go.acuvity.ai/minibridge/pkgs/scan" 15 | ) 16 | 17 | var fScan = pflag.NewFlagSet("scan", pflag.ExitOnError) 18 | 19 | func init() { 20 | 21 | initSharedFlagSet() 22 | 23 | fScan.Duration("timeout", 2*time.Minute, "maximum time to allow the scan to run.") 24 | fScan.Bool("exclude-resources", false, "exclude resources from scan") 25 | fScan.Bool("exclude-tools", false, "exclude tools from scan") 26 | fScan.Bool("exclude-prompts", false, "exclude prompts from scan") 27 | 28 | Scan.Flags().AddFlagSet(fScan) 29 | Scan.Flags().AddFlagSet(fMCP) 30 | Scan.Flags().AddFlagSet(fAgentAuth) 31 | } 32 | 33 | // Scan is the cobra command to run the server. 34 | var Scan = &cobra.Command{ 35 | Use: "scan [dump|sbom|check file.sbom] -- command [args...]", 36 | Short: "Scan an MCP server for resources, prompts, etc and generate sbom", 37 | SilenceUsage: true, 38 | SilenceErrors: true, 39 | TraverseChildren: true, 40 | Args: cobra.MinimumNArgs(2), 41 | 42 | RunE: func(cmd *cobra.Command, args []string) error { 43 | 44 | timeout := viper.GetDuration("timeout") 45 | 46 | exclusions := &scan.Exclusions{ 47 | Prompts: viper.GetBool("exclude-prompts"), 48 | Resources: viper.GetBool("exclude-resources"), 49 | Tools: viper.GetBool("exclude-tools"), 50 | } 51 | 52 | var ctx context.Context 53 | var cancel context.CancelFunc 54 | 55 | if timeout > 0 { 56 | ctx, cancel = context.WithTimeout(cmd.Context(), timeout) 57 | } else { 58 | ctx, cancel = context.WithCancel(cmd.Context()) 59 | } 60 | defer cancel() 61 | 62 | var err error 63 | var mcpCommand string 64 | var mcpArgs []string 65 | if args[0] == "check" { 66 | mcpCommand = args[2] 67 | mcpArgs = args[3:] 68 | } else { 69 | mcpCommand = args[1] 70 | mcpArgs = args[2:] 71 | } 72 | 73 | mcpClient, err := makeMCPClient(append([]string{mcpCommand}, mcpArgs...), false) 74 | if err != nil { 75 | return err 76 | } 77 | 78 | agentAuth, err := makeAgentAuth(false) 79 | if err != nil { 80 | return fmt.Errorf("unable to build auth: %w", err) 81 | } 82 | 83 | stream, err := mcpClient.Start(ctx, client.OptionAuth(agentAuth)) 84 | if err != nil { 85 | return fmt.Errorf("unable to start MCP server: %w", err) 86 | } 87 | 88 | dump, err := scan.DumpAll(ctx, stream, exclusions) 89 | if err != nil { 90 | return fmt.Errorf("unable to dump tools: %w", err) 91 | } 92 | 93 | cancel() 94 | 95 | var toolHashes scan.Hashes 96 | 97 | if !exclusions.Tools { 98 | toolHashes, err = scan.HashTools(dump.Tools) 99 | if err != nil { 100 | return fmt.Errorf("unable to hash tools: %w", err) 101 | } 102 | } 103 | 104 | var promptHashes scan.Hashes 105 | 106 | if !exclusions.Prompts { 107 | promptHashes, err = scan.HashPrompts(dump.Prompts) 108 | if err != nil { 109 | return fmt.Errorf("unable to hash prompts: %w", err) 110 | } 111 | } 112 | 113 | sbom := scan.SBOM{ 114 | Tools: toolHashes, 115 | Prompts: promptHashes, 116 | } 117 | 118 | switch args[0] { 119 | 120 | case "check": 121 | 122 | refSBOM, err := scan.LoadSBOM(args[1]) 123 | if err != nil { 124 | return fmt.Errorf("unable to load sbom: %w", err) 125 | } 126 | 127 | if err := refSBOM.Tools.Matches(sbom.Tools); err != nil { 128 | return fmt.Errorf("tools sbom does not match: %w", err) 129 | } 130 | 131 | if err := refSBOM.Prompts.Matches(sbom.Prompts); err != nil { 132 | return fmt.Errorf("prompts sbom does not match: %w", err) 133 | } 134 | 135 | case "sbom": 136 | 137 | enc := json.NewEncoder(os.Stdout) 138 | enc.SetIndent("", " ") 139 | if err := enc.Encode(sbom); err != nil { 140 | return fmt.Errorf("unable to encode sbom: %w", err) 141 | } 142 | 143 | case "dump": 144 | 145 | enc := json.NewEncoder(os.Stdout) 146 | enc.SetIndent("", " ") 147 | if err := enc.Encode(dump); err != nil { 148 | return fmt.Errorf("unable to encode dump: %w", err) 149 | } 150 | 151 | default: 152 | return fmt.Errorf("first command must be either dump, sbom or check") 153 | } 154 | 155 | return nil 156 | }, 157 | } 158 | -------------------------------------------------------------------------------- /cli/internal/cmd/frontend.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "log/slog" 6 | "strings" 7 | 8 | "github.com/spf13/cobra" 9 | "github.com/spf13/pflag" 10 | "github.com/spf13/viper" 11 | "go.acuvity.ai/minibridge/pkgs/frontend" 12 | ) 13 | 14 | var fFrontend = pflag.NewFlagSet("", pflag.ExitOnError) 15 | 16 | func init() { 17 | 18 | initSharedFlagSet() 19 | 20 | fFrontend.StringP("listen", "l", "", "listen address of the bridge for incoming connections. If this is unset, stdio is used.") 21 | fFrontend.StringP("backend", "A", "", "URL of the minibridge backend to connect to.") 22 | fFrontend.String("endpoint-mcp", "/mcp", "when using HTTP, sets the endpoint to send messages (proto 2025-03-26).") 23 | fFrontend.String("endpoint-messages", "/message", "when using HTTP, sets the endpoint to post messages.") 24 | fFrontend.String("endpoint-sse", "/sse", "when using HTTP, sets the endpoint to connect to the event stream.") 25 | fAgentAuth.BoolP("agent-auth-passthrough", "b", false, "Forwards incoming HTTP Authorization header to the minibridge backend as-is.") 26 | 27 | Frontend.Flags().AddFlagSet(fFrontend) 28 | Frontend.Flags().AddFlagSet(fTLSClient) 29 | Frontend.Flags().AddFlagSet(fTLSServer) 30 | Frontend.Flags().AddFlagSet(fHealth) 31 | Frontend.Flags().AddFlagSet(fProfiler) 32 | Frontend.Flags().AddFlagSet(fCORS) 33 | Frontend.Flags().AddFlagSet(fAgentAuth) 34 | } 35 | 36 | // Frontend is the cobra command to run the client. 37 | var Frontend = &cobra.Command{ 38 | Use: "frontend", 39 | Short: "Start a minibridge frontend to connect to a minibridge backend", 40 | SilenceUsage: true, 41 | SilenceErrors: true, 42 | TraverseChildren: true, 43 | 44 | RunE: func(cmd *cobra.Command, args []string) error { 45 | 46 | listen := viper.GetString("listen") 47 | backendURL := viper.GetString("backend") 48 | mcpEndpoint := viper.GetString("endpoint-mcp") 49 | sseEndpoint := viper.GetString("endpoint-sse") 50 | messageEndpoint := viper.GetString("endpoint-messages") 51 | agentAuthPassthrough := viper.GetBool("agent-auth-passthrough") 52 | 53 | if backendURL == "" { 54 | return fmt.Errorf("--backend must be set") 55 | } 56 | if !strings.HasPrefix(backendURL, "wss://") && !strings.HasPrefix(backendURL, "ws://") { 57 | return fmt.Errorf("--backend must use wss:// or ws:// scheme") 58 | } 59 | if !strings.HasSuffix(backendURL, "/ws") { 60 | backendURL = backendURL + "/ws" 61 | } 62 | 63 | agentAuth, err := makeAgentAuth(true) 64 | if err != nil { 65 | return fmt.Errorf("unable to build auth: %w", err) 66 | } 67 | 68 | clientTLSConfig, err := tlsConfigFromFlags(fTLSClient) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | tracer, err := makeTracer(cmd.Context(), "backend") 74 | if err != nil { 75 | return fmt.Errorf("unable to configure tracer: %w", err) 76 | } 77 | 78 | corsPolicy := makeCORSPolicy() 79 | 80 | mm := startHealthServer(cmd.Context()) 81 | 82 | var mfrontend frontend.Frontend 83 | 84 | if listen != "" { 85 | 86 | serverTLSConfig, err := tlsConfigFromFlags(fTLSServer) 87 | if err != nil { 88 | return err 89 | } 90 | 91 | slog.Info("Minibridge frontend configured", 92 | "backend", backendURL, 93 | "mcp", mcpEndpoint, 94 | "sse", sseEndpoint, 95 | "messages", messageEndpoint, 96 | "mode", "http", 97 | "server-tls", serverTLSConfig != nil, 98 | "server-mtls", mtlsMode(serverTLSConfig), 99 | "client-tls", clientTLSConfig != nil, 100 | "listen", listen, 101 | ) 102 | 103 | mfrontend = frontend.NewHTTP(listen, backendURL, serverTLSConfig, clientTLSConfig, 104 | frontend.OptHTTPMCPEndpoint(mcpEndpoint), 105 | frontend.OptHTTPSSEEndpoint(sseEndpoint), 106 | frontend.OptHTTPMessageEndpoint(messageEndpoint), 107 | frontend.OptHTTPAgentTokenPassthrough(agentAuthPassthrough), 108 | frontend.OptHTTPCORSPolicy(corsPolicy), 109 | frontend.OptHTTPMetricsManager(mm), 110 | frontend.OptHTTPTracer(tracer), 111 | ) 112 | 113 | } else { 114 | 115 | slog.Info("Minibridge frontend configured", 116 | "backend", backendURL, 117 | "mode", "stdio", 118 | ) 119 | 120 | mfrontend = frontend.NewStdio(backendURL, clientTLSConfig, 121 | frontend.OptStdioTracer(tracer), 122 | frontend.OptStdioRetry(false), 123 | ) 124 | } 125 | 126 | return startFrontendWithOAuth(cmd.Context(), mfrontend, agentAuth) 127 | }, 128 | } 129 | -------------------------------------------------------------------------------- /pkgs/frontend/options.go: -------------------------------------------------------------------------------- 1 | package frontend 2 | 3 | import ( 4 | "context" 5 | "net" 6 | 7 | "go.acuvity.ai/bahamut" 8 | "go.acuvity.ai/minibridge/pkgs/metrics" 9 | "go.opentelemetry.io/otel/trace" 10 | "go.opentelemetry.io/otel/trace/noop" 11 | ) 12 | 13 | type httpCfg struct { 14 | mcpEndpoint string 15 | sseEndpoint string 16 | messagesEndpoint string 17 | agentTokenPassthrough bool 18 | corsPolicy *bahamut.CORSPolicy 19 | metricsManager *metrics.Manager 20 | tracer trace.Tracer 21 | backendDialer func(ctx context.Context, network, addr string) (net.Conn, error) 22 | oauthEndpointRegister string 23 | oauthEndpointAuthorize string 24 | oauthEndpointToken string 25 | } 26 | 27 | func newHTTPCfg() httpCfg { 28 | return httpCfg{ 29 | mcpEndpoint: "/mcp", 30 | sseEndpoint: "/sse", 31 | messagesEndpoint: "/message", 32 | oauthEndpointRegister: "/register", 33 | oauthEndpointAuthorize: "/authorize", 34 | oauthEndpointToken: "/token", 35 | tracer: noop.NewTracerProvider().Tracer("noop"), 36 | } 37 | } 38 | 39 | // OptHTTP are options that can be given to NewSSE(). 40 | type OptHTTP func(*httpCfg) 41 | 42 | // OptHTTPMCPEndpoint sets the mcp endpoint (protocol 2025-03-26) 43 | // where agents can connect to the response stream. 44 | // Defaults to /mcp 45 | func OptHTTPMCPEndpoint(ep string) OptHTTP { 46 | return func(cfg *httpCfg) { 47 | cfg.mcpEndpoint = ep 48 | } 49 | } 50 | 51 | // OptHTTPSSEEndpoint sets the sse endpoint (protocol 2024-11-05) 52 | // where agents can connect to the response stream. 53 | // Defaults to /sse 54 | func OptHTTPSSEEndpoint(ep string) OptHTTP { 55 | return func(cfg *httpCfg) { 56 | cfg.sseEndpoint = ep 57 | } 58 | } 59 | 60 | // OptHTTPMessageEndpoint sets the message endpoint (protocol 2024-11-05) 61 | // where agents can post request. 62 | // Defaults to /messages 63 | func OptHTTPMessageEndpoint(ep string) OptHTTP { 64 | return func(cfg *httpCfg) { 65 | cfg.messagesEndpoint = ep 66 | } 67 | } 68 | 69 | // OptHTTPCORSPolicy sets the bahamut.CORSPolicy to use for 70 | // connection originating from a webrowser. 71 | func OptHTTPCORSPolicy(policy *bahamut.CORSPolicy) OptHTTP { 72 | return func(cfg *httpCfg) { 73 | cfg.corsPolicy = policy 74 | } 75 | } 76 | 77 | // OptHTTPAgentTokenPassthrough decides if the HTTP request Authorization header 78 | // should be passed as-is to the minibridge backend. 79 | func OptHTTPAgentTokenPassthrough(passthrough bool) OptHTTP { 80 | return func(cfg *httpCfg) { 81 | cfg.agentTokenPassthrough = passthrough 82 | } 83 | } 84 | 85 | // OptHTTPMetricsManager sets the metric manager to use to collect 86 | // prometheus metrics. 87 | func OptHTTPMetricsManager(m *metrics.Manager) OptHTTP { 88 | return func(cfg *httpCfg) { 89 | cfg.metricsManager = m 90 | } 91 | } 92 | 93 | // OptHTTPTracer sets the otel trace.Tracer to use to trace requests 94 | func OptHTTPTracer(tracer trace.Tracer) OptHTTP { 95 | return func(cfg *httpCfg) { 96 | if tracer == nil { 97 | tracer = noop.NewTracerProvider().Tracer("noop") 98 | } 99 | cfg.tracer = tracer 100 | } 101 | } 102 | 103 | // OptHTTPBackendDialer sets the dialer to use to connect to the backend. 104 | func OptHTTPBackendDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) OptHTTP { 105 | return func(cfg *httpCfg) { 106 | cfg.backendDialer = dialer 107 | } 108 | } 109 | 110 | type stdioCfg struct { 111 | retry bool 112 | tracer trace.Tracer 113 | backendDialer func(ctx context.Context, network, addr string) (net.Conn, error) 114 | } 115 | 116 | func newStdioCfg() stdioCfg { 117 | return stdioCfg{ 118 | retry: true, 119 | tracer: noop.NewTracerProvider().Tracer("noop"), 120 | } 121 | } 122 | 123 | // OptStdio are options that can be given to NewStdio(). 124 | type OptStdio func(*stdioCfg) 125 | 126 | // OptStdioRetry allows to control if the Stdio frontend 127 | // should retry or not after a wbesocket connection failure. 128 | func OptStdioRetry(retry bool) OptStdio { 129 | return func(cfg *stdioCfg) { 130 | cfg.retry = retry 131 | } 132 | } 133 | 134 | // OptStdioTracer sets the otel trace.Tracer to use to trace requests 135 | func OptStdioTracer(tracer trace.Tracer) OptStdio { 136 | return func(cfg *stdioCfg) { 137 | if tracer == nil { 138 | tracer = noop.NewTracerProvider().Tracer("noop") 139 | } 140 | cfg.tracer = tracer 141 | } 142 | } 143 | 144 | // OptStdioBackendDialer sets the dialer to use to connect to the backend. 145 | func OptStdioBackendDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) OptStdio { 146 | return func(cfg *stdioCfg) { 147 | cfg.backendDialer = dialer 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /pkgs/backend/ws_test.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "testing" 8 | "time" 9 | 10 | . "github.com/smartystreets/goconvey/convey" 11 | "go.acuvity.ai/minibridge/pkgs/backend/client" 12 | "go.acuvity.ai/minibridge/pkgs/frontend" 13 | "go.acuvity.ai/minibridge/pkgs/policer" 14 | "go.acuvity.ai/wsc" 15 | ) 16 | 17 | func freePort() int { 18 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 19 | if err != nil { 20 | panic(err) 21 | } 22 | 23 | l, err := net.ListenTCP("tcp", addr) 24 | if err != nil { 25 | panic(err) 26 | } 27 | defer func() { _ = l.Close() }() 28 | return l.Addr().(*net.TCPAddr).Port 29 | } 30 | 31 | func startBackend(ctx context.Context, opts ...Option) (wsc.Websocket, error) { 32 | 33 | backendListen := fmt.Sprintf("127.0.0.1:%d", freePort()) 34 | 35 | srv, err := client.NewMCPServer("cat") 36 | if err != nil { 37 | return nil, err 38 | } 39 | client := client.NewStdio(srv) 40 | 41 | backend := NewWebSocket(backendListen, nil, client, opts...) 42 | 43 | go func() { 44 | if err := backend.Start(ctx); err != nil { 45 | panic(err) 46 | } 47 | }() 48 | 49 | <-time.After(time.Second) // wait a bit.. gh workers are slow 50 | 51 | ws, err := frontend.Connect(ctx, nil, fmt.Sprintf("ws://%s/ws", backendListen), nil, frontend.AgentInfo{UserAgent: "go-test"}) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return ws, nil 57 | } 58 | 59 | func TestWS(t *testing.T) { 60 | 61 | Convey("Given a ws backend without policer or tls", t, func() { 62 | 63 | ctx, cancel := context.WithCancel(t.Context()) 64 | defer cancel() 65 | 66 | ws, err := startBackend(ctx) 67 | So(err, ShouldBeNil) 68 | 69 | echo := `{"hello": "world"}` 70 | ws.Write([]byte(echo)) 71 | 72 | var data []byte 73 | select { 74 | case data = <-ws.Read(): 75 | case <-time.After(time.Second): 76 | } 77 | 78 | So(string(data), ShouldEqual, echo) 79 | 80 | echo = `not-json` 81 | ws.Write([]byte(echo)) 82 | 83 | select { 84 | case data = <-ws.Read(): 85 | case <-time.After(time.Second): 86 | } 87 | 88 | So(string(data), ShouldEqual, `{"error":{"code":500,"message":"unable to decode application/json: json decode error [pos 4]: expecting ot-: got ull"},"jsonrpc":"2.0"}`) 89 | }) 90 | 91 | Convey("Given a ws backend with a rego policer that denies the call", t, func() { 92 | 93 | ctx, cancel := context.WithCancel(t.Context()) 94 | defer cancel() 95 | 96 | policer, err := policer.NewRego(`package main 97 | import rego.v1 98 | default allow := false 99 | reasons contains "you can't do that, Dave" 100 | `) 101 | So(err, ShouldBeNil) 102 | 103 | ws, err := startBackend(ctx, OptPolicer(policer)) 104 | So(err, ShouldBeNil) 105 | 106 | echo := `{"jsonrpc": "2.0", "id": 2}` 107 | ws.Write([]byte(echo)) 108 | 109 | var data []byte 110 | select { 111 | case data = <-ws.Read(): 112 | case <-time.After(time.Second): 113 | } 114 | 115 | So(string(data), ShouldEqual, `{"error":{"code":451,"message":"request blocked: you can't do that, Dave"},"id":2,"jsonrpc":"2.0"}`) 116 | }) 117 | 118 | Convey("Given a ws backend with a rego policer that allows the call without mutation", t, func() { 119 | 120 | ctx, cancel := context.WithCancel(t.Context()) 121 | defer cancel() 122 | 123 | policer, err := policer.NewRego(`package main 124 | import rego.v1 125 | default allow := true 126 | `) 127 | So(err, ShouldBeNil) 128 | 129 | ws, err := startBackend(ctx, OptPolicer(policer)) 130 | So(err, ShouldBeNil) 131 | 132 | echo := `{"jsonrpc": "2.0", "id": 1}` 133 | ws.Write([]byte(echo)) 134 | 135 | var data []byte 136 | select { 137 | case data = <-ws.Read(): 138 | case <-time.After(time.Second): 139 | } 140 | 141 | So(string(data), ShouldEqual, `{"jsonrpc": "2.0", "id": 1}`) 142 | }) 143 | 144 | Convey("Given a ws backend with a rego policer that allows the call with mutation", t, func() { 145 | 146 | ctx, cancel := context.WithCancel(t.Context()) 147 | defer cancel() 148 | 149 | policer, err := policer.NewRego(`package main 150 | import rego.v1 151 | default allow := true 152 | 153 | mcp := x if { 154 | x := json.patch(input.mcp, [{ 155 | "op": "replace", 156 | "path": "/result/hello", 157 | "value": "world" 158 | }, { 159 | "op": "replace", 160 | "path": "/id", 161 | "value": 2, 162 | }]) 163 | } 164 | `) 165 | So(err, ShouldBeNil) 166 | 167 | ws, err := startBackend(ctx, OptPolicer(policer)) 168 | So(err, ShouldBeNil) 169 | 170 | echo := `{"id": 1, "jsonrpc": "2.0", "result": {"hello": "monde"}}` 171 | ws.Write([]byte(echo)) 172 | 173 | var data []byte 174 | select { 175 | case data = <-ws.Read(): 176 | case <-time.After(time.Second): 177 | } 178 | 179 | So(string(data), ShouldEqual, `{"id":1,"jsonrpc":"2.0","result":{"hello":"world"}}`) 180 | }) 181 | } 182 | -------------------------------------------------------------------------------- /pkgs/backend/client/stdio_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | "time" 8 | 9 | . "github.com/smartystreets/goconvey/convey" 10 | ) 11 | 12 | func TestStdioClient(t *testing.T) { 13 | 14 | Convey("Type is correct", t, func() { 15 | cl := NewStdio(MCPServer{}) 16 | So(cl.Type(), ShouldEqual, "stdio") 17 | }) 18 | 19 | Convey("Given I have a client cat and I send lots of trailing \n", t, func() { 20 | 21 | srv := MCPServer{ 22 | Command: "cat", 23 | Env: []string{"A=A"}, 24 | } 25 | cl := NewStdio(srv) 26 | 27 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 28 | defer cancel() 29 | 30 | stream, err := cl.Start(ctx) 31 | So(err, ShouldBeNil) 32 | So(stream, ShouldNotBeNil) 33 | 34 | out, unregister := stream.Stdout() 35 | defer unregister() 36 | 37 | stream.Stdin() <- []byte("hello world\r\n\n\n") 38 | 39 | data := <-out 40 | So(data, ShouldResemble, []byte("hello world")) 41 | }) 42 | 43 | Convey("Given I have a client with env", t, func() { 44 | 45 | srv := MCPServer{ 46 | Command: "sh", 47 | Args: []string{"-c", "echo $MTEST"}, 48 | Env: []string{"MTEST=HELLO"}, 49 | } 50 | cl := NewStdio(srv) 51 | 52 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 53 | defer cancel() 54 | 55 | stream, err := cl.Start(ctx) 56 | So(err, ShouldBeNil) 57 | So(stream, ShouldNotBeNil) 58 | 59 | out, unregisterOut := stream.Stdout() 60 | defer unregisterOut() 61 | 62 | exit, unregisterExit := stream.Exit() 63 | defer unregisterExit() 64 | 65 | data := <-out 66 | So(string(data), ShouldEqual, "HELLO") 67 | 68 | err = <-exit 69 | So(err, ShouldBeNil) 70 | }) 71 | 72 | Convey("Given I have a client to which I give an invalid server", t, func() { 73 | 74 | srv := MCPServer{ 75 | Command: "dog", 76 | } 77 | cl := NewStdio(srv) 78 | 79 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 80 | defer cancel() 81 | 82 | stream, err := cl.Start(ctx) 83 | So(err, ShouldNotBeNil) 84 | So(err.Error(), ShouldEqual, `unable to start command: exec: "dog": executable file not found in $PATH`) 85 | So(stream, ShouldBeNil) 86 | }) 87 | 88 | Convey("Given I have a client with a command that exits unexpectedly", t, func() { 89 | 90 | srv := MCPServer{ 91 | Command: "bash", 92 | Args: []string{"-c", "sleep 1 && exit 1"}, 93 | } 94 | cl := NewStdio(srv) 95 | 96 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 97 | defer cancel() 98 | 99 | stream, err := cl.Start(ctx) 100 | So(err, ShouldBeNil) 101 | So(stream, ShouldNotBeNil) 102 | 103 | exit, unregister := stream.Exit() 104 | defer unregister() 105 | 106 | time.Sleep(1050 * time.Millisecond) 107 | 108 | err = <-exit 109 | So(err, ShouldNotBeNil) 110 | So(err.Error(), ShouldEqual, "exit status 1") 111 | }) 112 | 113 | Convey("Given I have a client that writes a file", t, func() { 114 | 115 | srv := MCPServer{ 116 | Command: "sh", 117 | Args: []string{"-c", "touch testfile"}, 118 | } 119 | cl := NewStdio(srv) 120 | 121 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 122 | defer cancel() 123 | 124 | stream, err := cl.Start(ctx) 125 | So(err, ShouldBeNil) 126 | So(stream, ShouldNotBeNil) 127 | 128 | exit, unregister := stream.Exit() 129 | defer unregister() 130 | 131 | So(<-exit, ShouldBeNil) 132 | 133 | _, err = os.Stat("testfile") 134 | So(err, ShouldBeNil) 135 | _ = os.RemoveAll("testfile") 136 | }) 137 | 138 | Convey("Given I have a client that writes a file with tempdir", t, func() { 139 | 140 | srv := MCPServer{ 141 | Command: "sh", 142 | Args: []string{"-c", "touch testfile"}, 143 | } 144 | cl := NewStdio(srv, OptStdioUseTempDir(true)) 145 | 146 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 147 | defer cancel() 148 | 149 | stream, err := cl.Start(ctx) 150 | So(err, ShouldBeNil) 151 | So(stream, ShouldNotBeNil) 152 | 153 | exit, unregister := stream.Exit() 154 | defer unregister() 155 | 156 | So(<-exit, ShouldBeNil) 157 | 158 | _, err = os.Stat("testfile") 159 | So(err.Error(), ShouldEqual, "stat testfile: no such file or directory") 160 | }) 161 | 162 | Convey("Given I have a running client and an expiring context", t, func() { 163 | 164 | srv := MCPServer{ 165 | Command: "cat", 166 | } 167 | cl := NewStdio(srv) 168 | 169 | ctx, cancel := context.WithCancel(t.Context()) 170 | defer cancel() 171 | 172 | stream, err := cl.Start(ctx) 173 | So(err, ShouldBeNil) 174 | So(stream, ShouldNotBeNil) 175 | 176 | exit, unregister := stream.Exit() 177 | defer unregister() 178 | 179 | time.Sleep(300 * time.Millisecond) 180 | cancel() 181 | 182 | err = <-exit 183 | So(err, ShouldNotBeNil) 184 | So(err.Error(), ShouldEqual, "signal: interrupt") 185 | }) 186 | } 187 | -------------------------------------------------------------------------------- /pkgs/frontend/internal/session/session.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "log/slog" 5 | "maps" 6 | "sync" 7 | "time" 8 | 9 | "go.acuvity.ai/wsc" 10 | ) 11 | 12 | const defaultDeadlineDuration = 10 * time.Second 13 | 14 | // A Session represents an active agent session. 15 | // 16 | // It contains the underlying websocket to communicate 17 | // with the correct MCP server instance running in minibridge backend. 18 | // 19 | // The session must be acquired through a session manager when used 20 | // then a channel must be registered to receive the data coming from the 21 | // backend. 22 | // 23 | // When not in use, the Session must be released though the session manager 24 | // and all hooks must be unregistered. 25 | // When the the count is at 1 (only the initial aqcuirement), the session will 26 | // linger for the duration of deadline, after which the websocket will be closed 27 | // and so the MCP server running in the backend will be temrinated. 28 | type Session struct { 29 | closeCh chan error 30 | count int 31 | countLock sync.RWMutex 32 | nextDeadline time.Duration 33 | deadline time.Time 34 | deadlineLock sync.RWMutex 35 | h uint64 36 | hookLock sync.RWMutex 37 | hooks map[chan []byte]struct{} 38 | id string 39 | ws wsc.Websocket 40 | } 41 | 42 | // New retutns a new session backed by the given websocket. 43 | func New(ws wsc.Websocket, credsHash uint64, sid string) *Session { 44 | return newSession(ws, credsHash, sid, defaultDeadlineDuration) 45 | } 46 | 47 | func newSession(ws wsc.Websocket, credsHash uint64, sid string, deadline time.Duration) *Session { 48 | 49 | s := &Session{ 50 | ws: ws, 51 | h: credsHash, 52 | count: 1, 53 | id: sid, 54 | deadline: time.Now().Add(deadline), 55 | nextDeadline: deadline, 56 | closeCh: make(chan error), 57 | hooks: map[chan []byte]struct{}{}, 58 | } 59 | 60 | slog.Debug("session created", "sid", s.id, "c", s.count) 61 | s.start() 62 | 63 | return s 64 | } 65 | 66 | // ID returns the session ID. 67 | func (s *Session) ID() string { 68 | return s.id 69 | } 70 | 71 | // Write writes to the underlying websocket and 72 | // advances the deadline 73 | func (s *Session) Write(data []byte) { 74 | s.setDeadline(time.Now().Add(s.nextDeadline)) 75 | s.ws.Write(data) 76 | } 77 | 78 | // Done returns a channel that will receive 79 | // an error (or nil) when the websocket closes 80 | // for any reason. 81 | func (s *Session) Done() chan error { 82 | return s.closeCh 83 | } 84 | 85 | // ValidateHash validates the session hash. 86 | func (s *Session) ValidateHash(h uint64) bool { 87 | return h == s.h 88 | } 89 | 90 | // Close closes the websocket 91 | func (s *Session) Close() { 92 | s.ws.Close(1001) 93 | } 94 | 95 | func (s *Session) acquire() { 96 | s.countLock.Lock() 97 | defer s.countLock.Unlock() 98 | 99 | s.count++ 100 | 101 | s.setDeadline(time.Now().Add(s.nextDeadline)) 102 | slog.Debug("session acquired", "sid", s.id, "c", s.count) 103 | } 104 | 105 | func (s *Session) release() bool { 106 | s.countLock.Lock() 107 | defer s.countLock.Unlock() 108 | 109 | s.count-- 110 | slog.Debug("session released", "sid", s.id, "c", s.count, "deleted", s.count <= 0) 111 | 112 | if s.count <= 0 { 113 | s.Close() 114 | return true 115 | } 116 | 117 | return false 118 | } 119 | 120 | func (s *Session) register(c chan []byte) { 121 | s.hookLock.Lock() 122 | defer s.hookLock.Unlock() 123 | 124 | s.hooks[c] = struct{}{} 125 | } 126 | 127 | func (s *Session) unregister(c chan []byte) { 128 | s.hookLock.Lock() 129 | defer s.hookLock.Unlock() 130 | 131 | delete(s.hooks, c) 132 | } 133 | 134 | func (s *Session) start() { 135 | 136 | go func() { 137 | 138 | ticker := time.NewTicker(time.Second) 139 | 140 | for { 141 | select { 142 | 143 | case data := <-s.ws.Read(): 144 | 145 | for c := range s.getHooks() { 146 | select { 147 | case c <- data: 148 | default: 149 | slog.Error("Session sent data to inactive hook") 150 | } 151 | } 152 | 153 | s.setDeadline(time.Now().Add(s.nextDeadline)) 154 | 155 | case err := <-s.ws.Done(): 156 | s.release() 157 | s.closeCh <- err 158 | return 159 | 160 | case <-ticker.C: 161 | if time.Now().After(s.getDeadline()) && s.getCount() <= 1 { 162 | slog.Debug("session terminated: deadline exceeded", "sid", s.ID()) 163 | s.release() 164 | } 165 | } 166 | } 167 | }() 168 | } 169 | 170 | func (s *Session) getDeadline() time.Time { 171 | s.deadlineLock.RLock() 172 | defer s.deadlineLock.RUnlock() 173 | return s.deadline 174 | } 175 | 176 | func (s *Session) setDeadline(deadline time.Time) { 177 | s.deadlineLock.Lock() 178 | defer s.deadlineLock.Unlock() 179 | s.deadline = deadline 180 | } 181 | 182 | func (s *Session) getCount() int { 183 | s.countLock.RLock() 184 | defer s.countLock.RUnlock() 185 | return s.count 186 | } 187 | 188 | func (s *Session) getHooks() map[chan []byte]struct{} { 189 | s.hookLock.RLock() 190 | defer s.hookLock.RUnlock() 191 | return maps.Clone(s.hooks) 192 | } 193 | -------------------------------------------------------------------------------- /pkgs/frontend/stdio.go: -------------------------------------------------------------------------------- 1 | package frontend 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "crypto/tls" 7 | "fmt" 8 | "log/slog" 9 | "net/http" 10 | "net/url" 11 | "os" 12 | "os/user" 13 | "time" 14 | 15 | "go.acuvity.ai/minibridge/pkgs/auth" 16 | "go.acuvity.ai/minibridge/pkgs/info" 17 | "go.acuvity.ai/minibridge/pkgs/internal/sanitize" 18 | ) 19 | 20 | var _ Frontend = (*stdioFrontend)(nil) 21 | 22 | type stdioFrontend struct { 23 | u *url.URL 24 | agentAuth *auth.Auth 25 | backendURL string 26 | cfg stdioCfg 27 | claims []string 28 | tlsClientConfig *tls.Config 29 | user string 30 | wsWrite chan []byte 31 | } 32 | 33 | // NewStdio returns a new *StdioProxy that will connect to the given 34 | // endpoint using the given tlsConfig. Agents can write request to stdin and read 35 | // responses from stdout. stderr contains the logs. 36 | // 37 | // A single session to the backend will be created and it will 38 | // reconnect in case of disconnection. 39 | func NewStdio(backend string, tlsConfig *tls.Config, opts ...OptStdio) Frontend { 40 | 41 | cfg := newStdioCfg() 42 | for _, o := range opts { 43 | o(&cfg) 44 | } 45 | 46 | u, err := url.Parse(backend) 47 | if err != nil { 48 | panic(err) 49 | } 50 | 51 | return &stdioFrontend{ 52 | u: u, 53 | backendURL: backend, 54 | tlsClientConfig: tlsConfig, 55 | wsWrite: make(chan []byte), 56 | cfg: cfg, 57 | } 58 | } 59 | 60 | // Start starts the proxy. It will run until the given context is canceled or until 61 | // the server returns an error. 62 | func (p *stdioFrontend) Start(ctx context.Context, agentAuth *auth.Auth) error { 63 | 64 | p.agentAuth = agentAuth 65 | 66 | user, err := user.Current() 67 | if err != nil { 68 | return fmt.Errorf("unable to get current user: %w", err) 69 | } 70 | 71 | host, err := os.Hostname() 72 | if err != nil { 73 | return fmt.Errorf("unable to get current hostname: %w", err) 74 | } 75 | 76 | p.user = fmt.Sprintf("%s@%s", user.Username, host) 77 | p.claims = []string{ 78 | fmt.Sprintf("gid=%s", user.Gid), 79 | fmt.Sprintf("uid=%s", user.Uid), 80 | fmt.Sprintf("username=%s", user.Username), 81 | fmt.Sprintf("hostname=%s", host), 82 | "minibridge=stdio", 83 | } 84 | 85 | slog.Debug("Local machine user set", "user", p.user, "claims", p.claims) 86 | 87 | subctx, cancel := context.WithCancel(ctx) 88 | defer cancel() 89 | 90 | errCh := make(chan error, 2) 91 | 92 | go func() { errCh <- p.stdiopump(subctx) }() 93 | go func() { errCh <- p.wspump(subctx) }() 94 | 95 | return <-errCh 96 | } 97 | 98 | func (p *stdioFrontend) HTTPClient() *http.Client { 99 | return &http.Client{ 100 | Transport: &http.Transport{ 101 | TLSClientConfig: p.tlsClientConfig, 102 | DialContext: p.cfg.backendDialer, 103 | }, 104 | } 105 | } 106 | 107 | func (p *stdioFrontend) BackendURL() string { 108 | 109 | scheme := "http" 110 | if p.u.Scheme == "wss" || p.u.Scheme == "https" { 111 | scheme = "https" 112 | } 113 | 114 | return fmt.Sprintf("%s://%s", scheme, p.u.Host) 115 | } 116 | 117 | func (p *stdioFrontend) BackendInfo() (info.Info, error) { 118 | return getBackendInfo(p) 119 | } 120 | 121 | func (p *stdioFrontend) wspump(ctx context.Context) error { 122 | 123 | var failures int 124 | 125 | for { 126 | 127 | select { 128 | 129 | case <-ctx.Done(): 130 | return nil 131 | 132 | default: 133 | session, err := Connect(ctx, p.cfg.backendDialer, p.backendURL, p.tlsClientConfig, AgentInfo{ 134 | Auth: p.agentAuth, 135 | UserAgent: "stdio", 136 | RemoteAddr: "local", 137 | }) 138 | if err != nil { 139 | 140 | if !p.cfg.retry { 141 | return err 142 | } 143 | 144 | if failures == 1 { 145 | slog.Error("Retrying...", err) 146 | } 147 | 148 | failures++ 149 | time.Sleep(2 * time.Second) 150 | 151 | continue 152 | } 153 | 154 | if failures > 0 { 155 | slog.Info("Connection restored", "attempts", failures) 156 | } 157 | 158 | failures = 0 159 | 160 | L: 161 | for { 162 | 163 | select { 164 | 165 | case buf := <-p.wsWrite: 166 | session.Write(sanitize.Data(buf)) 167 | 168 | case data := <-session.Read(): 169 | fmt.Println(string(sanitize.Data(data))) 170 | 171 | case err := <-session.Error(): 172 | failures++ 173 | slog.Error("Error from webscoket", err) 174 | 175 | case err := <-session.Done(): 176 | failures++ 177 | if !p.cfg.retry { 178 | return err 179 | } 180 | break L 181 | 182 | case <-ctx.Done(): 183 | session.Close(1000) 184 | return nil 185 | } 186 | } 187 | } 188 | } 189 | } 190 | 191 | func (p *stdioFrontend) stdiopump(ctx context.Context) error { 192 | 193 | stdin := bufio.NewReader(os.Stdin) 194 | 195 | for { 196 | select { 197 | 198 | default: 199 | 200 | buf, err := stdin.ReadBytes('\n') 201 | 202 | if err != nil { 203 | slog.Debug("unable to read stdin", err) 204 | return err 205 | } 206 | 207 | if len(buf) == 0 { 208 | continue 209 | } 210 | 211 | p.wsWrite <- sanitize.Data(buf) 212 | 213 | case <-ctx.Done(): 214 | return nil 215 | } 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /cli/internal/cmd/aio.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log/slog" 7 | "net" 8 | "time" 9 | 10 | "github.com/spf13/cobra" 11 | "github.com/spf13/pflag" 12 | "github.com/spf13/viper" 13 | "go.acuvity.ai/minibridge/pkgs/backend" 14 | "go.acuvity.ai/minibridge/pkgs/frontend" 15 | "go.acuvity.ai/minibridge/pkgs/memconn" 16 | "golang.org/x/sync/errgroup" 17 | ) 18 | 19 | var fAIO = pflag.NewFlagSet("aio", pflag.ExitOnError) 20 | 21 | func init() { 22 | 23 | initSharedFlagSet() 24 | 25 | fAIO.StringP("listen", "l", "", "listen address of the bridge for incoming connections. If unset, stdio is used.") 26 | fAIO.String("endpoint-mcp", "/mcp", "when using HTTP, sets the endpoint to send messages (proto 2025-03-26).") 27 | fAIO.String("endpoint-messages", "/message", "when using HTTP, sets the endpoint to post messages (proto 2024-11-05).") 28 | fAIO.String("endpoint-sse", "/sse", "when using HTTP, sets the endpoint to connect to the event stream (proto 2024-11-05).") 29 | 30 | AIO.Flags().AddFlagSet(fAIO) 31 | AIO.Flags().AddFlagSet(fPolicer) 32 | AIO.Flags().AddFlagSet(fTLSServer) 33 | AIO.Flags().AddFlagSet(fHealth) 34 | AIO.Flags().AddFlagSet(fProfiler) 35 | AIO.Flags().AddFlagSet(fCORS) 36 | AIO.Flags().AddFlagSet(fAgentAuth) 37 | AIO.Flags().AddFlagSet(fSBOM) 38 | AIO.Flags().AddFlagSet(fMCP) 39 | } 40 | 41 | var AIO = &cobra.Command{ 42 | Use: "aio [flags] -- command [args...]", 43 | Short: "Start an all-in-one minibridge frontend and backend", 44 | Args: cobra.MinimumNArgs(1), 45 | SilenceUsage: true, 46 | SilenceErrors: true, 47 | TraverseChildren: true, 48 | 49 | RunE: func(cmd *cobra.Command, args []string) (err error) { 50 | 51 | ctx, cancel := context.WithCancel(cmd.Context()) 52 | defer cancel() 53 | 54 | listen := viper.GetString("listen") 55 | mcpEndpoint := viper.GetString("endpoint-mcp") 56 | sseEndpoint := viper.GetString("endpoint-sse") 57 | messageEndpoint := viper.GetString("endpoint-messages") 58 | 59 | agentAuth, err := makeAgentAuth(true) 60 | if err != nil { 61 | return fmt.Errorf("unable to build auth: %w", err) 62 | } 63 | 64 | policer, penforce, err := makePolicer() 65 | if err != nil { 66 | return fmt.Errorf("unable to make policer: %w", err) 67 | } 68 | 69 | sbom, err := makeSBOM() 70 | if err != nil { 71 | return fmt.Errorf("unable to make hashes: %w", err) 72 | } 73 | 74 | tracer, err := makeTracer(ctx, "aio") 75 | if err != nil { 76 | return fmt.Errorf("unable to configure tracer: %w", err) 77 | } 78 | 79 | corsPolicy := makeCORSPolicy() 80 | 81 | mcpClient, err := makeMCPClient(args, true) 82 | if err != nil { 83 | return fmt.Errorf("unable to create MCP client: %w", err) 84 | } 85 | 86 | mm := startHealthServer(ctx) 87 | 88 | listener := memconn.NewListener() 89 | defer func() { _ = listener.Close() }() 90 | 91 | var eg errgroup.Group 92 | 93 | var mbackend backend.Backend 94 | eg.Go(func() error { 95 | 96 | defer cancel() 97 | 98 | slog.Info("Minibridge backend configured") 99 | 100 | mbackend = backend.NewWebSocket("self", nil, mcpClient, 101 | backend.OptListener(listener), 102 | backend.OptPolicer(policer), 103 | backend.OptPolicerEnforce(penforce), 104 | backend.OptDumpStderrOnError(viper.GetString("log-format") != "json"), 105 | backend.OptSBOM(sbom), 106 | backend.OptMetricsManager(mm), 107 | backend.OptTracer(tracer), 108 | ) 109 | 110 | return mbackend.Start(ctx) 111 | }) 112 | 113 | eg.Go(func() error { 114 | 115 | defer cancel() 116 | 117 | var mfrontend frontend.Frontend 118 | 119 | frontendServerTLSConfig, err := tlsConfigFromFlags(fTLSServer) 120 | if err != nil { 121 | return err 122 | } 123 | 124 | dialer := func(ctx context.Context, network, addr string) (net.Conn, error) { 125 | return listener.DialContext(cmd.Context(), "127.0.0.1:443") 126 | } 127 | 128 | if listen != "" { 129 | 130 | slog.Info("Minibridge frontend configured", 131 | "mcp", mcpEndpoint, 132 | "sse", sseEndpoint, 133 | "messages", messageEndpoint, 134 | "agent-token", agentAuth != nil, 135 | "mode", "http", 136 | "server-tls", frontendServerTLSConfig != nil, 137 | "server-mtls", mtlsMode(frontendServerTLSConfig), 138 | "listen", listen, 139 | ) 140 | 141 | mfrontend = frontend.NewHTTP(listen, "ws://self/ws", frontendServerTLSConfig, nil, 142 | frontend.OptHTTPBackendDialer(dialer), 143 | frontend.OptHTTPMCPEndpoint(mcpEndpoint), 144 | frontend.OptHTTPSSEEndpoint(sseEndpoint), 145 | frontend.OptHTTPMessageEndpoint(messageEndpoint), 146 | frontend.OptHTTPAgentTokenPassthrough(true), 147 | frontend.OptHTTPCORSPolicy(corsPolicy), 148 | frontend.OptHTTPMetricsManager(mm), 149 | frontend.OptHTTPTracer(tracer), 150 | ) 151 | } else { 152 | 153 | slog.Info("Minibridge frontend configured", 154 | "mode", "stdio", 155 | ) 156 | 157 | mfrontend = frontend.NewStdio("ws://self/ws", nil, 158 | frontend.OptStdioBackendDialer(dialer), 159 | frontend.OptStdioRetry(false), 160 | frontend.OptStdioTracer(tracer), 161 | ) 162 | } 163 | 164 | time.Sleep(300 * time.Millisecond) 165 | 166 | return startFrontendWithOAuth(ctx, mfrontend, agentAuth) 167 | }) 168 | 169 | return eg.Wait() 170 | }, 171 | } 172 | -------------------------------------------------------------------------------- /pkgs/scan/scan.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "fmt" 7 | "slices" 8 | "strings" 9 | 10 | "github.com/gofrs/uuid" 11 | "github.com/mitchellh/mapstructure" 12 | "go.acuvity.ai/minibridge/pkgs/backend/client" 13 | "go.acuvity.ai/minibridge/pkgs/mcp" 14 | ) 15 | 16 | type Dump struct { 17 | Tools mcp.Tools `json:"tools,omitempty"` 18 | Resources mcp.Resources `json:"resources,omitempty"` 19 | ResourceTemplates mcp.ResourceTemplates `json:"resourceTemplates,omitempty"` 20 | Prompts mcp.Prompts `json:"prompts,omitempty"` 21 | } 22 | 23 | // DumpAll dumps all the all available tools/resource/prompts from the given client.MCPStream. 24 | func DumpAll(ctx context.Context, stream *client.MCPStream, exclusions *Exclusions) (Dump, error) { 25 | 26 | if _, err := stream.SendRequest(ctx, mcp.NewInitMessage(mcp.ProtocolVersion20250326)); err != nil { 27 | return Dump{}, fmt.Errorf("unable to send mcp request: %w", err) 28 | } 29 | 30 | if err := stream.SendNotification(ctx, mcp.NewNotification("notifications/initialize")); err != nil { 31 | return Dump{}, fmt.Errorf("unable to send mcp initialized notif: %w", err) 32 | } 33 | 34 | dump := Dump{} 35 | 36 | // Tools 37 | if !exclusions.Tools { 38 | toolsReq := mcp.NewMessage(uuid.Must(uuid.NewV7()).String()) 39 | toolsReq.Method = "tools/list" 40 | resps, err := stream.SendPaginatedRequest(ctx, toolsReq) 41 | if err != nil { 42 | return Dump{}, fmt.Errorf("unable to send tools/list mcp request: %w", err) 43 | } 44 | 45 | for _, resp := range resps { 46 | 47 | if _, ok := resp.Result["tools"]; !ok { 48 | continue 49 | } 50 | 51 | tools := mcp.Tools{} 52 | if err := mapstructure.Decode(resp.Result["tools"], &tools); err != nil { 53 | return Dump{}, fmt.Errorf("unable to convert to tools: %w", err) 54 | } 55 | 56 | dump.Tools = append(dump.Tools, tools...) 57 | } 58 | } 59 | 60 | // Resources 61 | if !exclusions.Resources { 62 | resourcesReq := mcp.NewMessage(uuid.Must(uuid.NewV7()).String()) 63 | resourcesReq.Method = "resources/list" 64 | resps, err := stream.SendPaginatedRequest(ctx, resourcesReq) 65 | if err != nil { 66 | return Dump{}, fmt.Errorf("unable to send resources/list mcp request: %w", err) 67 | } 68 | 69 | for _, resp := range resps { 70 | 71 | if _, ok := resp.Result["resources"]; !ok { 72 | continue 73 | } 74 | 75 | resources := mcp.Resources{} 76 | if err := mapstructure.Decode(resp.Result["resources"], &resources); err != nil { 77 | return Dump{}, fmt.Errorf("unable to convert to resources: %w", err) 78 | } 79 | 80 | dump.Resources = append(dump.Resources, resources...) 81 | } 82 | 83 | // Resources Templates 84 | resourcesTemplateReq := mcp.NewMessage(uuid.Must(uuid.NewV7()).String()) 85 | resourcesTemplateReq.Method = "resources/templates/list" 86 | resps, err = stream.SendPaginatedRequest(ctx, resourcesTemplateReq) 87 | if err != nil { 88 | return Dump{}, fmt.Errorf("unable to send resources/templates/list mcp request: %w", err) 89 | } 90 | 91 | for _, resp := range resps { 92 | 93 | if _, ok := resp.Result["resourceTemplates"]; !ok { 94 | continue 95 | } 96 | 97 | resourceTemplates := mcp.ResourceTemplates{} 98 | if err := mapstructure.Decode(resp.Result["resourceTemplates"], &resourceTemplates); err != nil { 99 | return Dump{}, fmt.Errorf("unable to convert to resources templates: %w", err) 100 | } 101 | 102 | dump.ResourceTemplates = append(dump.ResourceTemplates, resourceTemplates...) 103 | } 104 | 105 | } 106 | 107 | // Prompts 108 | if !exclusions.Prompts { 109 | promptsReq := mcp.NewMessage(uuid.Must(uuid.NewV7()).String()) 110 | promptsReq.Method = "prompts/list" 111 | resps, err := stream.SendPaginatedRequest(ctx, promptsReq) 112 | if err != nil { 113 | return Dump{}, fmt.Errorf("unable to send prompts/list mcp request: %w", err) 114 | } 115 | 116 | for _, resp := range resps { 117 | 118 | if _, ok := resp.Result["prompts"]; !ok { 119 | continue 120 | } 121 | 122 | prompts := mcp.Prompts{} 123 | if err := mapstructure.Decode(resp.Result["prompts"], &prompts); err != nil { 124 | return Dump{}, fmt.Errorf("unable to convert to prompts: %w", err) 125 | } 126 | 127 | dump.Prompts = append(dump.Prompts, prompts...) 128 | } 129 | } 130 | return dump, nil 131 | } 132 | 133 | // HashTools will generate Hashes for the given api.Tools 134 | func HashTools(tools mcp.Tools) (Hashes, error) { 135 | 136 | hashes := []Hash{} 137 | for _, tool := range tools { 138 | 139 | h := Hash{ 140 | Name: tool.Name, 141 | Hash: fmt.Sprintf("%x", sha256.Sum256([]byte(tool.Description))), 142 | } 143 | 144 | for k, v := range tool.InputSchema { 145 | if k != "properties" { 146 | continue 147 | } 148 | for pk, pv := range v.(map[string]any) { 149 | 150 | pvv, ok := pv.(map[string]any) 151 | if !ok { 152 | continue 153 | } 154 | 155 | pdesc, ok := pvv["description"].(string) 156 | if !ok { 157 | continue 158 | } 159 | 160 | h.Params = append(h.Params, Hash{ 161 | Name: pk, 162 | Hash: fmt.Sprintf("%x", sha256.Sum256([]byte(pdesc))), 163 | }) 164 | } 165 | } 166 | 167 | slices.SortFunc(h.Params, func(a Hash, b Hash) int { 168 | return strings.Compare(a.Name, b.Name) 169 | }) 170 | 171 | hashes = append(hashes, h) 172 | } 173 | 174 | slices.SortFunc(hashes, func(a Hash, b Hash) int { 175 | return strings.Compare(a.Name, b.Name) 176 | }) 177 | 178 | return hashes, nil 179 | } 180 | 181 | // HashPrompt generate Hashes for the given api.Prompt 182 | func HashPrompts(prompts mcp.Prompts) (Hashes, error) { 183 | 184 | hashes := []Hash{} 185 | for _, tool := range prompts { 186 | 187 | h := Hash{ 188 | Name: tool.Name, 189 | Hash: fmt.Sprintf("%x", sha256.Sum256([]byte(tool.Description))), 190 | } 191 | 192 | for _, p := range tool.Arguments { 193 | 194 | h.Params = append(h.Params, Hash{ 195 | Name: p.Name, 196 | Hash: fmt.Sprintf("%x", sha256.Sum256([]byte(p.Description))), 197 | }) 198 | } 199 | 200 | slices.SortFunc(h.Params, func(a Hash, b Hash) int { 201 | return strings.Compare(a.Name, b.Name) 202 | }) 203 | 204 | hashes = append(hashes, h) 205 | } 206 | 207 | slices.SortFunc(hashes, func(a Hash, b Hash) int { 208 | return strings.Compare(a.Name, b.Name) 209 | }) 210 | 211 | return hashes, nil 212 | } 213 | -------------------------------------------------------------------------------- /pkgs/backend/client/sse.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "crypto/tls" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "log/slog" 12 | "net/http" 13 | "net/url" 14 | "strings" 15 | "time" 16 | 17 | "go.acuvity.ai/minibridge/pkgs/auth" 18 | "go.acuvity.ai/minibridge/pkgs/internal/sanitize" 19 | ) 20 | 21 | var _ Client = (*sseClient)(nil) 22 | var _ RemoteClient = (*sseClient)(nil) 23 | 24 | var ErrAuthRequired = errors.New("authorization required") 25 | 26 | type sseClient struct { 27 | u *url.URL 28 | endpoint string 29 | messageEndpoint string 30 | client *http.Client 31 | } 32 | 33 | func NewSSE(endpoint string, tlsConfig *tls.Config) Client { 34 | 35 | client := &http.Client{ 36 | Transport: &http.Transport{ 37 | TLSClientConfig: tlsConfig, 38 | ResponseHeaderTimeout: 10 * time.Second, 39 | }, 40 | } 41 | 42 | u, err := url.Parse(endpoint) 43 | if err != nil { 44 | panic(err) 45 | } 46 | 47 | return &sseClient{ 48 | endpoint: strings.TrimRight(endpoint, "/"), 49 | client: client, 50 | u: u, 51 | } 52 | } 53 | 54 | func (c *sseClient) Type() string { return "sse" } 55 | 56 | func (c *sseClient) Server() string { return c.BaseURL() } 57 | 58 | func (c *sseClient) BaseURL() string { return fmt.Sprintf("%s://%s", c.u.Scheme, c.u.Host) } 59 | 60 | func (c *sseClient) HTTPClient() *http.Client { return c.client } 61 | 62 | func (c *sseClient) MPCClient() Client { return c } 63 | 64 | func (c *sseClient) Start(ctx context.Context, opts ...Option) (pipe *MCPStream, err error) { 65 | 66 | cfg := cfg{} 67 | for _, o := range opts { 68 | o(&cfg) 69 | } 70 | 71 | sseEndpoint := fmt.Sprintf("%s/sse", c.endpoint) 72 | 73 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, sseEndpoint, nil) 74 | if err != nil { 75 | return nil, fmt.Errorf("unable to initiate request: %w", err) 76 | } 77 | req.Header.Set("Accept", "text/event-stream") 78 | if cfg.auth != nil { 79 | req.Header.Set("Authorization", cfg.auth.Encode()) 80 | } 81 | 82 | stream := NewMCPStream(ctx) 83 | out, unregisterOut := stream.Stdout() 84 | defer unregisterOut() 85 | exit, unregisterExit := stream.Exit() 86 | defer unregisterExit() 87 | 88 | // we don't close the body here (which makes the linter all triggered) 89 | // because it's a long running connection. however readResponse will 90 | // do it on exit 91 | resp, err := c.client.Do(req) // nolint 92 | if err != nil { 93 | return nil, fmt.Errorf("unable to send initial sse request (%s): %w", req.URL.String(), err) 94 | } 95 | 96 | if resp.StatusCode == http.StatusUnauthorized { 97 | return nil, ErrAuthRequired 98 | } 99 | 100 | if resp.StatusCode != http.StatusOK { 101 | return nil, fmt.Errorf("invalid response from sse initialization (%s): %s", req.URL.String(), resp.Status) 102 | } 103 | 104 | go c.readRequest(ctx, stream.stdin, stream.exit, cfg.auth) 105 | go c.readResponse(ctx, resp.Body, stream.stdout, stream.exit) 106 | 107 | // Get the first response to get the endpoint 108 | var data []byte 109 | 110 | L: 111 | for { 112 | select { 113 | case data = <-out: 114 | break L 115 | 116 | case err := <-exit: 117 | if !errors.Is(err, io.EOF) { 118 | return nil, fmt.Errorf("unable to process sse message: %w", err) 119 | } 120 | 121 | case <-time.After(time.Second): 122 | return nil, fmt.Errorf("did not receive /message endpoint in time: timeout") 123 | 124 | case <-ctx.Done(): 125 | return nil, fmt.Errorf("did not receive /message endpoint in time: %w", ctx.Err()) 126 | } 127 | } 128 | 129 | c.messageEndpoint = fmt.Sprintf( 130 | "%s/%s", 131 | c.endpoint, 132 | strings.TrimLeft(string(data), "/"), 133 | ) 134 | slog.Debug("SSE Client: message endpoint set", "endpoint", c.messageEndpoint) 135 | 136 | return stream, nil 137 | } 138 | 139 | func (c *sseClient) readRequest(ctx context.Context, ch chan []byte, exitCh chan error, auth *auth.Auth) { 140 | 141 | for { 142 | 143 | select { 144 | 145 | case <-ctx.Done(): 146 | return 147 | 148 | case data := <-ch: 149 | 150 | buf := bytes.NewBuffer(append(sanitize.Data(data), '\n', '\n')) 151 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.messageEndpoint, buf) 152 | if err != nil { 153 | exitCh <- fmt.Errorf("unable to make post request: %w", err) 154 | return 155 | } 156 | req.Header.Set("Content-Type", "application/json") 157 | req.Header.Set("Accept", "application/json") 158 | if auth != nil { 159 | req.Header.Set("Authorization", auth.Encode()) 160 | } 161 | 162 | resp, err := c.client.Do(req) 163 | if err != nil { 164 | exitCh <- fmt.Errorf("unable to send post request: %w", err) 165 | return 166 | } 167 | defer func() { _ = resp.Body.Close() }() 168 | 169 | if resp.StatusCode == http.StatusUnauthorized { 170 | exitCh <- ErrAuthRequired 171 | return 172 | } 173 | 174 | if resp.StatusCode != http.StatusAccepted { 175 | exitCh <- fmt.Errorf("invalid mcp server response status: %s", resp.Status) 176 | return 177 | } 178 | } 179 | } 180 | } 181 | 182 | func (c *sseClient) readResponse(ctx context.Context, r io.ReadCloser, ch chan []byte, exitCh chan error) { 183 | 184 | defer func() { _ = r.Close() }() 185 | 186 | scan := bufio.NewScanner(r) 187 | scan.Split(split) 188 | scan.Buffer(make([]byte, 1024), 5*1024*1024) 189 | 190 | for scan.Scan() { 191 | 192 | data := sanitize.Data(scan.Bytes()) 193 | 194 | parts := bytes.SplitN(data, []byte{'\n'}, 2) 195 | if len(parts) != 2 { 196 | exitCh <- fmt.Errorf("invalid sse message: %s", string(data)) 197 | return 198 | } 199 | 200 | data = bytes.TrimPrefix(parts[1], []byte("data: ")) 201 | 202 | select { 203 | case ch <- data: 204 | case <-ctx.Done(): 205 | return 206 | } 207 | } 208 | 209 | if err := scan.Err(); err != nil { 210 | exitCh <- fmt.Errorf("sse stream closed: %w", scan.Err()) 211 | } else { 212 | exitCh <- fmt.Errorf("sse stream closed: %w", io.EOF) 213 | } 214 | } 215 | 216 | func split(data []byte, atEOF bool) (int, []byte, error) { 217 | 218 | if atEOF && len(data) == 0 { 219 | return 0, nil, nil 220 | } 221 | 222 | if i, nlen := hasNewLine(data); i >= 0 { 223 | return i + nlen, data[0:i], nil 224 | } 225 | 226 | if atEOF { 227 | return len(data), data, nil 228 | } 229 | 230 | return 0, nil, nil 231 | } 232 | 233 | func hasNewLine(data []byte) (int, int) { 234 | 235 | crunix := bytes.Index(data, []byte("\n\n")) 236 | crwin := bytes.Index(data, []byte("\r\n\r\n")) 237 | minPos := minPos(crunix, crwin) 238 | nlen := 2 239 | if minPos == crwin { 240 | nlen = 4 241 | } 242 | return minPos, nlen 243 | } 244 | 245 | func minPos(a, b int) int { 246 | if a < 0 { 247 | return b 248 | } 249 | if b < 0 { 250 | return a 251 | } 252 | if a > b { 253 | return b 254 | } 255 | return a 256 | } 257 | -------------------------------------------------------------------------------- /pkgs/metrics/manager.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "log/slog" 7 | "net" 8 | "net/http" 9 | "strconv" 10 | "time" 11 | 12 | "github.com/prometheus/client_golang/prometheus" 13 | "github.com/prometheus/client_golang/prometheus/promhttp" 14 | "go.acuvity.ai/minibridge/pkgs/policer/api" 15 | ) 16 | 17 | type Manager struct { 18 | reqDurationMetric *prometheus.HistogramVec 19 | reqTotalMetric *prometheus.CounterVec 20 | errorMetric *prometheus.CounterVec 21 | tcpConnTotalMetric prometheus.Counter 22 | tcpConnCurrentMetric prometheus.Gauge 23 | wsConnTotalMetric prometheus.Counter 24 | wsConnCurrentMetric prometheus.Gauge 25 | policerDurationMetric *prometheus.HistogramVec 26 | policerRequestTotalMetric *prometheus.CounterVec 27 | 28 | server *http.Server 29 | } 30 | 31 | func NewManager(listen string) *Manager { 32 | 33 | r := prometheus.DefaultRegisterer 34 | 35 | mc := &Manager{ 36 | 37 | reqTotalMetric: prometheus.NewCounterVec( 38 | prometheus.CounterOpts{ 39 | Name: "http_requests_total", 40 | Help: "The total number of requests.", 41 | }, 42 | []string{"method", "url", "code"}, 43 | ), 44 | reqDurationMetric: prometheus.NewHistogramVec( 45 | prometheus.HistogramOpts{ 46 | Name: "http_requests_duration_seconds", 47 | Help: "The average duration of the requests", 48 | Buckets: []float64{0.001, 0.0025, 0.005, 0.010, 0.025, 0.050, 0.100, 0.250, 0.500, 1.0, 2.5, 5.0, 10.0}, 49 | }, 50 | []string{"method", "url"}, 51 | ), 52 | tcpConnTotalMetric: prometheus.NewCounter( 53 | prometheus.CounterOpts{ 54 | Name: "tcp_connections_total", 55 | Help: "The total number of TCP connection.", 56 | }, 57 | ), 58 | tcpConnCurrentMetric: prometheus.NewGauge( 59 | prometheus.GaugeOpts{ 60 | Name: "tcp_connections_current", 61 | Help: "The current number of TCP connection.", 62 | }, 63 | ), 64 | wsConnTotalMetric: prometheus.NewCounter( 65 | prometheus.CounterOpts{ 66 | Name: "http_ws_connections_total", 67 | Help: "The total number of ws connection.", 68 | }, 69 | ), 70 | wsConnCurrentMetric: prometheus.NewGauge( 71 | prometheus.GaugeOpts{ 72 | Name: "http_ws_connections_current", 73 | Help: "The current number of ws connection.", 74 | }, 75 | ), 76 | errorMetric: prometheus.NewCounterVec( 77 | prometheus.CounterOpts{ 78 | Name: "http_errors_5xx_total", 79 | Help: "The total number of 5xx errors.", 80 | }, 81 | []string{"trace", "method", "url", "code"}, 82 | ), 83 | 84 | policerDurationMetric: prometheus.NewHistogramVec( 85 | prometheus.HistogramOpts{ 86 | Name: "policer_requests_duration_seconds", 87 | Help: "The average duration of the policing requests", 88 | Buckets: []float64{0.001, 0.0025, 0.005, 0.010, 0.025, 0.050, 0.100, 0.250, 0.500, 1.0, 2.5, 5.0, 10.0}, 89 | }, 90 | []string{"policer_type", "call_type"}, 91 | ), 92 | policerRequestTotalMetric: prometheus.NewCounterVec( 93 | prometheus.CounterOpts{ 94 | Name: "policer_request_total", 95 | Help: "The total number of policer requests.", 96 | }, 97 | []string{"policer_type", "call_type", "decision"}, 98 | ), 99 | } 100 | 101 | r.MustRegister(mc.tcpConnCurrentMetric) 102 | r.MustRegister(mc.tcpConnTotalMetric) 103 | r.MustRegister(mc.reqTotalMetric) 104 | r.MustRegister(mc.reqDurationMetric) 105 | r.MustRegister(mc.wsConnTotalMetric) 106 | r.MustRegister(mc.wsConnCurrentMetric) 107 | r.MustRegister(mc.errorMetric) 108 | r.MustRegister(mc.policerDurationMetric) 109 | r.MustRegister(mc.policerRequestTotalMetric) 110 | 111 | mc.server = &http.Server{ 112 | Addr: listen, 113 | ReadHeaderTimeout: time.Second, 114 | Handler: mc, 115 | } 116 | 117 | return mc 118 | } 119 | 120 | func (c *Manager) Start(ctx context.Context) error { 121 | 122 | errCh := make(chan error, 1) 123 | 124 | sctx, cancel := context.WithCancel(ctx) 125 | defer cancel() 126 | 127 | c.server.BaseContext = func(net.Listener) context.Context { return sctx } 128 | c.server.RegisterOnShutdown(func() { cancel() }) 129 | 130 | go func() { 131 | err := c.server.ListenAndServe() 132 | if err != nil { 133 | if !errors.Is(err, http.ErrServerClosed) { 134 | slog.Error("unable to start health server", "err", err) 135 | } 136 | } 137 | errCh <- err 138 | }() 139 | 140 | select { 141 | case <-sctx.Done(): 142 | case err := <-errCh: 143 | return err 144 | } 145 | 146 | stopCtx, stopCancel := context.WithTimeout(context.Background(), 10*time.Second) 147 | defer stopCancel() 148 | 149 | return c.server.Shutdown(stopCtx) 150 | } 151 | 152 | func (c *Manager) MeasureRequest(method string, path string) func(int) time.Duration { 153 | 154 | timer := prometheus.NewTimer( 155 | prometheus.ObserverFunc( 156 | func(v float64) { 157 | c.reqDurationMetric.With( 158 | prometheus.Labels{ 159 | "method": method, 160 | "url": path, 161 | }, 162 | ).Observe(v) 163 | }, 164 | ), 165 | ) 166 | 167 | return func(code int) time.Duration { 168 | 169 | c.reqTotalMetric.With(prometheus.Labels{ 170 | "method": method, 171 | "url": path, 172 | "code": strconv.Itoa(code), 173 | }).Inc() 174 | 175 | if code >= http.StatusInternalServerError { 176 | 177 | c.errorMetric.With(prometheus.Labels{ 178 | "method": method, 179 | "url": path, 180 | "code": strconv.Itoa(code), 181 | }).Inc() 182 | } 183 | 184 | return timer.ObserveDuration() 185 | } 186 | } 187 | 188 | func (c *Manager) MeasurePolicer(ptype string, rtype api.CallType) func(allow bool) time.Duration { 189 | 190 | timer := prometheus.NewTimer( 191 | prometheus.ObserverFunc( 192 | func(v float64) { 193 | c.policerDurationMetric.With( 194 | prometheus.Labels{ 195 | "policer_type": ptype, 196 | "call_type": string(rtype), 197 | }, 198 | ).Observe(v) 199 | }, 200 | ), 201 | ) 202 | 203 | return func(allow bool) time.Duration { 204 | c.policerRequestTotalMetric.With(prometheus.Labels{ 205 | "policer_type": ptype, 206 | "call_type": string(rtype), 207 | "decision": func() string { 208 | if allow { 209 | return "allow" 210 | } 211 | return "deny" 212 | }(), 213 | }).Inc() 214 | 215 | return timer.ObserveDuration() 216 | } 217 | } 218 | 219 | func (c *Manager) RegisterWSConnection() { 220 | c.wsConnTotalMetric.Inc() 221 | c.wsConnCurrentMetric.Inc() 222 | } 223 | 224 | func (c *Manager) UnregisterWSConnection() { 225 | c.wsConnCurrentMetric.Dec() 226 | } 227 | 228 | func (c *Manager) RegisterTCPConnection() { 229 | c.tcpConnTotalMetric.Inc() 230 | c.tcpConnCurrentMetric.Inc() 231 | } 232 | 233 | func (c *Manager) UnregisterTCPConnection() { 234 | c.tcpConnCurrentMetric.Dec() 235 | } 236 | 237 | func (c *Manager) ServeHTTP(w http.ResponseWriter, req *http.Request) { 238 | 239 | switch req.URL.Path { 240 | 241 | case "/": 242 | w.WriteHeader(http.StatusNoContent) 243 | 244 | case "/metrics": 245 | promhttp.Handler().ServeHTTP(w, req) 246 | 247 | default: 248 | http.Error(w, "Not found", http.StatusNotFound) 249 | } 250 | } 251 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minibridge 2 | 3 | [![Build](https://github.com/acuvity/minibridge/actions/workflows/build.yaml/badge.svg?branch=main)](https://github.com/acuvity/minibridge/actions/workflows/build.yaml) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/acuvity/minibridge?cache)](https://goreportcard.com/report/github.com/acuvity/minibridge) 5 | [![GoDoc](https://pkg.go.dev/badge/github.com/acuvity/minibridge.svg)](https://pkg.go.dev/github.com/acuvity/minibridge) 6 | [![DockerHub](https://img.shields.io/badge/containers-dockerhub-blue.svg)](https://hub.docker.com/u/acuvity?page=1&search=mcp-server) 7 | 8 | Minibridge serves as a backend-to-frontend bridge, streamlining and securing 9 | communication between Agents and MCP servers. It safely exposes [MCP 10 | servers](https://modelcontextprotocol.io) to the internet and can optionally 11 | integrate with generic policing services — known as Policers — for agent 12 | authentication, content analysis, and transformation. Policers can be 13 | implemented remotely via HTTP or locally using [OPA 14 | Rego](https://www.openpolicyagent.org/docs/latest/policy-reference/) policies. 15 | 16 | Minibridge can help ensure the integrity of MCP servers through 17 | SBOM (Software Bill of Materials) generation and real-time validation. 18 | 19 | Additionally, Minibridge supports [OTEL](https://opentelemetry.io/) and can 20 | report/rettach spans from classical OTEL headers, as well as directly from the 21 | MCP call, as inserted by certain tool like [Open 22 | Inference](https://arize-ai.github.io/openinference). 23 | 24 | ![arch-overview](assets/imgs/mb-arch-overview.png) 25 | 26 | - **Minibridge Frontend**: The Client connects to the Frontend part of Minibridge. 27 | - **Minibridge Backend**: The Frontend connects to the Backend which wraps the MCP server. 28 | - **Minibridge Policer**: The Policer runs in the Backend and can optionally take decision on the input and output based on some policies (locally with Rego or remotely using HTTPs) 29 | 30 | > [!TIP] 31 | > Conveniently, Minibridge can be started in an "all-in-one" (AIO) mode to act as a single process. 32 | 33 | ## Why using Minibridge ? 34 | 35 | Minibridge covers the following: 36 | 37 | - **Secure Transport**: Use TLS with optionally, client certificate validation 38 | - **Integrity**: Ensure the MCP server can not mutate tools, templates, etc. during the execution 39 | - **User Authentication**: Transport the user information to the Policer 40 | - **Monitoring**: Expose prometheus metrics 41 | - **Telemetry**: Report traces and spans using Opentelemetry 42 | 43 | ## Installation 44 | 45 | Minibridge can be installed from various places: 46 | 47 | ### Homebrew 48 | 49 | On macOS, you can use Homebrew 50 | 51 | ```console 52 | brew tap acuvity/tap 53 | brew install minibridge 54 | ``` 55 | 56 | ### AUR 57 | 58 | On Arch based Linux distributions, you can run: 59 | 60 | ```console 61 | yay -S minibridge 62 | ``` 63 | 64 | Alternatively, to get the latest version from the main branch: 65 | 66 | ```console 67 | yay -S minibridge-git 68 | ``` 69 | 70 | ### Go 71 | 72 | If you have the Go toolchain install: 73 | 74 | ```console 75 | go install go.acuvity.ai/minibridge@latest 76 | ``` 77 | 78 | Alternatively, to get the latest version from the main branch: 79 | 80 | ```console 81 | go install go.acuvity.ai/minibridge@main 82 | ``` 83 | 84 | ### Manually 85 | 86 | You can easily grab a binary version for your platform from the [release 87 | page](https://github.com/acuvity/minibridge/releases/tag/v0.6.2). 88 | 89 | 90 | ## Features comparisons 91 | 92 | | 🚀 **Feature** | 🔹 **MCP** | 🔸 **Minibridge** | 📦 **ARC (Acuvity Containers)** | 93 | | ------------------------------ | ----------- | ----------------- | ------------------------------- | 94 | | 🌐 Remote Access | ⚠️ | ✅ | ✅ | 95 | | 🔒 TLS Support | ❌ | ✅ | ✅ | 96 | | 📃 Tool integrity check | ❌ | ✅ | ✅ | 97 | | 📊 Visualization and Tracing | ❌ | ✅ | ✅ | 98 | | 🛡️ Isolation | ❌ | ⚠️ | ✅ | 99 | | 🔐 Security Policy Management | ❌ | 👤 | ⚠️ | 100 | | 🕵️ Secrets Redaction | ❌ | 👤 | ⚠️ | 101 | | 🔑 Authorization Controls | ❌ | 👤 | 👤 | 102 | | 🧑‍💻 PII Detection and Redaction | ❌ | 👤 | 👤 | 103 | | 📌 Version Pinning | ❌ | ❌ | ✅ | 104 | 105 | ✅ _Included_ | ⚠️ _Partial/Basic Support_ | 👤 _Custom User Implementation_ | ❌ _Not Supported_ 106 | 107 | ## Example: Configuring Minibridge in your MCP Client 108 | 109 | Suppose your client configuration originally specifies an MCP server like this: 110 | 111 | ```json 112 | { 113 | "mcpServers": { 114 | "fetch": { 115 | "command": "uvx", 116 | "args": ["mcp-server-fetch"] 117 | } 118 | } 119 | } 120 | ``` 121 | 122 | To route requests through Minibridge (enabling SBOM checks, policy enforcement, etc.), update the entry: 123 | 124 | ```json 125 | { 126 | "mcpServers": { 127 | "fetch": { 128 | "command": "minibridge", 129 | "args": ["aio", "--", "uvx", "mcp-server-fetch"] 130 | } 131 | } 132 | } 133 | ``` 134 | 135 | - **`minibridge aio`**: Invokes Minibridge in “all-in-one” mode, wrapping the downstream tool. 136 | - **`uvx mcp-server-fetch`**: The original MCP server command, now executed inside Minibridge. 137 | 138 | > [!TIP] 139 | > The location of the configuration files depends on your Client. For example, if you use Claude Desktop, configuration files are located: 140 | > 141 | > - macOS: `~/Library/Application Support/Claude/claude_desktop_config.json` 142 | > - Windows: `%APPDATA%\Claude\claude_desktop_config.json` 143 | > 144 | > See the official [MCP QuickStart for Claude Desktop Users](https://modelcontextprotocol.io/quickstart/user#2-add-the-filesystem-mcp-server) documentation. 145 | 146 | > [!IMPORTANT] 147 | > Your client must be able to resolve the path of the binary. 148 | > If you see an error like `MCP fetch: spawn minibridge ENOENT`, set the `command` parameter above to the full path of minibridge (`which minibridge` will give you the full path). 149 | 150 | ## Documentation 151 | 152 | Check out the complete [documentation](https://github.com/acuvity/minibridge/wiki) from the wiki pages. 153 | 154 | ## Contribute 155 | 156 | We are excited to welcome contributions from everyone! 🎉 Whether you're fixing bugs, enhancing features, improving documentation, or proposing entirely new ideas, your involvement helps strengthen the project and benefits the entire community. 157 | 158 | You do not need to sign a Contributor License Agreement (CLA) — just open a pull request and let's collaborate! 159 | 160 | ## Join us 161 | 162 | - [Discord](https://discord.gg/BkU7fBkrNk) 163 | - [LinkedIn](https://www.linkedin.com/company/acuvity) 164 | - [Bluesky](https://bsky.app/profile/acuvity.bsky.social) 165 | - [Docker](https://hub.docker.com/u/acuvity) 166 | -------------------------------------------------------------------------------- /pkgs/memconn/memconn.go: -------------------------------------------------------------------------------- 1 | // Package memconn provides an in-memory network connections. This allows 2 | // applications to connect to themselves without having to open up ports on the 3 | // network. 4 | package memconn 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net" 13 | "os" 14 | "sync" 15 | "time" 16 | ) 17 | 18 | const connBufferSize = 64 // 1KB buffer size 19 | 20 | // Conn is an in-memory connection. Every Conn has a remote peer representing 21 | // the other side of the connection. Writes to Conn will be sent to the peer's 22 | // buffer for reading. 23 | // 24 | // conns use a 1kB buffer where writes can be sent immediately. Writes beyond 25 | // the buffer size will be blocked until they are read by the peer. 26 | type Conn struct { 27 | peerCh chan struct{} // Closed when peer is set 28 | peer *Conn 29 | 30 | cnd *sync.Cond 31 | buf bytes.Buffer 32 | readTimeout time.Time 33 | readTimeoutCancel context.CancelFunc 34 | writeTimeout time.Time 35 | writeTimeoutCancel context.CancelFunc 36 | closed bool 37 | name string 38 | } 39 | 40 | var _ net.Conn = (*Conn)(nil) 41 | 42 | func newConn() *Conn { 43 | var mut sync.Mutex 44 | return &Conn{ 45 | peerCh: make(chan struct{}), 46 | cnd: sync.NewCond(&mut), 47 | name: "memory", 48 | } 49 | } 50 | 51 | // Attach sets the remote peer. Panics if called more than once. 52 | func (c *Conn) Attach(peer *Conn) { 53 | select { 54 | default: 55 | c.peer = peer 56 | close(c.peerCh) 57 | case <-c.peerCh: 58 | panic("peer already set") 59 | } 60 | } 61 | 62 | // WaitPeer waits for a peer to be set or until ctx is canceled. 63 | func (c *Conn) WaitPeer(ctx context.Context) error { 64 | select { 65 | case <-ctx.Done(): 66 | return ctx.Err() 67 | case <-c.peerCh: 68 | return nil 69 | } 70 | } 71 | 72 | func (c *Conn) Read(b []byte) (n int, err error) { 73 | for n == 0 { 74 | n2, err := c.readOrBlock(b) 75 | if err != nil { 76 | return n2, err 77 | } 78 | n += n2 79 | } 80 | c.cnd.Signal() // Wake up calls to .Write 81 | return n, nil 82 | } 83 | 84 | func (c *Conn) readOrBlock(b []byte) (int, error) { 85 | c.cnd.L.Lock() 86 | defer c.cnd.L.Unlock() 87 | 88 | if !c.readTimeout.IsZero() && !time.Now().Before(c.readTimeout) { 89 | return 0, os.ErrDeadlineExceeded 90 | } 91 | 92 | n, err := c.buf.Read(b) 93 | 94 | // We expect to get EOF from our buffer whenever there's no pending data. We 95 | // don't want to propagate the EOF to our reader until the conn itself is 96 | // closed. 97 | if errors.Is(err, io.EOF) { 98 | if c.closed { 99 | return n, err 100 | } 101 | 102 | // Wait until we're woken up by something, either because there's a timeout 103 | // or there's data to read. Spurious wakeups may happen, which would 104 | // eventually cause us to re-enter the wait. 105 | c.cnd.Wait() 106 | } 107 | return n, nil 108 | } 109 | 110 | func (c *Conn) Write(b []byte) (n int, err error) { 111 | for len(b) > 0 { 112 | n2, err := c.writeOrBlock(b) 113 | if err != nil { 114 | return n + n2, err 115 | } 116 | n += n2 117 | b = b[n2:] 118 | } 119 | return n, nil 120 | } 121 | 122 | func (c *Conn) writeOrBlock(b []byte) (int, error) { 123 | if err := c.writeAvail(); err != nil { 124 | return 0, err 125 | } 126 | return c.peer.enqueueOrBlock(b) 127 | } 128 | 129 | // writeAvail returns nil when writing is available. 130 | func (c *Conn) writeAvail() error { 131 | c.cnd.L.Lock() 132 | defer c.cnd.L.Unlock() 133 | 134 | switch { 135 | case c.closed: 136 | return net.ErrClosed 137 | case !c.writeTimeout.IsZero() && !time.Now().Before(c.writeTimeout): 138 | return os.ErrDeadlineExceeded 139 | default: 140 | return nil 141 | } 142 | } 143 | 144 | // enqueueOrBlock is invoked by a peer and writes b into the local buffer. 145 | func (c *Conn) enqueueOrBlock(b []byte) (int, error) { 146 | c.cnd.L.Lock() 147 | defer c.cnd.L.Unlock() 148 | if c.closed { 149 | return 0, net.ErrClosed 150 | } 151 | 152 | // Try to write as much as possible. 153 | n := len(b) 154 | limit := connBufferSize - c.buf.Len() 155 | if limit < n { 156 | n = limit 157 | } 158 | 159 | if n == 0 { 160 | // Buffer is completely full; wait for it to free up. 161 | c.cnd.Wait() 162 | return 0, nil 163 | } 164 | 165 | c.buf.Write(b[:n]) 166 | c.cnd.Signal() // Signal that data can be read 167 | return n, nil 168 | } 169 | 170 | // Close closes both sides of the connection. 171 | func (c *Conn) Close() error { 172 | err := c.handleClose() 173 | _ = c.peer.handleClose() 174 | return err 175 | } 176 | 177 | func (c *Conn) handleClose() error { 178 | c.cnd.L.Lock() 179 | defer c.cnd.L.Unlock() 180 | 181 | if c.closed { 182 | return nil 183 | } 184 | c.closed = true 185 | 186 | if c.readTimeoutCancel != nil { 187 | c.readTimeoutCancel() 188 | c.readTimeoutCancel = nil 189 | } 190 | if c.writeTimeoutCancel != nil { 191 | c.writeTimeoutCancel() 192 | c.writeTimeoutCancel = nil 193 | } 194 | c.broadcast() 195 | return nil 196 | } 197 | 198 | // broadcast will wake up all sleeping goroutines on the local and peer 199 | // connections. 200 | func (c *Conn) broadcast() { 201 | // We *MUST* wake up goroutines waiting on both the local and remote 202 | // condition variables because usage of a conn depends on both. 203 | c.cnd.Broadcast() 204 | c.peer.cnd.Broadcast() 205 | } 206 | 207 | func (c *Conn) LocalAddr() net.Addr { 208 | return Addr{ 209 | name: c.name, 210 | } 211 | } 212 | 213 | func (c *Conn) RemoteAddr() net.Addr { 214 | return c.peer.LocalAddr() 215 | } 216 | 217 | func (c *Conn) SetDeadline(t time.Time) error { 218 | var firstError error 219 | if err := c.SetReadDeadline(t); err != nil { 220 | firstError = err 221 | } 222 | if err := c.SetWriteDeadline(t); err != nil && firstError == nil { 223 | firstError = err 224 | } 225 | return firstError 226 | } 227 | 228 | func (c *Conn) SetReadDeadline(t time.Time) error { 229 | 230 | c.cnd.L.Lock() 231 | defer c.cnd.L.Unlock() 232 | if c.closed { 233 | return fmt.Errorf("conn closed") 234 | } 235 | 236 | c.readTimeout = t 237 | 238 | // There should only be one deadline goroutine at a time, so cancel it if it 239 | // already exists. 240 | if c.readTimeoutCancel != nil { 241 | c.readTimeoutCancel() 242 | c.readTimeoutCancel = nil 243 | } 244 | c.readTimeoutCancel = c.deadlineTimer(t) 245 | return nil 246 | } 247 | 248 | func (c *Conn) deadlineTimer(t time.Time) context.CancelFunc { 249 | if t.IsZero() { 250 | // Deadline of zero means to wait forever. 251 | return nil 252 | } 253 | if t.Before(time.Now()) { 254 | c.broadcast() 255 | } 256 | ctx, cancel := context.WithDeadline(context.Background(), t) 257 | go func() { 258 | <-ctx.Done() 259 | if errors.Is(ctx.Err(), context.DeadlineExceeded) { 260 | c.broadcast() 261 | } 262 | }() 263 | return cancel 264 | } 265 | 266 | func (c *Conn) SetWriteDeadline(t time.Time) error { 267 | c.cnd.L.Lock() 268 | defer c.cnd.L.Unlock() 269 | if c.closed { 270 | return fmt.Errorf("conn closed") 271 | } 272 | 273 | c.writeTimeout = t 274 | 275 | // There should only be one deadline goroutine at a time, so cancel it if it 276 | // already exists. 277 | if c.writeTimeoutCancel != nil { 278 | c.writeTimeoutCancel() 279 | c.writeTimeoutCancel = nil 280 | } 281 | c.writeTimeoutCancel = c.deadlineTimer(t) 282 | return nil 283 | } 284 | -------------------------------------------------------------------------------- /pkgs/oauth/dance.go: -------------------------------------------------------------------------------- 1 | package oauth 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "log/slog" 11 | "net/http" 12 | "net/url" 13 | "time" 14 | 15 | "github.com/gofrs/uuid" 16 | "github.com/pkg/browser" 17 | "go.acuvity.ai/minibridge/pkgs/info" 18 | ) 19 | 20 | type RegistrationRequest struct { 21 | RedirectURI []string `json:"redirect_uris"` 22 | ClientName string `json:"client_name"` 23 | ClientURI string `json:"client_uri,omitempty"` 24 | TokenEndpointMethod string `json:"token_endpoint_auth_method"` 25 | LogoURI string `json:"logo_uri,omitempty"` 26 | ResponseTypes []string `json:"response_types,omitempty"` 27 | GrantTypes []string `json:"grant_types,omitempty"` 28 | } 29 | 30 | type RegistrationResponse struct { 31 | ClientID string `json:"client_id,omitempty"` 32 | RegistrationClientID string `json:"registration_client_uri,omitempty"` 33 | ClientIDIssuedAt int `json:"client_id_issued_at,omitempty"` 34 | } 35 | 36 | type Credentials struct { 37 | ClientID string `json:"client_id"` 38 | AccessToken string `json:"access_token"` 39 | RefreshToken string `json:"refresh_token"` 40 | } 41 | 42 | // Refresh performs a refresh of the access token using the RefreshToken in the gviven Creds 43 | func Refresh(ctx context.Context, backendURL string, cl *http.Client, int info.Info, rt Credentials) (t Credentials, err error) { 44 | 45 | u := fmt.Sprintf("%s/oauth2/token", backendURL) 46 | 47 | form := url.Values{ 48 | "grant_type": {"refresh_token"}, 49 | "refresh_token": {rt.RefreshToken}, 50 | "client_id": {rt.ClientID}, 51 | } 52 | 53 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, bytes.NewBuffer([]byte(form.Encode()))) 54 | if err != nil { 55 | return t, fmt.Errorf("unable to make token request: %w", err) 56 | } 57 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 58 | req.Header.Set("Accept", "application/json") 59 | 60 | resp, err := cl.Do(req) 61 | if err != nil { 62 | return t, fmt.Errorf("unable to send token request: %w", err) 63 | } 64 | defer func(r *http.Response) { _ = r.Body.Close() }(resp) 65 | 66 | data, err := io.ReadAll(resp.Body) 67 | if err != nil && !errors.Is(err, io.EOF) { 68 | return t, fmt.Errorf("unable to read token response body: %w", err) 69 | } 70 | 71 | if resp.StatusCode != http.StatusOK { 72 | return t, fmt.Errorf("invalid token response status code: %s (%s)", resp.Status, string(data)) 73 | } 74 | 75 | if err := json.Unmarshal(data, &t); err != nil { 76 | return t, fmt.Errorf("unable to unmarshal token response body: %w", err) 77 | } 78 | 79 | t.ClientID = rt.ClientID 80 | 81 | return t, nil 82 | } 83 | 84 | func Dance(ctx context.Context, backendURL string, cl *http.Client, inf info.Info) (t Credentials, err error) { 85 | 86 | redirectURI := "http://127.0.0.1:9977/callback" 87 | clientID := "" 88 | state := uuid.Must(uuid.NewV7()).String() 89 | 90 | if inf.OAuthRegister { 91 | 92 | oreq := RegistrationRequest{ 93 | ClientName: "minibridge", 94 | ClientURI: "https://github.com/acuvity/minibridge", 95 | TokenEndpointMethod: "none", 96 | RedirectURI: []string{redirectURI}, 97 | GrantTypes: []string{"authorization_code", "refresh_token"}, 98 | ResponseTypes: []string{"code"}, 99 | } 100 | 101 | data, err := json.MarshalIndent(oreq, "", " ") 102 | if err != nil { 103 | return t, fmt.Errorf("unable to marshal registration request: %w", err) 104 | } 105 | 106 | u := fmt.Sprintf("%s/oauth2/register", backendURL) 107 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, bytes.NewBuffer(data)) 108 | if err != nil { 109 | return t, fmt.Errorf("unable to build registration request: %w", err) 110 | } 111 | req.Header.Set("Content-Type", "application/json") 112 | req.Header.Set("Accept", "application/json") 113 | 114 | resp, err := cl.Do(req) 115 | if err != nil { 116 | return t, fmt.Errorf("unable to send registration request: %w", err) 117 | } 118 | defer func(r *http.Response) { _ = r.Body.Close() }(resp) 119 | 120 | data, err = io.ReadAll(resp.Body) 121 | if err != nil && !errors.Is(err, io.EOF) { 122 | return t, fmt.Errorf("unable to read registration response body: %w", err) 123 | } 124 | 125 | if resp.StatusCode != http.StatusCreated { 126 | return t, fmt.Errorf("invalid registration response status code: %s (%s)", resp.Status, string(data)) 127 | } 128 | 129 | oresp := RegistrationResponse{} 130 | if err := json.Unmarshal(data, &oresp); err != nil { 131 | return t, fmt.Errorf("unable to unmarshal registration response body: %w", err) 132 | } 133 | 134 | clientID = oresp.ClientID 135 | } 136 | 137 | values := url.Values{ 138 | "response_type": {"code"}, 139 | "client_id": {clientID}, 140 | "redirect_uri": {redirectURI}, 141 | "state": {state}, 142 | } 143 | 144 | u := fmt.Sprintf("%s/authorize?%s", inf.Server, values.Encode()) 145 | if err := browser.OpenURL(u); err != nil { 146 | fmt.Println("Open the following URL in your browser:", u) 147 | } 148 | 149 | codeCh := make(chan string, 1) 150 | 151 | server := &http.Server{ 152 | ReadHeaderTimeout: 3 * time.Second, 153 | Addr: "127.0.0.1:9977", 154 | } 155 | 156 | server.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 157 | 158 | codeCh <- req.URL.Query().Get("code") 159 | 160 | w.WriteHeader(http.StatusOK) 161 | _, _ = w.Write(fmt.Appendf([]byte{}, successBody, inf.Server)) 162 | 163 | sctx, cancel := context.WithTimeout(ctx, 1*time.Second) 164 | defer cancel() 165 | 166 | _ = server.Shutdown(sctx) 167 | }) 168 | 169 | go func() { 170 | if err := server.ListenAndServe(); err != nil { 171 | if !errors.Is(err, http.ErrServerClosed) { 172 | slog.Error("Unable to start oauth callback server", err) 173 | return 174 | } 175 | } 176 | }() 177 | 178 | var code string 179 | select { 180 | case code = <-codeCh: 181 | case <-ctx.Done(): 182 | return t, ctx.Err() 183 | case <-time.After(10 * time.Minute): 184 | return t, fmt.Errorf("oauth timeout") 185 | } 186 | 187 | u = fmt.Sprintf("%s/oauth2/token", backendURL) 188 | 189 | form := url.Values{ 190 | "grant_type": {"authorization_code"}, 191 | "client_id": {clientID}, 192 | "code": {code}, 193 | "redirect_uri": {redirectURI}, 194 | } 195 | 196 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, bytes.NewBuffer([]byte(form.Encode()))) 197 | if err != nil { 198 | return t, fmt.Errorf("unable to make token request: %w", err) 199 | } 200 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 201 | req.Header.Set("Accept", "application/json") 202 | 203 | resp, err := cl.Do(req) 204 | if err != nil { 205 | return t, fmt.Errorf("unable to send token request: %w", err) 206 | } 207 | defer func(r *http.Response) { _ = r.Body.Close() }(resp) 208 | 209 | data, err := io.ReadAll(resp.Body) 210 | if err != nil && !errors.Is(err, io.EOF) { 211 | return t, fmt.Errorf("unable to read token response body: %w", err) 212 | } 213 | 214 | if resp.StatusCode != http.StatusOK { 215 | return t, fmt.Errorf("invalid token response status code: %s (%s)", resp.Status, string(data)) 216 | } 217 | 218 | if err := json.Unmarshal(data, &t); err != nil { 219 | return t, fmt.Errorf("unable to unmarshal token response body: %w", err) 220 | } 221 | 222 | t.ClientID = clientID 223 | 224 | return t, nil 225 | } 226 | -------------------------------------------------------------------------------- /pkgs/backend/client/stream_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | "time" 8 | 9 | . "github.com/smartystreets/goconvey/convey" 10 | "go.acuvity.ai/minibridge/pkgs/mcp" 11 | ) 12 | 13 | func TestMCPStream(t *testing.T) { 14 | 15 | Convey("I have a running stream, registrations should work", t, func() { 16 | ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) 17 | defer cancel() 18 | 19 | stream := NewMCPStream(ctx) 20 | 21 | stdout1, closeStdout1 := stream.Stdout() 22 | defer closeStdout1() 23 | stdout2, closeStdout2 := stream.Stdout() 24 | defer closeStdout2() 25 | 26 | stderr1, closeStderr1 := stream.Stderr() 27 | defer closeStderr1() 28 | stderr2, closeStderr2 := stream.Stderr() 29 | defer closeStderr2() 30 | 31 | exit1, closeExit1 := stream.Exit() 32 | defer closeExit1() 33 | exit2, closeExit2 := stream.Exit() 34 | defer closeExit2() 35 | 36 | cstdout1 := make(chan []byte) 37 | go func() { cstdout1 <- <-stdout1 }() 38 | cstdout2 := make(chan []byte) 39 | go func() { cstdout2 <- <-stdout2 }() 40 | cstderr1 := make(chan []byte) 41 | go func() { cstderr1 <- <-stderr1 }() 42 | cstderr2 := make(chan []byte) 43 | go func() { cstderr2 <- <-stderr2 }() 44 | cexit1 := make(chan error) 45 | go func() { cexit1 <- <-exit1 }() 46 | cexit2 := make(chan error) 47 | go func() { cexit2 <- <-exit2 }() 48 | 49 | go func() { 50 | stream.stdout <- []byte("hello stdout") 51 | stream.stderr <- []byte("hello stderr") 52 | stream.exit <- fmt.Errorf("hello from error") 53 | }() 54 | 55 | So(string(<-cstdout1), ShouldEqual, "hello stdout") 56 | So(string(<-cstdout2), ShouldEqual, "hello stdout") 57 | So(string(<-cstderr1), ShouldEqual, "hello stderr") 58 | So(string(<-cstderr2), ShouldEqual, "hello stderr") 59 | So((<-cexit1).Error(), ShouldEqual, "hello from error") 60 | So((<-cexit2).Error(), ShouldEqual, "hello from error") 61 | 62 | // let's unregister all 1s 63 | // this will test that all channel will 64 | // correctly unregister and will not block 65 | // the rest. 66 | closeStdout1() 67 | closeStderr1() 68 | closeExit1() 69 | 70 | go func() { cstdout2 <- <-stdout2 }() 71 | go func() { cstderr2 <- <-stderr2 }() 72 | go func() { cexit2 <- <-exit2 }() 73 | 74 | go func() { 75 | stream.stdout <- []byte("hello stdout 2") 76 | stream.stderr <- []byte("hello stderr 2") 77 | stream.exit <- fmt.Errorf("hello from error 2") 78 | }() 79 | 80 | So(string(<-cstdout2), ShouldEqual, "hello stdout 2") 81 | So(string(<-cstderr2), ShouldEqual, "hello stderr 2") 82 | So((<-cexit2).Error(), ShouldEqual, "hello from error 2") 83 | }) 84 | 85 | Convey("SendRequest should work", t, func() { 86 | 87 | stream := NewMCPStream(t.Context()) 88 | 89 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) 90 | defer cancel() 91 | 92 | done := make(chan bool, 1) 93 | cstdin := make(chan []byte, 2) 94 | go func() { 95 | cstdin <- <-stream.stdin 96 | stream.stdout <- []byte(`{"id":"not-id"}`) 97 | stream.stdout <- []byte(`{"id":"not-id-again-3"}`) 98 | stream.stdout <- []byte(`{"id":"id","result":{"a":1}}`) 99 | stream.stdout <- []byte(`{"id":"not-id-again-4"}`) 100 | done <- true 101 | }() 102 | 103 | resp, err := stream.SendRequest(ctx, mcp.NewMessage("id")) 104 | 105 | So(string(<-cstdin), ShouldResemble, `{"id":"id","jsonrpc":"2.0"}`) 106 | 107 | So(err, ShouldBeNil) 108 | So(resp, ShouldNotBeNil) 109 | So(resp.ID, ShouldEqual, "id") 110 | So(resp.Result["a"], ShouldEqual, 1) 111 | So(<-done, ShouldBeTrue) 112 | }) 113 | 114 | Convey("SendRequest should work when context cancels before sending data", t, func() { 115 | 116 | stream := NewMCPStream(t.Context()) 117 | 118 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) 119 | cancel() 120 | 121 | _, err := stream.SendRequest(ctx, mcp.NewMessage("id")) 122 | So(err, ShouldNotBeNil) 123 | So(err.Error(), ShouldEqual, "unable to send request: context canceled") 124 | }) 125 | 126 | Convey("SendRequest should work when context cancels while awaiting for response", t, func() { 127 | 128 | stream := NewMCPStream(t.Context()) 129 | 130 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) 131 | defer cancel() 132 | 133 | done := make(chan bool, 1) 134 | cstdin := make(chan []byte, 1) 135 | go func() { 136 | cstdin <- <-stream.stdin 137 | stream.stdout <- []byte(`{"id":"not-id"}`) 138 | stream.stdout <- []byte(`{"id":"not-id-again-3"}`) 139 | stream.stdout <- []byte(`{"id":"not-id-again-4"}`) 140 | cancel() 141 | done <- true 142 | }() 143 | 144 | _, err := stream.SendRequest(ctx, mcp.NewMessage("id")) 145 | 146 | So(string(<-cstdin), ShouldResemble, `{"id":"id","jsonrpc":"2.0"}`) 147 | So(err, ShouldNotBeNil) 148 | So(err.Error(), ShouldEqual, "unable to get response: context canceled") 149 | }) 150 | 151 | Convey("SendRequest should handle invalid json response for summary", t, func() { 152 | 153 | stream := NewMCPStream(t.Context()) 154 | 155 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) 156 | defer cancel() 157 | 158 | go func() { 159 | <-stream.stdin 160 | stream.stdout <- []byte(`"id":"not-id"}`) 161 | }() 162 | 163 | _, err := stream.SendRequest(ctx, mcp.NewMessage("id")) 164 | 165 | So(err, ShouldNotBeNil) 166 | So(err.Error(), ShouldStartWith, "unable to decode mcp call as summary: ") 167 | }) 168 | 169 | Convey("SendRequest should handle invalid json response for mcp call", t, func() { 170 | 171 | stream := NewMCPStream(t.Context()) 172 | 173 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) 174 | defer cancel() 175 | 176 | go func() { 177 | <-stream.stdin 178 | stream.stdout <- []byte(`{"id":"id", "result": "not a map"}`) 179 | }() 180 | 181 | _, err := stream.SendRequest(ctx, mcp.NewMessage("id")) 182 | 183 | So(err, ShouldNotBeNil) 184 | So(err.Error(), ShouldStartWith, "unable to decode mcp call: ") 185 | }) 186 | 187 | Convey("calling SendPaginatedRequest should work", t, func() { 188 | 189 | stream := NewMCPStream(t.Context()) 190 | 191 | ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) 192 | defer cancel() 193 | 194 | // enqueue the resp immediately 195 | cstdin := make(chan []byte, 3) 196 | go func() { 197 | cstdin <- <-stream.stdin 198 | stream.stdout <- []byte(`{"id":"a","jsonrpc":"2.0","result":{"nextCursor":"1"}}`) 199 | cstdin <- <-stream.stdin 200 | stream.stdout <- []byte(`{"id":"a","jsonrpc":"2.0"}`) 201 | }() 202 | 203 | calls, err := stream.SendPaginatedRequest(ctx, mcp.NewMessage("a")) 204 | So(err, ShouldBeNil) 205 | So(len(calls), ShouldEqual, 2) 206 | 207 | So(string(<-cstdin), ShouldEqual, `{"id":"a","jsonrpc":"2.0"}`) 208 | So(string(<-cstdin), ShouldEqual, `{"id":"a","jsonrpc":"2.0","params":{"cursor":"1"}}`) 209 | }) 210 | 211 | Convey("calling SendPaginatedRequest while context cancels should work", t, func() { 212 | 213 | ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) 214 | defer cancel() 215 | 216 | stream := NewMCPStream(ctx) 217 | 218 | sctx, cancel := context.WithCancel(t.Context()) 219 | cancel() 220 | 221 | // enqueue the resp immediately 222 | go func() { stream.stdout <- []byte(`{"id":46,"jsonrpc":"2.0","result":{"nextCursor":"1"}}`) }() 223 | 224 | calls, err := stream.SendPaginatedRequest(sctx, mcp.NewMessage(44)) 225 | So(err, ShouldNotBeNil) 226 | So(err.Error(), ShouldEqual, "unable to send paginated request: unable to send request: context canceled") 227 | So(len(calls), ShouldEqual, 0) 228 | }) 229 | } 230 | -------------------------------------------------------------------------------- /pkgs/backend/client/stream.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log/slog" 7 | "sync" 8 | "time" 9 | 10 | "go.acuvity.ai/elemental" 11 | "go.acuvity.ai/minibridge/pkgs/mcp" 12 | ) 13 | 14 | // MCPStream holds the MCP Server stdio streams as channels. 15 | // 16 | // It only deals with []byte containing MCP messages. The data 17 | // is not validated, in or out. 18 | // 19 | // It accepts input via the chan []byte returned by Stdin(). 20 | // 21 | // To access stdout, or stderr or the exit channel, you must call 22 | // Stdout(), Stderr() or Exit() to get a chan []byte you can 23 | // pull data from. 24 | // 25 | // The returned channel is registered in a pool of other subcriber 26 | // channels that will all receive a broadcast of the data when they arrive. 27 | // 28 | // These functions also return a func() that must be called to 29 | // unregister the channel from the pool. 30 | // Failure to do so will leak go routines. 31 | // 32 | // The consumers of the channels must be mindful of the rest of the system. 33 | // All channel operations are blocking to avoid possibly complete out of order 34 | // messages (responses before requests). So we you register a channel, be sure 35 | // to consume as fast as possible. 36 | type MCPStream struct { 37 | stdin chan []byte 38 | stdout chan []byte 39 | stderr chan []byte 40 | exit chan error 41 | 42 | outChs map[chan []byte]struct{} 43 | errChs map[chan []byte]struct{} 44 | exitChs map[chan error]struct{} 45 | 46 | sync.RWMutex 47 | } 48 | 49 | // NewMCPStream returns an initialized *MCPStream. 50 | // It will start a listener in the background that will run 51 | // ultimately up until the provided context cancels. 52 | func NewMCPStream(ctx context.Context) *MCPStream { 53 | 54 | s := &MCPStream{ 55 | stderr: make(chan []byte), 56 | stdout: make(chan []byte), 57 | stdin: make(chan []byte), 58 | exit: make(chan error), 59 | outChs: map[chan []byte]struct{}{}, 60 | errChs: map[chan []byte]struct{}{}, 61 | exitChs: map[chan error]struct{}{}, 62 | } 63 | 64 | s.start(ctx) 65 | 66 | return s 67 | } 68 | 69 | // Stdin returns a channel that will accepts []byte 70 | // containing MCP messages from the client. 71 | func (s *MCPStream) Stdin() chan []byte { 72 | return s.stdin 73 | } 74 | 75 | // Stdout returns a channel that will produce []byte 76 | // containing MCP messages from the MCP server. 77 | // It also returns a function that must be called 78 | // when the channel is not needed anymore. 79 | // Failure to do so will leak go routines. 80 | func (s *MCPStream) Stdout() (chan []byte, func()) { 81 | c := make(chan []byte, 8) 82 | s.registerOut(c) 83 | return c, func() { s.unregisterOut(c) } 84 | } 85 | 86 | // Stderr returns a channel that will produce []byte 87 | // containing MCP Server logs. 88 | // It also returns a function that must be called 89 | // when the channel is not needed anymore. 90 | // Failure to do so will leak go routines. 91 | func (s *MCPStream) Stderr() (chan []byte, func()) { 92 | c := make(chan []byte, 8) 93 | s.registerErr(c) 94 | return c, func() { s.unregisterErr(c) } 95 | } 96 | 97 | // Exit returns a channel that will produce an error 98 | // representing the end of the MCP server execution. 99 | // Once a message is received from this channel, 100 | // The MCPStream should be considered dead. 101 | // It also returns a function that must be called 102 | // when the channel is not needed anymore. 103 | // Failure to do so will leak go routines. 104 | func (s *MCPStream) Exit() (chan error, func()) { 105 | c := make(chan error, 1) 106 | s.registerExit(c) 107 | return c, func() { s.unregisterExit(c) } 108 | } 109 | 110 | // SendNotification sends a mcp.Message without waiting for a reply. 111 | func (s *MCPStream) SendNotification(ctx context.Context, notif mcp.Notification) error { 112 | 113 | data, err := elemental.Encode(elemental.EncodingTypeJSON, notif) 114 | if err != nil { 115 | return fmt.Errorf("unable to encode mcp notification: %w", err) 116 | } 117 | 118 | select { 119 | case s.stdin <- data: 120 | case <-ctx.Done(): 121 | return fmt.Errorf("unable to send mcp notification: %w", ctx.Err()) 122 | } 123 | 124 | return nil 125 | } 126 | 127 | // SendRequest sends the given MCP request and waits for a MCP response related to the 128 | // request ID, up until the provider context expires. 129 | // The request is not validated. 130 | func (s *MCPStream) SendRequest(ctx context.Context, req mcp.Message) (resp mcp.Message, err error) { 131 | 132 | data, err := elemental.Encode(elemental.EncodingTypeJSON, req) 133 | if err != nil { 134 | return resp, fmt.Errorf("unable to encode mcp call: %w", err) 135 | } 136 | 137 | stdout, unregister := s.Stdout() 138 | defer unregister() 139 | 140 | select { 141 | case s.stdin <- data: 142 | case <-ctx.Done(): 143 | return resp, fmt.Errorf("unable to send request: %w", ctx.Err()) 144 | } 145 | 146 | summary := struct { 147 | ID any `json:"id"` 148 | }{} 149 | 150 | for { 151 | select { 152 | 153 | case <-ctx.Done(): 154 | return resp, fmt.Errorf("unable to get response: %w", ctx.Err()) 155 | 156 | case data := <-stdout: 157 | 158 | if err := elemental.Decode(elemental.EncodingTypeJSON, data, &summary); err != nil { 159 | return req, fmt.Errorf("unable to decode mcp call as summary: %w", err) 160 | } 161 | 162 | if !mcp.RelatedIDs(req.ID, summary.ID) { 163 | continue 164 | } 165 | 166 | if err := elemental.Decode(elemental.EncodingTypeJSON, data, &resp); err != nil { 167 | return req, fmt.Errorf("unable to decode mcp call: %w", err) 168 | } 169 | 170 | return resp, nil 171 | } 172 | } 173 | } 174 | 175 | // SendPaginatedRequest works like SendRequest, but will retrieve the next pages until 176 | // it reaches the end. All responses are returned at once in an slice of mcp.Message. 177 | func (s *MCPStream) SendPaginatedRequest(ctx context.Context, msg mcp.Message) (out []mcp.Message, err error) { 178 | 179 | var resp mcp.Message 180 | 181 | currentRequest := msg 182 | 183 | for { 184 | 185 | resp, err = s.SendRequest(ctx, currentRequest) 186 | if err != nil { 187 | return nil, fmt.Errorf("unable to send paginated request: %w", err) 188 | } 189 | 190 | out = append(out, resp) 191 | 192 | cursor, ok := resp.Result["nextCursor"].(string) 193 | if !ok || cursor == "" { 194 | return out, nil 195 | } 196 | 197 | currentRequest = mcp.NewMessage("") 198 | currentRequest.ID = msg.ID 199 | currentRequest.Method = msg.Method 200 | currentRequest.Params = map[string]any{"cursor": cursor} 201 | } 202 | } 203 | 204 | func (s *MCPStream) registerOut(ch chan []byte) { s.Lock(); s.outChs[ch] = struct{}{}; s.Unlock() } 205 | func (s *MCPStream) unregisterOut(ch chan []byte) { s.Lock(); delete(s.outChs, ch); s.Unlock() } 206 | 207 | func (s *MCPStream) registerErr(ch chan []byte) { s.Lock(); s.errChs[ch] = struct{}{}; s.Unlock() } 208 | func (s *MCPStream) unregisterErr(ch chan []byte) { s.Lock(); delete(s.errChs, ch); s.Unlock() } 209 | 210 | func (s *MCPStream) registerExit(ch chan error) { s.Lock(); s.exitChs[ch] = struct{}{}; s.Unlock() } 211 | func (s *MCPStream) unregisterExit(ch chan error) { s.Lock(); delete(s.exitChs, ch); s.Unlock() } 212 | 213 | func (s *MCPStream) start(ctx context.Context) { 214 | 215 | go func() { 216 | 217 | for { 218 | select { 219 | 220 | case data := <-s.stdout: 221 | 222 | s.RLock() 223 | for c := range s.outChs { 224 | select { 225 | case c <- data: 226 | default: 227 | slog.Error("stdout message dropped for registered channel") 228 | } 229 | } 230 | s.RUnlock() 231 | 232 | case data := <-s.stderr: 233 | 234 | s.RLock() 235 | for c := range s.errChs { 236 | select { 237 | case c <- data: 238 | default: 239 | slog.Error("stderr message dropped for registered channel") 240 | } 241 | } 242 | s.RUnlock() 243 | 244 | case err := <-s.exit: 245 | 246 | s.RLock() 247 | for c := range s.exitChs { 248 | select { 249 | case c <- err: 250 | default: 251 | slog.Error("exit message dropped for registered channel") 252 | } 253 | } 254 | s.RUnlock() 255 | 256 | case <-ctx.Done(): 257 | 258 | var err error 259 | select { 260 | case err = <-s.exit: 261 | case <-time.After(time.Second): 262 | break 263 | } 264 | 265 | if err == nil { 266 | err = ctx.Err() 267 | } 268 | 269 | s.RLock() 270 | for c := range s.exitChs { 271 | select { 272 | case c <- err: 273 | default: 274 | } 275 | } 276 | s.RUnlock() 277 | } 278 | } 279 | }() 280 | } 281 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module go.acuvity.ai/minibridge 2 | 3 | go 1.24.2 4 | 5 | require ( 6 | go.acuvity.ai/a3s v0.0.0-20250513140342-5386bdb1c75f 7 | go.acuvity.ai/bahamut v0.0.0-20250416135203-fa73110b2604 8 | go.acuvity.ai/elemental v0.0.0-20250430230636-ac931152934a 9 | go.acuvity.ai/manipulate v0.0.0-20250416135246-f7d22e975de4 // indirect 10 | go.acuvity.ai/regolithe v0.0.0-20250321141528-1fe83b60f317 // indirect 11 | go.acuvity.ai/tg v0.0.0-20250220234315-d9494083aa3a 12 | go.acuvity.ai/wsc v0.0.0-20250506232542-8de7ff436ec0 13 | ) 14 | 15 | require ( 16 | github.com/adrg/xdg v0.5.3 17 | github.com/go-viper/mapstructure/v2 v2.4.0 18 | github.com/gofrs/uuid v4.4.0+incompatible 19 | github.com/gorilla/websocket v1.5.3 20 | github.com/karlseguin/ccache/v3 v3.0.6 21 | github.com/mitchellh/mapstructure v1.5.0 22 | github.com/open-policy-agent/opa v1.4.0 23 | github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c 24 | github.com/prometheus/client_golang v1.22.0 25 | github.com/smallnest/ringbuffer v0.0.0-20250317021400-0da97b586904 26 | github.com/smartystreets/goconvey v1.8.1 27 | github.com/spaolacci/murmur3 v1.1.0 28 | github.com/spf13/cobra v1.9.1 29 | github.com/spf13/pflag v1.0.6 30 | github.com/spf13/viper v1.20.1 31 | github.com/zalando/go-keyring v0.2.6 32 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 33 | go.opentelemetry.io/otel v1.35.0 34 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.35.0 35 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0 36 | go.opentelemetry.io/otel/sdk v1.35.0 37 | go.opentelemetry.io/otel/trace v1.35.0 38 | golang.org/x/sync v0.13.0 39 | ) 40 | 41 | require ( 42 | al.essio.dev/pkg/shellescape v1.5.1 // indirect 43 | cloud.google.com/go/compute/metadata v0.6.0 // indirect 44 | dario.cat/mergo v1.0.1 // indirect 45 | github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect 46 | github.com/Microsoft/go-winio v0.6.2 // indirect 47 | github.com/NYTimes/gziphandler v1.1.1 // indirect 48 | github.com/ProtonMail/go-crypto v1.1.5 // indirect 49 | github.com/agnivade/levenshtein v1.2.1 // indirect 50 | github.com/apparentlymart/go-cidr v1.1.0 // indirect 51 | github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect 52 | github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a // indirect 53 | github.com/beorn7/perks v1.0.1 // indirect 54 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect 55 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 56 | github.com/cloudflare/circl v1.6.1 // indirect 57 | github.com/cyphar/filepath-securejoin v0.3.6 // indirect 58 | github.com/danieljoos/wincred v1.2.2 // indirect 59 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 60 | github.com/deckarep/golang-set v1.8.0 // indirect 61 | github.com/emirpasic/gods v1.18.1 // indirect 62 | github.com/fatih/color v1.18.0 // indirect 63 | github.com/fatih/structs v1.1.0 // indirect 64 | github.com/felixge/httpsnoop v1.0.4 // indirect 65 | github.com/fsnotify/fsnotify v1.8.0 // indirect 66 | github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 // indirect 67 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect 68 | github.com/go-git/go-billy/v5 v5.6.1 // indirect 69 | github.com/go-git/go-git/v5 v5.13.1 // indirect 70 | github.com/go-ini/ini v1.67.0 // indirect 71 | github.com/go-logr/logr v1.4.2 // indirect 72 | github.com/go-logr/stdr v1.2.2 // indirect 73 | github.com/go-ole/go-ole v1.2.6 // indirect 74 | github.com/go-zoo/bone v1.3.0 // indirect 75 | github.com/gobwas/glob v0.2.3 // indirect 76 | github.com/godbus/dbus/v5 v5.1.0 // indirect 77 | github.com/golang-jwt/jwt/v5 v5.2.2 // indirect 78 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect 79 | github.com/google/go-tpm v0.9.3 // indirect 80 | github.com/google/uuid v1.6.0 // indirect 81 | github.com/gopherjs/gopherjs v1.17.2 // indirect 82 | github.com/gorilla/mux v1.8.1 // indirect 83 | github.com/gravitational/trace v1.5.0 // indirect 84 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect 85 | github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f // indirect 86 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 87 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect 88 | github.com/jtolds/gls v4.20.0+incompatible // indirect 89 | github.com/karlseguin/ccache/v2 v2.0.8 // indirect 90 | github.com/kevinburke/ssh_config v1.2.0 // indirect 91 | github.com/klauspost/compress v1.18.0 // indirect 92 | github.com/lmittmann/tint v1.0.3 // indirect 93 | github.com/lufia/plan9stats v0.0.0-20230110061619-bbe2e5e100de // indirect 94 | github.com/mailgun/multibuf v0.2.0 // indirect 95 | github.com/mattn/go-colorable v0.1.13 // indirect 96 | github.com/mattn/go-isatty v0.0.20 // indirect 97 | github.com/mdp/qrterminal v1.0.1 // indirect 98 | github.com/minio/highwayhash v1.0.3 // indirect 99 | github.com/mitchellh/copystructure v1.2.0 // indirect 100 | github.com/mitchellh/go-wordwrap v1.0.1 // indirect 101 | github.com/mitchellh/reflectwalk v1.0.2 // indirect 102 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 103 | github.com/nats-io/jwt/v2 v2.7.3 // indirect 104 | github.com/nats-io/nats-server/v2 v2.11.1 // indirect 105 | github.com/nats-io/nats.go v1.39.1 // indirect 106 | github.com/nats-io/nkeys v0.4.10 // indirect 107 | github.com/nats-io/nuid v1.0.1 // indirect 108 | github.com/opentracing/opentracing-go v1.2.0 // indirect 109 | github.com/pelletier/go-toml/v2 v2.2.3 // indirect 110 | github.com/pjbgf/sha1cd v0.3.2 // indirect 111 | github.com/pkg/errors v0.9.1 // indirect 112 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 113 | github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b // indirect 114 | github.com/prometheus/client_model v0.6.1 // indirect 115 | github.com/prometheus/common v0.62.0 // indirect 116 | github.com/prometheus/procfs v0.15.1 // indirect 117 | github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect 118 | github.com/sagikazarmark/locafero v0.7.0 // indirect 119 | github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect 120 | github.com/shirou/gopsutil/v3 v3.24.5 // indirect 121 | github.com/shoenig/go-m1cpu v0.1.6 // indirect 122 | github.com/sirupsen/logrus v1.9.3 // indirect 123 | github.com/skeema/knownhosts v1.3.0 // indirect 124 | github.com/smarty/assertions v1.15.0 // indirect 125 | github.com/sourcegraph/conc v0.3.0 // indirect 126 | github.com/spf13/afero v1.12.0 // indirect 127 | github.com/spf13/cast v1.7.1 // indirect 128 | github.com/subosito/gotenv v1.6.0 // indirect 129 | github.com/tchap/go-patricia/v2 v2.3.2 // indirect 130 | github.com/tklauser/go-sysconf v0.3.12 // indirect 131 | github.com/tklauser/numcpus v0.6.1 // indirect 132 | github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect 133 | github.com/uber/jaeger-lib v2.4.1+incompatible // indirect 134 | github.com/ugorji/go/codec v1.2.12 // indirect 135 | github.com/valyala/tcplisten v1.0.0 // indirect 136 | github.com/vulcand/oxy/v2 v2.0.2 // indirect 137 | github.com/vulcand/predicate v1.2.0 // indirect 138 | github.com/xanzy/ssh-agent v0.3.3 // indirect 139 | github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect 140 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect 141 | github.com/xeipuuv/gojsonschema v1.2.0 // indirect 142 | github.com/yashtewari/glob-intersection v0.2.0 // indirect 143 | github.com/yusufpapurcu/wmi v1.2.4 // indirect 144 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 145 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.35.0 // indirect 146 | go.opentelemetry.io/otel/metric v1.35.0 // indirect 147 | go.opentelemetry.io/proto/otlp v1.5.0 // indirect 148 | go.uber.org/atomic v1.11.0 // indirect 149 | go.uber.org/automaxprocs v1.6.0 // indirect 150 | go.uber.org/multierr v1.11.0 // indirect 151 | golang.org/x/crypto v0.37.0 // indirect 152 | golang.org/x/net v0.39.0 // indirect 153 | golang.org/x/sys v0.32.0 // indirect 154 | golang.org/x/text v0.24.0 // indirect 155 | golang.org/x/time v0.11.0 // indirect 156 | google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect 157 | google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect 158 | google.golang.org/grpc v1.71.1 // indirect 159 | google.golang.org/protobuf v1.36.6 // indirect 160 | gopkg.in/ini.v1 v1.67.0 // indirect 161 | gopkg.in/warnings.v0 v0.1.2 // indirect 162 | gopkg.in/yaml.v2 v2.4.0 // indirect 163 | gopkg.in/yaml.v3 v3.0.1 // indirect 164 | rsc.io/qr v0.2.0 // indirect 165 | sigs.k8s.io/yaml v1.4.0 // indirect 166 | ) 167 | -------------------------------------------------------------------------------- /pkgs/scan/sbom_test.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestSBOM_Matches(t *testing.T) { 8 | type args struct { 9 | o Hashes 10 | } 11 | tests := []struct { 12 | name string 13 | init func(t *testing.T) Hashes 14 | inspect func(r Hashes, t *testing.T) //inspects receiver after test run 15 | 16 | args func(t *testing.T) args 17 | 18 | wantErr bool 19 | inspectErr func(err error, t *testing.T) //use for more precise error evaluation after test 20 | }{ 21 | { 22 | "matching", 23 | func(t *testing.T) Hashes { 24 | return Hashes{ 25 | { 26 | Name: "a1", 27 | Hash: "ah1", 28 | Params: Hashes{ 29 | { 30 | Name: "p1", 31 | Hash: "ph1", 32 | }, 33 | }, 34 | }, 35 | } 36 | }, 37 | nil, 38 | func(*testing.T) args { 39 | return args{ 40 | Hashes{ 41 | { 42 | Name: "a1", 43 | Hash: "ah1", 44 | Params: Hashes{ 45 | { 46 | Name: "p1", 47 | Hash: "ph1", 48 | }, 49 | }, 50 | }, 51 | }, 52 | } 53 | }, 54 | false, 55 | nil, 56 | }, 57 | { 58 | "no params", 59 | func(t *testing.T) Hashes { 60 | return Hashes{ 61 | { 62 | Name: "a1", 63 | Hash: "ah1", 64 | }, 65 | } 66 | }, 67 | nil, 68 | func(*testing.T) args { 69 | return args{ 70 | Hashes{ 71 | { 72 | Name: "a1", 73 | Hash: "ah1", 74 | }, 75 | }, 76 | } 77 | }, 78 | false, 79 | nil, 80 | }, 81 | { 82 | "empty", 83 | func(t *testing.T) Hashes { 84 | return Hashes{} 85 | }, 86 | nil, 87 | func(*testing.T) args { 88 | return args{ 89 | Hashes{}, 90 | } 91 | }, 92 | false, 93 | nil, 94 | }, 95 | { 96 | "missing param", 97 | func(t *testing.T) Hashes { 98 | return Hashes{ 99 | { 100 | Name: "a1", 101 | Hash: "ah1", 102 | Params: Hashes{ 103 | { 104 | Name: "p1", 105 | Hash: "ph1", 106 | }, 107 | { 108 | Name: "p2", 109 | Hash: "ph2", 110 | }, 111 | }, 112 | }, 113 | } 114 | }, 115 | nil, 116 | func(*testing.T) args { 117 | return args{ 118 | Hashes{ 119 | { 120 | Name: "a1", 121 | Hash: "ah1", 122 | Params: Hashes{ 123 | { 124 | Name: "p1", 125 | Hash: "ph1", 126 | }, 127 | }, 128 | }, 129 | }, 130 | } 131 | }, 132 | false, 133 | nil, 134 | }, 135 | { 136 | "extra param", 137 | func(t *testing.T) Hashes { 138 | return Hashes{ 139 | { 140 | Name: "a1", 141 | Hash: "ah1", 142 | Params: Hashes{ 143 | { 144 | Name: "p1", 145 | Hash: "ph1", 146 | }, 147 | }, 148 | }, 149 | } 150 | }, 151 | nil, 152 | func(*testing.T) args { 153 | return args{ 154 | Hashes{ 155 | { 156 | Name: "a1", 157 | Hash: "ah1", 158 | Params: Hashes{ 159 | { 160 | Name: "p1", 161 | Hash: "ph1", 162 | }, 163 | { 164 | Name: "p2", 165 | Hash: "ph2", 166 | }, 167 | }, 168 | }, 169 | }, 170 | } 171 | }, 172 | true, 173 | func(err error, t *testing.T) { 174 | want := "'a1': invalid param: invalid len. left: 1 right: 2" 175 | if err.Error() != want { 176 | t.Logf("invalid err. want: %s got: %s", want, err.Error()) 177 | t.Fail() 178 | } 179 | }, 180 | }, 181 | { 182 | "missing tool", 183 | func(t *testing.T) Hashes { 184 | return Hashes{ 185 | { 186 | Name: "a1", 187 | Hash: "ah1", 188 | Params: Hashes{ 189 | { 190 | Name: "p1", 191 | Hash: "ph1", 192 | }, 193 | }, 194 | }, 195 | { 196 | Name: "a2", 197 | Hash: "ah2", 198 | Params: Hashes{ 199 | { 200 | Name: "p1", 201 | Hash: "ph1", 202 | }, 203 | }, 204 | }, 205 | } 206 | }, 207 | nil, 208 | func(*testing.T) args { 209 | return args{ 210 | Hashes{ 211 | { 212 | Name: "a1", 213 | Hash: "ah1", 214 | Params: Hashes{ 215 | { 216 | Name: "p1", 217 | Hash: "ph1", 218 | }, 219 | }, 220 | }, 221 | }, 222 | } 223 | }, 224 | false, 225 | nil, 226 | }, 227 | { 228 | "extra tool", 229 | func(t *testing.T) Hashes { 230 | return Hashes{ 231 | { 232 | Name: "a1", 233 | Hash: "ah1", 234 | Params: Hashes{ 235 | { 236 | Name: "p1", 237 | Hash: "ph1", 238 | }, 239 | }, 240 | }, 241 | } 242 | }, 243 | nil, 244 | func(*testing.T) args { 245 | return args{ 246 | Hashes{ 247 | { 248 | Name: "a1", 249 | Hash: "ah1", 250 | Params: Hashes{ 251 | { 252 | Name: "p1", 253 | Hash: "ph1", 254 | }, 255 | }, 256 | }, 257 | { 258 | Name: "a2", 259 | Hash: "ah2", 260 | Params: Hashes{ 261 | { 262 | Name: "p1", 263 | Hash: "ph1", 264 | }, 265 | }, 266 | }, 267 | }, 268 | } 269 | }, 270 | true, 271 | func(err error, t *testing.T) { 272 | want := "invalid len. left: 1 right: 2" 273 | if err.Error() != want { 274 | t.Logf("invalid err. want: %s got: %s", want, err.Error()) 275 | t.Fail() 276 | } 277 | }, 278 | }, 279 | { 280 | "invalid tool hash", 281 | func(t *testing.T) Hashes { 282 | return Hashes{ 283 | { 284 | Name: "a1", 285 | Hash: "ah1", 286 | Params: Hashes{ 287 | { 288 | Name: "p1", 289 | Hash: "ph1", 290 | }, 291 | }, 292 | }, 293 | } 294 | }, 295 | nil, 296 | func(*testing.T) args { 297 | return args{ 298 | Hashes{ 299 | { 300 | Name: "a1", 301 | Hash: "NOT_ah1", 302 | Params: Hashes{ 303 | { 304 | Name: "p1", 305 | Hash: "ph1", 306 | }, 307 | }, 308 | }, 309 | }, 310 | } 311 | }, 312 | true, 313 | func(err error, t *testing.T) { 314 | want := "'a1': hash mismatch" 315 | if err.Error() != want { 316 | t.Logf("invalid err. want: %s got: %s", want, err.Error()) 317 | t.Fail() 318 | } 319 | }, 320 | }, 321 | { 322 | "invalid param hash", 323 | func(t *testing.T) Hashes { 324 | return Hashes{ 325 | { 326 | Name: "a1", 327 | Hash: "ah1", 328 | Params: Hashes{ 329 | { 330 | Name: "p1", 331 | Hash: "ph1", 332 | }, 333 | }, 334 | }, 335 | } 336 | }, 337 | nil, 338 | func(*testing.T) args { 339 | return args{ 340 | Hashes{ 341 | { 342 | Name: "a1", 343 | Hash: "ah1", 344 | Params: Hashes{ 345 | { 346 | Name: "p1", 347 | Hash: "NOT-ph1", 348 | }, 349 | }, 350 | }, 351 | }, 352 | } 353 | }, 354 | true, 355 | func(err error, t *testing.T) { 356 | want := "'a1': invalid param: 'p1': hash mismatch" 357 | if err.Error() != want { 358 | t.Logf("invalid err. want: %s got: %s", want, err.Error()) 359 | t.Fail() 360 | } 361 | }, 362 | }, 363 | { 364 | "same len, different tool", 365 | func(t *testing.T) Hashes { 366 | return Hashes{ 367 | { 368 | Name: "a1", 369 | Hash: "ah1", 370 | Params: Hashes{ 371 | { 372 | Name: "p1", 373 | Hash: "ph1", 374 | }, 375 | }, 376 | }, 377 | } 378 | }, 379 | nil, 380 | func(*testing.T) args { 381 | return args{ 382 | Hashes{ 383 | { 384 | Name: "b1", 385 | Hash: "bh1", 386 | Params: Hashes{ 387 | { 388 | Name: "p1", 389 | Hash: "ph1", 390 | }, 391 | }, 392 | }, 393 | }, 394 | } 395 | }, 396 | true, 397 | func(err error, t *testing.T) { 398 | want := "'b1': missing" 399 | if err.Error() != want { 400 | t.Logf("invalid err. want: %s got: %s", want, err.Error()) 401 | t.Fail() 402 | } 403 | }, 404 | }, 405 | { 406 | "param name missing", 407 | func(t *testing.T) Hashes { 408 | return Hashes{ 409 | { 410 | Name: "a1", 411 | Hash: "ah1", 412 | Params: Hashes{ 413 | { 414 | Name: "p1", 415 | Hash: "ph1", 416 | }, 417 | }, 418 | }, 419 | } 420 | }, 421 | nil, 422 | func(*testing.T) args { 423 | return args{ 424 | Hashes{ 425 | { 426 | Name: "a1", 427 | Hash: "ah1", 428 | Params: Hashes{ 429 | { 430 | Name: "q1", 431 | Hash: "qh1", 432 | }, 433 | }, 434 | }, 435 | }, 436 | } 437 | }, 438 | true, 439 | func(err error, t *testing.T) { 440 | want := "'a1': invalid param: 'q1': missing" 441 | if err.Error() != want { 442 | t.Logf("invalid err. want: %s got: %s", want, err.Error()) 443 | t.Fail() 444 | } 445 | }, 446 | }, 447 | } 448 | 449 | for _, tt := range tests { 450 | t.Run(tt.name, func(t *testing.T) { 451 | tArgs := tt.args(t) 452 | 453 | receiver := tt.init(t) 454 | err := receiver.Matches(tArgs.o) 455 | 456 | if tt.inspect != nil { 457 | tt.inspect(receiver, t) 458 | } 459 | 460 | if (err != nil) != tt.wantErr { 461 | t.Fatalf("SBOM.Matches error = %v, wantErr: %t", err, tt.wantErr) 462 | } 463 | 464 | if tt.inspectErr != nil { 465 | tt.inspectErr(err, t) 466 | } 467 | }) 468 | } 469 | } 470 | --------------------------------------------------------------------------------