├── bin
├── hermit.hcl
├── go
├── gofmt
├── .go-1.23.4.pkg
├── README.hermit.md
├── activate-hermit
└── hermit
├── testdata
├── docstore.json
└── simple_store.json
├── renovate.json
├── .idea
├── vcs.xml
├── .gitignore
├── modules.xml
└── gollum.iml
├── packages
├── agents
│ ├── agent.go
│ ├── calcagent_test.go
│ └── calcagent.go
├── tools
│ ├── tools.go
│ ├── calculator_test.go
│ ├── readme.md
│ └── calculator.go
├── llm
│ ├── readme.md
│ ├── modelconfigs.go
│ ├── cache
│ │ └── cache.go
│ ├── llm.go
│ ├── providers
│ │ ├── cached
│ │ │ ├── cached_provider.go
│ │ │ ├── cached_embedder.go
│ │ │ ├── sqlitecache
│ │ │ │ └── sqlite.go
│ │ │ └── cached_provider_test.go
│ │ ├── voyage
│ │ │ └── voyage.go
│ │ ├── mixedbread
│ │ │ └── mixedbread.go
│ │ ├── anthropic
│ │ │ └── anthropic.go
│ │ ├── openai
│ │ │ └── openai.go
│ │ ├── vertex
│ │ │ └── vertex.go
│ │ └── google
│ │ │ └── google.go
│ ├── internal
│ │ └── mocks
│ │ │ └── llm.go
│ └── configs.go
├── hyde
│ ├── README.md
│ ├── hyde_test.go
│ └── hyde.go
├── syncpool
│ └── syncpool.go
├── queryplanner
│ ├── queryplanner.go
│ └── queryplanner_test.go
├── docstore
│ ├── docstore_test.go
│ └── docstore.go
├── dispatch
│ ├── functions.go
│ ├── dispatch.go
│ ├── dispatch_test.go
│ └── functions_test.go
├── jsonparser
│ ├── parser.go
│ └── parser_test.go
└── vectorstore
│ ├── vectorstore.go
│ ├── vectorstore_memory.go
│ ├── vectorstore.md
│ ├── vectorstore_compressed.go
│ ├── vectorstore_test.go
│ └── vectorstore_compressed_test.go
├── gollum.go
├── docs
└── best_practices.md
├── .gitignore
├── llm.go
├── LICENSE
├── internal
├── hash
│ ├── hash_bench_test.go
│ └── readme.md
├── testutil
│ └── utils.go
└── mocks
│ └── llm.go
├── go.mod
├── math_test.go
├── README.md
└── bench.md
/bin/hermit.hcl:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bin/go:
--------------------------------------------------------------------------------
1 | .go-1.23.4.pkg
--------------------------------------------------------------------------------
/bin/gofmt:
--------------------------------------------------------------------------------
1 | .go-1.23.4.pkg
--------------------------------------------------------------------------------
/bin/.go-1.23.4.pkg:
--------------------------------------------------------------------------------
1 | hermit
--------------------------------------------------------------------------------
/testdata/docstore.json:
--------------------------------------------------------------------------------
1 | {"1":{"id":"1","content":"test data"},"2":{"id":"2","content":"test data 2"}}
--------------------------------------------------------------------------------
/renovate.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json",
3 | "extends": [
4 | "config:base"
5 | ]
6 | }
7 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/packages/agents/agent.go:
--------------------------------------------------------------------------------
1 | package agents
2 |
3 | import "context"
4 |
5 | type Agent interface {
6 | Name() string
7 | Description() string
8 | Run(context.Context, interface{}) (interface{}, error)
9 | }
10 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/testdata/simple_store.json:
--------------------------------------------------------------------------------
1 | [{"id":"02a4d737-121b-4845-9649-64781326fdfd","content":"Apple","embedding":[0.1,0.1,0.1]},{"id":"15d85177-1b7e-46fe-9489-77ef27f10ead","content":"Orange","embedding":[0.1,1.1,1.1]},{"id":"f5c42d8f-9c17-424d-934c-88192d4c336f","content":"Basketball","embedding":[0.1,2.1,2.1]}]
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/packages/tools/tools.go:
--------------------------------------------------------------------------------
1 | package tools
2 |
3 | import "context"
4 |
5 | // Tool is an incredibly generic interface for a tool that an LLM can use.
6 | type Tool interface {
7 | Name() string
8 | Description() string
9 | Run(ctx context.Context, input interface{}) (interface{}, error)
10 | }
11 |
--------------------------------------------------------------------------------
/bin/README.hermit.md:
--------------------------------------------------------------------------------
1 | # Hermit environment
2 |
3 | This is a [Hermit](https://github.com/cashapp/hermit) bin directory.
4 |
5 | The symlinks in this directory are managed by Hermit and will automatically
6 | download and install Hermit itself as well as packages. These packages are
7 | local to this environment.
8 |
--------------------------------------------------------------------------------
/.idea/gollum.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/packages/llm/readme.md:
--------------------------------------------------------------------------------
1 | # llm
2 |
3 | Useful abstractions for working with LLMs.
4 |
5 | Features:
6 |
7 | - synchronous and async wrappers (streaming output)
8 | - prompt caching for supported providers
9 | - automatically load supported providers from environment variables
10 |
11 | We support
12 |
13 | - Anthropic
14 | - Google Gemini
15 | - OpenAI
16 | - OpenAI compatible providers (Together, Groq, Hyperbolic, Deepseek, ...)
17 |
--------------------------------------------------------------------------------
/gollum.go:
--------------------------------------------------------------------------------
1 | package gollum
2 |
3 | import (
4 | "github.com/google/uuid"
5 | )
6 |
7 | type Document struct {
8 | ID string `json:"id"`
9 | Content string `json:"content,omitempty"`
10 | Embedding []float32 `json:"embedding,omitempty"`
11 | Metadata map[string]interface{} `json:"metadata,omitempty"`
12 | }
13 |
14 | func NewDocumentFromString(content string) Document {
15 | return Document{
16 | ID: uuid.New().String(),
17 | Content: content,
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/docs/best_practices.md:
--------------------------------------------------------------------------------
1 | # Best Practices
2 |
3 | Collection of tips and tricks for working with LLM's in Go.
4 |
5 | ## Retry
6 |
7 | OpenAI has a somewhat flaky rate limit, and it's easy to hit it. You can use a [retryablehttp](https://pkg.go.dev/github.com/hashicorp/go-retryablehttp) client to retry requests automatically:
8 |
9 | ```go
10 | retryableClient := retryablehttp.NewClient()
11 | retryableClient.RetryMax = 5
12 | oaiCfg := openai.DefaultConfig(mustGetEnv("OPENAI_API_KEY"))
13 | oaiCfg.HTTPClient = retryableClient.StandardClient()
14 | oai := openai.NewClientWithConfig(oaiCfg)
15 | ```
16 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .hermit/
2 | enwik8
3 |
4 | # If you prefer the allow list template instead of the deny list, see community template:
5 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
6 | #
7 | # Binaries for programs and plugins
8 | *.exe
9 | *.exe~
10 | *.dll
11 | *.so
12 | *.dylib
13 |
14 | # Test binary, built with `go test -c`
15 | *.test
16 |
17 | # Output of the go coverage tool, specifically when used with LiteIDE
18 | *.out
19 |
20 | # Dependency directories (remove the comment below to include it)
21 | # vendor/
22 |
23 | # Go workspace file
24 | go.work
25 |
--------------------------------------------------------------------------------
/bin/activate-hermit:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This file must be used with "source bin/activate-hermit" from bash or zsh.
3 | # You cannot run it directly
4 | #
5 | # THIS FILE IS GENERATED; DO NOT MODIFY
6 |
7 | if [ "${BASH_SOURCE-}" = "$0" ]; then
8 | echo "You must source this script: \$ source $0" >&2
9 | exit 33
10 | fi
11 |
12 | BIN_DIR="$(dirname "${BASH_SOURCE[0]:-${(%):-%x}}")"
13 | if "${BIN_DIR}/hermit" noop > /dev/null; then
14 | eval "$("${BIN_DIR}/hermit" activate "${BIN_DIR}/..")"
15 |
16 | if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ]; then
17 | hash -r 2>/dev/null
18 | fi
19 |
20 | echo "Hermit environment $("${HERMIT_ENV}"/bin/hermit env HERMIT_ENV) activated"
21 | fi
22 |
--------------------------------------------------------------------------------
/packages/hyde/README.md:
--------------------------------------------------------------------------------
1 | # HyDE
2 |
3 | This module is an implementation of [HyDE: Precise Zero-Shot Dense Retrieval without Relevance Labels](https://github.com/texttron/hyde).
4 |
5 | We differ from the reference Python implementation by using an in-memory exact search (instead of FAISS) and OpenAI Ada embeddings instead of Contriever. Realistically you should expect slightly worse performance (exact search is slower than approximate search, calling OpenAI is probably slower than a local model for smaller batch sizes). However, the results should still be valid.
6 |
7 | However, this modules _only_ expects the interfaces, not the actual implementations -- so you could implement a FAISS interface that connects to a docker instance. Or the same for an embedding service, just make a wrapper compatible with the OpenAI interface.
--------------------------------------------------------------------------------
/packages/syncpool/syncpool.go:
--------------------------------------------------------------------------------
1 | // Package syncpool provides a generic wrapper around sync.Pool
2 | // Copied from https://github.com/mkmik/syncpool
3 | package syncpool
4 |
5 | import (
6 | "sync"
7 | )
8 |
9 | // A Pool is a generic wrapper around a sync.Pool.
10 | type Pool[T any] struct {
11 | pool sync.Pool
12 | }
13 |
14 | // New creates a new Pool with the provided new function.
15 | //
16 | // The equivalent sync.Pool construct is "sync.Pool{New: fn}"
17 | func New[T any](fn func() T) Pool[T] {
18 | return Pool[T]{
19 | pool: sync.Pool{New: func() interface{} { return fn() }},
20 | }
21 | }
22 |
23 | // Get is a generic wrapper around sync.Pool's Get method.
24 | func (p *Pool[T]) Get() T {
25 | return p.pool.Get().(T)
26 | }
27 |
28 | // Get is a generic wrapper around sync.Pool's Put method.
29 | func (p *Pool[T]) Put(x T) {
30 | p.pool.Put(x)
31 | }
32 |
--------------------------------------------------------------------------------
/llm.go:
--------------------------------------------------------------------------------
1 | //go:generate mockgen -source llm.go -destination internal/mocks/llm.go
2 |
3 | package gollum
4 |
5 | import (
6 | "context"
7 |
8 | "github.com/sashabaranov/go-openai"
9 | )
10 |
11 | // Deprecated: who even uses this anymore?
12 | type Completer interface {
13 | CreateCompletion(context.Context, openai.CompletionRequest) (openai.CompletionResponse, error)
14 | }
15 |
16 | // Deprecated: use packages/llm/llm.go implementation
17 | type ChatCompleter interface {
18 | CreateChatCompletion(context.Context, openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)
19 | }
20 |
21 | type Embedder interface {
22 | CreateEmbeddings(context.Context, openai.EmbeddingRequest) (openai.EmbeddingResponse, error)
23 | }
24 |
25 | type Moderator interface {
26 | Moderations(context.Context, openai.ModerationRequest) (openai.ModerationResponse, error)
27 | }
28 |
--------------------------------------------------------------------------------
/packages/llm/modelconfigs.go:
--------------------------------------------------------------------------------
1 | package llm
2 |
3 | type ModelConfigStore struct {
4 | configs map[string]ModelConfig
5 | }
6 |
7 | func NewModelConfigStore() *ModelConfigStore {
8 | return &ModelConfigStore{
9 | configs: configs,
10 | }
11 | }
12 |
13 | func NewModelConfigStoreWithConfigs(configs map[string]ModelConfig) *ModelConfigStore {
14 | return &ModelConfigStore{
15 | configs: configs,
16 | }
17 | }
18 |
19 | func (m *ModelConfigStore) GetConfig(configName string) (ModelConfig, bool) {
20 | config, ok := m.configs[configName]
21 | return config, ok
22 | }
23 |
24 | func (m *ModelConfigStore) GetConfigNames() []string {
25 | var configNames []string
26 | for k := range m.configs {
27 | configNames = append(configNames, k)
28 | }
29 | return configNames
30 | }
31 |
32 | func (m *ModelConfigStore) AddConfig(configName string, config ModelConfig) {
33 | m.configs[configName] = config
34 | }
35 |
--------------------------------------------------------------------------------
/packages/tools/calculator_test.go:
--------------------------------------------------------------------------------
1 | package tools_test
2 |
3 | import (
4 | "context"
5 | tools2 "github.com/stillmatic/gollum/packages/tools"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestCalculator(t *testing.T) {
12 | calc := tools2.CalculatorTool{}
13 | var _ tools2.Tool = &calc
14 | ctx := context.Background()
15 | t.Run("simple", func(t *testing.T) {
16 | calcInput := tools2.CalculatorInput{
17 | Expression: "1 + 1",
18 | }
19 | output, err := calc.Run(ctx, &calcInput)
20 | assert.NoError(t, err)
21 | assert.Equal(t, "2", output)
22 | })
23 |
24 | t.Run("simple with env", func(t *testing.T) {
25 | env := map[string]interface{}{
26 | "foo": 1,
27 | "bar": "baz",
28 | }
29 |
30 | calcInput := tools2.CalculatorInput{
31 | Expression: "foo + foo",
32 | Environment: env,
33 | }
34 | output, err := calc.Run(ctx, calcInput)
35 | assert.NoError(t, err)
36 | assert.Equal(t, "2", output)
37 | })
38 | }
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2023 Christopher Hua
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
9 |
--------------------------------------------------------------------------------
/packages/queryplanner/queryplanner.go:
--------------------------------------------------------------------------------
1 | package queryplanner
2 |
3 | type QueryNode struct {
4 | ID int `json:"id" jsonschema_description:"unique query id - incrementing integer"`
5 | Question string `json:"question" jsonschema:"required" jsonschema_description:"Question we are asking using a question answer system, if we are asking multiple questions, this question is asked by also providing the answers to the sub questions"`
6 | // NodeType string `json:"node_type" jsonschema:"required,enum=single_question,enum=merge_responses" jsonschema_description:"type of question. Either a single question or a multi question merge when there are multiple questions."`
7 | DependencyIDs []int `json:"dependency_ids,omitempty" jsonschema_description:"list of sub-question ID's that need to be answered before this question can be answered. Use a subquery when anything may be unknown, and we need to ask multiple questions to get the answer. Dependencies must only be other query IDs."`
8 | }
9 |
10 | type QueryPlan struct {
11 | Nodes []QueryNode `json:"nodes" jsonschema_description:"list of questions to ask"`
12 | }
13 |
--------------------------------------------------------------------------------
/packages/tools/readme.md:
--------------------------------------------------------------------------------
1 | # tools
2 |
3 | This is an implementation of a Tools interface for LLM's to interact with. Tools are entirely arbitrary and can do whatever you want, and just have a very simple API:
4 |
5 | ```go
6 | type Tool interface {
7 | Name() string
8 | Description() string
9 | Run(ctx context.Context, input interface{}) (interface{}, error)
10 | }
11 | ```
12 |
13 | A simple implementation would then be something like
14 |
15 | ```go
16 | type CalculatorInput struct {
17 | Expression string `json:"expression" jsonschema:"required" jsonschema_description:"mathematical expression to evaluate"`
18 | Environment map[string]interface{} `json:"environment,omitempty" jsonschema_description:"optional environment variables to use when evaluating the expression"`
19 | }
20 |
21 | type CalculatorTool struct{}
22 |
23 | // Run evaluates a mathematical expression and returns it as a string.
24 | func (c *CalculatorTool) Run(ctx context.Context, input interface{}) (interface{}, error) {
25 | // do stuff
26 | }
27 | ```
28 |
29 | Note that the input struct has a lot of JSON Schema mappings -- this is so that we can feed the description directly to OpenAI!
--------------------------------------------------------------------------------
/packages/llm/cache/cache.go:
--------------------------------------------------------------------------------
1 | package cache
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/stillmatic/gollum/packages/llm"
7 | )
8 |
9 | // Cache defines the interface for caching LLM responses and embeddings
10 | type Cache interface {
11 | // GetResponse retrieves a cached response for a given request
12 | GetResponse(ctx context.Context, req llm.InferRequest) (string, error)
13 |
14 | // SetResponse stores a response for a given request
15 | SetResponse(ctx context.Context, req llm.InferRequest, response string) error
16 |
17 | // GetEmbedding retrieves a cached embedding for a given input and model config
18 | GetEmbedding(ctx context.Context, modelConfig string, input string) ([]float32, error)
19 |
20 | // SetEmbedding stores an embedding for a given input and model config
21 | SetEmbedding(ctx context.Context, modelConfig string, input string, embedding []float32) error
22 |
23 | // Close closes the cache, releasing any resources
24 | Close() error
25 |
26 | // GetStats returns cache statistics (e.g., number of requests, cache hits)
27 | GetStats() CacheStats
28 | }
29 |
30 | // CacheStats represents cache usage statistics
31 | type CacheStats struct {
32 | NumRequests int
33 | NumCacheHits int
34 | }
35 |
--------------------------------------------------------------------------------
/packages/tools/calculator.go:
--------------------------------------------------------------------------------
1 | package tools
2 |
3 | import (
4 | "context"
5 | "strconv"
6 |
7 | "github.com/antonmedv/expr"
8 | "github.com/pkg/errors"
9 | )
10 |
11 | type CalculatorInput struct {
12 | Expression string `json:"expression" jsonschema:"required" jsonschema_description:"mathematical expression to evaluate"`
13 | Environment map[string]interface{} `json:"environment,omitempty" jsonschema_description:"optional environment variables to use when evaluating the expression"`
14 | }
15 |
16 | type CalculatorTool struct{}
17 |
18 | func (c *CalculatorTool) Name() string {
19 | return "calculator"
20 | }
21 |
22 | func (c *CalculatorTool) Description() string {
23 | return "evaluate mathematical expressions"
24 | }
25 |
26 | // Run evaluates a mathematical expression and returns it as a string.
27 | func (c *CalculatorTool) Run(ctx context.Context, input interface{}) (interface{}, error) {
28 | cinput, ok := input.(CalculatorInput)
29 | if !ok {
30 | return "", errors.New("invalid input")
31 | }
32 |
33 | output, err := expr.Eval(cinput.Expression, cinput.Environment)
34 | if err != nil {
35 | return "", errors.Wrap(err, "couldn't run expression")
36 | }
37 | switch t := output.(type) {
38 | case string:
39 | return t, nil
40 | case int:
41 | return strconv.Itoa(t), nil
42 | case float64:
43 | return strconv.FormatFloat(t, 'f', -1, 64), nil
44 | default:
45 | return "", errors.New("invalid output")
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/bin/hermit:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # THIS FILE IS GENERATED; DO NOT MODIFY
4 |
5 | set -eo pipefail
6 |
7 | export HERMIT_USER_HOME=~
8 |
9 | if [ -z "${HERMIT_STATE_DIR}" ]; then
10 | case "$(uname -s)" in
11 | Darwin)
12 | export HERMIT_STATE_DIR="${HERMIT_USER_HOME}/Library/Caches/hermit"
13 | ;;
14 | Linux)
15 | export HERMIT_STATE_DIR="${XDG_CACHE_HOME:-${HERMIT_USER_HOME}/.cache}/hermit"
16 | ;;
17 | esac
18 | fi
19 |
20 | export HERMIT_DIST_URL="${HERMIT_DIST_URL:-https://github.com/cashapp/hermit/releases/download/stable}"
21 | HERMIT_CHANNEL="$(basename "${HERMIT_DIST_URL}")"
22 | export HERMIT_CHANNEL
23 | export HERMIT_EXE=${HERMIT_EXE:-${HERMIT_STATE_DIR}/pkg/hermit@${HERMIT_CHANNEL}/hermit}
24 |
25 | if [ ! -x "${HERMIT_EXE}" ]; then
26 | echo "Bootstrapping ${HERMIT_EXE} from ${HERMIT_DIST_URL}" 1>&2
27 | INSTALL_SCRIPT="$(mktemp)"
28 | # This value must match that of the install script
29 | INSTALL_SCRIPT_SHA256="180e997dd837f839a3072a5e2f558619b6d12555cd5452d3ab19d87720704e38"
30 | if [ "${INSTALL_SCRIPT_SHA256}" = "BYPASS" ]; then
31 | curl -fsSL "${HERMIT_DIST_URL}/install.sh" -o "${INSTALL_SCRIPT}"
32 | else
33 | # Install script is versioned by its sha256sum value
34 | curl -fsSL "${HERMIT_DIST_URL}/install-${INSTALL_SCRIPT_SHA256}.sh" -o "${INSTALL_SCRIPT}"
35 | # Verify install script's sha256sum
36 | openssl dgst -sha256 "${INSTALL_SCRIPT}" | \
37 | awk -v EXPECTED="$INSTALL_SCRIPT_SHA256" \
38 | '$2!=EXPECTED {print "Install script sha256 " $2 " does not match " EXPECTED; exit 1}'
39 | fi
40 | /bin/bash "${INSTALL_SCRIPT}" 1>&2
41 | fi
42 |
43 | exec "${HERMIT_EXE}" --level=fatal exec "$0" -- "$@"
44 |
--------------------------------------------------------------------------------
/packages/docstore/docstore_test.go:
--------------------------------------------------------------------------------
1 | package docstore_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | . "github.com/stillmatic/gollum"
8 | "github.com/stillmatic/gollum/packages/docstore"
9 | "github.com/stretchr/testify/assert"
10 | "gocloud.dev/blob/fileblob"
11 | )
12 |
13 | func TestMemoryDocStore(t *testing.T) {
14 | ctx := context.Background()
15 | store := docstore.NewMemoryDocStore()
16 | doc := Document{ID: "1", Content: "test data"}
17 | doc2 := Document{ID: "2", Content: "test data 2"}
18 |
19 | // ensure store implements the DocStore interface
20 | var _ docstore.DocStore = store
21 |
22 | t.Run("Insert document", func(t *testing.T) {
23 | err := store.Insert(ctx, doc)
24 | assert.NoError(t, err)
25 | err = store.Insert(ctx, doc2)
26 | assert.NoError(t, err)
27 | })
28 |
29 | t.Run("Retrieve document", func(t *testing.T) {
30 | retrievedDoc, err := store.Retrieve(ctx, "1")
31 | assert.NoError(t, err)
32 | assert.Equal(t, doc, retrievedDoc)
33 | })
34 |
35 | t.Run("Retrieve non-existing document", func(t *testing.T) {
36 | _, err := store.Retrieve(ctx, "non-existing-id")
37 | assert.Error(t, err)
38 | })
39 |
40 | t.Run("Persist document store", func(t *testing.T) {
41 | // persist to testdata/docstore.json
42 | bucket, err := fileblob.OpenBucket("testdata", nil)
43 | assert.NoError(t, err)
44 | err = store.Persist(ctx, bucket, "docstore.json")
45 | assert.NoError(t, err)
46 | })
47 |
48 | t.Run("Load document store from disk", func(t *testing.T) {
49 | // load from testdata/docstore.json
50 | bucket, err := fileblob.OpenBucket("testdata", nil)
51 | assert.NoError(t, err)
52 | loadedStore, err := docstore.NewMemoryDocStoreFromDisk(ctx, bucket, "docstore.json")
53 | assert.NoError(t, err)
54 | assert.Equal(t, store, loadedStore)
55 | })
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/internal/hash/hash_bench_test.go:
--------------------------------------------------------------------------------
1 | package hash
2 |
3 | // copied from https://gist.github.com/wizjin/e103e1040db0c4c427db4104cce67566
4 |
5 | import (
6 | "crypto/md5"
7 | "crypto/rand"
8 | "crypto/sha1"
9 | "crypto/sha256"
10 | "crypto/sha512"
11 | "github.com/cespare/xxhash/v2"
12 | "hash"
13 | "hash/fnv"
14 | "testing"
15 | )
16 |
17 | // NB chua: we care about caching / hashing lots of tokens.
18 | // That probably starts mattering at O(1k), O(10k), O(100k) tokens.
19 | // note that there are ~4 characters per token, and 1-4 bytes per character in UTF-8
20 | // so 1k tokens is 4k-16k bytes, 10k tokens is 40k-160k bytes, 100k tokens is 400k-1.6M bytes.
21 | const (
22 | K = 1024
23 | DATALEN = 512 * K
24 | )
25 |
26 | func runHash(b *testing.B, h hash.Hash, n int) {
27 | var data = make([]byte, n)
28 | rand.Read(data)
29 | b.ResetTimer()
30 |
31 | for i := 0; i < b.N; i++ {
32 | h.Write(data)
33 | h.Sum(nil)
34 | }
35 | }
36 |
37 | func BenchmarkFNV32(b *testing.B) {
38 | runHash(b, fnv.New32(), DATALEN)
39 | }
40 |
41 | func BenchmarkFNV64(b *testing.B) {
42 | runHash(b, fnv.New64(), DATALEN)
43 | }
44 |
45 | func BenchmarkFNV128(b *testing.B) {
46 | runHash(b, fnv.New128(), DATALEN)
47 | }
48 |
49 | func BenchmarkMD5(b *testing.B) {
50 | runHash(b, md5.New(), DATALEN)
51 | }
52 |
53 | func BenchmarkSHA1(b *testing.B) {
54 | runHash(b, sha1.New(), DATALEN)
55 | }
56 |
57 | func BenchmarkSHA224(b *testing.B) {
58 | runHash(b, sha256.New224(), DATALEN)
59 | }
60 |
61 | func BenchmarkSHA256(b *testing.B) {
62 | runHash(b, sha256.New(), DATALEN)
63 | }
64 |
65 | func BenchmarkSHA512(b *testing.B) {
66 | runHash(b, sha512.New(), DATALEN)
67 | }
68 |
69 | // func BenchmarkMurmur3(b *testing.B) {
70 | // runHash(b, murmur3.New32(), DATALEN)
71 | // }
72 | func BenchmarkXxhash(b *testing.B) {
73 | runHash(b, xxhash.New(), DATALEN)
74 | }
75 |
--------------------------------------------------------------------------------
/packages/agents/calcagent_test.go:
--------------------------------------------------------------------------------
1 | package agents_test
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "github.com/stillmatic/gollum/packages/agents"
7 | "github.com/stillmatic/gollum/packages/tools"
8 | "os"
9 | "testing"
10 |
11 | "github.com/sashabaranov/go-openai"
12 | mock_gollum "github.com/stillmatic/gollum/internal/mocks"
13 | "github.com/stretchr/testify/assert"
14 | "go.uber.org/mock/gomock"
15 | )
16 |
17 | func TestCalcAgentMocked(t *testing.T) {
18 | ctrl := gomock.NewController(t)
19 | llm := mock_gollum.NewMockChatCompleter(ctrl)
20 | agent := agents.NewCalcAgent(llm)
21 | assert.NotNil(t, agent)
22 | ctx := context.Background()
23 | inp := agents.CalcAgentInput{
24 | Content: "What's 2 + 2?",
25 | }
26 | expectedResp := tools.CalculatorInput{
27 | Expression: "2 + 2",
28 | }
29 | expectedBytes, err := json.Marshal(expectedResp)
30 | assert.NoError(t, err)
31 | expectedChatCompletionResp := openai.ChatCompletionResponse{
32 | Choices: []openai.ChatCompletionChoice{
33 | {
34 | Message: openai.ChatCompletionMessage{
35 | Role: openai.ChatMessageRoleAssistant,
36 | Content: "",
37 | ToolCalls: []openai.ToolCall{
38 | {Function: openai.FunctionCall{Name: "calc", Arguments: string(expectedBytes)}},
39 | },
40 | },
41 | },
42 | },
43 | }
44 | llm.EXPECT().CreateChatCompletion(ctx, gomock.Any()).Return(expectedChatCompletionResp, nil)
45 |
46 | resp, err := agent.Run(ctx, inp)
47 | assert.NoError(t, err)
48 | assert.Equal(t, "4", resp.(string))
49 | }
50 |
51 | func TestCalcAgentReal(t *testing.T) {
52 | openai_key := os.Getenv("OPENAI_API_KEY")
53 | assert.NotEmpty(t, openai_key)
54 | llm := openai.NewClient(openai_key)
55 |
56 | agent := agents.NewCalcAgent(llm)
57 | assert.NotNil(t, agent)
58 | ctx := context.Background()
59 | inp := agents.CalcAgentInput{
60 | Content: "What's 2 + 2?",
61 | }
62 |
63 | resp, err := agent.Run(ctx, inp)
64 | assert.NoError(t, err)
65 | assert.Equal(t, "4", resp.(string))
66 | }
67 |
--------------------------------------------------------------------------------
/packages/llm/llm.go:
--------------------------------------------------------------------------------
1 | //go:generate mockgen -source llm.go -destination internal/mocks/llm.go
2 | package llm
3 |
4 | import (
5 | "context"
6 | )
7 |
8 | type ProviderType string
9 | type ModelType string
10 |
11 | // yuck sorry
12 | const (
13 | ModelTypeLLM ModelType = "llm"
14 | ModelTypeEmbedding ModelType = "embedding"
15 | )
16 |
17 | type ModelConfig struct {
18 | ProviderType ProviderType
19 | ModelName string
20 | BaseURL string
21 |
22 | ModelType ModelType
23 | CentiCentsPerMillionInputTokens int
24 | CentiCentsPerMillionOutputTokens int
25 | }
26 |
27 | // MessageOptions are options that can be passed to the model for generating a response.
28 | // NB chua: these are the only ones I use, I assume others are useful too...
29 | type MessageOptions struct {
30 | MaxTokens int
31 | Temperature float32
32 | }
33 |
34 | type InferMessage struct {
35 | Content string
36 | Role string
37 | Image []byte
38 | Audio []byte
39 |
40 | ShouldCache bool
41 | }
42 |
43 | type InferRequest struct {
44 | Messages []InferMessage
45 |
46 | // ModelConfig describes the model to use for generating a response.
47 | ModelConfig ModelConfig
48 | // MessageOptions are options that can be passed to the model for generating a response.
49 | MessageOptions MessageOptions
50 | }
51 |
52 | type StreamDelta struct {
53 | Text string
54 | EOF bool
55 | }
56 |
57 | type Responder interface {
58 | GenerateResponse(ctx context.Context, req InferRequest) (string, error)
59 | GenerateResponseAsync(ctx context.Context, req InferRequest) (<-chan StreamDelta, error)
60 | }
61 |
62 | type EmbedRequest struct {
63 | Input []string
64 | Image []byte
65 |
66 | // Prompt is an instruction applied to all the input strings in this request.
67 | // Ignored unless the model specifically supports it
68 | Prompt string
69 |
70 | ModelConfig ModelConfig
71 | // only supported for openai (matryoshka) models
72 | Dimensions int
73 | }
74 |
75 | type Embedding struct {
76 | Values []float32
77 | }
78 |
79 | type EmbeddingResponse struct {
80 | Data []Embedding
81 | }
82 |
83 | type Embedder interface {
84 | GenerateEmbedding(ctx context.Context, req EmbedRequest) (*EmbeddingResponse, error)
85 | }
86 |
--------------------------------------------------------------------------------
/packages/dispatch/functions.go:
--------------------------------------------------------------------------------
1 | package dispatch
2 |
3 | import (
4 | "reflect"
5 |
6 | "github.com/invopop/jsonschema"
7 | "github.com/sashabaranov/go-openai"
8 | )
9 |
10 | type FunctionInput struct {
11 | Name string `json:"name"`
12 | Description string `json:"description,omitempty"`
13 | Parameters any `json:"parameters"`
14 | }
15 |
16 | type OAITool struct {
17 | // Type is always "function" for now.
18 | Type string `json:"type"`
19 | Function FunctionInput `json:"function"`
20 | }
21 |
22 | func FunctionInputToTool(fi FunctionInput) openai.Tool {
23 | f_ := openai.FunctionDefinition(fi)
24 | return openai.Tool{
25 | Type: "function",
26 | Function: &f_,
27 | }
28 | }
29 |
30 | func StructToJsonSchema(functionName string, functionDescription string, inputStruct interface{}) FunctionInput {
31 | t := reflect.TypeOf(inputStruct)
32 | schema := jsonschema.ReflectFromType(reflect.Type(t))
33 | inputStructName := t.Name()
34 | // only get the single struct we care about
35 | inputProperties, ok := schema.Definitions[inputStructName]
36 | if !ok {
37 | // this should not happen
38 | panic("could not find input struct in schema")
39 | }
40 | parameters := jsonschema.Schema{
41 | Type: "object",
42 | Properties: inputProperties.Properties,
43 | Required: inputProperties.Required,
44 | }
45 | return FunctionInput{
46 | Name: functionName,
47 | Description: functionDescription,
48 | Parameters: parameters,
49 | }
50 | }
51 |
52 | func StructToJsonSchemaGeneric[T any](functionName string, functionDescription string) FunctionInput {
53 | var tArr [0]T
54 | tt := reflect.TypeOf(tArr).Elem()
55 | schema := jsonschema.ReflectFromType(reflect.Type(tt))
56 | inputStructName := tt.Name()
57 | // only get the single struct we care about
58 | inputProperties, ok := schema.Definitions[inputStructName]
59 | if !ok {
60 | // this should not happen
61 | panic("could not find input struct in schema")
62 | }
63 | parameters := jsonschema.Schema{
64 | Type: "object",
65 | Properties: inputProperties.Properties,
66 | Required: inputProperties.Required,
67 | }
68 | return FunctionInput{
69 | Name: functionName,
70 | Description: functionDescription,
71 | Parameters: parameters,
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/packages/queryplanner/queryplanner_test.go:
--------------------------------------------------------------------------------
1 | package queryplanner_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/stillmatic/gollum/packages/dispatch"
7 | "github.com/stillmatic/gollum/packages/jsonparser"
8 | . "github.com/stillmatic/gollum/packages/queryplanner"
9 | "os"
10 | "testing"
11 |
12 | "github.com/joho/godotenv"
13 | "github.com/sashabaranov/go-openai"
14 | "github.com/stillmatic/gollum/internal/testutil"
15 | "github.com/stretchr/testify/assert"
16 | )
17 |
18 | func TestQueryPlanner(t *testing.T) {
19 | godotenv.Load()
20 | baseAPIURL := "https://api.openai.com/v1/chat/completions"
21 | openAIKey := os.Getenv("OPENAI_API_KEY")
22 | assert.NotEmpty(t, openAIKey)
23 |
24 | api := testutil.NewTestAPI(baseAPIURL, openAIKey)
25 | fi := dispatch.StructToJsonSchemaGeneric[QueryPlan]("QueryPlan", "Use this to plan a query.")
26 | question := "What is the difference between populations of Canada and Jason's home country?"
27 |
28 | messages := []openai.ChatCompletionMessage{
29 | {
30 | Role: openai.ChatMessageRoleSystem,
31 | Content: "You are a world class query planning algorithm capable of breaking apart questions into its depenencies queries such that the answers can be used to inform the parent question. Do not answer the questions, simply provide correct compute graph with good specific questions to ask and relevant dependencies. Before you call the function, think step by step to get a better understanding the problem.",
32 | },
33 | {
34 | Role: openai.ChatMessageRoleUser,
35 | Content: fmt.Sprintf("Consider: %s\nGenerate the correct query plan.", question),
36 | },
37 | }
38 | f_ := openai.FunctionDefinition(fi)
39 | chatRequest := openai.ChatCompletionRequest{
40 | Messages: messages,
41 | Model: "gpt-3.5-turbo-0613",
42 | Temperature: 0.0,
43 | Tools: []openai.Tool{
44 | {
45 | Type: "function",
46 | Function: &f_,
47 | },
48 | },
49 | }
50 | ctx := context.Background()
51 | resp, err := api.SendRequest(ctx, chatRequest)
52 |
53 | assert.NoError(t, err)
54 | assert.NotNil(t, resp)
55 | parser := jsonparser.NewJSONParserGeneric[QueryPlan](false)
56 | queryPlan, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
57 | assert.NoError(t, err)
58 | assert.NotNil(t, queryPlan)
59 | t.Log(queryPlan)
60 | assert.Equal(t, 0, 1)
61 | }
62 |
--------------------------------------------------------------------------------
/internal/testutil/utils.go:
--------------------------------------------------------------------------------
1 | package testutil
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "math/rand"
9 | "net/http"
10 |
11 | "github.com/sashabaranov/go-openai"
12 | )
13 |
14 | type TestAPI struct {
15 | baseAPIURL string
16 | apiKey string
17 | client *http.Client
18 | }
19 |
20 | func NewTestAPI(baseAPIURL, apiKey string) *TestAPI {
21 | return &TestAPI{
22 | baseAPIURL: baseAPIURL,
23 | apiKey: apiKey,
24 | client: &http.Client{},
25 | }
26 | }
27 |
28 | func (api *TestAPI) SendRequest(ctx context.Context, chatRequest openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error) {
29 | b, err := json.Marshal(chatRequest)
30 | if err != nil {
31 | return nil, err
32 | }
33 |
34 | req, err := http.NewRequestWithContext(ctx, "POST", api.baseAPIURL, bytes.NewReader(b))
35 | if err != nil {
36 | return nil, err
37 | }
38 | req.Header.Set("Authorization", "Bearer "+api.apiKey)
39 | req.Header.Set("Content-Type", "application/json")
40 |
41 | resp, err := api.client.Do(req)
42 | if err != nil {
43 | return nil, err
44 | }
45 | defer resp.Body.Close()
46 |
47 | if resp.StatusCode != 200 {
48 | return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
49 | }
50 |
51 | var chatResponse openai.ChatCompletionResponse
52 | err = json.NewDecoder(resp.Body).Decode(&chatResponse)
53 | if err != nil {
54 | return nil, err
55 | }
56 |
57 | return &chatResponse, nil
58 | }
59 |
60 | func GetRandomEmbedding(n int) []float32 {
61 | vec := make([]float32, n)
62 | for i := range vec {
63 | vec[i] = rand.Float32()
64 | }
65 | return vec
66 | }
67 |
68 | func GetRandomEmbeddingResponse(n int, dim int) openai.EmbeddingResponse {
69 | data := make([]openai.Embedding, n)
70 | for i := range data {
71 | data[i] = openai.Embedding{
72 | Embedding: GetRandomEmbedding(dim),
73 | }
74 | }
75 | resp := openai.EmbeddingResponse{
76 | Data: data,
77 | }
78 | return resp
79 | }
80 |
81 | func GetRandomChatCompletionResponse(n int) openai.ChatCompletionResponse {
82 | choices := make([]openai.ChatCompletionChoice, n)
83 | for i := range choices {
84 | choices[i] = openai.ChatCompletionChoice{
85 | Message: openai.ChatCompletionMessage{
86 | Role: openai.ChatMessageRoleSystem,
87 | Content: fmt.Sprintf("test? %d", i),
88 | },
89 | }
90 | }
91 | return openai.ChatCompletionResponse{
92 | Choices: choices,
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/packages/jsonparser/parser.go:
--------------------------------------------------------------------------------
1 | package jsonparser
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "reflect"
7 |
8 | "github.com/invopop/jsonschema"
9 | "github.com/pkg/errors"
10 | val "github.com/santhosh-tekuri/jsonschema/v5"
11 | )
12 |
13 | // Parser is an interface for parsing strings into structs
14 | // It is threadsafe and can be used concurrently. The underlying validator is threadsafe as well.
15 | type Parser[T any] interface {
16 | Parse(ctx context.Context, input []byte) (T, error)
17 | }
18 |
19 | // JSONParser is a parser that parses arbitrary JSON structs
20 | // It is threadsafe and can be used concurrently. The underlying validator is threadsafe as well.
21 | type JSONParserGeneric[T any] struct {
22 | validate bool
23 | schema *val.Schema
24 | }
25 |
26 | // NewJSONParser returns a new JSONParser
27 | // validation is done via jsonschema
28 | func NewJSONParserGeneric[T any](validate bool) *JSONParserGeneric[T] {
29 | var sch *val.Schema
30 | var t_ T
31 | if validate {
32 | // reflect T into schema
33 | t := reflect.TypeOf(t_)
34 | schema := jsonschema.ReflectFromType(t)
35 | // compile schema
36 | b, err := schema.MarshalJSON()
37 | if err != nil {
38 | panic(errors.Wrap(err, "could not marshal schema"))
39 | }
40 | schemaStr := string(b)
41 | sch, err = val.CompileString("schema.json", schemaStr)
42 | if err != nil {
43 | panic(errors.Wrap(err, "could not compile schema"))
44 | }
45 | }
46 |
47 | return &JSONParserGeneric[T]{
48 | validate: validate,
49 | schema: sch,
50 | }
51 | }
52 |
53 | func (p *JSONParserGeneric[T]) Parse(ctx context.Context, input []byte) (T, error) {
54 | var t T
55 | // but also unmarshal to the struct because it's easy to get a type conversion error
56 | // e.g. if go struct name doesn't match the json name
57 | err := json.Unmarshal(input, &t)
58 | if err != nil {
59 | return t, errors.Wrap(err, "could not unmarshal input json to struct")
60 | }
61 |
62 | // annoying, must pass an interface to the validate function
63 | // so we have to unmarshal to interface{} twice
64 | if p.validate {
65 | var v interface{}
66 | err := json.Unmarshal(input, &v)
67 | if err != nil {
68 | return t, errors.Wrap(err, "could not unmarshal input json to interface")
69 | }
70 | err = p.schema.Validate(v)
71 | if err != nil {
72 | return t, errors.Wrap(err, "error validating input json")
73 | }
74 | }
75 |
76 | return t, nil
77 | }
78 |
--------------------------------------------------------------------------------
/packages/vectorstore/vectorstore.go:
--------------------------------------------------------------------------------
1 | package vectorstore
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/stillmatic/gollum"
7 | )
8 |
9 | // QueryRequest is a struct that contains the query and optional query strings or embeddings
10 | type QueryRequest struct {
11 | // Query is the text to query
12 | Query string
13 | // EmbeddingStrings is a list of strings to concatenate and embed instead of Query
14 | EmbeddingStrings []string
15 | // EmbeddingFloats is a query vector to use instead of Query
16 | EmbeddingFloats []float32
17 | // K is the number of results to return
18 | K int
19 | }
20 |
21 | type VectorStore interface {
22 | Insert(context.Context, gollum.Document) error
23 | Query(ctx context.Context, qb QueryRequest) ([]*gollum.Document, error)
24 | RetrieveAll(ctx context.Context) ([]gollum.Document, error)
25 | }
26 |
27 | type NodeSimilarity struct {
28 | Document *gollum.Document
29 | Similarity float32
30 | }
31 |
32 | // Heap is a custom heap implementation, to avoid interface{} conversion.
33 | // I _think_ theoretically that a memory arena would be useful here, but that feels a bit beyond the pale, even for me.
34 | // In benchmarking, we see that allocations are limited by scale according to k --
35 | // since K is known, we should be able to allocate a fixed-size arena and use that.
36 | // That being said... let's revisit in the future :)
37 | type Heap []NodeSimilarity
38 |
39 | func (h *Heap) Init(k int) {
40 | *h = make(Heap, 0, k)
41 | }
42 |
43 | func (h Heap) down(u int) {
44 | v := u
45 | if 2*u+1 < len(h) && h[2*u+1].Similarity < h[v].Similarity {
46 | v = 2*u + 1
47 | }
48 | if 2*u+2 < len(h) && h[2*u+2].Similarity < h[v].Similarity {
49 | v = 2*u + 2
50 | }
51 | if v != u {
52 | h[v], h[u] = h[u], h[v]
53 | h.down(v)
54 | }
55 | }
56 |
57 | func (h Heap) up(u int) {
58 | for u != 0 && h[(u-1)/2].Similarity > h[u].Similarity {
59 | h[(u-1)/2], h[u] = h[u], h[(u-1)/2]
60 | u = (u - 1) / 2
61 | }
62 | }
63 |
64 | func (h *Heap) Push(e NodeSimilarity) {
65 | *h = append(*h, e)
66 | h.up(len(*h) - 1)
67 | }
68 |
69 | func (h *Heap) Pop() NodeSimilarity {
70 | x := (*h)[0]
71 | n := len(*h)
72 | (*h)[0], (*h)[n-1] = (*h)[n-1], (*h)[0]
73 | *h = (*h)[:n-1]
74 | h.down(0)
75 | return x
76 | }
77 |
78 | func (h Heap) Less(i, j int) bool {
79 | return h[i].Similarity < h[j].Similarity
80 | }
81 |
82 | func (h Heap) Swap(i, j int) {
83 | h[i], h[j] = h[j], h[i]
84 | }
85 |
86 | func (h *Heap) Len() int {
87 | return len(*h)
88 | }
89 |
--------------------------------------------------------------------------------
/packages/llm/providers/cached/cached_provider.go:
--------------------------------------------------------------------------------
1 | package cached
2 |
3 | import (
4 | "context"
5 | "crypto/sha256"
6 | "fmt"
7 | "log"
8 |
9 | "github.com/stillmatic/gollum/packages/llm"
10 | "github.com/stillmatic/gollum/packages/llm/cache"
11 | "github.com/stillmatic/gollum/packages/llm/providers/cached/sqlitecache"
12 |
13 | "hash"
14 |
15 | _ "modernc.org/sqlite"
16 | )
17 |
18 | // CachedResponder implements the Responder interface with caching
19 | type CachedResponder struct {
20 | underlying llm.Responder
21 | cache cache.Cache
22 | hasher hash.Hash
23 | }
24 |
25 | // NewLocalCachedResponder creates a new CachedResponder with a local SQLite cache
26 | // For example, initialize an OpenAI provider and then wrap it with this cache.
27 | func NewLocalCachedResponder(underlying llm.Responder, dbPath string) (*CachedResponder, error) {
28 | cache, err := sqlitecache.NewSQLiteCache(dbPath)
29 | if err != nil {
30 | return nil, fmt.Errorf("failed to create cache: %w", err)
31 | }
32 |
33 | // we use sha256 to avoid pulling down the xxhash dep if you don't need to
34 | // these are small strings to cache so shouldn't make a big diff
35 | hasher := sha256.New()
36 |
37 | return &CachedResponder{
38 | underlying: underlying,
39 | cache: cache,
40 | hasher: hasher,
41 | }, nil
42 | }
43 |
44 | func (cr *CachedResponder) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) {
45 | // Check cache
46 | cachedResponse, err := cr.cache.GetResponse(ctx, req)
47 | if err == nil {
48 | return cachedResponse, nil
49 | }
50 |
51 | // If not in cache, call underlying provider
52 | response, err := cr.underlying.GenerateResponse(ctx, req)
53 | if err != nil {
54 | return "", err
55 | }
56 |
57 | // Cache the result
58 | if err := cr.cache.SetResponse(ctx, req, response); err != nil {
59 | log.Printf("Failed to cache response: %v", err)
60 | }
61 |
62 | return response, nil
63 | }
64 |
65 | func (cr *CachedResponder) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) {
66 | // For async responses, we don't cache and just pass through to the underlying provider
67 | // TODO: think about if we should just cache final response and return immediately
68 | return cr.underlying.GenerateResponseAsync(ctx, req)
69 | }
70 |
71 | func (cr *CachedResponder) Close() error {
72 | return cr.cache.Close()
73 | }
74 |
75 | func (cr *CachedResponder) GetCacheStats() cache.CacheStats {
76 | return cr.cache.GetStats()
77 | }
78 |
--------------------------------------------------------------------------------
/internal/hash/readme.md:
--------------------------------------------------------------------------------
1 | with `go test -bench=. -cpu 1 -benchmem -benchtime=1s`
2 |
3 | with n=32
4 |
5 | ```
6 | goos: linux
7 | goarch: amd64
8 | pkg: github.com/stillmatic/gollum/internal/hash
9 | cpu: AMD Ryzen 9 7950X 16-Core Processor
10 | BenchmarkFNV32 48111 25479 ns/op 8 B/op 1 allocs/op
11 | BenchmarkFNV64 47475 25640 ns/op 8 B/op 1 allocs/op
12 | BenchmarkFNV128 38230 31021 ns/op 16 B/op 1 allocs/op
13 | BenchmarkMD5 44103 27324 ns/op 16 B/op 1 allocs/op
14 | BenchmarkSHA1 69037 17428 ns/op 24 B/op 1 allocs/op
15 | BenchmarkSHA224 98686 12203 ns/op 32 B/op 1 allocs/op
16 | BenchmarkSHA256 97653 12591 ns/op 32 B/op 1 allocs/op
17 | BenchmarkSHA512 39342 29469 ns/op 64 B/op 1 allocs/op
18 | BenchmarkMurmur3 170241 6990 ns/op 8 B/op 1 allocs/op
19 | BenchmarkXxhash 756992 1552 ns/op 8 B/op 1 allocs/op
20 | PASS
21 | ok github.com/stillmatic/gollum/internal/hash 13.927s
22 | ```
23 |
24 | with n=512
25 |
26 | ```
27 | goos: linux
28 | goarch: amd64
29 | pkg: github.com/stillmatic/gollum/internal/hash
30 | cpu: AMD Ryzen 9 7950X 16-Core Processor
31 | BenchmarkFNV32 3097 407335 ns/op 8 B/op 1 allocs/op
32 | BenchmarkFNV64 3090 389704 ns/op 8 B/op 1 allocs/op
33 | BenchmarkFNV128 2314 499748 ns/op 16 B/op 1 allocs/op
34 | BenchmarkMD5 2682 443034 ns/op 16 B/op 1 allocs/op
35 | BenchmarkSHA1 4224 280997 ns/op 24 B/op 1 allocs/op
36 | BenchmarkSHA224 5968 198692 ns/op 32 B/op 1 allocs/op
37 | BenchmarkSHA256 5952 201876 ns/op 32 B/op 1 allocs/op
38 | BenchmarkSHA512 2544 489953 ns/op 64 B/op 1 allocs/op
39 | BenchmarkMurmur3 9975 119104 ns/op 8 B/op 1 allocs/op
40 | BenchmarkXxhash 46592 24758 ns/op 8 B/op 1 allocs/op
41 | PASS
42 | ok github.com/stillmatic/gollum/internal/hash 12.580s
43 | ```
44 |
45 | Conclusion: `xxhash` is an order of magnitude improvement on large strings. `murmur3` is good for medium size but doesn't seem to scale as well.
--------------------------------------------------------------------------------
/packages/docstore/docstore.go:
--------------------------------------------------------------------------------
1 | package docstore
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "github.com/stillmatic/gollum"
7 |
8 | "github.com/pkg/errors"
9 | "gocloud.dev/blob"
10 | )
11 |
12 | type DocStore interface {
13 | Insert(context.Context, gollum.Document) error
14 | Retrieve(ctx context.Context, id string) (gollum.Document, error)
15 | }
16 |
17 | // MemoryDocStore is a simple in-memory document store.
18 | // It's functionally a hashmap / inverted-index.
19 | type MemoryDocStore struct {
20 | Documents map[string]gollum.Document
21 | }
22 |
23 | func NewMemoryDocStore() *MemoryDocStore {
24 | return &MemoryDocStore{
25 | Documents: make(map[string]gollum.Document),
26 | }
27 | }
28 |
29 | func NewMemoryDocStoreFromDisk(ctx context.Context, bucket *blob.Bucket, path string) (*MemoryDocStore, error) {
30 | // load documents from disk
31 | data, err := bucket.ReadAll(ctx, path)
32 | if err != nil {
33 | return nil, errors.Wrap(err, "failed to read documents from disk")
34 | }
35 | var nodes map[string]gollum.Document
36 | err = json.Unmarshal(data, &nodes)
37 | if err != nil {
38 | return nil, errors.Wrap(err, "failed to unmarshal documents from JSON")
39 | }
40 | return &MemoryDocStore{
41 | Documents: nodes,
42 | }, nil
43 | }
44 |
45 | // Insert adds a node to the document store. It overwrites duplicates.
46 | func (m *MemoryDocStore) Insert(ctx context.Context, d gollum.Document) error {
47 | m.Documents[d.ID] = d
48 | return nil
49 | }
50 |
51 | // Retrieve returns a node from the document store matching an ID.
52 | func (m *MemoryDocStore) Retrieve(ctx context.Context, id string) (gollum.Document, error) {
53 | v, ok := m.Documents[id]
54 | if !ok {
55 | return gollum.Document{}, errors.New("document not found")
56 | }
57 | return v, nil
58 | }
59 |
60 | // Persist saves the document store to disk.
61 | func (m *MemoryDocStore) Persist(ctx context.Context, bucket *blob.Bucket, path string) error {
62 | // save documents to disk
63 | data, err := json.Marshal(m.Documents)
64 | if err != nil {
65 | return errors.Wrap(err, "failed to marshal documents to JSON")
66 | }
67 | err = bucket.WriteAll(ctx, path, data, nil)
68 | if err != nil {
69 | return errors.Wrap(err, "failed to write documents to disk")
70 | }
71 | return nil
72 | }
73 |
74 | // Load loads the document store from disk.
75 | func (m *MemoryDocStore) Load(ctx context.Context, bucket *blob.Bucket, path string) error {
76 | // load documents from disk
77 | data, err := bucket.ReadAll(ctx, path)
78 | if err != nil {
79 | return errors.Wrap(err, "failed to read documents from disk")
80 | }
81 | var nodes map[string]gollum.Document
82 | err = json.Unmarshal(data, &nodes)
83 | if err != nil {
84 | return errors.Wrap(err, "failed to unmarshal documents from JSON")
85 | }
86 | m.Documents = nodes
87 | return nil
88 | }
89 |
--------------------------------------------------------------------------------
/packages/llm/providers/voyage/voyage.go:
--------------------------------------------------------------------------------
1 | package voyage
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "github.com/stillmatic/gollum/packages/llm"
9 | "io"
10 | "net/http"
11 | )
12 |
13 | const (
14 | apiURL = "https://api.voyageai.com/v1/embeddings"
15 | )
16 |
17 | type VoyageAIEmbedder struct {
18 | APIKey string
19 | }
20 |
21 | type voyageAIRequest struct {
22 | Input []string `json:"input"`
23 | Model string `json:"model"`
24 | }
25 |
26 | type voyageAIResponse struct {
27 | Object string `json:"object"`
28 | Data []struct {
29 | Object string `json:"object"`
30 | Embedding []float32 `json:"embedding"`
31 | Index int `json:"index"`
32 | } `json:"data"`
33 | Model string `json:"model"`
34 | Usage struct {
35 | TotalTokens int `json:"total_tokens"`
36 | } `json:"usage"`
37 | }
38 |
39 | func NewVoyageAIEmbedder(apiKey string) *VoyageAIEmbedder {
40 | return &VoyageAIEmbedder{APIKey: apiKey}
41 | }
42 |
43 | func (e *VoyageAIEmbedder) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) {
44 | if len(req.Image) > 0 {
45 | return nil, fmt.Errorf("image embedding not supported by Voyage AI")
46 | }
47 |
48 | if req.Dimensions != 0 {
49 | return nil, fmt.Errorf("custom dimensions not supported by Voyage AI")
50 | }
51 |
52 | voyageReq := voyageAIRequest{
53 | Input: req.Input,
54 | Model: req.ModelConfig.ModelName,
55 | }
56 |
57 | jsonData, err := json.Marshal(voyageReq)
58 | if err != nil {
59 | return nil, fmt.Errorf("failed to marshal request: %w", err)
60 | }
61 |
62 | httpReq, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData))
63 | if err != nil {
64 | return nil, fmt.Errorf("failed to create request: %w", err)
65 | }
66 |
67 | httpReq.Header.Set("Content-Type", "application/json")
68 | httpReq.Header.Set("Authorization", "Bearer "+e.APIKey)
69 |
70 | client := &http.Client{}
71 | resp, err := client.Do(httpReq)
72 | if err != nil {
73 | return nil, fmt.Errorf("failed to send request: %w", err)
74 | }
75 | defer resp.Body.Close()
76 |
77 | body, err := io.ReadAll(resp.Body)
78 | if err != nil {
79 | return nil, fmt.Errorf("failed to read response body: %w", err)
80 | }
81 |
82 | if resp.StatusCode != http.StatusOK {
83 | return nil, fmt.Errorf("API request failed with status code %d: %s", resp.StatusCode, string(body))
84 | }
85 |
86 | var voyageResp voyageAIResponse
87 | err = json.Unmarshal(body, &voyageResp)
88 | if err != nil {
89 | return nil, fmt.Errorf("failed to unmarshal response: %w", err)
90 | }
91 |
92 | embeddings := make([]llm.Embedding, len(voyageResp.Data))
93 | for i, data := range voyageResp.Data {
94 | embeddings[i] = llm.Embedding{Values: data.Embedding}
95 | }
96 |
97 | return &llm.EmbeddingResponse{Data: embeddings}, nil
98 | }
99 |
--------------------------------------------------------------------------------
/packages/jsonparser/parser_test.go:
--------------------------------------------------------------------------------
1 | package jsonparser_test
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "testing"
7 |
8 | "github.com/stillmatic/gollum/packages/jsonparser"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | type employee struct {
13 | Name string `json:"name" yaml:"name" jsonschema:"required,minLength=1,maxLength=100"`
14 | Age int `json:"age" yaml:"age" jsonschema:"minimum=18,maximum=80,required"`
15 | }
16 |
17 | type company struct {
18 | Name string `json:"name" yaml:"name"`
19 | Employees []employee `json:"employees,omitempty" yaml:"employees"`
20 | }
21 |
22 | type bulletList []string
23 |
24 | var testCo = company{
25 | Name: "Acme",
26 | Employees: []employee{
27 | {
28 | Name: "John",
29 | Age: 30,
30 | },
31 | {
32 | Name: "Jane",
33 | Age: 25,
34 | },
35 | },
36 | }
37 | var badEmployees = []employee{
38 | {
39 | Name: "John",
40 | Age: 0,
41 | },
42 | {
43 | Name: "",
44 | Age: 25,
45 | },
46 | }
47 |
48 | func TestParsers(t *testing.T) {
49 | t.Run("JSONParser", func(t *testing.T) {
50 | jsonParser := jsonparser.NewJSONParserGeneric[company](false)
51 | input, err := json.Marshal(testCo)
52 | assert.NoError(t, err)
53 |
54 | actual, err := jsonParser.Parse(context.Background(), input)
55 | assert.NoError(t, err)
56 | assert.Equal(t, testCo, actual)
57 |
58 | // test failure
59 | employeeParser := jsonparser.NewJSONParserGeneric[employee](true)
60 | input2, err := json.Marshal(badEmployees)
61 | assert.NoError(t, err)
62 | _, err = employeeParser.Parse(context.Background(), input2)
63 | assert.Error(t, err)
64 | })
65 |
66 | t.Run("testbenchmark", func(t *testing.T) {
67 | jsonParser := jsonparser.NewJSONParserGeneric[company](true)
68 | input, err := json.Marshal(testCo)
69 | assert.NoError(t, err)
70 | actual, err := jsonParser.Parse(context.Background(), input)
71 | assert.NoError(t, err)
72 | assert.Equal(t, testCo.Name, actual.Name)
73 | })
74 | }
75 |
76 | func BenchmarkParser(b *testing.B) {
77 | b.Run("JSONParser-NoValidate", func(b *testing.B) {
78 | jsonParser := jsonparser.NewJSONParserGeneric[company](false)
79 | input, err := json.Marshal(testCo)
80 | assert.NoError(b, err)
81 | b.ResetTimer()
82 | for i := 0; i < b.N; i++ {
83 | actual, err := jsonParser.Parse(context.Background(), input)
84 | assert.NoError(b, err)
85 | assert.Equal(b, testCo, actual)
86 | }
87 | })
88 | b.Run("JSONParser-Validate", func(b *testing.B) {
89 | jsonParser := jsonparser.NewJSONParserGeneric[company](true)
90 | input, err := json.Marshal(testCo)
91 | assert.NoError(b, err)
92 | b.ResetTimer()
93 | for i := 0; i < b.N; i++ {
94 | actual, err := jsonParser.Parse(context.Background(), input)
95 | assert.NoError(b, err)
96 | assert.Equal(b, testCo, actual)
97 | }
98 | })
99 | }
100 |
--------------------------------------------------------------------------------
/packages/agents/calcagent.go:
--------------------------------------------------------------------------------
1 | package agents
2 |
3 | import (
4 | "context"
5 | "github.com/stillmatic/gollum/packages/dispatch"
6 | "github.com/stillmatic/gollum/packages/jsonparser"
7 | "github.com/stillmatic/gollum/packages/tools"
8 | "strconv"
9 |
10 | "github.com/pkg/errors"
11 | "github.com/sashabaranov/go-openai"
12 | "github.com/stillmatic/gollum"
13 | )
14 |
15 | type CalcAgent struct {
16 | tool tools.CalculatorTool
17 | env map[string]interface{}
18 | llm gollum.ChatCompleter
19 | functionInput openai.FunctionDefinition
20 | parser jsonparser.Parser[tools.CalculatorInput]
21 | }
22 |
23 | func NewCalcAgent(llm gollum.ChatCompleter) *CalcAgent {
24 | // might as well pre-compute it
25 | fi := dispatch.StructToJsonSchemaGeneric[tools.CalculatorInput]("calculator", "evaluate mathematical expressions")
26 | parser := jsonparser.NewJSONParserGeneric[tools.CalculatorInput](true)
27 | return &CalcAgent{
28 | tool: tools.CalculatorTool{},
29 | env: make(map[string]interface{}),
30 | llm: llm,
31 | functionInput: openai.FunctionDefinition(fi),
32 | parser: parser,
33 | }
34 | }
35 |
36 | type CalcAgentInput struct {
37 | Content string `json:"content" jsonschema:"required" jsonschema_description:"Natural language input to the calculator"`
38 | }
39 |
40 | func (c *CalcAgent) Name() string {
41 | return "calcagent"
42 | }
43 |
44 | func (c *CalcAgent) Description() string {
45 | return "convert natural language and evaluate mathematical expressions"
46 | }
47 |
48 | func (c *CalcAgent) Run(ctx context.Context, input interface{}) (interface{}, error) {
49 | cinput, ok := input.(CalcAgentInput)
50 | if !ok {
51 | return "", errors.New("invalid input")
52 | }
53 | // call LLM to convert natural language to mathematical expression
54 | resp, err := c.llm.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
55 | Model: openai.GPT3Dot5Turbo0613,
56 | Messages: []openai.ChatCompletionMessage{
57 | {
58 | Role: openai.ChatMessageRoleSystem,
59 | Content: "Convert the user's natural language input to a mathematical expression and input to a calculator function. Do not use prior knowledge.",
60 | },
61 | {
62 | Role: openai.ChatMessageRoleUser,
63 | Content: cinput.Content,
64 | },
65 | },
66 | MaxTokens: 128,
67 | Tools: []openai.Tool{{
68 | Type: "function",
69 | Function: &c.functionInput,
70 | }},
71 | ToolChoice: "calculator",
72 | })
73 | if err != nil {
74 | return "", errors.Wrap(err, "couldn't call the LLM")
75 | }
76 | // parse response
77 | parsed, err := c.parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
78 | if err != nil {
79 | return "", errors.Wrap(err, "couldn't parse response")
80 | }
81 | output, err := c.tool.Run(ctx, parsed)
82 | if err != nil {
83 | return "", errors.Wrap(err, "couldn't run expression")
84 | }
85 | switch t := output.(type) {
86 | case string:
87 | return t, nil
88 | case int:
89 | return strconv.Itoa(t), nil
90 | case float64:
91 | return strconv.FormatFloat(t, 'f', -1, 64), nil
92 | default:
93 | return "", errors.New("invalid output")
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/packages/llm/providers/cached/cached_embedder.go:
--------------------------------------------------------------------------------
1 | package cached
2 |
3 | import (
4 | "context"
5 | "crypto/sha256"
6 | "fmt"
7 | "hash"
8 | "log"
9 |
10 | "github.com/stillmatic/gollum/packages/llm"
11 | "github.com/stillmatic/gollum/packages/llm/cache"
12 | "github.com/stillmatic/gollum/packages/llm/providers/cached/sqlitecache"
13 | )
14 |
15 | // CachedEmbedder implements the llm.Embedder interface with caching
16 | type CachedEmbedder struct {
17 | underlying llm.Embedder
18 | cache cache.Cache
19 | hasher hash.Hash
20 | }
21 |
22 | // NewLocalCachedEmbedder creates a new CachedEmbedder with a local SQLite cache
23 | func NewLocalCachedEmbedder(underlying llm.Embedder, dbPath string) (*CachedEmbedder, error) {
24 | cache, err := sqlitecache.NewSQLiteCache(dbPath)
25 | if err != nil {
26 | return nil, fmt.Errorf("failed to create cache: %w", err)
27 | }
28 |
29 | return &CachedEmbedder{
30 | underlying: underlying,
31 | cache: cache,
32 | hasher: sha256.New(),
33 | }, nil
34 | }
35 |
36 | func (ce *CachedEmbedder) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) {
37 | cachedEmbeddings := make([]llm.Embedding, 0, len(req.Input))
38 | uncachedIndices := make([]int, 0)
39 | uncachedInputs := make([]string, 0)
40 |
41 | // Check cache for each input string
42 | for i, input := range req.Input {
43 | embedding, err := ce.cache.GetEmbedding(ctx, req.ModelConfig.ModelName, input)
44 | if err == nil {
45 | cachedEmbeddings = append(cachedEmbeddings, llm.Embedding{Values: embedding})
46 | } else {
47 | uncachedIndices = append(uncachedIndices, i)
48 | uncachedInputs = append(uncachedInputs, input)
49 | }
50 | }
51 |
52 | // If all embeddings were cached, return immediately
53 | if len(uncachedInputs) == 0 {
54 | return &llm.EmbeddingResponse{
55 | Data: cachedEmbeddings,
56 | }, nil
57 | }
58 |
59 | // Generate embeddings for uncached inputs
60 | uncachedReq := llm.EmbedRequest{
61 | ModelConfig: req.ModelConfig,
62 | Input: uncachedInputs,
63 | }
64 | uncachedResponse, err := ce.underlying.GenerateEmbedding(ctx, uncachedReq)
65 | if err != nil {
66 | return nil, err
67 | }
68 |
69 | // Cache the new embeddings
70 | for i, embedding := range uncachedResponse.Data {
71 | if err := ce.cache.SetEmbedding(ctx, req.ModelConfig.ModelName, uncachedInputs[i], embedding.Values); err != nil {
72 | log.Printf("Failed to cache embedding: %v", err)
73 | }
74 | }
75 |
76 | // Merge cached and new embeddings
77 | finalEmbeddings := make([]llm.Embedding, len(req.Input))
78 | cachedIndex, uncachedIndex := 0, 0
79 | for i := range req.Input {
80 | if contains(uncachedIndices, i) {
81 | finalEmbeddings[i] = uncachedResponse.Data[uncachedIndex]
82 | uncachedIndex++
83 | } else {
84 | finalEmbeddings[i] = cachedEmbeddings[cachedIndex]
85 | cachedIndex++
86 | }
87 | }
88 |
89 | return &llm.EmbeddingResponse{
90 | Data: finalEmbeddings,
91 | }, nil
92 | }
93 | func (ce *CachedEmbedder) Close() error {
94 | return ce.cache.Close()
95 | }
96 |
97 | func (ce *CachedEmbedder) GetCacheStats() cache.CacheStats {
98 | return ce.cache.GetStats()
99 | }
100 |
101 | func contains(slice []int, val int) bool {
102 | for _, item := range slice {
103 | if item == val {
104 | return true
105 | }
106 | }
107 | return false
108 | }
109 |
--------------------------------------------------------------------------------
/packages/llm/providers/mixedbread/mixedbread.go:
--------------------------------------------------------------------------------
1 | package mixedbread
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "github.com/stillmatic/gollum/packages/llm"
9 | "io"
10 | "net/http"
11 | )
12 |
13 | const (
14 | apiURL = "https://api.mixedbread.ai/v1/embeddings"
15 | )
16 |
17 | type MixedbreadEmbedder struct {
18 | APIKey string
19 | }
20 |
21 | type mixedbreadRequest struct {
22 | Input interface{} `json:"input"`
23 | Model string `json:"model"`
24 | Prompt string `json:"prompt,omitempty"`
25 | Normalized *bool `json:"normalized,omitempty"`
26 | Dimensions *int `json:"dimensions,omitempty"`
27 | EncodingFormat string `json:"encoding_format,omitempty"`
28 | TruncationStrategy string `json:"truncation_strategy,omitempty"`
29 | }
30 |
31 | type mixedbreadResponse struct {
32 | Model string `json:"model"`
33 | Object string `json:"object"`
34 | Data []struct {
35 | Embedding interface{} `json:"embedding"`
36 | Index int `json:"index"`
37 | Object string `json:"object"`
38 | } `json:"data"`
39 | Usage struct {
40 | PromptTokens int `json:"prompt_tokens"`
41 | TotalTokens int `json:"total_tokens"`
42 | } `json:"usage"`
43 | Normalized bool `json:"normalized"`
44 | }
45 |
46 | func NewMixedbreadEmbedder(apiKey string) *MixedbreadEmbedder {
47 | return &MixedbreadEmbedder{APIKey: apiKey}
48 | }
49 |
50 | func ptr[T any](x T) *T {
51 | return &x
52 | }
53 |
54 | func (e *MixedbreadEmbedder) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) {
55 | if len(req.Image) > 0 {
56 | return nil, fmt.Errorf("image embedding not supported by Mixedbread API")
57 | }
58 |
59 | mixedReq := mixedbreadRequest{
60 | Input: req.Input,
61 | Model: req.ModelConfig.ModelName,
62 | Normalized: ptr(true),
63 | }
64 | if req.Prompt != "" {
65 | mixedReq.Prompt = req.Prompt
66 | }
67 |
68 | if req.Dimensions != 0 {
69 | mixedReq.Dimensions = &req.Dimensions
70 | }
71 |
72 | jsonData, err := json.Marshal(mixedReq)
73 | if err != nil {
74 | return nil, fmt.Errorf("failed to marshal request: %w", err)
75 | }
76 |
77 | httpReq, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData))
78 | if err != nil {
79 | return nil, fmt.Errorf("failed to create request: %w", err)
80 | }
81 |
82 | httpReq.Header.Set("Content-Type", "application/json")
83 | httpReq.Header.Set("Authorization", "Bearer "+e.APIKey)
84 |
85 | client := &http.Client{}
86 | resp, err := client.Do(httpReq)
87 | if err != nil {
88 | return nil, fmt.Errorf("failed to send request: %w", err)
89 | }
90 | defer resp.Body.Close()
91 |
92 | body, err := io.ReadAll(resp.Body)
93 | if err != nil {
94 | return nil, fmt.Errorf("failed to read response body: %w", err)
95 | }
96 |
97 | if resp.StatusCode != http.StatusOK {
98 | return nil, fmt.Errorf("API request failed with status code %d: %s", resp.StatusCode, string(body))
99 | }
100 |
101 | var mixedResp mixedbreadResponse
102 | err = json.Unmarshal(body, &mixedResp)
103 | if err != nil {
104 | return nil, fmt.Errorf("failed to unmarshal response: %w", err)
105 | }
106 |
107 | embeddings := make([]llm.Embedding, len(mixedResp.Data))
108 | for i, data := range mixedResp.Data {
109 | switch v := data.Embedding.(type) {
110 | case []interface{}:
111 | values := make([]float32, len(v))
112 | for j, val := range v {
113 | if f, ok := val.(float64); ok {
114 | values[j] = float32(f)
115 | }
116 | }
117 | embeddings[i] = llm.Embedding{Values: values}
118 | default:
119 | return nil, fmt.Errorf("unexpected embedding format")
120 | }
121 | }
122 |
123 | return &llm.EmbeddingResponse{Data: embeddings}, nil
124 | }
125 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/stillmatic/gollum
2 |
3 | go 1.23
4 |
5 | require (
6 | cloud.google.com/go/vertexai v0.13.0
7 | github.com/antonmedv/expr v1.15.3
8 | github.com/cespare/xxhash/v2 v2.3.0
9 | github.com/chewxy/math32 v1.10.1
10 | github.com/google/generative-ai-go v0.19.0
11 | github.com/google/uuid v1.6.0
12 | github.com/invopop/jsonschema v0.12.0
13 | github.com/joho/godotenv v1.5.1
14 | github.com/klauspost/compress v1.17.2
15 | github.com/liushuangls/go-anthropic/v2 v2.13.0
16 | github.com/pkg/errors v0.9.1
17 | github.com/santhosh-tekuri/jsonschema/v5 v5.3.1
18 | github.com/sashabaranov/go-openai v1.36.1
19 | github.com/stretchr/testify v1.9.0
20 | github.com/viterin/vek v0.4.2
21 | go.uber.org/mock v0.3.0
22 | gocloud.dev v0.38.0
23 | google.golang.org/api v0.215.0
24 | modernc.org/sqlite v1.32.0
25 | )
26 |
27 | require (
28 | cloud.google.com/go v0.115.1 // indirect
29 | cloud.google.com/go/ai v0.8.0 // indirect
30 | cloud.google.com/go/aiplatform v1.68.0 // indirect
31 | cloud.google.com/go/auth v0.13.0 // indirect
32 | cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
33 | cloud.google.com/go/compute/metadata v0.6.0 // indirect
34 | cloud.google.com/go/iam v1.1.12 // indirect
35 | cloud.google.com/go/longrunning v0.5.11 // indirect
36 | github.com/bahlo/generic-list-go v0.2.0 // indirect
37 | github.com/buger/jsonparser v1.1.1 // indirect
38 | github.com/davecgh/go-spew v1.1.1 // indirect
39 | github.com/dustin/go-humanize v1.0.1 // indirect
40 | github.com/felixge/httpsnoop v1.0.4 // indirect
41 | github.com/go-logr/logr v1.4.2 // indirect
42 | github.com/go-logr/stdr v1.2.2 // indirect
43 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
44 | github.com/google/s2a-go v0.1.8 // indirect
45 | github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
46 | github.com/googleapis/gax-go/v2 v2.14.1 // indirect
47 | github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
48 | github.com/kr/text v0.2.0 // indirect
49 | github.com/mailru/easyjson v0.7.7 // indirect
50 | github.com/mattn/go-isatty v0.0.20 // indirect
51 | github.com/ncruces/go-strftime v0.1.9 // indirect
52 | github.com/pmezard/go-difflib v1.0.0 // indirect
53 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
54 | github.com/viterin/partial v1.1.0 // indirect
55 | github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
56 | go.opencensus.io v0.24.0 // indirect
57 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
58 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
59 | go.opentelemetry.io/otel v1.29.0 // indirect
60 | go.opentelemetry.io/otel/metric v1.29.0 // indirect
61 | go.opentelemetry.io/otel/trace v1.29.0 // indirect
62 | golang.org/x/crypto v0.31.0 // indirect
63 | golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect
64 | golang.org/x/net v0.33.0 // indirect
65 | golang.org/x/oauth2 v0.24.0 // indirect
66 | golang.org/x/sync v0.10.0 // indirect
67 | golang.org/x/sys v0.28.0 // indirect
68 | golang.org/x/text v0.21.0 // indirect
69 | golang.org/x/time v0.8.0 // indirect
70 | golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect
71 | google.golang.org/genproto v0.0.0-20240814211410-ddb44dafa142 // indirect
72 | google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 // indirect
73 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect
74 | google.golang.org/grpc v1.67.1 // indirect
75 | google.golang.org/protobuf v1.36.1 // indirect
76 | gopkg.in/yaml.v3 v3.0.1 // indirect
77 | modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect
78 | modernc.org/libc v1.55.3 // indirect
79 | modernc.org/mathutil v1.6.0 // indirect
80 | modernc.org/memory v1.8.0 // indirect
81 | modernc.org/strutil v1.2.0 // indirect
82 | modernc.org/token v1.1.0 // indirect
83 | )
84 |
--------------------------------------------------------------------------------
/packages/llm/providers/cached/sqlitecache/sqlite.go:
--------------------------------------------------------------------------------
1 | package sqlitecache
2 |
3 | import (
4 | "context"
5 | "crypto/sha256"
6 | "database/sql"
7 | "encoding/json"
8 | "fmt"
9 | "hash"
10 |
11 | "github.com/stillmatic/gollum/packages/llm"
12 | "github.com/stillmatic/gollum/packages/llm/cache"
13 | _ "modernc.org/sqlite"
14 | )
15 |
16 | type SQLiteCache struct {
17 | db *sql.DB
18 | hasher hash.Hash
19 | numRequests int
20 | numCacheHits int
21 | }
22 |
23 | func NewSQLiteCache(dbPath string) (*SQLiteCache, error) {
24 | db, err := sql.Open("sqlite", dbPath)
25 | if err != nil {
26 | return nil, fmt.Errorf("failed to open database: %w", err)
27 | }
28 |
29 | if err := initDB(db); err != nil {
30 | return nil, fmt.Errorf("failed to initialize database: %w", err)
31 | }
32 |
33 | return &SQLiteCache{
34 | db: db,
35 | hasher: sha256.New(),
36 | }, nil
37 | }
38 |
39 | func (c *SQLiteCache) GetResponse(ctx context.Context, req llm.InferRequest) (string, error) {
40 | c.numRequests++
41 | requestJSON, err := json.Marshal(req)
42 | if err != nil {
43 | return "", err
44 | }
45 |
46 | hashedRequest := c.hasher.Sum(requestJSON)
47 |
48 | var response string
49 | err = c.db.QueryRowContext(ctx, "SELECT response FROM response_cache WHERE request = ?", hashedRequest).Scan(&response)
50 | if err != nil {
51 | return "", err
52 | }
53 | c.numCacheHits++
54 |
55 | return response, nil
56 | }
57 |
58 | func (c *SQLiteCache) SetResponse(ctx context.Context, req llm.InferRequest, response string) error {
59 | requestJSON, err := json.Marshal(req)
60 | if err != nil {
61 | return err
62 | }
63 | hashedRequest := c.hasher.Sum(requestJSON)
64 |
65 | _, err = c.db.ExecContext(ctx, "INSERT INTO response_cache (request, response) VALUES (?, ?)", hashedRequest, response)
66 | return err
67 | }
68 |
69 | func (c *SQLiteCache) GetEmbedding(ctx context.Context, modelConfig string, input string) ([]float32, error) {
70 | c.numRequests++
71 | var embeddingBlob []byte
72 | err := c.db.QueryRowContext(ctx, "SELECT embedding FROM embedding_cache WHERE model_config = ? AND input_string = ?", modelConfig, input).Scan(&embeddingBlob)
73 | if err != nil {
74 | return nil, err
75 | }
76 | c.numCacheHits++
77 |
78 | var embedding []float32
79 | err = json.Unmarshal(embeddingBlob, &embedding)
80 | if err != nil {
81 | return nil, err
82 | }
83 |
84 | return embedding, nil
85 | }
86 |
87 | func (c *SQLiteCache) SetEmbedding(ctx context.Context, modelConfig string, input string, embedding []float32) error {
88 | embeddingBlob, err := json.Marshal(embedding)
89 | if err != nil {
90 | return err
91 | }
92 |
93 | _, err = c.db.ExecContext(ctx, "INSERT OR REPLACE INTO embedding_cache (model_config, input_string, embedding) VALUES (?, ?, ?)", modelConfig, input, embeddingBlob)
94 | return err
95 | }
96 |
97 | func (c *SQLiteCache) Close() error {
98 | return c.db.Close()
99 | }
100 |
101 | func (c *SQLiteCache) GetStats() cache.CacheStats {
102 | return cache.CacheStats{
103 | NumRequests: c.numRequests,
104 | NumCacheHits: c.numCacheHits,
105 | }
106 | }
107 |
108 | func initDB(db *sql.DB) error {
109 | _, err := db.Exec(`
110 | CREATE TABLE IF NOT EXISTS response_cache (
111 | id INTEGER PRIMARY KEY AUTOINCREMENT,
112 | request BLOB,
113 | response TEXT
114 | );
115 | CREATE TABLE IF NOT EXISTS embedding_cache (
116 | id INTEGER PRIMARY KEY AUTOINCREMENT,
117 | model_config TEXT,
118 | input_string TEXT,
119 | embedding BLOB,
120 | UNIQUE(model_config, input_string)
121 | );
122 | `)
123 | if err != nil {
124 | return err
125 | }
126 |
127 | // Set to WAL mode for better performance
128 | _, err = db.Exec("PRAGMA journal_mode=WAL;")
129 | return err
130 | }
131 |
132 | // Ensure SQLiteCache implements the Cache interface
133 | var _ cache.Cache = (*SQLiteCache)(nil)
134 |
--------------------------------------------------------------------------------
/packages/llm/internal/mocks/llm.go:
--------------------------------------------------------------------------------
1 | // Code generated by MockGen. DO NOT EDIT.
2 | // Source: llm.go
3 |
4 | // Package mock_llm is a generated GoMock package.
5 | package mock_llm
6 |
7 | import (
8 | context "context"
9 | reflect "reflect"
10 |
11 | llm "github.com/stillmatic/gollum/packages/llm"
12 | gomock "go.uber.org/mock/gomock"
13 | )
14 |
15 | // MockResponder is a mock of Responder interface.
16 | type MockResponder struct {
17 | ctrl *gomock.Controller
18 | recorder *MockResponderMockRecorder
19 | }
20 |
21 | // MockResponderMockRecorder is the mock recorder for MockResponder.
22 | type MockResponderMockRecorder struct {
23 | mock *MockResponder
24 | }
25 |
26 | // NewMockResponder creates a new mock instance.
27 | func NewMockResponder(ctrl *gomock.Controller) *MockResponder {
28 | mock := &MockResponder{ctrl: ctrl}
29 | mock.recorder = &MockResponderMockRecorder{mock}
30 | return mock
31 | }
32 |
33 | // EXPECT returns an object that allows the caller to indicate expected use.
34 | func (m *MockResponder) EXPECT() *MockResponderMockRecorder {
35 | return m.recorder
36 | }
37 |
38 | // GenerateResponse mocks base method.
39 | func (m *MockResponder) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) {
40 | m.ctrl.T.Helper()
41 | ret := m.ctrl.Call(m, "GenerateResponse", ctx, req)
42 | ret0, _ := ret[0].(string)
43 | ret1, _ := ret[1].(error)
44 | return ret0, ret1
45 | }
46 |
47 | // GenerateResponse indicates an expected call of GenerateResponse.
48 | func (mr *MockResponderMockRecorder) GenerateResponse(ctx, req interface{}) *gomock.Call {
49 | mr.mock.ctrl.T.Helper()
50 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateResponse", reflect.TypeOf((*MockResponder)(nil).GenerateResponse), ctx, req)
51 | }
52 |
53 | // GenerateResponseAsync mocks base method.
54 | func (m *MockResponder) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) {
55 | m.ctrl.T.Helper()
56 | ret := m.ctrl.Call(m, "GenerateResponseAsync", ctx, req)
57 | ret0, _ := ret[0].(<-chan llm.StreamDelta)
58 | ret1, _ := ret[1].(error)
59 | return ret0, ret1
60 | }
61 |
62 | // GenerateResponseAsync indicates an expected call of GenerateResponseAsync.
63 | func (mr *MockResponderMockRecorder) GenerateResponseAsync(ctx, req interface{}) *gomock.Call {
64 | mr.mock.ctrl.T.Helper()
65 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateResponseAsync", reflect.TypeOf((*MockResponder)(nil).GenerateResponseAsync), ctx, req)
66 | }
67 |
68 | // MockEmbedder is a mock of Embedder interface.
69 | type MockEmbedder struct {
70 | ctrl *gomock.Controller
71 | recorder *MockEmbedderMockRecorder
72 | }
73 |
74 | // MockEmbedderMockRecorder is the mock recorder for MockEmbedder.
75 | type MockEmbedderMockRecorder struct {
76 | mock *MockEmbedder
77 | }
78 |
79 | // NewMockEmbedder creates a new mock instance.
80 | func NewMockEmbedder(ctrl *gomock.Controller) *MockEmbedder {
81 | mock := &MockEmbedder{ctrl: ctrl}
82 | mock.recorder = &MockEmbedderMockRecorder{mock}
83 | return mock
84 | }
85 |
86 | // EXPECT returns an object that allows the caller to indicate expected use.
87 | func (m *MockEmbedder) EXPECT() *MockEmbedderMockRecorder {
88 | return m.recorder
89 | }
90 |
91 | // GenerateEmbedding mocks base method.
92 | func (m *MockEmbedder) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) {
93 | m.ctrl.T.Helper()
94 | ret := m.ctrl.Call(m, "GenerateEmbedding", ctx, req)
95 | ret0, _ := ret[0].(*llm.EmbeddingResponse)
96 | ret1, _ := ret[1].(error)
97 | return ret0, ret1
98 | }
99 |
100 | // GenerateEmbedding indicates an expected call of GenerateEmbedding.
101 | func (mr *MockEmbedderMockRecorder) GenerateEmbedding(ctx, req interface{}) *gomock.Call {
102 | mr.mock.ctrl.T.Helper()
103 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateEmbedding", reflect.TypeOf((*MockEmbedder)(nil).GenerateEmbedding), ctx, req)
104 | }
105 |
--------------------------------------------------------------------------------
/packages/vectorstore/vectorstore_memory.go:
--------------------------------------------------------------------------------
1 | package vectorstore
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "strings"
7 |
8 | "github.com/pkg/errors"
9 | "github.com/sashabaranov/go-openai"
10 | "github.com/stillmatic/gollum"
11 | "github.com/viterin/vek/vek32"
12 | "gocloud.dev/blob"
13 | )
14 |
15 | // MemoryVectorStore embeds documents on insert and stores them in memory
16 | type MemoryVectorStore struct {
17 | Documents []gollum.Document
18 | LLM gollum.Embedder
19 | }
20 |
21 | func NewMemoryVectorStore(llm gollum.Embedder) *MemoryVectorStore {
22 | return &MemoryVectorStore{
23 | Documents: make([]gollum.Document, 0),
24 | LLM: llm,
25 | }
26 | }
27 |
28 | func (m *MemoryVectorStore) Insert(ctx context.Context, d gollum.Document) error {
29 | // replace newlines with spaces and strip whitespace, per OpenAI's recommendation
30 | if d.Embedding == nil {
31 | cleanText := strings.ReplaceAll(d.Content, "\n", " ")
32 | cleanText = strings.TrimSpace(cleanText)
33 |
34 | embedding, err := m.LLM.CreateEmbeddings(ctx, openai.EmbeddingRequest{
35 | Input: []string{cleanText},
36 | // TODO: make this configurable -- may require forking the base library, this expects an enum
37 | Model: openai.AdaEmbeddingV2,
38 | })
39 | if err != nil {
40 | return errors.Wrap(err, "failed to create embedding")
41 | }
42 | d.Embedding = embedding.Data[0].Embedding
43 | }
44 |
45 | m.Documents = append(m.Documents, d)
46 | return nil
47 | }
48 |
49 | func (m *MemoryVectorStore) Persist(ctx context.Context, bucket *blob.Bucket, path string) error {
50 | // save documents to disk
51 | data, err := json.Marshal(m.Documents)
52 | if err != nil {
53 | return errors.Wrap(err, "failed to marshal documents to JSON")
54 | }
55 | err = bucket.WriteAll(ctx, path, data, nil)
56 | if err != nil {
57 | return errors.Wrap(err, "failed to write documents to file")
58 | }
59 | return nil
60 | }
61 |
62 | func NewMemoryVectorStoreFromDisk(ctx context.Context, bucket *blob.Bucket, path string, llm gollum.Embedder) (*MemoryVectorStore, error) {
63 | data, err := bucket.ReadAll(ctx, path)
64 | if err != nil {
65 | return nil, errors.Wrap(err, "failed to read file")
66 | }
67 | var documents []gollum.Document
68 | err = json.Unmarshal(data, &documents)
69 | if err != nil {
70 | return nil, errors.Wrap(err, "failed to unmarshal JSON")
71 | }
72 | return &MemoryVectorStore{
73 | Documents: documents,
74 | LLM: llm,
75 | }, nil
76 | }
77 |
78 | func (m *MemoryVectorStore) Query(ctx context.Context, qb QueryRequest) ([]*gollum.Document, error) {
79 | if len(m.Documents) == 0 {
80 | return nil, errors.New("no documents in store")
81 | }
82 | if len(qb.EmbeddingStrings) > 0 {
83 | // concatenate strings and set query
84 | qb.Query = strings.Join(qb.EmbeddingStrings, " ")
85 | }
86 | if len(qb.EmbeddingFloats) == 0 {
87 | // create embedding
88 | embedding, err := m.LLM.CreateEmbeddings(ctx, openai.EmbeddingRequest{
89 | Input: []string{qb.Query},
90 | // TODO: make this configurable
91 | Model: openai.AdaEmbeddingV2,
92 | })
93 | if err != nil {
94 | return nil, errors.Wrap(err, "failed to create embedding")
95 | }
96 | qb.EmbeddingFloats = embedding.Data[0].Embedding
97 | }
98 | scores := Heap{}
99 | k := qb.K
100 | scores.Init(k)
101 |
102 | for _, doc := range m.Documents {
103 | score := vek32.CosineSimilarity(qb.EmbeddingFloats, doc.Embedding)
104 | doc := doc
105 | ns := NodeSimilarity{
106 | Document: &doc,
107 | Similarity: score,
108 | }
109 | // maintain a max-heap of size k
110 | scores.Push(ns)
111 | if scores.Len() > k {
112 | scores.Pop()
113 | }
114 | }
115 |
116 | result := make([]*gollum.Document, k)
117 | for i := 0; i < k; i++ {
118 | ns := scores.Pop()
119 | doc := ns.Document
120 | result[k-i-1] = doc
121 | }
122 | return result, nil
123 | }
124 |
125 | // RetrieveAll returns all documents
126 | func (m *MemoryVectorStore) RetrieveAll(ctx context.Context) ([]gollum.Document, error) {
127 | return m.Documents, nil
128 | }
129 |
--------------------------------------------------------------------------------
/packages/llm/providers/cached/cached_provider_test.go:
--------------------------------------------------------------------------------
1 | package cached_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/stillmatic/gollum/packages/llm"
8 | mock_llm "github.com/stillmatic/gollum/packages/llm/internal/mocks"
9 | "github.com/stillmatic/gollum/packages/llm/providers/cached"
10 | "github.com/stretchr/testify/assert"
11 | "go.uber.org/mock/gomock"
12 | )
13 |
14 | func TestCachedProvider(t *testing.T) {
15 | ctrl := gomock.NewController(t)
16 | t.Run("responder", func(t *testing.T) {
17 | mockProvider := mock_llm.NewMockResponder(ctrl)
18 | ctx := context.Background()
19 | req := llm.InferRequest{
20 | Messages: []llm.InferMessage{
21 | {Content: "hello world",
22 | Role: "user",
23 | },
24 | },
25 | ModelConfig: llm.ModelConfig{
26 | ModelName: "fake_model",
27 | ProviderType: llm.ProviderAnthropic,
28 | },
29 | }
30 |
31 | mockProvider.EXPECT().GenerateResponse(ctx, req).Return("hello user", nil)
32 |
33 | cachedProvider, err := cached.NewLocalCachedResponder(mockProvider, ":memory:")
34 | assert.NoError(t, err)
35 | resp, err := cachedProvider.GenerateResponse(ctx, req)
36 | assert.NoError(t, err)
37 | assert.Equal(t, "hello user", resp)
38 |
39 | cs := cachedProvider.GetCacheStats()
40 | assert.Equal(t, 1, cs.NumRequests)
41 | assert.Equal(t, 0, cs.NumCacheHits)
42 |
43 | resp, err = cachedProvider.GenerateResponse(ctx, req)
44 | assert.NoError(t, err)
45 | assert.Equal(t, "hello user", resp)
46 |
47 | cs = cachedProvider.GetCacheStats()
48 | assert.Equal(t, 2, cs.NumRequests)
49 | assert.Equal(t, 1, cs.NumCacheHits)
50 | })
51 |
52 | t.Run("embedder", func(t *testing.T) {
53 | mockProvider := mock_llm.NewMockEmbedder(ctrl)
54 | ctx := context.Background()
55 | req := llm.EmbedRequest{
56 | Input: []string{"abc"},
57 | ModelConfig: llm.ModelConfig{
58 | ModelName: "fake_model",
59 | ProviderType: llm.ProviderAnthropic,
60 | },
61 | }
62 |
63 | // we call the function twice but it returns the cached value second time, so
64 | // the provider should only be called once
65 | mockProvider.EXPECT().GenerateEmbedding(ctx, req).Return(&llm.EmbeddingResponse{
66 | Data: []llm.Embedding{{Values: []float32{1.0, 2.0, 3.0}}}}, nil).Times(1)
67 |
68 | cachedProvider, err := cached.NewLocalCachedEmbedder(mockProvider, ":memory:")
69 | assert.NoError(t, err)
70 | resp, err := cachedProvider.GenerateEmbedding(ctx, req)
71 | assert.NoError(t, err)
72 | assert.Equal(t, []float32{1.0, 2.0, 3.0}, resp.Data[0].Values)
73 |
74 | cs := cachedProvider.GetCacheStats()
75 | assert.Equal(t, 1, cs.NumRequests)
76 | assert.Equal(t, 0, cs.NumCacheHits)
77 |
78 | resp, err = cachedProvider.GenerateEmbedding(ctx, req)
79 | assert.NoError(t, err)
80 | assert.Equal(t, []float32{1.0, 2.0, 3.0}, resp.Data[0].Values)
81 |
82 | cs = cachedProvider.GetCacheStats()
83 | assert.Equal(t, 2, cs.NumRequests)
84 | assert.Equal(t, 1, cs.NumCacheHits)
85 |
86 | req = llm.EmbedRequest{
87 | Input: []string{"abc", "def"},
88 | ModelConfig: llm.ModelConfig{
89 | ModelName: "fake_model",
90 | ProviderType: llm.ProviderAnthropic,
91 | },
92 | }
93 | // this should be the request to the provider since we don't have the second embedding cached
94 | cachedReq := llm.EmbedRequest{
95 | Input: []string{"def"},
96 | ModelConfig: llm.ModelConfig{
97 | ModelName: "fake_model",
98 | ProviderType: llm.ProviderAnthropic,
99 | },
100 | }
101 | // and the provider only returns one
102 | mockProvider.EXPECT().GenerateEmbedding(ctx, cachedReq).Return(&llm.EmbeddingResponse{
103 | Data: []llm.Embedding{{Values: []float32{2.0, 3.0, 4.0}}}}, nil).Times(1)
104 |
105 | //but we should expect to get both embeddings back
106 | resp, err = cachedProvider.GenerateEmbedding(ctx, req)
107 | assert.NoError(t, err)
108 | assert.Equal(t, []float32{1.0, 2.0, 3.0}, resp.Data[0].Values)
109 | assert.Equal(t, []float32{2.0, 3.0, 4.0}, resp.Data[1].Values)
110 |
111 | cs = cachedProvider.GetCacheStats()
112 | assert.Equal(t, 4, cs.NumRequests)
113 | assert.Equal(t, 2, cs.NumCacheHits)
114 | })
115 | }
116 |
--------------------------------------------------------------------------------
/packages/dispatch/dispatch.go:
--------------------------------------------------------------------------------
1 | package dispatch
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/stillmatic/gollum"
7 | "github.com/stillmatic/gollum/packages/jsonparser"
8 | "strings"
9 | "text/template"
10 |
11 | "github.com/sashabaranov/go-openai"
12 | )
13 |
14 | type Dispatcher[T any] interface {
15 | // Prompt generates an object of type T from the given prompt.
16 | Prompt(ctx context.Context, prompt string) (T, error)
17 | // PromptTemplate generates an object of type T from a given template.
18 | // The prompt is then a template string that is rendered with the given values.
19 | PromptTemplate(ctx context.Context, template *template.Template, values interface{}) (T, error)
20 | }
21 |
22 | type DummyDispatcher[T any] struct{}
23 |
24 | func NewDummyDispatcher[T any]() *DummyDispatcher[T] {
25 | return &DummyDispatcher[T]{}
26 | }
27 |
28 | func (d *DummyDispatcher[T]) Prompt(ctx context.Context, prompt string) (T, error) {
29 | var t T
30 | return t, nil
31 | }
32 |
33 | func (d *DummyDispatcher[T]) PromptTemplate(ctx context.Context, template *template.Template, values interface{}) (T, error) {
34 | var t T
35 | var sb strings.Builder
36 | err := template.Execute(&sb, values)
37 | if err != nil {
38 | return t, fmt.Errorf("error executing template: %w", err)
39 | }
40 | return t, nil
41 | }
42 |
43 | type OpenAIDispatcherConfig struct {
44 | Model *string
45 | Temperature *float32
46 | MaxTokens *int
47 | }
48 |
49 | // OpenAIDispatcher dispatches to any OpenAI compatible model.
50 | // For any type T and prompt, it will generate and parse the response into T.
51 | type OpenAIDispatcher[T any] struct {
52 | *OpenAIDispatcherConfig
53 | completer gollum.ChatCompleter
54 | ti openai.Tool
55 | systemPrompt string
56 | parser jsonparser.Parser[T]
57 | }
58 |
59 | func NewOpenAIDispatcher[T any](name, description, systemPrompt string, completer gollum.ChatCompleter, cfg *OpenAIDispatcherConfig) *OpenAIDispatcher[T] {
60 | // note: name must not have spaces - valid json
61 | // we won't check here but the openai client will throw an error
62 | var t T
63 | fi := StructToJsonSchema(name, description, t)
64 | ti := FunctionInputToTool(fi)
65 | parser := jsonparser.NewJSONParserGeneric[T](true)
66 | return &OpenAIDispatcher[T]{
67 | OpenAIDispatcherConfig: cfg,
68 | completer: completer,
69 | ti: ti,
70 | parser: parser,
71 | systemPrompt: systemPrompt,
72 | }
73 | }
74 |
75 | func (d *OpenAIDispatcher[T]) Prompt(ctx context.Context, prompt string) (T, error) {
76 | var output T
77 | model := openai.GPT3Dot5Turbo1106
78 | temperature := float32(0.0)
79 | maxTokens := 512
80 | if d.OpenAIDispatcherConfig != nil {
81 | if d.Model != nil {
82 | model = *d.Model
83 | }
84 | if d.Temperature != nil {
85 | temperature = *d.Temperature
86 | }
87 | if d.MaxTokens != nil {
88 | maxTokens = *d.MaxTokens
89 | }
90 | }
91 |
92 | req := openai.ChatCompletionRequest{
93 | Model: model,
94 | Messages: []openai.ChatCompletionMessage{
95 | {
96 | Role: openai.ChatMessageRoleSystem,
97 | Content: d.systemPrompt,
98 | },
99 | {
100 | Role: openai.ChatMessageRoleUser,
101 | Content: prompt,
102 | },
103 | },
104 | Tools: []openai.Tool{d.ti},
105 | ToolChoice: openai.ToolChoice{
106 | Type: "function",
107 | Function: openai.ToolFunction{
108 | Name: d.ti.Function.Name,
109 | }},
110 | Temperature: temperature,
111 | MaxTokens: maxTokens,
112 | }
113 |
114 | resp, err := d.completer.CreateChatCompletion(ctx, req)
115 | if err != nil {
116 | return output, err
117 | }
118 |
119 | toolOutput := resp.Choices[0].Message.ToolCalls[0].Function.Arguments
120 | output, err = d.parser.Parse(ctx, []byte(toolOutput))
121 | if err != nil {
122 | return output, err
123 | }
124 |
125 | return output, nil
126 | }
127 |
128 | // PromptTemplate generates an object of type T from a given template.
129 | // This is mostly a convenience wrapper around Prompt.
130 | func (d *OpenAIDispatcher[T]) PromptTemplate(ctx context.Context, template *template.Template, values interface{}) (T, error) {
131 | var t T
132 | var sb strings.Builder
133 | err := template.Execute(&sb, values)
134 | if err != nil {
135 | return t, fmt.Errorf("error executing template: %w", err)
136 | }
137 | return d.Prompt(ctx, sb.String())
138 | }
139 |
--------------------------------------------------------------------------------
/packages/hyde/hyde_test.go:
--------------------------------------------------------------------------------
1 | package hyde_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/stillmatic/gollum/packages/hyde"
7 | "github.com/stillmatic/gollum/packages/vectorstore"
8 | "testing"
9 |
10 | "github.com/stillmatic/gollum"
11 | mock_gollum "github.com/stillmatic/gollum/internal/mocks"
12 | "github.com/stillmatic/gollum/internal/testutil"
13 | "github.com/stretchr/testify/assert"
14 | "go.uber.org/mock/gomock"
15 | )
16 |
17 | func TestHyde(t *testing.T) {
18 | ctrl := gomock.NewController(t)
19 | embedder := mock_gollum.NewMockEmbedder(ctrl)
20 | completer := mock_gollum.NewMockChatCompleter(ctrl)
21 | prompter := hyde.NewZeroShotPrompter(
22 | "Roleplay as a character. Write a short biographical answer to the question.\nQ: %s\nA:",
23 | )
24 | generator := hyde.NewLLMGenerator(completer)
25 | encoder := hyde.NewLLMEncoder(embedder)
26 | vs := vectorstore.NewMemoryVectorStore(embedder)
27 | for i := range make([]int, 10) {
28 | embedder.EXPECT().CreateEmbeddings(context.Background(), gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(1, 1536), nil)
29 | vs.Insert(context.Background(), gollum.NewDocumentFromString(fmt.Sprintf("hey %d", i)))
30 | }
31 | assert.Equal(t, 10, len(vs.Documents))
32 | searcher := hyde.NewVectorSearcher(
33 | vs,
34 | )
35 |
36 | hyde := hyde.NewHyde(prompter, generator, encoder, searcher)
37 | t.Run("prompter", func(t *testing.T) {
38 | prompt := prompter.BuildPrompt(context.Background(), "What is your name?")
39 | assert.Equal(t, "Roleplay as a character. Write a short biographical answer to the question.\nQ: What is your name?\nA:", prompt)
40 | })
41 |
42 | t.Run("generator", func(t *testing.T) {
43 | ctx := context.Background()
44 | k := 10
45 | completer.EXPECT().CreateChatCompletion(ctx, gomock.Any()).Return(testutil.GetRandomChatCompletionResponse(k), nil)
46 | res, err := generator.Generate(ctx, "What is your name?", k)
47 | assert.NoError(t, err)
48 | assert.Equal(t, 10, len(res))
49 | })
50 |
51 | t.Run("encoder", func(t *testing.T) {
52 | ctx := context.Background()
53 | embedder.EXPECT().CreateEmbeddings(ctx, gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(1, 1536), nil)
54 | res, err := encoder.Encode(ctx, "What is your name?")
55 | assert.NoError(t, err)
56 | assert.Equal(t, 1536, len(res))
57 |
58 | embedder.EXPECT().CreateEmbeddings(ctx, gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(2, 1536), nil)
59 | res2, err := encoder.EncodeBatch(ctx, []string{"What is your name?", "What is your quest?"})
60 | assert.NoError(t, err)
61 | assert.Equal(t, 2, len(res2))
62 | assert.Equal(t, 1536, len(res2[0]))
63 | })
64 |
65 | t.Run("e2e", func(t *testing.T) {
66 | ctx := context.Background()
67 | completer.EXPECT().CreateChatCompletion(ctx, gomock.Any()).Return(testutil.GetRandomChatCompletionResponse(3), nil)
68 | embedder.EXPECT().CreateEmbeddings(ctx, gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(4, 1536), nil)
69 | res, err := hyde.SearchEndToEnd(ctx, "What is your name?", 3)
70 | assert.NoError(t, err)
71 | assert.Equal(t, 3, len(res))
72 | })
73 | }
74 |
75 | func BenchmarkHyde(b *testing.B) {
76 | prompter := hyde.NewZeroShotPrompter(
77 | "Roleplay as a character. Write a short biographical answer to the question.\nQ: %s\nA:",
78 | )
79 | ctrl := gomock.NewController(b)
80 | embedder := mock_gollum.NewMockEmbedder(ctrl)
81 | completer := mock_gollum.NewMockChatCompleter(ctrl)
82 | generator := hyde.NewLLMGenerator(completer)
83 | encoder := hyde.NewLLMEncoder(embedder)
84 | vs := vectorstore.NewMemoryVectorStore(embedder)
85 | searcher := hyde.NewVectorSearcher(
86 | vs,
87 | )
88 | hyde := hyde.NewHyde(prompter, generator, encoder, searcher)
89 | docNums := []int{10, 100, 1000, 10000, 100_000, 1_000_000}
90 | for _, docNum := range docNums {
91 | b.Run(fmt.Sprintf("docs=%v", docNum), func(b *testing.B) {
92 | for _ = range make([]int, docNum) {
93 | embedder.EXPECT().CreateEmbeddings(gomock.Any(), gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(1, 1536), nil)
94 | vs.Insert(context.Background(), gollum.NewDocumentFromString("hey"))
95 | }
96 | b.ResetTimer()
97 | for i := 0; i < b.N; i++ {
98 | k := 8
99 | completer.EXPECT().CreateChatCompletion(context.Background(), gomock.Any()).Return(testutil.GetRandomChatCompletionResponse(k), nil)
100 | embedder.EXPECT().CreateEmbeddings(context.Background(), gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(k+1, 1536), nil)
101 | _, err := hyde.SearchEndToEnd(context.Background(), "What is your name?", k)
102 | assert.NoError(b, err)
103 | }
104 | })
105 | }
106 | }
107 |
--------------------------------------------------------------------------------
/packages/llm/providers/anthropic/anthropic.go:
--------------------------------------------------------------------------------
1 | package anthropic
2 |
3 | import (
4 | "context"
5 | "encoding/base64"
6 | "log/slog"
7 | "slices"
8 |
9 | "github.com/stillmatic/gollum/packages/llm"
10 |
11 | "github.com/liushuangls/go-anthropic/v2"
12 | "github.com/pkg/errors"
13 | )
14 |
15 | type Provider struct {
16 | client *anthropic.Client
17 | cacheEnabled bool
18 | }
19 |
20 | func NewAnthropicProvider(apiKey string) *Provider {
21 | return &Provider{
22 | client: anthropic.NewClient(apiKey),
23 | }
24 | }
25 |
26 | func NewAnthropicProviderWithCache(apiKey string) *Provider {
27 | client := anthropic.NewClient(apiKey, anthropic.WithBetaVersion(anthropic.BetaPromptCaching20240731))
28 | return &Provider{
29 | client: client,
30 | cacheEnabled: true,
31 | }
32 | }
33 |
34 | func reqToMessages(req llm.InferRequest) ([]anthropic.Message, []anthropic.MessageSystemPart, error) {
35 | msgs := make([]anthropic.Message, 0)
36 | systemMsgs := make([]anthropic.MessageSystemPart, 0)
37 | for _, m := range req.Messages {
38 | if m.Role == "system" {
39 | msgContent := anthropic.MessageSystemPart{
40 | Type: "text",
41 | Text: m.Content,
42 | }
43 | if m.ShouldCache {
44 | msgContent.CacheControl = &anthropic.MessageCacheControl{
45 | Type: anthropic.CacheControlTypeEphemeral,
46 | }
47 | }
48 | systemMsgs = append(systemMsgs, msgContent)
49 | continue
50 | }
51 |
52 | // only allow user and assistant roles
53 | // TODO: this should be a little cleaner...
54 | if !(slices.Index([]string{string(anthropic.RoleUser), string(anthropic.RoleAssistant)}, m.Role) > -1) {
55 | return nil, nil, errors.New("invalid role")
56 | }
57 | content := make([]anthropic.MessageContent, 0)
58 | txtContent := anthropic.NewTextMessageContent(m.Content)
59 | // this will fail if the model is not configured to cache
60 | if m.ShouldCache {
61 | txtContent.SetCacheControl()
62 | }
63 | content = append(content, txtContent)
64 | if m.Image != nil && len(m.Image) > 0 {
65 | b64Image := base64.StdEncoding.EncodeToString(m.Image)
66 | // TODO: support other image types
67 | content = append(content, anthropic.NewImageMessageContent(
68 | anthropic.MessageContentSource{Type: "base64", MediaType: "image/png", Data: b64Image}))
69 | }
70 | newMsg := anthropic.Message{
71 | Role: anthropic.ChatRole(m.Role),
72 | Content: content,
73 | }
74 |
75 | msgs = append(msgs, newMsg)
76 | }
77 |
78 | return msgs, systemMsgs, nil
79 | }
80 |
81 | func (p *Provider) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) {
82 | msgs, systemPrompt, err := reqToMessages(req)
83 | if err != nil {
84 | return "", errors.Wrap(err, "invalid messages")
85 | }
86 | msgsReq := anthropic.MessagesRequest{
87 | Model: anthropic.Model(req.ModelConfig.ModelName),
88 | Messages: msgs,
89 | MaxTokens: req.MessageOptions.MaxTokens,
90 | Temperature: &req.MessageOptions.Temperature,
91 | }
92 | if systemPrompt != nil && len(systemPrompt) > 0 {
93 | msgsReq.MultiSystem = systemPrompt
94 | }
95 | res, err := p.client.CreateMessagesStream(ctx, anthropic.MessagesStreamRequest{
96 | MessagesRequest: msgsReq,
97 | })
98 | if err != nil {
99 | return "", errors.Wrap(err, "anthropic messages stream error")
100 | }
101 |
102 | return res.GetFirstContentText(), nil
103 | }
104 |
105 | func (p *Provider) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) {
106 | outChan := make(chan llm.StreamDelta)
107 | go func() {
108 | defer close(outChan)
109 | msgs, systemPrompt, err := reqToMessages(req)
110 | if err != nil {
111 | slog.Error("invalid messages", "err", err)
112 | return
113 | }
114 | msgsReq := anthropic.MessagesRequest{
115 | Model: anthropic.Model(req.ModelConfig.ModelName),
116 | Messages: msgs,
117 | MaxTokens: req.MessageOptions.MaxTokens,
118 | Temperature: &req.MessageOptions.Temperature,
119 | }
120 | if systemPrompt != nil {
121 | msgsReq.MultiSystem = systemPrompt
122 | }
123 |
124 | _, err = p.client.CreateMessagesStream(ctx, anthropic.MessagesStreamRequest{
125 | MessagesRequest: msgsReq,
126 | OnContentBlockDelta: func(data anthropic.MessagesEventContentBlockDeltaData) {
127 | if data.Delta.Text == nil {
128 | outChan <- llm.StreamDelta{
129 | EOF: true,
130 | }
131 | return
132 | }
133 |
134 | outChan <- llm.StreamDelta{
135 | Text: *data.Delta.Text,
136 | }
137 | },
138 | })
139 | if err != nil {
140 | slog.Error("anthropic messages stream error", "err", err)
141 | return
142 | }
143 | }()
144 |
145 | return outChan, nil
146 | }
147 |
--------------------------------------------------------------------------------
/packages/dispatch/dispatch_test.go:
--------------------------------------------------------------------------------
1 | package dispatch_test
2 |
3 | import (
4 | "context"
5 | "os"
6 | "testing"
7 | "text/template"
8 |
9 | "github.com/sashabaranov/go-openai"
10 | mock_gollum "github.com/stillmatic/gollum/internal/mocks"
11 | "github.com/stillmatic/gollum/packages/dispatch"
12 | "github.com/stretchr/testify/assert"
13 | "go.uber.org/mock/gomock"
14 | )
15 |
16 | type testInput struct {
17 | Topic string `json:"topic" jsonschema:"required" jsonschema_description:"The topic of the conversation"`
18 | RandomWords []string `json:"random_words" jsonschema:"required" jsonschema_description:"Random words to prime the conversation"`
19 | }
20 |
21 | type templateInput struct {
22 | Topic string
23 | }
24 |
25 | type wordCountOutput struct {
26 | Count int `json:"count" jsonschema:"required" jsonschema_description:"The number of words in the sentence"`
27 | }
28 |
29 | func TestDummyDispatcher(t *testing.T) {
30 | d := dispatch.NewDummyDispatcher[testInput]()
31 |
32 | t.Run("prompt", func(t *testing.T) {
33 | output, err := d.Prompt(context.Background(), "Talk to me about Dinosaurs")
34 |
35 | assert.NoError(t, err)
36 | assert.Equal(t, testInput{}, output)
37 | })
38 | t.Run("promptTemplate", func(t *testing.T) {
39 | te, err := template.New("").Parse("Talk to me about {{.Topic}}")
40 | assert.NoError(t, err)
41 | tempInp := templateInput{
42 | Topic: "Dinosaurs",
43 | }
44 |
45 | output, err := d.PromptTemplate(context.Background(), te, tempInp)
46 | assert.NoError(t, err)
47 | assert.Equal(t, testInput{}, output)
48 | })
49 | }
50 |
51 | func TestOpenAIDispatcher(t *testing.T) {
52 | ctrl := gomock.NewController(t)
53 | completer := mock_gollum.NewMockChatCompleter(ctrl)
54 | systemPrompt := "When prompted, use the tool."
55 | d := dispatch.NewOpenAIDispatcher[testInput]("random_conversation", "Given a topic, return random words", systemPrompt, completer, nil)
56 |
57 | ctx := context.Background()
58 | expected := testInput{
59 | Topic: "dinosaurs",
60 | RandomWords: []string{"dinosaur", "fossil", "extinct"},
61 | }
62 | inpStr := `{"topic": "dinosaurs", "random_words": ["dinosaur", "fossil", "extinct"]}`
63 |
64 | fi := openai.FunctionDefinition(dispatch.StructToJsonSchema("random_conversation", "Given a topic, return random words", testInput{}))
65 | ti := openai.Tool{Type: "function", Function: &fi}
66 | expectedRequest := openai.ChatCompletionRequest{
67 | Model: openai.GPT3Dot5Turbo1106,
68 | Messages: []openai.ChatCompletionMessage{
69 | {
70 | Role: openai.ChatMessageRoleUser,
71 | Content: "Tell me about dinosaurs",
72 | },
73 | },
74 | Tools: []openai.Tool{ti},
75 | ToolChoice: openai.ToolChoice{
76 | Type: "function",
77 | Function: openai.ToolFunction{
78 | Name: "random_conversation",
79 | }},
80 | MaxTokens: 512,
81 | Temperature: 0.0,
82 | }
83 |
84 | t.Run("prompt", func(t *testing.T) {
85 | queryStr := "Tell me about dinosaurs"
86 | completer.EXPECT().CreateChatCompletion(gomock.Any(), expectedRequest).Return(openai.ChatCompletionResponse{
87 | Choices: []openai.ChatCompletionChoice{
88 | {
89 | Message: openai.ChatCompletionMessage{
90 | Role: openai.ChatMessageRoleSystem,
91 | Content: "Hello there!",
92 | ToolCalls: []openai.ToolCall{
93 | {
94 | Type: "function",
95 | Function: openai.FunctionCall{
96 | Name: "random_conversation",
97 | Arguments: inpStr,
98 | },
99 | }},
100 | },
101 | },
102 | },
103 | }, nil)
104 |
105 | output, err := d.Prompt(ctx, queryStr)
106 | assert.NoError(t, err)
107 |
108 | assert.Equal(t, expected, output)
109 | })
110 | t.Run("promptTemplate", func(t *testing.T) {
111 | completer.EXPECT().CreateChatCompletion(gomock.Any(), expectedRequest).Return(openai.ChatCompletionResponse{
112 | Choices: []openai.ChatCompletionChoice{
113 | {
114 | Message: openai.ChatCompletionMessage{
115 | Role: openai.ChatMessageRoleSystem,
116 | Content: "Hello there!",
117 | ToolCalls: []openai.ToolCall{
118 | {
119 | Type: "function",
120 | Function: openai.FunctionCall{
121 | Name: "random_conversation",
122 | Arguments: inpStr,
123 | },
124 | }},
125 | },
126 | },
127 | },
128 | }, nil)
129 |
130 | te, err := template.New("").Parse("Tell me about {{.Topic}}")
131 | assert.NoError(t, err)
132 |
133 | output, err := d.PromptTemplate(ctx, te, templateInput{
134 | Topic: "dinosaurs",
135 | })
136 | assert.NoError(t, err)
137 | assert.Equal(t, expected, output)
138 | })
139 | }
140 |
141 | func TestDispatchIntegration(t *testing.T) {
142 | t.Skip("Skipping integration test")
143 | completer := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
144 | systemPrompt := "When prompted, use the tool on the user's input."
145 | d := dispatch.NewOpenAIDispatcher[wordCountOutput]("wordCounter", "count the number of words in a sentence", systemPrompt, completer, nil)
146 | output, err := d.Prompt(context.Background(), "I like dinosaurs")
147 | assert.NoError(t, err)
148 | assert.Equal(t, 3, output.Count)
149 | }
150 |
--------------------------------------------------------------------------------
/packages/llm/providers/openai/openai.go:
--------------------------------------------------------------------------------
1 | package openai
2 |
3 | import (
4 | "context"
5 | "encoding/base64"
6 | "github.com/stillmatic/gollum/packages/llm"
7 | "io"
8 | "log/slog"
9 |
10 | "github.com/pkg/errors"
11 | "github.com/sashabaranov/go-openai"
12 | )
13 |
14 | type Provider struct {
15 | client *openai.Client
16 | }
17 |
18 | func NewOpenAIProvider(apiKey string) *Provider {
19 | return &Provider{
20 | client: openai.NewClient(apiKey),
21 | }
22 | }
23 |
24 | func NewGenericProvider(apiKey string, baseURL string) *Provider {
25 | genericConfig := openai.DefaultConfig(apiKey)
26 | genericConfig.BaseURL = baseURL
27 | return &Provider{
28 | client: openai.NewClientWithConfig(genericConfig),
29 | }
30 | }
31 |
32 | func NewTogetherProvider(apiKey string) *Provider {
33 | return NewGenericProvider(apiKey, "https://api.together.xyz/v1")
34 | }
35 |
36 | func NewGroqProvider(apiKey string) *Provider {
37 | return NewGenericProvider(apiKey, "https://api.groq.com/openai/v1/")
38 | }
39 |
40 | func NewHyperbolicProvider(apiKey string) *Provider {
41 | return NewGenericProvider(apiKey, "https://api.hyperbolic.xyz/v1")
42 | }
43 |
44 | func NewDeepseekProvider(apiKey string) *Provider {
45 | return NewGenericProvider(apiKey, "https://api.deepseek.com/v1")
46 | }
47 |
48 | func (p *Provider) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) {
49 | msgs := inferReqToOpenAIMessages(req.Messages)
50 |
51 | oaiReq := openai.ChatCompletionRequest{
52 | Model: req.ModelConfig.ModelName,
53 | Messages: msgs,
54 | MaxTokens: req.MessageOptions.MaxTokens,
55 | Temperature: req.MessageOptions.Temperature,
56 | }
57 |
58 | res, err := p.client.CreateChatCompletion(ctx, oaiReq)
59 | if err != nil {
60 | slog.Error("error from openai", "err", err, "req", req.Messages, "model", req.ModelConfig.ModelName)
61 | return "", errors.Wrap(err, "openai chat completion error")
62 | }
63 |
64 | return res.Choices[0].Message.Content, nil
65 | }
66 |
67 | func inferReqToOpenAIMessages(req []llm.InferMessage) []openai.ChatCompletionMessage {
68 | msgs := make([]openai.ChatCompletionMessage, 0)
69 |
70 | for _, m := range req {
71 | msg := openai.ChatCompletionMessage{
72 | Role: m.Role,
73 | Content: m.Content,
74 | }
75 | if m.Image != nil && len(m.Image) > 0 {
76 | b64Image := base64.StdEncoding.EncodeToString(m.Image)
77 | msg.MultiContent = []openai.ChatMessagePart{
78 | {
79 | Type: openai.ChatMessagePartTypeImageURL,
80 | // TODO: support other image types
81 | ImageURL: &openai.ChatMessageImageURL{
82 | URL: "data:image/png;base64," + b64Image,
83 | Detail: openai.ImageURLDetailAuto,
84 | },
85 | },
86 | {
87 | Type: openai.ChatMessagePartTypeText,
88 | Text: m.Content,
89 | },
90 | }
91 | msg.Content = ""
92 | }
93 | msgs = append(msgs, msg)
94 | }
95 | return msgs
96 | }
97 |
98 | func (p *Provider) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) {
99 | outChan := make(chan llm.StreamDelta)
100 | go func() {
101 | defer close(outChan)
102 | msgs := inferReqToOpenAIMessages(req.Messages)
103 | oaiReq := openai.ChatCompletionRequest{
104 | Model: req.ModelConfig.ModelName,
105 | Messages: msgs,
106 | MaxTokens: req.MessageOptions.MaxTokens,
107 | Temperature: req.MessageOptions.Temperature,
108 | }
109 |
110 | stream, err := p.client.CreateChatCompletionStream(ctx, oaiReq)
111 | if err != nil {
112 | slog.Error("error from openai", "err", err, "req", req.Messages, "model", req.ModelConfig.ModelName)
113 | return
114 | }
115 | defer stream.Close()
116 |
117 | response, err := stream.Recv()
118 | if err != nil {
119 | if err == io.EOF {
120 | return
121 | }
122 | slog.Error("error receiving from openai stream", "err", err)
123 | return
124 | }
125 |
126 | if len(response.Choices) > 0 {
127 | content := response.Choices[0].Delta.Content
128 | if content != "" {
129 | select {
130 | case <-ctx.Done():
131 | return
132 | case outChan <- llm.StreamDelta{
133 | Text: content}:
134 | }
135 | } else {
136 | outChan <- llm.StreamDelta{
137 | EOF: true,
138 | }
139 | }
140 | }
141 | }()
142 |
143 | return outChan, nil
144 | }
145 |
146 | func (p *Provider) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) {
147 | // TODO: this only supports openai models, not other providers using the same interface
148 | oaiReq := openai.EmbeddingRequest{
149 | Input: req.Input,
150 | Model: openai.EmbeddingModel(req.ModelConfig.ModelName),
151 | Dimensions: req.Dimensions,
152 | }
153 |
154 | res, err := p.client.CreateEmbeddings(ctx, oaiReq)
155 | if err != nil {
156 | slog.Error("error from openai", "err", err, "req", req.Input, "model", req.ModelConfig.ModelName)
157 | return nil, errors.Wrap(err, "openai embedding error")
158 | }
159 |
160 | respVectors := make([]llm.Embedding, len(res.Data))
161 | for i, v := range res.Data {
162 | respVectors[i] = llm.Embedding{
163 | Values: v.Embedding,
164 | }
165 | }
166 |
167 | return &llm.EmbeddingResponse{
168 | Data: respVectors,
169 | }, nil
170 | }
171 |
--------------------------------------------------------------------------------
/packages/vectorstore/vectorstore.md:
--------------------------------------------------------------------------------
1 | # Document Stores
2 |
3 | Gollum provides several implementations of document stores. Document stores solve the problem of, we have lots of documents and want to find the most relevant documents to a particular query.
4 |
5 | Names are TBD.
6 |
7 | # docstore
8 |
9 | Docstore is a simple document store which provides no indexing. It simply has insert and retrieve, you must know the ID.
10 |
11 | Think of it as essentially a key-value store. In the future, we will probably extend this to have a KV interface and provide an implementation backed by Redis or DragonflyDB.
12 |
13 | # memory vector store
14 |
15 | This is a simple document store that takes an embedding model and embeds documents on insert. At retrieval time, it embeds the search query and does a simple KNN lookup.
16 |
17 | # xyz vector store
18 |
19 | I haven't gotten around to actually writing any of these implementations but it should be simple to imagine clients for Weaviate or Pinecone following the interface. I don't actually use them though :)
20 |
21 |
22 |
23 | # compressed store
24 |
25 | This is inspired by [link] - basically we can use gzip to compress the documents, then at query time, compute `enc(term) + enc(doc)` - in this case, `enc(doc)` is computed on insert. The idea is that the more similar your term and document are, the shorter the encoded representation is - because there is less entropy that needs to be included in the compressed representation.
26 |
27 |
28 | On an M1 Max
29 |
30 | ```
31 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-10-1-10 25092 49245 ns/op 288 B/op 3 allocs/op
32 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-10-10-10 13317 107361 ns/op 2880 B/op 6 allocs/op
33 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100-1-10 2582 657344 ns/op 300 B/op 3 allocs/op
34 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100-10-10 1051 1286167 ns/op 2880 B/op 6 allocs/op
35 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100-100-10 688 1879785 ns/op 25887 B/op 9 allocs/op
36 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-1000-1-10 202 7782683 ns/op 569 B/op 3 allocs/op
37 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-1000-10-10 92 11442531 ns/op 3461 B/op 6 allocs/op
38 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-1000-100-10 73 15614635 ns/op 25765 B/op 9 allocs/op
39 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-10000-1-10 32 52570010 ns/op 288 B/op 3 allocs/op
40 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-10000-10-10 15 95620514 ns/op 3161 B/op 6 allocs/op
41 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-10000-100-10 9 138641384 ns/op 28024 B/op 10 allocs/op
42 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100000-1-10 2 523147917 ns/op 6624 B/op 5 allocs/op
43 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100000-10-10 2 985437625 ns/op 9280 B/op 10 allocs/op
44 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100000-100-10 1 1123811333 ns/op 33600 B/op 14 allocs/op
45 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-10-1-10 18867 107243 ns/op 288 B/op 3 allocs/op
46 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-10-10-10 8036 202150 ns/op 2880 B/op 6 allocs/op
47 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100-1-10 1106 1541100 ns/op 288 B/op 3 allocs/op
48 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100-10-10 511 2303053 ns/op 2880 B/op 6 allocs/op
49 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100-100-10 327 3784640 ns/op 25664 B/op 9 allocs/op
50 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-1000-1-10 100 12470110 ns/op 288 B/op 3 allocs/op
51 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-1000-10-10 70 21186270 ns/op 2880 B/op 6 allocs/op
52 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-1000-100-10 58 26575966 ns/op 25664 B/op 9 allocs/op
53 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-10000-1-10 16 102998362 ns/op 288 B/op 3 allocs/op
54 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-10000-10-10 7 200548024 ns/op 2880 B/op 6 allocs/op
55 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-10000-100-10 4 290161719 ns/op 25664 B/op 9 allocs/op
56 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100000-1-10 2 1115429562 ns/op 288 B/op 3 allocs/op
57 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100000-10-10 1 1561993375 ns/op 2880 B/op 6 allocs/op
58 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100000-100-10 1 1751069333 ns/op 25664 B/op 9 allocs/op
59 | PASS
60 | ```
61 |
62 | I have done a fair amount of allocation chasing, I think there is a bit more work to do, but actually, it's pretty slow.
63 |
64 | I also think adding [Lempel-Ziv Jaccard Distance](https://arxiv.org/pdf/1708.03346.pdf) is quite promising. We would need to write it in Go or add Cgo bindings to the C version.
--------------------------------------------------------------------------------
/packages/vectorstore/vectorstore_compressed.go:
--------------------------------------------------------------------------------
1 | package vectorstore
2 |
3 | import (
4 | "bytes"
5 | stdgzip "compress/gzip"
6 | "context"
7 | "github.com/stillmatic/gollum/packages/syncpool"
8 | "io"
9 | "sync"
10 |
11 | gzip "github.com/klauspost/compress/gzip"
12 | "github.com/klauspost/compress/zstd"
13 | "github.com/stillmatic/gollum"
14 | )
15 |
16 | // Compressor is a single method interface that returns a compressed representation of an object.
17 | type Compressor interface {
18 | Compress(src []byte) []byte
19 | }
20 |
21 | // GzipCompressor uses the klauspost/compress gzip compressor.
22 | // We generally suggest using this optimized implementation over the stdlib.
23 | type GzipCompressor struct {
24 | pool syncpool.Pool[*gzip.Writer]
25 | }
26 |
27 | // ZstdCompressor uses the klauspost/compress zstd compressor.
28 | type ZstdCompressor struct {
29 | pool syncpool.Pool[*zstd.Encoder]
30 | enc *zstd.Encoder
31 | }
32 |
33 | // StdGzipCompressor uses the std gzip compressor.
34 | type StdGzipCompressor struct {
35 | pool syncpool.Pool[*stdgzip.Writer]
36 | }
37 |
38 | type DummyCompressor struct {
39 | }
40 |
41 | func (g *GzipCompressor) Compress(src []byte) []byte {
42 | gz := g.pool.Get()
43 | var b bytes.Buffer
44 | defer g.pool.Put(gz)
45 |
46 | gz.Reset(&b)
47 | if _, err := gz.Write(src); err != nil {
48 | panic(err)
49 | }
50 | if err := gz.Close(); err != nil {
51 | panic(err)
52 | }
53 | return b.Bytes()
54 | }
55 |
56 | func (g *ZstdCompressor) Compress(src []byte) []byte {
57 | // return g.enc.EncodeAll(src, make([]byte, 0, len(src)))
58 | var b bytes.Buffer
59 | b.Reset()
60 | enc := g.enc
61 | // zstd := g.pool.Get()
62 | enc.Reset(&b)
63 | if _, err := enc.Write(src); err != nil {
64 | panic(err)
65 | }
66 | if err := enc.Flush(); err != nil {
67 | panic(err)
68 | }
69 |
70 | // g.pool.Put(zstd)
71 | return b.Bytes()
72 | }
73 |
74 | func (g *DummyCompressor) Compress(src []byte) []byte {
75 | return src
76 | }
77 |
78 | func (g *StdGzipCompressor) Compress(src []byte) []byte {
79 | var b bytes.Buffer
80 | gz := g.pool.Get()
81 | defer g.pool.Put(gz)
82 | gz.Reset(&b)
83 |
84 | if _, err := gz.Write(src); err != nil {
85 | panic(err)
86 | }
87 | if err := gz.Close(); err != nil {
88 | panic(err)
89 | }
90 | return b.Bytes()
91 | }
92 |
93 | type CompressedDocument struct {
94 | *gollum.Document
95 | Encoded []byte
96 | Unencoded []byte
97 | }
98 |
99 | type CompressedVectorStore struct {
100 | Data []CompressedDocument
101 | Compressor Compressor
102 | }
103 |
104 | // Insert compresses the document and inserts it into the store.
105 | // An alternative implementation would ONLY store the compressed representation and decompress as necessary.
106 | func (ts *CompressedVectorStore) Insert(ctx context.Context, d gollum.Document) error {
107 | bb := bufPool.Get().(*bytes.Buffer)
108 | defer bufPool.Put(bb)
109 | bb.Reset()
110 | bb.WriteString(d.Content)
111 | docBytes := bb.Bytes()
112 | encoded := ts.Compressor.Compress(docBytes)
113 | ts.Data = append(ts.Data, CompressedDocument{Document: &d, Encoded: encoded, Unencoded: docBytes})
114 | return nil
115 | }
116 |
117 | func minMax(val1, val2 float64) (float64, float64) {
118 | if val1 < val2 {
119 | return val1, val2
120 | }
121 | return val2, val1
122 | }
123 |
124 | var bufPool = sync.Pool{
125 | New: func() any {
126 | return new(bytes.Buffer)
127 | },
128 | }
129 |
130 | var spaceBytes = []byte(" ")
131 |
132 | func (cvs *CompressedVectorStore) Query(ctx context.Context, qb QueryRequest) ([]*gollum.Document, error) {
133 | bb := bufPool.Get().(*bytes.Buffer)
134 | defer bufPool.Put(bb)
135 | bb.Reset()
136 | queryBytes := make([]byte, len(qb.Query))
137 | copy(queryBytes, qb.Query)
138 | searchTermEncoded := cvs.Compressor.Compress(queryBytes)
139 |
140 | k := qb.K
141 | if k > len(cvs.Data) {
142 | k = len(cvs.Data)
143 | }
144 | h := make(Heap, 0, k+1)
145 | h.Init(k + 1)
146 |
147 | for _, doc := range cvs.Data {
148 | Cx1 := float64(len(searchTermEncoded))
149 | Cx2 := float64(len(doc.Encoded))
150 | bb.Write(queryBytes)
151 | bb.Write(spaceBytes)
152 | bb.Write(doc.Unencoded)
153 | x1x2 := cvs.Compressor.Compress(bb.Bytes())
154 | Cx1x2 := float64(len(x1x2))
155 | min, max := minMax(Cx1, Cx2)
156 | ncd := (Cx1x2 - min) / (max)
157 | // ncd := 0.5
158 |
159 | node := NodeSimilarity{
160 | Document: doc.Document,
161 | Similarity: float32(ncd),
162 | }
163 |
164 | h.Push(node)
165 | if h.Len() > k {
166 | h.Pop()
167 | }
168 | bb.Reset()
169 | }
170 |
171 | docs := make([]*gollum.Document, k)
172 | for i := range docs {
173 | docs[k-i-1] = h.Pop().Document
174 | }
175 |
176 | return docs, nil
177 | }
178 |
179 | func (cvs *CompressedVectorStore) RetrieveAll(ctx context.Context) ([]gollum.Document, error) {
180 | docs := make([]gollum.Document, len(cvs.Data))
181 | for i, doc := range cvs.Data {
182 | docs[i] = *doc.Document
183 | }
184 | return docs, nil
185 | }
186 |
187 | func NewStdGzipVectorStore() *CompressedVectorStore {
188 | w := io.Discard
189 | return &CompressedVectorStore{
190 | Compressor: &StdGzipCompressor{
191 | pool: syncpool.New[*stdgzip.Writer](func() *stdgzip.Writer {
192 | return stdgzip.NewWriter(w)
193 | }),
194 | },
195 | }
196 | }
197 |
198 | func NewGzipVectorStore() *CompressedVectorStore {
199 | w := io.Discard
200 | return &CompressedVectorStore{
201 | Compressor: &GzipCompressor{
202 | pool: syncpool.New[*gzip.Writer](func() *gzip.Writer {
203 | return gzip.NewWriter(w)
204 | }),
205 | },
206 | }
207 | }
208 |
209 | func NewZstdVectorStore() *CompressedVectorStore {
210 | w := io.Discard
211 | enc, err := zstd.NewWriter(w, zstd.WithEncoderCRC(false))
212 | if err != nil {
213 | panic(err)
214 | }
215 | return &CompressedVectorStore{
216 | Compressor: &ZstdCompressor{
217 | enc: enc,
218 | pool: syncpool.New[*zstd.Encoder](func() *zstd.Encoder {
219 | enc, err := zstd.NewWriter(w, zstd.WithEncoderCRC(false))
220 | if err != nil {
221 | panic(err)
222 | }
223 | return enc
224 | }),
225 | },
226 | }
227 | }
228 |
229 | func NewDummyVectorStore() *CompressedVectorStore {
230 | return &CompressedVectorStore{
231 | Compressor: &DummyCompressor{},
232 | }
233 | }
234 |
--------------------------------------------------------------------------------
/internal/mocks/llm.go:
--------------------------------------------------------------------------------
1 | // Code generated by MockGen. DO NOT EDIT.
2 | // Source: llm.go
3 |
4 | // Package mock_gollum is a generated GoMock package.
5 | package mock_gollum
6 |
7 | import (
8 | context "context"
9 | reflect "reflect"
10 |
11 | openai "github.com/sashabaranov/go-openai"
12 | gomock "go.uber.org/mock/gomock"
13 | )
14 |
15 | // MockCompleter is a mock of Completer interface.
16 | type MockCompleter struct {
17 | ctrl *gomock.Controller
18 | recorder *MockCompleterMockRecorder
19 | }
20 |
21 | // MockCompleterMockRecorder is the mock recorder for MockCompleter.
22 | type MockCompleterMockRecorder struct {
23 | mock *MockCompleter
24 | }
25 |
26 | // NewMockCompleter creates a new mock instance.
27 | func NewMockCompleter(ctrl *gomock.Controller) *MockCompleter {
28 | mock := &MockCompleter{ctrl: ctrl}
29 | mock.recorder = &MockCompleterMockRecorder{mock}
30 | return mock
31 | }
32 |
33 | // EXPECT returns an object that allows the caller to indicate expected use.
34 | func (m *MockCompleter) EXPECT() *MockCompleterMockRecorder {
35 | return m.recorder
36 | }
37 |
38 | // CreateCompletion mocks base method.
39 | func (m *MockCompleter) CreateCompletion(arg0 context.Context, arg1 openai.CompletionRequest) (openai.CompletionResponse, error) {
40 | m.ctrl.T.Helper()
41 | ret := m.ctrl.Call(m, "CreateCompletion", arg0, arg1)
42 | ret0, _ := ret[0].(openai.CompletionResponse)
43 | ret1, _ := ret[1].(error)
44 | return ret0, ret1
45 | }
46 |
47 | // CreateCompletion indicates an expected call of CreateCompletion.
48 | func (mr *MockCompleterMockRecorder) CreateCompletion(arg0, arg1 interface{}) *gomock.Call {
49 | mr.mock.ctrl.T.Helper()
50 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCompletion", reflect.TypeOf((*MockCompleter)(nil).CreateCompletion), arg0, arg1)
51 | }
52 |
53 | // MockChatCompleter is a mock of ChatCompleter interface.
54 | type MockChatCompleter struct {
55 | ctrl *gomock.Controller
56 | recorder *MockChatCompleterMockRecorder
57 | }
58 |
59 | // MockChatCompleterMockRecorder is the mock recorder for MockChatCompleter.
60 | type MockChatCompleterMockRecorder struct {
61 | mock *MockChatCompleter
62 | }
63 |
64 | // NewMockChatCompleter creates a new mock instance.
65 | func NewMockChatCompleter(ctrl *gomock.Controller) *MockChatCompleter {
66 | mock := &MockChatCompleter{ctrl: ctrl}
67 | mock.recorder = &MockChatCompleterMockRecorder{mock}
68 | return mock
69 | }
70 |
71 | // EXPECT returns an object that allows the caller to indicate expected use.
72 | func (m *MockChatCompleter) EXPECT() *MockChatCompleterMockRecorder {
73 | return m.recorder
74 | }
75 |
76 | // CreateChatCompletion mocks base method.
77 | func (m *MockChatCompleter) CreateChatCompletion(arg0 context.Context, arg1 openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
78 | m.ctrl.T.Helper()
79 | ret := m.ctrl.Call(m, "CreateChatCompletion", arg0, arg1)
80 | ret0, _ := ret[0].(openai.ChatCompletionResponse)
81 | ret1, _ := ret[1].(error)
82 | return ret0, ret1
83 | }
84 |
85 | // CreateChatCompletion indicates an expected call of CreateChatCompletion.
86 | func (mr *MockChatCompleterMockRecorder) CreateChatCompletion(arg0, arg1 interface{}) *gomock.Call {
87 | mr.mock.ctrl.T.Helper()
88 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateChatCompletion", reflect.TypeOf((*MockChatCompleter)(nil).CreateChatCompletion), arg0, arg1)
89 | }
90 |
91 | // MockEmbedder is a mock of Embedder interface.
92 | type MockEmbedder struct {
93 | ctrl *gomock.Controller
94 | recorder *MockEmbedderMockRecorder
95 | }
96 |
97 | // MockEmbedderMockRecorder is the mock recorder for MockEmbedder.
98 | type MockEmbedderMockRecorder struct {
99 | mock *MockEmbedder
100 | }
101 |
102 | // NewMockEmbedder creates a new mock instance.
103 | func NewMockEmbedder(ctrl *gomock.Controller) *MockEmbedder {
104 | mock := &MockEmbedder{ctrl: ctrl}
105 | mock.recorder = &MockEmbedderMockRecorder{mock}
106 | return mock
107 | }
108 |
109 | // EXPECT returns an object that allows the caller to indicate expected use.
110 | func (m *MockEmbedder) EXPECT() *MockEmbedderMockRecorder {
111 | return m.recorder
112 | }
113 |
114 | // CreateEmbeddings mocks base method.
115 | func (m *MockEmbedder) CreateEmbeddings(arg0 context.Context, arg1 openai.EmbeddingRequest) (openai.EmbeddingResponse, error) {
116 | m.ctrl.T.Helper()
117 | ret := m.ctrl.Call(m, "CreateEmbeddings", arg0, arg1)
118 | ret0, _ := ret[0].(openai.EmbeddingResponse)
119 | ret1, _ := ret[1].(error)
120 | return ret0, ret1
121 | }
122 |
123 | // CreateEmbeddings indicates an expected call of CreateEmbeddings.
124 | func (mr *MockEmbedderMockRecorder) CreateEmbeddings(arg0, arg1 interface{}) *gomock.Call {
125 | mr.mock.ctrl.T.Helper()
126 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEmbeddings", reflect.TypeOf((*MockEmbedder)(nil).CreateEmbeddings), arg0, arg1)
127 | }
128 |
129 | // MockModerator is a mock of Moderator interface.
130 | type MockModerator struct {
131 | ctrl *gomock.Controller
132 | recorder *MockModeratorMockRecorder
133 | }
134 |
135 | // MockModeratorMockRecorder is the mock recorder for MockModerator.
136 | type MockModeratorMockRecorder struct {
137 | mock *MockModerator
138 | }
139 |
140 | // NewMockModerator creates a new mock instance.
141 | func NewMockModerator(ctrl *gomock.Controller) *MockModerator {
142 | mock := &MockModerator{ctrl: ctrl}
143 | mock.recorder = &MockModeratorMockRecorder{mock}
144 | return mock
145 | }
146 |
147 | // EXPECT returns an object that allows the caller to indicate expected use.
148 | func (m *MockModerator) EXPECT() *MockModeratorMockRecorder {
149 | return m.recorder
150 | }
151 |
152 | // Moderations mocks base method.
153 | func (m *MockModerator) Moderations(arg0 context.Context, arg1 openai.ModerationRequest) (openai.ModerationResponse, error) {
154 | m.ctrl.T.Helper()
155 | ret := m.ctrl.Call(m, "Moderations", arg0, arg1)
156 | ret0, _ := ret[0].(openai.ModerationResponse)
157 | ret1, _ := ret[1].(error)
158 | return ret0, ret1
159 | }
160 |
161 | // Moderations indicates an expected call of Moderations.
162 | func (mr *MockModeratorMockRecorder) Moderations(arg0, arg1 interface{}) *gomock.Call {
163 | mr.mock.ctrl.T.Helper()
164 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Moderations", reflect.TypeOf((*MockModerator)(nil).Moderations), arg0, arg1)
165 | }
166 |
--------------------------------------------------------------------------------
/math_test.go:
--------------------------------------------------------------------------------
1 | package gollum_test
2 |
3 | import (
4 | "fmt"
5 | "testing"
6 |
7 | math "github.com/chewxy/math32"
8 | "github.com/stillmatic/gollum/internal/testutil"
9 | "github.com/viterin/vek/vek32"
10 | )
11 |
12 | func weaviateCosineSimilarity(a []float32, b []float32) (float32, error) {
13 | if len(a) != len(b) {
14 | return 0, fmt.Errorf("vectors have different dimensions")
15 | }
16 |
17 | var (
18 | sumProduct float32
19 | sumASquare float32
20 | sumBSquare float32
21 | )
22 |
23 | for i := range a {
24 | sumProduct += (a[i] * b[i])
25 | sumASquare += (a[i] * a[i])
26 | sumBSquare += (b[i] * b[i])
27 | }
28 |
29 | return sumProduct / (math.Sqrt(sumASquare) * math.Sqrt(sumBSquare)), nil
30 | }
31 |
32 | func BenchmarkCosSim(b *testing.B) {
33 | ns := []int{256, 512, 768, 1024}
34 | b.Run("weaviate", func(b *testing.B) {
35 | for _, n := range ns {
36 | A := testutil.GetRandomEmbedding(n)
37 | B := testutil.GetRandomEmbedding(n)
38 | b.ResetTimer()
39 | b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
40 | for i := 0; i < b.N; i++ {
41 | f, err := weaviateCosineSimilarity(A, B)
42 | if err != nil {
43 | panic(err)
44 | }
45 | _ = f
46 | }
47 | })
48 | }
49 | })
50 | b.Run("vek", func(b *testing.B) {
51 | for _, n := range ns {
52 | A := testutil.GetRandomEmbedding(n)
53 | B := testutil.GetRandomEmbedding(n)
54 | b.ResetTimer()
55 | b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
56 | for i := 0; i < b.N; i++ {
57 | f := vek32.CosineSimilarity(A, B)
58 | _ = f
59 | }
60 | })
61 | }
62 | })
63 | }
64 |
65 | // func BenchmarkGonum32(b *testing.B) {
66 | // vs := []int{256, 512, 768, 1024}
67 | // for _, n := range vs {
68 | // A := getRandomEmbedding(n)
69 | // B := getRandomEmbedding(n)
70 | // b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
71 | // for i := 0; i < b.N; i++ {
72 | // f, err := gallant.GonumSim(A, B)
73 | // if err != nil {
74 | // panic(err)
75 | // }
76 | // _ = f
77 | // }
78 | // })
79 | // }
80 | // }
81 |
82 | // func BenchmarkGonum64(b *testing.B) {
83 | // vs := []int{256, 512, 768, 1024}
84 | // for _, n := range vs {
85 | // A := getRandomEmbedding(n)
86 | // B := getRandomEmbedding(n)
87 | // a_ := make([]float64, len(A))
88 | // b_ := make([]float64, len(B))
89 | // for i := range A {
90 | // a_[i] = float64(A[i])
91 | // b_[i] = float64(B[i])
92 | // }
93 | // b.Run("Naive", func(b *testing.B) {
94 | // for i := 0; i < b.N; i++ {
95 | // f, err := gallant.GonumSim64(a_, b_)
96 | // if err != nil {
97 | // panic(err)
98 | // }
99 | // _ = f
100 | // }
101 | // })
102 | // am := mat.NewDense(1, len(a_), a_)
103 | // bm := mat.NewDense(1, len(b_), b_)
104 | // b.Run("prealloc", func(b *testing.B) {
105 | // for i := 0; i < b.N; i++ {
106 | // var dot mat.Dense
107 | // dot.Mul(am, bm.T())
108 | // _ = dot.At(0, 0)
109 | // }
110 | // })
111 | // }
112 | // }
113 |
114 | // goos: linux
115 | // goarch: amd64
116 | // cpu: AMD Ryzen 9 7950X 16-Core Processor
117 | // BenchmarkWeaviateCosSim/256-32 8028375 154.8 ns/op 0 B/op 0 allocs/op
118 | // BenchmarkWeaviateCosSim/512-32 3958342 300.1 ns/op 0 B/op 0 allocs/op
119 | // BenchmarkWeaviateCosSim/768-32 2677993 456.1 ns/op 0 B/op 0 allocs/op
120 | // BenchmarkWeaviateCosSim/1024-32 2002258 601.6 ns/op 0 B/op 0 allocs/op
121 | // BenchmarkVek32CosSim/256-32 81166414 15.43 ns/op 0 B/op 0 allocs/op
122 | // BenchmarkVek32CosSim/512-32 46376474 26.72 ns/op 0 B/op 0 allocs/op
123 | // BenchmarkVek32CosSim/768-32 30476739 39.65 ns/op 0 B/op 0 allocs/op
124 | // BenchmarkVek32CosSim/1024-32 22698370 51.63 ns/op 0 B/op 0 allocs/op
125 | // BenchmarkGonum32/256-32 1664224 738.9 ns/op 4312 B/op 7 allocs/op
126 | // BenchmarkGonum32/512-32 914896 1212 ns/op 8408 B/op 7 allocs/op
127 | // BenchmarkGonum32/768-32 718142 1634 ns/op 12504 B/op 7 allocs/op
128 | // BenchmarkGonum32/1024-32 573609 2885 ns/op 16600 B/op 7 allocs/op
129 | // BenchmarkGonum64/Naive-32 6329708 189.8 ns/op 216 B/op 5 allocs/op
130 | // BenchmarkGonum64/prealloc-32 9126764 144.8 ns/op 88 B/op 3 allocs/op
131 | // BenchmarkGonum64/Naive#01-32 5176160 222.0 ns/op 216 B/op 5 allocs/op
132 | // BenchmarkGonum64/prealloc#01-32 6484794 185.2 ns/op 88 B/op 3 allocs/op
133 | // BenchmarkGonum64/Naive#02-32 4569102 266.8 ns/op 216 B/op 5 allocs/op
134 | // BenchmarkGonum64/prealloc#02-32 5566400 225.0 ns/op 88 B/op 3 allocs/op
135 | // BenchmarkGonum64/Naive#03-32 3910498 300.0 ns/op 216 B/op 5 allocs/op
136 | // BenchmarkGonum64/prealloc#03-32 4585336 265.9 ns/op 88 B/op 3 allocs/op
137 |
138 | // goos: darwin
139 | // goarch: arm64
140 | // pkg: github.com/stillmatic/gollum
141 | // BenchmarkCosSim/weaviate/256-10 4330524 274.5 ns/op 0 B/op 0 allocs/op
142 | // BenchmarkCosSim/weaviate/512-10 1995426 605.6 ns/op 0 B/op 0 allocs/op
143 | // BenchmarkCosSim/weaviate/768-10 1312820 917.6 ns/op 0 B/op 0 allocs/op
144 | // BenchmarkCosSim/weaviate/1024-10 973432 1232 ns/op 0 B/op 0 allocs/op
145 | // BenchmarkCosSim/vek/256-10 4335747 272.0 ns/op 0 B/op 0 allocs/op
146 | // BenchmarkCosSim/vek/512-10 2027366 596.2 ns/op 0 B/op 0 allocs/op
147 | // BenchmarkCosSim/vek/768-10 1310983 925.2 ns/op 0 B/op 0 allocs/op
148 | // BenchmarkCosSim/vek/1024-10 969460 1233 ns/op 0 B/op 0 allocs/op
149 | // PASS
150 |
151 | // General high level takeaways:
152 | // - vek32 is best option if SIMD is available
153 | // - gonum64 is fast and scales better than weaviate but requires allocs and using f64
154 | // - gonum64 probably has SIMD intrinsics?
155 | // - weaviate implementation is 'fine', appears to have linear scaling with vector size
156 | // - for something in the hotpath I think it's worth it to use vek32 when possible
157 | // - on mac, vek32 and weaviate are more or less identical
158 |
--------------------------------------------------------------------------------
/packages/hyde/hyde.go:
--------------------------------------------------------------------------------
1 | package hyde
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/stillmatic/gollum/packages/vectorstore"
7 |
8 | "github.com/pkg/errors"
9 | "github.com/sashabaranov/go-openai"
10 | "github.com/stillmatic/gollum"
11 | "github.com/viterin/vek/vek32"
12 | )
13 |
14 | type Prompter interface {
15 | BuildPrompt(context.Context, string) string
16 | }
17 |
18 | type Generator interface {
19 | Generate(ctx context.Context, input string, n int) ([]string, error)
20 | }
21 |
22 | type Encoder interface {
23 | Encode(context.Context, string) ([]float32, error)
24 | EncodeBatch(context.Context, []string) ([][]float32, error)
25 | }
26 |
27 | type Searcher interface {
28 | Search(context.Context, []float32, int) ([]*gollum.Document, error)
29 | }
30 |
31 | type Hyde struct {
32 | prompter Prompter
33 | generator Generator
34 | encoder Encoder
35 | searcher Searcher
36 | }
37 |
38 | type ZeroShotPrompter struct {
39 | template string
40 | }
41 |
42 | func NewZeroShotPrompter(template string) *ZeroShotPrompter {
43 | // something like
44 | // Roleplay as a character. Write a short biographical answer to the question.
45 | // Q: %s
46 | // A:
47 | return &ZeroShotPrompter{
48 | template: template,
49 | }
50 | }
51 |
52 | func (z *ZeroShotPrompter) BuildPrompt(ctx context.Context, prompt string) string {
53 | // fill in template values
54 | return fmt.Sprintf(z.template, prompt)
55 | }
56 |
57 | type LLMGenerator struct {
58 | Model gollum.ChatCompleter
59 | }
60 |
61 | func NewLLMGenerator(model gollum.ChatCompleter) *LLMGenerator {
62 | return &LLMGenerator{
63 | Model: model,
64 | }
65 | }
66 |
67 | func (l *LLMGenerator) Generate(ctx context.Context, prompt string, n int) ([]string, error) {
68 | createChatCompletionReq := openai.ChatCompletionRequest{
69 | Messages: []openai.ChatCompletionMessage{
70 | {
71 | Role: openai.ChatMessageRoleUser,
72 | Content: prompt,
73 | },
74 | },
75 | Model: openai.GPT3Dot5Turbo,
76 | N: n,
77 | // hyperparams from https://github.com/texttron/hyde/blob/74101c5157e04f7b57559e7da8ef4a4e5b6da82b/src/hyde/generator.py#LL15C121-L15C121
78 | Temperature: 0.9,
79 | MaxTokens: 512,
80 | TopP: 1,
81 | FrequencyPenalty: 0,
82 | PresencePenalty: 0,
83 | Stop: []string{"\n\n\n"},
84 | }
85 | resp, err := l.Model.CreateChatCompletion(ctx, createChatCompletionReq)
86 | if err != nil {
87 | return make([]string, 0), err
88 | }
89 | replies := make([]string, n)
90 | for i, choice := range resp.Choices {
91 | replies[i] = choice.Message.Content
92 | }
93 | return replies, nil
94 | }
95 |
96 | type LLMEncoder struct {
97 | Model gollum.Embedder
98 | }
99 |
100 | func NewLLMEncoder(model gollum.Embedder) *LLMEncoder {
101 | return &LLMEncoder{
102 | Model: model,
103 | }
104 | }
105 |
106 | func (l *LLMEncoder) Encode(ctx context.Context, query string) ([]float32, error) {
107 | createEmbeddingReq := openai.EmbeddingRequest{
108 | Input: []string{query},
109 | // TODO: allow customization
110 | Model: openai.AdaEmbeddingV2,
111 | }
112 | resp, err := l.Model.CreateEmbeddings(ctx, createEmbeddingReq)
113 | if err != nil {
114 | return make([]float32, 0), err
115 | }
116 | return resp.Data[0].Embedding, nil
117 | }
118 |
119 | func (l *LLMEncoder) EncodeBatch(ctx context.Context, docs []string) ([][]float32, error) {
120 | createEmbeddingReq := openai.EmbeddingRequest{
121 | Input: docs,
122 | // TODO: allow customization
123 | Model: openai.AdaEmbeddingV2,
124 | }
125 | resp, err := l.Model.CreateEmbeddings(ctx, createEmbeddingReq)
126 | if err != nil {
127 | return make([][]float32, 0), err
128 | }
129 | embeddings := make([][]float32, len(resp.Data))
130 | for i, data := range resp.Data {
131 | embeddings[i] = data.Embedding
132 | }
133 | return embeddings, nil
134 | }
135 |
136 | type VectorSearcher struct {
137 | vs vectorstore.VectorStore
138 | }
139 |
140 | func NewVectorSearcher(vs vectorstore.VectorStore) *VectorSearcher {
141 | return &VectorSearcher{
142 | vs: vs,
143 | }
144 | }
145 |
146 | func (v *VectorSearcher) Search(ctx context.Context, query []float32, n int) ([]*gollum.Document, error) {
147 | qb := vectorstore.QueryRequest{
148 | EmbeddingFloats: query,
149 | K: n,
150 | }
151 | return v.vs.Query(ctx, qb)
152 | }
153 |
154 | func NewHyde(prompter Prompter, generator Generator, encoder Encoder, searcher Searcher) *Hyde {
155 | return &Hyde{
156 | prompter: prompter,
157 | generator: generator,
158 | encoder: encoder,
159 | searcher: searcher,
160 | }
161 | }
162 |
163 | // Prompt builds a prompt from the given string.
164 | func (h *Hyde) Prompt(ctx context.Context, prompt string) string {
165 | return h.prompter.BuildPrompt(ctx, prompt)
166 | }
167 |
168 | // Generate generates n hypothesis documents.
169 | func (h *Hyde) Generate(ctx context.Context, prompt string, n int) ([]string, error) {
170 | searchPrompt := h.prompter.BuildPrompt(ctx, prompt)
171 | return h.generator.Generate(ctx, searchPrompt, n)
172 | }
173 |
174 | // Encode encodes a query and a list of documents into a single embedding.
175 | func (h *Hyde) Encode(ctx context.Context, query string, docs []string) ([]float32, error) {
176 | docs = append(docs, query)
177 | embeddings, err := h.encoder.EncodeBatch(ctx, docs)
178 | if err != nil {
179 | return make([]float32, 0), errors.Wrap(err, "error encoding batch")
180 | }
181 | embedDim := len(embeddings[0])
182 | numEmbeddings := float32(len(embeddings))
183 | // mean pooling of response embeddings
184 | avgEmbedding := make([]float32, embedDim)
185 | for _, embedding := range embeddings {
186 | vek32.Add_Inplace(avgEmbedding, embedding)
187 | }
188 | // unclear if this is faster or slower than naive approach. a little cleaner code though.
189 | vek32.DivNumber_Inplace(avgEmbedding, numEmbeddings)
190 | return avgEmbedding, nil
191 | }
192 |
193 | func (h *Hyde) Search(ctx context.Context, hydeVector []float32, k int) ([]*gollum.Document, error) {
194 | return h.searcher.Search(ctx, hydeVector, k)
195 | }
196 |
197 | func (h *Hyde) SearchEndToEnd(ctx context.Context, query string, k int) ([]*gollum.Document, error) {
198 | // generate n hypothesis documents
199 | docs, err := h.Generate(ctx, query, k)
200 | if err != nil {
201 | return nil, errors.Wrap(err, "error generating hypothetical documents")
202 | }
203 | // encode query and hypothesis documents
204 | hydeVector, err := h.Encode(ctx, query, docs)
205 | if err != nil {
206 | return make([]*gollum.Document, 0), errors.Wrap(err, "error encoding hypothetical documents")
207 | }
208 | // search for the most similar documents
209 | return h.Search(ctx, hydeVector, k)
210 | }
211 |
--------------------------------------------------------------------------------
/packages/vectorstore/vectorstore_test.go:
--------------------------------------------------------------------------------
1 | package vectorstore_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | vectorstore2 "github.com/stillmatic/gollum/packages/vectorstore"
7 | "math/rand"
8 | "testing"
9 |
10 | "github.com/sashabaranov/go-openai"
11 | "github.com/stillmatic/gollum"
12 | mock_gollum "github.com/stillmatic/gollum/internal/mocks"
13 | "github.com/stretchr/testify/assert"
14 | "go.uber.org/mock/gomock"
15 | "gocloud.dev/blob/fileblob"
16 | )
17 |
18 | func getRandomEmbedding(n int) []float32 {
19 | vec := make([]float32, n)
20 | for i := range vec {
21 | vec[i] = rand.Float32()
22 | }
23 | return vec
24 | }
25 |
26 | // setup with godotenv load
27 | func initialize(tb testing.TB) (*mock_gollum.MockEmbedder, *vectorstore2.MemoryVectorStore) {
28 | tb.Helper()
29 |
30 | ctrl := gomock.NewController(tb)
31 | oai := mock_gollum.NewMockEmbedder(ctrl)
32 | ctx := context.Background()
33 | bucket, err := fileblob.OpenBucket("testdata", nil)
34 | assert.NoError(tb, err)
35 | mvs, err := vectorstore2.NewMemoryVectorStoreFromDisk(ctx, bucket, "simple_store.json", oai)
36 | if err != nil {
37 | fmt.Println(err)
38 | mvs = vectorstore2.NewMemoryVectorStore(oai)
39 | testStrs := []string{"Apple", "Orange", "Basketball"}
40 | for i, s := range testStrs {
41 | mv := gollum.NewDocumentFromString(s)
42 | expectedReq := openai.EmbeddingRequest{
43 | Input: []string{s},
44 | Model: openai.AdaEmbeddingV2,
45 | }
46 | val := float64(i) + 0.1
47 | expectedResp := openai.EmbeddingResponse{
48 | Data: []openai.Embedding{{Embedding: []float32{float32(0.1), float32(val), float32(val)}}},
49 | }
50 | oai.EXPECT().CreateEmbeddings(ctx, expectedReq).Return(expectedResp, nil)
51 | err := mvs.Insert(ctx, mv)
52 | assert.NoError(tb, err)
53 | }
54 | err := mvs.Persist(ctx, bucket, "simple_store.json")
55 | assert.NoError(tb, err)
56 | }
57 | return oai, mvs
58 | }
59 |
60 | // TestRetrieval tests inserting embeddings and retrieving them
61 | func TestMemoryVectorStore(t *testing.T) {
62 | mockllm, mvs := initialize(t)
63 | ctx := context.Background()
64 | t.Run("LoadFromDisk", func(t *testing.T) {
65 | t.Log(mvs.Documents)
66 | assert.Equal(t, 3, len(mvs.Documents))
67 | // check that an ID is in the map
68 | testStrs := []string{"Apple", "Orange", "Basketball"}
69 | // test that all the strings are in the documents
70 | for _, s := range mvs.Documents {
71 | found := false
72 | for _, t := range testStrs {
73 | if s.Content == t {
74 | found = true
75 | }
76 | }
77 | assert.True(t, found)
78 | }
79 | })
80 |
81 | // should return apple and orange first.
82 | t.Run("QueryWithQuery", func(t *testing.T) {
83 | k := 2
84 | qb := vectorstore2.QueryRequest{
85 | Query: "favorite fruit?",
86 | K: k,
87 | }
88 | expectedCreateReq := openai.EmbeddingRequest{
89 | Input: []string{"favorite fruit?"},
90 | Model: openai.AdaEmbeddingV2,
91 | }
92 | expectedCreateResp := openai.EmbeddingResponse{
93 | Data: []openai.Embedding{
94 | {
95 | Embedding: []float32{float32(0.1), float32(0.1), float32(0.1)},
96 | },
97 | },
98 | }
99 | mockllm.EXPECT().CreateEmbeddings(ctx, expectedCreateReq).Return(expectedCreateResp, nil)
100 | resp, err := mvs.Query(ctx, qb)
101 | assert.NoError(t, err)
102 | assert.Equal(t, k, len(resp))
103 | assert.Equal(t, "Apple", resp[0].Content)
104 | assert.Equal(t, "Orange", resp[1].Content)
105 | })
106 |
107 | // This should return basketball because the embedding str should override the query
108 | t.Run("QueryWithEmbedding", func(t *testing.T) {
109 | k := 1
110 | qb := vectorstore2.QueryRequest{
111 | Query: "What is your favorite fruit",
112 | EmbeddingStrings: []string{"favorite sport?"},
113 | K: k,
114 | }
115 | expectedCreateReq := openai.EmbeddingRequest{
116 | Input: []string{"favorite sport?"},
117 | Model: openai.AdaEmbeddingV2,
118 | }
119 | expectedCreateResp := openai.EmbeddingResponse{
120 | Data: []openai.Embedding{
121 | {Embedding: []float32{float32(0.1), float32(2.11), float32(2.11)}},
122 | }}
123 | mockllm.EXPECT().CreateEmbeddings(ctx, expectedCreateReq).Return(expectedCreateResp, nil)
124 | resp, err := mvs.Query(ctx, qb)
125 | assert.NoError(t, err)
126 | assert.Equal(t, k, len(resp))
127 | assert.Equal(t, "Basketball", resp[0].Content)
128 | })
129 | }
130 |
131 | type MockEmbedder struct{}
132 |
133 | func (m MockEmbedder) CreateEmbeddings(ctx context.Context, req openai.EmbeddingRequest) (openai.EmbeddingResponse, error) {
134 | resp := openai.EmbeddingResponse{
135 | Data: []openai.Embedding{
136 | {Embedding: getRandomEmbedding(1536)},
137 | },
138 | }
139 | return resp, nil
140 | }
141 |
142 | func BenchmarkMemoryVectorStore(b *testing.B) {
143 | llm := mock_gollum.NewMockEmbedder(gomock.NewController(b))
144 | ctx := context.Background()
145 |
146 | nValues := []int{10, 100, 1_000, 10_000, 100_000, 1_000_000}
147 | kValues := []int{1, 10, 100}
148 | dim := 768
149 | for _, n := range nValues {
150 | b.Run(fmt.Sprintf("BenchmarkInsert-n=%v", n), func(b *testing.B) {
151 | mvs := vectorstore2.NewMemoryVectorStore(llm)
152 | for i := 0; i < b.N; i++ {
153 | for j := 0; j < n; j++ {
154 | mv := gollum.Document{
155 | ID: fmt.Sprintf("%v", j),
156 | Content: "test",
157 | Embedding: getRandomEmbedding(dim),
158 | }
159 | mvs.Insert(ctx, mv)
160 | }
161 | }
162 | })
163 | for _, k := range kValues {
164 | if k <= n {
165 | b.Run(fmt.Sprintf("BenchmarkQuery-n=%v-k=%v", n, k), func(b *testing.B) {
166 | mvs := vectorstore2.NewMemoryVectorStore(llm)
167 | for j := 0; j < n; j++ {
168 | mv := gollum.Document{
169 | ID: fmt.Sprintf("%v", j),
170 | Content: "test",
171 | Embedding: getRandomEmbedding(dim),
172 | }
173 | mvs.Insert(ctx, mv)
174 | }
175 | qb := vectorstore2.QueryRequest{
176 | EmbeddingFloats: getRandomEmbedding(dim),
177 | K: k,
178 | }
179 | b.ResetTimer()
180 | for i := 0; i < b.N; i++ {
181 | _, err := mvs.Query(ctx, qb)
182 | assert.NoError(b, err)
183 | }
184 | })
185 | }
186 | }
187 | }
188 | }
189 |
190 | func BenchmarkHeap(b *testing.B) {
191 | // Create a sample Heap.
192 |
193 | ks := []int{1, 10, 100}
194 |
195 | for _, k := range ks {
196 | var h vectorstore2.Heap
197 | h.Init(k)
198 | b.Run(fmt.Sprintf("BenchmarkHeapPush-k=%v", k), func(b *testing.B) {
199 | for i := 0; i < b.N; i++ {
200 | doc := &gollum.Document{}
201 | similarity := rand.Float32()
202 | ns := vectorstore2.NodeSimilarity{Document: doc, Similarity: similarity}
203 | h.Push(ns)
204 | if h.Len() > k {
205 | h.Pop()
206 | }
207 | }
208 | })
209 | }
210 | }
211 |
--------------------------------------------------------------------------------
/packages/llm/providers/vertex/vertex.go:
--------------------------------------------------------------------------------
1 | // Package vertex implements the Vertex AI api
2 | // it is largely similar to the Google "ai studio" provider but uses a different library...
3 | package vertex
4 |
5 | import (
6 | "cloud.google.com/go/vertexai/genai"
7 | "context"
8 | "fmt"
9 | "github.com/pkg/errors"
10 | "github.com/stillmatic/gollum/packages/llm"
11 | "google.golang.org/api/iterator"
12 | "log"
13 | )
14 |
15 | type VertexAIProvider struct {
16 | client *genai.Client
17 | }
18 |
19 | func NewVertexAIProvider(ctx context.Context, projectID, location string) (*VertexAIProvider, error) {
20 | client, err := genai.NewClient(ctx, projectID, location)
21 | if err != nil {
22 | return nil, errors.Wrap(err, "failed to create Vertex AI client")
23 | }
24 |
25 | ccIter := client.ListCachedContents(ctx)
26 | for {
27 | cc, err := ccIter.Next()
28 | if err == iterator.Done {
29 | break
30 | }
31 | if err != nil {
32 | return nil, errors.Wrap(err, "failed to list cached contents")
33 | }
34 | log.Printf("Cached content: %v", cc)
35 | }
36 |
37 | return &VertexAIProvider{
38 | client: client,
39 | }, nil
40 | }
41 |
42 | func (p *VertexAIProvider) getModel(req llm.InferRequest) *genai.GenerativeModel {
43 | // this does NOT validate if the model name is valid, that is done at inference time.
44 | model := p.client.GenerativeModel(req.ModelConfig.ModelName)
45 | model.SetTemperature(req.MessageOptions.Temperature)
46 | model.SetMaxOutputTokens(int32(req.MessageOptions.MaxTokens))
47 |
48 | return model
49 | }
50 |
51 | func (p *VertexAIProvider) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) {
52 | if len(req.Messages) > 1 {
53 | return p.generateResponseMultiTurn(ctx, req)
54 | }
55 | return p.generateResponseSingleTurn(ctx, req)
56 | }
57 |
58 | func (p *VertexAIProvider) generateResponseSingleTurn(ctx context.Context, req llm.InferRequest) (string, error) {
59 | model := p.getModel(req)
60 | parts := messageToParts(req.Messages[0])
61 |
62 | resp, err := model.GenerateContent(ctx, parts...)
63 | if err != nil {
64 | return "", errors.Wrap(err, "failed to generate content")
65 | }
66 |
67 | return flattenResponse(resp), nil
68 | }
69 |
70 | func (p *VertexAIProvider) generateResponseMultiTurn(ctx context.Context, req llm.InferRequest) (string, error) {
71 | model := p.getModel(req)
72 |
73 | msgs, sysInstr := multiTurnMessageToParts(req.Messages[:len(req.Messages)-1])
74 | if sysInstr != nil {
75 | model.SystemInstruction = sysInstr
76 | }
77 |
78 | cs := model.StartChat()
79 | cs.History = msgs
80 | mostRecentMessage := req.Messages[len(req.Messages)-1]
81 |
82 | // Send the last message
83 | resp, err := cs.SendMessage(ctx, genai.Text(mostRecentMessage.Content))
84 | if err != nil {
85 | return "", errors.Wrap(err, "failed to send message in chat")
86 | }
87 |
88 | return flattenResponse(resp), nil
89 | }
90 |
91 | func (p *VertexAIProvider) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) {
92 | if len(req.Messages) > 1 {
93 | return p.generateResponseAsyncMultiTurn(ctx, req)
94 | }
95 | return p.generateResponseAsyncSingleTurn(ctx, req)
96 | }
97 |
98 | func (p *VertexAIProvider) generateResponseAsyncSingleTurn(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) {
99 | outChan := make(chan llm.StreamDelta)
100 |
101 | go func() {
102 | defer close(outChan)
103 |
104 | model := p.getModel(req)
105 | parts := messageToParts(req.Messages[0])
106 |
107 | iter := model.GenerateContentStream(ctx, parts...)
108 |
109 | for {
110 | resp, err := iter.Next()
111 | if errors.Is(err, iterator.Done) {
112 | outChan <- llm.StreamDelta{EOF: true}
113 | break
114 | }
115 | if err != nil {
116 | log.Printf("Error from Vertex AI stream: %v", err)
117 | return
118 | }
119 |
120 | content := flattenResponse(resp)
121 | if content != "" {
122 | select {
123 | case <-ctx.Done():
124 | return
125 | case outChan <- llm.StreamDelta{Text: content}:
126 | }
127 | }
128 | }
129 | }()
130 |
131 | return outChan, nil
132 | }
133 |
134 | func (p *VertexAIProvider) generateResponseAsyncMultiTurn(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) {
135 | outChan := make(chan llm.StreamDelta)
136 |
137 | go func() {
138 | defer close(outChan)
139 |
140 | model := p.getModel(req)
141 | cs := model.StartChat()
142 |
143 | // Add previous messages to chat history
144 | for _, msg := range req.Messages[:len(req.Messages)-1] {
145 | parts := messageToParts(msg)
146 | cs.History = append(cs.History, &genai.Content{
147 | Parts: parts,
148 | Role: msg.Role,
149 | })
150 | }
151 |
152 | // Send the last message
153 | lastMsg := req.Messages[len(req.Messages)-1]
154 | iter := cs.SendMessageStream(ctx, messageToParts(lastMsg)...)
155 |
156 | for {
157 | resp, err := iter.Next()
158 | if errors.Is(err, iterator.Done) {
159 | outChan <- llm.StreamDelta{EOF: true}
160 | break
161 | }
162 | if err != nil {
163 | log.Printf("Error from Vertex AI stream: %v", err)
164 | return
165 | }
166 |
167 | content := flattenResponse(resp)
168 | if content != "" {
169 | select {
170 | case <-ctx.Done():
171 | return
172 | case outChan <- llm.StreamDelta{Text: content}:
173 | }
174 | }
175 | }
176 | }()
177 |
178 | return outChan, nil
179 | }
180 |
181 | func messageToParts(message llm.InferMessage) []genai.Part {
182 | parts := []genai.Part{genai.Text(message.Content)}
183 | if message.Image != nil && len(message.Image) > 0 {
184 | parts = append(parts, genai.ImageData("png", message.Image))
185 | }
186 | return parts
187 | }
188 |
189 | func multiTurnMessageToParts(messages []llm.InferMessage) ([]*genai.Content, *genai.Content) {
190 | sysInstructionParts := make([]genai.Part, 0)
191 | hist := make([]*genai.Content, 0, len(messages))
192 | for _, message := range messages {
193 | parts := []genai.Part{genai.Text(message.Content)}
194 | if message.Image != nil && len(message.Image) > 0 {
195 | parts = append(parts, genai.ImageData("png", message.Image))
196 | }
197 | if message.Role == "system" {
198 | sysInstructionParts = append(sysInstructionParts, parts...)
199 | continue
200 | }
201 | hist = append(hist, &genai.Content{
202 | Parts: parts,
203 | Role: message.Role,
204 | })
205 | }
206 | if len(sysInstructionParts) > 0 {
207 | return hist, &genai.Content{
208 | Parts: sysInstructionParts,
209 | }
210 | }
211 |
212 | return hist, nil
213 | }
214 |
215 | func flattenResponse(resp *genai.GenerateContentResponse) string {
216 | var result string
217 | for _, cand := range resp.Candidates {
218 | for _, part := range cand.Content.Parts {
219 | result += fmt.Sprintf("%v", part)
220 | }
221 | }
222 | return result
223 | }
224 |
225 | var _ llm.Responder = &VertexAIProvider{}
226 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GOLLuM
2 |
3 | Production-grade LLM tooling. At least, in theory -- stuff changes fast so don't expect stability from this library so much as ideas for your own apps.
4 |
5 | ## Features
6 |
7 | - Sane LLM provider abstraction.
8 | - Automated function dispatch
9 | - Parses arbitrary Go structs into JSONSchema for OpenAI - and validates when unmarshaling back to your structs
10 | - Simplified API to generate results from a single prompt or template
11 | - Highly performant vector store solution with exact search
12 | - SIMD acceleration for 10x better perf than naive approach, constant memory usage
13 | - Drop-in integration with OpenAI and other embedding providers
14 | - Carefully mocked, tested, and benchmarked.
15 | - Implementation of HyDE (hypothetical documents embeddings) for enhanced retrieval
16 | - MIT License
17 |
18 | # Examples
19 |
20 | ## Dispatch
21 |
22 | Function dispatch is a highly simplified and easy way to generate filled structs via an LLM.
23 |
24 | ```go
25 | type dinnerParty struct {
26 | Topic string `json:"topic" jsonschema:"required" jsonschema_description:"The topic of the conversation"`
27 | RandomWords []string `json:"random_words" jsonschema:"required" jsonschema_description:"Random words to prime the conversation"`
28 | }
29 | completer := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
30 | d := gollum.NewOpenAIDispatcher[dinnerParty]("dinner_party", "Given a topic, return random words", completer, nil)
31 | output, _ := d.Prompt(context.Background(), "Talk to me about dinosaurs")
32 | ```
33 |
34 | The result should be a filled `dinnerParty`` struct.
35 |
36 | ```go
37 | expected := dinnerParty{
38 | Topic: "dinosaurs",
39 | RandomWords: []string{"dinosaur", "fossil", "extinct"},
40 | }
41 | ```
42 |
43 | Some similar libraries / ideas:
44 |
45 | - Rust: [grantslatton/ai-functions](https://github.com/grantslatton/ai-functions/blob/main/ai_bin/src/main.rs)
46 | - Python: [jxnl/openai_function_call](https://github.com/jxnl/openai_function_call)
47 |
48 | ## Parsing
49 |
50 | ### Simplest
51 |
52 | Imagine you have a function `GetWeather` --
53 |
54 | ```go
55 | type getWeatherInput struct {
56 | Location string `json:"location" jsonschema_description:"The city and state, e.g. San Francisco, CA" jsonschema:"required"`
57 | Unit string `json:"unit,omitempty" jsonschema:"enum=celsius,enum=fahrenheit" jsonschema_description:"The unit of temperature"`
58 | }
59 |
60 | type getWeatherOutput struct {
61 | // ...
62 | }
63 |
64 | // GetWeather does something, this dosctring is annoying but theoretically possible to get
65 | func GetWeather(ctx context.Context, inp getWeatherInput) (out getWeatherOutput, err error) {
66 | return out, err
67 | }
68 | ```
69 |
70 | This is a common pattern for API design, as it is eay to share the `getWeatherInput` struct (well, imagine if it were public). See, for example, the [GRPC service definitions](https://github.com/grpc/grpc-go/blob/master/examples/helloworld/greeter_server/main.go#L43), or the [Connect RPC implementation](https://github.com/bufbuild/connect-go/blob/main/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go#LL155C6-L155C24). This means we can simplify the logic greatly by assuming a single input struct.
71 |
72 | Now, we can construct the responses:
73 |
74 | ```go
75 | type getWeatherInput struct {
76 | Location string `json:"location" jsonschema_description:"The city and state, e.g. San Francisco, CA" jsonschema:"required"`
77 | Unit string `json:"unit,omitempty" jsonschema:"enum=celsius,enum=fahrenheit" jsonschema_description:"The unit of temperature"`
78 | }
79 |
80 | fi := gollum.StructToJsonSchema("weather", "Get the current weather in a given location", getWeatherInput{})
81 |
82 | chatRequest := openai.ChatCompletionRequest{
83 | Model: "gpt-3.5-turbo-0613",
84 | Messages: []openai.ChatCompletionMessage{
85 | {
86 | Role: "user",
87 | Content: "Whats the temperature in Boston?",
88 | },
89 | },
90 | MaxTokens: 256,
91 | Temperature: 0.0,
92 | Tools: []openai.Tool{{Type: "function", Function: openai.FunctionDefinition(fi)}},
93 | ToolChoice: "weather",
94 | }
95 |
96 | ctx := context.Background()
97 | resp, err := api.SendRequest(ctx, chatRequest)
98 | parser := gollum.NewJSONParser[getWeatherInput](false)
99 | input, err := parser.Parse(ctx, resp.Choices[0].Message.ToolCalls[0].Function.Arguments)
100 | ```
101 |
102 | This example steps through all that, end to end. Some of this is 'sort of' pseudo-code, as the OpenAI clients I use haven't implemented support yet for functions, but it should also hopefully show that minimal modifications are necessary to upstream libraries.
103 |
104 | It is also possible to go from just the function definition to a fully formed OpenAI FunctionCall. Reflection gives name of the function for free, godoc parsing can get the function description too. I think in practice though that it's fairly unlikely that you need to change the name/description of the function that often, and in practice the inputs change more often. Using this pattern and compiling once makes the most sense to me.
105 |
106 | We should be able to chain the call for the single input and for the ctx + single input case and return it easily.
107 |
108 | ### Recursion on arbitrary structs without explicit definitions
109 |
110 | Say you have a struct that has JSON tags defined.
111 |
112 | ```go
113 | fi := gollum.StructToJsonSchema("ChatCompletion", "Call the OpenAI chat completion API", chatCompletionRequest{})
114 |
115 | chatRequest := chatCompletionRequest{
116 | ChatCompletionRequest: openai.ChatCompletionRequest{
117 | Model: "gpt-3.5-turbo-0613",
118 | Messages: []openai.ChatCompletionMessage{
119 | {
120 | Role: openai.ChatMessageRoleSystem,
121 | Content: "Construct a ChatCompletionRequest to answer the user's question, but using Kirby references. Do not answer the question directly using prior knowledge, you must generate a ChatCompletionRequest that will answer the question.",
122 | },
123 | {
124 | Role: openai.ChatMessageRoleUser,
125 | Content: "What is the definition of recursion?",
126 | },
127 | },
128 | MaxTokens: 256,
129 | Temperature: 0.0,
130 | },
131 | Tools: []openai.Tool{
132 | {
133 | Type: "function",
134 | Function: fi,
135 | }
136 | }
137 | }
138 | parser := gollum.NewJSONParser[openai.ChatCompletionRequest](false)
139 | input, err := parser.Parse(ctx, resp.Choices[0].Message.ToolCalls[0].Function.Arguments)
140 | ```
141 |
142 | On the first try, this yielded the following result:
143 |
144 | ```json
145 | {
146 | "model": "gpt-3.5-turbo",
147 | "messages": [
148 | {"role": "system", "content": "You are Kirby, a friendly virtual assistant."},
149 | {"role": "user", "content": "What is the definition of recursion?"}
150 | ]
151 | }
152 | ```
153 |
154 | That's really sick considering that _no_ effort was put into manually creating a new JSON struct, and the original struct didn't have any JSONSchema tags - just JSON serdes comments.
155 |
--------------------------------------------------------------------------------
/bench.md:
--------------------------------------------------------------------------------
1 |
2 | 2023-07-09
3 |
4 | ```
5 | goos: linux
6 | goarch: amd64
7 | pkg: github.com/stillmatic/gollum
8 | cpu: AMD Ryzen 9 7950X 16-Core Processor
9 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10-32 14752 83971 ns/op 64912 B/op 10 allocs/op
10 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=1-32 810657 1256 ns/op 288 B/op 3 allocs/op
11 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=10-32 663574 1639 ns/op 2880 B/op 6 allocs/op
12 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100-32 1263 1042804 ns/op 646807 B/op 190 allocs/op
13 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=1-32 113851 10399 ns/op 288 B/op 3 allocs/op
14 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=10-32 93428 12625 ns/op 2880 B/op 6 allocs/op
15 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=100-32 69256 17569 ns/op 25664 B/op 9 allocs/op
16 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000-32 147 8100071 ns/op 6505065 B/op 2734 allocs/op
17 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=1-32 10000 104921 ns/op 288 B/op 3 allocs/op
18 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=10-32 9831 123464 ns/op 2880 B/op 6 allocs/op
19 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=100-32 7848 152007 ns/op 25664 B/op 9 allocs/op
20 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10000-32 13 82907557 ns/op 64727761 B/op 29740 allocs/op
21 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=1-32 783 1514925 ns/op 288 B/op 3 allocs/op
22 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=10-32 679 1692251 ns/op 2880 B/op 6 allocs/op
23 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=100-32 650 2095042 ns/op 25664 B/op 9 allocs/op
24 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100000-32 2 814309670 ns/op 648192728 B/op 299774 allocs/op
25 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=1-32 82 16656264 ns/op 288 B/op 3 allocs/op
26 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=10-32 68 16402470 ns/op 2880 B/op 6 allocs/op
27 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=100-32 64 18205266 ns/op 25664 B/op 9 allocs/op
28 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000000-32 1 9578485965 ns/op 6552089784 B/op 2999874 allocs/op
29 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=1-32 7 161260588 ns/op 288 B/op 3 allocs/op
30 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=10-32 7 212760511 ns/op 2880 B/op 6 allocs/op
31 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=100-32 4 290261365 ns/op 25664 B/op 9 allocs/op
32 | PASS
33 | ok github.com/stillmatic/gollum 111.224s
34 | ```
35 |
36 | post perf improvements, 2023-07-17
37 | changes are most pronounced when k is large, more efficient memory reuse makes time nearly constant with k, about 2x improvement with large k.
38 |
39 | ```
40 | goos: linux
41 | goarch: amd64
42 | pkg: github.com/stillmatic/gollum
43 | cpu: AMD Ryzen 9 7950X 16-Core Processor
44 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10-32 12562 86781 ns/op 64680 B/op 10 allocs/op
45 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=1-32 1259478 967.6 ns/op 120 B/op 4 allocs/op
46 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=10-32 1201736 993.5 ns/op 304 B/op 3 allocs/op
47 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100-32 1335 847394 ns/op 652949 B/op 190 allocs/op
48 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=1-32 143896 8509 ns/op 120 B/op 4 allocs/op
49 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=10-32 122300 9787 ns/op 624 B/op 4 allocs/op
50 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=100-32 111930 11152 ns/op 2752 B/op 3 allocs/op
51 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000-32 127 8455112 ns/op 6477091 B/op 2734 allocs/op
52 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=1-32 13416 88695 ns/op 120 B/op 4 allocs/op
53 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=10-32 12246 97176 ns/op 624 B/op 4 allocs/op
54 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=100-32 9746 114710 ns/op 5952 B/op 4 allocs/op
55 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10000-32 14 88090856 ns/op 65255787 B/op 29740 allocs/op
56 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=1-32 769 1357818 ns/op 120 B/op 4 allocs/op
57 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=10-32 727 1555869 ns/op 624 B/op 4 allocs/op
58 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=100-32 752 1574506 ns/op 5952 B/op 4 allocs/op
59 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100000-32 2 888843642 ns/op 648192776 B/op 299774 allocs/op
60 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=1-32 69 15276284 ns/op 120 B/op 4 allocs/op
61 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=10-32 79 14270086 ns/op 624 B/op 4 allocs/op
62 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=100-32 68 15162731 ns/op 5952 B/op 4 allocs/op
63 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000000-32 1 8644221139 ns/op 6552072176 B/op 2999841 allocs/op
64 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=1-32 8 141239584 ns/op 120 B/op 4 allocs/op
65 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=10-32 1 1354937045 ns/op 624 B/op 4 allocs/op
66 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=100-32 7 156217518 ns/op 5952 B/op 4 allocs/op
67 | PASS
68 | ok github.com/stillmatic/gollum 105.535s
69 | ```
70 |
71 | post perf improvement - mac. stabilizes the allocations
72 |
73 | ```
74 | goos: darwin
75 | goarch: arm64
76 | pkg: github.com/stillmatic/gollum
77 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10-10 5341 220817 ns/op 65223 B/op 10 allocs/op
78 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=1-10 60616 19622 ns/op 120 B/op 4 allocs/op
79 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=10-10 60388 20033 ns/op 304 B/op 3 allocs/op
80 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100-10 536 2202933 ns/op 652278 B/op 190 allocs/op
81 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=1-10 6152 194476 ns/op 120 B/op 4 allocs/op
82 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=10-10 6094 198124 ns/op 624 B/op 4 allocs/op
83 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=100-10 5946 199925 ns/op 2752 B/op 3 allocs/op
84 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000-10 55 22152592 ns/op 6523947 B/op 2735 allocs/op
85 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=1-10 613 1953824 ns/op 120 B/op 4 allocs/op
86 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=10-10 610 1987216 ns/op 624 B/op 4 allocs/op
87 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=100-10 580 2051436 ns/op 5952 B/op 4 allocs/op
88 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10000-10 5 222244750 ns/op 64782620 B/op 29747 allocs/op
89 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=1-10 61 19383620 ns/op 120 B/op 4 allocs/op
90 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=10-10 60 19823898 ns/op 624 B/op 4 allocs/op
91 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=100-10 57 20027584 ns/op 5952 B/op 4 allocs/op
92 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100000-10 1 2207505500 ns/op 648271208 B/op 299808 allocs/op
93 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=1-10 6 196473680 ns/op 120 B/op 4 allocs/op
94 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=10-10 6 197389812 ns/op 624 B/op 4 allocs/op
95 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=100-10 5 200068883 ns/op 5952 B/op 4 allocs/op
96 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000000-10 1 22239769458 ns/op 6552038696 B/op 2999849 allocs/op
97 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=1-10 1 1966544833 ns/op 120 B/op 4 allocs/op
98 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=10-10 1 1963972417 ns/op 624 B/op 4 allocs/op
99 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=100-10 1 1988149583 ns/op 5952 B/op 4 allocs/op
100 | PASS
101 | ok github.com/stillmatic/gollum 142.897s
102 | ```
103 |
104 | mac is expected to be slower. however, post change, what we see is that our memury usage is much more stable - consistently 4 allocs per operation and much less memory usage too. the memory characteristics are proportional to `k`. the desktop chip is faster and has SIMD enhanced distance calculation, so not unexpected. the runtime is also consistently linear with `n` where `n` is the number of values in the db.
105 |
106 |
--------------------------------------------------------------------------------
/packages/vectorstore/vectorstore_compressed_test.go:
--------------------------------------------------------------------------------
1 | package vectorstore_test
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "context"
7 | "crypto/rand"
8 | "fmt"
9 | vectorstore2 "github.com/stillmatic/gollum/packages/vectorstore"
10 | mathrand "math/rand"
11 | "os"
12 | "testing"
13 |
14 | "github.com/google/uuid"
15 | "github.com/stillmatic/gollum"
16 | "github.com/stretchr/testify/assert"
17 | )
18 |
19 | func TestCompressedVectorStore(t *testing.T) {
20 | vs := vectorstore2.NewGzipVectorStore()
21 | t.Run("implements interface", func(t *testing.T) {
22 | var vs2 vectorstore2.VectorStore
23 | vs2 = vs
24 | assert.NotNil(t, vs2)
25 | })
26 | ctx := context.Background()
27 | testStrings := []string{
28 | "Japan's Seiko Epson Corp. has developed a 12-gram flying microrobot.",
29 | "The latest tiny flying robot has been unveiled in Japan.",
30 | "Michael Phelps won the gold medal in the 400 individual medley.",
31 | }
32 | t.Run("testInsert", func(t *testing.T) {
33 | for _, str := range testStrings {
34 | vs.Insert(ctx, gollum.Document{
35 | ID: uuid.NewString(),
36 | Content: str,
37 | })
38 | }
39 | docs, err := vs.RetrieveAll(ctx)
40 | assert.NoError(t, err)
41 | assert.Equal(t, 5, len(docs))
42 | })
43 | t.Run("correctness", func(t *testing.T) {
44 | for _, str := range testStrings {
45 | vs.Insert(ctx, gollum.Document{
46 | ID: uuid.NewString(),
47 | Content: str,
48 | })
49 | }
50 | docs, err := vs.Query(ctx, vectorstore2.QueryRequest{
51 | Query: "Where was the new robot unveiled?",
52 | K: 5,
53 | })
54 | assert.NoError(t, err)
55 | assert.Equal(t, 3, len(docs))
56 | assert.Equal(t, "Japan's Seiko Epson Corp. has developed a 12-gram flying microrobot.", docs[0].Content)
57 | assert.Equal(t, "The latest tiny flying robot has been unveiled in Japan.", docs[1].Content)
58 | })
59 | }
60 |
61 | func BenchmarkCompressedVectorStore(b *testing.B) {
62 | ctx := context.Background()
63 | // Test different sizes
64 | sizes := []int{10, 100, 1000, 10_000, 100_000}
65 | // note that runtime doesn't really depend on K -
66 | ks := []int{1, 10, 100}
67 | // benchmark inserts
68 | stores := map[string]vectorstore2.VectorStore{
69 | "DummyVectorStore": vectorstore2.NewDummyVectorStore(),
70 | "StdGzipVectorStore": vectorstore2.NewStdGzipVectorStore(),
71 | "ZstdVectorStore": vectorstore2.NewZstdVectorStore(),
72 | "GzipVectorStore": vectorstore2.NewGzipVectorStore(),
73 | }
74 |
75 | for vsName, vs := range stores {
76 | // for _, size := range sizes {
77 | // b.Run(fmt.Sprintf("%s-Insert-%d", vsName, size), func(b *testing.B) {
78 | // // Create vector store using live compression
79 | // docs := make([]gollum.Document, size)
80 | // // Generate synthetic docs
81 | // for i := range docs {
82 | // docs[i] = syntheticDoc()
83 | // }
84 | // b.ReportAllocs()
85 | // b.ResetTimer()
86 | // for n := 0; n < b.N; n++ {
87 | // // Insert docs
88 | // for _, doc := range docs {
89 | // vs.Insert(ctx, doc)
90 | // }
91 | // }
92 | // })
93 | // }
94 | // // Concurrent writes to a slice are ok
95 | // for _, size := range sizes {
96 | // b.Run(fmt.Sprintf("%s-InsertConcurrent-%d", vsName, size), func(b *testing.B) {
97 | // // Create vector store using live compression
98 | // docs := make([]gollum.Document, size)
99 | // // Generate synthetic docs
100 | // for i := range docs {
101 | // docs[i] = syntheticDoc()
102 | // }
103 | // var wg sync.WaitGroup
104 | // sem := make(chan struct{}, 8)
105 | // b.ReportAllocs()
106 | // b.ResetTimer()
107 | // for n := 0; n < b.N; n++ {
108 | // // Insert docs
109 | // for _, doc := range docs {
110 | // wg.Add(1)
111 | // sem <- struct{}{}
112 | // go func(doc gollum.Document) {
113 | // defer wg.Done()
114 | // defer func() { <-sem }()
115 | // vs.Insert(ctx, doc)
116 | // }(doc)
117 | // }
118 | // wg.Wait()
119 | // }
120 | // })
121 | // }
122 | // benchmark queries
123 | for _, size := range sizes {
124 | f, err := os.Open("testdata/enwik8")
125 | if err != nil {
126 | panic(err)
127 | }
128 | defer f.Close()
129 | lines := make([]string, 0)
130 | scanner := bufio.NewScanner(f)
131 | for scanner.Scan() {
132 | lines = append(lines, scanner.Text())
133 | }
134 | for _, k := range ks {
135 | if k <= size {
136 | b.Run(fmt.Sprintf("%s-Query-%d-%d", vsName, size, k), func(b *testing.B) {
137 | // Create vector store and insert docs
138 | for i := 0; i < size; i++ {
139 | vs.Insert(ctx, gollum.NewDocumentFromString(lines[i]))
140 | }
141 | query := vectorstore2.QueryRequest{
142 | Query: lines[size+1],
143 | }
144 | b.ReportAllocs()
145 | b.ResetTimer()
146 | // Create query
147 | for n := 0; n < b.N; n++ {
148 | vs.Query(ctx, query)
149 | }
150 | })
151 | }
152 | }
153 | }
154 | }
155 | }
156 |
157 | // Helper functions
158 | func syntheticString() string {
159 | // Random length between 8 and 32
160 | randLength := mathrand.Intn(32-8+1) + 8
161 |
162 | // Generate random bytes
163 | randBytes := make([]byte, randLength)
164 | rand.Read(randBytes)
165 |
166 | // Format as hex string
167 | return fmt.Sprintf("%x", randBytes)
168 | }
169 |
170 | // syntheticQuery return query request with random embedding
171 | func syntheticQuery(k int) vectorstore2.QueryRequest {
172 | return vectorstore2.QueryRequest{
173 | Query: syntheticString(),
174 | K: k,
175 | }
176 | }
177 |
178 | func BenchmarkStringToBytes(b *testing.B) {
179 | st := syntheticString()
180 | b.ResetTimer()
181 | b.Run("byteSlice", func(b *testing.B) {
182 | for n := 0; n < b.N; n++ {
183 | _ = []byte(st)
184 | }
185 | })
186 | b.Run("byteSliceCopy", func(b *testing.B) {
187 | for n := 0; n < b.N; n++ {
188 | bts := make([]byte, len(st))
189 | copy(bts, st)
190 | }
191 | })
192 | b.Run("byteSliceCopyAppend", func(b *testing.B) {
193 | for n := 0; n < b.N; n++ {
194 | bts := make([]byte, 0)
195 | bts = append(bts, st...)
196 | _ = bts
197 | }
198 | })
199 | b.Run("bytesBuffer", func(b *testing.B) {
200 | for n := 0; n < b.N; n++ {
201 | bb := bytes.NewBufferString(st)
202 | _ = bb.Bytes()
203 | }
204 | })
205 | b.Run("bytesBufferEmpty", func(b *testing.B) {
206 | for n := 0; n < b.N; n++ {
207 | var bb bytes.Buffer
208 | bb.WriteString(st)
209 | _ = bb.Bytes()
210 | }
211 | })
212 | }
213 |
214 | func dummyCompress(src []byte) []byte {
215 | return src
216 | }
217 |
218 | func minMax(val1, val2 float64) (float64, float64) {
219 | if val1 < val2 {
220 | return val1, val2
221 | }
222 | return val2, val1
223 | }
224 |
225 | func BenchmarkE2E(b *testing.B) {
226 | f, err := os.Open("testdata/enwik8")
227 | if err != nil {
228 | panic(err)
229 | }
230 | defer f.Close()
231 | lines := make([]string, 0)
232 | scanner := bufio.NewScanner(f)
233 | for scanner.Scan() {
234 | lines = append(lines, scanner.Text())
235 | }
236 | // st1 := syntheticString()
237 | // st2 := syntheticString()
238 | st1 := lines[1]
239 | st2 := lines[2]
240 | b.ResetTimer()
241 | b.Run("minMax", func(b *testing.B) {
242 | for n := 0; n < b.N; n++ {
243 | Cx1 := float64(len(st1))
244 | Cx2 := float64(len(st2))
245 | min, max := minMax(Cx1, Cx2)
246 | _ = min
247 | _ = max
248 | }
249 | })
250 | var bb bytes.Buffer
251 | b.Run("resetBytesBufferBytes", func(b *testing.B) {
252 | st1b := []byte(st1)
253 | st2b := []byte(st2)
254 | spb := []byte(" ")
255 | for n := 0; n < b.N; n++ {
256 | Cx1 := float64(len(st1b))
257 | Cx2 := float64(len(st2b))
258 | bb.Reset()
259 | bb.Write(st1b)
260 | bb.Write(spb)
261 | bb.Write(st2b)
262 | b_ := bb.Bytes()
263 | x1x2 := dummyCompress(b_)
264 | Cx1x2 := float64(len(x1x2))
265 | min, max := minMax(Cx1, Cx2)
266 | ncd := (Cx1x2 - min) / (max)
267 | _ = ncd
268 | }
269 | })
270 | }
271 |
272 | func BenchmarkConcatenateStrings(b *testing.B) {
273 | f, err := os.Open("testdata/enwik8")
274 | if err != nil {
275 | panic(err)
276 | }
277 | defer f.Close()
278 | lines := make([]string, 0)
279 | scanner := bufio.NewScanner(f)
280 | for scanner.Scan() {
281 | lines = append(lines, scanner.Text())
282 | }
283 | // st1 := syntheticString()
284 | // st2 := syntheticString()
285 | st1 := lines[1]
286 | st2 := lines[2]
287 | b.ResetTimer()
288 | b.Run("minMax", func(b *testing.B) {
289 | for n := 0; n < b.N; n++ {
290 | Cx1 := float64(len(st1))
291 | Cx2 := float64(len(st2))
292 | min, max := minMax(Cx1, Cx2)
293 | _ = min
294 | _ = max
295 | }
296 | })
297 | b.Run("concatenate", func(b *testing.B) {
298 | for n := 0; n < b.N; n++ {
299 | _ = []byte(st1 + " " + st2)
300 | }
301 | })
302 | b.Run("bytesBuffer", func(b *testing.B) {
303 | for n := 0; n < b.N; n++ {
304 | bb := bytes.NewBufferString(st1)
305 | bb.WriteString(" ")
306 | bb.WriteString(st2)
307 | _ = bb.Bytes()
308 | }
309 | })
310 | b.Run("bytesBufferEmpty", func(b *testing.B) {
311 | for n := 0; n < b.N; n++ {
312 | var bb bytes.Buffer
313 | bb.WriteString(st1)
314 | bb.WriteString(" ")
315 | bb.WriteString(st2)
316 | _ = bb.Bytes()
317 | }
318 | })
319 | var bb bytes.Buffer
320 | b.Run("resetBytesBuffer", func(b *testing.B) {
321 | for n := 0; n < b.N; n++ {
322 | bb.Reset()
323 | bb.WriteString(st1)
324 | bb.WriteString(" ")
325 | bb.WriteString(st2)
326 | _ = bb.Bytes()
327 | }
328 | })
329 |
330 | }
331 |
332 | func BenchmarkCompress(b *testing.B) {
333 | compressors := map[string]vectorstore2.Compressor{
334 | "DummyCompressor": vectorstore2.NewDummyVectorStore().Compressor,
335 | "StdGzipCompressor": vectorstore2.NewStdGzipVectorStore().Compressor,
336 | "ZstdCompressor": vectorstore2.NewZstdVectorStore().Compressor,
337 | "GzipCompressor": vectorstore2.NewGzipVectorStore().Compressor,
338 | }
339 | str := syntheticString()
340 | b.ResetTimer()
341 | for name, compressor := range compressors {
342 | b.Run(name, func(b *testing.B) {
343 | for n := 0; n < b.N; n++ {
344 | _ = compressor.Compress([]byte(str))
345 | }
346 | })
347 | }
348 | }
349 |
--------------------------------------------------------------------------------
/packages/llm/configs.go:
--------------------------------------------------------------------------------
1 | package llm
2 |
3 | const (
4 | ProviderAnthropic ProviderType = "anthropic"
5 | ProviderGoogle ProviderType = "google"
6 | ProviderVertex ProviderType = "vertex"
7 |
8 | ProviderOpenAI ProviderType = "openai"
9 | ProviderGroq ProviderType = "groq"
10 | ProviderTogether ProviderType = "together"
11 | ProviderHyperbolic ProviderType = "hyperbolic"
12 | ProviderDeepseek ProviderType = "deepseek"
13 |
14 | ProviderVoyage ProviderType = "voyage"
15 | ProviderMixedBread ProviderType = "mixedbread"
16 | )
17 |
18 | // configs are user declared, here's some useful defaults
19 | const (
20 | // LLM models
21 |
22 | // aka claude 3.6
23 | ConfigClaude3Dot6Sonnet = "claude-3.6-sonnet"
24 | // the traditional 3.5
25 | ConfigClaude3Dot5Sonnet = "claude-3.5-sonnet"
26 | ConfigClaude3Dot7Sonnet = "claude-3.7-sonnet"
27 |
28 | ConfigGPT4Mini = "gpt-4-mini"
29 | ConfigGPT4o = "gpt-4o"
30 | ConfigOpenAIO1 = "oai-o1"
31 | ConfigOpenAIO1Mini = "oai-o1-mini"
32 | ConfigOpenAIO1Preview = "oai-o1-preview"
33 | ConfigOpenAIO3Mini = "oai-o3-mini"
34 |
35 | ConfigGroqLlama70B = "groq-llama-70b"
36 | ConfigGroqLlama8B = "groq-llama-8b"
37 | ConfigGroqGemma9B = "groq-gemma2-9b"
38 | ConfigGroqMixtral = "groq-mixtral"
39 |
40 | ConfigTogetherGemma27B = "together-gemma-27b"
41 | ConfigTogetherDeepseekCoder33B = "together-deepseek-coder-33b"
42 |
43 | ConfigGemini1Dot5Flash8B = "gemini-flash-8b"
44 | ConfigGemini1Dot5Flash = "gemini-flash"
45 | ConfigGemini1Dot5Pro = "gemini-pro"
46 | ConfigGemini2Flash = "gemini-2-flash"
47 |
48 | ConfigHyperbolicLlama405B = "hyperbolic-llama-405b"
49 | ConfigHyperbolicLlama405BBase = "hyperbolic-llama-405b-base"
50 | ConfigHyperbolicLlama70B = "hyperbolic-llama-70b"
51 | ConfigHyperbolicLlama8B = "hyperbolic-llama-8b"
52 |
53 | ConfigDeepseekChat = "deepseek-chat"
54 | ConfigDeepseekCoder = "deepseek-coder"
55 |
56 | // Vertex
57 | ConfigClaude3Dot5SonnetVertex = "claude-3.5-sonnet-vertex"
58 | ConfigLlama405BVertex = "llama-405b-vertex"
59 |
60 | // Embedding models
61 | ConfigOpenAITextEmbedding3Small = "openai-text-embedding-3-small"
62 | ConfigOpenAITextEmbedding3Large = "openai-text-embedding-3-large"
63 | ConfigOpenAITextEmbeddingAda002 = "openai-text-embedding-ada-002"
64 |
65 | ConfigGeminiTextEmbedding4 = "gemini-text-embedding-004"
66 |
67 | ConfigMxbaiEmbedLargeV1 = "mxbai-embed-large-v1"
68 | ConfigVoyageLarge2Instruct = "voyage-large-2-instruct"
69 | )
70 |
71 | var configs = map[string]ModelConfig{
72 | ConfigClaude3Dot5Sonnet: {
73 | ProviderType: ProviderAnthropic,
74 | ModelName: "claude-3-5-sonnet-20240620",
75 | ModelType: ModelTypeLLM,
76 | CentiCentsPerMillionInputTokens: 30000,
77 | CentiCentsPerMillionOutputTokens: 150000,
78 | },
79 | ConfigClaude3Dot5SonnetVertex: {
80 | ProviderType: ProviderVertex,
81 | ModelName: "claude-3-5-sonnet@20240620",
82 | ModelType: ModelTypeLLM,
83 | CentiCentsPerMillionInputTokens: 30000,
84 | CentiCentsPerMillionOutputTokens: 150000,
85 | },
86 | ConfigLlama405BVertex: {
87 | ProviderType: ProviderVertex,
88 | ModelName: "llama3-405b-instruct-maas",
89 | ModelType: ModelTypeLLM,
90 | // The Llama 3.1 API service is at no cost during public preview, and will be priced as per dollar-per-1M-tokens at GA.
91 | },
92 | ConfigClaude3Dot6Sonnet: {
93 | ProviderType: ProviderAnthropic,
94 | ModelName: "claude-3-5-sonnet-20241022",
95 | },
96 | ConfigClaude3Dot7Sonnet: {
97 | ProviderType: ProviderAnthropic,
98 | ModelName: "claude-3-7-sonnet-latest",
99 | },
100 | ConfigGPT4Mini: {
101 | ProviderType: ProviderOpenAI,
102 | ModelName: "gpt-4o-mini",
103 | ModelType: ModelTypeLLM,
104 | CentiCentsPerMillionInputTokens: 1_500,
105 | CentiCentsPerMillionOutputTokens: 6_000,
106 | },
107 | ConfigGPT4o: {
108 | ProviderType: ProviderOpenAI,
109 | // NB 2025-01-08: 2024-08-06 remains 'latest', but there is also 'gpt-4o-2024-11-20'
110 | ModelName: "gpt-4o-2024-08-06",
111 | CentiCentsPerMillionInputTokens: 25_000,
112 | CentiCentsPerMillionOutputTokens: 100_000,
113 | },
114 | ConfigOpenAIO1: {
115 | ProviderType: ProviderOpenAI,
116 | ModelName: "o1",
117 | CentiCentsPerMillionInputTokens: 150_000,
118 | CentiCentsPerMillionOutputTokens: 600_000,
119 | },
120 | ConfigOpenAIO1Mini: {
121 | ProviderType: ProviderOpenAI,
122 | ModelName: "o1-mini",
123 | CentiCentsPerMillionInputTokens: 30_000,
124 | CentiCentsPerMillionOutputTokens: 120_000,
125 | },
126 | ConfigOpenAIO3Mini: {
127 | ProviderType: ProviderOpenAI,
128 | ModelName: "o3-mini",
129 | CentiCentsPerMillionInputTokens: 150_000,
130 | CentiCentsPerMillionOutputTokens: 600_000,
131 | },
132 | ConfigOpenAIO1Preview: {
133 | ProviderType: ProviderOpenAI,
134 | ModelName: "o1-preview",
135 | CentiCentsPerMillionInputTokens: 150_000,
136 | CentiCentsPerMillionOutputTokens: 600_000,
137 | },
138 | ConfigGroqLlama70B: {
139 | ProviderType: ProviderGroq,
140 | ModelName: "llama-3.3-70b-versatile",
141 | CentiCentsPerMillionInputTokens: 5900,
142 | CentiCentsPerMillionOutputTokens: 7900,
143 | },
144 | ConfigGroqMixtral: {
145 | ProviderType: ProviderGroq,
146 | ModelName: "mixtral-8x7b-32768",
147 | ModelType: ModelTypeLLM,
148 | CentiCentsPerMillionInputTokens: 2400,
149 | CentiCentsPerMillionOutputTokens: 2400,
150 | },
151 | ConfigGroqGemma9B: {
152 | ProviderType: ProviderGroq,
153 | ModelName: "gemma2-9b-it",
154 | ModelType: ModelTypeLLM,
155 | CentiCentsPerMillionInputTokens: 2000,
156 | CentiCentsPerMillionOutputTokens: 2000,
157 | },
158 | ConfigGroqLlama8B: {
159 | ProviderType: ProviderGroq,
160 | ModelType: ModelTypeLLM,
161 | CentiCentsPerMillionInputTokens: 500,
162 | CentiCentsPerMillionOutputTokens: 800,
163 | ModelName: "llama-3.1-8b-instant",
164 | },
165 | ConfigTogetherGemma27B: {
166 | ProviderType: ProviderTogether,
167 | ModelName: "google/gemma-2-27b-it",
168 | ModelType: ModelTypeLLM,
169 | CentiCentsPerMillionOutputTokens: 8000,
170 | CentiCentsPerMillionInputTokens: 8000,
171 | },
172 | ConfigTogetherDeepseekCoder33B: {
173 | ProviderType: ProviderTogether,
174 | ModelName: "deepseek-ai/deepseek-coder-33b-instruct",
175 | ModelType: ModelTypeLLM,
176 | CentiCentsPerMillionOutputTokens: 8000,
177 | CentiCentsPerMillionInputTokens: 8000,
178 | },
179 | ConfigGemini1Dot5Flash: {
180 | ProviderType: ProviderGoogle,
181 | ModelName: "gemini-1.5-flash",
182 | ModelType: ModelTypeLLM,
183 | // assumes < 128k
184 | CentiCentsPerMillionOutputTokens: 3000,
185 | CentiCentsPerMillionInputTokens: 750,
186 | },
187 | ConfigGemini1Dot5Flash8B: {
188 | ProviderType: ProviderGoogle,
189 | ModelName: "gemini-1.5-flash-8b",
190 | ModelType: ModelTypeLLM,
191 | // assumes < 128k
192 | CentiCentsPerMillionOutputTokens: 1500,
193 | CentiCentsPerMillionInputTokens: 375,
194 | },
195 | ConfigGemini1Dot5Pro: {
196 | ProviderType: ProviderGoogle,
197 | ModelName: "gemini-1.5-pro",
198 | ModelType: ModelTypeLLM,
199 | // assumes < 128k
200 | CentiCentsPerMillionOutputTokens: 50000,
201 | CentiCentsPerMillionInputTokens: 12500,
202 | },
203 | ConfigGemini2Flash: {
204 | ProviderType: ProviderGoogle,
205 | ModelName: "gemini-2.0-flash",
206 | },
207 | ConfigHyperbolicLlama405B: {
208 | ProviderType: ProviderHyperbolic,
209 | ModelName: "meta-llama/Meta-Llama-3.1-405B-Instruct",
210 | ModelType: ModelTypeLLM,
211 |
212 | CentiCentsPerMillionInputTokens: 40000,
213 | CentiCentsPerMillionOutputTokens: 40000,
214 | },
215 | ConfigHyperbolicLlama405BBase: {
216 | ProviderType: ProviderHyperbolic,
217 | ModelName: "meta-llama/Meta-Llama-3.1-405B",
218 | ModelType: ModelTypeLLM,
219 |
220 | CentiCentsPerMillionInputTokens: 40000,
221 | CentiCentsPerMillionOutputTokens: 40000,
222 | },
223 | ConfigHyperbolicLlama70B: {
224 | ProviderType: ProviderHyperbolic,
225 | ModelName: "meta-llama/Meta-Llama-3.1-70B-Instruct",
226 | ModelType: ModelTypeLLM,
227 |
228 | CentiCentsPerMillionInputTokens: 4000,
229 | CentiCentsPerMillionOutputTokens: 4000,
230 | },
231 | ConfigHyperbolicLlama8B: {
232 | ProviderType: ProviderHyperbolic,
233 | ModelName: "meta-llama/Meta-Llama-3.1-8B-Instruct",
234 | ModelType: ModelTypeLLM,
235 |
236 | CentiCentsPerMillionInputTokens: 1000,
237 | CentiCentsPerMillionOutputTokens: 1000,
238 | },
239 |
240 | ConfigDeepseekChat: {
241 | ProviderType: ProviderDeepseek,
242 | ModelName: "deepseek-chat",
243 | ModelType: ModelTypeLLM,
244 |
245 | // assume cache miss
246 | CentiCentsPerMillionInputTokens: 1400,
247 | CentiCentsPerMillionOutputTokens: 2800,
248 | },
249 | ConfigDeepseekCoder: {
250 | ProviderType: ProviderDeepseek,
251 | ModelName: "deepseek-coder",
252 | ModelType: ModelTypeLLM,
253 |
254 | // assume cache miss
255 | CentiCentsPerMillionInputTokens: 1400,
256 | CentiCentsPerMillionOutputTokens: 2800,
257 | },
258 |
259 | ConfigOpenAITextEmbedding3Small: {
260 | ProviderType: ProviderOpenAI,
261 | ModelName: "text-embedding-3-small",
262 | ModelType: ModelTypeEmbedding,
263 | },
264 | ConfigOpenAITextEmbedding3Large: {
265 | ProviderType: ProviderOpenAI,
266 | ModelName: "text-embedding-3-large",
267 | ModelType: ModelTypeEmbedding,
268 | },
269 | ConfigOpenAITextEmbeddingAda002: {
270 | ProviderType: ProviderOpenAI,
271 | ModelName: "text-embedding-ada-002",
272 | ModelType: ModelTypeEmbedding,
273 | },
274 |
275 | ConfigGeminiTextEmbedding4: {
276 | ProviderType: ProviderGoogle,
277 | ModelName: "text-embedding-004",
278 | ModelType: ModelTypeEmbedding,
279 | },
280 |
281 | ConfigMxbaiEmbedLargeV1: {
282 | ProviderType: ProviderMixedBread,
283 | ModelName: "mxbai-embed-large-v1",
284 | ModelType: ModelTypeEmbedding,
285 | },
286 |
287 | ConfigVoyageLarge2Instruct: {
288 | ProviderType: ProviderVoyage,
289 | ModelName: "voyage-large-2-instruct",
290 | ModelType: ModelTypeEmbedding,
291 | },
292 | }
293 |
--------------------------------------------------------------------------------
/packages/dispatch/functions_test.go:
--------------------------------------------------------------------------------
1 | package dispatch_test
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "os"
7 | "testing"
8 |
9 | "github.com/stillmatic/gollum/packages/dispatch"
10 | "github.com/stillmatic/gollum/packages/jsonparser"
11 |
12 | "github.com/joho/godotenv"
13 | "github.com/sashabaranov/go-openai"
14 | "github.com/stillmatic/gollum/internal/testutil"
15 | "github.com/stretchr/testify/assert"
16 | )
17 |
18 | type addInput struct {
19 | A int `json:"a" json_schema:"required"`
20 | B int `json:"b" json_schema:"required"`
21 | }
22 |
23 | type getWeatherInput struct {
24 | Location string `json:"location" jsonschema_description:"The city and state." jsonschema:"required,example=San Francisco, CA"`
25 | Unit string `json:"unit,omitempty" jsonschema:"enum=celsius,enum=fahrenheit" jsonschema_description:"The unit of temperature,default=fahrenheit"`
26 | }
27 |
28 | type counter struct {
29 | Count int `json:"count" jsonschema:"required" jsonschema_description:"total number of words in sentence"`
30 | Words []string `json:"words" jsonschema:"required" jsonschema_description:"list of words in sentence"`
31 | }
32 |
33 | type blobNode struct {
34 | Name string `json:"name" jsonschema:"required"`
35 | Children []blobNode `json:"children,omitempty" jsonschema_description:"list of child nodes - only applicable if this is a directory"`
36 | NodeType string `json:"node_type" jsonschema:"required,enum=file,enum=folder" jsonschema_description:"type of node, inferred from name"`
37 | }
38 |
39 | type queryNode struct {
40 | Question string `json:"question" jsonschema:"required" jsonschema_description:"question to ask - questions can use information from children questions"`
41 | // NodeType string `json:"node_type" jsonschema:"required,enum=single_question,enum=merge_responses" jsonschema_description:"type of question. Either a single question or a multi question merge when there are multiple questions."`
42 | Children []queryNode `json:"children,omitempty" jsonschema_description:"list of child questions that need to be answered before this question can be answered. Use a subquery when anything may be unknown, and we need to ask multiple questions to get the answer. Dependences must only be other queries."`
43 | }
44 |
45 | func TestConstructJSONSchema(t *testing.T) {
46 | t.Run("add_", func(t *testing.T) {
47 | res := dispatch.StructToJsonSchema("add_", "adds two numbers", addInput{})
48 | expectedStr := `{"name":"add_","description":"adds two numbers","parameters":{"properties":{"a":{"type":"integer"},"b":{"type":"integer"}},"type":"object","required":["a","b"]}}`
49 | b, err := json.Marshal(res)
50 | assert.NoError(t, err)
51 | assert.Equal(t, expectedStr, string(b))
52 | })
53 | t.Run("getWeather", func(t *testing.T) {
54 | res := dispatch.StructToJsonSchema("getWeather", "Get the current weather in a given location", getWeatherInput{})
55 | assert.Equal(t, res.Name, "getWeather")
56 | assert.Equal(t, res.Description, "Get the current weather in a given location")
57 | // assert.Equal(t, res.Parameters.Type, "object")
58 | expectedStr := `{"name":"getWeather","description":"Get the current weather in a given location","parameters":{"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"],"description":"The unit of temperature"}},"type":"object","required":["location"]}}`
59 | b, err := json.Marshal(res)
60 | assert.NoError(t, err)
61 | assert.Equal(t, expectedStr, string(b))
62 | })
63 | }
64 |
65 | func ptr[T any](v T) *T {
66 | return &v
67 | }
68 |
69 | func TestEndToEnd(t *testing.T) {
70 | godotenv.Load()
71 | baseAPIURL := "https://api.openai.com/v1/chat/completions"
72 | openAIKey := os.Getenv("OPENAI_API_KEY")
73 | assert.NotEmpty(t, openAIKey)
74 |
75 | api := testutil.NewTestAPI(baseAPIURL, openAIKey)
76 | t.Run("weather", func(t *testing.T) {
77 | t.Skip("somewhat flaky - word counter is more reliable")
78 | fi := dispatch.StructToJsonSchema("weather", "Get the current weather in a given location", getWeatherInput{})
79 |
80 | chatRequest := openai.ChatCompletionRequest{
81 | Model: openai.GPT3Dot5Turbo1106,
82 | Messages: []openai.ChatCompletionMessage{
83 | {
84 | Role: "user",
85 | Content: "Whats the temperature in Boston?",
86 | },
87 | },
88 | MaxTokens: 256,
89 | Temperature: 0.0,
90 | Tools: []openai.Tool{{Type: "function", Function: ptr(openai.FunctionDefinition(fi))}},
91 | ToolChoice: "weather",
92 | }
93 |
94 | ctx := context.Background()
95 | resp, err := api.SendRequest(ctx, chatRequest)
96 | assert.NoError(t, err)
97 |
98 | assert.Equal(t, resp.Model, "gpt-3.5-turbo-0613")
99 | assert.NotEmpty(t, resp.Choices)
100 | assert.Empty(t, resp.Choices[0].Message.Content)
101 | assert.NotNil(t, resp.Choices[0].Message.ToolCalls)
102 | assert.Equal(t, resp.Choices[0].Message.ToolCalls[0].Function.Name, "weather")
103 |
104 | // this is somewhat flaky - about 20% of the time it returns 'Boston'
105 | expectedArg := []byte(`{"location": "Boston, MA"}`)
106 | parser := jsonparser.NewJSONParserGeneric[getWeatherInput](false)
107 | expectedStruct, err := parser.Parse(ctx, expectedArg)
108 | assert.NoError(t, err)
109 | input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
110 | assert.NoError(t, err)
111 | assert.Equal(t, expectedStruct, input)
112 | })
113 |
114 | t.Run("counter", func(t *testing.T) {
115 | fi := dispatch.StructToJsonSchema("split_word", "Break sentences into words", counter{})
116 | chatRequest := openai.ChatCompletionRequest{
117 | Model: "gpt-3.5-turbo-0613",
118 | Messages: []openai.ChatCompletionMessage{
119 | {
120 | Role: "user",
121 | Content: "「What is the weather like in Boston?」Break down the above sentence into words",
122 | },
123 | },
124 | MaxTokens: 256,
125 | Temperature: 0.0,
126 | Tools: []openai.Tool{
127 | {Type: "function", Function: ptr(openai.FunctionDefinition(fi))},
128 | },
129 | }
130 | ctx := context.Background()
131 | resp, err := api.SendRequest(ctx, chatRequest)
132 | assert.NoError(t, err)
133 |
134 | assert.Equal(t, resp.Model, "gpt-3.5-turbo-0613")
135 | assert.NotEmpty(t, resp.Choices)
136 | assert.Empty(t, resp.Choices[0].Message.Content)
137 | assert.NotNil(t, resp.Choices[0].Message.ToolCalls)
138 | assert.Equal(t, resp.Choices[0].Message.ToolCalls[0].Function.Name, "split_word")
139 |
140 | expectedStruct := counter{
141 | Count: 7,
142 | Words: []string{"What", "is", "the", "weather", "like", "in", "Boston?"},
143 | }
144 | parser := jsonparser.NewJSONParserGeneric[counter](false)
145 | input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
146 | assert.NoError(t, err)
147 | assert.Equal(t, expectedStruct, input)
148 | })
149 |
150 | t.Run("callOpenAI", func(t *testing.T) {
151 | fi := dispatch.StructToJsonSchema("ChatCompletion", "Call the OpenAI chat completion API", openai.ChatCompletionRequest{})
152 |
153 | chatRequest := openai.ChatCompletionRequest{
154 | Model: "gpt-3.5-turbo-0613",
155 | Messages: []openai.ChatCompletionMessage{
156 | {
157 | Role: openai.ChatMessageRoleSystem,
158 | Content: "Construct a ChatCompletionRequest to answer the user's question, but using Kirby references. Do not answer the question directly using prior knowledge, you must generate a ChatCompletionRequest that will answer the question.",
159 | },
160 | {
161 | Role: openai.ChatMessageRoleUser,
162 | Content: "What is the definition of recursion?",
163 | },
164 | },
165 | MaxTokens: 256,
166 | Temperature: 0.0,
167 | Tools: []openai.Tool{
168 | {Type: "function", Function: ptr(openai.FunctionDefinition(fi))},
169 | },
170 | }
171 |
172 | ctx := context.Background()
173 | resp, err := api.SendRequest(ctx, chatRequest)
174 | assert.NoError(t, err)
175 | assert.Equal(t, resp.Model, "gpt-3.5-turbo-0613")
176 | assert.NotEmpty(t, resp.Choices)
177 | assert.Empty(t, resp.Choices[0].Message.Content)
178 | assert.NotNil(t, resp.Choices[0].Message.ToolCalls)
179 | assert.Equal(t, resp.Choices[0].Message.ToolCalls[0].Function.Name, "ChatCompletion")
180 |
181 | parser := jsonparser.NewJSONParserGeneric[openai.ChatCompletionRequest](false)
182 | input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
183 | assert.NoError(t, err)
184 | assert.NotEmpty(t, input)
185 |
186 | // an example output:
187 | // "{
188 | // "model": "gpt-3.5-turbo",
189 | // "messages": [
190 | // {"role": "system", "content": "You are Kirby, a friendly virtual assistant."},
191 | // {"role": "user", "content": "What is the definition of recursion?"}
192 | // ]
193 | // }"
194 | })
195 |
196 | t.Run("directory", func(t *testing.T) {
197 | fi := dispatch.StructToJsonSchema("directory", "Get the contents of a directory", blobNode{})
198 | inp := `root
199 | ├── dir1
200 | │ ├── file1.txt
201 | │ └── file2.txt
202 | └── dir2
203 | ├── file3.txt
204 | └── subfolder
205 | └── file4.txt`
206 |
207 | chatRequest := openai.ChatCompletionRequest{
208 | Model: "gpt-3.5-turbo-0613",
209 | Messages: []openai.ChatCompletionMessage{
210 | {
211 | Role: "user",
212 | Content: inp,
213 | },
214 | },
215 | MaxTokens: 256,
216 | Temperature: 0.0,
217 | Tools: []openai.Tool{
218 | {Type: "function", Function: ptr(openai.FunctionDefinition(fi))},
219 | },
220 | }
221 | ctx := context.Background()
222 | resp, err := api.SendRequest(ctx, chatRequest)
223 | assert.NoError(t, err)
224 | t.Log(resp)
225 | assert.Equal(t, 0, 1)
226 |
227 | parser := jsonparser.NewJSONParserGeneric[blobNode](false)
228 | input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
229 | assert.NoError(t, err)
230 | assert.NotEmpty(t, input)
231 | assert.Equal(t, input, blobNode{
232 | Name: "root",
233 | Children: []blobNode{
234 | {
235 | Name: "dir1",
236 | Children: []blobNode{
237 | {
238 | Name: "file1.txt",
239 | NodeType: "file",
240 | },
241 | {
242 | Name: "file2.txt",
243 | NodeType: "file",
244 | },
245 | },
246 | NodeType: "folder",
247 | },
248 | {
249 | Name: "dir2",
250 | Children: []blobNode{
251 | {
252 | Name: "file3.txt",
253 | NodeType: "file",
254 | },
255 | {
256 | Name: "subfolder",
257 | Children: []blobNode{
258 | {
259 | Name: "file4.txt",
260 | NodeType: "file",
261 | },
262 | },
263 | NodeType: "folder",
264 | },
265 | },
266 | NodeType: "folder",
267 | },
268 | },
269 | NodeType: "folder",
270 | })
271 | })
272 | t.Run("planner", func(t *testing.T) {
273 | fi := dispatch.StructToJsonSchema("queryPlanner", "Plan a multi-step query", queryNode{})
274 | // inp := `Jason is from Canada`
275 |
276 | chatRequest := openai.ChatCompletionRequest{
277 | Model: "gpt-3.5-turbo-0613",
278 | Messages: []openai.ChatCompletionMessage{
279 | {
280 | Role: "system",
281 | Content: `When a user asks a question, you must use the 'queryPlanner' function to answer the question. If you are at all unsure, break the question into multiple smaller questions
282 | Example:
283 | Input: What is the population of Jason's home country?
284 |
285 | Output: What is the population of Jason's home country?
286 | │ ├── What is Jason's home country?
287 | │ ├── What is the population of that country?`,
288 | },
289 | {
290 | Role: "user",
291 | Content: "What's on the flag of Jason's home country?",
292 | },
293 | },
294 | MaxTokens: 256,
295 | Temperature: 0.0,
296 | Tools: []openai.Tool{
297 | {Type: "function", Function: ptr(openai.FunctionDefinition(fi))},
298 | },
299 | }
300 | ctx := context.Background()
301 | resp, err := api.SendRequest(ctx, chatRequest)
302 | assert.NoError(t, err)
303 | t.Log(resp)
304 | assert.Equal(t, 0, 1)
305 |
306 | })
307 | }
308 |
309 | func BenchmarkStructToJsonSchem(b *testing.B) {
310 | b.Run("basic", func(b *testing.B) {
311 | for i := 0; i < b.N; i++ {
312 | dispatch.StructToJsonSchema("queryPlanner", "Plan a multi-step query", queryNode{})
313 | }
314 | })
315 | b.Run("generic", func(b *testing.B) {
316 | for i := 0; i < b.N; i++ {
317 | dispatch.StructToJsonSchemaGeneric[queryNode]("queryPlanner", "Plan a multi-step query")
318 | }
319 | })
320 | }
321 |
--------------------------------------------------------------------------------
/packages/llm/providers/google/google.go:
--------------------------------------------------------------------------------
1 | package google
2 |
3 | import (
4 | "context"
5 | "log/slog"
6 | "strings"
7 | "time"
8 |
9 | "github.com/cespare/xxhash/v2"
10 | "github.com/stillmatic/gollum/packages/llm"
11 | "google.golang.org/api/option"
12 |
13 | "github.com/google/generative-ai-go/genai"
14 | "github.com/pkg/errors"
15 | "google.golang.org/api/iterator"
16 | )
17 |
18 | type Provider struct {
19 | client *genai.Client
20 | cachedFileMap map[string]string
21 | cachedContentMap map[string]struct{}
22 | }
23 |
24 | func NewGoogleProvider(ctx context.Context, apiKey string) (*Provider, error) {
25 | client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
26 | if err != nil {
27 | return nil, errors.Wrap(err, "google client error")
28 | }
29 |
30 | // load cached content map
31 | p := &Provider{client: client}
32 | err = p.refreshCachedContentMap(ctx)
33 | if err != nil {
34 | return nil, errors.Wrap(err, "google refresh cached content map error")
35 | }
36 |
37 | return p, nil
38 | }
39 |
40 | func (p *Provider) refreshCachedFileMap(ctx context.Context) error {
41 | iter := p.client.ListFiles(ctx)
42 | cachedFileMap := make(map[string]string)
43 | for {
44 | cachedFile, err := iter.Next()
45 | if errors.Is(err, iterator.Done) {
46 | break
47 | }
48 | if err != nil {
49 | return errors.Wrap(err, "google list cached files error")
50 | }
51 | cachedFileMap[cachedFile.Name] = cachedFile.URI
52 | }
53 | p.cachedFileMap = cachedFileMap
54 | return nil
55 | }
56 |
57 | func (p *Provider) refreshCachedContentMap(ctx context.Context) error {
58 | iter := p.client.ListCachedContents(ctx)
59 | cachedContentMap := make(map[string]struct{})
60 | for {
61 | cachedContent, err := iter.Next()
62 | if errors.Is(err, iterator.Done) {
63 | break
64 | }
65 | if err != nil {
66 | return errors.Wrap(err, "google list cached content error")
67 | }
68 | cachedContentMap[cachedContent.Name] = struct{}{}
69 | }
70 | p.cachedContentMap = cachedContentMap
71 | return nil
72 | }
73 |
74 | func getHash(value string) string {
75 | return string(xxhash.New().Sum([]byte(value)))
76 | }
77 |
78 | func getHashBytes(value []byte) string {
79 | return string(xxhash.New().Sum(value))
80 | }
81 |
82 | func (p *Provider) uploadFile(ctx context.Context, key string, value string) (*genai.File, error) {
83 | // check if the file is already cached
84 | if _, ok := p.cachedFileMap[key]; ok {
85 | // if so, load the cached file and return it
86 | cachedFile, err := p.client.GetFile(ctx, key)
87 | if err != nil {
88 | return nil, errors.Wrap(err, "google get file error")
89 | }
90 | return cachedFile, nil
91 | }
92 |
93 | r := strings.NewReader(value)
94 | file, err := p.client.UploadFile(ctx, key, r, nil)
95 | if err != nil {
96 | return nil, errors.Wrap(err, "google upload file error")
97 | }
98 | p.cachedFileMap[key] = file.URI
99 |
100 | return file, nil
101 | }
102 |
103 | func (p *Provider) createCachedContent(ctx context.Context, value string, modelName string) (*genai.CachedContent, error) {
104 | key := getHash(value)
105 |
106 | // check if the content is already cached
107 | if _, ok := p.cachedContentMap[key]; ok {
108 | // if so, load the cached content and return it
109 | cachedContent, err := p.client.GetCachedContent(ctx, key)
110 | if err != nil {
111 | return nil, errors.Wrap(err, "google get cached content error")
112 | }
113 | return cachedContent, nil
114 | }
115 |
116 | file, err := p.uploadFile(ctx, key, value)
117 | if err != nil {
118 | return nil, errors.Wrap(err, "error uploading file")
119 | }
120 | fd := genai.FileData{URI: file.URI}
121 | cc := &genai.CachedContent{
122 | Name: key,
123 | Model: modelName,
124 | Contents: []*genai.Content{genai.NewUserContent(fd)},
125 | // TODO: make this configurable
126 | // maybe something like an optional field, 'ephemeral' / 'hour'?
127 | // default matches Anthropic's 5 minute TTL
128 | Expiration: genai.ExpireTimeOrTTL{TTL: 5 * time.Minute},
129 | }
130 | content, err := p.client.CreateCachedContent(ctx, cc)
131 | if err != nil {
132 | return nil, errors.Wrap(err, "error creating cached content")
133 | }
134 |
135 | return content, nil
136 | }
137 |
138 | func (p *Provider) getModel(req llm.InferRequest) *genai.GenerativeModel {
139 | model := p.client.GenerativeModel(req.ModelConfig.ModelName)
140 | model.SetTemperature(req.MessageOptions.Temperature)
141 | model.SetMaxOutputTokens(int32(req.MessageOptions.MaxTokens))
142 | // lol...
143 | model.SafetySettings = []*genai.SafetySetting{
144 | {Category: genai.HarmCategoryHarassment, Threshold: genai.HarmBlockNone},
145 | {Category: genai.HarmCategoryHateSpeech, Threshold: genai.HarmBlockNone},
146 | {Category: genai.HarmCategorySexuallyExplicit, Threshold: genai.HarmBlockNone},
147 | {Category: genai.HarmCategoryDangerousContent, Threshold: genai.HarmBlockNone},
148 | }
149 | model.SetCandidateCount(1)
150 | return model
151 | }
152 |
153 | func (p *Provider) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) {
154 | if len(req.Messages) > 1 {
155 | return p.generateResponseChat(ctx, req)
156 | }
157 |
158 | // it is slightly better to build a trie, indexed on hashes of each message
159 | // since we can quickly get based on the prefix (i.e. existing messages)
160 | // but ... your number of messages is probably not THAT high to justify the complexity.
161 | // NB: trie also not useful for images/audio
162 | messagesToCache := make([]llm.InferMessage, 0)
163 | for _, message := range req.Messages {
164 | if message.ShouldCache {
165 | messagesToCache = append(messagesToCache, message)
166 | }
167 | }
168 | model := p.getModel(req)
169 | if len(messagesToCache) > 0 {
170 | // hash the messages and check if the overall object is cached.
171 | // we choose to do this because you may have a later message identical to an earlier message
172 | // if we find exact match for this set of messages, load it.
173 | hashKeys := make([]string, 0, len(messagesToCache))
174 | for _, message := range messagesToCache {
175 | // it is possible to have collision between user + assistant content being identical
176 | // this feels like a rare case especially given that we are ordering sensitive in the hash.
177 | hashKeys = append(hashKeys, getHash(message.Content))
178 | if len(message.Image) > 0 {
179 | hashKeys = append(hashKeys, getHashBytes(message.Image))
180 | }
181 | if len(message.Audio) > 0 {
182 | hashKeys = append(hashKeys, getHashBytes(message.Audio))
183 | }
184 | }
185 | joinedKey := strings.Join(hashKeys, "/")
186 | var cachedContent *genai.CachedContent
187 | // if the cached object exists, load it
188 | if _, ok := p.cachedContentMap[joinedKey]; ok {
189 | cachedContent, _ = p.client.GetCachedContent(ctx, joinedKey)
190 | model = p.client.GenerativeModelFromCachedContent(cachedContent)
191 | } else {
192 | // otherwise, create a new cached object
193 | cc, err := p.createCachedContent(ctx, joinedKey, req.ModelConfig.ModelName)
194 | if err != nil {
195 | return "", errors.Wrap(err, "google upload file error")
196 | }
197 | model = p.client.GenerativeModelFromCachedContent(cc)
198 | }
199 | }
200 |
201 | parts := singleTurnMessageToParts(req.Messages[0])
202 |
203 | resp, err := model.GenerateContent(ctx, parts...)
204 | if err != nil {
205 | return "", errors.Wrap(err, "google generate content error")
206 | }
207 | respStr := flattenResponse(resp)
208 |
209 | return respStr, nil
210 | }
211 |
212 | func singleTurnMessageToParts(message llm.InferMessage) []genai.Part {
213 | parts := []genai.Part{genai.Text(message.Content)}
214 | if len(message.Image) > 0 {
215 | // TODO: set the image type based on the actual image type
216 | parts = append(parts, genai.ImageData("png", message.Image))
217 | }
218 | // if len(message.Audio) > 0 {
219 | // // TODO: set the audio type based on the actual audio type
220 | // parts = append(parts, genai.("wav", message.Audio))
221 | // }
222 |
223 | return parts
224 | }
225 |
226 | func multiTurnMessageToParts(messages []llm.InferMessage) ([]*genai.Content, *genai.Content) {
227 | sysInstructionParts := make([]genai.Part, 0)
228 | hist := make([]*genai.Content, 0, len(messages))
229 | for _, message := range messages {
230 | parts := []genai.Part{genai.Text(message.Content)}
231 | if len(message.Image) > 0 {
232 | parts = append(parts, genai.ImageData("png", message.Image))
233 | }
234 | if message.Role == "system" {
235 | sysInstructionParts = append(sysInstructionParts, parts...)
236 | continue
237 | }
238 | hist = append(hist, &genai.Content{
239 | Parts: parts,
240 | Role: message.Role,
241 | })
242 | }
243 | if len(sysInstructionParts) > 0 {
244 | return hist, &genai.Content{
245 | Parts: sysInstructionParts,
246 | }
247 | }
248 |
249 | return hist, nil
250 | }
251 |
252 | func (p *Provider) generateResponseChat(ctx context.Context, req llm.InferRequest) (string, error) {
253 | model := p.getModel(req)
254 |
255 | // annoyingly, the last message is the one we want to generate a response to, so we need to split it out
256 | msgs, sysInstr := multiTurnMessageToParts(req.Messages[:len(req.Messages)-1])
257 | if sysInstr != nil {
258 | model.SystemInstruction = sysInstr
259 | }
260 |
261 | cs := model.StartChat()
262 | cs.History = msgs
263 | mostRecentMessage := req.Messages[len(req.Messages)-1]
264 |
265 | // NB chua: this might be a bug but Google doesn't seem to accept multiple parts in the same message
266 | // in the chat API. So can't send text + image if it exists.
267 | //mostRecentMessagePart := []genai.Part{genai.Text(mostRecentMessage.Content)}
268 | //if mostRecentMessage.Image != nil && len(*mostRecentMessage.Image) > 0 {
269 | // mostRecentMessagePart = append(mostRecentMessagePart, genai.ImageData("png", *mostRecentMessage.Image))
270 | //}
271 |
272 | resp, err := cs.SendMessage(ctx, genai.Text(mostRecentMessage.Content))
273 | if err != nil {
274 | return "", errors.Wrap(err, "google generate content error")
275 | }
276 | respStr := flattenResponse(resp)
277 |
278 | return respStr, nil
279 | }
280 |
281 | func (p *Provider) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) {
282 | if len(req.Messages) > 1 {
283 | return p.generateResponseAsyncChat(ctx, req)
284 | }
285 | return p.generateResponseAsyncSingle(ctx, req)
286 | }
287 |
288 | func (p *Provider) generateResponseAsyncSingle(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) {
289 | outChan := make(chan llm.StreamDelta)
290 |
291 | go func() {
292 | defer close(outChan)
293 |
294 | model := p.getModel(req)
295 |
296 | parts := singleTurnMessageToParts(req.Messages[0])
297 | iter := model.GenerateContentStream(ctx, parts...)
298 |
299 | for {
300 | resp, err := iter.Next()
301 | if errors.Is(err, iterator.Done) {
302 | outChan <- llm.StreamDelta{EOF: true}
303 | break
304 | }
305 | if err != nil {
306 | slog.Error("error from gemini stream", "err", err, "req", req.Messages[0].Content, "model", req.ModelConfig.ModelName)
307 | return
308 | }
309 |
310 | content := flattenResponse(resp)
311 | if content != "" {
312 | select {
313 | case <-ctx.Done():
314 | return
315 | case outChan <- llm.StreamDelta{Text: content}:
316 | }
317 | }
318 | }
319 | }()
320 |
321 | return outChan, nil
322 | }
323 |
324 | func (p *Provider) generateResponseAsyncChat(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) {
325 | outChan := make(chan llm.StreamDelta)
326 |
327 | go func() {
328 | defer close(outChan)
329 |
330 | model := p.getModel(req)
331 | msgs, sysInstr := multiTurnMessageToParts(req.Messages[:len(req.Messages)-1])
332 | if sysInstr != nil {
333 | model.SystemInstruction = sysInstr
334 | }
335 | cs := model.StartChat()
336 | cs.History = msgs
337 |
338 | mostRecentMessage := req.Messages[len(req.Messages)-1]
339 |
340 | iter := cs.SendMessageStream(ctx, genai.Text(mostRecentMessage.Content))
341 |
342 | for {
343 | resp, err := iter.Next()
344 | if errors.Is(err, iterator.Done) {
345 | outChan <- llm.StreamDelta{EOF: true}
346 | break
347 | }
348 | if err != nil {
349 | slog.Error("error from gemini stream", "err", err, "req", mostRecentMessage.Content, "model", req.ModelConfig.ModelName)
350 | return
351 | }
352 |
353 | content := flattenResponse(resp)
354 | if content != "" {
355 | select {
356 | case <-ctx.Done():
357 | return
358 | case outChan <- llm.StreamDelta{Text: content}:
359 | }
360 | }
361 | }
362 | }()
363 |
364 | return outChan, nil
365 | }
366 |
367 | // flattenResponse flattens the response from the Gemini API into a single string.
368 | func flattenResponse(resp *genai.GenerateContentResponse) string {
369 | var rtn strings.Builder
370 | for i, part := range resp.Candidates[0].Content.Parts {
371 | switch part := part.(type) {
372 | case genai.Text:
373 | if i > 0 {
374 | rtn.WriteString(" ")
375 | }
376 | rtn.WriteString(string(part))
377 | }
378 | }
379 | return rtn.String()
380 | }
381 |
382 | // GenerateEmbedding generates embeddings for the given input.
383 | //
384 | // NB chua: This is a confusing method in the docs.
385 | // - There are two separate API methods and it's unclear which you should use. Is batch with 1 the same as single?
386 | // - What's the maximum number of docs to embed at once?
387 | // - TaskType is automatically set..? I don't see how to configure it ...?
388 | // see also https://pkg.go.dev/github.com/google/generative-ai-go/genai#TaskType
389 | func (p *Provider) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) {
390 | em := p.client.EmbeddingModel(req.ModelConfig.ModelName)
391 |
392 | // if there is only one input, use the single API
393 | if len(req.Input) == 1 {
394 | resp, err := em.EmbedContent(ctx, genai.Text(req.Input[0]))
395 | if err != nil {
396 | return nil, errors.Wrap(err, "google embedding error")
397 | }
398 |
399 | return &llm.EmbeddingResponse{Data: []llm.Embedding{{Values: resp.Embedding.Values}}}, nil
400 | }
401 | // otherwise, use the batch API. I'm not sure there's much difference though...
402 | batchReq := em.NewBatch()
403 | for _, input := range req.Input {
404 | batchReq.AddContent(genai.Text(input))
405 | }
406 | resp, err := em.BatchEmbedContents(ctx, batchReq)
407 | if err != nil {
408 | return nil, errors.Wrap(err, "google batch embedding error")
409 | }
410 |
411 | respVectors := make([]llm.Embedding, len(resp.Embeddings))
412 | for i, v := range resp.Embeddings {
413 | respVectors[i] = llm.Embedding{
414 | Values: v.Values,
415 | }
416 | }
417 |
418 | return &llm.EmbeddingResponse{
419 | Data: respVectors,
420 | }, nil
421 | }
422 |
--------------------------------------------------------------------------------