├── .covignore
├── examples
├── policer-http
│ ├── requirements.txt
│ └── policer.py
├── policer-rego
│ ├── simple.rego
│ ├── basic-auth.rego
│ └── police.rego
└── gen-test-tokens.sh
├── assets
└── imgs
│ ├── otel.png
│ └── mb-arch-overview.png
├── pkgs
├── policer
│ ├── api
│ │ ├── error.go
│ │ ├── response.go
│ │ └── request.go
│ ├── interface.go
│ └── internal
│ │ ├── rego
│ │ ├── utils.go
│ │ └── policer.go
│ │ └── http
│ │ └── policer.go
├── backend
│ ├── client
│ │ ├── cmd_others.go
│ │ ├── cmd_darwin.go
│ │ ├── cmd_linux.go
│ │ ├── interface.go
│ │ ├── server.go
│ │ ├── server_test.go
│ │ ├── options.go
│ │ ├── options_test.go
│ │ ├── stdio.go
│ │ ├── stdio_test.go
│ │ ├── sse.go
│ │ ├── stream_test.go
│ │ └── stream.go
│ ├── interface.go
│ ├── carrier.go
│ ├── helpers.go
│ ├── carrier_test.go
│ ├── options_test.go
│ ├── helpers_test.go
│ ├── options.go
│ └── ws_test.go
├── internal
│ ├── sanitize
│ │ ├── data.go
│ │ └── data_test.go
│ └── cors
│ │ ├── cors.go
│ │ └── cors_test.go
├── info
│ └── info.go
├── memconn
│ ├── addr.go
│ ├── listener.go
│ └── memconn.go
├── mcp
│ ├── notification.go
│ ├── error.go
│ ├── types.go
│ ├── message.go
│ └── id.go
├── frontend
│ ├── hash.go
│ ├── interface.go
│ ├── info.go
│ ├── internal
│ │ └── session
│ │ │ ├── manager.go
│ │ │ ├── manager_test.go
│ │ │ └── session.go
│ ├── connect.go
│ ├── options.go
│ └── stdio.go
├── auth
│ ├── auth_test.go
│ └── auth.go
├── oauth
│ ├── keyring.go
│ ├── oauth.go
│ ├── data.go
│ └── dance.go
├── scan
│ ├── sbom.go
│ ├── scan.go
│ └── sbom_test.go
└── metrics
│ └── manager.go
├── .gitignore
├── main.go
├── ROADMAP.md
├── .github
└── workflows
│ ├── cov.yaml
│ └── build.yaml
├── cli
├── internal
│ └── cmd
│ │ ├── completion.go
│ │ ├── root.go
│ │ ├── backend.go
│ │ ├── flags.go
│ │ ├── scan.go
│ │ ├── frontend.go
│ │ └── aio.go
├── main.go
└── init.go
├── Makefile
├── .goreleaser.yml
├── README.md
└── go.mod
/.covignore:
--------------------------------------------------------------------------------
1 | cli/*
2 |
--------------------------------------------------------------------------------
/examples/policer-http/requirements.txt:
--------------------------------------------------------------------------------
1 | PyJWT
2 |
--------------------------------------------------------------------------------
/assets/imgs/otel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acuvity/minibridge/HEAD/assets/imgs/otel.png
--------------------------------------------------------------------------------
/assets/imgs/mb-arch-overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acuvity/minibridge/HEAD/assets/imgs/mb-arch-overview.png
--------------------------------------------------------------------------------
/pkgs/policer/api/error.go:
--------------------------------------------------------------------------------
1 | package api
2 |
3 | import "errors"
4 |
5 | var ErrBlocked = errors.New("request blocked")
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | minibridge
2 | dist
3 | *.exe
4 | *.pem
5 | *.token
6 | *.sbom
7 | unit_coverage.out
8 | go.work
9 | go.work.sum
10 |
--------------------------------------------------------------------------------
/examples/policer-rego/simple.rego:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import rego.v1
4 |
5 | # This is most basic policy we can have It will allow everything by default.
6 | default allow := true
7 |
--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 |
6 | "go.acuvity.ai/minibridge/cli"
7 | )
8 |
9 | func main() {
10 |
11 | ctx, cancel := context.WithCancel(context.Background())
12 | defer cancel()
13 |
14 | cli.Main(ctx)
15 | }
16 |
--------------------------------------------------------------------------------
/pkgs/backend/client/cmd_others.go:
--------------------------------------------------------------------------------
1 | //go:build !darwin && !linux
2 |
3 | package client
4 |
5 | import (
6 | "os/exec"
7 | "syscall"
8 | )
9 |
10 | func setCaps(cmd *exec.Cmd, _ string, _ *creds) {
11 | cmd.SysProcAttr = &syscall.SysProcAttr{}
12 | }
13 |
--------------------------------------------------------------------------------
/pkgs/internal/sanitize/data.go:
--------------------------------------------------------------------------------
1 | package sanitize
2 |
3 | import "bytes"
4 |
5 | // Data sanitizes the data for internal
6 | // transport. It removed all trailing '\n' and `\r`
7 | func Data(data []byte) []byte {
8 |
9 | return bytes.TrimRight(data, "\n\r")
10 | }
11 |
--------------------------------------------------------------------------------
/pkgs/info/info.go:
--------------------------------------------------------------------------------
1 | package info
2 |
3 | type Info struct {
4 | OAuthAuthorize bool `json:"oauthAuthorize"`
5 | OAuthRegister bool `json:"oauthRegister"`
6 | OAuthToken bool `json:"oauthToken"`
7 | OAuthMetadata bool `json:"oauthMetadata"`
8 | Type string `json:"type"`
9 | Server string `json:"server"`
10 | }
11 |
--------------------------------------------------------------------------------
/pkgs/backend/interface.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "context"
5 | )
6 |
7 | // A Backend is the interface of object that can
8 | // act as a minibridge Backend.
9 | type Backend interface {
10 |
11 | // Sarts starts the backend. It will run in background
12 | // until the given context is done.
13 | Start(context.Context) error
14 | }
15 |
--------------------------------------------------------------------------------
/pkgs/memconn/addr.go:
--------------------------------------------------------------------------------
1 | package memconn
2 |
3 | import "net"
4 |
5 | // Addr is the address of a memlistener.
6 | type Addr struct {
7 | name string
8 | }
9 |
10 | var _ net.Addr = (*Addr)(nil)
11 |
12 | // Network implements net.Addr. Returns "memory."
13 | func (Addr) Network() string { return "memory" }
14 |
15 | // String implements net.Addr. Returns "memory."
16 | func (a Addr) String() string { return a.name }
17 |
--------------------------------------------------------------------------------
/pkgs/mcp/notification.go:
--------------------------------------------------------------------------------
1 | package mcp
2 |
3 | type Notification struct {
4 | JSONRPC string `json:"jsonrpc"`
5 | Method string `json:"method,omitempty"`
6 | Params map[string]any `json:"params,omitempty"`
7 | }
8 |
9 | // NewNotification returns a new notification.
10 | func NewNotification(name string) Notification {
11 | return Notification{
12 | JSONRPC: "2.0",
13 | Method: name,
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/pkgs/mcp/error.go:
--------------------------------------------------------------------------------
1 | package mcp
2 |
3 | // An Error represents an inline MPC error.
4 | type Error struct {
5 | Code int `json:"code"`
6 | Message string `json:"message,omitempty"`
7 | Data any `json:"data,omitempty"`
8 | }
9 |
10 | // NewError returns an *MCPError with code 500
11 | // and the given error
12 | func NewError(err error) *Error {
13 | return &Error{
14 | Code: 500,
15 | Message: err.Error(),
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/pkgs/frontend/hash.go:
--------------------------------------------------------------------------------
1 | package frontend
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/spaolacci/murmur3"
7 | "go.acuvity.ai/minibridge/pkgs/auth"
8 | )
9 |
10 | func hash(v string) uint64 {
11 | return murmur3.Sum64([]byte(v)) & 0x7FFFFFFFFFFFFFFF // #nosec G115
12 | }
13 |
14 | func hashCreds(auth *auth.Auth, authHeaders []string) uint64 {
15 | if auth != nil {
16 | return hash(fmt.Sprintf("%s-%v", auth.Encode(), authHeaders))
17 | }
18 | return hash(fmt.Sprintf("%v", authHeaders))
19 | }
20 |
--------------------------------------------------------------------------------
/pkgs/backend/client/cmd_darwin.go:
--------------------------------------------------------------------------------
1 | //go:build darwin
2 |
3 | package client
4 |
5 | import (
6 | "os/exec"
7 | "syscall"
8 | )
9 |
10 | func setCaps(cmd *exec.Cmd, chroot string, creds *creds) {
11 |
12 | var screds *syscall.Credential
13 | if creds != nil {
14 | screds = &syscall.Credential{
15 | Uid: creds.Uid,
16 | Gid: creds.Gid,
17 | Groups: creds.Groups,
18 | }
19 | }
20 |
21 | cmd.SysProcAttr = &syscall.SysProcAttr{
22 | Chroot: chroot,
23 | Credential: screds,
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/pkgs/backend/client/cmd_linux.go:
--------------------------------------------------------------------------------
1 | //go:build linux
2 |
3 | package client
4 |
5 | import (
6 | "os/exec"
7 | "syscall"
8 | )
9 |
10 | func setCaps(cmd *exec.Cmd, chroot string, creds *creds) {
11 |
12 | var screds *syscall.Credential
13 | if creds != nil {
14 | screds = &syscall.Credential{
15 | Uid: creds.Uid,
16 | Gid: creds.Gid,
17 | Groups: creds.Groups,
18 | }
19 | }
20 |
21 | cmd.SysProcAttr = &syscall.SysProcAttr{
22 | Pdeathsig: syscall.SIGKILL,
23 | Chroot: chroot,
24 | Credential: screds,
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/examples/gen-test-tokens.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if ! which jwt >/dev/null; then
4 | echo "you must install jwt tool from https://github.com/mike-engel/jwt-cli"
5 | exit 1
6 | fi
7 |
8 | echo "token for Alice"
9 | jwt encode --secret secret --iss pki.example.com -P email=alice@example.com --aud minibridge
10 |
11 | echo
12 | echo "token for Bob"
13 | jwt encode --secret secret --iss pki.example.com -P email=bob@example.com --aud minibridge
14 |
15 | echo
16 | echo "Token for Eve"
17 | jwt encode --secret "not-secret" --iss pki.example.com -P email=eve@example.com --aud minibridge
18 |
--------------------------------------------------------------------------------
/examples/policer-rego/basic-auth.rego:
--------------------------------------------------------------------------------
1 | # This policy enforces a deny‐all default and, if the BASIC_AUTH env var is set,
2 | # compares it to input.agent.password passed by the client as basic Authentication headers
3 | # The request with the reason “invalid credentials” when they don’t match.
4 | # To set BASIC_AUTH run export REGO_POLICY_RUNTIME_BASIC_AUTH="s3cr3t"
5 | package main
6 |
7 | import rego.v1
8 |
9 | secret := opa.runtime().env.BASIC_AUTH
10 |
11 | reasons contains "access denied" if {
12 | secret != ""
13 | input.agent.password != secret
14 | }
15 |
16 | default allow := false
17 |
18 | allow if {
19 | count(reasons) == 0
20 | }
21 |
--------------------------------------------------------------------------------
/pkgs/backend/client/interface.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "net/http"
6 |
7 | "go.acuvity.ai/minibridge/pkgs/auth"
8 | )
9 |
10 | type cfg struct {
11 | auth *auth.Auth
12 | }
13 |
14 | type Option func(*cfg)
15 |
16 | func OptionAuth(a *auth.Auth) Option {
17 | return func(c *cfg) {
18 | c.auth = a
19 | }
20 | }
21 |
22 | // A Client is the interface of object that can
23 | // act as a minibridge mcp Client.
24 | type Client interface {
25 | Start(context.Context, ...Option) (*MCPStream, error)
26 | Type() string
27 | Server() string
28 | }
29 |
30 | type RemoteClient interface {
31 | HTTPClient() *http.Client
32 | BaseURL() string
33 | Client
34 | }
35 |
--------------------------------------------------------------------------------
/pkgs/backend/client/server.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "fmt"
5 | "os/exec"
6 | )
7 |
8 | // A MCPServer contains the information needed
9 | // to launch an MCP Server.
10 | type MCPServer struct {
11 | Command string
12 | Args []string
13 | Env []string
14 | }
15 |
16 | // NewMCPServer returtns a new MCPServer. Returns an error is the given cmd path
17 | // does not exist.
18 | func NewMCPServer(path string, args ...string) (MCPServer, error) {
19 | p, err := exec.LookPath(path)
20 | if err != nil {
21 | return MCPServer{}, fmt.Errorf("unable to find server binary: %w", err)
22 | }
23 | return MCPServer{
24 | Command: p,
25 | Args: args,
26 | }, nil
27 | }
28 |
--------------------------------------------------------------------------------
/ROADMAP.md:
--------------------------------------------------------------------------------
1 | # Roadmap
2 |
3 | ## In Progress
4 |
5 | - [ ] Unit tests
6 | - [ ] Wiki with tutorials on specific parts
7 | - [ ] Support for 2025-03-26 (pr #21)
8 |
9 | ## Todo
10 |
11 | - [ ] Advanced sandboxing when not running in containers (firejail/bubblewarp)
12 | - [ ] Support for shared MCP server (when using a Policer)
13 | - [ ] Add MTLS policer
14 | - [ ] Add A3S Policer
15 | - [ ] Add DScope Policer
16 |
17 | ## Done
18 |
19 | - [x] Transport user information over the websocket channel
20 | - [x] Support for user extraction to pass to the policer
21 | - [x] Plug in prometheus metrics
22 | - [x] Opentelemetry
23 | - [x] Optimize communications between front/back in aio mode (use memconn)
24 |
--------------------------------------------------------------------------------
/pkgs/backend/client/server_test.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "testing"
5 |
6 | . "github.com/smartystreets/goconvey/convey"
7 | )
8 |
9 | func TestMCPServer(t *testing.T) {
10 |
11 | Convey("calling NewMCPServer on existing bin should work", t, func() {
12 | srv, err := NewMCPServer("echo", "hello")
13 | So(err, ShouldBeNil)
14 | So(srv.Command, ShouldEndWith, "/bin/echo")
15 | So(srv.Args, ShouldResemble, []string{"hello"})
16 | })
17 |
18 | Convey("calling NewMCPServer on non exiting bin should work", t, func() {
19 | _, err := NewMCPServer("not-echo", "hello")
20 | So(err, ShouldNotBeNil)
21 | So(err.Error(), ShouldEqual, `unable to find server binary: exec: "not-echo": executable file not found in $PATH`)
22 | })
23 | }
24 |
--------------------------------------------------------------------------------
/pkgs/auth/auth_test.go:
--------------------------------------------------------------------------------
1 | package auth
2 |
3 | import (
4 | "testing"
5 |
6 | . "github.com/smartystreets/goconvey/convey"
7 | )
8 |
9 | func TestAuth(t *testing.T) {
10 |
11 | Convey("Basic auth should work", t, func() {
12 | auth := NewBasicAuth("user", "pass")
13 | So(auth.Type(), ShouldEqual, "Basic")
14 | So(auth.User(), ShouldEqual, "user")
15 | So(auth.Password(), ShouldEqual, "pass")
16 | So(auth.Encode(), ShouldEqual, "Basic dXNlcjpwYXNz")
17 | })
18 |
19 | Convey("Bearer auth should work", t, func() {
20 | auth := NewBearerAuth("token")
21 | So(auth.Type(), ShouldEqual, "Bearer")
22 | So(auth.User(), ShouldEqual, "Bearer")
23 | So(auth.Password(), ShouldEqual, "token")
24 | So(auth.Encode(), ShouldEqual, "Bearer token")
25 | })
26 | }
27 |
--------------------------------------------------------------------------------
/pkgs/frontend/interface.go:
--------------------------------------------------------------------------------
1 | package frontend
2 |
3 | import (
4 | "context"
5 | "net/http"
6 |
7 | "go.acuvity.ai/minibridge/pkgs/auth"
8 | "go.acuvity.ai/minibridge/pkgs/info"
9 | )
10 |
11 | // A Frontend is the interface of object that can
12 | // act as a minibridge Frontend.
13 | type Frontend interface {
14 |
15 | // Start starts the frontend. It will run in background until
16 | // the given context is done.
17 | Start(context.Context, *auth.Auth) error
18 |
19 | // BackendURL returns the backend URL the frontend connects to.
20 | BackendURL() string
21 |
22 | // HTTPClient returns a client that can be used to communicate
23 | // with the backend.
24 | HTTPClient() *http.Client
25 |
26 | // BackendInfo queries the backend for information.
27 | BackendInfo() (info.Info, error)
28 | }
29 |
--------------------------------------------------------------------------------
/.github/workflows/cov.yaml:
--------------------------------------------------------------------------------
1 | name: cov
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["build-go"]
6 | types:
7 | - completed
8 |
9 | env:
10 | GO111MODULE: on
11 | GOPRIVATE: github.com/acuvity,go.acuvity.ai
12 | GOPROXY: https://proxy.golang.org,direct
13 | GOTOKEN: ${{ secrets.GO_PRIVATE_REPO_PAT }}
14 |
15 | jobs:
16 | cov:
17 | runs-on: ubuntu-latest
18 | permissions: write-all
19 | steps:
20 | - name: setup
21 | run: |
22 | git config --global url."https://acuvity:${GOTOKEN}@github.com/acuvity".insteadOf "https://github.com/acuvity"
23 | - uses: acuvity/cov@1.0.2
24 | with:
25 | cov_mode: send-status
26 | workflow_run_id: ${{github.event.workflow_run.id}}
27 | workflow_head_sha: ${{github.event.workflow_run.head_sha}}
28 |
--------------------------------------------------------------------------------
/pkgs/policer/interface.go:
--------------------------------------------------------------------------------
1 | package policer
2 |
3 | import (
4 | "context"
5 | "crypto/tls"
6 |
7 | "go.acuvity.ai/minibridge/pkgs/auth"
8 | "go.acuvity.ai/minibridge/pkgs/mcp"
9 | "go.acuvity.ai/minibridge/pkgs/policer/api"
10 | "go.acuvity.ai/minibridge/pkgs/policer/internal/http"
11 | "go.acuvity.ai/minibridge/pkgs/policer/internal/rego"
12 | )
13 |
14 | // A Policer is the interface of objects that can police request.
15 | type Policer interface {
16 | Police(context.Context, api.Request) (*mcp.Message, error)
17 | Type() string
18 | }
19 |
20 | // NewRego returns a new rego based Policer.
21 | func NewRego(policy string) (Policer, error) {
22 | return rego.New(policy)
23 | }
24 |
25 | // NewHTTP returns a new HTTP based Policer
26 | func NewHTTP(endpoint string, auth *auth.Auth, tlsConfig *tls.Config) Policer {
27 | return http.New(endpoint, auth, tlsConfig)
28 | }
29 |
--------------------------------------------------------------------------------
/pkgs/policer/api/response.go:
--------------------------------------------------------------------------------
1 | package api
2 |
3 | import "go.acuvity.ai/minibridge/pkgs/mcp"
4 |
5 | // GenericDenyReason is the generic reason returned if non is provided.
6 | const GenericDenyReason = "You are not allowed to perform this operation"
7 |
8 | // A Response is returned by the Policer.
9 | type Response struct {
10 |
11 | // If true, the request is allowed. Otherwise
12 | // it will be rejected for the reasons set in
13 | // Messages.
14 | Allow bool `json:"allow"`
15 |
16 | // Reasons contains reasons for denying the request.
17 | // If no message is given, and Allow is false, a generic
18 | // forbidden message will be used.
19 | Reasons []string `json:"reasons,omitempty"`
20 |
21 | // If non-zero, replace the request MCP call with
22 | // this one. This allows Policers to modify the content
23 | // of an MCP call.
24 | MCP *mcp.Message `json:"mcp,omitempty"`
25 | }
26 |
--------------------------------------------------------------------------------
/cli/internal/cmd/completion.go:
--------------------------------------------------------------------------------
1 | package cmd
2 |
3 | import (
4 | "github.com/spf13/cobra"
5 | "os"
6 | )
7 |
8 | var Completion = &cobra.Command{
9 | Use: "completion [bash|zsh|fish|powershell]",
10 | Short: "Generate completion script",
11 | DisableFlagsInUseLine: true,
12 | ValidArgs: []string{"bash", "zsh", "fish", "powershell"},
13 | Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
14 | RunE: func(cmd *cobra.Command, args []string) error {
15 | switch args[0] {
16 | case "bash":
17 | return cmd.Root().GenBashCompletion(os.Stdout)
18 | case "zsh":
19 | return cmd.Root().GenZshCompletion(os.Stdout)
20 | case "fish":
21 | return cmd.Root().GenFishCompletion(os.Stdout, true)
22 | case "powershell":
23 | return cmd.Root().GenPowerShellCompletionWithDesc(os.Stdout)
24 | }
25 |
26 | return nil
27 | },
28 | }
29 |
--------------------------------------------------------------------------------
/pkgs/internal/cors/cors.go:
--------------------------------------------------------------------------------
1 | package cors
2 |
3 | import (
4 | "net/http"
5 |
6 | "go.acuvity.ai/bahamut"
7 | )
8 |
9 | // HandleCORS handles CORS for browsers.
10 | func HandleCORS(w http.ResponseWriter, req *http.Request, corsPolicy *bahamut.CORSPolicy) bool {
11 |
12 | w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
13 | w.Header().Set("X-Frame-Options", "DENY")
14 | w.Header().Set("X-Content-Type-Options", "nosniff")
15 | w.Header().Set("X-Xss-Protection", "1; mode=block")
16 | w.Header().Set("Cache-Control", "private, no-transform")
17 |
18 | if corsPolicy == nil {
19 | return true
20 | }
21 |
22 | w.Header().Del("Origin")
23 |
24 | if req.Method == http.MethodOptions {
25 | corsPolicy.Inject(w.Header(), req.Header.Get("Origin"), true)
26 | w.WriteHeader(http.StatusNoContent)
27 | return false
28 | }
29 |
30 | corsPolicy.Inject(w.Header(), req.Header.Get("Origin"), false)
31 |
32 | return true
33 | }
34 |
--------------------------------------------------------------------------------
/pkgs/frontend/info.go:
--------------------------------------------------------------------------------
1 | package frontend
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "net/http"
7 |
8 | "go.acuvity.ai/elemental"
9 | "go.acuvity.ai/minibridge/pkgs/info"
10 | )
11 |
12 | func getBackendInfo(mfrontend Frontend) (info.Info, error) {
13 |
14 | inf := info.Info{}
15 | cl := mfrontend.HTTPClient()
16 |
17 | resp, err := cl.Get(fmt.Sprintf("%s/_info", mfrontend.BackendURL()))
18 | if err != nil {
19 | return inf, fmt.Errorf("unable to make backend info request: %w", err)
20 | }
21 | defer func() { _ = resp.Body.Close() }()
22 |
23 | if resp.StatusCode != http.StatusOK {
24 | return inf, fmt.Errorf("invalid backend info response status: %s", resp.Status)
25 | }
26 |
27 | data, err := io.ReadAll(resp.Body)
28 | if err != nil {
29 | return inf, fmt.Errorf("unable to read backend info response body: %w", err)
30 | }
31 |
32 | if err := elemental.Decode(elemental.EncodingTypeJSON, data, &inf); err != nil {
33 | return inf, fmt.Errorf("unable to decode backend info response body: %w", err)
34 | }
35 |
36 | return inf, nil
37 | }
38 |
--------------------------------------------------------------------------------
/pkgs/oauth/keyring.go:
--------------------------------------------------------------------------------
1 | package oauth
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/zalando/go-keyring"
7 | "go.acuvity.ai/elemental"
8 | )
9 |
10 | func TokenToKeyring(server string, t Credentials) error {
11 |
12 | data, err := elemental.Encode(elemental.EncodingTypeJSON, t)
13 | if err != nil {
14 | return fmt.Errorf("unable to encode token data: %w", err)
15 | }
16 |
17 | if err := keyring.Set("minibridge", server, string(data)); err != nil {
18 | return fmt.Errorf("unable to store token into keyring: %w", err)
19 | }
20 |
21 | return nil
22 | }
23 |
24 | func TokenFromKeyring(server string) (Credentials, error) {
25 |
26 | data, err := keyring.Get("minibridge", server)
27 | if err != nil {
28 | return Credentials{}, fmt.Errorf("unable to retrieve token from keyring: %w", err)
29 | }
30 |
31 | t := Credentials{}
32 | if err := elemental.Decode(elemental.EncodingTypeJSON, []byte(data), &t); err != nil {
33 | return Credentials{}, fmt.Errorf("unable to decode token data from keychain: %w", err)
34 | }
35 |
36 | return t, nil
37 | }
38 |
--------------------------------------------------------------------------------
/pkgs/backend/carrier.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "go.acuvity.ai/minibridge/pkgs/mcp"
5 | "go.opentelemetry.io/otel/propagation"
6 | )
7 |
8 | var _ propagation.TextMapCarrier = metaCarrier{}
9 |
10 | type metaCarrier struct {
11 | meta map[string]string
12 | }
13 |
14 | func newMCPMetaCarrier(call mcp.Message) metaCarrier {
15 |
16 | meta := map[string]string{}
17 |
18 | if call.Params != nil {
19 | if pmeta, ok := call.Params["_meta"].(map[string]any); ok {
20 | for k, v := range pmeta {
21 | if s, ok := v.(string); ok {
22 | meta[k] = s
23 | }
24 | }
25 | }
26 | }
27 |
28 | return metaCarrier{
29 | meta: meta,
30 | }
31 | }
32 |
33 | func (c metaCarrier) Get(key string) string {
34 |
35 | v, ok := c.meta[key]
36 | if !ok {
37 | return ""
38 | }
39 |
40 | return v
41 | }
42 |
43 | func (c metaCarrier) Set(key string, value string) {
44 | c.meta[key] = value
45 | }
46 |
47 | func (c metaCarrier) Keys() []string {
48 |
49 | out := make([]string, 0, len(c.meta))
50 | for k := range c.meta {
51 | out = append(out, k)
52 | }
53 |
54 | return out
55 | }
56 |
--------------------------------------------------------------------------------
/cli/main.go:
--------------------------------------------------------------------------------
1 | package cli
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log/slog"
7 | "os"
8 | "os/signal"
9 | "syscall"
10 |
11 | "github.com/spf13/cobra"
12 | "go.acuvity.ai/minibridge/cli/internal/cmd"
13 | )
14 |
15 | // Main is the main run for the cli.
16 | func Main(ctx context.Context) {
17 |
18 | cobra.OnInitialize(initCobra)
19 |
20 | ctx, cancel := context.WithCancel(ctx)
21 | defer cancel()
22 |
23 | installSIGINTHandler(cancel)
24 |
25 | if err := cmd.Root.ExecuteContext(ctx); err != nil {
26 | if _, ok := slog.Default().Handler().(*slog.JSONHandler); ok {
27 | slog.Error("Minibridge exited with error", "err", err)
28 | } else {
29 | fmt.Fprintf(os.Stderr, "error: %s\n", err.Error())
30 | }
31 | os.Exit(1)
32 | }
33 | }
34 |
35 | func installSIGINTHandler(cancel context.CancelFunc) {
36 |
37 | sigs := []os.Signal{syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGABRT}
38 | signalCh := make(chan os.Signal, 1)
39 | signal.Reset(sigs...)
40 | signal.Notify(signalCh, sigs...)
41 |
42 | go func() {
43 | <-signalCh
44 | cancel()
45 | }()
46 | }
47 |
--------------------------------------------------------------------------------
/.github/workflows/build.yaml:
--------------------------------------------------------------------------------
1 | name: build-go
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 |
8 | defaults:
9 | run:
10 | shell: bash
11 |
12 | env:
13 | GOPRIVATE: "github.com/acuvity,go.acuvity.ai"
14 | GOPROXY: "https://proxy.golang.org,direct"
15 | GOTOKEN: ${{ secrets.GO_PRIVATE_REPO_PAT }}
16 |
17 |
18 | jobs:
19 | build:
20 | runs-on: ubuntu-latest
21 | strategy:
22 | fail-fast: false
23 | matrix:
24 | go:
25 | - "1.24"
26 | steps:
27 | - uses: actions/checkout@v4
28 |
29 | - uses: actions/setup-go@v5
30 | with:
31 | go-version: ${{ matrix.go }}
32 | cache: true
33 |
34 | - name: setup
35 | run: |
36 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
37 | go install golang.org/x/vuln/cmd/govulncheck@latest
38 | go install github.com/securego/gosec/v2/cmd/gosec@master
39 |
40 | - name: build
41 | run: |
42 | git config --global url."https://acuvity:${GOTOKEN}@github.com/acuvity".insteadOf "https://github.com/acuvity"
43 | make
44 |
45 | - uses: acuvity/cov@1.0.2
46 | with:
47 | main_branch: main
48 | cov_file: unit_coverage.out
49 | cov_threshold: "0"
50 | cov_mode: coverage
51 |
--------------------------------------------------------------------------------
/pkgs/oauth/oauth.go:
--------------------------------------------------------------------------------
1 | package oauth
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "log/slog"
7 | "net/http"
8 | "strings"
9 | )
10 |
11 | func Forward(baseURL string, client *http.Client, w http.ResponseWriter, req *http.Request, path string) func() {
12 |
13 | u := fmt.Sprintf("%s/%s", baseURL, strings.TrimLeft(path, "/"))
14 |
15 | slog.Debug("OAuth: Forwarding OAuth call", "method", req.Method, "target", u)
16 |
17 | breq, err := http.NewRequestWithContext(req.Context(), req.Method, u, req.Body)
18 | if err != nil {
19 | http.Error(w, fmt.Sprintf("unable to make request: %s", err), http.StatusInternalServerError)
20 | return func() {}
21 | }
22 |
23 | breq.URL.RawQuery = req.URL.Query().Encode()
24 | breq.Header = req.Header.Clone()
25 |
26 | resp, err := client.Do(breq) // nolint: bodyclose
27 | if err != nil {
28 | http.Error(w, err.Error(), http.StatusInternalServerError)
29 | return func() {}
30 | }
31 |
32 | for k, vs := range resp.Header {
33 |
34 | for _, v := range vs {
35 |
36 | if k == "Origin" || strings.HasPrefix(k, "Access-Control-") {
37 | continue
38 | }
39 |
40 | w.Header().Add(k, v)
41 | }
42 | }
43 |
44 | w.WriteHeader(resp.StatusCode)
45 |
46 | if resp.Body != nil {
47 | _, _ = io.Copy(w, resp.Body)
48 | return func() { _ = resp.Body.Close() }
49 | }
50 |
51 | return func() {}
52 | }
53 |
--------------------------------------------------------------------------------
/pkgs/mcp/types.go:
--------------------------------------------------------------------------------
1 | package mcp
2 |
3 | type Tools []Tool
4 | type Tool struct {
5 | Name string `json:"name"`
6 | Description string `json:"description,omitempty"`
7 | InputSchema map[string]any `json:"inputSchema,omitempty"`
8 | }
9 |
10 | type Prompts []*Prompt
11 | type Prompt struct {
12 | Name string `json:"name,omitempty"`
13 | Description string `json:"description,omitempty"`
14 | Arguments PromptArguments `json:"arguments,omitempty"`
15 | }
16 |
17 | type PromptArguments []PromptArgument
18 | type PromptArgument struct {
19 | Name string `json:"name,omitempty"`
20 | Description string `json:"description,omitempty"`
21 | Required bool `json:"required,omitempty"`
22 | }
23 |
24 | type Resources []Resource
25 | type Resource struct {
26 | Name string `json:"name,omitempty"`
27 | Description string `json:"description,omitempty"`
28 | MimeType string `json:"mimeType,omitempty"`
29 | Text string `json:"text,omitempty"`
30 | Blob string `json:"blob,omitempty"`
31 | URI string `json:"uri,omitempty"`
32 | }
33 |
34 | type ResourceTemplates []ResourceTemplate
35 | type ResourceTemplate struct {
36 | Name string `json:"name,omitempty"`
37 | Description string `json:"description,omitempty"`
38 | URITemplate string `json:"uriTemplate,omitempty"`
39 | }
40 |
--------------------------------------------------------------------------------
/pkgs/policer/internal/rego/utils.go:
--------------------------------------------------------------------------------
1 | package rego
2 |
3 | import (
4 | "fmt"
5 | "log/slog"
6 |
7 | "github.com/open-policy-agent/opa/v1/ast"
8 | "github.com/open-policy-agent/opa/v1/topdown/print"
9 | )
10 |
11 | type printer struct{}
12 |
13 | func (p printer) Print(ctx print.Context, s string) error {
14 | slog.Info(fmt.Sprintf("Rego Print: %s", s), "row", ctx.Location.Row)
15 | return nil
16 | }
17 |
18 | func precompile(policy string, name string, modules ...*ast.Module) (*ast.Compiler, error) {
19 |
20 | name = name + ".rego"
21 |
22 | compiler := ast.NewCompiler().WithEnablePrintStatements(true)
23 | module, err := prepareModule("main", policy)
24 | if err != nil {
25 | return nil, err
26 | }
27 |
28 | allModules := map[string]*ast.Module{
29 | name: module,
30 | }
31 | for _, m := range modules {
32 | allModules[m.Package.String()+".rego"] = m
33 | }
34 |
35 | compiler.Compile(allModules)
36 |
37 | if compiler.Failed() {
38 | return nil, fmt.Errorf("unable compile rego module: %w", compiler.Errors)
39 | }
40 |
41 | return compiler, nil
42 | }
43 |
44 | func prepareModule(name string, policy string) (*ast.Module, error) {
45 |
46 | caps := ast.CapabilitiesForThisVersion()
47 |
48 | module, err := ast.ParseModuleWithOpts(
49 | name,
50 | policy,
51 | ast.ParserOptions{
52 | Capabilities: caps,
53 | },
54 | )
55 | if err != nil {
56 | return nil, fmt.Errorf("unable to parse rego module: %w", err)
57 | }
58 |
59 | return module, nil
60 | }
61 |
--------------------------------------------------------------------------------
/cli/internal/cmd/root.go:
--------------------------------------------------------------------------------
1 | package cmd
2 |
3 | import (
4 | "fmt"
5 | "os"
6 |
7 | "github.com/spf13/cobra"
8 | "github.com/spf13/viper"
9 | "go.acuvity.ai/a3s/pkgs/bootstrap"
10 | "go.acuvity.ai/a3s/pkgs/conf"
11 | "go.acuvity.ai/a3s/pkgs/version"
12 | )
13 |
14 | func init() {
15 |
16 | initSharedFlagSet()
17 |
18 | Root.PersistentFlags().String("log-level", "info", "sets the log level.")
19 | Root.PersistentFlags().String("log-format", "console", "sets the log format.")
20 | Root.PersistentFlags().Bool("version", false, "print version and exit.")
21 |
22 | Root.AddCommand(
23 | Backend,
24 | Frontend,
25 | AIO,
26 | Completion,
27 | Scan,
28 | )
29 | }
30 |
31 | var Root = &cobra.Command{
32 | Use: "minibridge",
33 | Short: "Secure your MCP Servers",
34 | SilenceUsage: true,
35 | SilenceErrors: true,
36 | TraverseChildren: true,
37 | PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
38 |
39 | if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil {
40 | return err
41 | }
42 |
43 | if err := viper.BindPFlags(cmd.Flags()); err != nil {
44 | return err
45 | }
46 |
47 | bootstrap.ConfigureLogger("minibridge", conf.LoggingConf{
48 | LogLevel: viper.GetString("log-level"),
49 | LogFormat: viper.GetString("log-format"),
50 | })
51 |
52 | return nil
53 | },
54 | RunE: func(cmd *cobra.Command, args []string) error {
55 | if viper.GetBool("version") {
56 | fmt.Println(version.Short())
57 | os.Exit(0)
58 | }
59 | return cmd.Usage()
60 | },
61 | }
62 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | MAKEFLAGS += --warn-undefined-variables
2 | SHELL := /bin/bash -o pipefail
3 | GIT_SHA=$(shell git rev-parse --short HEAD)
4 | GIT_BRANCH=$(shell git rev-parse --abbrev-ref HEAD)
5 | GIT_TAG=$(shell git describe --tags --abbrev=0 --match='v[0-9]*.[0-9]*.[0-9]*' 2> /dev/null | sed 's/^.//')
6 | BUILD_DATE=$(shell date)
7 | VERSION_PKG="go.acuvity.ai/a3s/pkgs/version"
8 | LDFLAGS = -ldflags="-w -s -X '$(VERSION_PKG).GitSha=$(GIT_SHA)' -X '$(VERSION_PKG).GitBranch=$(GIT_BRANCH)' -X '$(VERSION_PKG).GitTag=$(GIT_TAG)' -X '$(VERSION_PKG).BuildDate=$(BUILD_DATE)'"
9 |
10 | export GO111MODULE = on
11 |
12 | default: lint test build vuln sec
13 |
14 | lint:
15 | golangci-lint run \
16 | --timeout=5m \
17 | --disable=govet \
18 | --enable=errcheck \
19 | --enable=ineffassign \
20 | --enable=unused \
21 | --enable=unconvert \
22 | --enable=misspell \
23 | --enable=prealloc \
24 | --enable=nakedret \
25 | --enable=unparam \
26 | --enable=nilerr \
27 | --enable=bodyclose \
28 | --enable=errorlint \
29 | ./...
30 | test:
31 | go test ./... -vet off -race -cover -covermode=atomic -coverprofile=unit_coverage.out
32 |
33 | sec:
34 | gosec -quiet ./...
35 |
36 | vuln:
37 | govulncheck ./...
38 |
39 | build:
40 | go build $(LDFLAGS) -trimpath .
41 |
42 | remod:
43 | go get go.acuvity.ai/tg@master
44 | go get go.acuvity.ai/wsc@master
45 | go get go.acuvity.ai/regolithe@master
46 | go get go.acuvity.ai/bahamut@master
47 | go get go.acuvity.ai/elemental@master
48 | go get go.acuvity.ai/manipulate@master
49 | go get go.acuvity.ai/a3s@master
50 | go mod tidy
51 |
--------------------------------------------------------------------------------
/pkgs/frontend/internal/session/manager.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import "sync"
4 |
5 | // A Manager manages sessions and keep
6 | // track of them.
7 | type Manager struct {
8 | sessions map[string]*Session
9 | sync.RWMutex
10 | }
11 |
12 | // NewManager returns a new *session.Manager.
13 | func NewManager() *Manager {
14 | return &Manager{
15 | sessions: map[string]*Session{},
16 | }
17 | }
18 |
19 | // Acquire acquires and returns the session with the given sid.
20 | // It returns nil if not session with that sid is found.
21 | // In addition, if ch is not nil, it will be registered as a read hook.
22 | func (p *Manager) Acquire(sid string, ch chan []byte) *Session {
23 |
24 | p.RLock()
25 | defer p.RUnlock()
26 |
27 | s := p.sessions[sid]
28 | if s != nil {
29 | s.acquire()
30 | if ch != nil {
31 | s.register(ch)
32 | }
33 | }
34 |
35 | return s
36 | }
37 |
38 | // Release sessions releases an acquired session.
39 | // if the session is not acquired by anything, the ws connection
40 | // will be closed, and the session deleted.
41 | // In addition, if ch is not nil, it will be unregstered as a read hook.
42 | func (p *Manager) Release(sid string, ch chan []byte) {
43 |
44 | p.Lock()
45 | defer p.Unlock()
46 |
47 | s := p.sessions[sid]
48 | if s == nil {
49 | return
50 | }
51 |
52 | if ch != nil {
53 | s.unregister(ch)
54 | }
55 |
56 | if closed := s.release(); closed {
57 | delete(p.sessions, sid)
58 | }
59 | }
60 |
61 | func (p *Manager) Register(s *Session) {
62 | p.Lock()
63 | defer p.Unlock()
64 |
65 | p.sessions[s.ID()] = s
66 | }
67 |
--------------------------------------------------------------------------------
/pkgs/oauth/data.go:
--------------------------------------------------------------------------------
1 | package oauth
2 |
3 | var successBody = `
4 |
5 |
6 |
7 |
8 |
9 | Authentication Successful
10 |
37 |
38 |
39 |
40 |
44 |
Authentication Successful
45 |
You have successfully authenticated minibridge with %s.
46 |
You can now close this window.
47 |
48 |
49 |
50 | `
51 |
--------------------------------------------------------------------------------
/pkgs/backend/helpers.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "encoding/base64"
5 | "log/slog"
6 | "net/http"
7 | "strings"
8 |
9 | "go.acuvity.ai/elemental"
10 | "go.acuvity.ai/minibridge/pkgs/auth"
11 | "go.acuvity.ai/minibridge/pkgs/mcp"
12 | "go.opentelemetry.io/otel/codes"
13 | "go.opentelemetry.io/otel/trace"
14 | )
15 |
16 | func makeMCPError(ID any, err error) []byte {
17 |
18 | mpcerr := mcp.Message{
19 | JSONRPC: "2.0",
20 | ID: ID,
21 | Error: &mcp.Error{
22 | Code: 451,
23 | Message: err.Error(),
24 | },
25 | }
26 |
27 | data, err := elemental.Encode(elemental.EncodingTypeJSON, mpcerr)
28 | if err != nil {
29 | panic(err)
30 | }
31 |
32 | slog.Debug("Injecting MCP error", "err", string(data))
33 |
34 | return data
35 | }
36 |
37 | func parseBasicAuth(authString string) (a *auth.Auth, ok bool) {
38 |
39 | const prefix = "Basic "
40 |
41 | if len(authString) < len(prefix) || !strings.EqualFold(authString[:len(prefix)], prefix) {
42 |
43 | parts := strings.SplitN(authString, " ", 2)
44 | if len(parts) == 2 {
45 | a = auth.NewBearerAuth(parts[1])
46 | return a, true
47 | }
48 | return nil, false
49 | }
50 |
51 | c, err := base64.StdEncoding.DecodeString(authString[len(prefix):])
52 | if err != nil {
53 | return nil, false
54 | }
55 |
56 | cs := string(c)
57 |
58 | user, password, ok := strings.Cut(cs, ":")
59 | if !ok {
60 | return nil, false
61 | }
62 |
63 | return auth.NewBasicAuth(user, password), true
64 | }
65 |
66 | func hErr(w http.ResponseWriter, message string, code int, span trace.Span) {
67 | http.Error(w, message, code)
68 | span.SetStatus(codes.Error, message)
69 | }
70 |
--------------------------------------------------------------------------------
/pkgs/backend/carrier_test.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "testing"
5 |
6 | . "github.com/smartystreets/goconvey/convey"
7 | "go.acuvity.ai/minibridge/pkgs/mcp"
8 | )
9 |
10 | func TestCarrier(t *testing.T) {
11 |
12 | Convey("MCPCarrier from call with no params", t, func() {
13 | msg := newMCPMetaCarrier(mcp.NewMessage(1))
14 | So(len(msg.meta), ShouldEqual, 0)
15 | So(msg.Get("a"), ShouldBeEmpty)
16 | msg.Set("a", "1")
17 | So(msg.Get("a"), ShouldEqual, "1")
18 | So(msg.Keys(), ShouldResemble, []string{"a"})
19 | })
20 |
21 | Convey("MCPCarrier from call with params but no _meta", t, func() {
22 | msg := mcp.NewMessage(1)
23 | msg.Params = map[string]any{}
24 | c := newMCPMetaCarrier(msg)
25 | So(len(c.meta), ShouldEqual, 0)
26 | })
27 |
28 | Convey("MCPCarrier from call with params with _meta with wrong type", t, func() {
29 | msg := mcp.NewMessage(1)
30 | msg.Params = map[string]any{"_meta": "oh no"}
31 | c := newMCPMetaCarrier(msg)
32 | So(len(c.meta), ShouldEqual, 0)
33 | })
34 |
35 | Convey("MCPCarrier from call with params with wrong _meta value type", t, func() {
36 | msg := mcp.NewMessage(1)
37 | msg.Params = map[string]any{"_meta": map[string]any{"a": 42}}
38 | c := newMCPMetaCarrier(msg)
39 | So(len(c.meta), ShouldEqual, 0)
40 | })
41 |
42 | Convey("MCPCarrier from call with valid _meta", t, func() {
43 | msg := mcp.NewMessage(1)
44 | msg.Params = map[string]any{"_meta": map[string]any{"a": "42"}}
45 | c := newMCPMetaCarrier(msg)
46 | So(len(c.meta), ShouldEqual, 1)
47 | So(c.Get("a"), ShouldEqual, "42")
48 | So(c.Keys(), ShouldResemble, []string{"a"})
49 | So(msg.Params["_meta"], ShouldNotBeNil)
50 | })
51 |
52 | }
53 |
--------------------------------------------------------------------------------
/.goreleaser.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | snapshot:
3 | version_template: "v{{ .Tag }}-next"
4 | changelog:
5 | sort: asc
6 | filters:
7 | exclude:
8 | - "^docs:"
9 | - "^test:"
10 | - "^examples:"
11 | builds:
12 | - id: minibridge
13 | binary: minibridge
14 | goos:
15 | - linux
16 | - darwin
17 | - windows
18 | goarch:
19 | - amd64
20 | - arm64
21 | env:
22 | - CGO_ENABLED=0
23 | ldflags:
24 | - -w -s -X 'go.acuvity.ai/a3s/pkgs/version.GitSha={{.Commit}}' -X 'go.acuvity.ai/a3s/pkgs/version.GitBranch=main' -X 'go.acuvity.ai/a3s/pkgs/version.GitTag={{.Version}}' -X 'go.acuvity.ai/a3s/pkgs/version.BuildDate={{.Date}}'
25 |
26 | archives:
27 | - id: minibridge
28 | formats: ["zip"]
29 | builds:
30 | - minibridge
31 |
32 | signs:
33 | - artifacts: checksum
34 | args:
35 | [
36 | "-u",
37 | "0C3214A61024881F5CA1F5F056EDB08A11DCE325",
38 | "--output",
39 | "${signature}",
40 | "--detach-sign",
41 | "${artifact}",
42 | ]
43 |
44 | brews:
45 | - name: minibridge
46 | homepage: "https://github.com/acuvity/minibridge"
47 | description: Minibridge securely connects Agents to MCP servers, exposing them to the internet while enabling optional integration with remote or local Policers for authentication, analysis, and transformation.
48 | license: "Apache"
49 | repository:
50 | owner: acuvity
51 | name: homebrew-tap
52 | commit_author:
53 | name: goreleaserbot
54 | email: goreleaser@acuvity.ai
55 | directory: Formula
56 | install: |
57 | bin.install "minibridge"
58 | test: |
59 | system "#{bin}/minibridge", "--version"
60 |
--------------------------------------------------------------------------------
/pkgs/backend/client/options.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | )
7 |
8 | type creds struct {
9 | Uid uint32
10 | Gid uint32
11 | Groups []uint32
12 | }
13 |
14 | type stdioCfg struct {
15 | useTempDir bool
16 | creds *creds
17 | }
18 |
19 | func newStdioCfg() stdioCfg {
20 | return stdioCfg{}
21 | }
22 |
23 | // An StdioOption can be passed to the Client.
24 | type StdioOption func(*stdioCfg)
25 |
26 | // OptStdioUseTempDir defines if the the client should
27 | // run the command into it's own working dir. If false,
28 | // the command will run in minibridge current cwd
29 | func OptStdioUseTempDir(use bool) StdioOption {
30 | return func(c *stdioCfg) {
31 | c.useTempDir = use
32 | }
33 | }
34 |
35 | // OptStdioCredentials sets the uid and gid to run the command as.
36 | func OptStdioCredentials(uid int, gid int, groups []int) StdioOption {
37 | return func(c *stdioCfg) {
38 |
39 | grps := make([]uint32, 0, len(groups))
40 |
41 | for i, g := range groups {
42 |
43 | if g < 0 {
44 | continue
45 | }
46 |
47 | if g > math.MaxUint32 {
48 | panic(fmt.Sprintf("invalid group %d. overflows", i))
49 | }
50 |
51 | grps = append(grps, uint32(g)) // #nosec: G115
52 | }
53 |
54 | if len(grps) == 0 && uid < 0 && gid < 0 {
55 | return
56 | }
57 |
58 | c.creds = &creds{Groups: grps}
59 |
60 | if uid > -1 {
61 |
62 | if uid > math.MaxUint32 {
63 | panic("invalid uid. overflows")
64 | }
65 |
66 | c.creds.Uid = uint32(uid) // #nosec: G115
67 | }
68 |
69 | if gid > -1 {
70 |
71 | if gid > math.MaxUint32 {
72 | panic("invalid gid. overflows")
73 | }
74 |
75 | c.creds.Gid = uint32(gid) // #nosec: G115
76 | }
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/pkgs/mcp/message.go:
--------------------------------------------------------------------------------
1 | package mcp
2 |
3 | type ProtocolVersion string
4 |
5 | var (
6 | ProtocolVersion20250326 ProtocolVersion = "2025-03-26"
7 | ProtocolVersion20241105 ProtocolVersion = "2024-11-05"
8 | )
9 |
10 | // Message represents the inline MPC request.
11 | type Message struct {
12 | JSONRPC string `json:"jsonrpc"`
13 | ID any `json:"id,omitempty,omitzero"`
14 | Method string `json:"method,omitempty"`
15 | Params map[string]any `json:"params,omitempty"`
16 | Result map[string]any `json:"result,omitempty"`
17 | Error *Error `json:"error,omitempty"`
18 | }
19 |
20 | // NewMessage returns a MCPCall initialized with the given id.
21 | // To initialize a call without ID set, use an empty string.
22 | func NewMessage[T int | string](id T) Message {
23 | c := Message{
24 | JSONRPC: "2.0",
25 | }
26 |
27 | var zero T
28 | if id != zero {
29 | c.ID = id
30 | }
31 |
32 | return c
33 | }
34 |
35 | // IDString returns the call ID as a string
36 | // whatever is the original type.
37 | func (c *Message) IDString() string {
38 |
39 | if c.ID == nil {
40 | return ""
41 | }
42 |
43 | return normalizeID(c.ID)
44 | }
45 |
46 | // NewInitMessage makes a new init call using the given protocol version.
47 | func NewInitMessage(proto ProtocolVersion) Message {
48 | return Message{
49 | JSONRPC: "2.0",
50 | ID: 0,
51 | Method: "initialize",
52 | Params: map[string]any{
53 | "protocolVersion": proto,
54 | "capabilities": map[string]any{
55 | "sampling": map[string]any{},
56 | "roots": map[string]any{
57 | "listChanged": true,
58 | },
59 | },
60 | "clientInfo": map[string]any{
61 | "name": "minibridge",
62 | "version": "1.0",
63 | },
64 | },
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/pkgs/internal/sanitize/data_test.go:
--------------------------------------------------------------------------------
1 | package sanitize
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 | )
7 |
8 | func TestSanitize(t *testing.T) {
9 | type args struct {
10 | data []byte
11 | }
12 | tests := []struct {
13 | name string
14 | args func(t *testing.T) args
15 |
16 | want1 []byte
17 | }{
18 |
19 | {
20 | "empty",
21 | func(*testing.T) args {
22 | return args{
23 | data: []byte(""),
24 | }
25 | },
26 | []byte(""),
27 | },
28 | {
29 | "nil",
30 | func(*testing.T) args {
31 | return args{
32 | data: nil,
33 | }
34 | },
35 | nil,
36 | },
37 | {
38 | "no suffix",
39 | func(*testing.T) args {
40 | return args{
41 | data: []byte("hello world"),
42 | }
43 | },
44 | []byte("hello world"),
45 | },
46 | {
47 | "with \n",
48 | func(*testing.T) args {
49 | return args{
50 | data: []byte("hello world\n\n\n\n\n"),
51 | }
52 | },
53 | []byte("hello world"),
54 | },
55 | {
56 | "with \r",
57 | func(*testing.T) args {
58 | return args{
59 | data: []byte("hello world\r\r\r\r"),
60 | }
61 | },
62 | []byte("hello world"),
63 | },
64 | {
65 | "with \n\r",
66 | func(*testing.T) args {
67 | return args{
68 | data: []byte("hello world\r\n"),
69 | }
70 | },
71 | []byte("hello world"),
72 | },
73 | {
74 | "with \n\r at start",
75 | func(*testing.T) args {
76 | return args{
77 | data: []byte("\r\nhello world\r\n"),
78 | }
79 | },
80 | []byte("\r\nhello world"),
81 | },
82 | }
83 |
84 | for _, tt := range tests {
85 | t.Run(tt.name, func(t *testing.T) {
86 | tArgs := tt.args(t)
87 |
88 | got1 := Data(tArgs.data)
89 |
90 | if !reflect.DeepEqual(got1, tt.want1) {
91 | t.Errorf("Sanitize got1 = %v, want1: %v", got1, tt.want1)
92 | }
93 | })
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/examples/policer-rego/police.rego:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import rego.v1
4 |
5 | # the call is allowed is we don't have any reason to deny it.
6 | allow if {
7 | count(reasons) == 0
8 | }
9 |
10 | # verifies the claims from the JWT of the agent.
11 | claims := x if {
12 | [verified, _, x] := io.jwt.decode_verify(
13 | input.agent.password,
14 | {
15 | "secret": "secret",
16 | "iss": "pki.example.com",
17 | "aud": "minibridge",
18 | },
19 | )
20 | verified == true
21 | }
22 |
23 | # if we have no claims, we add a deny reason.
24 | reasons contains msg if {
25 | not claims
26 | msg := "You must provide a valid token"
27 | }
28 |
29 | # if the call is a tools/call and the name
30 | # is printEnv and the user is not Alice,
31 | # we add a deny reason.
32 | reasons contains msg if {
33 | input.mcp.method == "tools/call"
34 | input.mcp.params.name == "printEnv"
35 | claims.email != "alice@example.com"
36 | msg := "only alice can run printEnv"
37 | }
38 |
39 | # if the call is a tools/call and the name
40 | # is longRunningOperation and the user is Bob,
41 | # we add a deny reason.
42 | reasons contains msg if {
43 | input.mcp.method == "tools/call"
44 | claims.email == "bob@example.com"
45 | input.mcp.params.name == "longRunningOperation"
46 | msg := "bob cannot run longRunningOperation"
47 | }
48 |
49 | # if the call is a tools/call and the request is from Bob, we remove the
50 | # longRunningOperation from the response. If Bob still tries to call that tool,
51 | # it will be denied by the rule above. This allows the agent to not loose time
52 | # trying a tool that will be denied anyways
53 | mcp := x if {
54 | input.mcp.result.tools
55 | claims.email == "bob@example.com"
56 |
57 | x := json.patch(input.mcp, [{
58 | "op": "replace",
59 | "path": "/result/tools",
60 | "value": [x | x := input.mcp.result.tools[_]; x.name != "longRunningOperation"],
61 | }])
62 | }
63 |
--------------------------------------------------------------------------------
/pkgs/backend/client/options_test.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "math"
5 | "testing"
6 |
7 | . "github.com/smartystreets/goconvey/convey"
8 | )
9 |
10 | func TestOptions(t *testing.T) {
11 |
12 | Convey("OptUseTempDir should work", t, func() {
13 | cfg := newStdioCfg()
14 | OptStdioUseTempDir(true)(&cfg)
15 | So(cfg.useTempDir, ShouldBeTrue)
16 | })
17 |
18 | Convey("OptCredentials should work", t, func() {
19 | cfg := newStdioCfg()
20 | OptStdioCredentials(1000, 1001, []int{2001, 2002})(&cfg)
21 | So(cfg.creds.Uid, ShouldEqual, 1000)
22 | So(cfg.creds.Gid, ShouldEqual, 1001)
23 | So(cfg.creds.Groups, ShouldResemble, []uint32{2001, 2002})
24 | })
25 |
26 | Convey("OptCredentials with -1 should work", t, func() {
27 | cfg := newStdioCfg()
28 | OptStdioCredentials(-1, -1, []int{-1})(&cfg)
29 | So(cfg.creds, ShouldBeNil)
30 | })
31 |
32 | Convey("OptCredentials with only gid -1 should work", t, func() {
33 | cfg := newStdioCfg()
34 | OptStdioCredentials(100, -1, []int{-1})(&cfg)
35 | So(cfg.creds.Uid, ShouldEqual, 100)
36 | So(cfg.creds.Gid, ShouldEqual, 0)
37 | So(cfg.creds.Groups, ShouldResemble, []uint32{})
38 | })
39 |
40 | Convey("OptCredentials with uint32 overflow on uid should fail", t, func() {
41 | cfg := newStdioCfg()
42 | So(func() { OptStdioCredentials(math.MaxInt64, 1001, []int{2001, 2002})(&cfg) }, ShouldPanicWith, "invalid uid. overflows")
43 | })
44 |
45 | Convey("OptCredentials with uint32 overflow on uid should fail", t, func() {
46 | cfg := newStdioCfg()
47 | So(func() { OptStdioCredentials(1, math.MaxInt64, []int{2001, 2002})(&cfg) }, ShouldPanicWith, "invalid gid. overflows")
48 | })
49 |
50 | Convey("OptCredentials with uint32 overflow on groups should fail", t, func() {
51 | cfg := newStdioCfg()
52 | So(func() { OptStdioCredentials(1, 1, []int{2001, math.MaxInt64})(&cfg) }, ShouldPanicWith, "invalid group 1. overflows")
53 | })
54 | }
55 |
--------------------------------------------------------------------------------
/cli/init.go:
--------------------------------------------------------------------------------
1 | package cli
2 |
3 | import (
4 | "errors"
5 | "log/slog"
6 | "os"
7 | "strings"
8 |
9 | "github.com/adrg/xdg"
10 | "github.com/spf13/viper"
11 | )
12 |
13 | var (
14 | cfgFile string
15 | cfgName string
16 | )
17 |
18 | func initCobra() {
19 |
20 | viper.SetEnvPrefix("minibridge")
21 | viper.AutomaticEnv()
22 | viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
23 |
24 | dataFolder, err := xdg.DataFile("minibridge")
25 | if err != nil {
26 | slog.Error("failed to retrieve xdg data folder: %w", err)
27 | os.Exit(1)
28 | }
29 |
30 | configFolder, err := xdg.ConfigFile("minibridge")
31 | if err != nil {
32 | slog.Error("failed to retrieve xdg data folder: %w", err)
33 | os.Exit(1)
34 | }
35 |
36 | slog.Debug("Folders configured", "config", configFolder, "data", dataFolder)
37 |
38 | if cfgFile == "" {
39 | cfgFile = os.Getenv("MINIBRIDGE_CONFIG")
40 | }
41 |
42 | if cfgFile != "" {
43 | if _, err := os.Stat(cfgFile); os.IsNotExist(err) {
44 | slog.Error("Config file does not exist", err)
45 | os.Exit(1)
46 | }
47 |
48 | viper.SetConfigType("yaml")
49 | viper.SetConfigFile(cfgFile)
50 |
51 | if err = viper.ReadInConfig(); err != nil {
52 | slog.Error("Unable to read config",
53 | "path", cfgFile,
54 | err,
55 | )
56 | os.Exit(1)
57 | }
58 |
59 | slog.Debug("Using config file", "path", cfgFile)
60 | return
61 | }
62 |
63 | viper.AddConfigPath(configFolder)
64 | viper.AddConfigPath("/usr/local/etc/minibridge")
65 | viper.AddConfigPath("/etc/minibridge")
66 |
67 | if cfgName == "" {
68 | cfgName = os.Getenv("MINIBRIDGE_CONFIG_NAME")
69 | }
70 |
71 | if cfgName == "" {
72 | cfgName = "default"
73 | }
74 |
75 | viper.SetConfigName(cfgName)
76 |
77 | if err = viper.ReadInConfig(); err != nil {
78 | if !errors.As(err, &viper.ConfigFileNotFoundError{}) {
79 | slog.Error("Unable to read config", err)
80 | os.Exit(1)
81 | }
82 | }
83 |
84 | slog.Debug("Using config name", "name", cfgName)
85 | }
86 |
--------------------------------------------------------------------------------
/pkgs/policer/api/request.go:
--------------------------------------------------------------------------------
1 | package api
2 |
3 | import (
4 | "time"
5 |
6 | "go.acuvity.ai/minibridge/pkgs/mcp"
7 | )
8 |
9 | // CallType type of request to the policer.
10 | type CallType string
11 |
12 | // Various values of RequestType
13 | var (
14 | CallTypeRequest CallType = "request"
15 | CallTypeResponse CallType = "response"
16 | )
17 |
18 | // SpanContext contains information about the OTEL span
19 | // related to a Request.
20 | type SpanContext struct {
21 | TraceID string `json:"traceID" `
22 | ParentSpanID string `json:"parentSpanID,omitempty"`
23 | End time.Time `json:"end"`
24 | ID string `json:"ID"`
25 | Name string `json:"name"`
26 | Start time.Time `json:"start"`
27 | }
28 |
29 | func (c SpanContext) IsValid() bool {
30 | return c.TraceID != "" && c.ID != ""
31 | }
32 |
33 | // A Request represents the data sent to the Policer
34 | type Request struct {
35 |
36 | // Type of the request. Request will be set for request from the agent
37 | // and Response will be set for replies from the MPC server.
38 | Type CallType `json:"type"`
39 |
40 | // MPC embeds the full MPC call, either request or response,
41 | // based on the Type.
42 | MCP mcp.Message `json:"mcp,omitzero"`
43 |
44 | // Agent contains callers information.
45 | Agent Agent `json:"agent,omitzero"`
46 |
47 | // SpanContext contains info about the eventual OTEL span
48 | // for the request. There are advanced use cases where you
49 | // want correlation between a Request and the OTEL traces
50 | // associated.
51 | SpanContext SpanContext `json:"spanContext,omitzero"`
52 | }
53 |
54 | // Agent contains information about the caller of the request.
55 | type Agent struct {
56 |
57 | // User contains the user from the Auth header.
58 | User string `json:"user"`
59 |
60 | // Password contains the password from the Auth header.
61 | Password string `json:"password"`
62 |
63 | // RemoteAddr contains the agent's RemoteAddr, as seen by minibridge.
64 | RemoteAddr string `json:"remoteAddr"`
65 |
66 | // User Agent contains the user agent field of the agent.
67 | UserAgent string `json:"userAgent,omitempty"`
68 | }
69 |
--------------------------------------------------------------------------------
/pkgs/scan/sbom.go:
--------------------------------------------------------------------------------
1 | package scan
2 |
3 | import (
4 | "fmt"
5 | "os"
6 |
7 | "go.acuvity.ai/elemental"
8 | )
9 |
10 | // Exclusions contains the resources we
11 | // want to exclude from scan
12 | type Exclusions struct {
13 | Prompts bool
14 | Resources bool
15 | Tools bool
16 | }
17 |
18 | // SBOM contains a list of hashes for hashable
19 | // resources.
20 | type SBOM struct {
21 | Tools Hashes `json:"tools,omitzero"`
22 | Prompts Hashes `json:"prompts,omitzero"`
23 | }
24 |
25 | func LoadSBOM(path string) (sbom SBOM, err error) {
26 |
27 | data, err := os.ReadFile(path) // #nosec: G304
28 | if err != nil {
29 | return sbom, fmt.Errorf("unable to load sbom file at '%s': %w", path, err)
30 | }
31 |
32 | if err := elemental.Decode(elemental.EncodingTypeJSON, data, &sbom); err != nil {
33 | return sbom, fmt.Errorf("unable to decode content of sbom file: %w", err)
34 | }
35 |
36 | return sbom, nil
37 | }
38 |
39 | // Hashes are a list of Hash.
40 | type Hashes []Hash
41 |
42 | // Matches return nil if both receiver and o
43 | // match, meaning len are identical, and all hashes
44 | // on h match hashes on o.
45 | func (h Hashes) Matches(o Hashes) error {
46 | return cmpH(h, o)
47 | }
48 |
49 | // Map converts the Hashes into a map[string]Hash keyed by the Hash Name.
50 | func (l Hashes) Map() map[string]Hash {
51 |
52 | out := make(map[string]Hash, len(l))
53 |
54 | for _, h := range l {
55 | out[h.Name] = h
56 | }
57 |
58 | return out
59 | }
60 |
61 | // A Hash represent the hash of an item with it's name
62 | // and potential parameters.
63 | type Hash struct {
64 | Name string `json:"name"`
65 | Hash string `json:"hash"`
66 | Params Hashes `json:"params,omitzero"`
67 | }
68 |
69 | func cmpH(a Hashes, b Hashes) error {
70 |
71 | if len(b) > len(a) {
72 | return fmt.Errorf("invalid len. left: %d right: %d", len(a), len(b))
73 | }
74 |
75 | am := a.Map()
76 | bm := b.Map()
77 |
78 | for name, h := range bm {
79 |
80 | o, ok := am[name]
81 | if !ok {
82 | return fmt.Errorf("'%s': missing", name)
83 | }
84 |
85 | if h.Hash != o.Hash {
86 | return fmt.Errorf("'%s': hash mismatch", name)
87 | }
88 |
89 | if len(o.Params) > 0 {
90 |
91 | if err := cmpH(o.Params, h.Params); err != nil {
92 | return fmt.Errorf("'%s': invalid param: %w", name, err)
93 | }
94 | }
95 | }
96 |
97 | return nil
98 | }
99 |
--------------------------------------------------------------------------------
/pkgs/memconn/listener.go:
--------------------------------------------------------------------------------
1 | package memconn
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net"
7 | )
8 |
9 | // Listener is an in-memory net.Listener. Call DialContext to create a new
10 | // connection.
11 | type Listener struct {
12 | pending chan *Conn
13 | closed chan struct{}
14 | }
15 |
16 | var _ net.Listener = (*Listener)(nil)
17 |
18 | // NewListener creates a new in-memory Listener.
19 | func NewListener() *Listener {
20 | return &Listener{
21 | pending: make(chan *Conn),
22 | closed: make(chan struct{}),
23 | }
24 | }
25 |
26 | // Accept waits for and returns the next connection to l. Connections to l are
27 | // established by calling l.DialContext.
28 | //
29 | // The returned net.Conn is the server side of the connection.
30 | func (l *Listener) Accept() (net.Conn, error) {
31 | select {
32 | case peer := <-l.pending:
33 | local := newConn()
34 | peer.Attach(local)
35 | local.Attach(peer)
36 | return local, nil
37 |
38 | case <-l.closed:
39 | return nil, fmt.Errorf("Listener closed")
40 | }
41 | }
42 |
43 | // Close closes l. Any blocked Accept operations will immediately be unblocked
44 | // and return errors. Already Accepted connections are not closed.
45 | func (l *Listener) Close() error {
46 | select {
47 | default:
48 | close(l.closed)
49 | return nil
50 | case <-l.closed:
51 | return fmt.Errorf("already closed")
52 | }
53 | }
54 |
55 | // Addr returns l's address. This will always be a fake "memory"
56 | // address.
57 | func (l *Listener) Addr() net.Addr {
58 | return Addr{}
59 | }
60 |
61 | // DialContext creates a new connection to l. DialContext will block until the
62 | // connection is accepted through a blocked l.Accept call or until ctx is
63 | // canceled.
64 | //
65 | // Note that unlike other Dial methods in different packages, there is no
66 | // address to supply because the remote side of the connection is always the
67 | // in-memory listener.
68 | func (l *Listener) DialContext(ctx context.Context, clientName string) (net.Conn, error) {
69 | local := newConn()
70 | local.name = clientName
71 |
72 | select {
73 | case l.pending <- local:
74 | // Wait for our peer to be connected.
75 | if err := local.WaitPeer(ctx); err != nil {
76 | return nil, err
77 | }
78 | return local, nil
79 | case <-l.closed:
80 | return nil, fmt.Errorf("server closed")
81 | case <-ctx.Done():
82 | return nil, ctx.Err()
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/pkgs/auth/auth.go:
--------------------------------------------------------------------------------
1 | package auth
2 |
3 | import (
4 | "encoding/base64"
5 | "fmt"
6 |
7 | "go.acuvity.ai/minibridge/pkgs/oauth"
8 | )
9 |
10 | // AuthScheme represents the various auth schemes.
11 | type AuthScheme int
12 |
13 | // Supported version of auth schemes.
14 | const (
15 | AuthSchemeBasic AuthScheme = iota
16 | AuthSchemeBearer
17 | AuthSchemeOAuth
18 | )
19 |
20 | // Auth holds user credentials.
21 | type Auth struct {
22 | mode AuthScheme
23 | user string
24 | password string
25 |
26 | oauthCreds oauth.Credentials
27 | }
28 |
29 | // NewBasicAuth returns a new Basic Auth.
30 | func NewBasicAuth(user string, password string) *Auth {
31 | return &Auth{
32 | mode: AuthSchemeBasic,
33 | user: user,
34 | password: password,
35 | }
36 | }
37 |
38 | // NewBearerAuth returns a new Bearer auth.
39 | // User() will be set to "Bearer" and Password() ]
40 | // will hold the token.
41 | func NewBearerAuth(token string) *Auth {
42 | return &Auth{
43 | mode: AuthSchemeBearer,
44 | user: "Bearer",
45 | password: token,
46 | }
47 | }
48 |
49 | // NewOAuthAuth returns a new OAuth auth.
50 | // User() will be set to "Bearer" and Password() ]
51 | // will hold the access token.
52 | func NewOAuthAuth(creds oauth.Credentials) *Auth {
53 | return &Auth{
54 | mode: AuthSchemeOAuth,
55 | user: "Bearer",
56 | password: creds.AccessToken,
57 | oauthCreds: creds,
58 | }
59 | }
60 |
61 | // Type returns the current type of Auth as a string.
62 | func (a *Auth) Type() string {
63 | switch a.mode {
64 | case AuthSchemeBasic:
65 | return "Basic"
66 | case AuthSchemeBearer:
67 | return "Bearer"
68 | case AuthSchemeOAuth:
69 | return "OAuth"
70 | default:
71 | panic("unknown auth mode")
72 | }
73 | }
74 |
75 | // Encode encode the Auth to transmit on the wire.
76 | func (a *Auth) Encode() string {
77 | switch a.mode {
78 | case AuthSchemeBasic:
79 | return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(fmt.Appendf([]byte{}, "%s:%s", a.user, a.password)))
80 | case AuthSchemeBearer, AuthSchemeOAuth:
81 | return fmt.Sprintf("Bearer %s", a.password)
82 | default:
83 | panic("unknown auth mode")
84 | }
85 | }
86 |
87 | // User returns the user.
88 | func (a *Auth) User() string {
89 | return a.user
90 | }
91 |
92 | // Password returns the password.
93 | func (a *Auth) Password() string {
94 | return a.password
95 | }
96 |
--------------------------------------------------------------------------------
/pkgs/backend/options_test.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "context"
5 | "crypto/tls"
6 | "testing"
7 |
8 | . "github.com/smartystreets/goconvey/convey"
9 | "go.acuvity.ai/bahamut"
10 | "go.acuvity.ai/minibridge/pkgs/mcp"
11 | "go.acuvity.ai/minibridge/pkgs/metrics"
12 | "go.acuvity.ai/minibridge/pkgs/policer/api"
13 | "go.acuvity.ai/minibridge/pkgs/scan"
14 | "go.opentelemetry.io/otel/trace/noop"
15 | )
16 |
17 | type fakePolicer struct {
18 | }
19 |
20 | func (f fakePolicer) Police(context.Context, api.Request) (*mcp.Message, error) {
21 | return nil, nil
22 | }
23 |
24 | func (f fakePolicer) Type() string { return "fake" }
25 |
26 | func TestOptions(t *testing.T) {
27 |
28 | Convey("OptPolicer should work", t, func() {
29 | cfg := newWSCfg()
30 | f := fakePolicer{}
31 | OptPolicer(f)(&cfg)
32 | So(cfg.policer, ShouldEqual, f)
33 | })
34 |
35 | Convey("OptDumpStderrOnError", t, func() {
36 | cfg := newWSCfg()
37 | OptDumpStderrOnError(true)(&cfg)
38 | So(cfg.dumpStderr, ShouldBeTrue)
39 | })
40 |
41 | Convey("OPtCORSPolicy should work", t, func() {
42 | cfg := newWSCfg()
43 | p := &bahamut.CORSPolicy{}
44 | OptCORSPolicy(p)(&cfg)
45 | So(cfg.corsPolicy, ShouldEqual, p)
46 | })
47 |
48 | Convey("OptSBOM should work", t, func() {
49 | cfg := newWSCfg()
50 | s := scan.SBOM{}
51 | OptSBOM(s)(&cfg)
52 | So(cfg.sbom, ShouldEqual, s)
53 | })
54 |
55 | Convey("OptMetricsManager should work", t, func() {
56 | cfg := newWSCfg()
57 | mm := &metrics.Manager{}
58 | OptMetricsManager(mm)(&cfg)
59 | So(cfg.metricsManager, ShouldEqual, mm)
60 | })
61 |
62 | Convey("OptTracer should work", t, func() {
63 | cfg := newWSCfg()
64 | t := noop.NewTracerProvider().Tracer("test")
65 | OptTracer(t)(&cfg)
66 | So(cfg.tracer, ShouldEqual, t)
67 | })
68 |
69 | Convey("OptTracer with nil should work", t, func() {
70 | cfg := newWSCfg()
71 | OptTracer(nil)(&cfg)
72 | So(cfg.tracer, ShouldHaveSameTypeAs, noop.NewTracerProvider().Tracer("test"))
73 | })
74 |
75 | Convey("OptListener should work", t, func() {
76 | cfg := newWSCfg()
77 | listener := tls.NewListener(nil, nil)
78 | OptListener(listener)(&cfg)
79 | So(cfg.listener, ShouldEqual, listener)
80 | })
81 |
82 | Convey("OptPolicerEnforce should work", t, func() {
83 | cfg := newWSCfg()
84 | So(cfg.policerEnforced, ShouldBeTrue)
85 | OptPolicerEnforce(false)(&cfg)
86 | So(cfg.policerEnforced, ShouldBeFalse)
87 | })
88 | }
89 |
--------------------------------------------------------------------------------
/pkgs/mcp/id.go:
--------------------------------------------------------------------------------
1 | package mcp
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "math"
7 | "reflect"
8 | )
9 |
10 | // RelatedIDs returns true if the two given
11 | // ID as any are equal.
12 | // Slow clap on spec that basically says
13 | // MUST SHOULD BE AN INT. Or a string. whatever...
14 | func RelatedIDs(a any, b any) bool {
15 |
16 | if sa, ok := a.(string); ok {
17 | sb, ok := b.(string)
18 | return ok && sa == sb
19 | }
20 |
21 | if isNumeric(a) && isNumeric(b) {
22 | ia, oka := extractInt64(a)
23 | ib, okb := extractInt64(b)
24 | return oka && okb && ia == ib
25 | }
26 |
27 | return reflect.TypeOf(a) == reflect.TypeOf(b) && reflect.DeepEqual(a, b)
28 | }
29 |
30 | func normalizeID(id any) string {
31 | switch v := id.(type) {
32 | case string:
33 | return v
34 | case int:
35 | return fmt.Sprintf("%d", v)
36 | case int64:
37 | return fmt.Sprintf("%d", v)
38 | case uint64:
39 | return fmt.Sprintf("%d", v)
40 | case float64:
41 | return fmt.Sprintf("%.0f", v)
42 | case json.Number:
43 | return v.String()
44 | default:
45 | return fmt.Sprintf("%v", v)
46 | }
47 | }
48 |
49 | func isNumeric(v any) bool {
50 | switch v.(type) {
51 | case int, int8, int16, int32, int64,
52 | uint, uint8, uint16, uint32, uint64,
53 | float32, float64, json.Number:
54 | return true
55 | default:
56 | return false
57 | }
58 | }
59 |
60 | func extractInt64(v any) (int64, bool) {
61 | switch val := v.(type) {
62 | case int:
63 | return int64(val), true
64 | case int64:
65 | return val, true
66 | case int32:
67 | return int64(val), true
68 | case int16:
69 | return int64(val), true
70 | case int8:
71 | return int64(val), true
72 | case uint:
73 | if val <= math.MaxInt64 {
74 | return int64(val), true
75 | }
76 | return 0, false
77 | case uint64:
78 | if val <= math.MaxInt64 {
79 | return int64(val), true
80 | }
81 | return 0, false
82 | case uint32:
83 | return int64(val), true
84 | case float64:
85 | if val == float64(int64(val)) {
86 | return int64(val), true
87 | }
88 | return 0, false
89 | case float32:
90 | f := float64(val)
91 | if f == float64(int64(f)) {
92 | return int64(f), true
93 | }
94 | return 0, false
95 | case json.Number:
96 | if i, err := val.Int64(); err == nil {
97 | return i, true
98 | }
99 | return 0, false
100 | default:
101 | return 0, false
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/pkgs/policer/internal/http/policer.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "crypto/tls"
7 | "fmt"
8 | "io"
9 | "net/http"
10 | "strings"
11 |
12 | "go.acuvity.ai/elemental"
13 | "go.acuvity.ai/minibridge/pkgs/auth"
14 | "go.acuvity.ai/minibridge/pkgs/mcp"
15 | "go.acuvity.ai/minibridge/pkgs/policer/api"
16 | )
17 |
18 | type Policer struct {
19 | endpoint string
20 | auth auth.Auth
21 | client *http.Client
22 | }
23 |
24 | // New returns a new HTTP based Policer.
25 | func New(endpoint string, auth *auth.Auth, tlsConfig *tls.Config) *Policer {
26 |
27 | return &Policer{
28 | endpoint: endpoint,
29 | auth: *auth,
30 | client: &http.Client{
31 | Transport: &http.Transport{
32 | TLSClientConfig: tlsConfig,
33 | },
34 | },
35 | }
36 | }
37 |
38 | func (p *Policer) Type() string { return "http" }
39 |
40 | func (p *Policer) Police(ctx context.Context, preq api.Request) (*mcp.Message, error) {
41 |
42 | body, err := elemental.Encode(elemental.EncodingTypeJSON, preq)
43 | if err != nil {
44 | return nil, fmt.Errorf("unable to encode scan request: %w", err)
45 | }
46 |
47 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.endpoint, bytes.NewBuffer(body))
48 | if err != nil {
49 | return nil, fmt.Errorf("unable to create new http request: %w", err)
50 | }
51 |
52 | req.Header.Add("Accept", "application/json")
53 | req.Header.Add("Content-Type", "application/json")
54 | req.Header.Add("Authorization", p.auth.Encode())
55 |
56 | resp, err := p.client.Do(req)
57 | if err != nil {
58 | return nil, fmt.Errorf("unable to send request: %w", err)
59 | }
60 |
61 | if resp.StatusCode == http.StatusNoContent {
62 | return nil, nil
63 | }
64 |
65 | rbody, err := io.ReadAll(resp.Body)
66 | if err != nil {
67 | return nil, fmt.Errorf("unable to read response body: %w", err)
68 | }
69 | defer func() { _ = resp.Body.Close() }()
70 |
71 | if resp.StatusCode != http.StatusOK {
72 | return nil, fmt.Errorf("invalid response from policer `%s`: %s", string(rbody), resp.Status)
73 | }
74 |
75 | sresp := api.Response{}
76 | if err := elemental.Decode(elemental.EncodingTypeJSON, rbody, &sresp); err != nil {
77 | return nil, fmt.Errorf("unable to decode response body: %w", err)
78 | }
79 |
80 | if sresp.MCP != nil && preq.MCP.ID != nil {
81 | sresp.MCP.ID = preq.MCP.ID
82 | }
83 |
84 | if sresp.Allow {
85 | return sresp.MCP, nil
86 | }
87 |
88 | if len(sresp.Reasons) == 0 {
89 | sresp.Reasons = []string{api.GenericDenyReason}
90 | }
91 |
92 | return nil, fmt.Errorf("%w: %s", api.ErrBlocked, strings.Join(sresp.Reasons, ", "))
93 | }
94 |
--------------------------------------------------------------------------------
/pkgs/backend/helpers_test.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "testing"
7 |
8 | "go.acuvity.ai/minibridge/pkgs/auth"
9 | )
10 |
11 | func Test_makeMCPError(t *testing.T) {
12 | type args struct {
13 | ID any
14 | err error
15 | }
16 | tests := []struct {
17 | name string
18 | args func(t *testing.T) args
19 |
20 | want1 []byte
21 | }{
22 | {
23 | "basic",
24 | func(*testing.T) args {
25 | return args{
26 | ID: 42,
27 | err: fmt.Errorf("oh noes!"),
28 | }
29 | },
30 | []byte(`{"error":{"code":451,"message":"oh noes!"},"id":42,"jsonrpc":"2.0"}`),
31 | },
32 | }
33 |
34 | for _, tt := range tests {
35 | t.Run(tt.name, func(t *testing.T) {
36 | tArgs := tt.args(t)
37 |
38 | got1 := makeMCPError(tArgs.ID, tArgs.err)
39 |
40 | if !reflect.DeepEqual(got1, tt.want1) {
41 | t.Errorf("makeMCPError got1 = %v, want1: %v", string(got1), string(tt.want1))
42 | }
43 | })
44 | }
45 | }
46 |
47 | func Test_parseBasicAuth(t *testing.T) {
48 | type args struct {
49 | authString string
50 | }
51 | tests := []struct {
52 | name string
53 | args func(t *testing.T) args
54 |
55 | want1 *auth.Auth
56 | want2 bool
57 | }{
58 | {
59 | "empty header",
60 | func(t *testing.T) args {
61 | return args{
62 | authString: "",
63 | }
64 | },
65 | nil,
66 | false,
67 | },
68 | {
69 | "bearer",
70 | func(t *testing.T) args {
71 | return args{
72 | authString: "Bearer token",
73 | }
74 | },
75 | auth.NewBearerAuth("token"),
76 | true,
77 | },
78 | {
79 | "basic",
80 | func(t *testing.T) args {
81 | return args{
82 | authString: "Basic dXNlcjpwYXNz",
83 | }
84 | },
85 | auth.NewBasicAuth("user", "pass"),
86 | true,
87 | },
88 | {
89 | "invalid basic b64",
90 | func(t *testing.T) args {
91 | return args{
92 | authString: "Basic not-b64",
93 | }
94 | },
95 | nil,
96 | false,
97 | },
98 | {
99 | "invalid basic decoded",
100 | func(t *testing.T) args {
101 | return args{
102 | authString: "Basic aGVsbG8=",
103 | }
104 | },
105 | nil,
106 | false,
107 | },
108 | }
109 |
110 | for _, tt := range tests {
111 | t.Run(tt.name, func(t *testing.T) {
112 | tArgs := tt.args(t)
113 |
114 | got1, got2 := parseBasicAuth(tArgs.authString)
115 |
116 | if !reflect.DeepEqual(got1, tt.want1) {
117 | t.Errorf("parseBasicAuth got1 = %v, want1: %v", got1, tt.want1)
118 | }
119 |
120 | if !reflect.DeepEqual(got2, tt.want2) {
121 | t.Errorf("parseBasicAuth got2 = %v, want2: %v", got2, tt.want2)
122 | }
123 | })
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/cli/internal/cmd/backend.go:
--------------------------------------------------------------------------------
1 | package cmd
2 |
3 | import (
4 | "fmt"
5 | "log/slog"
6 |
7 | "github.com/spf13/cobra"
8 | "github.com/spf13/pflag"
9 | "github.com/spf13/viper"
10 | "go.acuvity.ai/minibridge/pkgs/backend"
11 | )
12 |
13 | var fBackend = pflag.NewFlagSet("backend", pflag.ExitOnError)
14 |
15 | func init() {
16 |
17 | initSharedFlagSet()
18 |
19 | fBackend.StringP("listen", "l", ":8000", "listen address of the bridge for incoming websocket connections.")
20 |
21 | Backend.Flags().AddFlagSet(fBackend)
22 | Backend.Flags().AddFlagSet(fPolicer)
23 | Backend.Flags().AddFlagSet(fTLSServer)
24 | Backend.Flags().AddFlagSet(fHealth)
25 | Backend.Flags().AddFlagSet(fProfiler)
26 | Backend.Flags().AddFlagSet(fCORS)
27 | Backend.Flags().AddFlagSet(fSBOM)
28 | Backend.Flags().AddFlagSet(fMCP)
29 | }
30 |
31 | // Backend is the cobra command to run the server.
32 | var Backend = &cobra.Command{
33 | Use: "backend [flags] -- command [args...]",
34 | Short: "Start a minibridge backend to expose an MCP server",
35 | SilenceUsage: true,
36 | SilenceErrors: true,
37 | TraverseChildren: true,
38 | Args: cobra.MinimumNArgs(1),
39 |
40 | RunE: func(cmd *cobra.Command, args []string) error {
41 |
42 | listen := viper.GetString("listen")
43 |
44 | if listen == "" {
45 | return fmt.Errorf("--listen must be set")
46 | }
47 |
48 | backendTLSConfig, err := tlsConfigFromFlags(fTLSServer)
49 | if err != nil {
50 | return err
51 | }
52 |
53 | policer, penforce, err := makePolicer()
54 | if err != nil {
55 | return fmt.Errorf("unable to make policer: %w", err)
56 | }
57 |
58 | sbom, err := makeSBOM()
59 | if err != nil {
60 | return fmt.Errorf("unable to make hashes: %w", err)
61 | }
62 |
63 | tracer, err := makeTracer(cmd.Context(), "backend")
64 | if err != nil {
65 | return fmt.Errorf("unable to configure tracer: %w", err)
66 | }
67 |
68 | corsPolicy := makeCORSPolicy()
69 |
70 | mcpClient, err := makeMCPClient(args, true)
71 | if err != nil {
72 | return fmt.Errorf("unable to create MCP client: %w", err)
73 | }
74 |
75 | mm := startHealthServer(cmd.Context())
76 |
77 | slog.Info("Minibridge backend configured",
78 | "server-tls", backendTLSConfig != nil,
79 | "server-mtls", mtlsMode(backendTLSConfig),
80 | "listen", listen,
81 | )
82 |
83 | proxy := backend.NewWebSocket(listen, backendTLSConfig, mcpClient,
84 | backend.OptPolicer(policer),
85 | backend.OptPolicerEnforce(penforce),
86 | backend.OptDumpStderrOnError(viper.GetString("log-format") != "json"),
87 | backend.OptCORSPolicy(corsPolicy),
88 | backend.OptSBOM(sbom),
89 | backend.OptMetricsManager(mm),
90 | backend.OptTracer(tracer),
91 | )
92 |
93 | return proxy.Start(cmd.Context())
94 | },
95 | }
96 |
--------------------------------------------------------------------------------
/pkgs/backend/options.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "net"
5 |
6 | "go.acuvity.ai/bahamut"
7 | "go.acuvity.ai/minibridge/pkgs/metrics"
8 | "go.acuvity.ai/minibridge/pkgs/policer"
9 | "go.acuvity.ai/minibridge/pkgs/scan"
10 | "go.opentelemetry.io/otel/trace"
11 | "go.opentelemetry.io/otel/trace/noop"
12 | )
13 |
14 | type wsCfg struct {
15 | corsPolicy *bahamut.CORSPolicy
16 | dumpStderr bool
17 | listener net.Listener
18 | metricsManager *metrics.Manager
19 | policer policer.Policer
20 | policerEnforced bool
21 | sbom scan.SBOM
22 | tracer trace.Tracer
23 | }
24 |
25 | func newWSCfg() wsCfg {
26 | return wsCfg{
27 | tracer: noop.NewTracerProvider().Tracer("noop"),
28 | policerEnforced: true,
29 | }
30 | }
31 |
32 | // Option are options that can be given to NewStdio().
33 | type Option func(*wsCfg)
34 |
35 | // OptPolicer sets the Policer to forward the traffic to.
36 | func OptPolicer(policer policer.Policer) Option {
37 | return func(cfg *wsCfg) {
38 | cfg.policer = policer
39 | }
40 | }
41 |
42 | // OptPolicerEnforce sets the Policer decision should be enforced
43 | // or just logged. The default is true if a policer is set
44 | func OptPolicerEnforce(enforced bool) Option {
45 | return func(cfg *wsCfg) {
46 | cfg.policerEnforced = enforced
47 | }
48 | }
49 |
50 | // OptDumpStderrOnError controls whether the WS server should
51 | // dump the stderr of the MCP server as is, or in a log.
52 | func OptDumpStderrOnError(dump bool) Option {
53 | return func(cfg *wsCfg) {
54 | cfg.dumpStderr = dump
55 | }
56 | }
57 |
58 | // OptCORSPolicy sets the bahamut.CORSPolicy to use for
59 | // connection originating from a webrowser.
60 | func OptCORSPolicy(policy *bahamut.CORSPolicy) Option {
61 | return func(cfg *wsCfg) {
62 | cfg.corsPolicy = policy
63 | }
64 | }
65 |
66 | // OptSBOM sets a the utils.SBOM to use to verify
67 | // server integrity.
68 | func OptSBOM(sbom scan.SBOM) Option {
69 | return func(cfg *wsCfg) {
70 | cfg.sbom = sbom
71 | }
72 | }
73 |
74 | // OptMetricsManager sets the metric manager to use to collect
75 | // prometheus metrics.
76 | func OptMetricsManager(m *metrics.Manager) Option {
77 | return func(cfg *wsCfg) {
78 | cfg.metricsManager = m
79 | }
80 | }
81 |
82 | // OptTracer sets the otel trace.Tracer to use to trace requests
83 | func OptTracer(tracer trace.Tracer) Option {
84 | return func(cfg *wsCfg) {
85 | if tracer == nil {
86 | tracer = noop.NewTracerProvider().Tracer("noop")
87 | }
88 | cfg.tracer = tracer
89 | }
90 | }
91 |
92 | // OptListener sets the listener to use for the server.
93 | // by defaut, it will use a classic listener.
94 | func OptListener(listener net.Listener) Option {
95 | return func(cfg *wsCfg) {
96 | cfg.listener = listener
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/examples/policer-http/policer.py:
--------------------------------------------------------------------------------
1 | #!/bin/python
2 | """
3 | This is an example implementation for a policer
4 | """
5 |
6 | import json
7 | import jwt
8 | from termcolor import colored as c
9 | from flask import Flask, request, Response
10 |
11 | app = Flask(__name__)
12 |
13 | FORBIDDEN_TOOLS = {
14 | "*": ["printEnv"],
15 | "bob@example.com": ["longRunningOperation"],
16 | "alice@example.com": [],
17 | }
18 |
19 |
20 | @app.route("/police", methods=["POST"])
21 | def police():
22 | """handles /police"""
23 |
24 | req = request.get_json()
25 | agent = req["agent"]
26 | mcp = req["mcp"]
27 |
28 | print()
29 | print("---")
30 | print(c(f"Type: {req['type']}", "green" if req["type"] == "request" else "yellow"))
31 | print(c(f"Agent: {agent['userAgent']} {agent['remoteAddr']}", "blue"))
32 | print()
33 | print(c(f"{request.headers}", "dark_grey"))
34 | print(json.dumps(req, sort_keys=True, indent=4))
35 |
36 | # Check the agent token. Deny if not valid.
37 | # This example is meant to work with JWT issued
38 | # with the gen-test-tokens.sh located a the parent folder.
39 | try:
40 | claims = jwt.decode(
41 | agent["password"],
42 | "secret",
43 | algorithms=["HS256"],
44 | issuer="pki.example.com",
45 | audience="minibridge",
46 | )
47 | except Exception as e:
48 | print(c(f"DENIED: invalid token: {e}", "red"))
49 | return json.dumps({"allow": False, "reasons": [f"{e}"]})
50 |
51 | # This is an example of blanket policing. We deny access
52 | # to the tool/calls declared in FORBIDDEN_TOOLS
53 | if req["type"] == "request":
54 | if (
55 | mcp["method"] == "tools/call"
56 | and "name" in mcp["params"]
57 | and (
58 | mcp["params"]["name"] in FORBIDDEN_TOOLS["*"]
59 | or mcp["params"]["name"] in FORBIDDEN_TOOLS[claims["email"]]
60 | )
61 | ):
62 | dmsg = f"forbidden method call {mcp['params']['name']} {mcp['method']}"
63 | print(c(f"DENIED: {dmsg}", "red"))
64 | return json.dumps({"allow": False, "reasons": [dmsg]})
65 |
66 | # This is an example of redaction: If the
67 | # user is Bob, then we remove the tool named `longRunningOperation`.
68 | # from the response.
69 | if (
70 | req["type"] == "response"
71 | and "result" in req["mcp"]
72 | and claims["email"] == "bob@example.com"
73 | ):
74 | result = req["mcp"]["result"]
75 | if "tools" in result:
76 | result["tools"] = [
77 | cell
78 | for cell in result["tools"]
79 | if cell["name"] != "longRunningOperation"
80 | ]
81 | return Response(
82 | status=200, response=json.dumps({"allow": True, "mcp": mcp})
83 | )
84 |
85 | # otherwise we allow everything
86 | return Response(status=204)
87 |
88 |
89 | app.run(port=5000)
90 |
--------------------------------------------------------------------------------
/pkgs/internal/cors/cors_test.go:
--------------------------------------------------------------------------------
1 | package cors
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "testing"
7 |
8 | . "github.com/smartystreets/goconvey/convey"
9 | "go.acuvity.ai/bahamut"
10 | )
11 |
12 | func TestThing(t *testing.T) {
13 |
14 | Convey("Given I have a response writer, req and corsPolicy", t, func() {
15 | w := httptest.NewRecorder()
16 | req := &http.Request{}
17 | pol := &bahamut.CORSPolicy{
18 | AllowOrigin: "https://coucou.test",
19 | AllowCredentials: true,
20 | MaxAge: 1500,
21 | AllowHeaders: []string{
22 | "Authorization",
23 | },
24 | AllowMethods: []string{
25 | "GET",
26 | "POST",
27 | "OPTIONS",
28 | },
29 | }
30 |
31 | Convey("Calling HandleCors should work on OPTIONS", func() {
32 |
33 | req.Method = http.MethodOptions
34 |
35 | shouldCont := HandleCORS(w, req, pol)
36 | So(shouldCont, ShouldBeFalse)
37 | So(w.Result().Header, ShouldResemble, http.Header{
38 | "Access-Control-Allow-Credentials": {"true"},
39 | "Access-Control-Allow-Headers": {"Authorization"},
40 | "Access-Control-Allow-Methods": {"GET, POST, OPTIONS"},
41 | "Access-Control-Allow-Origin": {"https://coucou.test"},
42 | "Access-Control-Expose-Headers": {""},
43 | "Access-Control-Max-Age": {"1500"},
44 | "Cache-Control": {"private, no-transform"},
45 | "Strict-Transport-Security": {"max-age=31536000; includeSubDomains; preload"},
46 | "X-Content-Type-Options": {"nosniff"},
47 | "X-Frame-Options": {"DENY"},
48 | "X-Xss-Protection": {"1; mode=block"},
49 | })
50 | })
51 |
52 | Convey("Calling HandleCors should work on non OPTIONS", func() {
53 |
54 | req.Method = http.MethodPost
55 |
56 | shouldCont := HandleCORS(w, req, pol)
57 | So(shouldCont, ShouldBeTrue)
58 | So(w.Result().Header, ShouldResemble, http.Header{
59 | "Access-Control-Allow-Credentials": {"true"},
60 | "Access-Control-Allow-Origin": {"https://coucou.test"},
61 | "Access-Control-Expose-Headers": {""},
62 | "Cache-Control": {"private, no-transform"},
63 | "Strict-Transport-Security": {"max-age=31536000; includeSubDomains; preload"},
64 | "X-Content-Type-Options": {"nosniff"},
65 | "X-Frame-Options": {"DENY"},
66 | "X-Xss-Protection": {"1; mode=block"},
67 | })
68 | })
69 |
70 | Convey("Calling HandleCors should work with no policy", func() {
71 |
72 | req.Method = http.MethodPost
73 |
74 | shouldCont := HandleCORS(w, req, nil)
75 | So(shouldCont, ShouldBeTrue)
76 | So(w.Result().Header, ShouldResemble, http.Header{
77 | "Cache-Control": {"private, no-transform"},
78 | "Strict-Transport-Security": {"max-age=31536000; includeSubDomains; preload"},
79 | "X-Content-Type-Options": {"nosniff"},
80 | "X-Frame-Options": {"DENY"},
81 | "X-Xss-Protection": {"1; mode=block"},
82 | })
83 | })
84 | })
85 | }
86 |
--------------------------------------------------------------------------------
/pkgs/frontend/connect.go:
--------------------------------------------------------------------------------
1 | package frontend
2 |
3 | import (
4 | "context"
5 | "crypto/tls"
6 | "errors"
7 | "fmt"
8 | "io"
9 | "log/slog"
10 | "net"
11 | "net/http"
12 | "strings"
13 |
14 | "go.acuvity.ai/minibridge/pkgs/auth"
15 | "go.acuvity.ai/wsc"
16 | "go.opentelemetry.io/otel"
17 | "go.opentelemetry.io/otel/propagation"
18 | )
19 |
20 | var ErrAuthRequired = errors.New("authorization required")
21 |
22 | // AgentInfo holds information about the agent
23 | // who sent an MCPCall.
24 | type AgentInfo struct {
25 | Auth *auth.Auth
26 | AuthHeaders []string
27 | UserAgent string
28 | RemoteAddr string
29 | }
30 |
31 | // Connect is a low level function to connect to the backend's websocket
32 | func Connect(
33 | ctx context.Context,
34 | dialer func(ctx context.Context, network, addr string) (net.Conn, error),
35 | backendURL string,
36 | tlsConfig *tls.Config,
37 | info AgentInfo,
38 | ) (wsc.Websocket, error) {
39 |
40 | slog.Debug("New websocket connection",
41 | "url", backendURL,
42 | "using-auth", info.Auth != nil,
43 | "using-headers", len(info.AuthHeaders) > 0,
44 | "tls", strings.HasPrefix(backendURL, "wss://"),
45 | "tls-config", tlsConfig != nil,
46 | )
47 |
48 | if dialer == nil && (info.Auth != nil || len(info.AuthHeaders) > 0) && tlsConfig == nil {
49 | slog.Warn("Security: connecting to a websocket with crendentials sent over the network in clear-text. Refused. Credentials have been stripped. Request will proceed and will likely fail.")
50 | }
51 |
52 | wsconfig := wsc.Config{
53 | WriteChanSize: 64,
54 | ReadChanSize: 16,
55 | TLSConfig: tlsConfig,
56 | NetDialContextFunc: dialer,
57 | }
58 |
59 | wsconfig.Headers = http.Header{
60 | "X-Forwarded-UA": {info.UserAgent},
61 | "X-Forwarded-For": {info.RemoteAddr},
62 | }
63 |
64 | otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(wsconfig.Headers))
65 |
66 | if tlsConfig != nil || dialer != nil {
67 | if info.Auth != nil {
68 | wsconfig.Headers["Authorization"] = []string{info.Auth.Encode()}
69 | } else if len(info.AuthHeaders) > 0 {
70 | wsconfig.Headers["Authorization"] = info.AuthHeaders
71 | }
72 | }
73 |
74 | session, resp, err := wsc.Connect(ctx, backendURL, wsconfig)
75 | if err != nil {
76 |
77 | var data []byte
78 | var code int
79 | status := ""
80 |
81 | if resp != nil {
82 |
83 | if resp.StatusCode == http.StatusUnauthorized {
84 | return nil, ErrAuthRequired
85 | }
86 |
87 | data, _ = io.ReadAll(resp.Body)
88 | _ = resp.Body.Close()
89 |
90 | code = resp.StatusCode
91 | status = resp.Status
92 | }
93 |
94 | slog.Error("WS connection failed", "code", code, "status", status, "data", strings.TrimSpace(string(data)), err)
95 |
96 | return nil, fmt.Errorf("unable to connect to the websocket. code: %d, status: %s: %w", code, status, err)
97 | }
98 |
99 | defer func() { _ = resp.Body.Close() }()
100 |
101 | if resp.StatusCode != http.StatusSwitchingProtocols {
102 | return nil, fmt.Errorf("invalid response from other end of the tunnel (must be 101): %s", resp.Status)
103 | }
104 |
105 | return session, nil
106 | }
107 |
--------------------------------------------------------------------------------
/pkgs/frontend/internal/session/manager_test.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "testing"
5 | "time"
6 |
7 | . "github.com/smartystreets/goconvey/convey"
8 | "go.acuvity.ai/wsc"
9 | )
10 |
11 | func TestManager(t *testing.T) {
12 |
13 | Convey("Manager should work", t, func() {
14 |
15 | ws1 := wsc.NewMockWebsocket(t.Context())
16 | ws2 := wsc.NewMockWebsocket(t.Context())
17 |
18 | m := NewManager()
19 | s1 := New(ws1, 1, "1")
20 | s2 := newSession(ws2, 2, "2", 2*time.Second)
21 |
22 | So(s1.nextDeadline, ShouldEqual, defaultDeadlineDuration)
23 | So(s2.nextDeadline, ShouldEqual, 2*time.Second)
24 | So(s1.ValidateHash(1), ShouldBeTrue)
25 | So(s2.ValidateHash(2), ShouldBeTrue)
26 |
27 | m.Register(s1)
28 | m.Register(s2)
29 | So(len(m.sessions), ShouldEqual, 2)
30 | So(m.sessions["1"].getCount(), ShouldEqual, 1)
31 | So(m.sessions["2"].getCount(), ShouldEqual, 1)
32 |
33 | // We acquire 1
34 | m.Acquire("1", nil)
35 | So(len(m.sessions), ShouldEqual, 2)
36 | So(m.sessions["1"].getCount(), ShouldEqual, 2)
37 | So(m.sessions["2"].getCount(), ShouldEqual, 1)
38 |
39 | go func() { s1.Write([]byte("coucou")) }()
40 | So(string(<-ws1.LastWrite()), ShouldEqual, "coucou")
41 |
42 | // We releease 1
43 | m.Release("1", nil)
44 | So(len(m.sessions), ShouldEqual, 2)
45 | So(m.sessions["1"].getCount(), ShouldEqual, 1)
46 | So(m.sessions["2"].getCount(), ShouldEqual, 1)
47 |
48 | // We release 1 again, it should be removed
49 | m.Release("1", nil)
50 | So(len(m.sessions), ShouldEqual, 1)
51 | So(m.sessions["1"], ShouldBeNil)
52 | So(m.sessions["2"].getCount(), ShouldEqual, 1)
53 |
54 | // we over release, it should be noop
55 | m.Release("1", nil)
56 | So(len(m.sessions), ShouldEqual, 1)
57 | So(m.sessions["1"], ShouldBeNil)
58 | So(m.sessions["2"].getCount(), ShouldEqual, 1)
59 |
60 | // We simulate a message when there is no hook
61 | // this should be noop and should not break the rest of
62 | // the test
63 | ws2.NextRead([]byte("nobody there"))
64 | <-time.After(time.Second)
65 |
66 | // We acquire 2 with a chan
67 | ch1 := make(chan []byte)
68 | ch2 := make(chan []byte)
69 | m.Acquire("2", ch1)
70 | m.Acquire("2", ch2)
71 | So(len(m.sessions), ShouldEqual, 1)
72 | So(m.sessions["2"].getCount(), ShouldEqual, 3)
73 | So(len(m.sessions["2"].getHooks()), ShouldEqual, 2)
74 |
75 | // we simulate a message from the ws
76 | ws2.NextRead([]byte("coucou"))
77 |
78 | // we should get it on ch1
79 | var data []byte
80 | select {
81 | case data = <-ch1:
82 | case <-time.After(time.Second):
83 | }
84 | So(string(data), ShouldEqual, "coucou")
85 |
86 | // we should get it on ch2
87 | select {
88 | case data = <-ch2:
89 | case <-time.After(time.Second):
90 | }
91 | So(string(data), ShouldEqual, "coucou")
92 |
93 | // Now we release 2 twice
94 | m.Release("2", ch1)
95 | m.Release("2", ch2)
96 | So(len(m.sessions), ShouldEqual, 1)
97 | So(m.sessions["2"].getCount(), ShouldEqual, 1)
98 | So(len(m.sessions["2"].getHooks()), ShouldEqual, 0)
99 |
100 | select {
101 | case <-s2.Done():
102 | m.Release("2", nil)
103 | case <-time.After(5 * time.Second):
104 | }
105 |
106 | So(len(m.sessions), ShouldEqual, 0)
107 | })
108 | }
109 |
--------------------------------------------------------------------------------
/cli/internal/cmd/flags.go:
--------------------------------------------------------------------------------
1 | package cmd
2 |
3 | import (
4 | "github.com/spf13/pflag"
5 | )
6 |
7 | var (
8 | fTLSClient = pflag.NewFlagSet("tlsclient", pflag.ExitOnError)
9 | fTLSServer = pflag.NewFlagSet("tlsserver", pflag.ExitOnError)
10 | fProfiler = pflag.NewFlagSet("profile", pflag.ExitOnError)
11 | fHealth = pflag.NewFlagSet("health", pflag.ExitOnError)
12 | fPolicer = pflag.NewFlagSet("policer", pflag.ExitOnError)
13 | fCORS = pflag.NewFlagSet("cors", pflag.ExitOnError)
14 | fAgentAuth = pflag.NewFlagSet("agentauth", pflag.ExitOnError)
15 | fSBOM = pflag.NewFlagSet("sbom", pflag.ExitOnError)
16 | fMCP = pflag.NewFlagSet("mcp", pflag.ExitOnError)
17 |
18 | initialized = false
19 | )
20 |
21 | func initSharedFlagSet() {
22 |
23 | if initialized {
24 | return
25 | }
26 |
27 | initialized = true
28 |
29 | fTLSServer.StringP("tls-server-cert", "c", "", "path to the server certificate for incoming HTTPS connections.")
30 | fTLSServer.StringP("tls-server-key", "k", "", "path to the key for the server certificate.")
31 | fTLSServer.StringP("tls-server-key-pass", "p", "", "passphrase for the server certificate key.")
32 | fTLSServer.String("tls-server-client-ca", "", "path to a CA to require and validate incoming client certificates.")
33 |
34 | fTLSClient.StringP("tls-client-cert", "C", "", "path to the client certificate to authenticate against the minibridge backend.")
35 | fTLSClient.StringP("tls-client-key", "K", "", "path to the key for the client certificate.")
36 | fTLSClient.StringP("tls-client-key-pass", "P", "", "passphrase for the client certificate key.")
37 | fTLSClient.String("tls-client-backend-ca", "", "path to a CA to validate the minibridge backend server certificates.")
38 | fTLSClient.Bool("tls-client-insecure-skip-verify", false, "skip backend's server certificates validation. INSECURE.")
39 |
40 | fHealth.String("health-listen", "", "if set, start health server on that address.")
41 |
42 | fPolicer.StringP("policer-type", "P", "", "type of policer to use. 'rego' or 'http'.")
43 | fPolicer.Bool("policer-enforce", true, "enforce policy or only log verdict.")
44 | fPolicer.String("policer-rego-policy", "", "path to a rego policy file for the rego policer.")
45 | fPolicer.String("policer-http-url", "", "URL of the HTTP policer to POST agent policing requests.")
46 | fPolicer.String("policer-http-bearer-token", "", "token to use to authenticate against the HTTP policer using Bearer scheme.")
47 | fPolicer.String("policer-http-basic-user", "", "user to use to authenticate against the HTTP policer using Basic scheme.")
48 | fPolicer.String("policer-http-basic-pass", "", "password to use to authenticate against the HTTP policer using Basic scheme.")
49 | fPolicer.String("policer-http-ca", "", "path to a CA to validate the policer server certificates.")
50 | fPolicer.Bool("policer-http-insecure-skip-verify", false, "skip policer's server certificates validation. INSECURE.")
51 |
52 | fCORS.String("cors-origin", "*", "sets the valid HTTP Origin for CORS responses.")
53 |
54 | fAgentAuth.StringP("agent-token", "t", "", "JWT token to pass to the minibridge backend for agent identification.")
55 | fAgentAuth.String("agent-user", "", "User to send to the backend as Basic auth.")
56 | fAgentAuth.String("agent-pass", "", "Password to send to the backend as Basic auth.")
57 | fAgentAuth.Bool("oauth-disabled", false, "If set, skip trying to perform the oauth dance.")
58 |
59 | fSBOM.String("sbom", "", "path to a sbom file (generated by minibridge scan sbom) to ensure server integrity.")
60 |
61 | fMCP.Int("mcp-uid", -1, "if greater than -1, use as UID to run the MCP server command.")
62 | fMCP.Int("mcp-gid", -1, "if greater than -1, use as GID to run the MCP server command.")
63 | fMCP.IntSlice("mcp-groups", nil, "additional GIDs to to run the MCP server command.")
64 | fMCP.Bool("mcp-use-tempdir", false, "if set, create a new temp execution dir for each MCP server instance.")
65 | fMCP.String("mcp-tls-ca", "", "when using SSE, path to a CA to valide MCP server server certificates.")
66 | fMCP.Bool("mcp-tls-insecure-skip-verify", false, "skip MCP server certificates validation. INSECURE.")
67 | }
68 |
--------------------------------------------------------------------------------
/pkgs/backend/client/stdio.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "bufio"
5 | "context"
6 | "fmt"
7 | "io"
8 | "log/slog"
9 | "os"
10 | "os/exec"
11 | "strings"
12 |
13 | "go.acuvity.ai/minibridge/pkgs/internal/sanitize"
14 | )
15 |
16 | var _ Client = (*stdioClient)(nil)
17 |
18 | type stdioClient struct {
19 | srv MCPServer
20 | cfg stdioCfg
21 | }
22 |
23 | // NewStdio returns a Client communicating through stdio.
24 | func NewStdio(srv MCPServer, options ...StdioOption) Client {
25 |
26 | cfg := newStdioCfg()
27 | for _, o := range options {
28 | o(&cfg)
29 | }
30 |
31 | return &stdioClient{
32 | srv: srv,
33 | cfg: cfg,
34 | }
35 | }
36 |
37 | func (c *stdioClient) Type() string {
38 | return "stdio"
39 | }
40 |
41 | func (c *stdioClient) Server() string { return c.srv.Command }
42 |
43 | func (c *stdioClient) Start(ctx context.Context, _ ...Option) (pipe *MCPStream, err error) {
44 |
45 | dir, err := os.Getwd()
46 | if err != nil {
47 | return nil, fmt.Errorf("unable to get current directory: %w", err)
48 | }
49 |
50 | if c.cfg.useTempDir {
51 | dir, err = os.MkdirTemp(os.TempDir(), "minibridge-")
52 | if err != nil {
53 | return nil, fmt.Errorf("unable to create tempdir: %w", err)
54 | }
55 | }
56 |
57 | cmd := exec.CommandContext(ctx, c.srv.Command, c.srv.Args...) // #nosec: G204
58 | cmd.Env = append(os.Environ(), c.srv.Env...)
59 | for i, e := range cmd.Env {
60 | cmd.Env[i] = strings.ReplaceAll(e, "_MINIBRIDGE_PREFIX_", dir)
61 | }
62 |
63 | cmd.Dir = dir
64 | cmd.Cancel = func() error {
65 | if err := cmd.Process.Signal(os.Interrupt); err != nil {
66 | return err
67 | }
68 |
69 | if c.cfg.useTempDir {
70 | return os.RemoveAll(dir)
71 | }
72 |
73 | return nil
74 | }
75 |
76 | setCaps(cmd, "", c.cfg.creds)
77 |
78 | slog.Debug("Client: starting command",
79 | "path", cmd.Path,
80 | "dir", cmd.Dir,
81 | "creds", c.cfg.creds,
82 | )
83 |
84 | stdin, err := cmd.StdinPipe()
85 | if err != nil {
86 | return nil, fmt.Errorf("unable to create stdin pipe: %w", err)
87 | }
88 |
89 | stdout, err := cmd.StdoutPipe()
90 | if err != nil {
91 | return nil, fmt.Errorf("unable to create stdout pipe: %w", err)
92 | }
93 |
94 | stderr, err := cmd.StderrPipe()
95 | if err != nil {
96 | return nil, fmt.Errorf("unable to create stderr pipe: %w", err)
97 | }
98 |
99 | stream := NewMCPStream(ctx)
100 |
101 | go c.readRequests(ctx, stdin, stream.stdin)
102 | go c.readResponses(ctx, stdout, stream.stdout)
103 | go c.readErrors(ctx, stderr, stream.stderr)
104 |
105 | if err := cmd.Start(); err != nil {
106 | return nil, fmt.Errorf("unable to start command: %w", err)
107 | }
108 |
109 | go func() { stream.exit <- cmd.Wait() }()
110 |
111 | return stream, nil
112 | }
113 |
114 | func (c *stdioClient) readRequests(ctx context.Context, stdin io.WriteCloser, ch chan []byte) {
115 |
116 | for {
117 | select {
118 | case <-ctx.Done():
119 | return
120 | case data := <-ch:
121 | if _, err := stdin.Write(append(sanitize.Data(data), '\n')); err != nil {
122 | slog.Error("Unable to write data to stdin", "err", err)
123 | return
124 | }
125 | }
126 | }
127 | }
128 |
129 | func (c *stdioClient) readResponses(ctx context.Context, stdout io.ReadCloser, ch chan []byte) {
130 |
131 | bstdout := bufio.NewReader(stdout)
132 | for {
133 | data, err := bstdout.ReadBytes('\n')
134 | if err != nil {
135 | if err != io.EOF {
136 | slog.Error("Unable to read response from stdout", "err", err)
137 | }
138 | return
139 | }
140 | select {
141 | case ch <- sanitize.Data(data):
142 | case <-ctx.Done():
143 | return
144 | }
145 | }
146 | }
147 |
148 | func (c *stdioClient) readErrors(ctx context.Context, stderr io.ReadCloser, ch chan []byte) {
149 |
150 | bstderr := bufio.NewReader(stderr)
151 | for {
152 | data, err := bstderr.ReadBytes('\n')
153 | if err != nil {
154 | if err != io.EOF {
155 | slog.Error("Unable to read error response from stderr", "err", err)
156 | }
157 | return
158 | }
159 | select {
160 | case ch <- sanitize.Data(data):
161 | case <-ctx.Done():
162 | return
163 | }
164 | }
165 | }
166 |
--------------------------------------------------------------------------------
/pkgs/policer/internal/rego/policer.go:
--------------------------------------------------------------------------------
1 | package rego
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "os"
7 | "strings"
8 | "time"
9 |
10 | "github.com/go-viper/mapstructure/v2"
11 | "github.com/open-policy-agent/opa/v1/ast"
12 | "github.com/open-policy-agent/opa/v1/rego"
13 | "go.acuvity.ai/minibridge/pkgs/mcp"
14 | "go.acuvity.ai/minibridge/pkgs/policer/api"
15 | )
16 |
17 | type Policer struct {
18 | queryAllow rego.PreparedEvalQuery
19 | queryReasons rego.PreparedEvalQuery
20 | queryMCP rego.PreparedEvalQuery
21 | }
22 |
23 | const RegoRuntimeEnvPrefix = "REGO_POLICY_RUNTIME_"
24 |
25 | // New returns a new Rego based Policer.
26 | func New(policy string) (*Policer, error) {
27 |
28 | comp, err := precompile(policy, "default")
29 | if err != nil {
30 | return nil, fmt.Errorf("unable to compile rego policy: %w", err)
31 | }
32 |
33 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
34 | defer cancel()
35 |
36 | rTerm := makeRegoRuntimeTerm()
37 |
38 | queryAllow, err := rego.New(rego.Compiler(comp), rego.Query("data.main.allow"), rego.Runtime(rTerm)).PrepareForEval(ctx)
39 | if err != nil {
40 | return nil, fmt.Errorf("unable to prepare rego deny query: %w", err)
41 | }
42 |
43 | queryReasons, err := rego.New(rego.Compiler(comp), rego.Query("reasons := data.main.reasons"), rego.Runtime(rTerm)).PrepareForEval(ctx)
44 | if err != nil {
45 | return nil, fmt.Errorf("unable to prepare rego deny query: %w", err)
46 | }
47 |
48 | queryMCP, err := rego.New(rego.Compiler(comp), rego.Query("mcp := data.main.mcp"), rego.Runtime(rTerm)).PrepareForEval(ctx)
49 | if err != nil {
50 | return nil, fmt.Errorf("unable to prepare rego mcp query: %w", err)
51 | }
52 |
53 | return &Policer{
54 | queryAllow: queryAllow,
55 | queryReasons: queryReasons,
56 | queryMCP: queryMCP,
57 | }, nil
58 | }
59 |
60 | func (p *Policer) Type() string { return "rego" }
61 |
62 | func (p *Policer) Police(ctx context.Context, preq api.Request) (*mcp.Message, error) {
63 |
64 | res, err := p.queryAllow.Eval(ctx, rego.EvalInput(preq), rego.EvalPrintHook(printer{}))
65 | if err != nil {
66 | return nil, fmt.Errorf("unable to eval allow query: %w", err)
67 | }
68 |
69 | if !res.Allowed() {
70 |
71 | res, err = p.queryReasons.Eval(ctx, rego.EvalInput(preq), rego.EvalPrintHook(printer{}))
72 | if err != nil {
73 | return nil, fmt.Errorf("unable to eval reasons query: %w", err)
74 | }
75 |
76 | reasons := []string{api.GenericDenyReason}
77 |
78 | if len(res) > 0 {
79 | bindings := res[0].Bindings
80 | breasons, _ := bindings["reasons"].([]any)
81 |
82 | if len(breasons) > 0 {
83 | reasons = make([]string, len(breasons))
84 | for i, v := range breasons {
85 | reasons[i], _ = v.(string)
86 | }
87 | }
88 | }
89 |
90 | return nil, fmt.Errorf("%w: %s", api.ErrBlocked, strings.Join(reasons, ", "))
91 | }
92 |
93 | res, err = p.queryMCP.Eval(ctx, rego.EvalInput(preq), rego.EvalPrintHook(printer{}))
94 | if err != nil {
95 | return nil, fmt.Errorf("unable to eval mcp query: %w", err)
96 | }
97 |
98 | if len(res) == 0 {
99 | return nil, nil
100 | }
101 |
102 | bindings := res[0].Bindings
103 |
104 | newmcp, ok := bindings["mcp"].(map[string]any)
105 | if !ok {
106 | return nil, fmt.Errorf("invalid binding: mcp must be an map[string]any, got %T", bindings["mcp"])
107 | }
108 |
109 | mcall := &mcp.Message{}
110 | if err := mapstructure.Decode(newmcp, mcall); err != nil {
111 | return nil, fmt.Errorf("unable to decode rego mcp into valid MCP call: %w", err)
112 | }
113 |
114 | mcall.ID = preq.MCP.ID
115 |
116 | return mcall, nil
117 | }
118 |
119 | // makeRegoRuntimeTerm create a rego ast Term
120 | // to expose prefixed env var to the rego runtime.
121 | func makeRegoRuntimeTerm() *ast.Term {
122 |
123 | env := ast.NewObject()
124 |
125 | for _, s := range os.Environ() {
126 | parts := strings.SplitN(s, "=", 2)
127 | if !strings.HasPrefix(parts[0], RegoRuntimeEnvPrefix) {
128 | continue
129 | }
130 | if len(parts) == 2 {
131 | env.Insert(ast.StringTerm(strings.ReplaceAll(parts[0], RegoRuntimeEnvPrefix, "")), ast.StringTerm(parts[1]))
132 | }
133 | }
134 |
135 | obj := ast.NewObject()
136 | obj.Insert(ast.StringTerm("env"), ast.NewTerm(env))
137 |
138 | return ast.NewTerm(obj)
139 | }
140 |
--------------------------------------------------------------------------------
/cli/internal/cmd/scan.go:
--------------------------------------------------------------------------------
1 | package cmd
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "os"
8 | "time"
9 |
10 | "github.com/spf13/cobra"
11 | "github.com/spf13/pflag"
12 | "github.com/spf13/viper"
13 | "go.acuvity.ai/minibridge/pkgs/backend/client"
14 | "go.acuvity.ai/minibridge/pkgs/scan"
15 | )
16 |
17 | var fScan = pflag.NewFlagSet("scan", pflag.ExitOnError)
18 |
19 | func init() {
20 |
21 | initSharedFlagSet()
22 |
23 | fScan.Duration("timeout", 2*time.Minute, "maximum time to allow the scan to run.")
24 | fScan.Bool("exclude-resources", false, "exclude resources from scan")
25 | fScan.Bool("exclude-tools", false, "exclude tools from scan")
26 | fScan.Bool("exclude-prompts", false, "exclude prompts from scan")
27 |
28 | Scan.Flags().AddFlagSet(fScan)
29 | Scan.Flags().AddFlagSet(fMCP)
30 | Scan.Flags().AddFlagSet(fAgentAuth)
31 | }
32 |
33 | // Scan is the cobra command to run the server.
34 | var Scan = &cobra.Command{
35 | Use: "scan [dump|sbom|check file.sbom] -- command [args...]",
36 | Short: "Scan an MCP server for resources, prompts, etc and generate sbom",
37 | SilenceUsage: true,
38 | SilenceErrors: true,
39 | TraverseChildren: true,
40 | Args: cobra.MinimumNArgs(2),
41 |
42 | RunE: func(cmd *cobra.Command, args []string) error {
43 |
44 | timeout := viper.GetDuration("timeout")
45 |
46 | exclusions := &scan.Exclusions{
47 | Prompts: viper.GetBool("exclude-prompts"),
48 | Resources: viper.GetBool("exclude-resources"),
49 | Tools: viper.GetBool("exclude-tools"),
50 | }
51 |
52 | var ctx context.Context
53 | var cancel context.CancelFunc
54 |
55 | if timeout > 0 {
56 | ctx, cancel = context.WithTimeout(cmd.Context(), timeout)
57 | } else {
58 | ctx, cancel = context.WithCancel(cmd.Context())
59 | }
60 | defer cancel()
61 |
62 | var err error
63 | var mcpCommand string
64 | var mcpArgs []string
65 | if args[0] == "check" {
66 | mcpCommand = args[2]
67 | mcpArgs = args[3:]
68 | } else {
69 | mcpCommand = args[1]
70 | mcpArgs = args[2:]
71 | }
72 |
73 | mcpClient, err := makeMCPClient(append([]string{mcpCommand}, mcpArgs...), false)
74 | if err != nil {
75 | return err
76 | }
77 |
78 | agentAuth, err := makeAgentAuth(false)
79 | if err != nil {
80 | return fmt.Errorf("unable to build auth: %w", err)
81 | }
82 |
83 | stream, err := mcpClient.Start(ctx, client.OptionAuth(agentAuth))
84 | if err != nil {
85 | return fmt.Errorf("unable to start MCP server: %w", err)
86 | }
87 |
88 | dump, err := scan.DumpAll(ctx, stream, exclusions)
89 | if err != nil {
90 | return fmt.Errorf("unable to dump tools: %w", err)
91 | }
92 |
93 | cancel()
94 |
95 | var toolHashes scan.Hashes
96 |
97 | if !exclusions.Tools {
98 | toolHashes, err = scan.HashTools(dump.Tools)
99 | if err != nil {
100 | return fmt.Errorf("unable to hash tools: %w", err)
101 | }
102 | }
103 |
104 | var promptHashes scan.Hashes
105 |
106 | if !exclusions.Prompts {
107 | promptHashes, err = scan.HashPrompts(dump.Prompts)
108 | if err != nil {
109 | return fmt.Errorf("unable to hash prompts: %w", err)
110 | }
111 | }
112 |
113 | sbom := scan.SBOM{
114 | Tools: toolHashes,
115 | Prompts: promptHashes,
116 | }
117 |
118 | switch args[0] {
119 |
120 | case "check":
121 |
122 | refSBOM, err := scan.LoadSBOM(args[1])
123 | if err != nil {
124 | return fmt.Errorf("unable to load sbom: %w", err)
125 | }
126 |
127 | if err := refSBOM.Tools.Matches(sbom.Tools); err != nil {
128 | return fmt.Errorf("tools sbom does not match: %w", err)
129 | }
130 |
131 | if err := refSBOM.Prompts.Matches(sbom.Prompts); err != nil {
132 | return fmt.Errorf("prompts sbom does not match: %w", err)
133 | }
134 |
135 | case "sbom":
136 |
137 | enc := json.NewEncoder(os.Stdout)
138 | enc.SetIndent("", " ")
139 | if err := enc.Encode(sbom); err != nil {
140 | return fmt.Errorf("unable to encode sbom: %w", err)
141 | }
142 |
143 | case "dump":
144 |
145 | enc := json.NewEncoder(os.Stdout)
146 | enc.SetIndent("", " ")
147 | if err := enc.Encode(dump); err != nil {
148 | return fmt.Errorf("unable to encode dump: %w", err)
149 | }
150 |
151 | default:
152 | return fmt.Errorf("first command must be either dump, sbom or check")
153 | }
154 |
155 | return nil
156 | },
157 | }
158 |
--------------------------------------------------------------------------------
/cli/internal/cmd/frontend.go:
--------------------------------------------------------------------------------
1 | package cmd
2 |
3 | import (
4 | "fmt"
5 | "log/slog"
6 | "strings"
7 |
8 | "github.com/spf13/cobra"
9 | "github.com/spf13/pflag"
10 | "github.com/spf13/viper"
11 | "go.acuvity.ai/minibridge/pkgs/frontend"
12 | )
13 |
14 | var fFrontend = pflag.NewFlagSet("", pflag.ExitOnError)
15 |
16 | func init() {
17 |
18 | initSharedFlagSet()
19 |
20 | fFrontend.StringP("listen", "l", "", "listen address of the bridge for incoming connections. If this is unset, stdio is used.")
21 | fFrontend.StringP("backend", "A", "", "URL of the minibridge backend to connect to.")
22 | fFrontend.String("endpoint-mcp", "/mcp", "when using HTTP, sets the endpoint to send messages (proto 2025-03-26).")
23 | fFrontend.String("endpoint-messages", "/message", "when using HTTP, sets the endpoint to post messages.")
24 | fFrontend.String("endpoint-sse", "/sse", "when using HTTP, sets the endpoint to connect to the event stream.")
25 | fAgentAuth.BoolP("agent-auth-passthrough", "b", false, "Forwards incoming HTTP Authorization header to the minibridge backend as-is.")
26 |
27 | Frontend.Flags().AddFlagSet(fFrontend)
28 | Frontend.Flags().AddFlagSet(fTLSClient)
29 | Frontend.Flags().AddFlagSet(fTLSServer)
30 | Frontend.Flags().AddFlagSet(fHealth)
31 | Frontend.Flags().AddFlagSet(fProfiler)
32 | Frontend.Flags().AddFlagSet(fCORS)
33 | Frontend.Flags().AddFlagSet(fAgentAuth)
34 | }
35 |
36 | // Frontend is the cobra command to run the client.
37 | var Frontend = &cobra.Command{
38 | Use: "frontend",
39 | Short: "Start a minibridge frontend to connect to a minibridge backend",
40 | SilenceUsage: true,
41 | SilenceErrors: true,
42 | TraverseChildren: true,
43 |
44 | RunE: func(cmd *cobra.Command, args []string) error {
45 |
46 | listen := viper.GetString("listen")
47 | backendURL := viper.GetString("backend")
48 | mcpEndpoint := viper.GetString("endpoint-mcp")
49 | sseEndpoint := viper.GetString("endpoint-sse")
50 | messageEndpoint := viper.GetString("endpoint-messages")
51 | agentAuthPassthrough := viper.GetBool("agent-auth-passthrough")
52 |
53 | if backendURL == "" {
54 | return fmt.Errorf("--backend must be set")
55 | }
56 | if !strings.HasPrefix(backendURL, "wss://") && !strings.HasPrefix(backendURL, "ws://") {
57 | return fmt.Errorf("--backend must use wss:// or ws:// scheme")
58 | }
59 | if !strings.HasSuffix(backendURL, "/ws") {
60 | backendURL = backendURL + "/ws"
61 | }
62 |
63 | agentAuth, err := makeAgentAuth(true)
64 | if err != nil {
65 | return fmt.Errorf("unable to build auth: %w", err)
66 | }
67 |
68 | clientTLSConfig, err := tlsConfigFromFlags(fTLSClient)
69 | if err != nil {
70 | return err
71 | }
72 |
73 | tracer, err := makeTracer(cmd.Context(), "backend")
74 | if err != nil {
75 | return fmt.Errorf("unable to configure tracer: %w", err)
76 | }
77 |
78 | corsPolicy := makeCORSPolicy()
79 |
80 | mm := startHealthServer(cmd.Context())
81 |
82 | var mfrontend frontend.Frontend
83 |
84 | if listen != "" {
85 |
86 | serverTLSConfig, err := tlsConfigFromFlags(fTLSServer)
87 | if err != nil {
88 | return err
89 | }
90 |
91 | slog.Info("Minibridge frontend configured",
92 | "backend", backendURL,
93 | "mcp", mcpEndpoint,
94 | "sse", sseEndpoint,
95 | "messages", messageEndpoint,
96 | "mode", "http",
97 | "server-tls", serverTLSConfig != nil,
98 | "server-mtls", mtlsMode(serverTLSConfig),
99 | "client-tls", clientTLSConfig != nil,
100 | "listen", listen,
101 | )
102 |
103 | mfrontend = frontend.NewHTTP(listen, backendURL, serverTLSConfig, clientTLSConfig,
104 | frontend.OptHTTPMCPEndpoint(mcpEndpoint),
105 | frontend.OptHTTPSSEEndpoint(sseEndpoint),
106 | frontend.OptHTTPMessageEndpoint(messageEndpoint),
107 | frontend.OptHTTPAgentTokenPassthrough(agentAuthPassthrough),
108 | frontend.OptHTTPCORSPolicy(corsPolicy),
109 | frontend.OptHTTPMetricsManager(mm),
110 | frontend.OptHTTPTracer(tracer),
111 | )
112 |
113 | } else {
114 |
115 | slog.Info("Minibridge frontend configured",
116 | "backend", backendURL,
117 | "mode", "stdio",
118 | )
119 |
120 | mfrontend = frontend.NewStdio(backendURL, clientTLSConfig,
121 | frontend.OptStdioTracer(tracer),
122 | frontend.OptStdioRetry(false),
123 | )
124 | }
125 |
126 | return startFrontendWithOAuth(cmd.Context(), mfrontend, agentAuth)
127 | },
128 | }
129 |
--------------------------------------------------------------------------------
/pkgs/frontend/options.go:
--------------------------------------------------------------------------------
1 | package frontend
2 |
3 | import (
4 | "context"
5 | "net"
6 |
7 | "go.acuvity.ai/bahamut"
8 | "go.acuvity.ai/minibridge/pkgs/metrics"
9 | "go.opentelemetry.io/otel/trace"
10 | "go.opentelemetry.io/otel/trace/noop"
11 | )
12 |
13 | type httpCfg struct {
14 | mcpEndpoint string
15 | sseEndpoint string
16 | messagesEndpoint string
17 | agentTokenPassthrough bool
18 | corsPolicy *bahamut.CORSPolicy
19 | metricsManager *metrics.Manager
20 | tracer trace.Tracer
21 | backendDialer func(ctx context.Context, network, addr string) (net.Conn, error)
22 | oauthEndpointRegister string
23 | oauthEndpointAuthorize string
24 | oauthEndpointToken string
25 | }
26 |
27 | func newHTTPCfg() httpCfg {
28 | return httpCfg{
29 | mcpEndpoint: "/mcp",
30 | sseEndpoint: "/sse",
31 | messagesEndpoint: "/message",
32 | oauthEndpointRegister: "/register",
33 | oauthEndpointAuthorize: "/authorize",
34 | oauthEndpointToken: "/token",
35 | tracer: noop.NewTracerProvider().Tracer("noop"),
36 | }
37 | }
38 |
39 | // OptHTTP are options that can be given to NewSSE().
40 | type OptHTTP func(*httpCfg)
41 |
42 | // OptHTTPMCPEndpoint sets the mcp endpoint (protocol 2025-03-26)
43 | // where agents can connect to the response stream.
44 | // Defaults to /mcp
45 | func OptHTTPMCPEndpoint(ep string) OptHTTP {
46 | return func(cfg *httpCfg) {
47 | cfg.mcpEndpoint = ep
48 | }
49 | }
50 |
51 | // OptHTTPSSEEndpoint sets the sse endpoint (protocol 2024-11-05)
52 | // where agents can connect to the response stream.
53 | // Defaults to /sse
54 | func OptHTTPSSEEndpoint(ep string) OptHTTP {
55 | return func(cfg *httpCfg) {
56 | cfg.sseEndpoint = ep
57 | }
58 | }
59 |
60 | // OptHTTPMessageEndpoint sets the message endpoint (protocol 2024-11-05)
61 | // where agents can post request.
62 | // Defaults to /messages
63 | func OptHTTPMessageEndpoint(ep string) OptHTTP {
64 | return func(cfg *httpCfg) {
65 | cfg.messagesEndpoint = ep
66 | }
67 | }
68 |
69 | // OptHTTPCORSPolicy sets the bahamut.CORSPolicy to use for
70 | // connection originating from a webrowser.
71 | func OptHTTPCORSPolicy(policy *bahamut.CORSPolicy) OptHTTP {
72 | return func(cfg *httpCfg) {
73 | cfg.corsPolicy = policy
74 | }
75 | }
76 |
77 | // OptHTTPAgentTokenPassthrough decides if the HTTP request Authorization header
78 | // should be passed as-is to the minibridge backend.
79 | func OptHTTPAgentTokenPassthrough(passthrough bool) OptHTTP {
80 | return func(cfg *httpCfg) {
81 | cfg.agentTokenPassthrough = passthrough
82 | }
83 | }
84 |
85 | // OptHTTPMetricsManager sets the metric manager to use to collect
86 | // prometheus metrics.
87 | func OptHTTPMetricsManager(m *metrics.Manager) OptHTTP {
88 | return func(cfg *httpCfg) {
89 | cfg.metricsManager = m
90 | }
91 | }
92 |
93 | // OptHTTPTracer sets the otel trace.Tracer to use to trace requests
94 | func OptHTTPTracer(tracer trace.Tracer) OptHTTP {
95 | return func(cfg *httpCfg) {
96 | if tracer == nil {
97 | tracer = noop.NewTracerProvider().Tracer("noop")
98 | }
99 | cfg.tracer = tracer
100 | }
101 | }
102 |
103 | // OptHTTPBackendDialer sets the dialer to use to connect to the backend.
104 | func OptHTTPBackendDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) OptHTTP {
105 | return func(cfg *httpCfg) {
106 | cfg.backendDialer = dialer
107 | }
108 | }
109 |
110 | type stdioCfg struct {
111 | retry bool
112 | tracer trace.Tracer
113 | backendDialer func(ctx context.Context, network, addr string) (net.Conn, error)
114 | }
115 |
116 | func newStdioCfg() stdioCfg {
117 | return stdioCfg{
118 | retry: true,
119 | tracer: noop.NewTracerProvider().Tracer("noop"),
120 | }
121 | }
122 |
123 | // OptStdio are options that can be given to NewStdio().
124 | type OptStdio func(*stdioCfg)
125 |
126 | // OptStdioRetry allows to control if the Stdio frontend
127 | // should retry or not after a wbesocket connection failure.
128 | func OptStdioRetry(retry bool) OptStdio {
129 | return func(cfg *stdioCfg) {
130 | cfg.retry = retry
131 | }
132 | }
133 |
134 | // OptStdioTracer sets the otel trace.Tracer to use to trace requests
135 | func OptStdioTracer(tracer trace.Tracer) OptStdio {
136 | return func(cfg *stdioCfg) {
137 | if tracer == nil {
138 | tracer = noop.NewTracerProvider().Tracer("noop")
139 | }
140 | cfg.tracer = tracer
141 | }
142 | }
143 |
144 | // OptStdioBackendDialer sets the dialer to use to connect to the backend.
145 | func OptStdioBackendDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) OptStdio {
146 | return func(cfg *stdioCfg) {
147 | cfg.backendDialer = dialer
148 | }
149 | }
150 |
--------------------------------------------------------------------------------
/pkgs/backend/ws_test.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net"
7 | "testing"
8 | "time"
9 |
10 | . "github.com/smartystreets/goconvey/convey"
11 | "go.acuvity.ai/minibridge/pkgs/backend/client"
12 | "go.acuvity.ai/minibridge/pkgs/frontend"
13 | "go.acuvity.ai/minibridge/pkgs/policer"
14 | "go.acuvity.ai/wsc"
15 | )
16 |
17 | func freePort() int {
18 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
19 | if err != nil {
20 | panic(err)
21 | }
22 |
23 | l, err := net.ListenTCP("tcp", addr)
24 | if err != nil {
25 | panic(err)
26 | }
27 | defer func() { _ = l.Close() }()
28 | return l.Addr().(*net.TCPAddr).Port
29 | }
30 |
31 | func startBackend(ctx context.Context, opts ...Option) (wsc.Websocket, error) {
32 |
33 | backendListen := fmt.Sprintf("127.0.0.1:%d", freePort())
34 |
35 | srv, err := client.NewMCPServer("cat")
36 | if err != nil {
37 | return nil, err
38 | }
39 | client := client.NewStdio(srv)
40 |
41 | backend := NewWebSocket(backendListen, nil, client, opts...)
42 |
43 | go func() {
44 | if err := backend.Start(ctx); err != nil {
45 | panic(err)
46 | }
47 | }()
48 |
49 | <-time.After(time.Second) // wait a bit.. gh workers are slow
50 |
51 | ws, err := frontend.Connect(ctx, nil, fmt.Sprintf("ws://%s/ws", backendListen), nil, frontend.AgentInfo{UserAgent: "go-test"})
52 | if err != nil {
53 | return nil, err
54 | }
55 |
56 | return ws, nil
57 | }
58 |
59 | func TestWS(t *testing.T) {
60 |
61 | Convey("Given a ws backend without policer or tls", t, func() {
62 |
63 | ctx, cancel := context.WithCancel(t.Context())
64 | defer cancel()
65 |
66 | ws, err := startBackend(ctx)
67 | So(err, ShouldBeNil)
68 |
69 | echo := `{"hello": "world"}`
70 | ws.Write([]byte(echo))
71 |
72 | var data []byte
73 | select {
74 | case data = <-ws.Read():
75 | case <-time.After(time.Second):
76 | }
77 |
78 | So(string(data), ShouldEqual, echo)
79 |
80 | echo = `not-json`
81 | ws.Write([]byte(echo))
82 |
83 | select {
84 | case data = <-ws.Read():
85 | case <-time.After(time.Second):
86 | }
87 |
88 | So(string(data), ShouldEqual, `{"error":{"code":500,"message":"unable to decode application/json: json decode error [pos 4]: expecting ot-: got ull"},"jsonrpc":"2.0"}`)
89 | })
90 |
91 | Convey("Given a ws backend with a rego policer that denies the call", t, func() {
92 |
93 | ctx, cancel := context.WithCancel(t.Context())
94 | defer cancel()
95 |
96 | policer, err := policer.NewRego(`package main
97 | import rego.v1
98 | default allow := false
99 | reasons contains "you can't do that, Dave"
100 | `)
101 | So(err, ShouldBeNil)
102 |
103 | ws, err := startBackend(ctx, OptPolicer(policer))
104 | So(err, ShouldBeNil)
105 |
106 | echo := `{"jsonrpc": "2.0", "id": 2}`
107 | ws.Write([]byte(echo))
108 |
109 | var data []byte
110 | select {
111 | case data = <-ws.Read():
112 | case <-time.After(time.Second):
113 | }
114 |
115 | So(string(data), ShouldEqual, `{"error":{"code":451,"message":"request blocked: you can't do that, Dave"},"id":2,"jsonrpc":"2.0"}`)
116 | })
117 |
118 | Convey("Given a ws backend with a rego policer that allows the call without mutation", t, func() {
119 |
120 | ctx, cancel := context.WithCancel(t.Context())
121 | defer cancel()
122 |
123 | policer, err := policer.NewRego(`package main
124 | import rego.v1
125 | default allow := true
126 | `)
127 | So(err, ShouldBeNil)
128 |
129 | ws, err := startBackend(ctx, OptPolicer(policer))
130 | So(err, ShouldBeNil)
131 |
132 | echo := `{"jsonrpc": "2.0", "id": 1}`
133 | ws.Write([]byte(echo))
134 |
135 | var data []byte
136 | select {
137 | case data = <-ws.Read():
138 | case <-time.After(time.Second):
139 | }
140 |
141 | So(string(data), ShouldEqual, `{"jsonrpc": "2.0", "id": 1}`)
142 | })
143 |
144 | Convey("Given a ws backend with a rego policer that allows the call with mutation", t, func() {
145 |
146 | ctx, cancel := context.WithCancel(t.Context())
147 | defer cancel()
148 |
149 | policer, err := policer.NewRego(`package main
150 | import rego.v1
151 | default allow := true
152 |
153 | mcp := x if {
154 | x := json.patch(input.mcp, [{
155 | "op": "replace",
156 | "path": "/result/hello",
157 | "value": "world"
158 | }, {
159 | "op": "replace",
160 | "path": "/id",
161 | "value": 2,
162 | }])
163 | }
164 | `)
165 | So(err, ShouldBeNil)
166 |
167 | ws, err := startBackend(ctx, OptPolicer(policer))
168 | So(err, ShouldBeNil)
169 |
170 | echo := `{"id": 1, "jsonrpc": "2.0", "result": {"hello": "monde"}}`
171 | ws.Write([]byte(echo))
172 |
173 | var data []byte
174 | select {
175 | case data = <-ws.Read():
176 | case <-time.After(time.Second):
177 | }
178 |
179 | So(string(data), ShouldEqual, `{"id":1,"jsonrpc":"2.0","result":{"hello":"world"}}`)
180 | })
181 | }
182 |
--------------------------------------------------------------------------------
/pkgs/backend/client/stdio_test.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "os"
6 | "testing"
7 | "time"
8 |
9 | . "github.com/smartystreets/goconvey/convey"
10 | )
11 |
12 | func TestStdioClient(t *testing.T) {
13 |
14 | Convey("Type is correct", t, func() {
15 | cl := NewStdio(MCPServer{})
16 | So(cl.Type(), ShouldEqual, "stdio")
17 | })
18 |
19 | Convey("Given I have a client cat and I send lots of trailing \n", t, func() {
20 |
21 | srv := MCPServer{
22 | Command: "cat",
23 | Env: []string{"A=A"},
24 | }
25 | cl := NewStdio(srv)
26 |
27 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
28 | defer cancel()
29 |
30 | stream, err := cl.Start(ctx)
31 | So(err, ShouldBeNil)
32 | So(stream, ShouldNotBeNil)
33 |
34 | out, unregister := stream.Stdout()
35 | defer unregister()
36 |
37 | stream.Stdin() <- []byte("hello world\r\n\n\n")
38 |
39 | data := <-out
40 | So(data, ShouldResemble, []byte("hello world"))
41 | })
42 |
43 | Convey("Given I have a client with env", t, func() {
44 |
45 | srv := MCPServer{
46 | Command: "sh",
47 | Args: []string{"-c", "echo $MTEST"},
48 | Env: []string{"MTEST=HELLO"},
49 | }
50 | cl := NewStdio(srv)
51 |
52 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
53 | defer cancel()
54 |
55 | stream, err := cl.Start(ctx)
56 | So(err, ShouldBeNil)
57 | So(stream, ShouldNotBeNil)
58 |
59 | out, unregisterOut := stream.Stdout()
60 | defer unregisterOut()
61 |
62 | exit, unregisterExit := stream.Exit()
63 | defer unregisterExit()
64 |
65 | data := <-out
66 | So(string(data), ShouldEqual, "HELLO")
67 |
68 | err = <-exit
69 | So(err, ShouldBeNil)
70 | })
71 |
72 | Convey("Given I have a client to which I give an invalid server", t, func() {
73 |
74 | srv := MCPServer{
75 | Command: "dog",
76 | }
77 | cl := NewStdio(srv)
78 |
79 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
80 | defer cancel()
81 |
82 | stream, err := cl.Start(ctx)
83 | So(err, ShouldNotBeNil)
84 | So(err.Error(), ShouldEqual, `unable to start command: exec: "dog": executable file not found in $PATH`)
85 | So(stream, ShouldBeNil)
86 | })
87 |
88 | Convey("Given I have a client with a command that exits unexpectedly", t, func() {
89 |
90 | srv := MCPServer{
91 | Command: "bash",
92 | Args: []string{"-c", "sleep 1 && exit 1"},
93 | }
94 | cl := NewStdio(srv)
95 |
96 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
97 | defer cancel()
98 |
99 | stream, err := cl.Start(ctx)
100 | So(err, ShouldBeNil)
101 | So(stream, ShouldNotBeNil)
102 |
103 | exit, unregister := stream.Exit()
104 | defer unregister()
105 |
106 | time.Sleep(1050 * time.Millisecond)
107 |
108 | err = <-exit
109 | So(err, ShouldNotBeNil)
110 | So(err.Error(), ShouldEqual, "exit status 1")
111 | })
112 |
113 | Convey("Given I have a client that writes a file", t, func() {
114 |
115 | srv := MCPServer{
116 | Command: "sh",
117 | Args: []string{"-c", "touch testfile"},
118 | }
119 | cl := NewStdio(srv)
120 |
121 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
122 | defer cancel()
123 |
124 | stream, err := cl.Start(ctx)
125 | So(err, ShouldBeNil)
126 | So(stream, ShouldNotBeNil)
127 |
128 | exit, unregister := stream.Exit()
129 | defer unregister()
130 |
131 | So(<-exit, ShouldBeNil)
132 |
133 | _, err = os.Stat("testfile")
134 | So(err, ShouldBeNil)
135 | _ = os.RemoveAll("testfile")
136 | })
137 |
138 | Convey("Given I have a client that writes a file with tempdir", t, func() {
139 |
140 | srv := MCPServer{
141 | Command: "sh",
142 | Args: []string{"-c", "touch testfile"},
143 | }
144 | cl := NewStdio(srv, OptStdioUseTempDir(true))
145 |
146 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
147 | defer cancel()
148 |
149 | stream, err := cl.Start(ctx)
150 | So(err, ShouldBeNil)
151 | So(stream, ShouldNotBeNil)
152 |
153 | exit, unregister := stream.Exit()
154 | defer unregister()
155 |
156 | So(<-exit, ShouldBeNil)
157 |
158 | _, err = os.Stat("testfile")
159 | So(err.Error(), ShouldEqual, "stat testfile: no such file or directory")
160 | })
161 |
162 | Convey("Given I have a running client and an expiring context", t, func() {
163 |
164 | srv := MCPServer{
165 | Command: "cat",
166 | }
167 | cl := NewStdio(srv)
168 |
169 | ctx, cancel := context.WithCancel(t.Context())
170 | defer cancel()
171 |
172 | stream, err := cl.Start(ctx)
173 | So(err, ShouldBeNil)
174 | So(stream, ShouldNotBeNil)
175 |
176 | exit, unregister := stream.Exit()
177 | defer unregister()
178 |
179 | time.Sleep(300 * time.Millisecond)
180 | cancel()
181 |
182 | err = <-exit
183 | So(err, ShouldNotBeNil)
184 | So(err.Error(), ShouldEqual, "signal: interrupt")
185 | })
186 | }
187 |
--------------------------------------------------------------------------------
/pkgs/frontend/internal/session/session.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "log/slog"
5 | "maps"
6 | "sync"
7 | "time"
8 |
9 | "go.acuvity.ai/wsc"
10 | )
11 |
12 | const defaultDeadlineDuration = 10 * time.Second
13 |
14 | // A Session represents an active agent session.
15 | //
16 | // It contains the underlying websocket to communicate
17 | // with the correct MCP server instance running in minibridge backend.
18 | //
19 | // The session must be acquired through a session manager when used
20 | // then a channel must be registered to receive the data coming from the
21 | // backend.
22 | //
23 | // When not in use, the Session must be released though the session manager
24 | // and all hooks must be unregistered.
25 | // When the the count is at 1 (only the initial aqcuirement), the session will
26 | // linger for the duration of deadline, after which the websocket will be closed
27 | // and so the MCP server running in the backend will be temrinated.
28 | type Session struct {
29 | closeCh chan error
30 | count int
31 | countLock sync.RWMutex
32 | nextDeadline time.Duration
33 | deadline time.Time
34 | deadlineLock sync.RWMutex
35 | h uint64
36 | hookLock sync.RWMutex
37 | hooks map[chan []byte]struct{}
38 | id string
39 | ws wsc.Websocket
40 | }
41 |
42 | // New retutns a new session backed by the given websocket.
43 | func New(ws wsc.Websocket, credsHash uint64, sid string) *Session {
44 | return newSession(ws, credsHash, sid, defaultDeadlineDuration)
45 | }
46 |
47 | func newSession(ws wsc.Websocket, credsHash uint64, sid string, deadline time.Duration) *Session {
48 |
49 | s := &Session{
50 | ws: ws,
51 | h: credsHash,
52 | count: 1,
53 | id: sid,
54 | deadline: time.Now().Add(deadline),
55 | nextDeadline: deadline,
56 | closeCh: make(chan error),
57 | hooks: map[chan []byte]struct{}{},
58 | }
59 |
60 | slog.Debug("session created", "sid", s.id, "c", s.count)
61 | s.start()
62 |
63 | return s
64 | }
65 |
66 | // ID returns the session ID.
67 | func (s *Session) ID() string {
68 | return s.id
69 | }
70 |
71 | // Write writes to the underlying websocket and
72 | // advances the deadline
73 | func (s *Session) Write(data []byte) {
74 | s.setDeadline(time.Now().Add(s.nextDeadline))
75 | s.ws.Write(data)
76 | }
77 |
78 | // Done returns a channel that will receive
79 | // an error (or nil) when the websocket closes
80 | // for any reason.
81 | func (s *Session) Done() chan error {
82 | return s.closeCh
83 | }
84 |
85 | // ValidateHash validates the session hash.
86 | func (s *Session) ValidateHash(h uint64) bool {
87 | return h == s.h
88 | }
89 |
90 | // Close closes the websocket
91 | func (s *Session) Close() {
92 | s.ws.Close(1001)
93 | }
94 |
95 | func (s *Session) acquire() {
96 | s.countLock.Lock()
97 | defer s.countLock.Unlock()
98 |
99 | s.count++
100 |
101 | s.setDeadline(time.Now().Add(s.nextDeadline))
102 | slog.Debug("session acquired", "sid", s.id, "c", s.count)
103 | }
104 |
105 | func (s *Session) release() bool {
106 | s.countLock.Lock()
107 | defer s.countLock.Unlock()
108 |
109 | s.count--
110 | slog.Debug("session released", "sid", s.id, "c", s.count, "deleted", s.count <= 0)
111 |
112 | if s.count <= 0 {
113 | s.Close()
114 | return true
115 | }
116 |
117 | return false
118 | }
119 |
120 | func (s *Session) register(c chan []byte) {
121 | s.hookLock.Lock()
122 | defer s.hookLock.Unlock()
123 |
124 | s.hooks[c] = struct{}{}
125 | }
126 |
127 | func (s *Session) unregister(c chan []byte) {
128 | s.hookLock.Lock()
129 | defer s.hookLock.Unlock()
130 |
131 | delete(s.hooks, c)
132 | }
133 |
134 | func (s *Session) start() {
135 |
136 | go func() {
137 |
138 | ticker := time.NewTicker(time.Second)
139 |
140 | for {
141 | select {
142 |
143 | case data := <-s.ws.Read():
144 |
145 | for c := range s.getHooks() {
146 | select {
147 | case c <- data:
148 | default:
149 | slog.Error("Session sent data to inactive hook")
150 | }
151 | }
152 |
153 | s.setDeadline(time.Now().Add(s.nextDeadline))
154 |
155 | case err := <-s.ws.Done():
156 | s.release()
157 | s.closeCh <- err
158 | return
159 |
160 | case <-ticker.C:
161 | if time.Now().After(s.getDeadline()) && s.getCount() <= 1 {
162 | slog.Debug("session terminated: deadline exceeded", "sid", s.ID())
163 | s.release()
164 | }
165 | }
166 | }
167 | }()
168 | }
169 |
170 | func (s *Session) getDeadline() time.Time {
171 | s.deadlineLock.RLock()
172 | defer s.deadlineLock.RUnlock()
173 | return s.deadline
174 | }
175 |
176 | func (s *Session) setDeadline(deadline time.Time) {
177 | s.deadlineLock.Lock()
178 | defer s.deadlineLock.Unlock()
179 | s.deadline = deadline
180 | }
181 |
182 | func (s *Session) getCount() int {
183 | s.countLock.RLock()
184 | defer s.countLock.RUnlock()
185 | return s.count
186 | }
187 |
188 | func (s *Session) getHooks() map[chan []byte]struct{} {
189 | s.hookLock.RLock()
190 | defer s.hookLock.RUnlock()
191 | return maps.Clone(s.hooks)
192 | }
193 |
--------------------------------------------------------------------------------
/pkgs/frontend/stdio.go:
--------------------------------------------------------------------------------
1 | package frontend
2 |
3 | import (
4 | "bufio"
5 | "context"
6 | "crypto/tls"
7 | "fmt"
8 | "log/slog"
9 | "net/http"
10 | "net/url"
11 | "os"
12 | "os/user"
13 | "time"
14 |
15 | "go.acuvity.ai/minibridge/pkgs/auth"
16 | "go.acuvity.ai/minibridge/pkgs/info"
17 | "go.acuvity.ai/minibridge/pkgs/internal/sanitize"
18 | )
19 |
20 | var _ Frontend = (*stdioFrontend)(nil)
21 |
22 | type stdioFrontend struct {
23 | u *url.URL
24 | agentAuth *auth.Auth
25 | backendURL string
26 | cfg stdioCfg
27 | claims []string
28 | tlsClientConfig *tls.Config
29 | user string
30 | wsWrite chan []byte
31 | }
32 |
33 | // NewStdio returns a new *StdioProxy that will connect to the given
34 | // endpoint using the given tlsConfig. Agents can write request to stdin and read
35 | // responses from stdout. stderr contains the logs.
36 | //
37 | // A single session to the backend will be created and it will
38 | // reconnect in case of disconnection.
39 | func NewStdio(backend string, tlsConfig *tls.Config, opts ...OptStdio) Frontend {
40 |
41 | cfg := newStdioCfg()
42 | for _, o := range opts {
43 | o(&cfg)
44 | }
45 |
46 | u, err := url.Parse(backend)
47 | if err != nil {
48 | panic(err)
49 | }
50 |
51 | return &stdioFrontend{
52 | u: u,
53 | backendURL: backend,
54 | tlsClientConfig: tlsConfig,
55 | wsWrite: make(chan []byte),
56 | cfg: cfg,
57 | }
58 | }
59 |
60 | // Start starts the proxy. It will run until the given context is canceled or until
61 | // the server returns an error.
62 | func (p *stdioFrontend) Start(ctx context.Context, agentAuth *auth.Auth) error {
63 |
64 | p.agentAuth = agentAuth
65 |
66 | user, err := user.Current()
67 | if err != nil {
68 | return fmt.Errorf("unable to get current user: %w", err)
69 | }
70 |
71 | host, err := os.Hostname()
72 | if err != nil {
73 | return fmt.Errorf("unable to get current hostname: %w", err)
74 | }
75 |
76 | p.user = fmt.Sprintf("%s@%s", user.Username, host)
77 | p.claims = []string{
78 | fmt.Sprintf("gid=%s", user.Gid),
79 | fmt.Sprintf("uid=%s", user.Uid),
80 | fmt.Sprintf("username=%s", user.Username),
81 | fmt.Sprintf("hostname=%s", host),
82 | "minibridge=stdio",
83 | }
84 |
85 | slog.Debug("Local machine user set", "user", p.user, "claims", p.claims)
86 |
87 | subctx, cancel := context.WithCancel(ctx)
88 | defer cancel()
89 |
90 | errCh := make(chan error, 2)
91 |
92 | go func() { errCh <- p.stdiopump(subctx) }()
93 | go func() { errCh <- p.wspump(subctx) }()
94 |
95 | return <-errCh
96 | }
97 |
98 | func (p *stdioFrontend) HTTPClient() *http.Client {
99 | return &http.Client{
100 | Transport: &http.Transport{
101 | TLSClientConfig: p.tlsClientConfig,
102 | DialContext: p.cfg.backendDialer,
103 | },
104 | }
105 | }
106 |
107 | func (p *stdioFrontend) BackendURL() string {
108 |
109 | scheme := "http"
110 | if p.u.Scheme == "wss" || p.u.Scheme == "https" {
111 | scheme = "https"
112 | }
113 |
114 | return fmt.Sprintf("%s://%s", scheme, p.u.Host)
115 | }
116 |
117 | func (p *stdioFrontend) BackendInfo() (info.Info, error) {
118 | return getBackendInfo(p)
119 | }
120 |
121 | func (p *stdioFrontend) wspump(ctx context.Context) error {
122 |
123 | var failures int
124 |
125 | for {
126 |
127 | select {
128 |
129 | case <-ctx.Done():
130 | return nil
131 |
132 | default:
133 | session, err := Connect(ctx, p.cfg.backendDialer, p.backendURL, p.tlsClientConfig, AgentInfo{
134 | Auth: p.agentAuth,
135 | UserAgent: "stdio",
136 | RemoteAddr: "local",
137 | })
138 | if err != nil {
139 |
140 | if !p.cfg.retry {
141 | return err
142 | }
143 |
144 | if failures == 1 {
145 | slog.Error("Retrying...", err)
146 | }
147 |
148 | failures++
149 | time.Sleep(2 * time.Second)
150 |
151 | continue
152 | }
153 |
154 | if failures > 0 {
155 | slog.Info("Connection restored", "attempts", failures)
156 | }
157 |
158 | failures = 0
159 |
160 | L:
161 | for {
162 |
163 | select {
164 |
165 | case buf := <-p.wsWrite:
166 | session.Write(sanitize.Data(buf))
167 |
168 | case data := <-session.Read():
169 | fmt.Println(string(sanitize.Data(data)))
170 |
171 | case err := <-session.Error():
172 | failures++
173 | slog.Error("Error from webscoket", err)
174 |
175 | case err := <-session.Done():
176 | failures++
177 | if !p.cfg.retry {
178 | return err
179 | }
180 | break L
181 |
182 | case <-ctx.Done():
183 | session.Close(1000)
184 | return nil
185 | }
186 | }
187 | }
188 | }
189 | }
190 |
191 | func (p *stdioFrontend) stdiopump(ctx context.Context) error {
192 |
193 | stdin := bufio.NewReader(os.Stdin)
194 |
195 | for {
196 | select {
197 |
198 | default:
199 |
200 | buf, err := stdin.ReadBytes('\n')
201 |
202 | if err != nil {
203 | slog.Debug("unable to read stdin", err)
204 | return err
205 | }
206 |
207 | if len(buf) == 0 {
208 | continue
209 | }
210 |
211 | p.wsWrite <- sanitize.Data(buf)
212 |
213 | case <-ctx.Done():
214 | return nil
215 | }
216 | }
217 | }
218 |
--------------------------------------------------------------------------------
/cli/internal/cmd/aio.go:
--------------------------------------------------------------------------------
1 | package cmd
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log/slog"
7 | "net"
8 | "time"
9 |
10 | "github.com/spf13/cobra"
11 | "github.com/spf13/pflag"
12 | "github.com/spf13/viper"
13 | "go.acuvity.ai/minibridge/pkgs/backend"
14 | "go.acuvity.ai/minibridge/pkgs/frontend"
15 | "go.acuvity.ai/minibridge/pkgs/memconn"
16 | "golang.org/x/sync/errgroup"
17 | )
18 |
19 | var fAIO = pflag.NewFlagSet("aio", pflag.ExitOnError)
20 |
21 | func init() {
22 |
23 | initSharedFlagSet()
24 |
25 | fAIO.StringP("listen", "l", "", "listen address of the bridge for incoming connections. If unset, stdio is used.")
26 | fAIO.String("endpoint-mcp", "/mcp", "when using HTTP, sets the endpoint to send messages (proto 2025-03-26).")
27 | fAIO.String("endpoint-messages", "/message", "when using HTTP, sets the endpoint to post messages (proto 2024-11-05).")
28 | fAIO.String("endpoint-sse", "/sse", "when using HTTP, sets the endpoint to connect to the event stream (proto 2024-11-05).")
29 |
30 | AIO.Flags().AddFlagSet(fAIO)
31 | AIO.Flags().AddFlagSet(fPolicer)
32 | AIO.Flags().AddFlagSet(fTLSServer)
33 | AIO.Flags().AddFlagSet(fHealth)
34 | AIO.Flags().AddFlagSet(fProfiler)
35 | AIO.Flags().AddFlagSet(fCORS)
36 | AIO.Flags().AddFlagSet(fAgentAuth)
37 | AIO.Flags().AddFlagSet(fSBOM)
38 | AIO.Flags().AddFlagSet(fMCP)
39 | }
40 |
41 | var AIO = &cobra.Command{
42 | Use: "aio [flags] -- command [args...]",
43 | Short: "Start an all-in-one minibridge frontend and backend",
44 | Args: cobra.MinimumNArgs(1),
45 | SilenceUsage: true,
46 | SilenceErrors: true,
47 | TraverseChildren: true,
48 |
49 | RunE: func(cmd *cobra.Command, args []string) (err error) {
50 |
51 | ctx, cancel := context.WithCancel(cmd.Context())
52 | defer cancel()
53 |
54 | listen := viper.GetString("listen")
55 | mcpEndpoint := viper.GetString("endpoint-mcp")
56 | sseEndpoint := viper.GetString("endpoint-sse")
57 | messageEndpoint := viper.GetString("endpoint-messages")
58 |
59 | agentAuth, err := makeAgentAuth(true)
60 | if err != nil {
61 | return fmt.Errorf("unable to build auth: %w", err)
62 | }
63 |
64 | policer, penforce, err := makePolicer()
65 | if err != nil {
66 | return fmt.Errorf("unable to make policer: %w", err)
67 | }
68 |
69 | sbom, err := makeSBOM()
70 | if err != nil {
71 | return fmt.Errorf("unable to make hashes: %w", err)
72 | }
73 |
74 | tracer, err := makeTracer(ctx, "aio")
75 | if err != nil {
76 | return fmt.Errorf("unable to configure tracer: %w", err)
77 | }
78 |
79 | corsPolicy := makeCORSPolicy()
80 |
81 | mcpClient, err := makeMCPClient(args, true)
82 | if err != nil {
83 | return fmt.Errorf("unable to create MCP client: %w", err)
84 | }
85 |
86 | mm := startHealthServer(ctx)
87 |
88 | listener := memconn.NewListener()
89 | defer func() { _ = listener.Close() }()
90 |
91 | var eg errgroup.Group
92 |
93 | var mbackend backend.Backend
94 | eg.Go(func() error {
95 |
96 | defer cancel()
97 |
98 | slog.Info("Minibridge backend configured")
99 |
100 | mbackend = backend.NewWebSocket("self", nil, mcpClient,
101 | backend.OptListener(listener),
102 | backend.OptPolicer(policer),
103 | backend.OptPolicerEnforce(penforce),
104 | backend.OptDumpStderrOnError(viper.GetString("log-format") != "json"),
105 | backend.OptSBOM(sbom),
106 | backend.OptMetricsManager(mm),
107 | backend.OptTracer(tracer),
108 | )
109 |
110 | return mbackend.Start(ctx)
111 | })
112 |
113 | eg.Go(func() error {
114 |
115 | defer cancel()
116 |
117 | var mfrontend frontend.Frontend
118 |
119 | frontendServerTLSConfig, err := tlsConfigFromFlags(fTLSServer)
120 | if err != nil {
121 | return err
122 | }
123 |
124 | dialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
125 | return listener.DialContext(cmd.Context(), "127.0.0.1:443")
126 | }
127 |
128 | if listen != "" {
129 |
130 | slog.Info("Minibridge frontend configured",
131 | "mcp", mcpEndpoint,
132 | "sse", sseEndpoint,
133 | "messages", messageEndpoint,
134 | "agent-token", agentAuth != nil,
135 | "mode", "http",
136 | "server-tls", frontendServerTLSConfig != nil,
137 | "server-mtls", mtlsMode(frontendServerTLSConfig),
138 | "listen", listen,
139 | )
140 |
141 | mfrontend = frontend.NewHTTP(listen, "ws://self/ws", frontendServerTLSConfig, nil,
142 | frontend.OptHTTPBackendDialer(dialer),
143 | frontend.OptHTTPMCPEndpoint(mcpEndpoint),
144 | frontend.OptHTTPSSEEndpoint(sseEndpoint),
145 | frontend.OptHTTPMessageEndpoint(messageEndpoint),
146 | frontend.OptHTTPAgentTokenPassthrough(true),
147 | frontend.OptHTTPCORSPolicy(corsPolicy),
148 | frontend.OptHTTPMetricsManager(mm),
149 | frontend.OptHTTPTracer(tracer),
150 | )
151 | } else {
152 |
153 | slog.Info("Minibridge frontend configured",
154 | "mode", "stdio",
155 | )
156 |
157 | mfrontend = frontend.NewStdio("ws://self/ws", nil,
158 | frontend.OptStdioBackendDialer(dialer),
159 | frontend.OptStdioRetry(false),
160 | frontend.OptStdioTracer(tracer),
161 | )
162 | }
163 |
164 | time.Sleep(300 * time.Millisecond)
165 |
166 | return startFrontendWithOAuth(ctx, mfrontend, agentAuth)
167 | })
168 |
169 | return eg.Wait()
170 | },
171 | }
172 |
--------------------------------------------------------------------------------
/pkgs/scan/scan.go:
--------------------------------------------------------------------------------
1 | package scan
2 |
3 | import (
4 | "context"
5 | "crypto/sha256"
6 | "fmt"
7 | "slices"
8 | "strings"
9 |
10 | "github.com/gofrs/uuid"
11 | "github.com/mitchellh/mapstructure"
12 | "go.acuvity.ai/minibridge/pkgs/backend/client"
13 | "go.acuvity.ai/minibridge/pkgs/mcp"
14 | )
15 |
16 | type Dump struct {
17 | Tools mcp.Tools `json:"tools,omitempty"`
18 | Resources mcp.Resources `json:"resources,omitempty"`
19 | ResourceTemplates mcp.ResourceTemplates `json:"resourceTemplates,omitempty"`
20 | Prompts mcp.Prompts `json:"prompts,omitempty"`
21 | }
22 |
23 | // DumpAll dumps all the all available tools/resource/prompts from the given client.MCPStream.
24 | func DumpAll(ctx context.Context, stream *client.MCPStream, exclusions *Exclusions) (Dump, error) {
25 |
26 | if _, err := stream.SendRequest(ctx, mcp.NewInitMessage(mcp.ProtocolVersion20250326)); err != nil {
27 | return Dump{}, fmt.Errorf("unable to send mcp request: %w", err)
28 | }
29 |
30 | if err := stream.SendNotification(ctx, mcp.NewNotification("notifications/initialize")); err != nil {
31 | return Dump{}, fmt.Errorf("unable to send mcp initialized notif: %w", err)
32 | }
33 |
34 | dump := Dump{}
35 |
36 | // Tools
37 | if !exclusions.Tools {
38 | toolsReq := mcp.NewMessage(uuid.Must(uuid.NewV7()).String())
39 | toolsReq.Method = "tools/list"
40 | resps, err := stream.SendPaginatedRequest(ctx, toolsReq)
41 | if err != nil {
42 | return Dump{}, fmt.Errorf("unable to send tools/list mcp request: %w", err)
43 | }
44 |
45 | for _, resp := range resps {
46 |
47 | if _, ok := resp.Result["tools"]; !ok {
48 | continue
49 | }
50 |
51 | tools := mcp.Tools{}
52 | if err := mapstructure.Decode(resp.Result["tools"], &tools); err != nil {
53 | return Dump{}, fmt.Errorf("unable to convert to tools: %w", err)
54 | }
55 |
56 | dump.Tools = append(dump.Tools, tools...)
57 | }
58 | }
59 |
60 | // Resources
61 | if !exclusions.Resources {
62 | resourcesReq := mcp.NewMessage(uuid.Must(uuid.NewV7()).String())
63 | resourcesReq.Method = "resources/list"
64 | resps, err := stream.SendPaginatedRequest(ctx, resourcesReq)
65 | if err != nil {
66 | return Dump{}, fmt.Errorf("unable to send resources/list mcp request: %w", err)
67 | }
68 |
69 | for _, resp := range resps {
70 |
71 | if _, ok := resp.Result["resources"]; !ok {
72 | continue
73 | }
74 |
75 | resources := mcp.Resources{}
76 | if err := mapstructure.Decode(resp.Result["resources"], &resources); err != nil {
77 | return Dump{}, fmt.Errorf("unable to convert to resources: %w", err)
78 | }
79 |
80 | dump.Resources = append(dump.Resources, resources...)
81 | }
82 |
83 | // Resources Templates
84 | resourcesTemplateReq := mcp.NewMessage(uuid.Must(uuid.NewV7()).String())
85 | resourcesTemplateReq.Method = "resources/templates/list"
86 | resps, err = stream.SendPaginatedRequest(ctx, resourcesTemplateReq)
87 | if err != nil {
88 | return Dump{}, fmt.Errorf("unable to send resources/templates/list mcp request: %w", err)
89 | }
90 |
91 | for _, resp := range resps {
92 |
93 | if _, ok := resp.Result["resourceTemplates"]; !ok {
94 | continue
95 | }
96 |
97 | resourceTemplates := mcp.ResourceTemplates{}
98 | if err := mapstructure.Decode(resp.Result["resourceTemplates"], &resourceTemplates); err != nil {
99 | return Dump{}, fmt.Errorf("unable to convert to resources templates: %w", err)
100 | }
101 |
102 | dump.ResourceTemplates = append(dump.ResourceTemplates, resourceTemplates...)
103 | }
104 |
105 | }
106 |
107 | // Prompts
108 | if !exclusions.Prompts {
109 | promptsReq := mcp.NewMessage(uuid.Must(uuid.NewV7()).String())
110 | promptsReq.Method = "prompts/list"
111 | resps, err := stream.SendPaginatedRequest(ctx, promptsReq)
112 | if err != nil {
113 | return Dump{}, fmt.Errorf("unable to send prompts/list mcp request: %w", err)
114 | }
115 |
116 | for _, resp := range resps {
117 |
118 | if _, ok := resp.Result["prompts"]; !ok {
119 | continue
120 | }
121 |
122 | prompts := mcp.Prompts{}
123 | if err := mapstructure.Decode(resp.Result["prompts"], &prompts); err != nil {
124 | return Dump{}, fmt.Errorf("unable to convert to prompts: %w", err)
125 | }
126 |
127 | dump.Prompts = append(dump.Prompts, prompts...)
128 | }
129 | }
130 | return dump, nil
131 | }
132 |
133 | // HashTools will generate Hashes for the given api.Tools
134 | func HashTools(tools mcp.Tools) (Hashes, error) {
135 |
136 | hashes := []Hash{}
137 | for _, tool := range tools {
138 |
139 | h := Hash{
140 | Name: tool.Name,
141 | Hash: fmt.Sprintf("%x", sha256.Sum256([]byte(tool.Description))),
142 | }
143 |
144 | for k, v := range tool.InputSchema {
145 | if k != "properties" {
146 | continue
147 | }
148 | for pk, pv := range v.(map[string]any) {
149 |
150 | pvv, ok := pv.(map[string]any)
151 | if !ok {
152 | continue
153 | }
154 |
155 | pdesc, ok := pvv["description"].(string)
156 | if !ok {
157 | continue
158 | }
159 |
160 | h.Params = append(h.Params, Hash{
161 | Name: pk,
162 | Hash: fmt.Sprintf("%x", sha256.Sum256([]byte(pdesc))),
163 | })
164 | }
165 | }
166 |
167 | slices.SortFunc(h.Params, func(a Hash, b Hash) int {
168 | return strings.Compare(a.Name, b.Name)
169 | })
170 |
171 | hashes = append(hashes, h)
172 | }
173 |
174 | slices.SortFunc(hashes, func(a Hash, b Hash) int {
175 | return strings.Compare(a.Name, b.Name)
176 | })
177 |
178 | return hashes, nil
179 | }
180 |
181 | // HashPrompt generate Hashes for the given api.Prompt
182 | func HashPrompts(prompts mcp.Prompts) (Hashes, error) {
183 |
184 | hashes := []Hash{}
185 | for _, tool := range prompts {
186 |
187 | h := Hash{
188 | Name: tool.Name,
189 | Hash: fmt.Sprintf("%x", sha256.Sum256([]byte(tool.Description))),
190 | }
191 |
192 | for _, p := range tool.Arguments {
193 |
194 | h.Params = append(h.Params, Hash{
195 | Name: p.Name,
196 | Hash: fmt.Sprintf("%x", sha256.Sum256([]byte(p.Description))),
197 | })
198 | }
199 |
200 | slices.SortFunc(h.Params, func(a Hash, b Hash) int {
201 | return strings.Compare(a.Name, b.Name)
202 | })
203 |
204 | hashes = append(hashes, h)
205 | }
206 |
207 | slices.SortFunc(hashes, func(a Hash, b Hash) int {
208 | return strings.Compare(a.Name, b.Name)
209 | })
210 |
211 | return hashes, nil
212 | }
213 |
--------------------------------------------------------------------------------
/pkgs/backend/client/sse.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "context"
7 | "crypto/tls"
8 | "errors"
9 | "fmt"
10 | "io"
11 | "log/slog"
12 | "net/http"
13 | "net/url"
14 | "strings"
15 | "time"
16 |
17 | "go.acuvity.ai/minibridge/pkgs/auth"
18 | "go.acuvity.ai/minibridge/pkgs/internal/sanitize"
19 | )
20 |
21 | var _ Client = (*sseClient)(nil)
22 | var _ RemoteClient = (*sseClient)(nil)
23 |
24 | var ErrAuthRequired = errors.New("authorization required")
25 |
26 | type sseClient struct {
27 | u *url.URL
28 | endpoint string
29 | messageEndpoint string
30 | client *http.Client
31 | }
32 |
33 | func NewSSE(endpoint string, tlsConfig *tls.Config) Client {
34 |
35 | client := &http.Client{
36 | Transport: &http.Transport{
37 | TLSClientConfig: tlsConfig,
38 | ResponseHeaderTimeout: 10 * time.Second,
39 | },
40 | }
41 |
42 | u, err := url.Parse(endpoint)
43 | if err != nil {
44 | panic(err)
45 | }
46 |
47 | return &sseClient{
48 | endpoint: strings.TrimRight(endpoint, "/"),
49 | client: client,
50 | u: u,
51 | }
52 | }
53 |
54 | func (c *sseClient) Type() string { return "sse" }
55 |
56 | func (c *sseClient) Server() string { return c.BaseURL() }
57 |
58 | func (c *sseClient) BaseURL() string { return fmt.Sprintf("%s://%s", c.u.Scheme, c.u.Host) }
59 |
60 | func (c *sseClient) HTTPClient() *http.Client { return c.client }
61 |
62 | func (c *sseClient) MPCClient() Client { return c }
63 |
64 | func (c *sseClient) Start(ctx context.Context, opts ...Option) (pipe *MCPStream, err error) {
65 |
66 | cfg := cfg{}
67 | for _, o := range opts {
68 | o(&cfg)
69 | }
70 |
71 | sseEndpoint := fmt.Sprintf("%s/sse", c.endpoint)
72 |
73 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, sseEndpoint, nil)
74 | if err != nil {
75 | return nil, fmt.Errorf("unable to initiate request: %w", err)
76 | }
77 | req.Header.Set("Accept", "text/event-stream")
78 | if cfg.auth != nil {
79 | req.Header.Set("Authorization", cfg.auth.Encode())
80 | }
81 |
82 | stream := NewMCPStream(ctx)
83 | out, unregisterOut := stream.Stdout()
84 | defer unregisterOut()
85 | exit, unregisterExit := stream.Exit()
86 | defer unregisterExit()
87 |
88 | // we don't close the body here (which makes the linter all triggered)
89 | // because it's a long running connection. however readResponse will
90 | // do it on exit
91 | resp, err := c.client.Do(req) // nolint
92 | if err != nil {
93 | return nil, fmt.Errorf("unable to send initial sse request (%s): %w", req.URL.String(), err)
94 | }
95 |
96 | if resp.StatusCode == http.StatusUnauthorized {
97 | return nil, ErrAuthRequired
98 | }
99 |
100 | if resp.StatusCode != http.StatusOK {
101 | return nil, fmt.Errorf("invalid response from sse initialization (%s): %s", req.URL.String(), resp.Status)
102 | }
103 |
104 | go c.readRequest(ctx, stream.stdin, stream.exit, cfg.auth)
105 | go c.readResponse(ctx, resp.Body, stream.stdout, stream.exit)
106 |
107 | // Get the first response to get the endpoint
108 | var data []byte
109 |
110 | L:
111 | for {
112 | select {
113 | case data = <-out:
114 | break L
115 |
116 | case err := <-exit:
117 | if !errors.Is(err, io.EOF) {
118 | return nil, fmt.Errorf("unable to process sse message: %w", err)
119 | }
120 |
121 | case <-time.After(time.Second):
122 | return nil, fmt.Errorf("did not receive /message endpoint in time: timeout")
123 |
124 | case <-ctx.Done():
125 | return nil, fmt.Errorf("did not receive /message endpoint in time: %w", ctx.Err())
126 | }
127 | }
128 |
129 | c.messageEndpoint = fmt.Sprintf(
130 | "%s/%s",
131 | c.endpoint,
132 | strings.TrimLeft(string(data), "/"),
133 | )
134 | slog.Debug("SSE Client: message endpoint set", "endpoint", c.messageEndpoint)
135 |
136 | return stream, nil
137 | }
138 |
139 | func (c *sseClient) readRequest(ctx context.Context, ch chan []byte, exitCh chan error, auth *auth.Auth) {
140 |
141 | for {
142 |
143 | select {
144 |
145 | case <-ctx.Done():
146 | return
147 |
148 | case data := <-ch:
149 |
150 | buf := bytes.NewBuffer(append(sanitize.Data(data), '\n', '\n'))
151 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.messageEndpoint, buf)
152 | if err != nil {
153 | exitCh <- fmt.Errorf("unable to make post request: %w", err)
154 | return
155 | }
156 | req.Header.Set("Content-Type", "application/json")
157 | req.Header.Set("Accept", "application/json")
158 | if auth != nil {
159 | req.Header.Set("Authorization", auth.Encode())
160 | }
161 |
162 | resp, err := c.client.Do(req)
163 | if err != nil {
164 | exitCh <- fmt.Errorf("unable to send post request: %w", err)
165 | return
166 | }
167 | defer func() { _ = resp.Body.Close() }()
168 |
169 | if resp.StatusCode == http.StatusUnauthorized {
170 | exitCh <- ErrAuthRequired
171 | return
172 | }
173 |
174 | if resp.StatusCode != http.StatusAccepted {
175 | exitCh <- fmt.Errorf("invalid mcp server response status: %s", resp.Status)
176 | return
177 | }
178 | }
179 | }
180 | }
181 |
182 | func (c *sseClient) readResponse(ctx context.Context, r io.ReadCloser, ch chan []byte, exitCh chan error) {
183 |
184 | defer func() { _ = r.Close() }()
185 |
186 | scan := bufio.NewScanner(r)
187 | scan.Split(split)
188 | scan.Buffer(make([]byte, 1024), 5*1024*1024)
189 |
190 | for scan.Scan() {
191 |
192 | data := sanitize.Data(scan.Bytes())
193 |
194 | parts := bytes.SplitN(data, []byte{'\n'}, 2)
195 | if len(parts) != 2 {
196 | exitCh <- fmt.Errorf("invalid sse message: %s", string(data))
197 | return
198 | }
199 |
200 | data = bytes.TrimPrefix(parts[1], []byte("data: "))
201 |
202 | select {
203 | case ch <- data:
204 | case <-ctx.Done():
205 | return
206 | }
207 | }
208 |
209 | if err := scan.Err(); err != nil {
210 | exitCh <- fmt.Errorf("sse stream closed: %w", scan.Err())
211 | } else {
212 | exitCh <- fmt.Errorf("sse stream closed: %w", io.EOF)
213 | }
214 | }
215 |
216 | func split(data []byte, atEOF bool) (int, []byte, error) {
217 |
218 | if atEOF && len(data) == 0 {
219 | return 0, nil, nil
220 | }
221 |
222 | if i, nlen := hasNewLine(data); i >= 0 {
223 | return i + nlen, data[0:i], nil
224 | }
225 |
226 | if atEOF {
227 | return len(data), data, nil
228 | }
229 |
230 | return 0, nil, nil
231 | }
232 |
233 | func hasNewLine(data []byte) (int, int) {
234 |
235 | crunix := bytes.Index(data, []byte("\n\n"))
236 | crwin := bytes.Index(data, []byte("\r\n\r\n"))
237 | minPos := minPos(crunix, crwin)
238 | nlen := 2
239 | if minPos == crwin {
240 | nlen = 4
241 | }
242 | return minPos, nlen
243 | }
244 |
245 | func minPos(a, b int) int {
246 | if a < 0 {
247 | return b
248 | }
249 | if b < 0 {
250 | return a
251 | }
252 | if a > b {
253 | return b
254 | }
255 | return a
256 | }
257 |
--------------------------------------------------------------------------------
/pkgs/metrics/manager.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "log/slog"
7 | "net"
8 | "net/http"
9 | "strconv"
10 | "time"
11 |
12 | "github.com/prometheus/client_golang/prometheus"
13 | "github.com/prometheus/client_golang/prometheus/promhttp"
14 | "go.acuvity.ai/minibridge/pkgs/policer/api"
15 | )
16 |
17 | type Manager struct {
18 | reqDurationMetric *prometheus.HistogramVec
19 | reqTotalMetric *prometheus.CounterVec
20 | errorMetric *prometheus.CounterVec
21 | tcpConnTotalMetric prometheus.Counter
22 | tcpConnCurrentMetric prometheus.Gauge
23 | wsConnTotalMetric prometheus.Counter
24 | wsConnCurrentMetric prometheus.Gauge
25 | policerDurationMetric *prometheus.HistogramVec
26 | policerRequestTotalMetric *prometheus.CounterVec
27 |
28 | server *http.Server
29 | }
30 |
31 | func NewManager(listen string) *Manager {
32 |
33 | r := prometheus.DefaultRegisterer
34 |
35 | mc := &Manager{
36 |
37 | reqTotalMetric: prometheus.NewCounterVec(
38 | prometheus.CounterOpts{
39 | Name: "http_requests_total",
40 | Help: "The total number of requests.",
41 | },
42 | []string{"method", "url", "code"},
43 | ),
44 | reqDurationMetric: prometheus.NewHistogramVec(
45 | prometheus.HistogramOpts{
46 | Name: "http_requests_duration_seconds",
47 | Help: "The average duration of the requests",
48 | Buckets: []float64{0.001, 0.0025, 0.005, 0.010, 0.025, 0.050, 0.100, 0.250, 0.500, 1.0, 2.5, 5.0, 10.0},
49 | },
50 | []string{"method", "url"},
51 | ),
52 | tcpConnTotalMetric: prometheus.NewCounter(
53 | prometheus.CounterOpts{
54 | Name: "tcp_connections_total",
55 | Help: "The total number of TCP connection.",
56 | },
57 | ),
58 | tcpConnCurrentMetric: prometheus.NewGauge(
59 | prometheus.GaugeOpts{
60 | Name: "tcp_connections_current",
61 | Help: "The current number of TCP connection.",
62 | },
63 | ),
64 | wsConnTotalMetric: prometheus.NewCounter(
65 | prometheus.CounterOpts{
66 | Name: "http_ws_connections_total",
67 | Help: "The total number of ws connection.",
68 | },
69 | ),
70 | wsConnCurrentMetric: prometheus.NewGauge(
71 | prometheus.GaugeOpts{
72 | Name: "http_ws_connections_current",
73 | Help: "The current number of ws connection.",
74 | },
75 | ),
76 | errorMetric: prometheus.NewCounterVec(
77 | prometheus.CounterOpts{
78 | Name: "http_errors_5xx_total",
79 | Help: "The total number of 5xx errors.",
80 | },
81 | []string{"trace", "method", "url", "code"},
82 | ),
83 |
84 | policerDurationMetric: prometheus.NewHistogramVec(
85 | prometheus.HistogramOpts{
86 | Name: "policer_requests_duration_seconds",
87 | Help: "The average duration of the policing requests",
88 | Buckets: []float64{0.001, 0.0025, 0.005, 0.010, 0.025, 0.050, 0.100, 0.250, 0.500, 1.0, 2.5, 5.0, 10.0},
89 | },
90 | []string{"policer_type", "call_type"},
91 | ),
92 | policerRequestTotalMetric: prometheus.NewCounterVec(
93 | prometheus.CounterOpts{
94 | Name: "policer_request_total",
95 | Help: "The total number of policer requests.",
96 | },
97 | []string{"policer_type", "call_type", "decision"},
98 | ),
99 | }
100 |
101 | r.MustRegister(mc.tcpConnCurrentMetric)
102 | r.MustRegister(mc.tcpConnTotalMetric)
103 | r.MustRegister(mc.reqTotalMetric)
104 | r.MustRegister(mc.reqDurationMetric)
105 | r.MustRegister(mc.wsConnTotalMetric)
106 | r.MustRegister(mc.wsConnCurrentMetric)
107 | r.MustRegister(mc.errorMetric)
108 | r.MustRegister(mc.policerDurationMetric)
109 | r.MustRegister(mc.policerRequestTotalMetric)
110 |
111 | mc.server = &http.Server{
112 | Addr: listen,
113 | ReadHeaderTimeout: time.Second,
114 | Handler: mc,
115 | }
116 |
117 | return mc
118 | }
119 |
120 | func (c *Manager) Start(ctx context.Context) error {
121 |
122 | errCh := make(chan error, 1)
123 |
124 | sctx, cancel := context.WithCancel(ctx)
125 | defer cancel()
126 |
127 | c.server.BaseContext = func(net.Listener) context.Context { return sctx }
128 | c.server.RegisterOnShutdown(func() { cancel() })
129 |
130 | go func() {
131 | err := c.server.ListenAndServe()
132 | if err != nil {
133 | if !errors.Is(err, http.ErrServerClosed) {
134 | slog.Error("unable to start health server", "err", err)
135 | }
136 | }
137 | errCh <- err
138 | }()
139 |
140 | select {
141 | case <-sctx.Done():
142 | case err := <-errCh:
143 | return err
144 | }
145 |
146 | stopCtx, stopCancel := context.WithTimeout(context.Background(), 10*time.Second)
147 | defer stopCancel()
148 |
149 | return c.server.Shutdown(stopCtx)
150 | }
151 |
152 | func (c *Manager) MeasureRequest(method string, path string) func(int) time.Duration {
153 |
154 | timer := prometheus.NewTimer(
155 | prometheus.ObserverFunc(
156 | func(v float64) {
157 | c.reqDurationMetric.With(
158 | prometheus.Labels{
159 | "method": method,
160 | "url": path,
161 | },
162 | ).Observe(v)
163 | },
164 | ),
165 | )
166 |
167 | return func(code int) time.Duration {
168 |
169 | c.reqTotalMetric.With(prometheus.Labels{
170 | "method": method,
171 | "url": path,
172 | "code": strconv.Itoa(code),
173 | }).Inc()
174 |
175 | if code >= http.StatusInternalServerError {
176 |
177 | c.errorMetric.With(prometheus.Labels{
178 | "method": method,
179 | "url": path,
180 | "code": strconv.Itoa(code),
181 | }).Inc()
182 | }
183 |
184 | return timer.ObserveDuration()
185 | }
186 | }
187 |
188 | func (c *Manager) MeasurePolicer(ptype string, rtype api.CallType) func(allow bool) time.Duration {
189 |
190 | timer := prometheus.NewTimer(
191 | prometheus.ObserverFunc(
192 | func(v float64) {
193 | c.policerDurationMetric.With(
194 | prometheus.Labels{
195 | "policer_type": ptype,
196 | "call_type": string(rtype),
197 | },
198 | ).Observe(v)
199 | },
200 | ),
201 | )
202 |
203 | return func(allow bool) time.Duration {
204 | c.policerRequestTotalMetric.With(prometheus.Labels{
205 | "policer_type": ptype,
206 | "call_type": string(rtype),
207 | "decision": func() string {
208 | if allow {
209 | return "allow"
210 | }
211 | return "deny"
212 | }(),
213 | }).Inc()
214 |
215 | return timer.ObserveDuration()
216 | }
217 | }
218 |
219 | func (c *Manager) RegisterWSConnection() {
220 | c.wsConnTotalMetric.Inc()
221 | c.wsConnCurrentMetric.Inc()
222 | }
223 |
224 | func (c *Manager) UnregisterWSConnection() {
225 | c.wsConnCurrentMetric.Dec()
226 | }
227 |
228 | func (c *Manager) RegisterTCPConnection() {
229 | c.tcpConnTotalMetric.Inc()
230 | c.tcpConnCurrentMetric.Inc()
231 | }
232 |
233 | func (c *Manager) UnregisterTCPConnection() {
234 | c.tcpConnCurrentMetric.Dec()
235 | }
236 |
237 | func (c *Manager) ServeHTTP(w http.ResponseWriter, req *http.Request) {
238 |
239 | switch req.URL.Path {
240 |
241 | case "/":
242 | w.WriteHeader(http.StatusNoContent)
243 |
244 | case "/metrics":
245 | promhttp.Handler().ServeHTTP(w, req)
246 |
247 | default:
248 | http.Error(w, "Not found", http.StatusNotFound)
249 | }
250 | }
251 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Minibridge
2 |
3 | [](https://github.com/acuvity/minibridge/actions/workflows/build.yaml)
4 | [](https://goreportcard.com/report/github.com/acuvity/minibridge)
5 | [](https://pkg.go.dev/github.com/acuvity/minibridge)
6 | [](https://hub.docker.com/u/acuvity?page=1&search=mcp-server)
7 |
8 | Minibridge serves as a backend-to-frontend bridge, streamlining and securing
9 | communication between Agents and MCP servers. It safely exposes [MCP
10 | servers](https://modelcontextprotocol.io) to the internet and can optionally
11 | integrate with generic policing services — known as Policers — for agent
12 | authentication, content analysis, and transformation. Policers can be
13 | implemented remotely via HTTP or locally using [OPA
14 | Rego](https://www.openpolicyagent.org/docs/latest/policy-reference/) policies.
15 |
16 | Minibridge can help ensure the integrity of MCP servers through
17 | SBOM (Software Bill of Materials) generation and real-time validation.
18 |
19 | Additionally, Minibridge supports [OTEL](https://opentelemetry.io/) and can
20 | report/rettach spans from classical OTEL headers, as well as directly from the
21 | MCP call, as inserted by certain tool like [Open
22 | Inference](https://arize-ai.github.io/openinference).
23 |
24 | 
25 |
26 | - **Minibridge Frontend**: The Client connects to the Frontend part of Minibridge.
27 | - **Minibridge Backend**: The Frontend connects to the Backend which wraps the MCP server.
28 | - **Minibridge Policer**: The Policer runs in the Backend and can optionally take decision on the input and output based on some policies (locally with Rego or remotely using HTTPs)
29 |
30 | > [!TIP]
31 | > Conveniently, Minibridge can be started in an "all-in-one" (AIO) mode to act as a single process.
32 |
33 | ## Why using Minibridge ?
34 |
35 | Minibridge covers the following:
36 |
37 | - **Secure Transport**: Use TLS with optionally, client certificate validation
38 | - **Integrity**: Ensure the MCP server can not mutate tools, templates, etc. during the execution
39 | - **User Authentication**: Transport the user information to the Policer
40 | - **Monitoring**: Expose prometheus metrics
41 | - **Telemetry**: Report traces and spans using Opentelemetry
42 |
43 | ## Installation
44 |
45 | Minibridge can be installed from various places:
46 |
47 | ### Homebrew
48 |
49 | On macOS, you can use Homebrew
50 |
51 | ```console
52 | brew tap acuvity/tap
53 | brew install minibridge
54 | ```
55 |
56 | ### AUR
57 |
58 | On Arch based Linux distributions, you can run:
59 |
60 | ```console
61 | yay -S minibridge
62 | ```
63 |
64 | Alternatively, to get the latest version from the main branch:
65 |
66 | ```console
67 | yay -S minibridge-git
68 | ```
69 |
70 | ### Go
71 |
72 | If you have the Go toolchain install:
73 |
74 | ```console
75 | go install go.acuvity.ai/minibridge@latest
76 | ```
77 |
78 | Alternatively, to get the latest version from the main branch:
79 |
80 | ```console
81 | go install go.acuvity.ai/minibridge@main
82 | ```
83 |
84 | ### Manually
85 |
86 | You can easily grab a binary version for your platform from the [release
87 | page](https://github.com/acuvity/minibridge/releases/tag/v0.6.2).
88 |
89 |
90 | ## Features comparisons
91 |
92 | | 🚀 **Feature** | 🔹 **MCP** | 🔸 **Minibridge** | 📦 **ARC (Acuvity Containers)** |
93 | | ------------------------------ | ----------- | ----------------- | ------------------------------- |
94 | | 🌐 Remote Access | ⚠️ | ✅ | ✅ |
95 | | 🔒 TLS Support | ❌ | ✅ | ✅ |
96 | | 📃 Tool integrity check | ❌ | ✅ | ✅ |
97 | | 📊 Visualization and Tracing | ❌ | ✅ | ✅ |
98 | | 🛡️ Isolation | ❌ | ⚠️ | ✅ |
99 | | 🔐 Security Policy Management | ❌ | 👤 | ⚠️ |
100 | | 🕵️ Secrets Redaction | ❌ | 👤 | ⚠️ |
101 | | 🔑 Authorization Controls | ❌ | 👤 | 👤 |
102 | | 🧑💻 PII Detection and Redaction | ❌ | 👤 | 👤 |
103 | | 📌 Version Pinning | ❌ | ❌ | ✅ |
104 |
105 | ✅ _Included_ | ⚠️ _Partial/Basic Support_ | 👤 _Custom User Implementation_ | ❌ _Not Supported_
106 |
107 | ## Example: Configuring Minibridge in your MCP Client
108 |
109 | Suppose your client configuration originally specifies an MCP server like this:
110 |
111 | ```json
112 | {
113 | "mcpServers": {
114 | "fetch": {
115 | "command": "uvx",
116 | "args": ["mcp-server-fetch"]
117 | }
118 | }
119 | }
120 | ```
121 |
122 | To route requests through Minibridge (enabling SBOM checks, policy enforcement, etc.), update the entry:
123 |
124 | ```json
125 | {
126 | "mcpServers": {
127 | "fetch": {
128 | "command": "minibridge",
129 | "args": ["aio", "--", "uvx", "mcp-server-fetch"]
130 | }
131 | }
132 | }
133 | ```
134 |
135 | - **`minibridge aio`**: Invokes Minibridge in “all-in-one” mode, wrapping the downstream tool.
136 | - **`uvx mcp-server-fetch`**: The original MCP server command, now executed inside Minibridge.
137 |
138 | > [!TIP]
139 | > The location of the configuration files depends on your Client. For example, if you use Claude Desktop, configuration files are located:
140 | >
141 | > - macOS: `~/Library/Application Support/Claude/claude_desktop_config.json`
142 | > - Windows: `%APPDATA%\Claude\claude_desktop_config.json`
143 | >
144 | > See the official [MCP QuickStart for Claude Desktop Users](https://modelcontextprotocol.io/quickstart/user#2-add-the-filesystem-mcp-server) documentation.
145 |
146 | > [!IMPORTANT]
147 | > Your client must be able to resolve the path of the binary.
148 | > If you see an error like `MCP fetch: spawn minibridge ENOENT`, set the `command` parameter above to the full path of minibridge (`which minibridge` will give you the full path).
149 |
150 | ## Documentation
151 |
152 | Check out the complete [documentation](https://github.com/acuvity/minibridge/wiki) from the wiki pages.
153 |
154 | ## Contribute
155 |
156 | We are excited to welcome contributions from everyone! 🎉 Whether you're fixing bugs, enhancing features, improving documentation, or proposing entirely new ideas, your involvement helps strengthen the project and benefits the entire community.
157 |
158 | You do not need to sign a Contributor License Agreement (CLA) — just open a pull request and let's collaborate!
159 |
160 | ## Join us
161 |
162 | - [Discord](https://discord.gg/BkU7fBkrNk)
163 | - [LinkedIn](https://www.linkedin.com/company/acuvity)
164 | - [Bluesky](https://bsky.app/profile/acuvity.bsky.social)
165 | - [Docker](https://hub.docker.com/u/acuvity)
166 |
--------------------------------------------------------------------------------
/pkgs/memconn/memconn.go:
--------------------------------------------------------------------------------
1 | // Package memconn provides an in-memory network connections. This allows
2 | // applications to connect to themselves without having to open up ports on the
3 | // network.
4 | package memconn
5 |
6 | import (
7 | "bytes"
8 | "context"
9 | "errors"
10 | "fmt"
11 | "io"
12 | "net"
13 | "os"
14 | "sync"
15 | "time"
16 | )
17 |
18 | const connBufferSize = 64 // 1KB buffer size
19 |
20 | // Conn is an in-memory connection. Every Conn has a remote peer representing
21 | // the other side of the connection. Writes to Conn will be sent to the peer's
22 | // buffer for reading.
23 | //
24 | // conns use a 1kB buffer where writes can be sent immediately. Writes beyond
25 | // the buffer size will be blocked until they are read by the peer.
26 | type Conn struct {
27 | peerCh chan struct{} // Closed when peer is set
28 | peer *Conn
29 |
30 | cnd *sync.Cond
31 | buf bytes.Buffer
32 | readTimeout time.Time
33 | readTimeoutCancel context.CancelFunc
34 | writeTimeout time.Time
35 | writeTimeoutCancel context.CancelFunc
36 | closed bool
37 | name string
38 | }
39 |
40 | var _ net.Conn = (*Conn)(nil)
41 |
42 | func newConn() *Conn {
43 | var mut sync.Mutex
44 | return &Conn{
45 | peerCh: make(chan struct{}),
46 | cnd: sync.NewCond(&mut),
47 | name: "memory",
48 | }
49 | }
50 |
51 | // Attach sets the remote peer. Panics if called more than once.
52 | func (c *Conn) Attach(peer *Conn) {
53 | select {
54 | default:
55 | c.peer = peer
56 | close(c.peerCh)
57 | case <-c.peerCh:
58 | panic("peer already set")
59 | }
60 | }
61 |
62 | // WaitPeer waits for a peer to be set or until ctx is canceled.
63 | func (c *Conn) WaitPeer(ctx context.Context) error {
64 | select {
65 | case <-ctx.Done():
66 | return ctx.Err()
67 | case <-c.peerCh:
68 | return nil
69 | }
70 | }
71 |
72 | func (c *Conn) Read(b []byte) (n int, err error) {
73 | for n == 0 {
74 | n2, err := c.readOrBlock(b)
75 | if err != nil {
76 | return n2, err
77 | }
78 | n += n2
79 | }
80 | c.cnd.Signal() // Wake up calls to .Write
81 | return n, nil
82 | }
83 |
84 | func (c *Conn) readOrBlock(b []byte) (int, error) {
85 | c.cnd.L.Lock()
86 | defer c.cnd.L.Unlock()
87 |
88 | if !c.readTimeout.IsZero() && !time.Now().Before(c.readTimeout) {
89 | return 0, os.ErrDeadlineExceeded
90 | }
91 |
92 | n, err := c.buf.Read(b)
93 |
94 | // We expect to get EOF from our buffer whenever there's no pending data. We
95 | // don't want to propagate the EOF to our reader until the conn itself is
96 | // closed.
97 | if errors.Is(err, io.EOF) {
98 | if c.closed {
99 | return n, err
100 | }
101 |
102 | // Wait until we're woken up by something, either because there's a timeout
103 | // or there's data to read. Spurious wakeups may happen, which would
104 | // eventually cause us to re-enter the wait.
105 | c.cnd.Wait()
106 | }
107 | return n, nil
108 | }
109 |
110 | func (c *Conn) Write(b []byte) (n int, err error) {
111 | for len(b) > 0 {
112 | n2, err := c.writeOrBlock(b)
113 | if err != nil {
114 | return n + n2, err
115 | }
116 | n += n2
117 | b = b[n2:]
118 | }
119 | return n, nil
120 | }
121 |
122 | func (c *Conn) writeOrBlock(b []byte) (int, error) {
123 | if err := c.writeAvail(); err != nil {
124 | return 0, err
125 | }
126 | return c.peer.enqueueOrBlock(b)
127 | }
128 |
129 | // writeAvail returns nil when writing is available.
130 | func (c *Conn) writeAvail() error {
131 | c.cnd.L.Lock()
132 | defer c.cnd.L.Unlock()
133 |
134 | switch {
135 | case c.closed:
136 | return net.ErrClosed
137 | case !c.writeTimeout.IsZero() && !time.Now().Before(c.writeTimeout):
138 | return os.ErrDeadlineExceeded
139 | default:
140 | return nil
141 | }
142 | }
143 |
144 | // enqueueOrBlock is invoked by a peer and writes b into the local buffer.
145 | func (c *Conn) enqueueOrBlock(b []byte) (int, error) {
146 | c.cnd.L.Lock()
147 | defer c.cnd.L.Unlock()
148 | if c.closed {
149 | return 0, net.ErrClosed
150 | }
151 |
152 | // Try to write as much as possible.
153 | n := len(b)
154 | limit := connBufferSize - c.buf.Len()
155 | if limit < n {
156 | n = limit
157 | }
158 |
159 | if n == 0 {
160 | // Buffer is completely full; wait for it to free up.
161 | c.cnd.Wait()
162 | return 0, nil
163 | }
164 |
165 | c.buf.Write(b[:n])
166 | c.cnd.Signal() // Signal that data can be read
167 | return n, nil
168 | }
169 |
170 | // Close closes both sides of the connection.
171 | func (c *Conn) Close() error {
172 | err := c.handleClose()
173 | _ = c.peer.handleClose()
174 | return err
175 | }
176 |
177 | func (c *Conn) handleClose() error {
178 | c.cnd.L.Lock()
179 | defer c.cnd.L.Unlock()
180 |
181 | if c.closed {
182 | return nil
183 | }
184 | c.closed = true
185 |
186 | if c.readTimeoutCancel != nil {
187 | c.readTimeoutCancel()
188 | c.readTimeoutCancel = nil
189 | }
190 | if c.writeTimeoutCancel != nil {
191 | c.writeTimeoutCancel()
192 | c.writeTimeoutCancel = nil
193 | }
194 | c.broadcast()
195 | return nil
196 | }
197 |
198 | // broadcast will wake up all sleeping goroutines on the local and peer
199 | // connections.
200 | func (c *Conn) broadcast() {
201 | // We *MUST* wake up goroutines waiting on both the local and remote
202 | // condition variables because usage of a conn depends on both.
203 | c.cnd.Broadcast()
204 | c.peer.cnd.Broadcast()
205 | }
206 |
207 | func (c *Conn) LocalAddr() net.Addr {
208 | return Addr{
209 | name: c.name,
210 | }
211 | }
212 |
213 | func (c *Conn) RemoteAddr() net.Addr {
214 | return c.peer.LocalAddr()
215 | }
216 |
217 | func (c *Conn) SetDeadline(t time.Time) error {
218 | var firstError error
219 | if err := c.SetReadDeadline(t); err != nil {
220 | firstError = err
221 | }
222 | if err := c.SetWriteDeadline(t); err != nil && firstError == nil {
223 | firstError = err
224 | }
225 | return firstError
226 | }
227 |
228 | func (c *Conn) SetReadDeadline(t time.Time) error {
229 |
230 | c.cnd.L.Lock()
231 | defer c.cnd.L.Unlock()
232 | if c.closed {
233 | return fmt.Errorf("conn closed")
234 | }
235 |
236 | c.readTimeout = t
237 |
238 | // There should only be one deadline goroutine at a time, so cancel it if it
239 | // already exists.
240 | if c.readTimeoutCancel != nil {
241 | c.readTimeoutCancel()
242 | c.readTimeoutCancel = nil
243 | }
244 | c.readTimeoutCancel = c.deadlineTimer(t)
245 | return nil
246 | }
247 |
248 | func (c *Conn) deadlineTimer(t time.Time) context.CancelFunc {
249 | if t.IsZero() {
250 | // Deadline of zero means to wait forever.
251 | return nil
252 | }
253 | if t.Before(time.Now()) {
254 | c.broadcast()
255 | }
256 | ctx, cancel := context.WithDeadline(context.Background(), t)
257 | go func() {
258 | <-ctx.Done()
259 | if errors.Is(ctx.Err(), context.DeadlineExceeded) {
260 | c.broadcast()
261 | }
262 | }()
263 | return cancel
264 | }
265 |
266 | func (c *Conn) SetWriteDeadline(t time.Time) error {
267 | c.cnd.L.Lock()
268 | defer c.cnd.L.Unlock()
269 | if c.closed {
270 | return fmt.Errorf("conn closed")
271 | }
272 |
273 | c.writeTimeout = t
274 |
275 | // There should only be one deadline goroutine at a time, so cancel it if it
276 | // already exists.
277 | if c.writeTimeoutCancel != nil {
278 | c.writeTimeoutCancel()
279 | c.writeTimeoutCancel = nil
280 | }
281 | c.writeTimeoutCancel = c.deadlineTimer(t)
282 | return nil
283 | }
284 |
--------------------------------------------------------------------------------
/pkgs/oauth/dance.go:
--------------------------------------------------------------------------------
1 | package oauth
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "log/slog"
11 | "net/http"
12 | "net/url"
13 | "time"
14 |
15 | "github.com/gofrs/uuid"
16 | "github.com/pkg/browser"
17 | "go.acuvity.ai/minibridge/pkgs/info"
18 | )
19 |
20 | type RegistrationRequest struct {
21 | RedirectURI []string `json:"redirect_uris"`
22 | ClientName string `json:"client_name"`
23 | ClientURI string `json:"client_uri,omitempty"`
24 | TokenEndpointMethod string `json:"token_endpoint_auth_method"`
25 | LogoURI string `json:"logo_uri,omitempty"`
26 | ResponseTypes []string `json:"response_types,omitempty"`
27 | GrantTypes []string `json:"grant_types,omitempty"`
28 | }
29 |
30 | type RegistrationResponse struct {
31 | ClientID string `json:"client_id,omitempty"`
32 | RegistrationClientID string `json:"registration_client_uri,omitempty"`
33 | ClientIDIssuedAt int `json:"client_id_issued_at,omitempty"`
34 | }
35 |
36 | type Credentials struct {
37 | ClientID string `json:"client_id"`
38 | AccessToken string `json:"access_token"`
39 | RefreshToken string `json:"refresh_token"`
40 | }
41 |
42 | // Refresh performs a refresh of the access token using the RefreshToken in the gviven Creds
43 | func Refresh(ctx context.Context, backendURL string, cl *http.Client, int info.Info, rt Credentials) (t Credentials, err error) {
44 |
45 | u := fmt.Sprintf("%s/oauth2/token", backendURL)
46 |
47 | form := url.Values{
48 | "grant_type": {"refresh_token"},
49 | "refresh_token": {rt.RefreshToken},
50 | "client_id": {rt.ClientID},
51 | }
52 |
53 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, bytes.NewBuffer([]byte(form.Encode())))
54 | if err != nil {
55 | return t, fmt.Errorf("unable to make token request: %w", err)
56 | }
57 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
58 | req.Header.Set("Accept", "application/json")
59 |
60 | resp, err := cl.Do(req)
61 | if err != nil {
62 | return t, fmt.Errorf("unable to send token request: %w", err)
63 | }
64 | defer func(r *http.Response) { _ = r.Body.Close() }(resp)
65 |
66 | data, err := io.ReadAll(resp.Body)
67 | if err != nil && !errors.Is(err, io.EOF) {
68 | return t, fmt.Errorf("unable to read token response body: %w", err)
69 | }
70 |
71 | if resp.StatusCode != http.StatusOK {
72 | return t, fmt.Errorf("invalid token response status code: %s (%s)", resp.Status, string(data))
73 | }
74 |
75 | if err := json.Unmarshal(data, &t); err != nil {
76 | return t, fmt.Errorf("unable to unmarshal token response body: %w", err)
77 | }
78 |
79 | t.ClientID = rt.ClientID
80 |
81 | return t, nil
82 | }
83 |
84 | func Dance(ctx context.Context, backendURL string, cl *http.Client, inf info.Info) (t Credentials, err error) {
85 |
86 | redirectURI := "http://127.0.0.1:9977/callback"
87 | clientID := ""
88 | state := uuid.Must(uuid.NewV7()).String()
89 |
90 | if inf.OAuthRegister {
91 |
92 | oreq := RegistrationRequest{
93 | ClientName: "minibridge",
94 | ClientURI: "https://github.com/acuvity/minibridge",
95 | TokenEndpointMethod: "none",
96 | RedirectURI: []string{redirectURI},
97 | GrantTypes: []string{"authorization_code", "refresh_token"},
98 | ResponseTypes: []string{"code"},
99 | }
100 |
101 | data, err := json.MarshalIndent(oreq, "", " ")
102 | if err != nil {
103 | return t, fmt.Errorf("unable to marshal registration request: %w", err)
104 | }
105 |
106 | u := fmt.Sprintf("%s/oauth2/register", backendURL)
107 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, bytes.NewBuffer(data))
108 | if err != nil {
109 | return t, fmt.Errorf("unable to build registration request: %w", err)
110 | }
111 | req.Header.Set("Content-Type", "application/json")
112 | req.Header.Set("Accept", "application/json")
113 |
114 | resp, err := cl.Do(req)
115 | if err != nil {
116 | return t, fmt.Errorf("unable to send registration request: %w", err)
117 | }
118 | defer func(r *http.Response) { _ = r.Body.Close() }(resp)
119 |
120 | data, err = io.ReadAll(resp.Body)
121 | if err != nil && !errors.Is(err, io.EOF) {
122 | return t, fmt.Errorf("unable to read registration response body: %w", err)
123 | }
124 |
125 | if resp.StatusCode != http.StatusCreated {
126 | return t, fmt.Errorf("invalid registration response status code: %s (%s)", resp.Status, string(data))
127 | }
128 |
129 | oresp := RegistrationResponse{}
130 | if err := json.Unmarshal(data, &oresp); err != nil {
131 | return t, fmt.Errorf("unable to unmarshal registration response body: %w", err)
132 | }
133 |
134 | clientID = oresp.ClientID
135 | }
136 |
137 | values := url.Values{
138 | "response_type": {"code"},
139 | "client_id": {clientID},
140 | "redirect_uri": {redirectURI},
141 | "state": {state},
142 | }
143 |
144 | u := fmt.Sprintf("%s/authorize?%s", inf.Server, values.Encode())
145 | if err := browser.OpenURL(u); err != nil {
146 | fmt.Println("Open the following URL in your browser:", u)
147 | }
148 |
149 | codeCh := make(chan string, 1)
150 |
151 | server := &http.Server{
152 | ReadHeaderTimeout: 3 * time.Second,
153 | Addr: "127.0.0.1:9977",
154 | }
155 |
156 | server.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
157 |
158 | codeCh <- req.URL.Query().Get("code")
159 |
160 | w.WriteHeader(http.StatusOK)
161 | _, _ = w.Write(fmt.Appendf([]byte{}, successBody, inf.Server))
162 |
163 | sctx, cancel := context.WithTimeout(ctx, 1*time.Second)
164 | defer cancel()
165 |
166 | _ = server.Shutdown(sctx)
167 | })
168 |
169 | go func() {
170 | if err := server.ListenAndServe(); err != nil {
171 | if !errors.Is(err, http.ErrServerClosed) {
172 | slog.Error("Unable to start oauth callback server", err)
173 | return
174 | }
175 | }
176 | }()
177 |
178 | var code string
179 | select {
180 | case code = <-codeCh:
181 | case <-ctx.Done():
182 | return t, ctx.Err()
183 | case <-time.After(10 * time.Minute):
184 | return t, fmt.Errorf("oauth timeout")
185 | }
186 |
187 | u = fmt.Sprintf("%s/oauth2/token", backendURL)
188 |
189 | form := url.Values{
190 | "grant_type": {"authorization_code"},
191 | "client_id": {clientID},
192 | "code": {code},
193 | "redirect_uri": {redirectURI},
194 | }
195 |
196 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, bytes.NewBuffer([]byte(form.Encode())))
197 | if err != nil {
198 | return t, fmt.Errorf("unable to make token request: %w", err)
199 | }
200 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
201 | req.Header.Set("Accept", "application/json")
202 |
203 | resp, err := cl.Do(req)
204 | if err != nil {
205 | return t, fmt.Errorf("unable to send token request: %w", err)
206 | }
207 | defer func(r *http.Response) { _ = r.Body.Close() }(resp)
208 |
209 | data, err := io.ReadAll(resp.Body)
210 | if err != nil && !errors.Is(err, io.EOF) {
211 | return t, fmt.Errorf("unable to read token response body: %w", err)
212 | }
213 |
214 | if resp.StatusCode != http.StatusOK {
215 | return t, fmt.Errorf("invalid token response status code: %s (%s)", resp.Status, string(data))
216 | }
217 |
218 | if err := json.Unmarshal(data, &t); err != nil {
219 | return t, fmt.Errorf("unable to unmarshal token response body: %w", err)
220 | }
221 |
222 | t.ClientID = clientID
223 |
224 | return t, nil
225 | }
226 |
--------------------------------------------------------------------------------
/pkgs/backend/client/stream_test.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "testing"
7 | "time"
8 |
9 | . "github.com/smartystreets/goconvey/convey"
10 | "go.acuvity.ai/minibridge/pkgs/mcp"
11 | )
12 |
13 | func TestMCPStream(t *testing.T) {
14 |
15 | Convey("I have a running stream, registrations should work", t, func() {
16 | ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
17 | defer cancel()
18 |
19 | stream := NewMCPStream(ctx)
20 |
21 | stdout1, closeStdout1 := stream.Stdout()
22 | defer closeStdout1()
23 | stdout2, closeStdout2 := stream.Stdout()
24 | defer closeStdout2()
25 |
26 | stderr1, closeStderr1 := stream.Stderr()
27 | defer closeStderr1()
28 | stderr2, closeStderr2 := stream.Stderr()
29 | defer closeStderr2()
30 |
31 | exit1, closeExit1 := stream.Exit()
32 | defer closeExit1()
33 | exit2, closeExit2 := stream.Exit()
34 | defer closeExit2()
35 |
36 | cstdout1 := make(chan []byte)
37 | go func() { cstdout1 <- <-stdout1 }()
38 | cstdout2 := make(chan []byte)
39 | go func() { cstdout2 <- <-stdout2 }()
40 | cstderr1 := make(chan []byte)
41 | go func() { cstderr1 <- <-stderr1 }()
42 | cstderr2 := make(chan []byte)
43 | go func() { cstderr2 <- <-stderr2 }()
44 | cexit1 := make(chan error)
45 | go func() { cexit1 <- <-exit1 }()
46 | cexit2 := make(chan error)
47 | go func() { cexit2 <- <-exit2 }()
48 |
49 | go func() {
50 | stream.stdout <- []byte("hello stdout")
51 | stream.stderr <- []byte("hello stderr")
52 | stream.exit <- fmt.Errorf("hello from error")
53 | }()
54 |
55 | So(string(<-cstdout1), ShouldEqual, "hello stdout")
56 | So(string(<-cstdout2), ShouldEqual, "hello stdout")
57 | So(string(<-cstderr1), ShouldEqual, "hello stderr")
58 | So(string(<-cstderr2), ShouldEqual, "hello stderr")
59 | So((<-cexit1).Error(), ShouldEqual, "hello from error")
60 | So((<-cexit2).Error(), ShouldEqual, "hello from error")
61 |
62 | // let's unregister all 1s
63 | // this will test that all channel will
64 | // correctly unregister and will not block
65 | // the rest.
66 | closeStdout1()
67 | closeStderr1()
68 | closeExit1()
69 |
70 | go func() { cstdout2 <- <-stdout2 }()
71 | go func() { cstderr2 <- <-stderr2 }()
72 | go func() { cexit2 <- <-exit2 }()
73 |
74 | go func() {
75 | stream.stdout <- []byte("hello stdout 2")
76 | stream.stderr <- []byte("hello stderr 2")
77 | stream.exit <- fmt.Errorf("hello from error 2")
78 | }()
79 |
80 | So(string(<-cstdout2), ShouldEqual, "hello stdout 2")
81 | So(string(<-cstderr2), ShouldEqual, "hello stderr 2")
82 | So((<-cexit2).Error(), ShouldEqual, "hello from error 2")
83 | })
84 |
85 | Convey("SendRequest should work", t, func() {
86 |
87 | stream := NewMCPStream(t.Context())
88 |
89 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
90 | defer cancel()
91 |
92 | done := make(chan bool, 1)
93 | cstdin := make(chan []byte, 2)
94 | go func() {
95 | cstdin <- <-stream.stdin
96 | stream.stdout <- []byte(`{"id":"not-id"}`)
97 | stream.stdout <- []byte(`{"id":"not-id-again-3"}`)
98 | stream.stdout <- []byte(`{"id":"id","result":{"a":1}}`)
99 | stream.stdout <- []byte(`{"id":"not-id-again-4"}`)
100 | done <- true
101 | }()
102 |
103 | resp, err := stream.SendRequest(ctx, mcp.NewMessage("id"))
104 |
105 | So(string(<-cstdin), ShouldResemble, `{"id":"id","jsonrpc":"2.0"}`)
106 |
107 | So(err, ShouldBeNil)
108 | So(resp, ShouldNotBeNil)
109 | So(resp.ID, ShouldEqual, "id")
110 | So(resp.Result["a"], ShouldEqual, 1)
111 | So(<-done, ShouldBeTrue)
112 | })
113 |
114 | Convey("SendRequest should work when context cancels before sending data", t, func() {
115 |
116 | stream := NewMCPStream(t.Context())
117 |
118 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
119 | cancel()
120 |
121 | _, err := stream.SendRequest(ctx, mcp.NewMessage("id"))
122 | So(err, ShouldNotBeNil)
123 | So(err.Error(), ShouldEqual, "unable to send request: context canceled")
124 | })
125 |
126 | Convey("SendRequest should work when context cancels while awaiting for response", t, func() {
127 |
128 | stream := NewMCPStream(t.Context())
129 |
130 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
131 | defer cancel()
132 |
133 | done := make(chan bool, 1)
134 | cstdin := make(chan []byte, 1)
135 | go func() {
136 | cstdin <- <-stream.stdin
137 | stream.stdout <- []byte(`{"id":"not-id"}`)
138 | stream.stdout <- []byte(`{"id":"not-id-again-3"}`)
139 | stream.stdout <- []byte(`{"id":"not-id-again-4"}`)
140 | cancel()
141 | done <- true
142 | }()
143 |
144 | _, err := stream.SendRequest(ctx, mcp.NewMessage("id"))
145 |
146 | So(string(<-cstdin), ShouldResemble, `{"id":"id","jsonrpc":"2.0"}`)
147 | So(err, ShouldNotBeNil)
148 | So(err.Error(), ShouldEqual, "unable to get response: context canceled")
149 | })
150 |
151 | Convey("SendRequest should handle invalid json response for summary", t, func() {
152 |
153 | stream := NewMCPStream(t.Context())
154 |
155 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
156 | defer cancel()
157 |
158 | go func() {
159 | <-stream.stdin
160 | stream.stdout <- []byte(`"id":"not-id"}`)
161 | }()
162 |
163 | _, err := stream.SendRequest(ctx, mcp.NewMessage("id"))
164 |
165 | So(err, ShouldNotBeNil)
166 | So(err.Error(), ShouldStartWith, "unable to decode mcp call as summary: ")
167 | })
168 |
169 | Convey("SendRequest should handle invalid json response for mcp call", t, func() {
170 |
171 | stream := NewMCPStream(t.Context())
172 |
173 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second)
174 | defer cancel()
175 |
176 | go func() {
177 | <-stream.stdin
178 | stream.stdout <- []byte(`{"id":"id", "result": "not a map"}`)
179 | }()
180 |
181 | _, err := stream.SendRequest(ctx, mcp.NewMessage("id"))
182 |
183 | So(err, ShouldNotBeNil)
184 | So(err.Error(), ShouldStartWith, "unable to decode mcp call: ")
185 | })
186 |
187 | Convey("calling SendPaginatedRequest should work", t, func() {
188 |
189 | stream := NewMCPStream(t.Context())
190 |
191 | ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
192 | defer cancel()
193 |
194 | // enqueue the resp immediately
195 | cstdin := make(chan []byte, 3)
196 | go func() {
197 | cstdin <- <-stream.stdin
198 | stream.stdout <- []byte(`{"id":"a","jsonrpc":"2.0","result":{"nextCursor":"1"}}`)
199 | cstdin <- <-stream.stdin
200 | stream.stdout <- []byte(`{"id":"a","jsonrpc":"2.0"}`)
201 | }()
202 |
203 | calls, err := stream.SendPaginatedRequest(ctx, mcp.NewMessage("a"))
204 | So(err, ShouldBeNil)
205 | So(len(calls), ShouldEqual, 2)
206 |
207 | So(string(<-cstdin), ShouldEqual, `{"id":"a","jsonrpc":"2.0"}`)
208 | So(string(<-cstdin), ShouldEqual, `{"id":"a","jsonrpc":"2.0","params":{"cursor":"1"}}`)
209 | })
210 |
211 | Convey("calling SendPaginatedRequest while context cancels should work", t, func() {
212 |
213 | ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
214 | defer cancel()
215 |
216 | stream := NewMCPStream(ctx)
217 |
218 | sctx, cancel := context.WithCancel(t.Context())
219 | cancel()
220 |
221 | // enqueue the resp immediately
222 | go func() { stream.stdout <- []byte(`{"id":46,"jsonrpc":"2.0","result":{"nextCursor":"1"}}`) }()
223 |
224 | calls, err := stream.SendPaginatedRequest(sctx, mcp.NewMessage(44))
225 | So(err, ShouldNotBeNil)
226 | So(err.Error(), ShouldEqual, "unable to send paginated request: unable to send request: context canceled")
227 | So(len(calls), ShouldEqual, 0)
228 | })
229 | }
230 |
--------------------------------------------------------------------------------
/pkgs/backend/client/stream.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log/slog"
7 | "sync"
8 | "time"
9 |
10 | "go.acuvity.ai/elemental"
11 | "go.acuvity.ai/minibridge/pkgs/mcp"
12 | )
13 |
14 | // MCPStream holds the MCP Server stdio streams as channels.
15 | //
16 | // It only deals with []byte containing MCP messages. The data
17 | // is not validated, in or out.
18 | //
19 | // It accepts input via the chan []byte returned by Stdin().
20 | //
21 | // To access stdout, or stderr or the exit channel, you must call
22 | // Stdout(), Stderr() or Exit() to get a chan []byte you can
23 | // pull data from.
24 | //
25 | // The returned channel is registered in a pool of other subcriber
26 | // channels that will all receive a broadcast of the data when they arrive.
27 | //
28 | // These functions also return a func() that must be called to
29 | // unregister the channel from the pool.
30 | // Failure to do so will leak go routines.
31 | //
32 | // The consumers of the channels must be mindful of the rest of the system.
33 | // All channel operations are blocking to avoid possibly complete out of order
34 | // messages (responses before requests). So we you register a channel, be sure
35 | // to consume as fast as possible.
36 | type MCPStream struct {
37 | stdin chan []byte
38 | stdout chan []byte
39 | stderr chan []byte
40 | exit chan error
41 |
42 | outChs map[chan []byte]struct{}
43 | errChs map[chan []byte]struct{}
44 | exitChs map[chan error]struct{}
45 |
46 | sync.RWMutex
47 | }
48 |
49 | // NewMCPStream returns an initialized *MCPStream.
50 | // It will start a listener in the background that will run
51 | // ultimately up until the provided context cancels.
52 | func NewMCPStream(ctx context.Context) *MCPStream {
53 |
54 | s := &MCPStream{
55 | stderr: make(chan []byte),
56 | stdout: make(chan []byte),
57 | stdin: make(chan []byte),
58 | exit: make(chan error),
59 | outChs: map[chan []byte]struct{}{},
60 | errChs: map[chan []byte]struct{}{},
61 | exitChs: map[chan error]struct{}{},
62 | }
63 |
64 | s.start(ctx)
65 |
66 | return s
67 | }
68 |
69 | // Stdin returns a channel that will accepts []byte
70 | // containing MCP messages from the client.
71 | func (s *MCPStream) Stdin() chan []byte {
72 | return s.stdin
73 | }
74 |
75 | // Stdout returns a channel that will produce []byte
76 | // containing MCP messages from the MCP server.
77 | // It also returns a function that must be called
78 | // when the channel is not needed anymore.
79 | // Failure to do so will leak go routines.
80 | func (s *MCPStream) Stdout() (chan []byte, func()) {
81 | c := make(chan []byte, 8)
82 | s.registerOut(c)
83 | return c, func() { s.unregisterOut(c) }
84 | }
85 |
86 | // Stderr returns a channel that will produce []byte
87 | // containing MCP Server logs.
88 | // It also returns a function that must be called
89 | // when the channel is not needed anymore.
90 | // Failure to do so will leak go routines.
91 | func (s *MCPStream) Stderr() (chan []byte, func()) {
92 | c := make(chan []byte, 8)
93 | s.registerErr(c)
94 | return c, func() { s.unregisterErr(c) }
95 | }
96 |
97 | // Exit returns a channel that will produce an error
98 | // representing the end of the MCP server execution.
99 | // Once a message is received from this channel,
100 | // The MCPStream should be considered dead.
101 | // It also returns a function that must be called
102 | // when the channel is not needed anymore.
103 | // Failure to do so will leak go routines.
104 | func (s *MCPStream) Exit() (chan error, func()) {
105 | c := make(chan error, 1)
106 | s.registerExit(c)
107 | return c, func() { s.unregisterExit(c) }
108 | }
109 |
110 | // SendNotification sends a mcp.Message without waiting for a reply.
111 | func (s *MCPStream) SendNotification(ctx context.Context, notif mcp.Notification) error {
112 |
113 | data, err := elemental.Encode(elemental.EncodingTypeJSON, notif)
114 | if err != nil {
115 | return fmt.Errorf("unable to encode mcp notification: %w", err)
116 | }
117 |
118 | select {
119 | case s.stdin <- data:
120 | case <-ctx.Done():
121 | return fmt.Errorf("unable to send mcp notification: %w", ctx.Err())
122 | }
123 |
124 | return nil
125 | }
126 |
127 | // SendRequest sends the given MCP request and waits for a MCP response related to the
128 | // request ID, up until the provider context expires.
129 | // The request is not validated.
130 | func (s *MCPStream) SendRequest(ctx context.Context, req mcp.Message) (resp mcp.Message, err error) {
131 |
132 | data, err := elemental.Encode(elemental.EncodingTypeJSON, req)
133 | if err != nil {
134 | return resp, fmt.Errorf("unable to encode mcp call: %w", err)
135 | }
136 |
137 | stdout, unregister := s.Stdout()
138 | defer unregister()
139 |
140 | select {
141 | case s.stdin <- data:
142 | case <-ctx.Done():
143 | return resp, fmt.Errorf("unable to send request: %w", ctx.Err())
144 | }
145 |
146 | summary := struct {
147 | ID any `json:"id"`
148 | }{}
149 |
150 | for {
151 | select {
152 |
153 | case <-ctx.Done():
154 | return resp, fmt.Errorf("unable to get response: %w", ctx.Err())
155 |
156 | case data := <-stdout:
157 |
158 | if err := elemental.Decode(elemental.EncodingTypeJSON, data, &summary); err != nil {
159 | return req, fmt.Errorf("unable to decode mcp call as summary: %w", err)
160 | }
161 |
162 | if !mcp.RelatedIDs(req.ID, summary.ID) {
163 | continue
164 | }
165 |
166 | if err := elemental.Decode(elemental.EncodingTypeJSON, data, &resp); err != nil {
167 | return req, fmt.Errorf("unable to decode mcp call: %w", err)
168 | }
169 |
170 | return resp, nil
171 | }
172 | }
173 | }
174 |
175 | // SendPaginatedRequest works like SendRequest, but will retrieve the next pages until
176 | // it reaches the end. All responses are returned at once in an slice of mcp.Message.
177 | func (s *MCPStream) SendPaginatedRequest(ctx context.Context, msg mcp.Message) (out []mcp.Message, err error) {
178 |
179 | var resp mcp.Message
180 |
181 | currentRequest := msg
182 |
183 | for {
184 |
185 | resp, err = s.SendRequest(ctx, currentRequest)
186 | if err != nil {
187 | return nil, fmt.Errorf("unable to send paginated request: %w", err)
188 | }
189 |
190 | out = append(out, resp)
191 |
192 | cursor, ok := resp.Result["nextCursor"].(string)
193 | if !ok || cursor == "" {
194 | return out, nil
195 | }
196 |
197 | currentRequest = mcp.NewMessage("")
198 | currentRequest.ID = msg.ID
199 | currentRequest.Method = msg.Method
200 | currentRequest.Params = map[string]any{"cursor": cursor}
201 | }
202 | }
203 |
204 | func (s *MCPStream) registerOut(ch chan []byte) { s.Lock(); s.outChs[ch] = struct{}{}; s.Unlock() }
205 | func (s *MCPStream) unregisterOut(ch chan []byte) { s.Lock(); delete(s.outChs, ch); s.Unlock() }
206 |
207 | func (s *MCPStream) registerErr(ch chan []byte) { s.Lock(); s.errChs[ch] = struct{}{}; s.Unlock() }
208 | func (s *MCPStream) unregisterErr(ch chan []byte) { s.Lock(); delete(s.errChs, ch); s.Unlock() }
209 |
210 | func (s *MCPStream) registerExit(ch chan error) { s.Lock(); s.exitChs[ch] = struct{}{}; s.Unlock() }
211 | func (s *MCPStream) unregisterExit(ch chan error) { s.Lock(); delete(s.exitChs, ch); s.Unlock() }
212 |
213 | func (s *MCPStream) start(ctx context.Context) {
214 |
215 | go func() {
216 |
217 | for {
218 | select {
219 |
220 | case data := <-s.stdout:
221 |
222 | s.RLock()
223 | for c := range s.outChs {
224 | select {
225 | case c <- data:
226 | default:
227 | slog.Error("stdout message dropped for registered channel")
228 | }
229 | }
230 | s.RUnlock()
231 |
232 | case data := <-s.stderr:
233 |
234 | s.RLock()
235 | for c := range s.errChs {
236 | select {
237 | case c <- data:
238 | default:
239 | slog.Error("stderr message dropped for registered channel")
240 | }
241 | }
242 | s.RUnlock()
243 |
244 | case err := <-s.exit:
245 |
246 | s.RLock()
247 | for c := range s.exitChs {
248 | select {
249 | case c <- err:
250 | default:
251 | slog.Error("exit message dropped for registered channel")
252 | }
253 | }
254 | s.RUnlock()
255 |
256 | case <-ctx.Done():
257 |
258 | var err error
259 | select {
260 | case err = <-s.exit:
261 | case <-time.After(time.Second):
262 | break
263 | }
264 |
265 | if err == nil {
266 | err = ctx.Err()
267 | }
268 |
269 | s.RLock()
270 | for c := range s.exitChs {
271 | select {
272 | case c <- err:
273 | default:
274 | }
275 | }
276 | s.RUnlock()
277 | }
278 | }
279 | }()
280 | }
281 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module go.acuvity.ai/minibridge
2 |
3 | go 1.24.2
4 |
5 | require (
6 | go.acuvity.ai/a3s v0.0.0-20250513140342-5386bdb1c75f
7 | go.acuvity.ai/bahamut v0.0.0-20250416135203-fa73110b2604
8 | go.acuvity.ai/elemental v0.0.0-20250430230636-ac931152934a
9 | go.acuvity.ai/manipulate v0.0.0-20250416135246-f7d22e975de4 // indirect
10 | go.acuvity.ai/regolithe v0.0.0-20250321141528-1fe83b60f317 // indirect
11 | go.acuvity.ai/tg v0.0.0-20250220234315-d9494083aa3a
12 | go.acuvity.ai/wsc v0.0.0-20250506232542-8de7ff436ec0
13 | )
14 |
15 | require (
16 | github.com/adrg/xdg v0.5.3
17 | github.com/go-viper/mapstructure/v2 v2.4.0
18 | github.com/gofrs/uuid v4.4.0+incompatible
19 | github.com/gorilla/websocket v1.5.3
20 | github.com/karlseguin/ccache/v3 v3.0.6
21 | github.com/mitchellh/mapstructure v1.5.0
22 | github.com/open-policy-agent/opa v1.4.0
23 | github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
24 | github.com/prometheus/client_golang v1.22.0
25 | github.com/smallnest/ringbuffer v0.0.0-20250317021400-0da97b586904
26 | github.com/smartystreets/goconvey v1.8.1
27 | github.com/spaolacci/murmur3 v1.1.0
28 | github.com/spf13/cobra v1.9.1
29 | github.com/spf13/pflag v1.0.6
30 | github.com/spf13/viper v1.20.1
31 | github.com/zalando/go-keyring v0.2.6
32 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0
33 | go.opentelemetry.io/otel v1.35.0
34 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.35.0
35 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0
36 | go.opentelemetry.io/otel/sdk v1.35.0
37 | go.opentelemetry.io/otel/trace v1.35.0
38 | golang.org/x/sync v0.13.0
39 | )
40 |
41 | require (
42 | al.essio.dev/pkg/shellescape v1.5.1 // indirect
43 | cloud.google.com/go/compute/metadata v0.6.0 // indirect
44 | dario.cat/mergo v1.0.1 // indirect
45 | github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect
46 | github.com/Microsoft/go-winio v0.6.2 // indirect
47 | github.com/NYTimes/gziphandler v1.1.1 // indirect
48 | github.com/ProtonMail/go-crypto v1.1.5 // indirect
49 | github.com/agnivade/levenshtein v1.2.1 // indirect
50 | github.com/apparentlymart/go-cidr v1.1.0 // indirect
51 | github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect
52 | github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a // indirect
53 | github.com/beorn7/perks v1.0.1 // indirect
54 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect
55 | github.com/cespare/xxhash/v2 v2.3.0 // indirect
56 | github.com/cloudflare/circl v1.6.1 // indirect
57 | github.com/cyphar/filepath-securejoin v0.3.6 // indirect
58 | github.com/danieljoos/wincred v1.2.2 // indirect
59 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
60 | github.com/deckarep/golang-set v1.8.0 // indirect
61 | github.com/emirpasic/gods v1.18.1 // indirect
62 | github.com/fatih/color v1.18.0 // indirect
63 | github.com/fatih/structs v1.1.0 // indirect
64 | github.com/felixge/httpsnoop v1.0.4 // indirect
65 | github.com/fsnotify/fsnotify v1.8.0 // indirect
66 | github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 // indirect
67 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
68 | github.com/go-git/go-billy/v5 v5.6.1 // indirect
69 | github.com/go-git/go-git/v5 v5.13.1 // indirect
70 | github.com/go-ini/ini v1.67.0 // indirect
71 | github.com/go-logr/logr v1.4.2 // indirect
72 | github.com/go-logr/stdr v1.2.2 // indirect
73 | github.com/go-ole/go-ole v1.2.6 // indirect
74 | github.com/go-zoo/bone v1.3.0 // indirect
75 | github.com/gobwas/glob v0.2.3 // indirect
76 | github.com/godbus/dbus/v5 v5.1.0 // indirect
77 | github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
78 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
79 | github.com/google/go-tpm v0.9.3 // indirect
80 | github.com/google/uuid v1.6.0 // indirect
81 | github.com/gopherjs/gopherjs v1.17.2 // indirect
82 | github.com/gorilla/mux v1.8.1 // indirect
83 | github.com/gravitational/trace v1.5.0 // indirect
84 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect
85 | github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f // indirect
86 | github.com/inconshreveable/mousetrap v1.1.0 // indirect
87 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
88 | github.com/jtolds/gls v4.20.0+incompatible // indirect
89 | github.com/karlseguin/ccache/v2 v2.0.8 // indirect
90 | github.com/kevinburke/ssh_config v1.2.0 // indirect
91 | github.com/klauspost/compress v1.18.0 // indirect
92 | github.com/lmittmann/tint v1.0.3 // indirect
93 | github.com/lufia/plan9stats v0.0.0-20230110061619-bbe2e5e100de // indirect
94 | github.com/mailgun/multibuf v0.2.0 // indirect
95 | github.com/mattn/go-colorable v0.1.13 // indirect
96 | github.com/mattn/go-isatty v0.0.20 // indirect
97 | github.com/mdp/qrterminal v1.0.1 // indirect
98 | github.com/minio/highwayhash v1.0.3 // indirect
99 | github.com/mitchellh/copystructure v1.2.0 // indirect
100 | github.com/mitchellh/go-wordwrap v1.0.1 // indirect
101 | github.com/mitchellh/reflectwalk v1.0.2 // indirect
102 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
103 | github.com/nats-io/jwt/v2 v2.7.3 // indirect
104 | github.com/nats-io/nats-server/v2 v2.11.1 // indirect
105 | github.com/nats-io/nats.go v1.39.1 // indirect
106 | github.com/nats-io/nkeys v0.4.10 // indirect
107 | github.com/nats-io/nuid v1.0.1 // indirect
108 | github.com/opentracing/opentracing-go v1.2.0 // indirect
109 | github.com/pelletier/go-toml/v2 v2.2.3 // indirect
110 | github.com/pjbgf/sha1cd v0.3.2 // indirect
111 | github.com/pkg/errors v0.9.1 // indirect
112 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
113 | github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b // indirect
114 | github.com/prometheus/client_model v0.6.1 // indirect
115 | github.com/prometheus/common v0.62.0 // indirect
116 | github.com/prometheus/procfs v0.15.1 // indirect
117 | github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect
118 | github.com/sagikazarmark/locafero v0.7.0 // indirect
119 | github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
120 | github.com/shirou/gopsutil/v3 v3.24.5 // indirect
121 | github.com/shoenig/go-m1cpu v0.1.6 // indirect
122 | github.com/sirupsen/logrus v1.9.3 // indirect
123 | github.com/skeema/knownhosts v1.3.0 // indirect
124 | github.com/smarty/assertions v1.15.0 // indirect
125 | github.com/sourcegraph/conc v0.3.0 // indirect
126 | github.com/spf13/afero v1.12.0 // indirect
127 | github.com/spf13/cast v1.7.1 // indirect
128 | github.com/subosito/gotenv v1.6.0 // indirect
129 | github.com/tchap/go-patricia/v2 v2.3.2 // indirect
130 | github.com/tklauser/go-sysconf v0.3.12 // indirect
131 | github.com/tklauser/numcpus v0.6.1 // indirect
132 | github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect
133 | github.com/uber/jaeger-lib v2.4.1+incompatible // indirect
134 | github.com/ugorji/go/codec v1.2.12 // indirect
135 | github.com/valyala/tcplisten v1.0.0 // indirect
136 | github.com/vulcand/oxy/v2 v2.0.2 // indirect
137 | github.com/vulcand/predicate v1.2.0 // indirect
138 | github.com/xanzy/ssh-agent v0.3.3 // indirect
139 | github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
140 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
141 | github.com/xeipuuv/gojsonschema v1.2.0 // indirect
142 | github.com/yashtewari/glob-intersection v0.2.0 // indirect
143 | github.com/yusufpapurcu/wmi v1.2.4 // indirect
144 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect
145 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.35.0 // indirect
146 | go.opentelemetry.io/otel/metric v1.35.0 // indirect
147 | go.opentelemetry.io/proto/otlp v1.5.0 // indirect
148 | go.uber.org/atomic v1.11.0 // indirect
149 | go.uber.org/automaxprocs v1.6.0 // indirect
150 | go.uber.org/multierr v1.11.0 // indirect
151 | golang.org/x/crypto v0.37.0 // indirect
152 | golang.org/x/net v0.39.0 // indirect
153 | golang.org/x/sys v0.32.0 // indirect
154 | golang.org/x/text v0.24.0 // indirect
155 | golang.org/x/time v0.11.0 // indirect
156 | google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect
157 | google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect
158 | google.golang.org/grpc v1.71.1 // indirect
159 | google.golang.org/protobuf v1.36.6 // indirect
160 | gopkg.in/ini.v1 v1.67.0 // indirect
161 | gopkg.in/warnings.v0 v0.1.2 // indirect
162 | gopkg.in/yaml.v2 v2.4.0 // indirect
163 | gopkg.in/yaml.v3 v3.0.1 // indirect
164 | rsc.io/qr v0.2.0 // indirect
165 | sigs.k8s.io/yaml v1.4.0 // indirect
166 | )
167 |
--------------------------------------------------------------------------------
/pkgs/scan/sbom_test.go:
--------------------------------------------------------------------------------
1 | package scan
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestSBOM_Matches(t *testing.T) {
8 | type args struct {
9 | o Hashes
10 | }
11 | tests := []struct {
12 | name string
13 | init func(t *testing.T) Hashes
14 | inspect func(r Hashes, t *testing.T) //inspects receiver after test run
15 |
16 | args func(t *testing.T) args
17 |
18 | wantErr bool
19 | inspectErr func(err error, t *testing.T) //use for more precise error evaluation after test
20 | }{
21 | {
22 | "matching",
23 | func(t *testing.T) Hashes {
24 | return Hashes{
25 | {
26 | Name: "a1",
27 | Hash: "ah1",
28 | Params: Hashes{
29 | {
30 | Name: "p1",
31 | Hash: "ph1",
32 | },
33 | },
34 | },
35 | }
36 | },
37 | nil,
38 | func(*testing.T) args {
39 | return args{
40 | Hashes{
41 | {
42 | Name: "a1",
43 | Hash: "ah1",
44 | Params: Hashes{
45 | {
46 | Name: "p1",
47 | Hash: "ph1",
48 | },
49 | },
50 | },
51 | },
52 | }
53 | },
54 | false,
55 | nil,
56 | },
57 | {
58 | "no params",
59 | func(t *testing.T) Hashes {
60 | return Hashes{
61 | {
62 | Name: "a1",
63 | Hash: "ah1",
64 | },
65 | }
66 | },
67 | nil,
68 | func(*testing.T) args {
69 | return args{
70 | Hashes{
71 | {
72 | Name: "a1",
73 | Hash: "ah1",
74 | },
75 | },
76 | }
77 | },
78 | false,
79 | nil,
80 | },
81 | {
82 | "empty",
83 | func(t *testing.T) Hashes {
84 | return Hashes{}
85 | },
86 | nil,
87 | func(*testing.T) args {
88 | return args{
89 | Hashes{},
90 | }
91 | },
92 | false,
93 | nil,
94 | },
95 | {
96 | "missing param",
97 | func(t *testing.T) Hashes {
98 | return Hashes{
99 | {
100 | Name: "a1",
101 | Hash: "ah1",
102 | Params: Hashes{
103 | {
104 | Name: "p1",
105 | Hash: "ph1",
106 | },
107 | {
108 | Name: "p2",
109 | Hash: "ph2",
110 | },
111 | },
112 | },
113 | }
114 | },
115 | nil,
116 | func(*testing.T) args {
117 | return args{
118 | Hashes{
119 | {
120 | Name: "a1",
121 | Hash: "ah1",
122 | Params: Hashes{
123 | {
124 | Name: "p1",
125 | Hash: "ph1",
126 | },
127 | },
128 | },
129 | },
130 | }
131 | },
132 | false,
133 | nil,
134 | },
135 | {
136 | "extra param",
137 | func(t *testing.T) Hashes {
138 | return Hashes{
139 | {
140 | Name: "a1",
141 | Hash: "ah1",
142 | Params: Hashes{
143 | {
144 | Name: "p1",
145 | Hash: "ph1",
146 | },
147 | },
148 | },
149 | }
150 | },
151 | nil,
152 | func(*testing.T) args {
153 | return args{
154 | Hashes{
155 | {
156 | Name: "a1",
157 | Hash: "ah1",
158 | Params: Hashes{
159 | {
160 | Name: "p1",
161 | Hash: "ph1",
162 | },
163 | {
164 | Name: "p2",
165 | Hash: "ph2",
166 | },
167 | },
168 | },
169 | },
170 | }
171 | },
172 | true,
173 | func(err error, t *testing.T) {
174 | want := "'a1': invalid param: invalid len. left: 1 right: 2"
175 | if err.Error() != want {
176 | t.Logf("invalid err. want: %s got: %s", want, err.Error())
177 | t.Fail()
178 | }
179 | },
180 | },
181 | {
182 | "missing tool",
183 | func(t *testing.T) Hashes {
184 | return Hashes{
185 | {
186 | Name: "a1",
187 | Hash: "ah1",
188 | Params: Hashes{
189 | {
190 | Name: "p1",
191 | Hash: "ph1",
192 | },
193 | },
194 | },
195 | {
196 | Name: "a2",
197 | Hash: "ah2",
198 | Params: Hashes{
199 | {
200 | Name: "p1",
201 | Hash: "ph1",
202 | },
203 | },
204 | },
205 | }
206 | },
207 | nil,
208 | func(*testing.T) args {
209 | return args{
210 | Hashes{
211 | {
212 | Name: "a1",
213 | Hash: "ah1",
214 | Params: Hashes{
215 | {
216 | Name: "p1",
217 | Hash: "ph1",
218 | },
219 | },
220 | },
221 | },
222 | }
223 | },
224 | false,
225 | nil,
226 | },
227 | {
228 | "extra tool",
229 | func(t *testing.T) Hashes {
230 | return Hashes{
231 | {
232 | Name: "a1",
233 | Hash: "ah1",
234 | Params: Hashes{
235 | {
236 | Name: "p1",
237 | Hash: "ph1",
238 | },
239 | },
240 | },
241 | }
242 | },
243 | nil,
244 | func(*testing.T) args {
245 | return args{
246 | Hashes{
247 | {
248 | Name: "a1",
249 | Hash: "ah1",
250 | Params: Hashes{
251 | {
252 | Name: "p1",
253 | Hash: "ph1",
254 | },
255 | },
256 | },
257 | {
258 | Name: "a2",
259 | Hash: "ah2",
260 | Params: Hashes{
261 | {
262 | Name: "p1",
263 | Hash: "ph1",
264 | },
265 | },
266 | },
267 | },
268 | }
269 | },
270 | true,
271 | func(err error, t *testing.T) {
272 | want := "invalid len. left: 1 right: 2"
273 | if err.Error() != want {
274 | t.Logf("invalid err. want: %s got: %s", want, err.Error())
275 | t.Fail()
276 | }
277 | },
278 | },
279 | {
280 | "invalid tool hash",
281 | func(t *testing.T) Hashes {
282 | return Hashes{
283 | {
284 | Name: "a1",
285 | Hash: "ah1",
286 | Params: Hashes{
287 | {
288 | Name: "p1",
289 | Hash: "ph1",
290 | },
291 | },
292 | },
293 | }
294 | },
295 | nil,
296 | func(*testing.T) args {
297 | return args{
298 | Hashes{
299 | {
300 | Name: "a1",
301 | Hash: "NOT_ah1",
302 | Params: Hashes{
303 | {
304 | Name: "p1",
305 | Hash: "ph1",
306 | },
307 | },
308 | },
309 | },
310 | }
311 | },
312 | true,
313 | func(err error, t *testing.T) {
314 | want := "'a1': hash mismatch"
315 | if err.Error() != want {
316 | t.Logf("invalid err. want: %s got: %s", want, err.Error())
317 | t.Fail()
318 | }
319 | },
320 | },
321 | {
322 | "invalid param hash",
323 | func(t *testing.T) Hashes {
324 | return Hashes{
325 | {
326 | Name: "a1",
327 | Hash: "ah1",
328 | Params: Hashes{
329 | {
330 | Name: "p1",
331 | Hash: "ph1",
332 | },
333 | },
334 | },
335 | }
336 | },
337 | nil,
338 | func(*testing.T) args {
339 | return args{
340 | Hashes{
341 | {
342 | Name: "a1",
343 | Hash: "ah1",
344 | Params: Hashes{
345 | {
346 | Name: "p1",
347 | Hash: "NOT-ph1",
348 | },
349 | },
350 | },
351 | },
352 | }
353 | },
354 | true,
355 | func(err error, t *testing.T) {
356 | want := "'a1': invalid param: 'p1': hash mismatch"
357 | if err.Error() != want {
358 | t.Logf("invalid err. want: %s got: %s", want, err.Error())
359 | t.Fail()
360 | }
361 | },
362 | },
363 | {
364 | "same len, different tool",
365 | func(t *testing.T) Hashes {
366 | return Hashes{
367 | {
368 | Name: "a1",
369 | Hash: "ah1",
370 | Params: Hashes{
371 | {
372 | Name: "p1",
373 | Hash: "ph1",
374 | },
375 | },
376 | },
377 | }
378 | },
379 | nil,
380 | func(*testing.T) args {
381 | return args{
382 | Hashes{
383 | {
384 | Name: "b1",
385 | Hash: "bh1",
386 | Params: Hashes{
387 | {
388 | Name: "p1",
389 | Hash: "ph1",
390 | },
391 | },
392 | },
393 | },
394 | }
395 | },
396 | true,
397 | func(err error, t *testing.T) {
398 | want := "'b1': missing"
399 | if err.Error() != want {
400 | t.Logf("invalid err. want: %s got: %s", want, err.Error())
401 | t.Fail()
402 | }
403 | },
404 | },
405 | {
406 | "param name missing",
407 | func(t *testing.T) Hashes {
408 | return Hashes{
409 | {
410 | Name: "a1",
411 | Hash: "ah1",
412 | Params: Hashes{
413 | {
414 | Name: "p1",
415 | Hash: "ph1",
416 | },
417 | },
418 | },
419 | }
420 | },
421 | nil,
422 | func(*testing.T) args {
423 | return args{
424 | Hashes{
425 | {
426 | Name: "a1",
427 | Hash: "ah1",
428 | Params: Hashes{
429 | {
430 | Name: "q1",
431 | Hash: "qh1",
432 | },
433 | },
434 | },
435 | },
436 | }
437 | },
438 | true,
439 | func(err error, t *testing.T) {
440 | want := "'a1': invalid param: 'q1': missing"
441 | if err.Error() != want {
442 | t.Logf("invalid err. want: %s got: %s", want, err.Error())
443 | t.Fail()
444 | }
445 | },
446 | },
447 | }
448 |
449 | for _, tt := range tests {
450 | t.Run(tt.name, func(t *testing.T) {
451 | tArgs := tt.args(t)
452 |
453 | receiver := tt.init(t)
454 | err := receiver.Matches(tArgs.o)
455 |
456 | if tt.inspect != nil {
457 | tt.inspect(receiver, t)
458 | }
459 |
460 | if (err != nil) != tt.wantErr {
461 | t.Fatalf("SBOM.Matches error = %v, wantErr: %t", err, tt.wantErr)
462 | }
463 |
464 | if tt.inspectErr != nil {
465 | tt.inspectErr(err, t)
466 | }
467 | })
468 | }
469 | }
470 |
--------------------------------------------------------------------------------