├── bin ├── hermit.hcl ├── go ├── gofmt ├── .go-1.23.4.pkg ├── README.hermit.md ├── activate-hermit └── hermit ├── testdata ├── docstore.json └── simple_store.json ├── renovate.json ├── .idea ├── vcs.xml ├── .gitignore ├── modules.xml └── gollum.iml ├── packages ├── agents │ ├── agent.go │ ├── calcagent_test.go │ └── calcagent.go ├── tools │ ├── tools.go │ ├── calculator_test.go │ ├── readme.md │ └── calculator.go ├── llm │ ├── readme.md │ ├── modelconfigs.go │ ├── cache │ │ └── cache.go │ ├── llm.go │ ├── providers │ │ ├── cached │ │ │ ├── cached_provider.go │ │ │ ├── cached_embedder.go │ │ │ ├── sqlitecache │ │ │ │ └── sqlite.go │ │ │ └── cached_provider_test.go │ │ ├── voyage │ │ │ └── voyage.go │ │ ├── mixedbread │ │ │ └── mixedbread.go │ │ ├── anthropic │ │ │ └── anthropic.go │ │ ├── openai │ │ │ └── openai.go │ │ ├── vertex │ │ │ └── vertex.go │ │ └── google │ │ │ └── google.go │ ├── internal │ │ └── mocks │ │ │ └── llm.go │ └── configs.go ├── hyde │ ├── README.md │ ├── hyde_test.go │ └── hyde.go ├── syncpool │ └── syncpool.go ├── queryplanner │ ├── queryplanner.go │ └── queryplanner_test.go ├── docstore │ ├── docstore_test.go │ └── docstore.go ├── dispatch │ ├── functions.go │ ├── dispatch.go │ ├── dispatch_test.go │ └── functions_test.go ├── jsonparser │ ├── parser.go │ └── parser_test.go └── vectorstore │ ├── vectorstore.go │ ├── vectorstore_memory.go │ ├── vectorstore.md │ ├── vectorstore_compressed.go │ ├── vectorstore_test.go │ └── vectorstore_compressed_test.go ├── gollum.go ├── docs └── best_practices.md ├── .gitignore ├── llm.go ├── LICENSE ├── internal ├── hash │ ├── hash_bench_test.go │ └── readme.md ├── testutil │ └── utils.go └── mocks │ └── llm.go ├── go.mod ├── math_test.go ├── README.md └── bench.md /bin/hermit.hcl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bin/go: -------------------------------------------------------------------------------- 1 | .go-1.23.4.pkg -------------------------------------------------------------------------------- /bin/gofmt: -------------------------------------------------------------------------------- 1 | .go-1.23.4.pkg -------------------------------------------------------------------------------- /bin/.go-1.23.4.pkg: -------------------------------------------------------------------------------- 1 | hermit -------------------------------------------------------------------------------- /testdata/docstore.json: -------------------------------------------------------------------------------- 1 | {"1":{"id":"1","content":"test data"},"2":{"id":"2","content":"test data 2"}} -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": [ 4 | "config:base" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /packages/agents/agent.go: -------------------------------------------------------------------------------- 1 | package agents 2 | 3 | import "context" 4 | 5 | type Agent interface { 6 | Name() string 7 | Description() string 8 | Run(context.Context, interface{}) (interface{}, error) 9 | } 10 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /testdata/simple_store.json: -------------------------------------------------------------------------------- 1 | [{"id":"02a4d737-121b-4845-9649-64781326fdfd","content":"Apple","embedding":[0.1,0.1,0.1]},{"id":"15d85177-1b7e-46fe-9489-77ef27f10ead","content":"Orange","embedding":[0.1,1.1,1.1]},{"id":"f5c42d8f-9c17-424d-934c-88192d4c336f","content":"Basketball","embedding":[0.1,2.1,2.1]}] -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /packages/tools/tools.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import "context" 4 | 5 | // Tool is an incredibly generic interface for a tool that an LLM can use. 6 | type Tool interface { 7 | Name() string 8 | Description() string 9 | Run(ctx context.Context, input interface{}) (interface{}, error) 10 | } 11 | -------------------------------------------------------------------------------- /bin/README.hermit.md: -------------------------------------------------------------------------------- 1 | # Hermit environment 2 | 3 | This is a [Hermit](https://github.com/cashapp/hermit) bin directory. 4 | 5 | The symlinks in this directory are managed by Hermit and will automatically 6 | download and install Hermit itself as well as packages. These packages are 7 | local to this environment. 8 | -------------------------------------------------------------------------------- /.idea/gollum.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /packages/llm/readme.md: -------------------------------------------------------------------------------- 1 | # llm 2 | 3 | Useful abstractions for working with LLMs. 4 | 5 | Features: 6 | 7 | - synchronous and async wrappers (streaming output) 8 | - prompt caching for supported providers 9 | - automatically load supported providers from environment variables 10 | 11 | We support 12 | 13 | - Anthropic 14 | - Google Gemini 15 | - OpenAI 16 | - OpenAI compatible providers (Together, Groq, Hyperbolic, Deepseek, ...) 17 | -------------------------------------------------------------------------------- /gollum.go: -------------------------------------------------------------------------------- 1 | package gollum 2 | 3 | import ( 4 | "github.com/google/uuid" 5 | ) 6 | 7 | type Document struct { 8 | ID string `json:"id"` 9 | Content string `json:"content,omitempty"` 10 | Embedding []float32 `json:"embedding,omitempty"` 11 | Metadata map[string]interface{} `json:"metadata,omitempty"` 12 | } 13 | 14 | func NewDocumentFromString(content string) Document { 15 | return Document{ 16 | ID: uuid.New().String(), 17 | Content: content, 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /docs/best_practices.md: -------------------------------------------------------------------------------- 1 | # Best Practices 2 | 3 | Collection of tips and tricks for working with LLM's in Go. 4 | 5 | ## Retry 6 | 7 | OpenAI has a somewhat flaky rate limit, and it's easy to hit it. You can use a [retryablehttp](https://pkg.go.dev/github.com/hashicorp/go-retryablehttp) client to retry requests automatically: 8 | 9 | ```go 10 | retryableClient := retryablehttp.NewClient() 11 | retryableClient.RetryMax = 5 12 | oaiCfg := openai.DefaultConfig(mustGetEnv("OPENAI_API_KEY")) 13 | oaiCfg.HTTPClient = retryableClient.StandardClient() 14 | oai := openai.NewClientWithConfig(oaiCfg) 15 | ``` 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .hermit/ 2 | enwik8 3 | 4 | # If you prefer the allow list template instead of the deny list, see community template: 5 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 6 | # 7 | # Binaries for programs and plugins 8 | *.exe 9 | *.exe~ 10 | *.dll 11 | *.so 12 | *.dylib 13 | 14 | # Test binary, built with `go test -c` 15 | *.test 16 | 17 | # Output of the go coverage tool, specifically when used with LiteIDE 18 | *.out 19 | 20 | # Dependency directories (remove the comment below to include it) 21 | # vendor/ 22 | 23 | # Go workspace file 24 | go.work 25 | -------------------------------------------------------------------------------- /bin/activate-hermit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This file must be used with "source bin/activate-hermit" from bash or zsh. 3 | # You cannot run it directly 4 | # 5 | # THIS FILE IS GENERATED; DO NOT MODIFY 6 | 7 | if [ "${BASH_SOURCE-}" = "$0" ]; then 8 | echo "You must source this script: \$ source $0" >&2 9 | exit 33 10 | fi 11 | 12 | BIN_DIR="$(dirname "${BASH_SOURCE[0]:-${(%):-%x}}")" 13 | if "${BIN_DIR}/hermit" noop > /dev/null; then 14 | eval "$("${BIN_DIR}/hermit" activate "${BIN_DIR}/..")" 15 | 16 | if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ]; then 17 | hash -r 2>/dev/null 18 | fi 19 | 20 | echo "Hermit environment $("${HERMIT_ENV}"/bin/hermit env HERMIT_ENV) activated" 21 | fi 22 | -------------------------------------------------------------------------------- /packages/hyde/README.md: -------------------------------------------------------------------------------- 1 | # HyDE 2 | 3 | This module is an implementation of [HyDE: Precise Zero-Shot Dense Retrieval without Relevance Labels](https://github.com/texttron/hyde). 4 | 5 | We differ from the reference Python implementation by using an in-memory exact search (instead of FAISS) and OpenAI Ada embeddings instead of Contriever. Realistically you should expect slightly worse performance (exact search is slower than approximate search, calling OpenAI is probably slower than a local model for smaller batch sizes). However, the results should still be valid. 6 | 7 | However, this modules _only_ expects the interfaces, not the actual implementations -- so you could implement a FAISS interface that connects to a docker instance. Or the same for an embedding service, just make a wrapper compatible with the OpenAI interface. -------------------------------------------------------------------------------- /packages/syncpool/syncpool.go: -------------------------------------------------------------------------------- 1 | // Package syncpool provides a generic wrapper around sync.Pool 2 | // Copied from https://github.com/mkmik/syncpool 3 | package syncpool 4 | 5 | import ( 6 | "sync" 7 | ) 8 | 9 | // A Pool is a generic wrapper around a sync.Pool. 10 | type Pool[T any] struct { 11 | pool sync.Pool 12 | } 13 | 14 | // New creates a new Pool with the provided new function. 15 | // 16 | // The equivalent sync.Pool construct is "sync.Pool{New: fn}" 17 | func New[T any](fn func() T) Pool[T] { 18 | return Pool[T]{ 19 | pool: sync.Pool{New: func() interface{} { return fn() }}, 20 | } 21 | } 22 | 23 | // Get is a generic wrapper around sync.Pool's Get method. 24 | func (p *Pool[T]) Get() T { 25 | return p.pool.Get().(T) 26 | } 27 | 28 | // Get is a generic wrapper around sync.Pool's Put method. 29 | func (p *Pool[T]) Put(x T) { 30 | p.pool.Put(x) 31 | } 32 | -------------------------------------------------------------------------------- /llm.go: -------------------------------------------------------------------------------- 1 | //go:generate mockgen -source llm.go -destination internal/mocks/llm.go 2 | 3 | package gollum 4 | 5 | import ( 6 | "context" 7 | 8 | "github.com/sashabaranov/go-openai" 9 | ) 10 | 11 | // Deprecated: who even uses this anymore? 12 | type Completer interface { 13 | CreateCompletion(context.Context, openai.CompletionRequest) (openai.CompletionResponse, error) 14 | } 15 | 16 | // Deprecated: use packages/llm/llm.go implementation 17 | type ChatCompleter interface { 18 | CreateChatCompletion(context.Context, openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) 19 | } 20 | 21 | type Embedder interface { 22 | CreateEmbeddings(context.Context, openai.EmbeddingRequest) (openai.EmbeddingResponse, error) 23 | } 24 | 25 | type Moderator interface { 26 | Moderations(context.Context, openai.ModerationRequest) (openai.ModerationResponse, error) 27 | } 28 | -------------------------------------------------------------------------------- /packages/llm/modelconfigs.go: -------------------------------------------------------------------------------- 1 | package llm 2 | 3 | type ModelConfigStore struct { 4 | configs map[string]ModelConfig 5 | } 6 | 7 | func NewModelConfigStore() *ModelConfigStore { 8 | return &ModelConfigStore{ 9 | configs: configs, 10 | } 11 | } 12 | 13 | func NewModelConfigStoreWithConfigs(configs map[string]ModelConfig) *ModelConfigStore { 14 | return &ModelConfigStore{ 15 | configs: configs, 16 | } 17 | } 18 | 19 | func (m *ModelConfigStore) GetConfig(configName string) (ModelConfig, bool) { 20 | config, ok := m.configs[configName] 21 | return config, ok 22 | } 23 | 24 | func (m *ModelConfigStore) GetConfigNames() []string { 25 | var configNames []string 26 | for k := range m.configs { 27 | configNames = append(configNames, k) 28 | } 29 | return configNames 30 | } 31 | 32 | func (m *ModelConfigStore) AddConfig(configName string, config ModelConfig) { 33 | m.configs[configName] = config 34 | } 35 | -------------------------------------------------------------------------------- /packages/tools/calculator_test.go: -------------------------------------------------------------------------------- 1 | package tools_test 2 | 3 | import ( 4 | "context" 5 | tools2 "github.com/stillmatic/gollum/packages/tools" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestCalculator(t *testing.T) { 12 | calc := tools2.CalculatorTool{} 13 | var _ tools2.Tool = &calc 14 | ctx := context.Background() 15 | t.Run("simple", func(t *testing.T) { 16 | calcInput := tools2.CalculatorInput{ 17 | Expression: "1 + 1", 18 | } 19 | output, err := calc.Run(ctx, &calcInput) 20 | assert.NoError(t, err) 21 | assert.Equal(t, "2", output) 22 | }) 23 | 24 | t.Run("simple with env", func(t *testing.T) { 25 | env := map[string]interface{}{ 26 | "foo": 1, 27 | "bar": "baz", 28 | } 29 | 30 | calcInput := tools2.CalculatorInput{ 31 | Expression: "foo + foo", 32 | Environment: env, 33 | } 34 | output, err := calc.Run(ctx, calcInput) 35 | assert.NoError(t, err) 36 | assert.Equal(t, "2", output) 37 | }) 38 | } 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Christopher Hua 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | 9 | -------------------------------------------------------------------------------- /packages/queryplanner/queryplanner.go: -------------------------------------------------------------------------------- 1 | package queryplanner 2 | 3 | type QueryNode struct { 4 | ID int `json:"id" jsonschema_description:"unique query id - incrementing integer"` 5 | Question string `json:"question" jsonschema:"required" jsonschema_description:"Question we are asking using a question answer system, if we are asking multiple questions, this question is asked by also providing the answers to the sub questions"` 6 | // NodeType string `json:"node_type" jsonschema:"required,enum=single_question,enum=merge_responses" jsonschema_description:"type of question. Either a single question or a multi question merge when there are multiple questions."` 7 | DependencyIDs []int `json:"dependency_ids,omitempty" jsonschema_description:"list of sub-question ID's that need to be answered before this question can be answered. Use a subquery when anything may be unknown, and we need to ask multiple questions to get the answer. Dependencies must only be other query IDs."` 8 | } 9 | 10 | type QueryPlan struct { 11 | Nodes []QueryNode `json:"nodes" jsonschema_description:"list of questions to ask"` 12 | } 13 | -------------------------------------------------------------------------------- /packages/tools/readme.md: -------------------------------------------------------------------------------- 1 | # tools 2 | 3 | This is an implementation of a Tools interface for LLM's to interact with. Tools are entirely arbitrary and can do whatever you want, and just have a very simple API: 4 | 5 | ```go 6 | type Tool interface { 7 | Name() string 8 | Description() string 9 | Run(ctx context.Context, input interface{}) (interface{}, error) 10 | } 11 | ``` 12 | 13 | A simple implementation would then be something like 14 | 15 | ```go 16 | type CalculatorInput struct { 17 | Expression string `json:"expression" jsonschema:"required" jsonschema_description:"mathematical expression to evaluate"` 18 | Environment map[string]interface{} `json:"environment,omitempty" jsonschema_description:"optional environment variables to use when evaluating the expression"` 19 | } 20 | 21 | type CalculatorTool struct{} 22 | 23 | // Run evaluates a mathematical expression and returns it as a string. 24 | func (c *CalculatorTool) Run(ctx context.Context, input interface{}) (interface{}, error) { 25 | // do stuff 26 | } 27 | ``` 28 | 29 | Note that the input struct has a lot of JSON Schema mappings -- this is so that we can feed the description directly to OpenAI! -------------------------------------------------------------------------------- /packages/llm/cache/cache.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/stillmatic/gollum/packages/llm" 7 | ) 8 | 9 | // Cache defines the interface for caching LLM responses and embeddings 10 | type Cache interface { 11 | // GetResponse retrieves a cached response for a given request 12 | GetResponse(ctx context.Context, req llm.InferRequest) (string, error) 13 | 14 | // SetResponse stores a response for a given request 15 | SetResponse(ctx context.Context, req llm.InferRequest, response string) error 16 | 17 | // GetEmbedding retrieves a cached embedding for a given input and model config 18 | GetEmbedding(ctx context.Context, modelConfig string, input string) ([]float32, error) 19 | 20 | // SetEmbedding stores an embedding for a given input and model config 21 | SetEmbedding(ctx context.Context, modelConfig string, input string, embedding []float32) error 22 | 23 | // Close closes the cache, releasing any resources 24 | Close() error 25 | 26 | // GetStats returns cache statistics (e.g., number of requests, cache hits) 27 | GetStats() CacheStats 28 | } 29 | 30 | // CacheStats represents cache usage statistics 31 | type CacheStats struct { 32 | NumRequests int 33 | NumCacheHits int 34 | } 35 | -------------------------------------------------------------------------------- /packages/tools/calculator.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "strconv" 6 | 7 | "github.com/antonmedv/expr" 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | type CalculatorInput struct { 12 | Expression string `json:"expression" jsonschema:"required" jsonschema_description:"mathematical expression to evaluate"` 13 | Environment map[string]interface{} `json:"environment,omitempty" jsonschema_description:"optional environment variables to use when evaluating the expression"` 14 | } 15 | 16 | type CalculatorTool struct{} 17 | 18 | func (c *CalculatorTool) Name() string { 19 | return "calculator" 20 | } 21 | 22 | func (c *CalculatorTool) Description() string { 23 | return "evaluate mathematical expressions" 24 | } 25 | 26 | // Run evaluates a mathematical expression and returns it as a string. 27 | func (c *CalculatorTool) Run(ctx context.Context, input interface{}) (interface{}, error) { 28 | cinput, ok := input.(CalculatorInput) 29 | if !ok { 30 | return "", errors.New("invalid input") 31 | } 32 | 33 | output, err := expr.Eval(cinput.Expression, cinput.Environment) 34 | if err != nil { 35 | return "", errors.Wrap(err, "couldn't run expression") 36 | } 37 | switch t := output.(type) { 38 | case string: 39 | return t, nil 40 | case int: 41 | return strconv.Itoa(t), nil 42 | case float64: 43 | return strconv.FormatFloat(t, 'f', -1, 64), nil 44 | default: 45 | return "", errors.New("invalid output") 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /bin/hermit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # THIS FILE IS GENERATED; DO NOT MODIFY 4 | 5 | set -eo pipefail 6 | 7 | export HERMIT_USER_HOME=~ 8 | 9 | if [ -z "${HERMIT_STATE_DIR}" ]; then 10 | case "$(uname -s)" in 11 | Darwin) 12 | export HERMIT_STATE_DIR="${HERMIT_USER_HOME}/Library/Caches/hermit" 13 | ;; 14 | Linux) 15 | export HERMIT_STATE_DIR="${XDG_CACHE_HOME:-${HERMIT_USER_HOME}/.cache}/hermit" 16 | ;; 17 | esac 18 | fi 19 | 20 | export HERMIT_DIST_URL="${HERMIT_DIST_URL:-https://github.com/cashapp/hermit/releases/download/stable}" 21 | HERMIT_CHANNEL="$(basename "${HERMIT_DIST_URL}")" 22 | export HERMIT_CHANNEL 23 | export HERMIT_EXE=${HERMIT_EXE:-${HERMIT_STATE_DIR}/pkg/hermit@${HERMIT_CHANNEL}/hermit} 24 | 25 | if [ ! -x "${HERMIT_EXE}" ]; then 26 | echo "Bootstrapping ${HERMIT_EXE} from ${HERMIT_DIST_URL}" 1>&2 27 | INSTALL_SCRIPT="$(mktemp)" 28 | # This value must match that of the install script 29 | INSTALL_SCRIPT_SHA256="180e997dd837f839a3072a5e2f558619b6d12555cd5452d3ab19d87720704e38" 30 | if [ "${INSTALL_SCRIPT_SHA256}" = "BYPASS" ]; then 31 | curl -fsSL "${HERMIT_DIST_URL}/install.sh" -o "${INSTALL_SCRIPT}" 32 | else 33 | # Install script is versioned by its sha256sum value 34 | curl -fsSL "${HERMIT_DIST_URL}/install-${INSTALL_SCRIPT_SHA256}.sh" -o "${INSTALL_SCRIPT}" 35 | # Verify install script's sha256sum 36 | openssl dgst -sha256 "${INSTALL_SCRIPT}" | \ 37 | awk -v EXPECTED="$INSTALL_SCRIPT_SHA256" \ 38 | '$2!=EXPECTED {print "Install script sha256 " $2 " does not match " EXPECTED; exit 1}' 39 | fi 40 | /bin/bash "${INSTALL_SCRIPT}" 1>&2 41 | fi 42 | 43 | exec "${HERMIT_EXE}" --level=fatal exec "$0" -- "$@" 44 | -------------------------------------------------------------------------------- /packages/docstore/docstore_test.go: -------------------------------------------------------------------------------- 1 | package docstore_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | . "github.com/stillmatic/gollum" 8 | "github.com/stillmatic/gollum/packages/docstore" 9 | "github.com/stretchr/testify/assert" 10 | "gocloud.dev/blob/fileblob" 11 | ) 12 | 13 | func TestMemoryDocStore(t *testing.T) { 14 | ctx := context.Background() 15 | store := docstore.NewMemoryDocStore() 16 | doc := Document{ID: "1", Content: "test data"} 17 | doc2 := Document{ID: "2", Content: "test data 2"} 18 | 19 | // ensure store implements the DocStore interface 20 | var _ docstore.DocStore = store 21 | 22 | t.Run("Insert document", func(t *testing.T) { 23 | err := store.Insert(ctx, doc) 24 | assert.NoError(t, err) 25 | err = store.Insert(ctx, doc2) 26 | assert.NoError(t, err) 27 | }) 28 | 29 | t.Run("Retrieve document", func(t *testing.T) { 30 | retrievedDoc, err := store.Retrieve(ctx, "1") 31 | assert.NoError(t, err) 32 | assert.Equal(t, doc, retrievedDoc) 33 | }) 34 | 35 | t.Run("Retrieve non-existing document", func(t *testing.T) { 36 | _, err := store.Retrieve(ctx, "non-existing-id") 37 | assert.Error(t, err) 38 | }) 39 | 40 | t.Run("Persist document store", func(t *testing.T) { 41 | // persist to testdata/docstore.json 42 | bucket, err := fileblob.OpenBucket("testdata", nil) 43 | assert.NoError(t, err) 44 | err = store.Persist(ctx, bucket, "docstore.json") 45 | assert.NoError(t, err) 46 | }) 47 | 48 | t.Run("Load document store from disk", func(t *testing.T) { 49 | // load from testdata/docstore.json 50 | bucket, err := fileblob.OpenBucket("testdata", nil) 51 | assert.NoError(t, err) 52 | loadedStore, err := docstore.NewMemoryDocStoreFromDisk(ctx, bucket, "docstore.json") 53 | assert.NoError(t, err) 54 | assert.Equal(t, store, loadedStore) 55 | }) 56 | 57 | } 58 | -------------------------------------------------------------------------------- /internal/hash/hash_bench_test.go: -------------------------------------------------------------------------------- 1 | package hash 2 | 3 | // copied from https://gist.github.com/wizjin/e103e1040db0c4c427db4104cce67566 4 | 5 | import ( 6 | "crypto/md5" 7 | "crypto/rand" 8 | "crypto/sha1" 9 | "crypto/sha256" 10 | "crypto/sha512" 11 | "github.com/cespare/xxhash/v2" 12 | "hash" 13 | "hash/fnv" 14 | "testing" 15 | ) 16 | 17 | // NB chua: we care about caching / hashing lots of tokens. 18 | // That probably starts mattering at O(1k), O(10k), O(100k) tokens. 19 | // note that there are ~4 characters per token, and 1-4 bytes per character in UTF-8 20 | // so 1k tokens is 4k-16k bytes, 10k tokens is 40k-160k bytes, 100k tokens is 400k-1.6M bytes. 21 | const ( 22 | K = 1024 23 | DATALEN = 512 * K 24 | ) 25 | 26 | func runHash(b *testing.B, h hash.Hash, n int) { 27 | var data = make([]byte, n) 28 | rand.Read(data) 29 | b.ResetTimer() 30 | 31 | for i := 0; i < b.N; i++ { 32 | h.Write(data) 33 | h.Sum(nil) 34 | } 35 | } 36 | 37 | func BenchmarkFNV32(b *testing.B) { 38 | runHash(b, fnv.New32(), DATALEN) 39 | } 40 | 41 | func BenchmarkFNV64(b *testing.B) { 42 | runHash(b, fnv.New64(), DATALEN) 43 | } 44 | 45 | func BenchmarkFNV128(b *testing.B) { 46 | runHash(b, fnv.New128(), DATALEN) 47 | } 48 | 49 | func BenchmarkMD5(b *testing.B) { 50 | runHash(b, md5.New(), DATALEN) 51 | } 52 | 53 | func BenchmarkSHA1(b *testing.B) { 54 | runHash(b, sha1.New(), DATALEN) 55 | } 56 | 57 | func BenchmarkSHA224(b *testing.B) { 58 | runHash(b, sha256.New224(), DATALEN) 59 | } 60 | 61 | func BenchmarkSHA256(b *testing.B) { 62 | runHash(b, sha256.New(), DATALEN) 63 | } 64 | 65 | func BenchmarkSHA512(b *testing.B) { 66 | runHash(b, sha512.New(), DATALEN) 67 | } 68 | 69 | // func BenchmarkMurmur3(b *testing.B) { 70 | // runHash(b, murmur3.New32(), DATALEN) 71 | // } 72 | func BenchmarkXxhash(b *testing.B) { 73 | runHash(b, xxhash.New(), DATALEN) 74 | } 75 | -------------------------------------------------------------------------------- /packages/agents/calcagent_test.go: -------------------------------------------------------------------------------- 1 | package agents_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "github.com/stillmatic/gollum/packages/agents" 7 | "github.com/stillmatic/gollum/packages/tools" 8 | "os" 9 | "testing" 10 | 11 | "github.com/sashabaranov/go-openai" 12 | mock_gollum "github.com/stillmatic/gollum/internal/mocks" 13 | "github.com/stretchr/testify/assert" 14 | "go.uber.org/mock/gomock" 15 | ) 16 | 17 | func TestCalcAgentMocked(t *testing.T) { 18 | ctrl := gomock.NewController(t) 19 | llm := mock_gollum.NewMockChatCompleter(ctrl) 20 | agent := agents.NewCalcAgent(llm) 21 | assert.NotNil(t, agent) 22 | ctx := context.Background() 23 | inp := agents.CalcAgentInput{ 24 | Content: "What's 2 + 2?", 25 | } 26 | expectedResp := tools.CalculatorInput{ 27 | Expression: "2 + 2", 28 | } 29 | expectedBytes, err := json.Marshal(expectedResp) 30 | assert.NoError(t, err) 31 | expectedChatCompletionResp := openai.ChatCompletionResponse{ 32 | Choices: []openai.ChatCompletionChoice{ 33 | { 34 | Message: openai.ChatCompletionMessage{ 35 | Role: openai.ChatMessageRoleAssistant, 36 | Content: "", 37 | ToolCalls: []openai.ToolCall{ 38 | {Function: openai.FunctionCall{Name: "calc", Arguments: string(expectedBytes)}}, 39 | }, 40 | }, 41 | }, 42 | }, 43 | } 44 | llm.EXPECT().CreateChatCompletion(ctx, gomock.Any()).Return(expectedChatCompletionResp, nil) 45 | 46 | resp, err := agent.Run(ctx, inp) 47 | assert.NoError(t, err) 48 | assert.Equal(t, "4", resp.(string)) 49 | } 50 | 51 | func TestCalcAgentReal(t *testing.T) { 52 | openai_key := os.Getenv("OPENAI_API_KEY") 53 | assert.NotEmpty(t, openai_key) 54 | llm := openai.NewClient(openai_key) 55 | 56 | agent := agents.NewCalcAgent(llm) 57 | assert.NotNil(t, agent) 58 | ctx := context.Background() 59 | inp := agents.CalcAgentInput{ 60 | Content: "What's 2 + 2?", 61 | } 62 | 63 | resp, err := agent.Run(ctx, inp) 64 | assert.NoError(t, err) 65 | assert.Equal(t, "4", resp.(string)) 66 | } 67 | -------------------------------------------------------------------------------- /packages/llm/llm.go: -------------------------------------------------------------------------------- 1 | //go:generate mockgen -source llm.go -destination internal/mocks/llm.go 2 | package llm 3 | 4 | import ( 5 | "context" 6 | ) 7 | 8 | type ProviderType string 9 | type ModelType string 10 | 11 | // yuck sorry 12 | const ( 13 | ModelTypeLLM ModelType = "llm" 14 | ModelTypeEmbedding ModelType = "embedding" 15 | ) 16 | 17 | type ModelConfig struct { 18 | ProviderType ProviderType 19 | ModelName string 20 | BaseURL string 21 | 22 | ModelType ModelType 23 | CentiCentsPerMillionInputTokens int 24 | CentiCentsPerMillionOutputTokens int 25 | } 26 | 27 | // MessageOptions are options that can be passed to the model for generating a response. 28 | // NB chua: these are the only ones I use, I assume others are useful too... 29 | type MessageOptions struct { 30 | MaxTokens int 31 | Temperature float32 32 | } 33 | 34 | type InferMessage struct { 35 | Content string 36 | Role string 37 | Image []byte 38 | Audio []byte 39 | 40 | ShouldCache bool 41 | } 42 | 43 | type InferRequest struct { 44 | Messages []InferMessage 45 | 46 | // ModelConfig describes the model to use for generating a response. 47 | ModelConfig ModelConfig 48 | // MessageOptions are options that can be passed to the model for generating a response. 49 | MessageOptions MessageOptions 50 | } 51 | 52 | type StreamDelta struct { 53 | Text string 54 | EOF bool 55 | } 56 | 57 | type Responder interface { 58 | GenerateResponse(ctx context.Context, req InferRequest) (string, error) 59 | GenerateResponseAsync(ctx context.Context, req InferRequest) (<-chan StreamDelta, error) 60 | } 61 | 62 | type EmbedRequest struct { 63 | Input []string 64 | Image []byte 65 | 66 | // Prompt is an instruction applied to all the input strings in this request. 67 | // Ignored unless the model specifically supports it 68 | Prompt string 69 | 70 | ModelConfig ModelConfig 71 | // only supported for openai (matryoshka) models 72 | Dimensions int 73 | } 74 | 75 | type Embedding struct { 76 | Values []float32 77 | } 78 | 79 | type EmbeddingResponse struct { 80 | Data []Embedding 81 | } 82 | 83 | type Embedder interface { 84 | GenerateEmbedding(ctx context.Context, req EmbedRequest) (*EmbeddingResponse, error) 85 | } 86 | -------------------------------------------------------------------------------- /packages/dispatch/functions.go: -------------------------------------------------------------------------------- 1 | package dispatch 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/invopop/jsonschema" 7 | "github.com/sashabaranov/go-openai" 8 | ) 9 | 10 | type FunctionInput struct { 11 | Name string `json:"name"` 12 | Description string `json:"description,omitempty"` 13 | Parameters any `json:"parameters"` 14 | } 15 | 16 | type OAITool struct { 17 | // Type is always "function" for now. 18 | Type string `json:"type"` 19 | Function FunctionInput `json:"function"` 20 | } 21 | 22 | func FunctionInputToTool(fi FunctionInput) openai.Tool { 23 | f_ := openai.FunctionDefinition(fi) 24 | return openai.Tool{ 25 | Type: "function", 26 | Function: &f_, 27 | } 28 | } 29 | 30 | func StructToJsonSchema(functionName string, functionDescription string, inputStruct interface{}) FunctionInput { 31 | t := reflect.TypeOf(inputStruct) 32 | schema := jsonschema.ReflectFromType(reflect.Type(t)) 33 | inputStructName := t.Name() 34 | // only get the single struct we care about 35 | inputProperties, ok := schema.Definitions[inputStructName] 36 | if !ok { 37 | // this should not happen 38 | panic("could not find input struct in schema") 39 | } 40 | parameters := jsonschema.Schema{ 41 | Type: "object", 42 | Properties: inputProperties.Properties, 43 | Required: inputProperties.Required, 44 | } 45 | return FunctionInput{ 46 | Name: functionName, 47 | Description: functionDescription, 48 | Parameters: parameters, 49 | } 50 | } 51 | 52 | func StructToJsonSchemaGeneric[T any](functionName string, functionDescription string) FunctionInput { 53 | var tArr [0]T 54 | tt := reflect.TypeOf(tArr).Elem() 55 | schema := jsonschema.ReflectFromType(reflect.Type(tt)) 56 | inputStructName := tt.Name() 57 | // only get the single struct we care about 58 | inputProperties, ok := schema.Definitions[inputStructName] 59 | if !ok { 60 | // this should not happen 61 | panic("could not find input struct in schema") 62 | } 63 | parameters := jsonschema.Schema{ 64 | Type: "object", 65 | Properties: inputProperties.Properties, 66 | Required: inputProperties.Required, 67 | } 68 | return FunctionInput{ 69 | Name: functionName, 70 | Description: functionDescription, 71 | Parameters: parameters, 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /packages/queryplanner/queryplanner_test.go: -------------------------------------------------------------------------------- 1 | package queryplanner_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/stillmatic/gollum/packages/dispatch" 7 | "github.com/stillmatic/gollum/packages/jsonparser" 8 | . "github.com/stillmatic/gollum/packages/queryplanner" 9 | "os" 10 | "testing" 11 | 12 | "github.com/joho/godotenv" 13 | "github.com/sashabaranov/go-openai" 14 | "github.com/stillmatic/gollum/internal/testutil" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func TestQueryPlanner(t *testing.T) { 19 | godotenv.Load() 20 | baseAPIURL := "https://api.openai.com/v1/chat/completions" 21 | openAIKey := os.Getenv("OPENAI_API_KEY") 22 | assert.NotEmpty(t, openAIKey) 23 | 24 | api := testutil.NewTestAPI(baseAPIURL, openAIKey) 25 | fi := dispatch.StructToJsonSchemaGeneric[QueryPlan]("QueryPlan", "Use this to plan a query.") 26 | question := "What is the difference between populations of Canada and Jason's home country?" 27 | 28 | messages := []openai.ChatCompletionMessage{ 29 | { 30 | Role: openai.ChatMessageRoleSystem, 31 | Content: "You are a world class query planning algorithm capable of breaking apart questions into its depenencies queries such that the answers can be used to inform the parent question. Do not answer the questions, simply provide correct compute graph with good specific questions to ask and relevant dependencies. Before you call the function, think step by step to get a better understanding the problem.", 32 | }, 33 | { 34 | Role: openai.ChatMessageRoleUser, 35 | Content: fmt.Sprintf("Consider: %s\nGenerate the correct query plan.", question), 36 | }, 37 | } 38 | f_ := openai.FunctionDefinition(fi) 39 | chatRequest := openai.ChatCompletionRequest{ 40 | Messages: messages, 41 | Model: "gpt-3.5-turbo-0613", 42 | Temperature: 0.0, 43 | Tools: []openai.Tool{ 44 | { 45 | Type: "function", 46 | Function: &f_, 47 | }, 48 | }, 49 | } 50 | ctx := context.Background() 51 | resp, err := api.SendRequest(ctx, chatRequest) 52 | 53 | assert.NoError(t, err) 54 | assert.NotNil(t, resp) 55 | parser := jsonparser.NewJSONParserGeneric[QueryPlan](false) 56 | queryPlan, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments)) 57 | assert.NoError(t, err) 58 | assert.NotNil(t, queryPlan) 59 | t.Log(queryPlan) 60 | assert.Equal(t, 0, 1) 61 | } 62 | -------------------------------------------------------------------------------- /internal/testutil/utils.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "math/rand" 9 | "net/http" 10 | 11 | "github.com/sashabaranov/go-openai" 12 | ) 13 | 14 | type TestAPI struct { 15 | baseAPIURL string 16 | apiKey string 17 | client *http.Client 18 | } 19 | 20 | func NewTestAPI(baseAPIURL, apiKey string) *TestAPI { 21 | return &TestAPI{ 22 | baseAPIURL: baseAPIURL, 23 | apiKey: apiKey, 24 | client: &http.Client{}, 25 | } 26 | } 27 | 28 | func (api *TestAPI) SendRequest(ctx context.Context, chatRequest openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error) { 29 | b, err := json.Marshal(chatRequest) 30 | if err != nil { 31 | return nil, err 32 | } 33 | 34 | req, err := http.NewRequestWithContext(ctx, "POST", api.baseAPIURL, bytes.NewReader(b)) 35 | if err != nil { 36 | return nil, err 37 | } 38 | req.Header.Set("Authorization", "Bearer "+api.apiKey) 39 | req.Header.Set("Content-Type", "application/json") 40 | 41 | resp, err := api.client.Do(req) 42 | if err != nil { 43 | return nil, err 44 | } 45 | defer resp.Body.Close() 46 | 47 | if resp.StatusCode != 200 { 48 | return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) 49 | } 50 | 51 | var chatResponse openai.ChatCompletionResponse 52 | err = json.NewDecoder(resp.Body).Decode(&chatResponse) 53 | if err != nil { 54 | return nil, err 55 | } 56 | 57 | return &chatResponse, nil 58 | } 59 | 60 | func GetRandomEmbedding(n int) []float32 { 61 | vec := make([]float32, n) 62 | for i := range vec { 63 | vec[i] = rand.Float32() 64 | } 65 | return vec 66 | } 67 | 68 | func GetRandomEmbeddingResponse(n int, dim int) openai.EmbeddingResponse { 69 | data := make([]openai.Embedding, n) 70 | for i := range data { 71 | data[i] = openai.Embedding{ 72 | Embedding: GetRandomEmbedding(dim), 73 | } 74 | } 75 | resp := openai.EmbeddingResponse{ 76 | Data: data, 77 | } 78 | return resp 79 | } 80 | 81 | func GetRandomChatCompletionResponse(n int) openai.ChatCompletionResponse { 82 | choices := make([]openai.ChatCompletionChoice, n) 83 | for i := range choices { 84 | choices[i] = openai.ChatCompletionChoice{ 85 | Message: openai.ChatCompletionMessage{ 86 | Role: openai.ChatMessageRoleSystem, 87 | Content: fmt.Sprintf("test? %d", i), 88 | }, 89 | } 90 | } 91 | return openai.ChatCompletionResponse{ 92 | Choices: choices, 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /packages/jsonparser/parser.go: -------------------------------------------------------------------------------- 1 | package jsonparser 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "reflect" 7 | 8 | "github.com/invopop/jsonschema" 9 | "github.com/pkg/errors" 10 | val "github.com/santhosh-tekuri/jsonschema/v5" 11 | ) 12 | 13 | // Parser is an interface for parsing strings into structs 14 | // It is threadsafe and can be used concurrently. The underlying validator is threadsafe as well. 15 | type Parser[T any] interface { 16 | Parse(ctx context.Context, input []byte) (T, error) 17 | } 18 | 19 | // JSONParser is a parser that parses arbitrary JSON structs 20 | // It is threadsafe and can be used concurrently. The underlying validator is threadsafe as well. 21 | type JSONParserGeneric[T any] struct { 22 | validate bool 23 | schema *val.Schema 24 | } 25 | 26 | // NewJSONParser returns a new JSONParser 27 | // validation is done via jsonschema 28 | func NewJSONParserGeneric[T any](validate bool) *JSONParserGeneric[T] { 29 | var sch *val.Schema 30 | var t_ T 31 | if validate { 32 | // reflect T into schema 33 | t := reflect.TypeOf(t_) 34 | schema := jsonschema.ReflectFromType(t) 35 | // compile schema 36 | b, err := schema.MarshalJSON() 37 | if err != nil { 38 | panic(errors.Wrap(err, "could not marshal schema")) 39 | } 40 | schemaStr := string(b) 41 | sch, err = val.CompileString("schema.json", schemaStr) 42 | if err != nil { 43 | panic(errors.Wrap(err, "could not compile schema")) 44 | } 45 | } 46 | 47 | return &JSONParserGeneric[T]{ 48 | validate: validate, 49 | schema: sch, 50 | } 51 | } 52 | 53 | func (p *JSONParserGeneric[T]) Parse(ctx context.Context, input []byte) (T, error) { 54 | var t T 55 | // but also unmarshal to the struct because it's easy to get a type conversion error 56 | // e.g. if go struct name doesn't match the json name 57 | err := json.Unmarshal(input, &t) 58 | if err != nil { 59 | return t, errors.Wrap(err, "could not unmarshal input json to struct") 60 | } 61 | 62 | // annoying, must pass an interface to the validate function 63 | // so we have to unmarshal to interface{} twice 64 | if p.validate { 65 | var v interface{} 66 | err := json.Unmarshal(input, &v) 67 | if err != nil { 68 | return t, errors.Wrap(err, "could not unmarshal input json to interface") 69 | } 70 | err = p.schema.Validate(v) 71 | if err != nil { 72 | return t, errors.Wrap(err, "error validating input json") 73 | } 74 | } 75 | 76 | return t, nil 77 | } 78 | -------------------------------------------------------------------------------- /packages/vectorstore/vectorstore.go: -------------------------------------------------------------------------------- 1 | package vectorstore 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/stillmatic/gollum" 7 | ) 8 | 9 | // QueryRequest is a struct that contains the query and optional query strings or embeddings 10 | type QueryRequest struct { 11 | // Query is the text to query 12 | Query string 13 | // EmbeddingStrings is a list of strings to concatenate and embed instead of Query 14 | EmbeddingStrings []string 15 | // EmbeddingFloats is a query vector to use instead of Query 16 | EmbeddingFloats []float32 17 | // K is the number of results to return 18 | K int 19 | } 20 | 21 | type VectorStore interface { 22 | Insert(context.Context, gollum.Document) error 23 | Query(ctx context.Context, qb QueryRequest) ([]*gollum.Document, error) 24 | RetrieveAll(ctx context.Context) ([]gollum.Document, error) 25 | } 26 | 27 | type NodeSimilarity struct { 28 | Document *gollum.Document 29 | Similarity float32 30 | } 31 | 32 | // Heap is a custom heap implementation, to avoid interface{} conversion. 33 | // I _think_ theoretically that a memory arena would be useful here, but that feels a bit beyond the pale, even for me. 34 | // In benchmarking, we see that allocations are limited by scale according to k -- 35 | // since K is known, we should be able to allocate a fixed-size arena and use that. 36 | // That being said... let's revisit in the future :) 37 | type Heap []NodeSimilarity 38 | 39 | func (h *Heap) Init(k int) { 40 | *h = make(Heap, 0, k) 41 | } 42 | 43 | func (h Heap) down(u int) { 44 | v := u 45 | if 2*u+1 < len(h) && h[2*u+1].Similarity < h[v].Similarity { 46 | v = 2*u + 1 47 | } 48 | if 2*u+2 < len(h) && h[2*u+2].Similarity < h[v].Similarity { 49 | v = 2*u + 2 50 | } 51 | if v != u { 52 | h[v], h[u] = h[u], h[v] 53 | h.down(v) 54 | } 55 | } 56 | 57 | func (h Heap) up(u int) { 58 | for u != 0 && h[(u-1)/2].Similarity > h[u].Similarity { 59 | h[(u-1)/2], h[u] = h[u], h[(u-1)/2] 60 | u = (u - 1) / 2 61 | } 62 | } 63 | 64 | func (h *Heap) Push(e NodeSimilarity) { 65 | *h = append(*h, e) 66 | h.up(len(*h) - 1) 67 | } 68 | 69 | func (h *Heap) Pop() NodeSimilarity { 70 | x := (*h)[0] 71 | n := len(*h) 72 | (*h)[0], (*h)[n-1] = (*h)[n-1], (*h)[0] 73 | *h = (*h)[:n-1] 74 | h.down(0) 75 | return x 76 | } 77 | 78 | func (h Heap) Less(i, j int) bool { 79 | return h[i].Similarity < h[j].Similarity 80 | } 81 | 82 | func (h Heap) Swap(i, j int) { 83 | h[i], h[j] = h[j], h[i] 84 | } 85 | 86 | func (h *Heap) Len() int { 87 | return len(*h) 88 | } 89 | -------------------------------------------------------------------------------- /packages/llm/providers/cached/cached_provider.go: -------------------------------------------------------------------------------- 1 | package cached 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "fmt" 7 | "log" 8 | 9 | "github.com/stillmatic/gollum/packages/llm" 10 | "github.com/stillmatic/gollum/packages/llm/cache" 11 | "github.com/stillmatic/gollum/packages/llm/providers/cached/sqlitecache" 12 | 13 | "hash" 14 | 15 | _ "modernc.org/sqlite" 16 | ) 17 | 18 | // CachedResponder implements the Responder interface with caching 19 | type CachedResponder struct { 20 | underlying llm.Responder 21 | cache cache.Cache 22 | hasher hash.Hash 23 | } 24 | 25 | // NewLocalCachedResponder creates a new CachedResponder with a local SQLite cache 26 | // For example, initialize an OpenAI provider and then wrap it with this cache. 27 | func NewLocalCachedResponder(underlying llm.Responder, dbPath string) (*CachedResponder, error) { 28 | cache, err := sqlitecache.NewSQLiteCache(dbPath) 29 | if err != nil { 30 | return nil, fmt.Errorf("failed to create cache: %w", err) 31 | } 32 | 33 | // we use sha256 to avoid pulling down the xxhash dep if you don't need to 34 | // these are small strings to cache so shouldn't make a big diff 35 | hasher := sha256.New() 36 | 37 | return &CachedResponder{ 38 | underlying: underlying, 39 | cache: cache, 40 | hasher: hasher, 41 | }, nil 42 | } 43 | 44 | func (cr *CachedResponder) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) { 45 | // Check cache 46 | cachedResponse, err := cr.cache.GetResponse(ctx, req) 47 | if err == nil { 48 | return cachedResponse, nil 49 | } 50 | 51 | // If not in cache, call underlying provider 52 | response, err := cr.underlying.GenerateResponse(ctx, req) 53 | if err != nil { 54 | return "", err 55 | } 56 | 57 | // Cache the result 58 | if err := cr.cache.SetResponse(ctx, req, response); err != nil { 59 | log.Printf("Failed to cache response: %v", err) 60 | } 61 | 62 | return response, nil 63 | } 64 | 65 | func (cr *CachedResponder) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) { 66 | // For async responses, we don't cache and just pass through to the underlying provider 67 | // TODO: think about if we should just cache final response and return immediately 68 | return cr.underlying.GenerateResponseAsync(ctx, req) 69 | } 70 | 71 | func (cr *CachedResponder) Close() error { 72 | return cr.cache.Close() 73 | } 74 | 75 | func (cr *CachedResponder) GetCacheStats() cache.CacheStats { 76 | return cr.cache.GetStats() 77 | } 78 | -------------------------------------------------------------------------------- /internal/hash/readme.md: -------------------------------------------------------------------------------- 1 | with `go test -bench=. -cpu 1 -benchmem -benchtime=1s` 2 | 3 | with n=32 4 | 5 | ``` 6 | goos: linux 7 | goarch: amd64 8 | pkg: github.com/stillmatic/gollum/internal/hash 9 | cpu: AMD Ryzen 9 7950X 16-Core Processor 10 | BenchmarkFNV32 48111 25479 ns/op 8 B/op 1 allocs/op 11 | BenchmarkFNV64 47475 25640 ns/op 8 B/op 1 allocs/op 12 | BenchmarkFNV128 38230 31021 ns/op 16 B/op 1 allocs/op 13 | BenchmarkMD5 44103 27324 ns/op 16 B/op 1 allocs/op 14 | BenchmarkSHA1 69037 17428 ns/op 24 B/op 1 allocs/op 15 | BenchmarkSHA224 98686 12203 ns/op 32 B/op 1 allocs/op 16 | BenchmarkSHA256 97653 12591 ns/op 32 B/op 1 allocs/op 17 | BenchmarkSHA512 39342 29469 ns/op 64 B/op 1 allocs/op 18 | BenchmarkMurmur3 170241 6990 ns/op 8 B/op 1 allocs/op 19 | BenchmarkXxhash 756992 1552 ns/op 8 B/op 1 allocs/op 20 | PASS 21 | ok github.com/stillmatic/gollum/internal/hash 13.927s 22 | ``` 23 | 24 | with n=512 25 | 26 | ``` 27 | goos: linux 28 | goarch: amd64 29 | pkg: github.com/stillmatic/gollum/internal/hash 30 | cpu: AMD Ryzen 9 7950X 16-Core Processor 31 | BenchmarkFNV32 3097 407335 ns/op 8 B/op 1 allocs/op 32 | BenchmarkFNV64 3090 389704 ns/op 8 B/op 1 allocs/op 33 | BenchmarkFNV128 2314 499748 ns/op 16 B/op 1 allocs/op 34 | BenchmarkMD5 2682 443034 ns/op 16 B/op 1 allocs/op 35 | BenchmarkSHA1 4224 280997 ns/op 24 B/op 1 allocs/op 36 | BenchmarkSHA224 5968 198692 ns/op 32 B/op 1 allocs/op 37 | BenchmarkSHA256 5952 201876 ns/op 32 B/op 1 allocs/op 38 | BenchmarkSHA512 2544 489953 ns/op 64 B/op 1 allocs/op 39 | BenchmarkMurmur3 9975 119104 ns/op 8 B/op 1 allocs/op 40 | BenchmarkXxhash 46592 24758 ns/op 8 B/op 1 allocs/op 41 | PASS 42 | ok github.com/stillmatic/gollum/internal/hash 12.580s 43 | ``` 44 | 45 | Conclusion: `xxhash` is an order of magnitude improvement on large strings. `murmur3` is good for medium size but doesn't seem to scale as well. -------------------------------------------------------------------------------- /packages/docstore/docstore.go: -------------------------------------------------------------------------------- 1 | package docstore 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "github.com/stillmatic/gollum" 7 | 8 | "github.com/pkg/errors" 9 | "gocloud.dev/blob" 10 | ) 11 | 12 | type DocStore interface { 13 | Insert(context.Context, gollum.Document) error 14 | Retrieve(ctx context.Context, id string) (gollum.Document, error) 15 | } 16 | 17 | // MemoryDocStore is a simple in-memory document store. 18 | // It's functionally a hashmap / inverted-index. 19 | type MemoryDocStore struct { 20 | Documents map[string]gollum.Document 21 | } 22 | 23 | func NewMemoryDocStore() *MemoryDocStore { 24 | return &MemoryDocStore{ 25 | Documents: make(map[string]gollum.Document), 26 | } 27 | } 28 | 29 | func NewMemoryDocStoreFromDisk(ctx context.Context, bucket *blob.Bucket, path string) (*MemoryDocStore, error) { 30 | // load documents from disk 31 | data, err := bucket.ReadAll(ctx, path) 32 | if err != nil { 33 | return nil, errors.Wrap(err, "failed to read documents from disk") 34 | } 35 | var nodes map[string]gollum.Document 36 | err = json.Unmarshal(data, &nodes) 37 | if err != nil { 38 | return nil, errors.Wrap(err, "failed to unmarshal documents from JSON") 39 | } 40 | return &MemoryDocStore{ 41 | Documents: nodes, 42 | }, nil 43 | } 44 | 45 | // Insert adds a node to the document store. It overwrites duplicates. 46 | func (m *MemoryDocStore) Insert(ctx context.Context, d gollum.Document) error { 47 | m.Documents[d.ID] = d 48 | return nil 49 | } 50 | 51 | // Retrieve returns a node from the document store matching an ID. 52 | func (m *MemoryDocStore) Retrieve(ctx context.Context, id string) (gollum.Document, error) { 53 | v, ok := m.Documents[id] 54 | if !ok { 55 | return gollum.Document{}, errors.New("document not found") 56 | } 57 | return v, nil 58 | } 59 | 60 | // Persist saves the document store to disk. 61 | func (m *MemoryDocStore) Persist(ctx context.Context, bucket *blob.Bucket, path string) error { 62 | // save documents to disk 63 | data, err := json.Marshal(m.Documents) 64 | if err != nil { 65 | return errors.Wrap(err, "failed to marshal documents to JSON") 66 | } 67 | err = bucket.WriteAll(ctx, path, data, nil) 68 | if err != nil { 69 | return errors.Wrap(err, "failed to write documents to disk") 70 | } 71 | return nil 72 | } 73 | 74 | // Load loads the document store from disk. 75 | func (m *MemoryDocStore) Load(ctx context.Context, bucket *blob.Bucket, path string) error { 76 | // load documents from disk 77 | data, err := bucket.ReadAll(ctx, path) 78 | if err != nil { 79 | return errors.Wrap(err, "failed to read documents from disk") 80 | } 81 | var nodes map[string]gollum.Document 82 | err = json.Unmarshal(data, &nodes) 83 | if err != nil { 84 | return errors.Wrap(err, "failed to unmarshal documents from JSON") 85 | } 86 | m.Documents = nodes 87 | return nil 88 | } 89 | -------------------------------------------------------------------------------- /packages/llm/providers/voyage/voyage.go: -------------------------------------------------------------------------------- 1 | package voyage 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "github.com/stillmatic/gollum/packages/llm" 9 | "io" 10 | "net/http" 11 | ) 12 | 13 | const ( 14 | apiURL = "https://api.voyageai.com/v1/embeddings" 15 | ) 16 | 17 | type VoyageAIEmbedder struct { 18 | APIKey string 19 | } 20 | 21 | type voyageAIRequest struct { 22 | Input []string `json:"input"` 23 | Model string `json:"model"` 24 | } 25 | 26 | type voyageAIResponse struct { 27 | Object string `json:"object"` 28 | Data []struct { 29 | Object string `json:"object"` 30 | Embedding []float32 `json:"embedding"` 31 | Index int `json:"index"` 32 | } `json:"data"` 33 | Model string `json:"model"` 34 | Usage struct { 35 | TotalTokens int `json:"total_tokens"` 36 | } `json:"usage"` 37 | } 38 | 39 | func NewVoyageAIEmbedder(apiKey string) *VoyageAIEmbedder { 40 | return &VoyageAIEmbedder{APIKey: apiKey} 41 | } 42 | 43 | func (e *VoyageAIEmbedder) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) { 44 | if len(req.Image) > 0 { 45 | return nil, fmt.Errorf("image embedding not supported by Voyage AI") 46 | } 47 | 48 | if req.Dimensions != 0 { 49 | return nil, fmt.Errorf("custom dimensions not supported by Voyage AI") 50 | } 51 | 52 | voyageReq := voyageAIRequest{ 53 | Input: req.Input, 54 | Model: req.ModelConfig.ModelName, 55 | } 56 | 57 | jsonData, err := json.Marshal(voyageReq) 58 | if err != nil { 59 | return nil, fmt.Errorf("failed to marshal request: %w", err) 60 | } 61 | 62 | httpReq, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData)) 63 | if err != nil { 64 | return nil, fmt.Errorf("failed to create request: %w", err) 65 | } 66 | 67 | httpReq.Header.Set("Content-Type", "application/json") 68 | httpReq.Header.Set("Authorization", "Bearer "+e.APIKey) 69 | 70 | client := &http.Client{} 71 | resp, err := client.Do(httpReq) 72 | if err != nil { 73 | return nil, fmt.Errorf("failed to send request: %w", err) 74 | } 75 | defer resp.Body.Close() 76 | 77 | body, err := io.ReadAll(resp.Body) 78 | if err != nil { 79 | return nil, fmt.Errorf("failed to read response body: %w", err) 80 | } 81 | 82 | if resp.StatusCode != http.StatusOK { 83 | return nil, fmt.Errorf("API request failed with status code %d: %s", resp.StatusCode, string(body)) 84 | } 85 | 86 | var voyageResp voyageAIResponse 87 | err = json.Unmarshal(body, &voyageResp) 88 | if err != nil { 89 | return nil, fmt.Errorf("failed to unmarshal response: %w", err) 90 | } 91 | 92 | embeddings := make([]llm.Embedding, len(voyageResp.Data)) 93 | for i, data := range voyageResp.Data { 94 | embeddings[i] = llm.Embedding{Values: data.Embedding} 95 | } 96 | 97 | return &llm.EmbeddingResponse{Data: embeddings}, nil 98 | } 99 | -------------------------------------------------------------------------------- /packages/jsonparser/parser_test.go: -------------------------------------------------------------------------------- 1 | package jsonparser_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "testing" 7 | 8 | "github.com/stillmatic/gollum/packages/jsonparser" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | type employee struct { 13 | Name string `json:"name" yaml:"name" jsonschema:"required,minLength=1,maxLength=100"` 14 | Age int `json:"age" yaml:"age" jsonschema:"minimum=18,maximum=80,required"` 15 | } 16 | 17 | type company struct { 18 | Name string `json:"name" yaml:"name"` 19 | Employees []employee `json:"employees,omitempty" yaml:"employees"` 20 | } 21 | 22 | type bulletList []string 23 | 24 | var testCo = company{ 25 | Name: "Acme", 26 | Employees: []employee{ 27 | { 28 | Name: "John", 29 | Age: 30, 30 | }, 31 | { 32 | Name: "Jane", 33 | Age: 25, 34 | }, 35 | }, 36 | } 37 | var badEmployees = []employee{ 38 | { 39 | Name: "John", 40 | Age: 0, 41 | }, 42 | { 43 | Name: "", 44 | Age: 25, 45 | }, 46 | } 47 | 48 | func TestParsers(t *testing.T) { 49 | t.Run("JSONParser", func(t *testing.T) { 50 | jsonParser := jsonparser.NewJSONParserGeneric[company](false) 51 | input, err := json.Marshal(testCo) 52 | assert.NoError(t, err) 53 | 54 | actual, err := jsonParser.Parse(context.Background(), input) 55 | assert.NoError(t, err) 56 | assert.Equal(t, testCo, actual) 57 | 58 | // test failure 59 | employeeParser := jsonparser.NewJSONParserGeneric[employee](true) 60 | input2, err := json.Marshal(badEmployees) 61 | assert.NoError(t, err) 62 | _, err = employeeParser.Parse(context.Background(), input2) 63 | assert.Error(t, err) 64 | }) 65 | 66 | t.Run("testbenchmark", func(t *testing.T) { 67 | jsonParser := jsonparser.NewJSONParserGeneric[company](true) 68 | input, err := json.Marshal(testCo) 69 | assert.NoError(t, err) 70 | actual, err := jsonParser.Parse(context.Background(), input) 71 | assert.NoError(t, err) 72 | assert.Equal(t, testCo.Name, actual.Name) 73 | }) 74 | } 75 | 76 | func BenchmarkParser(b *testing.B) { 77 | b.Run("JSONParser-NoValidate", func(b *testing.B) { 78 | jsonParser := jsonparser.NewJSONParserGeneric[company](false) 79 | input, err := json.Marshal(testCo) 80 | assert.NoError(b, err) 81 | b.ResetTimer() 82 | for i := 0; i < b.N; i++ { 83 | actual, err := jsonParser.Parse(context.Background(), input) 84 | assert.NoError(b, err) 85 | assert.Equal(b, testCo, actual) 86 | } 87 | }) 88 | b.Run("JSONParser-Validate", func(b *testing.B) { 89 | jsonParser := jsonparser.NewJSONParserGeneric[company](true) 90 | input, err := json.Marshal(testCo) 91 | assert.NoError(b, err) 92 | b.ResetTimer() 93 | for i := 0; i < b.N; i++ { 94 | actual, err := jsonParser.Parse(context.Background(), input) 95 | assert.NoError(b, err) 96 | assert.Equal(b, testCo, actual) 97 | } 98 | }) 99 | } 100 | -------------------------------------------------------------------------------- /packages/agents/calcagent.go: -------------------------------------------------------------------------------- 1 | package agents 2 | 3 | import ( 4 | "context" 5 | "github.com/stillmatic/gollum/packages/dispatch" 6 | "github.com/stillmatic/gollum/packages/jsonparser" 7 | "github.com/stillmatic/gollum/packages/tools" 8 | "strconv" 9 | 10 | "github.com/pkg/errors" 11 | "github.com/sashabaranov/go-openai" 12 | "github.com/stillmatic/gollum" 13 | ) 14 | 15 | type CalcAgent struct { 16 | tool tools.CalculatorTool 17 | env map[string]interface{} 18 | llm gollum.ChatCompleter 19 | functionInput openai.FunctionDefinition 20 | parser jsonparser.Parser[tools.CalculatorInput] 21 | } 22 | 23 | func NewCalcAgent(llm gollum.ChatCompleter) *CalcAgent { 24 | // might as well pre-compute it 25 | fi := dispatch.StructToJsonSchemaGeneric[tools.CalculatorInput]("calculator", "evaluate mathematical expressions") 26 | parser := jsonparser.NewJSONParserGeneric[tools.CalculatorInput](true) 27 | return &CalcAgent{ 28 | tool: tools.CalculatorTool{}, 29 | env: make(map[string]interface{}), 30 | llm: llm, 31 | functionInput: openai.FunctionDefinition(fi), 32 | parser: parser, 33 | } 34 | } 35 | 36 | type CalcAgentInput struct { 37 | Content string `json:"content" jsonschema:"required" jsonschema_description:"Natural language input to the calculator"` 38 | } 39 | 40 | func (c *CalcAgent) Name() string { 41 | return "calcagent" 42 | } 43 | 44 | func (c *CalcAgent) Description() string { 45 | return "convert natural language and evaluate mathematical expressions" 46 | } 47 | 48 | func (c *CalcAgent) Run(ctx context.Context, input interface{}) (interface{}, error) { 49 | cinput, ok := input.(CalcAgentInput) 50 | if !ok { 51 | return "", errors.New("invalid input") 52 | } 53 | // call LLM to convert natural language to mathematical expression 54 | resp, err := c.llm.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ 55 | Model: openai.GPT3Dot5Turbo0613, 56 | Messages: []openai.ChatCompletionMessage{ 57 | { 58 | Role: openai.ChatMessageRoleSystem, 59 | Content: "Convert the user's natural language input to a mathematical expression and input to a calculator function. Do not use prior knowledge.", 60 | }, 61 | { 62 | Role: openai.ChatMessageRoleUser, 63 | Content: cinput.Content, 64 | }, 65 | }, 66 | MaxTokens: 128, 67 | Tools: []openai.Tool{{ 68 | Type: "function", 69 | Function: &c.functionInput, 70 | }}, 71 | ToolChoice: "calculator", 72 | }) 73 | if err != nil { 74 | return "", errors.Wrap(err, "couldn't call the LLM") 75 | } 76 | // parse response 77 | parsed, err := c.parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments)) 78 | if err != nil { 79 | return "", errors.Wrap(err, "couldn't parse response") 80 | } 81 | output, err := c.tool.Run(ctx, parsed) 82 | if err != nil { 83 | return "", errors.Wrap(err, "couldn't run expression") 84 | } 85 | switch t := output.(type) { 86 | case string: 87 | return t, nil 88 | case int: 89 | return strconv.Itoa(t), nil 90 | case float64: 91 | return strconv.FormatFloat(t, 'f', -1, 64), nil 92 | default: 93 | return "", errors.New("invalid output") 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /packages/llm/providers/cached/cached_embedder.go: -------------------------------------------------------------------------------- 1 | package cached 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "fmt" 7 | "hash" 8 | "log" 9 | 10 | "github.com/stillmatic/gollum/packages/llm" 11 | "github.com/stillmatic/gollum/packages/llm/cache" 12 | "github.com/stillmatic/gollum/packages/llm/providers/cached/sqlitecache" 13 | ) 14 | 15 | // CachedEmbedder implements the llm.Embedder interface with caching 16 | type CachedEmbedder struct { 17 | underlying llm.Embedder 18 | cache cache.Cache 19 | hasher hash.Hash 20 | } 21 | 22 | // NewLocalCachedEmbedder creates a new CachedEmbedder with a local SQLite cache 23 | func NewLocalCachedEmbedder(underlying llm.Embedder, dbPath string) (*CachedEmbedder, error) { 24 | cache, err := sqlitecache.NewSQLiteCache(dbPath) 25 | if err != nil { 26 | return nil, fmt.Errorf("failed to create cache: %w", err) 27 | } 28 | 29 | return &CachedEmbedder{ 30 | underlying: underlying, 31 | cache: cache, 32 | hasher: sha256.New(), 33 | }, nil 34 | } 35 | 36 | func (ce *CachedEmbedder) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) { 37 | cachedEmbeddings := make([]llm.Embedding, 0, len(req.Input)) 38 | uncachedIndices := make([]int, 0) 39 | uncachedInputs := make([]string, 0) 40 | 41 | // Check cache for each input string 42 | for i, input := range req.Input { 43 | embedding, err := ce.cache.GetEmbedding(ctx, req.ModelConfig.ModelName, input) 44 | if err == nil { 45 | cachedEmbeddings = append(cachedEmbeddings, llm.Embedding{Values: embedding}) 46 | } else { 47 | uncachedIndices = append(uncachedIndices, i) 48 | uncachedInputs = append(uncachedInputs, input) 49 | } 50 | } 51 | 52 | // If all embeddings were cached, return immediately 53 | if len(uncachedInputs) == 0 { 54 | return &llm.EmbeddingResponse{ 55 | Data: cachedEmbeddings, 56 | }, nil 57 | } 58 | 59 | // Generate embeddings for uncached inputs 60 | uncachedReq := llm.EmbedRequest{ 61 | ModelConfig: req.ModelConfig, 62 | Input: uncachedInputs, 63 | } 64 | uncachedResponse, err := ce.underlying.GenerateEmbedding(ctx, uncachedReq) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | // Cache the new embeddings 70 | for i, embedding := range uncachedResponse.Data { 71 | if err := ce.cache.SetEmbedding(ctx, req.ModelConfig.ModelName, uncachedInputs[i], embedding.Values); err != nil { 72 | log.Printf("Failed to cache embedding: %v", err) 73 | } 74 | } 75 | 76 | // Merge cached and new embeddings 77 | finalEmbeddings := make([]llm.Embedding, len(req.Input)) 78 | cachedIndex, uncachedIndex := 0, 0 79 | for i := range req.Input { 80 | if contains(uncachedIndices, i) { 81 | finalEmbeddings[i] = uncachedResponse.Data[uncachedIndex] 82 | uncachedIndex++ 83 | } else { 84 | finalEmbeddings[i] = cachedEmbeddings[cachedIndex] 85 | cachedIndex++ 86 | } 87 | } 88 | 89 | return &llm.EmbeddingResponse{ 90 | Data: finalEmbeddings, 91 | }, nil 92 | } 93 | func (ce *CachedEmbedder) Close() error { 94 | return ce.cache.Close() 95 | } 96 | 97 | func (ce *CachedEmbedder) GetCacheStats() cache.CacheStats { 98 | return ce.cache.GetStats() 99 | } 100 | 101 | func contains(slice []int, val int) bool { 102 | for _, item := range slice { 103 | if item == val { 104 | return true 105 | } 106 | } 107 | return false 108 | } 109 | -------------------------------------------------------------------------------- /packages/llm/providers/mixedbread/mixedbread.go: -------------------------------------------------------------------------------- 1 | package mixedbread 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "github.com/stillmatic/gollum/packages/llm" 9 | "io" 10 | "net/http" 11 | ) 12 | 13 | const ( 14 | apiURL = "https://api.mixedbread.ai/v1/embeddings" 15 | ) 16 | 17 | type MixedbreadEmbedder struct { 18 | APIKey string 19 | } 20 | 21 | type mixedbreadRequest struct { 22 | Input interface{} `json:"input"` 23 | Model string `json:"model"` 24 | Prompt string `json:"prompt,omitempty"` 25 | Normalized *bool `json:"normalized,omitempty"` 26 | Dimensions *int `json:"dimensions,omitempty"` 27 | EncodingFormat string `json:"encoding_format,omitempty"` 28 | TruncationStrategy string `json:"truncation_strategy,omitempty"` 29 | } 30 | 31 | type mixedbreadResponse struct { 32 | Model string `json:"model"` 33 | Object string `json:"object"` 34 | Data []struct { 35 | Embedding interface{} `json:"embedding"` 36 | Index int `json:"index"` 37 | Object string `json:"object"` 38 | } `json:"data"` 39 | Usage struct { 40 | PromptTokens int `json:"prompt_tokens"` 41 | TotalTokens int `json:"total_tokens"` 42 | } `json:"usage"` 43 | Normalized bool `json:"normalized"` 44 | } 45 | 46 | func NewMixedbreadEmbedder(apiKey string) *MixedbreadEmbedder { 47 | return &MixedbreadEmbedder{APIKey: apiKey} 48 | } 49 | 50 | func ptr[T any](x T) *T { 51 | return &x 52 | } 53 | 54 | func (e *MixedbreadEmbedder) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) { 55 | if len(req.Image) > 0 { 56 | return nil, fmt.Errorf("image embedding not supported by Mixedbread API") 57 | } 58 | 59 | mixedReq := mixedbreadRequest{ 60 | Input: req.Input, 61 | Model: req.ModelConfig.ModelName, 62 | Normalized: ptr(true), 63 | } 64 | if req.Prompt != "" { 65 | mixedReq.Prompt = req.Prompt 66 | } 67 | 68 | if req.Dimensions != 0 { 69 | mixedReq.Dimensions = &req.Dimensions 70 | } 71 | 72 | jsonData, err := json.Marshal(mixedReq) 73 | if err != nil { 74 | return nil, fmt.Errorf("failed to marshal request: %w", err) 75 | } 76 | 77 | httpReq, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData)) 78 | if err != nil { 79 | return nil, fmt.Errorf("failed to create request: %w", err) 80 | } 81 | 82 | httpReq.Header.Set("Content-Type", "application/json") 83 | httpReq.Header.Set("Authorization", "Bearer "+e.APIKey) 84 | 85 | client := &http.Client{} 86 | resp, err := client.Do(httpReq) 87 | if err != nil { 88 | return nil, fmt.Errorf("failed to send request: %w", err) 89 | } 90 | defer resp.Body.Close() 91 | 92 | body, err := io.ReadAll(resp.Body) 93 | if err != nil { 94 | return nil, fmt.Errorf("failed to read response body: %w", err) 95 | } 96 | 97 | if resp.StatusCode != http.StatusOK { 98 | return nil, fmt.Errorf("API request failed with status code %d: %s", resp.StatusCode, string(body)) 99 | } 100 | 101 | var mixedResp mixedbreadResponse 102 | err = json.Unmarshal(body, &mixedResp) 103 | if err != nil { 104 | return nil, fmt.Errorf("failed to unmarshal response: %w", err) 105 | } 106 | 107 | embeddings := make([]llm.Embedding, len(mixedResp.Data)) 108 | for i, data := range mixedResp.Data { 109 | switch v := data.Embedding.(type) { 110 | case []interface{}: 111 | values := make([]float32, len(v)) 112 | for j, val := range v { 113 | if f, ok := val.(float64); ok { 114 | values[j] = float32(f) 115 | } 116 | } 117 | embeddings[i] = llm.Embedding{Values: values} 118 | default: 119 | return nil, fmt.Errorf("unexpected embedding format") 120 | } 121 | } 122 | 123 | return &llm.EmbeddingResponse{Data: embeddings}, nil 124 | } 125 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/stillmatic/gollum 2 | 3 | go 1.23 4 | 5 | require ( 6 | cloud.google.com/go/vertexai v0.13.0 7 | github.com/antonmedv/expr v1.15.3 8 | github.com/cespare/xxhash/v2 v2.3.0 9 | github.com/chewxy/math32 v1.10.1 10 | github.com/google/generative-ai-go v0.19.0 11 | github.com/google/uuid v1.6.0 12 | github.com/invopop/jsonschema v0.12.0 13 | github.com/joho/godotenv v1.5.1 14 | github.com/klauspost/compress v1.17.2 15 | github.com/liushuangls/go-anthropic/v2 v2.13.0 16 | github.com/pkg/errors v0.9.1 17 | github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 18 | github.com/sashabaranov/go-openai v1.36.1 19 | github.com/stretchr/testify v1.9.0 20 | github.com/viterin/vek v0.4.2 21 | go.uber.org/mock v0.3.0 22 | gocloud.dev v0.38.0 23 | google.golang.org/api v0.215.0 24 | modernc.org/sqlite v1.32.0 25 | ) 26 | 27 | require ( 28 | cloud.google.com/go v0.115.1 // indirect 29 | cloud.google.com/go/ai v0.8.0 // indirect 30 | cloud.google.com/go/aiplatform v1.68.0 // indirect 31 | cloud.google.com/go/auth v0.13.0 // indirect 32 | cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect 33 | cloud.google.com/go/compute/metadata v0.6.0 // indirect 34 | cloud.google.com/go/iam v1.1.12 // indirect 35 | cloud.google.com/go/longrunning v0.5.11 // indirect 36 | github.com/bahlo/generic-list-go v0.2.0 // indirect 37 | github.com/buger/jsonparser v1.1.1 // indirect 38 | github.com/davecgh/go-spew v1.1.1 // indirect 39 | github.com/dustin/go-humanize v1.0.1 // indirect 40 | github.com/felixge/httpsnoop v1.0.4 // indirect 41 | github.com/go-logr/logr v1.4.2 // indirect 42 | github.com/go-logr/stdr v1.2.2 // indirect 43 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 44 | github.com/google/s2a-go v0.1.8 // indirect 45 | github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect 46 | github.com/googleapis/gax-go/v2 v2.14.1 // indirect 47 | github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect 48 | github.com/kr/text v0.2.0 // indirect 49 | github.com/mailru/easyjson v0.7.7 // indirect 50 | github.com/mattn/go-isatty v0.0.20 // indirect 51 | github.com/ncruces/go-strftime v0.1.9 // indirect 52 | github.com/pmezard/go-difflib v1.0.0 // indirect 53 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect 54 | github.com/viterin/partial v1.1.0 // indirect 55 | github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect 56 | go.opencensus.io v0.24.0 // indirect 57 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect 58 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect 59 | go.opentelemetry.io/otel v1.29.0 // indirect 60 | go.opentelemetry.io/otel/metric v1.29.0 // indirect 61 | go.opentelemetry.io/otel/trace v1.29.0 // indirect 62 | golang.org/x/crypto v0.31.0 // indirect 63 | golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 // indirect 64 | golang.org/x/net v0.33.0 // indirect 65 | golang.org/x/oauth2 v0.24.0 // indirect 66 | golang.org/x/sync v0.10.0 // indirect 67 | golang.org/x/sys v0.28.0 // indirect 68 | golang.org/x/text v0.21.0 // indirect 69 | golang.org/x/time v0.8.0 // indirect 70 | golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect 71 | google.golang.org/genproto v0.0.0-20240814211410-ddb44dafa142 // indirect 72 | google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 // indirect 73 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect 74 | google.golang.org/grpc v1.67.1 // indirect 75 | google.golang.org/protobuf v1.36.1 // indirect 76 | gopkg.in/yaml.v3 v3.0.1 // indirect 77 | modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect 78 | modernc.org/libc v1.55.3 // indirect 79 | modernc.org/mathutil v1.6.0 // indirect 80 | modernc.org/memory v1.8.0 // indirect 81 | modernc.org/strutil v1.2.0 // indirect 82 | modernc.org/token v1.1.0 // indirect 83 | ) 84 | -------------------------------------------------------------------------------- /packages/llm/providers/cached/sqlitecache/sqlite.go: -------------------------------------------------------------------------------- 1 | package sqlitecache 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "database/sql" 7 | "encoding/json" 8 | "fmt" 9 | "hash" 10 | 11 | "github.com/stillmatic/gollum/packages/llm" 12 | "github.com/stillmatic/gollum/packages/llm/cache" 13 | _ "modernc.org/sqlite" 14 | ) 15 | 16 | type SQLiteCache struct { 17 | db *sql.DB 18 | hasher hash.Hash 19 | numRequests int 20 | numCacheHits int 21 | } 22 | 23 | func NewSQLiteCache(dbPath string) (*SQLiteCache, error) { 24 | db, err := sql.Open("sqlite", dbPath) 25 | if err != nil { 26 | return nil, fmt.Errorf("failed to open database: %w", err) 27 | } 28 | 29 | if err := initDB(db); err != nil { 30 | return nil, fmt.Errorf("failed to initialize database: %w", err) 31 | } 32 | 33 | return &SQLiteCache{ 34 | db: db, 35 | hasher: sha256.New(), 36 | }, nil 37 | } 38 | 39 | func (c *SQLiteCache) GetResponse(ctx context.Context, req llm.InferRequest) (string, error) { 40 | c.numRequests++ 41 | requestJSON, err := json.Marshal(req) 42 | if err != nil { 43 | return "", err 44 | } 45 | 46 | hashedRequest := c.hasher.Sum(requestJSON) 47 | 48 | var response string 49 | err = c.db.QueryRowContext(ctx, "SELECT response FROM response_cache WHERE request = ?", hashedRequest).Scan(&response) 50 | if err != nil { 51 | return "", err 52 | } 53 | c.numCacheHits++ 54 | 55 | return response, nil 56 | } 57 | 58 | func (c *SQLiteCache) SetResponse(ctx context.Context, req llm.InferRequest, response string) error { 59 | requestJSON, err := json.Marshal(req) 60 | if err != nil { 61 | return err 62 | } 63 | hashedRequest := c.hasher.Sum(requestJSON) 64 | 65 | _, err = c.db.ExecContext(ctx, "INSERT INTO response_cache (request, response) VALUES (?, ?)", hashedRequest, response) 66 | return err 67 | } 68 | 69 | func (c *SQLiteCache) GetEmbedding(ctx context.Context, modelConfig string, input string) ([]float32, error) { 70 | c.numRequests++ 71 | var embeddingBlob []byte 72 | err := c.db.QueryRowContext(ctx, "SELECT embedding FROM embedding_cache WHERE model_config = ? AND input_string = ?", modelConfig, input).Scan(&embeddingBlob) 73 | if err != nil { 74 | return nil, err 75 | } 76 | c.numCacheHits++ 77 | 78 | var embedding []float32 79 | err = json.Unmarshal(embeddingBlob, &embedding) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | return embedding, nil 85 | } 86 | 87 | func (c *SQLiteCache) SetEmbedding(ctx context.Context, modelConfig string, input string, embedding []float32) error { 88 | embeddingBlob, err := json.Marshal(embedding) 89 | if err != nil { 90 | return err 91 | } 92 | 93 | _, err = c.db.ExecContext(ctx, "INSERT OR REPLACE INTO embedding_cache (model_config, input_string, embedding) VALUES (?, ?, ?)", modelConfig, input, embeddingBlob) 94 | return err 95 | } 96 | 97 | func (c *SQLiteCache) Close() error { 98 | return c.db.Close() 99 | } 100 | 101 | func (c *SQLiteCache) GetStats() cache.CacheStats { 102 | return cache.CacheStats{ 103 | NumRequests: c.numRequests, 104 | NumCacheHits: c.numCacheHits, 105 | } 106 | } 107 | 108 | func initDB(db *sql.DB) error { 109 | _, err := db.Exec(` 110 | CREATE TABLE IF NOT EXISTS response_cache ( 111 | id INTEGER PRIMARY KEY AUTOINCREMENT, 112 | request BLOB, 113 | response TEXT 114 | ); 115 | CREATE TABLE IF NOT EXISTS embedding_cache ( 116 | id INTEGER PRIMARY KEY AUTOINCREMENT, 117 | model_config TEXT, 118 | input_string TEXT, 119 | embedding BLOB, 120 | UNIQUE(model_config, input_string) 121 | ); 122 | `) 123 | if err != nil { 124 | return err 125 | } 126 | 127 | // Set to WAL mode for better performance 128 | _, err = db.Exec("PRAGMA journal_mode=WAL;") 129 | return err 130 | } 131 | 132 | // Ensure SQLiteCache implements the Cache interface 133 | var _ cache.Cache = (*SQLiteCache)(nil) 134 | -------------------------------------------------------------------------------- /packages/llm/internal/mocks/llm.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: llm.go 3 | 4 | // Package mock_llm is a generated GoMock package. 5 | package mock_llm 6 | 7 | import ( 8 | context "context" 9 | reflect "reflect" 10 | 11 | llm "github.com/stillmatic/gollum/packages/llm" 12 | gomock "go.uber.org/mock/gomock" 13 | ) 14 | 15 | // MockResponder is a mock of Responder interface. 16 | type MockResponder struct { 17 | ctrl *gomock.Controller 18 | recorder *MockResponderMockRecorder 19 | } 20 | 21 | // MockResponderMockRecorder is the mock recorder for MockResponder. 22 | type MockResponderMockRecorder struct { 23 | mock *MockResponder 24 | } 25 | 26 | // NewMockResponder creates a new mock instance. 27 | func NewMockResponder(ctrl *gomock.Controller) *MockResponder { 28 | mock := &MockResponder{ctrl: ctrl} 29 | mock.recorder = &MockResponderMockRecorder{mock} 30 | return mock 31 | } 32 | 33 | // EXPECT returns an object that allows the caller to indicate expected use. 34 | func (m *MockResponder) EXPECT() *MockResponderMockRecorder { 35 | return m.recorder 36 | } 37 | 38 | // GenerateResponse mocks base method. 39 | func (m *MockResponder) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) { 40 | m.ctrl.T.Helper() 41 | ret := m.ctrl.Call(m, "GenerateResponse", ctx, req) 42 | ret0, _ := ret[0].(string) 43 | ret1, _ := ret[1].(error) 44 | return ret0, ret1 45 | } 46 | 47 | // GenerateResponse indicates an expected call of GenerateResponse. 48 | func (mr *MockResponderMockRecorder) GenerateResponse(ctx, req interface{}) *gomock.Call { 49 | mr.mock.ctrl.T.Helper() 50 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateResponse", reflect.TypeOf((*MockResponder)(nil).GenerateResponse), ctx, req) 51 | } 52 | 53 | // GenerateResponseAsync mocks base method. 54 | func (m *MockResponder) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) { 55 | m.ctrl.T.Helper() 56 | ret := m.ctrl.Call(m, "GenerateResponseAsync", ctx, req) 57 | ret0, _ := ret[0].(<-chan llm.StreamDelta) 58 | ret1, _ := ret[1].(error) 59 | return ret0, ret1 60 | } 61 | 62 | // GenerateResponseAsync indicates an expected call of GenerateResponseAsync. 63 | func (mr *MockResponderMockRecorder) GenerateResponseAsync(ctx, req interface{}) *gomock.Call { 64 | mr.mock.ctrl.T.Helper() 65 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateResponseAsync", reflect.TypeOf((*MockResponder)(nil).GenerateResponseAsync), ctx, req) 66 | } 67 | 68 | // MockEmbedder is a mock of Embedder interface. 69 | type MockEmbedder struct { 70 | ctrl *gomock.Controller 71 | recorder *MockEmbedderMockRecorder 72 | } 73 | 74 | // MockEmbedderMockRecorder is the mock recorder for MockEmbedder. 75 | type MockEmbedderMockRecorder struct { 76 | mock *MockEmbedder 77 | } 78 | 79 | // NewMockEmbedder creates a new mock instance. 80 | func NewMockEmbedder(ctrl *gomock.Controller) *MockEmbedder { 81 | mock := &MockEmbedder{ctrl: ctrl} 82 | mock.recorder = &MockEmbedderMockRecorder{mock} 83 | return mock 84 | } 85 | 86 | // EXPECT returns an object that allows the caller to indicate expected use. 87 | func (m *MockEmbedder) EXPECT() *MockEmbedderMockRecorder { 88 | return m.recorder 89 | } 90 | 91 | // GenerateEmbedding mocks base method. 92 | func (m *MockEmbedder) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) { 93 | m.ctrl.T.Helper() 94 | ret := m.ctrl.Call(m, "GenerateEmbedding", ctx, req) 95 | ret0, _ := ret[0].(*llm.EmbeddingResponse) 96 | ret1, _ := ret[1].(error) 97 | return ret0, ret1 98 | } 99 | 100 | // GenerateEmbedding indicates an expected call of GenerateEmbedding. 101 | func (mr *MockEmbedderMockRecorder) GenerateEmbedding(ctx, req interface{}) *gomock.Call { 102 | mr.mock.ctrl.T.Helper() 103 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateEmbedding", reflect.TypeOf((*MockEmbedder)(nil).GenerateEmbedding), ctx, req) 104 | } 105 | -------------------------------------------------------------------------------- /packages/vectorstore/vectorstore_memory.go: -------------------------------------------------------------------------------- 1 | package vectorstore 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "strings" 7 | 8 | "github.com/pkg/errors" 9 | "github.com/sashabaranov/go-openai" 10 | "github.com/stillmatic/gollum" 11 | "github.com/viterin/vek/vek32" 12 | "gocloud.dev/blob" 13 | ) 14 | 15 | // MemoryVectorStore embeds documents on insert and stores them in memory 16 | type MemoryVectorStore struct { 17 | Documents []gollum.Document 18 | LLM gollum.Embedder 19 | } 20 | 21 | func NewMemoryVectorStore(llm gollum.Embedder) *MemoryVectorStore { 22 | return &MemoryVectorStore{ 23 | Documents: make([]gollum.Document, 0), 24 | LLM: llm, 25 | } 26 | } 27 | 28 | func (m *MemoryVectorStore) Insert(ctx context.Context, d gollum.Document) error { 29 | // replace newlines with spaces and strip whitespace, per OpenAI's recommendation 30 | if d.Embedding == nil { 31 | cleanText := strings.ReplaceAll(d.Content, "\n", " ") 32 | cleanText = strings.TrimSpace(cleanText) 33 | 34 | embedding, err := m.LLM.CreateEmbeddings(ctx, openai.EmbeddingRequest{ 35 | Input: []string{cleanText}, 36 | // TODO: make this configurable -- may require forking the base library, this expects an enum 37 | Model: openai.AdaEmbeddingV2, 38 | }) 39 | if err != nil { 40 | return errors.Wrap(err, "failed to create embedding") 41 | } 42 | d.Embedding = embedding.Data[0].Embedding 43 | } 44 | 45 | m.Documents = append(m.Documents, d) 46 | return nil 47 | } 48 | 49 | func (m *MemoryVectorStore) Persist(ctx context.Context, bucket *blob.Bucket, path string) error { 50 | // save documents to disk 51 | data, err := json.Marshal(m.Documents) 52 | if err != nil { 53 | return errors.Wrap(err, "failed to marshal documents to JSON") 54 | } 55 | err = bucket.WriteAll(ctx, path, data, nil) 56 | if err != nil { 57 | return errors.Wrap(err, "failed to write documents to file") 58 | } 59 | return nil 60 | } 61 | 62 | func NewMemoryVectorStoreFromDisk(ctx context.Context, bucket *blob.Bucket, path string, llm gollum.Embedder) (*MemoryVectorStore, error) { 63 | data, err := bucket.ReadAll(ctx, path) 64 | if err != nil { 65 | return nil, errors.Wrap(err, "failed to read file") 66 | } 67 | var documents []gollum.Document 68 | err = json.Unmarshal(data, &documents) 69 | if err != nil { 70 | return nil, errors.Wrap(err, "failed to unmarshal JSON") 71 | } 72 | return &MemoryVectorStore{ 73 | Documents: documents, 74 | LLM: llm, 75 | }, nil 76 | } 77 | 78 | func (m *MemoryVectorStore) Query(ctx context.Context, qb QueryRequest) ([]*gollum.Document, error) { 79 | if len(m.Documents) == 0 { 80 | return nil, errors.New("no documents in store") 81 | } 82 | if len(qb.EmbeddingStrings) > 0 { 83 | // concatenate strings and set query 84 | qb.Query = strings.Join(qb.EmbeddingStrings, " ") 85 | } 86 | if len(qb.EmbeddingFloats) == 0 { 87 | // create embedding 88 | embedding, err := m.LLM.CreateEmbeddings(ctx, openai.EmbeddingRequest{ 89 | Input: []string{qb.Query}, 90 | // TODO: make this configurable 91 | Model: openai.AdaEmbeddingV2, 92 | }) 93 | if err != nil { 94 | return nil, errors.Wrap(err, "failed to create embedding") 95 | } 96 | qb.EmbeddingFloats = embedding.Data[0].Embedding 97 | } 98 | scores := Heap{} 99 | k := qb.K 100 | scores.Init(k) 101 | 102 | for _, doc := range m.Documents { 103 | score := vek32.CosineSimilarity(qb.EmbeddingFloats, doc.Embedding) 104 | doc := doc 105 | ns := NodeSimilarity{ 106 | Document: &doc, 107 | Similarity: score, 108 | } 109 | // maintain a max-heap of size k 110 | scores.Push(ns) 111 | if scores.Len() > k { 112 | scores.Pop() 113 | } 114 | } 115 | 116 | result := make([]*gollum.Document, k) 117 | for i := 0; i < k; i++ { 118 | ns := scores.Pop() 119 | doc := ns.Document 120 | result[k-i-1] = doc 121 | } 122 | return result, nil 123 | } 124 | 125 | // RetrieveAll returns all documents 126 | func (m *MemoryVectorStore) RetrieveAll(ctx context.Context) ([]gollum.Document, error) { 127 | return m.Documents, nil 128 | } 129 | -------------------------------------------------------------------------------- /packages/llm/providers/cached/cached_provider_test.go: -------------------------------------------------------------------------------- 1 | package cached_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/stillmatic/gollum/packages/llm" 8 | mock_llm "github.com/stillmatic/gollum/packages/llm/internal/mocks" 9 | "github.com/stillmatic/gollum/packages/llm/providers/cached" 10 | "github.com/stretchr/testify/assert" 11 | "go.uber.org/mock/gomock" 12 | ) 13 | 14 | func TestCachedProvider(t *testing.T) { 15 | ctrl := gomock.NewController(t) 16 | t.Run("responder", func(t *testing.T) { 17 | mockProvider := mock_llm.NewMockResponder(ctrl) 18 | ctx := context.Background() 19 | req := llm.InferRequest{ 20 | Messages: []llm.InferMessage{ 21 | {Content: "hello world", 22 | Role: "user", 23 | }, 24 | }, 25 | ModelConfig: llm.ModelConfig{ 26 | ModelName: "fake_model", 27 | ProviderType: llm.ProviderAnthropic, 28 | }, 29 | } 30 | 31 | mockProvider.EXPECT().GenerateResponse(ctx, req).Return("hello user", nil) 32 | 33 | cachedProvider, err := cached.NewLocalCachedResponder(mockProvider, ":memory:") 34 | assert.NoError(t, err) 35 | resp, err := cachedProvider.GenerateResponse(ctx, req) 36 | assert.NoError(t, err) 37 | assert.Equal(t, "hello user", resp) 38 | 39 | cs := cachedProvider.GetCacheStats() 40 | assert.Equal(t, 1, cs.NumRequests) 41 | assert.Equal(t, 0, cs.NumCacheHits) 42 | 43 | resp, err = cachedProvider.GenerateResponse(ctx, req) 44 | assert.NoError(t, err) 45 | assert.Equal(t, "hello user", resp) 46 | 47 | cs = cachedProvider.GetCacheStats() 48 | assert.Equal(t, 2, cs.NumRequests) 49 | assert.Equal(t, 1, cs.NumCacheHits) 50 | }) 51 | 52 | t.Run("embedder", func(t *testing.T) { 53 | mockProvider := mock_llm.NewMockEmbedder(ctrl) 54 | ctx := context.Background() 55 | req := llm.EmbedRequest{ 56 | Input: []string{"abc"}, 57 | ModelConfig: llm.ModelConfig{ 58 | ModelName: "fake_model", 59 | ProviderType: llm.ProviderAnthropic, 60 | }, 61 | } 62 | 63 | // we call the function twice but it returns the cached value second time, so 64 | // the provider should only be called once 65 | mockProvider.EXPECT().GenerateEmbedding(ctx, req).Return(&llm.EmbeddingResponse{ 66 | Data: []llm.Embedding{{Values: []float32{1.0, 2.0, 3.0}}}}, nil).Times(1) 67 | 68 | cachedProvider, err := cached.NewLocalCachedEmbedder(mockProvider, ":memory:") 69 | assert.NoError(t, err) 70 | resp, err := cachedProvider.GenerateEmbedding(ctx, req) 71 | assert.NoError(t, err) 72 | assert.Equal(t, []float32{1.0, 2.0, 3.0}, resp.Data[0].Values) 73 | 74 | cs := cachedProvider.GetCacheStats() 75 | assert.Equal(t, 1, cs.NumRequests) 76 | assert.Equal(t, 0, cs.NumCacheHits) 77 | 78 | resp, err = cachedProvider.GenerateEmbedding(ctx, req) 79 | assert.NoError(t, err) 80 | assert.Equal(t, []float32{1.0, 2.0, 3.0}, resp.Data[0].Values) 81 | 82 | cs = cachedProvider.GetCacheStats() 83 | assert.Equal(t, 2, cs.NumRequests) 84 | assert.Equal(t, 1, cs.NumCacheHits) 85 | 86 | req = llm.EmbedRequest{ 87 | Input: []string{"abc", "def"}, 88 | ModelConfig: llm.ModelConfig{ 89 | ModelName: "fake_model", 90 | ProviderType: llm.ProviderAnthropic, 91 | }, 92 | } 93 | // this should be the request to the provider since we don't have the second embedding cached 94 | cachedReq := llm.EmbedRequest{ 95 | Input: []string{"def"}, 96 | ModelConfig: llm.ModelConfig{ 97 | ModelName: "fake_model", 98 | ProviderType: llm.ProviderAnthropic, 99 | }, 100 | } 101 | // and the provider only returns one 102 | mockProvider.EXPECT().GenerateEmbedding(ctx, cachedReq).Return(&llm.EmbeddingResponse{ 103 | Data: []llm.Embedding{{Values: []float32{2.0, 3.0, 4.0}}}}, nil).Times(1) 104 | 105 | //but we should expect to get both embeddings back 106 | resp, err = cachedProvider.GenerateEmbedding(ctx, req) 107 | assert.NoError(t, err) 108 | assert.Equal(t, []float32{1.0, 2.0, 3.0}, resp.Data[0].Values) 109 | assert.Equal(t, []float32{2.0, 3.0, 4.0}, resp.Data[1].Values) 110 | 111 | cs = cachedProvider.GetCacheStats() 112 | assert.Equal(t, 4, cs.NumRequests) 113 | assert.Equal(t, 2, cs.NumCacheHits) 114 | }) 115 | } 116 | -------------------------------------------------------------------------------- /packages/dispatch/dispatch.go: -------------------------------------------------------------------------------- 1 | package dispatch 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/stillmatic/gollum" 7 | "github.com/stillmatic/gollum/packages/jsonparser" 8 | "strings" 9 | "text/template" 10 | 11 | "github.com/sashabaranov/go-openai" 12 | ) 13 | 14 | type Dispatcher[T any] interface { 15 | // Prompt generates an object of type T from the given prompt. 16 | Prompt(ctx context.Context, prompt string) (T, error) 17 | // PromptTemplate generates an object of type T from a given template. 18 | // The prompt is then a template string that is rendered with the given values. 19 | PromptTemplate(ctx context.Context, template *template.Template, values interface{}) (T, error) 20 | } 21 | 22 | type DummyDispatcher[T any] struct{} 23 | 24 | func NewDummyDispatcher[T any]() *DummyDispatcher[T] { 25 | return &DummyDispatcher[T]{} 26 | } 27 | 28 | func (d *DummyDispatcher[T]) Prompt(ctx context.Context, prompt string) (T, error) { 29 | var t T 30 | return t, nil 31 | } 32 | 33 | func (d *DummyDispatcher[T]) PromptTemplate(ctx context.Context, template *template.Template, values interface{}) (T, error) { 34 | var t T 35 | var sb strings.Builder 36 | err := template.Execute(&sb, values) 37 | if err != nil { 38 | return t, fmt.Errorf("error executing template: %w", err) 39 | } 40 | return t, nil 41 | } 42 | 43 | type OpenAIDispatcherConfig struct { 44 | Model *string 45 | Temperature *float32 46 | MaxTokens *int 47 | } 48 | 49 | // OpenAIDispatcher dispatches to any OpenAI compatible model. 50 | // For any type T and prompt, it will generate and parse the response into T. 51 | type OpenAIDispatcher[T any] struct { 52 | *OpenAIDispatcherConfig 53 | completer gollum.ChatCompleter 54 | ti openai.Tool 55 | systemPrompt string 56 | parser jsonparser.Parser[T] 57 | } 58 | 59 | func NewOpenAIDispatcher[T any](name, description, systemPrompt string, completer gollum.ChatCompleter, cfg *OpenAIDispatcherConfig) *OpenAIDispatcher[T] { 60 | // note: name must not have spaces - valid json 61 | // we won't check here but the openai client will throw an error 62 | var t T 63 | fi := StructToJsonSchema(name, description, t) 64 | ti := FunctionInputToTool(fi) 65 | parser := jsonparser.NewJSONParserGeneric[T](true) 66 | return &OpenAIDispatcher[T]{ 67 | OpenAIDispatcherConfig: cfg, 68 | completer: completer, 69 | ti: ti, 70 | parser: parser, 71 | systemPrompt: systemPrompt, 72 | } 73 | } 74 | 75 | func (d *OpenAIDispatcher[T]) Prompt(ctx context.Context, prompt string) (T, error) { 76 | var output T 77 | model := openai.GPT3Dot5Turbo1106 78 | temperature := float32(0.0) 79 | maxTokens := 512 80 | if d.OpenAIDispatcherConfig != nil { 81 | if d.Model != nil { 82 | model = *d.Model 83 | } 84 | if d.Temperature != nil { 85 | temperature = *d.Temperature 86 | } 87 | if d.MaxTokens != nil { 88 | maxTokens = *d.MaxTokens 89 | } 90 | } 91 | 92 | req := openai.ChatCompletionRequest{ 93 | Model: model, 94 | Messages: []openai.ChatCompletionMessage{ 95 | { 96 | Role: openai.ChatMessageRoleSystem, 97 | Content: d.systemPrompt, 98 | }, 99 | { 100 | Role: openai.ChatMessageRoleUser, 101 | Content: prompt, 102 | }, 103 | }, 104 | Tools: []openai.Tool{d.ti}, 105 | ToolChoice: openai.ToolChoice{ 106 | Type: "function", 107 | Function: openai.ToolFunction{ 108 | Name: d.ti.Function.Name, 109 | }}, 110 | Temperature: temperature, 111 | MaxTokens: maxTokens, 112 | } 113 | 114 | resp, err := d.completer.CreateChatCompletion(ctx, req) 115 | if err != nil { 116 | return output, err 117 | } 118 | 119 | toolOutput := resp.Choices[0].Message.ToolCalls[0].Function.Arguments 120 | output, err = d.parser.Parse(ctx, []byte(toolOutput)) 121 | if err != nil { 122 | return output, err 123 | } 124 | 125 | return output, nil 126 | } 127 | 128 | // PromptTemplate generates an object of type T from a given template. 129 | // This is mostly a convenience wrapper around Prompt. 130 | func (d *OpenAIDispatcher[T]) PromptTemplate(ctx context.Context, template *template.Template, values interface{}) (T, error) { 131 | var t T 132 | var sb strings.Builder 133 | err := template.Execute(&sb, values) 134 | if err != nil { 135 | return t, fmt.Errorf("error executing template: %w", err) 136 | } 137 | return d.Prompt(ctx, sb.String()) 138 | } 139 | -------------------------------------------------------------------------------- /packages/hyde/hyde_test.go: -------------------------------------------------------------------------------- 1 | package hyde_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/stillmatic/gollum/packages/hyde" 7 | "github.com/stillmatic/gollum/packages/vectorstore" 8 | "testing" 9 | 10 | "github.com/stillmatic/gollum" 11 | mock_gollum "github.com/stillmatic/gollum/internal/mocks" 12 | "github.com/stillmatic/gollum/internal/testutil" 13 | "github.com/stretchr/testify/assert" 14 | "go.uber.org/mock/gomock" 15 | ) 16 | 17 | func TestHyde(t *testing.T) { 18 | ctrl := gomock.NewController(t) 19 | embedder := mock_gollum.NewMockEmbedder(ctrl) 20 | completer := mock_gollum.NewMockChatCompleter(ctrl) 21 | prompter := hyde.NewZeroShotPrompter( 22 | "Roleplay as a character. Write a short biographical answer to the question.\nQ: %s\nA:", 23 | ) 24 | generator := hyde.NewLLMGenerator(completer) 25 | encoder := hyde.NewLLMEncoder(embedder) 26 | vs := vectorstore.NewMemoryVectorStore(embedder) 27 | for i := range make([]int, 10) { 28 | embedder.EXPECT().CreateEmbeddings(context.Background(), gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(1, 1536), nil) 29 | vs.Insert(context.Background(), gollum.NewDocumentFromString(fmt.Sprintf("hey %d", i))) 30 | } 31 | assert.Equal(t, 10, len(vs.Documents)) 32 | searcher := hyde.NewVectorSearcher( 33 | vs, 34 | ) 35 | 36 | hyde := hyde.NewHyde(prompter, generator, encoder, searcher) 37 | t.Run("prompter", func(t *testing.T) { 38 | prompt := prompter.BuildPrompt(context.Background(), "What is your name?") 39 | assert.Equal(t, "Roleplay as a character. Write a short biographical answer to the question.\nQ: What is your name?\nA:", prompt) 40 | }) 41 | 42 | t.Run("generator", func(t *testing.T) { 43 | ctx := context.Background() 44 | k := 10 45 | completer.EXPECT().CreateChatCompletion(ctx, gomock.Any()).Return(testutil.GetRandomChatCompletionResponse(k), nil) 46 | res, err := generator.Generate(ctx, "What is your name?", k) 47 | assert.NoError(t, err) 48 | assert.Equal(t, 10, len(res)) 49 | }) 50 | 51 | t.Run("encoder", func(t *testing.T) { 52 | ctx := context.Background() 53 | embedder.EXPECT().CreateEmbeddings(ctx, gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(1, 1536), nil) 54 | res, err := encoder.Encode(ctx, "What is your name?") 55 | assert.NoError(t, err) 56 | assert.Equal(t, 1536, len(res)) 57 | 58 | embedder.EXPECT().CreateEmbeddings(ctx, gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(2, 1536), nil) 59 | res2, err := encoder.EncodeBatch(ctx, []string{"What is your name?", "What is your quest?"}) 60 | assert.NoError(t, err) 61 | assert.Equal(t, 2, len(res2)) 62 | assert.Equal(t, 1536, len(res2[0])) 63 | }) 64 | 65 | t.Run("e2e", func(t *testing.T) { 66 | ctx := context.Background() 67 | completer.EXPECT().CreateChatCompletion(ctx, gomock.Any()).Return(testutil.GetRandomChatCompletionResponse(3), nil) 68 | embedder.EXPECT().CreateEmbeddings(ctx, gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(4, 1536), nil) 69 | res, err := hyde.SearchEndToEnd(ctx, "What is your name?", 3) 70 | assert.NoError(t, err) 71 | assert.Equal(t, 3, len(res)) 72 | }) 73 | } 74 | 75 | func BenchmarkHyde(b *testing.B) { 76 | prompter := hyde.NewZeroShotPrompter( 77 | "Roleplay as a character. Write a short biographical answer to the question.\nQ: %s\nA:", 78 | ) 79 | ctrl := gomock.NewController(b) 80 | embedder := mock_gollum.NewMockEmbedder(ctrl) 81 | completer := mock_gollum.NewMockChatCompleter(ctrl) 82 | generator := hyde.NewLLMGenerator(completer) 83 | encoder := hyde.NewLLMEncoder(embedder) 84 | vs := vectorstore.NewMemoryVectorStore(embedder) 85 | searcher := hyde.NewVectorSearcher( 86 | vs, 87 | ) 88 | hyde := hyde.NewHyde(prompter, generator, encoder, searcher) 89 | docNums := []int{10, 100, 1000, 10000, 100_000, 1_000_000} 90 | for _, docNum := range docNums { 91 | b.Run(fmt.Sprintf("docs=%v", docNum), func(b *testing.B) { 92 | for _ = range make([]int, docNum) { 93 | embedder.EXPECT().CreateEmbeddings(gomock.Any(), gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(1, 1536), nil) 94 | vs.Insert(context.Background(), gollum.NewDocumentFromString("hey")) 95 | } 96 | b.ResetTimer() 97 | for i := 0; i < b.N; i++ { 98 | k := 8 99 | completer.EXPECT().CreateChatCompletion(context.Background(), gomock.Any()).Return(testutil.GetRandomChatCompletionResponse(k), nil) 100 | embedder.EXPECT().CreateEmbeddings(context.Background(), gomock.Any()).Return(testutil.GetRandomEmbeddingResponse(k+1, 1536), nil) 101 | _, err := hyde.SearchEndToEnd(context.Background(), "What is your name?", k) 102 | assert.NoError(b, err) 103 | } 104 | }) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /packages/llm/providers/anthropic/anthropic.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "log/slog" 7 | "slices" 8 | 9 | "github.com/stillmatic/gollum/packages/llm" 10 | 11 | "github.com/liushuangls/go-anthropic/v2" 12 | "github.com/pkg/errors" 13 | ) 14 | 15 | type Provider struct { 16 | client *anthropic.Client 17 | cacheEnabled bool 18 | } 19 | 20 | func NewAnthropicProvider(apiKey string) *Provider { 21 | return &Provider{ 22 | client: anthropic.NewClient(apiKey), 23 | } 24 | } 25 | 26 | func NewAnthropicProviderWithCache(apiKey string) *Provider { 27 | client := anthropic.NewClient(apiKey, anthropic.WithBetaVersion(anthropic.BetaPromptCaching20240731)) 28 | return &Provider{ 29 | client: client, 30 | cacheEnabled: true, 31 | } 32 | } 33 | 34 | func reqToMessages(req llm.InferRequest) ([]anthropic.Message, []anthropic.MessageSystemPart, error) { 35 | msgs := make([]anthropic.Message, 0) 36 | systemMsgs := make([]anthropic.MessageSystemPart, 0) 37 | for _, m := range req.Messages { 38 | if m.Role == "system" { 39 | msgContent := anthropic.MessageSystemPart{ 40 | Type: "text", 41 | Text: m.Content, 42 | } 43 | if m.ShouldCache { 44 | msgContent.CacheControl = &anthropic.MessageCacheControl{ 45 | Type: anthropic.CacheControlTypeEphemeral, 46 | } 47 | } 48 | systemMsgs = append(systemMsgs, msgContent) 49 | continue 50 | } 51 | 52 | // only allow user and assistant roles 53 | // TODO: this should be a little cleaner... 54 | if !(slices.Index([]string{string(anthropic.RoleUser), string(anthropic.RoleAssistant)}, m.Role) > -1) { 55 | return nil, nil, errors.New("invalid role") 56 | } 57 | content := make([]anthropic.MessageContent, 0) 58 | txtContent := anthropic.NewTextMessageContent(m.Content) 59 | // this will fail if the model is not configured to cache 60 | if m.ShouldCache { 61 | txtContent.SetCacheControl() 62 | } 63 | content = append(content, txtContent) 64 | if m.Image != nil && len(m.Image) > 0 { 65 | b64Image := base64.StdEncoding.EncodeToString(m.Image) 66 | // TODO: support other image types 67 | content = append(content, anthropic.NewImageMessageContent( 68 | anthropic.MessageContentSource{Type: "base64", MediaType: "image/png", Data: b64Image})) 69 | } 70 | newMsg := anthropic.Message{ 71 | Role: anthropic.ChatRole(m.Role), 72 | Content: content, 73 | } 74 | 75 | msgs = append(msgs, newMsg) 76 | } 77 | 78 | return msgs, systemMsgs, nil 79 | } 80 | 81 | func (p *Provider) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) { 82 | msgs, systemPrompt, err := reqToMessages(req) 83 | if err != nil { 84 | return "", errors.Wrap(err, "invalid messages") 85 | } 86 | msgsReq := anthropic.MessagesRequest{ 87 | Model: anthropic.Model(req.ModelConfig.ModelName), 88 | Messages: msgs, 89 | MaxTokens: req.MessageOptions.MaxTokens, 90 | Temperature: &req.MessageOptions.Temperature, 91 | } 92 | if systemPrompt != nil && len(systemPrompt) > 0 { 93 | msgsReq.MultiSystem = systemPrompt 94 | } 95 | res, err := p.client.CreateMessagesStream(ctx, anthropic.MessagesStreamRequest{ 96 | MessagesRequest: msgsReq, 97 | }) 98 | if err != nil { 99 | return "", errors.Wrap(err, "anthropic messages stream error") 100 | } 101 | 102 | return res.GetFirstContentText(), nil 103 | } 104 | 105 | func (p *Provider) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) { 106 | outChan := make(chan llm.StreamDelta) 107 | go func() { 108 | defer close(outChan) 109 | msgs, systemPrompt, err := reqToMessages(req) 110 | if err != nil { 111 | slog.Error("invalid messages", "err", err) 112 | return 113 | } 114 | msgsReq := anthropic.MessagesRequest{ 115 | Model: anthropic.Model(req.ModelConfig.ModelName), 116 | Messages: msgs, 117 | MaxTokens: req.MessageOptions.MaxTokens, 118 | Temperature: &req.MessageOptions.Temperature, 119 | } 120 | if systemPrompt != nil { 121 | msgsReq.MultiSystem = systemPrompt 122 | } 123 | 124 | _, err = p.client.CreateMessagesStream(ctx, anthropic.MessagesStreamRequest{ 125 | MessagesRequest: msgsReq, 126 | OnContentBlockDelta: func(data anthropic.MessagesEventContentBlockDeltaData) { 127 | if data.Delta.Text == nil { 128 | outChan <- llm.StreamDelta{ 129 | EOF: true, 130 | } 131 | return 132 | } 133 | 134 | outChan <- llm.StreamDelta{ 135 | Text: *data.Delta.Text, 136 | } 137 | }, 138 | }) 139 | if err != nil { 140 | slog.Error("anthropic messages stream error", "err", err) 141 | return 142 | } 143 | }() 144 | 145 | return outChan, nil 146 | } 147 | -------------------------------------------------------------------------------- /packages/dispatch/dispatch_test.go: -------------------------------------------------------------------------------- 1 | package dispatch_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | "text/template" 8 | 9 | "github.com/sashabaranov/go-openai" 10 | mock_gollum "github.com/stillmatic/gollum/internal/mocks" 11 | "github.com/stillmatic/gollum/packages/dispatch" 12 | "github.com/stretchr/testify/assert" 13 | "go.uber.org/mock/gomock" 14 | ) 15 | 16 | type testInput struct { 17 | Topic string `json:"topic" jsonschema:"required" jsonschema_description:"The topic of the conversation"` 18 | RandomWords []string `json:"random_words" jsonschema:"required" jsonschema_description:"Random words to prime the conversation"` 19 | } 20 | 21 | type templateInput struct { 22 | Topic string 23 | } 24 | 25 | type wordCountOutput struct { 26 | Count int `json:"count" jsonschema:"required" jsonschema_description:"The number of words in the sentence"` 27 | } 28 | 29 | func TestDummyDispatcher(t *testing.T) { 30 | d := dispatch.NewDummyDispatcher[testInput]() 31 | 32 | t.Run("prompt", func(t *testing.T) { 33 | output, err := d.Prompt(context.Background(), "Talk to me about Dinosaurs") 34 | 35 | assert.NoError(t, err) 36 | assert.Equal(t, testInput{}, output) 37 | }) 38 | t.Run("promptTemplate", func(t *testing.T) { 39 | te, err := template.New("").Parse("Talk to me about {{.Topic}}") 40 | assert.NoError(t, err) 41 | tempInp := templateInput{ 42 | Topic: "Dinosaurs", 43 | } 44 | 45 | output, err := d.PromptTemplate(context.Background(), te, tempInp) 46 | assert.NoError(t, err) 47 | assert.Equal(t, testInput{}, output) 48 | }) 49 | } 50 | 51 | func TestOpenAIDispatcher(t *testing.T) { 52 | ctrl := gomock.NewController(t) 53 | completer := mock_gollum.NewMockChatCompleter(ctrl) 54 | systemPrompt := "When prompted, use the tool." 55 | d := dispatch.NewOpenAIDispatcher[testInput]("random_conversation", "Given a topic, return random words", systemPrompt, completer, nil) 56 | 57 | ctx := context.Background() 58 | expected := testInput{ 59 | Topic: "dinosaurs", 60 | RandomWords: []string{"dinosaur", "fossil", "extinct"}, 61 | } 62 | inpStr := `{"topic": "dinosaurs", "random_words": ["dinosaur", "fossil", "extinct"]}` 63 | 64 | fi := openai.FunctionDefinition(dispatch.StructToJsonSchema("random_conversation", "Given a topic, return random words", testInput{})) 65 | ti := openai.Tool{Type: "function", Function: &fi} 66 | expectedRequest := openai.ChatCompletionRequest{ 67 | Model: openai.GPT3Dot5Turbo1106, 68 | Messages: []openai.ChatCompletionMessage{ 69 | { 70 | Role: openai.ChatMessageRoleUser, 71 | Content: "Tell me about dinosaurs", 72 | }, 73 | }, 74 | Tools: []openai.Tool{ti}, 75 | ToolChoice: openai.ToolChoice{ 76 | Type: "function", 77 | Function: openai.ToolFunction{ 78 | Name: "random_conversation", 79 | }}, 80 | MaxTokens: 512, 81 | Temperature: 0.0, 82 | } 83 | 84 | t.Run("prompt", func(t *testing.T) { 85 | queryStr := "Tell me about dinosaurs" 86 | completer.EXPECT().CreateChatCompletion(gomock.Any(), expectedRequest).Return(openai.ChatCompletionResponse{ 87 | Choices: []openai.ChatCompletionChoice{ 88 | { 89 | Message: openai.ChatCompletionMessage{ 90 | Role: openai.ChatMessageRoleSystem, 91 | Content: "Hello there!", 92 | ToolCalls: []openai.ToolCall{ 93 | { 94 | Type: "function", 95 | Function: openai.FunctionCall{ 96 | Name: "random_conversation", 97 | Arguments: inpStr, 98 | }, 99 | }}, 100 | }, 101 | }, 102 | }, 103 | }, nil) 104 | 105 | output, err := d.Prompt(ctx, queryStr) 106 | assert.NoError(t, err) 107 | 108 | assert.Equal(t, expected, output) 109 | }) 110 | t.Run("promptTemplate", func(t *testing.T) { 111 | completer.EXPECT().CreateChatCompletion(gomock.Any(), expectedRequest).Return(openai.ChatCompletionResponse{ 112 | Choices: []openai.ChatCompletionChoice{ 113 | { 114 | Message: openai.ChatCompletionMessage{ 115 | Role: openai.ChatMessageRoleSystem, 116 | Content: "Hello there!", 117 | ToolCalls: []openai.ToolCall{ 118 | { 119 | Type: "function", 120 | Function: openai.FunctionCall{ 121 | Name: "random_conversation", 122 | Arguments: inpStr, 123 | }, 124 | }}, 125 | }, 126 | }, 127 | }, 128 | }, nil) 129 | 130 | te, err := template.New("").Parse("Tell me about {{.Topic}}") 131 | assert.NoError(t, err) 132 | 133 | output, err := d.PromptTemplate(ctx, te, templateInput{ 134 | Topic: "dinosaurs", 135 | }) 136 | assert.NoError(t, err) 137 | assert.Equal(t, expected, output) 138 | }) 139 | } 140 | 141 | func TestDispatchIntegration(t *testing.T) { 142 | t.Skip("Skipping integration test") 143 | completer := openai.NewClient(os.Getenv("OPENAI_API_KEY")) 144 | systemPrompt := "When prompted, use the tool on the user's input." 145 | d := dispatch.NewOpenAIDispatcher[wordCountOutput]("wordCounter", "count the number of words in a sentence", systemPrompt, completer, nil) 146 | output, err := d.Prompt(context.Background(), "I like dinosaurs") 147 | assert.NoError(t, err) 148 | assert.Equal(t, 3, output.Count) 149 | } 150 | -------------------------------------------------------------------------------- /packages/llm/providers/openai/openai.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "github.com/stillmatic/gollum/packages/llm" 7 | "io" 8 | "log/slog" 9 | 10 | "github.com/pkg/errors" 11 | "github.com/sashabaranov/go-openai" 12 | ) 13 | 14 | type Provider struct { 15 | client *openai.Client 16 | } 17 | 18 | func NewOpenAIProvider(apiKey string) *Provider { 19 | return &Provider{ 20 | client: openai.NewClient(apiKey), 21 | } 22 | } 23 | 24 | func NewGenericProvider(apiKey string, baseURL string) *Provider { 25 | genericConfig := openai.DefaultConfig(apiKey) 26 | genericConfig.BaseURL = baseURL 27 | return &Provider{ 28 | client: openai.NewClientWithConfig(genericConfig), 29 | } 30 | } 31 | 32 | func NewTogetherProvider(apiKey string) *Provider { 33 | return NewGenericProvider(apiKey, "https://api.together.xyz/v1") 34 | } 35 | 36 | func NewGroqProvider(apiKey string) *Provider { 37 | return NewGenericProvider(apiKey, "https://api.groq.com/openai/v1/") 38 | } 39 | 40 | func NewHyperbolicProvider(apiKey string) *Provider { 41 | return NewGenericProvider(apiKey, "https://api.hyperbolic.xyz/v1") 42 | } 43 | 44 | func NewDeepseekProvider(apiKey string) *Provider { 45 | return NewGenericProvider(apiKey, "https://api.deepseek.com/v1") 46 | } 47 | 48 | func (p *Provider) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) { 49 | msgs := inferReqToOpenAIMessages(req.Messages) 50 | 51 | oaiReq := openai.ChatCompletionRequest{ 52 | Model: req.ModelConfig.ModelName, 53 | Messages: msgs, 54 | MaxTokens: req.MessageOptions.MaxTokens, 55 | Temperature: req.MessageOptions.Temperature, 56 | } 57 | 58 | res, err := p.client.CreateChatCompletion(ctx, oaiReq) 59 | if err != nil { 60 | slog.Error("error from openai", "err", err, "req", req.Messages, "model", req.ModelConfig.ModelName) 61 | return "", errors.Wrap(err, "openai chat completion error") 62 | } 63 | 64 | return res.Choices[0].Message.Content, nil 65 | } 66 | 67 | func inferReqToOpenAIMessages(req []llm.InferMessage) []openai.ChatCompletionMessage { 68 | msgs := make([]openai.ChatCompletionMessage, 0) 69 | 70 | for _, m := range req { 71 | msg := openai.ChatCompletionMessage{ 72 | Role: m.Role, 73 | Content: m.Content, 74 | } 75 | if m.Image != nil && len(m.Image) > 0 { 76 | b64Image := base64.StdEncoding.EncodeToString(m.Image) 77 | msg.MultiContent = []openai.ChatMessagePart{ 78 | { 79 | Type: openai.ChatMessagePartTypeImageURL, 80 | // TODO: support other image types 81 | ImageURL: &openai.ChatMessageImageURL{ 82 | URL: "data:image/png;base64," + b64Image, 83 | Detail: openai.ImageURLDetailAuto, 84 | }, 85 | }, 86 | { 87 | Type: openai.ChatMessagePartTypeText, 88 | Text: m.Content, 89 | }, 90 | } 91 | msg.Content = "" 92 | } 93 | msgs = append(msgs, msg) 94 | } 95 | return msgs 96 | } 97 | 98 | func (p *Provider) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) { 99 | outChan := make(chan llm.StreamDelta) 100 | go func() { 101 | defer close(outChan) 102 | msgs := inferReqToOpenAIMessages(req.Messages) 103 | oaiReq := openai.ChatCompletionRequest{ 104 | Model: req.ModelConfig.ModelName, 105 | Messages: msgs, 106 | MaxTokens: req.MessageOptions.MaxTokens, 107 | Temperature: req.MessageOptions.Temperature, 108 | } 109 | 110 | stream, err := p.client.CreateChatCompletionStream(ctx, oaiReq) 111 | if err != nil { 112 | slog.Error("error from openai", "err", err, "req", req.Messages, "model", req.ModelConfig.ModelName) 113 | return 114 | } 115 | defer stream.Close() 116 | 117 | response, err := stream.Recv() 118 | if err != nil { 119 | if err == io.EOF { 120 | return 121 | } 122 | slog.Error("error receiving from openai stream", "err", err) 123 | return 124 | } 125 | 126 | if len(response.Choices) > 0 { 127 | content := response.Choices[0].Delta.Content 128 | if content != "" { 129 | select { 130 | case <-ctx.Done(): 131 | return 132 | case outChan <- llm.StreamDelta{ 133 | Text: content}: 134 | } 135 | } else { 136 | outChan <- llm.StreamDelta{ 137 | EOF: true, 138 | } 139 | } 140 | } 141 | }() 142 | 143 | return outChan, nil 144 | } 145 | 146 | func (p *Provider) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) { 147 | // TODO: this only supports openai models, not other providers using the same interface 148 | oaiReq := openai.EmbeddingRequest{ 149 | Input: req.Input, 150 | Model: openai.EmbeddingModel(req.ModelConfig.ModelName), 151 | Dimensions: req.Dimensions, 152 | } 153 | 154 | res, err := p.client.CreateEmbeddings(ctx, oaiReq) 155 | if err != nil { 156 | slog.Error("error from openai", "err", err, "req", req.Input, "model", req.ModelConfig.ModelName) 157 | return nil, errors.Wrap(err, "openai embedding error") 158 | } 159 | 160 | respVectors := make([]llm.Embedding, len(res.Data)) 161 | for i, v := range res.Data { 162 | respVectors[i] = llm.Embedding{ 163 | Values: v.Embedding, 164 | } 165 | } 166 | 167 | return &llm.EmbeddingResponse{ 168 | Data: respVectors, 169 | }, nil 170 | } 171 | -------------------------------------------------------------------------------- /packages/vectorstore/vectorstore.md: -------------------------------------------------------------------------------- 1 | # Document Stores 2 | 3 | Gollum provides several implementations of document stores. Document stores solve the problem of, we have lots of documents and want to find the most relevant documents to a particular query. 4 | 5 | Names are TBD. 6 | 7 | # docstore 8 | 9 | Docstore is a simple document store which provides no indexing. It simply has insert and retrieve, you must know the ID. 10 | 11 | Think of it as essentially a key-value store. In the future, we will probably extend this to have a KV interface and provide an implementation backed by Redis or DragonflyDB. 12 | 13 | # memory vector store 14 | 15 | This is a simple document store that takes an embedding model and embeds documents on insert. At retrieval time, it embeds the search query and does a simple KNN lookup. 16 | 17 | # xyz vector store 18 | 19 | I haven't gotten around to actually writing any of these implementations but it should be simple to imagine clients for Weaviate or Pinecone following the interface. I don't actually use them though :) 20 | 21 | 22 | 23 | # compressed store 24 | 25 | This is inspired by [link] - basically we can use gzip to compress the documents, then at query time, compute `enc(term) + enc(doc)` - in this case, `enc(doc)` is computed on insert. The idea is that the more similar your term and document are, the shorter the encoded representation is - because there is less entropy that needs to be included in the compressed representation. 26 | 27 | 28 | On an M1 Max 29 | 30 | ``` 31 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-10-1-10 25092 49245 ns/op 288 B/op 3 allocs/op 32 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-10-10-10 13317 107361 ns/op 2880 B/op 6 allocs/op 33 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100-1-10 2582 657344 ns/op 300 B/op 3 allocs/op 34 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100-10-10 1051 1286167 ns/op 2880 B/op 6 allocs/op 35 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100-100-10 688 1879785 ns/op 25887 B/op 9 allocs/op 36 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-1000-1-10 202 7782683 ns/op 569 B/op 3 allocs/op 37 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-1000-10-10 92 11442531 ns/op 3461 B/op 6 allocs/op 38 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-1000-100-10 73 15614635 ns/op 25765 B/op 9 allocs/op 39 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-10000-1-10 32 52570010 ns/op 288 B/op 3 allocs/op 40 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-10000-10-10 15 95620514 ns/op 3161 B/op 6 allocs/op 41 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-10000-100-10 9 138641384 ns/op 28024 B/op 10 allocs/op 42 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100000-1-10 2 523147917 ns/op 6624 B/op 5 allocs/op 43 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100000-10-10 2 985437625 ns/op 9280 B/op 10 allocs/op 44 | BenchmarkCompressedVectorStore/ZstdVectorStore-Query-100000-100-10 1 1123811333 ns/op 33600 B/op 14 allocs/op 45 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-10-1-10 18867 107243 ns/op 288 B/op 3 allocs/op 46 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-10-10-10 8036 202150 ns/op 2880 B/op 6 allocs/op 47 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100-1-10 1106 1541100 ns/op 288 B/op 3 allocs/op 48 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100-10-10 511 2303053 ns/op 2880 B/op 6 allocs/op 49 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100-100-10 327 3784640 ns/op 25664 B/op 9 allocs/op 50 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-1000-1-10 100 12470110 ns/op 288 B/op 3 allocs/op 51 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-1000-10-10 70 21186270 ns/op 2880 B/op 6 allocs/op 52 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-1000-100-10 58 26575966 ns/op 25664 B/op 9 allocs/op 53 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-10000-1-10 16 102998362 ns/op 288 B/op 3 allocs/op 54 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-10000-10-10 7 200548024 ns/op 2880 B/op 6 allocs/op 55 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-10000-100-10 4 290161719 ns/op 25664 B/op 9 allocs/op 56 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100000-1-10 2 1115429562 ns/op 288 B/op 3 allocs/op 57 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100000-10-10 1 1561993375 ns/op 2880 B/op 6 allocs/op 58 | BenchmarkCompressedVectorStore/GzipVectorStore-Query-100000-100-10 1 1751069333 ns/op 25664 B/op 9 allocs/op 59 | PASS 60 | ``` 61 | 62 | I have done a fair amount of allocation chasing, I think there is a bit more work to do, but actually, it's pretty slow. 63 | 64 | I also think adding [Lempel-Ziv Jaccard Distance](https://arxiv.org/pdf/1708.03346.pdf) is quite promising. We would need to write it in Go or add Cgo bindings to the C version. -------------------------------------------------------------------------------- /packages/vectorstore/vectorstore_compressed.go: -------------------------------------------------------------------------------- 1 | package vectorstore 2 | 3 | import ( 4 | "bytes" 5 | stdgzip "compress/gzip" 6 | "context" 7 | "github.com/stillmatic/gollum/packages/syncpool" 8 | "io" 9 | "sync" 10 | 11 | gzip "github.com/klauspost/compress/gzip" 12 | "github.com/klauspost/compress/zstd" 13 | "github.com/stillmatic/gollum" 14 | ) 15 | 16 | // Compressor is a single method interface that returns a compressed representation of an object. 17 | type Compressor interface { 18 | Compress(src []byte) []byte 19 | } 20 | 21 | // GzipCompressor uses the klauspost/compress gzip compressor. 22 | // We generally suggest using this optimized implementation over the stdlib. 23 | type GzipCompressor struct { 24 | pool syncpool.Pool[*gzip.Writer] 25 | } 26 | 27 | // ZstdCompressor uses the klauspost/compress zstd compressor. 28 | type ZstdCompressor struct { 29 | pool syncpool.Pool[*zstd.Encoder] 30 | enc *zstd.Encoder 31 | } 32 | 33 | // StdGzipCompressor uses the std gzip compressor. 34 | type StdGzipCompressor struct { 35 | pool syncpool.Pool[*stdgzip.Writer] 36 | } 37 | 38 | type DummyCompressor struct { 39 | } 40 | 41 | func (g *GzipCompressor) Compress(src []byte) []byte { 42 | gz := g.pool.Get() 43 | var b bytes.Buffer 44 | defer g.pool.Put(gz) 45 | 46 | gz.Reset(&b) 47 | if _, err := gz.Write(src); err != nil { 48 | panic(err) 49 | } 50 | if err := gz.Close(); err != nil { 51 | panic(err) 52 | } 53 | return b.Bytes() 54 | } 55 | 56 | func (g *ZstdCompressor) Compress(src []byte) []byte { 57 | // return g.enc.EncodeAll(src, make([]byte, 0, len(src))) 58 | var b bytes.Buffer 59 | b.Reset() 60 | enc := g.enc 61 | // zstd := g.pool.Get() 62 | enc.Reset(&b) 63 | if _, err := enc.Write(src); err != nil { 64 | panic(err) 65 | } 66 | if err := enc.Flush(); err != nil { 67 | panic(err) 68 | } 69 | 70 | // g.pool.Put(zstd) 71 | return b.Bytes() 72 | } 73 | 74 | func (g *DummyCompressor) Compress(src []byte) []byte { 75 | return src 76 | } 77 | 78 | func (g *StdGzipCompressor) Compress(src []byte) []byte { 79 | var b bytes.Buffer 80 | gz := g.pool.Get() 81 | defer g.pool.Put(gz) 82 | gz.Reset(&b) 83 | 84 | if _, err := gz.Write(src); err != nil { 85 | panic(err) 86 | } 87 | if err := gz.Close(); err != nil { 88 | panic(err) 89 | } 90 | return b.Bytes() 91 | } 92 | 93 | type CompressedDocument struct { 94 | *gollum.Document 95 | Encoded []byte 96 | Unencoded []byte 97 | } 98 | 99 | type CompressedVectorStore struct { 100 | Data []CompressedDocument 101 | Compressor Compressor 102 | } 103 | 104 | // Insert compresses the document and inserts it into the store. 105 | // An alternative implementation would ONLY store the compressed representation and decompress as necessary. 106 | func (ts *CompressedVectorStore) Insert(ctx context.Context, d gollum.Document) error { 107 | bb := bufPool.Get().(*bytes.Buffer) 108 | defer bufPool.Put(bb) 109 | bb.Reset() 110 | bb.WriteString(d.Content) 111 | docBytes := bb.Bytes() 112 | encoded := ts.Compressor.Compress(docBytes) 113 | ts.Data = append(ts.Data, CompressedDocument{Document: &d, Encoded: encoded, Unencoded: docBytes}) 114 | return nil 115 | } 116 | 117 | func minMax(val1, val2 float64) (float64, float64) { 118 | if val1 < val2 { 119 | return val1, val2 120 | } 121 | return val2, val1 122 | } 123 | 124 | var bufPool = sync.Pool{ 125 | New: func() any { 126 | return new(bytes.Buffer) 127 | }, 128 | } 129 | 130 | var spaceBytes = []byte(" ") 131 | 132 | func (cvs *CompressedVectorStore) Query(ctx context.Context, qb QueryRequest) ([]*gollum.Document, error) { 133 | bb := bufPool.Get().(*bytes.Buffer) 134 | defer bufPool.Put(bb) 135 | bb.Reset() 136 | queryBytes := make([]byte, len(qb.Query)) 137 | copy(queryBytes, qb.Query) 138 | searchTermEncoded := cvs.Compressor.Compress(queryBytes) 139 | 140 | k := qb.K 141 | if k > len(cvs.Data) { 142 | k = len(cvs.Data) 143 | } 144 | h := make(Heap, 0, k+1) 145 | h.Init(k + 1) 146 | 147 | for _, doc := range cvs.Data { 148 | Cx1 := float64(len(searchTermEncoded)) 149 | Cx2 := float64(len(doc.Encoded)) 150 | bb.Write(queryBytes) 151 | bb.Write(spaceBytes) 152 | bb.Write(doc.Unencoded) 153 | x1x2 := cvs.Compressor.Compress(bb.Bytes()) 154 | Cx1x2 := float64(len(x1x2)) 155 | min, max := minMax(Cx1, Cx2) 156 | ncd := (Cx1x2 - min) / (max) 157 | // ncd := 0.5 158 | 159 | node := NodeSimilarity{ 160 | Document: doc.Document, 161 | Similarity: float32(ncd), 162 | } 163 | 164 | h.Push(node) 165 | if h.Len() > k { 166 | h.Pop() 167 | } 168 | bb.Reset() 169 | } 170 | 171 | docs := make([]*gollum.Document, k) 172 | for i := range docs { 173 | docs[k-i-1] = h.Pop().Document 174 | } 175 | 176 | return docs, nil 177 | } 178 | 179 | func (cvs *CompressedVectorStore) RetrieveAll(ctx context.Context) ([]gollum.Document, error) { 180 | docs := make([]gollum.Document, len(cvs.Data)) 181 | for i, doc := range cvs.Data { 182 | docs[i] = *doc.Document 183 | } 184 | return docs, nil 185 | } 186 | 187 | func NewStdGzipVectorStore() *CompressedVectorStore { 188 | w := io.Discard 189 | return &CompressedVectorStore{ 190 | Compressor: &StdGzipCompressor{ 191 | pool: syncpool.New[*stdgzip.Writer](func() *stdgzip.Writer { 192 | return stdgzip.NewWriter(w) 193 | }), 194 | }, 195 | } 196 | } 197 | 198 | func NewGzipVectorStore() *CompressedVectorStore { 199 | w := io.Discard 200 | return &CompressedVectorStore{ 201 | Compressor: &GzipCompressor{ 202 | pool: syncpool.New[*gzip.Writer](func() *gzip.Writer { 203 | return gzip.NewWriter(w) 204 | }), 205 | }, 206 | } 207 | } 208 | 209 | func NewZstdVectorStore() *CompressedVectorStore { 210 | w := io.Discard 211 | enc, err := zstd.NewWriter(w, zstd.WithEncoderCRC(false)) 212 | if err != nil { 213 | panic(err) 214 | } 215 | return &CompressedVectorStore{ 216 | Compressor: &ZstdCompressor{ 217 | enc: enc, 218 | pool: syncpool.New[*zstd.Encoder](func() *zstd.Encoder { 219 | enc, err := zstd.NewWriter(w, zstd.WithEncoderCRC(false)) 220 | if err != nil { 221 | panic(err) 222 | } 223 | return enc 224 | }), 225 | }, 226 | } 227 | } 228 | 229 | func NewDummyVectorStore() *CompressedVectorStore { 230 | return &CompressedVectorStore{ 231 | Compressor: &DummyCompressor{}, 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /internal/mocks/llm.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: llm.go 3 | 4 | // Package mock_gollum is a generated GoMock package. 5 | package mock_gollum 6 | 7 | import ( 8 | context "context" 9 | reflect "reflect" 10 | 11 | openai "github.com/sashabaranov/go-openai" 12 | gomock "go.uber.org/mock/gomock" 13 | ) 14 | 15 | // MockCompleter is a mock of Completer interface. 16 | type MockCompleter struct { 17 | ctrl *gomock.Controller 18 | recorder *MockCompleterMockRecorder 19 | } 20 | 21 | // MockCompleterMockRecorder is the mock recorder for MockCompleter. 22 | type MockCompleterMockRecorder struct { 23 | mock *MockCompleter 24 | } 25 | 26 | // NewMockCompleter creates a new mock instance. 27 | func NewMockCompleter(ctrl *gomock.Controller) *MockCompleter { 28 | mock := &MockCompleter{ctrl: ctrl} 29 | mock.recorder = &MockCompleterMockRecorder{mock} 30 | return mock 31 | } 32 | 33 | // EXPECT returns an object that allows the caller to indicate expected use. 34 | func (m *MockCompleter) EXPECT() *MockCompleterMockRecorder { 35 | return m.recorder 36 | } 37 | 38 | // CreateCompletion mocks base method. 39 | func (m *MockCompleter) CreateCompletion(arg0 context.Context, arg1 openai.CompletionRequest) (openai.CompletionResponse, error) { 40 | m.ctrl.T.Helper() 41 | ret := m.ctrl.Call(m, "CreateCompletion", arg0, arg1) 42 | ret0, _ := ret[0].(openai.CompletionResponse) 43 | ret1, _ := ret[1].(error) 44 | return ret0, ret1 45 | } 46 | 47 | // CreateCompletion indicates an expected call of CreateCompletion. 48 | func (mr *MockCompleterMockRecorder) CreateCompletion(arg0, arg1 interface{}) *gomock.Call { 49 | mr.mock.ctrl.T.Helper() 50 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCompletion", reflect.TypeOf((*MockCompleter)(nil).CreateCompletion), arg0, arg1) 51 | } 52 | 53 | // MockChatCompleter is a mock of ChatCompleter interface. 54 | type MockChatCompleter struct { 55 | ctrl *gomock.Controller 56 | recorder *MockChatCompleterMockRecorder 57 | } 58 | 59 | // MockChatCompleterMockRecorder is the mock recorder for MockChatCompleter. 60 | type MockChatCompleterMockRecorder struct { 61 | mock *MockChatCompleter 62 | } 63 | 64 | // NewMockChatCompleter creates a new mock instance. 65 | func NewMockChatCompleter(ctrl *gomock.Controller) *MockChatCompleter { 66 | mock := &MockChatCompleter{ctrl: ctrl} 67 | mock.recorder = &MockChatCompleterMockRecorder{mock} 68 | return mock 69 | } 70 | 71 | // EXPECT returns an object that allows the caller to indicate expected use. 72 | func (m *MockChatCompleter) EXPECT() *MockChatCompleterMockRecorder { 73 | return m.recorder 74 | } 75 | 76 | // CreateChatCompletion mocks base method. 77 | func (m *MockChatCompleter) CreateChatCompletion(arg0 context.Context, arg1 openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { 78 | m.ctrl.T.Helper() 79 | ret := m.ctrl.Call(m, "CreateChatCompletion", arg0, arg1) 80 | ret0, _ := ret[0].(openai.ChatCompletionResponse) 81 | ret1, _ := ret[1].(error) 82 | return ret0, ret1 83 | } 84 | 85 | // CreateChatCompletion indicates an expected call of CreateChatCompletion. 86 | func (mr *MockChatCompleterMockRecorder) CreateChatCompletion(arg0, arg1 interface{}) *gomock.Call { 87 | mr.mock.ctrl.T.Helper() 88 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateChatCompletion", reflect.TypeOf((*MockChatCompleter)(nil).CreateChatCompletion), arg0, arg1) 89 | } 90 | 91 | // MockEmbedder is a mock of Embedder interface. 92 | type MockEmbedder struct { 93 | ctrl *gomock.Controller 94 | recorder *MockEmbedderMockRecorder 95 | } 96 | 97 | // MockEmbedderMockRecorder is the mock recorder for MockEmbedder. 98 | type MockEmbedderMockRecorder struct { 99 | mock *MockEmbedder 100 | } 101 | 102 | // NewMockEmbedder creates a new mock instance. 103 | func NewMockEmbedder(ctrl *gomock.Controller) *MockEmbedder { 104 | mock := &MockEmbedder{ctrl: ctrl} 105 | mock.recorder = &MockEmbedderMockRecorder{mock} 106 | return mock 107 | } 108 | 109 | // EXPECT returns an object that allows the caller to indicate expected use. 110 | func (m *MockEmbedder) EXPECT() *MockEmbedderMockRecorder { 111 | return m.recorder 112 | } 113 | 114 | // CreateEmbeddings mocks base method. 115 | func (m *MockEmbedder) CreateEmbeddings(arg0 context.Context, arg1 openai.EmbeddingRequest) (openai.EmbeddingResponse, error) { 116 | m.ctrl.T.Helper() 117 | ret := m.ctrl.Call(m, "CreateEmbeddings", arg0, arg1) 118 | ret0, _ := ret[0].(openai.EmbeddingResponse) 119 | ret1, _ := ret[1].(error) 120 | return ret0, ret1 121 | } 122 | 123 | // CreateEmbeddings indicates an expected call of CreateEmbeddings. 124 | func (mr *MockEmbedderMockRecorder) CreateEmbeddings(arg0, arg1 interface{}) *gomock.Call { 125 | mr.mock.ctrl.T.Helper() 126 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEmbeddings", reflect.TypeOf((*MockEmbedder)(nil).CreateEmbeddings), arg0, arg1) 127 | } 128 | 129 | // MockModerator is a mock of Moderator interface. 130 | type MockModerator struct { 131 | ctrl *gomock.Controller 132 | recorder *MockModeratorMockRecorder 133 | } 134 | 135 | // MockModeratorMockRecorder is the mock recorder for MockModerator. 136 | type MockModeratorMockRecorder struct { 137 | mock *MockModerator 138 | } 139 | 140 | // NewMockModerator creates a new mock instance. 141 | func NewMockModerator(ctrl *gomock.Controller) *MockModerator { 142 | mock := &MockModerator{ctrl: ctrl} 143 | mock.recorder = &MockModeratorMockRecorder{mock} 144 | return mock 145 | } 146 | 147 | // EXPECT returns an object that allows the caller to indicate expected use. 148 | func (m *MockModerator) EXPECT() *MockModeratorMockRecorder { 149 | return m.recorder 150 | } 151 | 152 | // Moderations mocks base method. 153 | func (m *MockModerator) Moderations(arg0 context.Context, arg1 openai.ModerationRequest) (openai.ModerationResponse, error) { 154 | m.ctrl.T.Helper() 155 | ret := m.ctrl.Call(m, "Moderations", arg0, arg1) 156 | ret0, _ := ret[0].(openai.ModerationResponse) 157 | ret1, _ := ret[1].(error) 158 | return ret0, ret1 159 | } 160 | 161 | // Moderations indicates an expected call of Moderations. 162 | func (mr *MockModeratorMockRecorder) Moderations(arg0, arg1 interface{}) *gomock.Call { 163 | mr.mock.ctrl.T.Helper() 164 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Moderations", reflect.TypeOf((*MockModerator)(nil).Moderations), arg0, arg1) 165 | } 166 | -------------------------------------------------------------------------------- /math_test.go: -------------------------------------------------------------------------------- 1 | package gollum_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | math "github.com/chewxy/math32" 8 | "github.com/stillmatic/gollum/internal/testutil" 9 | "github.com/viterin/vek/vek32" 10 | ) 11 | 12 | func weaviateCosineSimilarity(a []float32, b []float32) (float32, error) { 13 | if len(a) != len(b) { 14 | return 0, fmt.Errorf("vectors have different dimensions") 15 | } 16 | 17 | var ( 18 | sumProduct float32 19 | sumASquare float32 20 | sumBSquare float32 21 | ) 22 | 23 | for i := range a { 24 | sumProduct += (a[i] * b[i]) 25 | sumASquare += (a[i] * a[i]) 26 | sumBSquare += (b[i] * b[i]) 27 | } 28 | 29 | return sumProduct / (math.Sqrt(sumASquare) * math.Sqrt(sumBSquare)), nil 30 | } 31 | 32 | func BenchmarkCosSim(b *testing.B) { 33 | ns := []int{256, 512, 768, 1024} 34 | b.Run("weaviate", func(b *testing.B) { 35 | for _, n := range ns { 36 | A := testutil.GetRandomEmbedding(n) 37 | B := testutil.GetRandomEmbedding(n) 38 | b.ResetTimer() 39 | b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { 40 | for i := 0; i < b.N; i++ { 41 | f, err := weaviateCosineSimilarity(A, B) 42 | if err != nil { 43 | panic(err) 44 | } 45 | _ = f 46 | } 47 | }) 48 | } 49 | }) 50 | b.Run("vek", func(b *testing.B) { 51 | for _, n := range ns { 52 | A := testutil.GetRandomEmbedding(n) 53 | B := testutil.GetRandomEmbedding(n) 54 | b.ResetTimer() 55 | b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { 56 | for i := 0; i < b.N; i++ { 57 | f := vek32.CosineSimilarity(A, B) 58 | _ = f 59 | } 60 | }) 61 | } 62 | }) 63 | } 64 | 65 | // func BenchmarkGonum32(b *testing.B) { 66 | // vs := []int{256, 512, 768, 1024} 67 | // for _, n := range vs { 68 | // A := getRandomEmbedding(n) 69 | // B := getRandomEmbedding(n) 70 | // b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { 71 | // for i := 0; i < b.N; i++ { 72 | // f, err := gallant.GonumSim(A, B) 73 | // if err != nil { 74 | // panic(err) 75 | // } 76 | // _ = f 77 | // } 78 | // }) 79 | // } 80 | // } 81 | 82 | // func BenchmarkGonum64(b *testing.B) { 83 | // vs := []int{256, 512, 768, 1024} 84 | // for _, n := range vs { 85 | // A := getRandomEmbedding(n) 86 | // B := getRandomEmbedding(n) 87 | // a_ := make([]float64, len(A)) 88 | // b_ := make([]float64, len(B)) 89 | // for i := range A { 90 | // a_[i] = float64(A[i]) 91 | // b_[i] = float64(B[i]) 92 | // } 93 | // b.Run("Naive", func(b *testing.B) { 94 | // for i := 0; i < b.N; i++ { 95 | // f, err := gallant.GonumSim64(a_, b_) 96 | // if err != nil { 97 | // panic(err) 98 | // } 99 | // _ = f 100 | // } 101 | // }) 102 | // am := mat.NewDense(1, len(a_), a_) 103 | // bm := mat.NewDense(1, len(b_), b_) 104 | // b.Run("prealloc", func(b *testing.B) { 105 | // for i := 0; i < b.N; i++ { 106 | // var dot mat.Dense 107 | // dot.Mul(am, bm.T()) 108 | // _ = dot.At(0, 0) 109 | // } 110 | // }) 111 | // } 112 | // } 113 | 114 | // goos: linux 115 | // goarch: amd64 116 | // cpu: AMD Ryzen 9 7950X 16-Core Processor 117 | // BenchmarkWeaviateCosSim/256-32 8028375 154.8 ns/op 0 B/op 0 allocs/op 118 | // BenchmarkWeaviateCosSim/512-32 3958342 300.1 ns/op 0 B/op 0 allocs/op 119 | // BenchmarkWeaviateCosSim/768-32 2677993 456.1 ns/op 0 B/op 0 allocs/op 120 | // BenchmarkWeaviateCosSim/1024-32 2002258 601.6 ns/op 0 B/op 0 allocs/op 121 | // BenchmarkVek32CosSim/256-32 81166414 15.43 ns/op 0 B/op 0 allocs/op 122 | // BenchmarkVek32CosSim/512-32 46376474 26.72 ns/op 0 B/op 0 allocs/op 123 | // BenchmarkVek32CosSim/768-32 30476739 39.65 ns/op 0 B/op 0 allocs/op 124 | // BenchmarkVek32CosSim/1024-32 22698370 51.63 ns/op 0 B/op 0 allocs/op 125 | // BenchmarkGonum32/256-32 1664224 738.9 ns/op 4312 B/op 7 allocs/op 126 | // BenchmarkGonum32/512-32 914896 1212 ns/op 8408 B/op 7 allocs/op 127 | // BenchmarkGonum32/768-32 718142 1634 ns/op 12504 B/op 7 allocs/op 128 | // BenchmarkGonum32/1024-32 573609 2885 ns/op 16600 B/op 7 allocs/op 129 | // BenchmarkGonum64/Naive-32 6329708 189.8 ns/op 216 B/op 5 allocs/op 130 | // BenchmarkGonum64/prealloc-32 9126764 144.8 ns/op 88 B/op 3 allocs/op 131 | // BenchmarkGonum64/Naive#01-32 5176160 222.0 ns/op 216 B/op 5 allocs/op 132 | // BenchmarkGonum64/prealloc#01-32 6484794 185.2 ns/op 88 B/op 3 allocs/op 133 | // BenchmarkGonum64/Naive#02-32 4569102 266.8 ns/op 216 B/op 5 allocs/op 134 | // BenchmarkGonum64/prealloc#02-32 5566400 225.0 ns/op 88 B/op 3 allocs/op 135 | // BenchmarkGonum64/Naive#03-32 3910498 300.0 ns/op 216 B/op 5 allocs/op 136 | // BenchmarkGonum64/prealloc#03-32 4585336 265.9 ns/op 88 B/op 3 allocs/op 137 | 138 | // goos: darwin 139 | // goarch: arm64 140 | // pkg: github.com/stillmatic/gollum 141 | // BenchmarkCosSim/weaviate/256-10 4330524 274.5 ns/op 0 B/op 0 allocs/op 142 | // BenchmarkCosSim/weaviate/512-10 1995426 605.6 ns/op 0 B/op 0 allocs/op 143 | // BenchmarkCosSim/weaviate/768-10 1312820 917.6 ns/op 0 B/op 0 allocs/op 144 | // BenchmarkCosSim/weaviate/1024-10 973432 1232 ns/op 0 B/op 0 allocs/op 145 | // BenchmarkCosSim/vek/256-10 4335747 272.0 ns/op 0 B/op 0 allocs/op 146 | // BenchmarkCosSim/vek/512-10 2027366 596.2 ns/op 0 B/op 0 allocs/op 147 | // BenchmarkCosSim/vek/768-10 1310983 925.2 ns/op 0 B/op 0 allocs/op 148 | // BenchmarkCosSim/vek/1024-10 969460 1233 ns/op 0 B/op 0 allocs/op 149 | // PASS 150 | 151 | // General high level takeaways: 152 | // - vek32 is best option if SIMD is available 153 | // - gonum64 is fast and scales better than weaviate but requires allocs and using f64 154 | // - gonum64 probably has SIMD intrinsics? 155 | // - weaviate implementation is 'fine', appears to have linear scaling with vector size 156 | // - for something in the hotpath I think it's worth it to use vek32 when possible 157 | // - on mac, vek32 and weaviate are more or less identical 158 | -------------------------------------------------------------------------------- /packages/hyde/hyde.go: -------------------------------------------------------------------------------- 1 | package hyde 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/stillmatic/gollum/packages/vectorstore" 7 | 8 | "github.com/pkg/errors" 9 | "github.com/sashabaranov/go-openai" 10 | "github.com/stillmatic/gollum" 11 | "github.com/viterin/vek/vek32" 12 | ) 13 | 14 | type Prompter interface { 15 | BuildPrompt(context.Context, string) string 16 | } 17 | 18 | type Generator interface { 19 | Generate(ctx context.Context, input string, n int) ([]string, error) 20 | } 21 | 22 | type Encoder interface { 23 | Encode(context.Context, string) ([]float32, error) 24 | EncodeBatch(context.Context, []string) ([][]float32, error) 25 | } 26 | 27 | type Searcher interface { 28 | Search(context.Context, []float32, int) ([]*gollum.Document, error) 29 | } 30 | 31 | type Hyde struct { 32 | prompter Prompter 33 | generator Generator 34 | encoder Encoder 35 | searcher Searcher 36 | } 37 | 38 | type ZeroShotPrompter struct { 39 | template string 40 | } 41 | 42 | func NewZeroShotPrompter(template string) *ZeroShotPrompter { 43 | // something like 44 | // Roleplay as a character. Write a short biographical answer to the question. 45 | // Q: %s 46 | // A: 47 | return &ZeroShotPrompter{ 48 | template: template, 49 | } 50 | } 51 | 52 | func (z *ZeroShotPrompter) BuildPrompt(ctx context.Context, prompt string) string { 53 | // fill in template values 54 | return fmt.Sprintf(z.template, prompt) 55 | } 56 | 57 | type LLMGenerator struct { 58 | Model gollum.ChatCompleter 59 | } 60 | 61 | func NewLLMGenerator(model gollum.ChatCompleter) *LLMGenerator { 62 | return &LLMGenerator{ 63 | Model: model, 64 | } 65 | } 66 | 67 | func (l *LLMGenerator) Generate(ctx context.Context, prompt string, n int) ([]string, error) { 68 | createChatCompletionReq := openai.ChatCompletionRequest{ 69 | Messages: []openai.ChatCompletionMessage{ 70 | { 71 | Role: openai.ChatMessageRoleUser, 72 | Content: prompt, 73 | }, 74 | }, 75 | Model: openai.GPT3Dot5Turbo, 76 | N: n, 77 | // hyperparams from https://github.com/texttron/hyde/blob/74101c5157e04f7b57559e7da8ef4a4e5b6da82b/src/hyde/generator.py#LL15C121-L15C121 78 | Temperature: 0.9, 79 | MaxTokens: 512, 80 | TopP: 1, 81 | FrequencyPenalty: 0, 82 | PresencePenalty: 0, 83 | Stop: []string{"\n\n\n"}, 84 | } 85 | resp, err := l.Model.CreateChatCompletion(ctx, createChatCompletionReq) 86 | if err != nil { 87 | return make([]string, 0), err 88 | } 89 | replies := make([]string, n) 90 | for i, choice := range resp.Choices { 91 | replies[i] = choice.Message.Content 92 | } 93 | return replies, nil 94 | } 95 | 96 | type LLMEncoder struct { 97 | Model gollum.Embedder 98 | } 99 | 100 | func NewLLMEncoder(model gollum.Embedder) *LLMEncoder { 101 | return &LLMEncoder{ 102 | Model: model, 103 | } 104 | } 105 | 106 | func (l *LLMEncoder) Encode(ctx context.Context, query string) ([]float32, error) { 107 | createEmbeddingReq := openai.EmbeddingRequest{ 108 | Input: []string{query}, 109 | // TODO: allow customization 110 | Model: openai.AdaEmbeddingV2, 111 | } 112 | resp, err := l.Model.CreateEmbeddings(ctx, createEmbeddingReq) 113 | if err != nil { 114 | return make([]float32, 0), err 115 | } 116 | return resp.Data[0].Embedding, nil 117 | } 118 | 119 | func (l *LLMEncoder) EncodeBatch(ctx context.Context, docs []string) ([][]float32, error) { 120 | createEmbeddingReq := openai.EmbeddingRequest{ 121 | Input: docs, 122 | // TODO: allow customization 123 | Model: openai.AdaEmbeddingV2, 124 | } 125 | resp, err := l.Model.CreateEmbeddings(ctx, createEmbeddingReq) 126 | if err != nil { 127 | return make([][]float32, 0), err 128 | } 129 | embeddings := make([][]float32, len(resp.Data)) 130 | for i, data := range resp.Data { 131 | embeddings[i] = data.Embedding 132 | } 133 | return embeddings, nil 134 | } 135 | 136 | type VectorSearcher struct { 137 | vs vectorstore.VectorStore 138 | } 139 | 140 | func NewVectorSearcher(vs vectorstore.VectorStore) *VectorSearcher { 141 | return &VectorSearcher{ 142 | vs: vs, 143 | } 144 | } 145 | 146 | func (v *VectorSearcher) Search(ctx context.Context, query []float32, n int) ([]*gollum.Document, error) { 147 | qb := vectorstore.QueryRequest{ 148 | EmbeddingFloats: query, 149 | K: n, 150 | } 151 | return v.vs.Query(ctx, qb) 152 | } 153 | 154 | func NewHyde(prompter Prompter, generator Generator, encoder Encoder, searcher Searcher) *Hyde { 155 | return &Hyde{ 156 | prompter: prompter, 157 | generator: generator, 158 | encoder: encoder, 159 | searcher: searcher, 160 | } 161 | } 162 | 163 | // Prompt builds a prompt from the given string. 164 | func (h *Hyde) Prompt(ctx context.Context, prompt string) string { 165 | return h.prompter.BuildPrompt(ctx, prompt) 166 | } 167 | 168 | // Generate generates n hypothesis documents. 169 | func (h *Hyde) Generate(ctx context.Context, prompt string, n int) ([]string, error) { 170 | searchPrompt := h.prompter.BuildPrompt(ctx, prompt) 171 | return h.generator.Generate(ctx, searchPrompt, n) 172 | } 173 | 174 | // Encode encodes a query and a list of documents into a single embedding. 175 | func (h *Hyde) Encode(ctx context.Context, query string, docs []string) ([]float32, error) { 176 | docs = append(docs, query) 177 | embeddings, err := h.encoder.EncodeBatch(ctx, docs) 178 | if err != nil { 179 | return make([]float32, 0), errors.Wrap(err, "error encoding batch") 180 | } 181 | embedDim := len(embeddings[0]) 182 | numEmbeddings := float32(len(embeddings)) 183 | // mean pooling of response embeddings 184 | avgEmbedding := make([]float32, embedDim) 185 | for _, embedding := range embeddings { 186 | vek32.Add_Inplace(avgEmbedding, embedding) 187 | } 188 | // unclear if this is faster or slower than naive approach. a little cleaner code though. 189 | vek32.DivNumber_Inplace(avgEmbedding, numEmbeddings) 190 | return avgEmbedding, nil 191 | } 192 | 193 | func (h *Hyde) Search(ctx context.Context, hydeVector []float32, k int) ([]*gollum.Document, error) { 194 | return h.searcher.Search(ctx, hydeVector, k) 195 | } 196 | 197 | func (h *Hyde) SearchEndToEnd(ctx context.Context, query string, k int) ([]*gollum.Document, error) { 198 | // generate n hypothesis documents 199 | docs, err := h.Generate(ctx, query, k) 200 | if err != nil { 201 | return nil, errors.Wrap(err, "error generating hypothetical documents") 202 | } 203 | // encode query and hypothesis documents 204 | hydeVector, err := h.Encode(ctx, query, docs) 205 | if err != nil { 206 | return make([]*gollum.Document, 0), errors.Wrap(err, "error encoding hypothetical documents") 207 | } 208 | // search for the most similar documents 209 | return h.Search(ctx, hydeVector, k) 210 | } 211 | -------------------------------------------------------------------------------- /packages/vectorstore/vectorstore_test.go: -------------------------------------------------------------------------------- 1 | package vectorstore_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | vectorstore2 "github.com/stillmatic/gollum/packages/vectorstore" 7 | "math/rand" 8 | "testing" 9 | 10 | "github.com/sashabaranov/go-openai" 11 | "github.com/stillmatic/gollum" 12 | mock_gollum "github.com/stillmatic/gollum/internal/mocks" 13 | "github.com/stretchr/testify/assert" 14 | "go.uber.org/mock/gomock" 15 | "gocloud.dev/blob/fileblob" 16 | ) 17 | 18 | func getRandomEmbedding(n int) []float32 { 19 | vec := make([]float32, n) 20 | for i := range vec { 21 | vec[i] = rand.Float32() 22 | } 23 | return vec 24 | } 25 | 26 | // setup with godotenv load 27 | func initialize(tb testing.TB) (*mock_gollum.MockEmbedder, *vectorstore2.MemoryVectorStore) { 28 | tb.Helper() 29 | 30 | ctrl := gomock.NewController(tb) 31 | oai := mock_gollum.NewMockEmbedder(ctrl) 32 | ctx := context.Background() 33 | bucket, err := fileblob.OpenBucket("testdata", nil) 34 | assert.NoError(tb, err) 35 | mvs, err := vectorstore2.NewMemoryVectorStoreFromDisk(ctx, bucket, "simple_store.json", oai) 36 | if err != nil { 37 | fmt.Println(err) 38 | mvs = vectorstore2.NewMemoryVectorStore(oai) 39 | testStrs := []string{"Apple", "Orange", "Basketball"} 40 | for i, s := range testStrs { 41 | mv := gollum.NewDocumentFromString(s) 42 | expectedReq := openai.EmbeddingRequest{ 43 | Input: []string{s}, 44 | Model: openai.AdaEmbeddingV2, 45 | } 46 | val := float64(i) + 0.1 47 | expectedResp := openai.EmbeddingResponse{ 48 | Data: []openai.Embedding{{Embedding: []float32{float32(0.1), float32(val), float32(val)}}}, 49 | } 50 | oai.EXPECT().CreateEmbeddings(ctx, expectedReq).Return(expectedResp, nil) 51 | err := mvs.Insert(ctx, mv) 52 | assert.NoError(tb, err) 53 | } 54 | err := mvs.Persist(ctx, bucket, "simple_store.json") 55 | assert.NoError(tb, err) 56 | } 57 | return oai, mvs 58 | } 59 | 60 | // TestRetrieval tests inserting embeddings and retrieving them 61 | func TestMemoryVectorStore(t *testing.T) { 62 | mockllm, mvs := initialize(t) 63 | ctx := context.Background() 64 | t.Run("LoadFromDisk", func(t *testing.T) { 65 | t.Log(mvs.Documents) 66 | assert.Equal(t, 3, len(mvs.Documents)) 67 | // check that an ID is in the map 68 | testStrs := []string{"Apple", "Orange", "Basketball"} 69 | // test that all the strings are in the documents 70 | for _, s := range mvs.Documents { 71 | found := false 72 | for _, t := range testStrs { 73 | if s.Content == t { 74 | found = true 75 | } 76 | } 77 | assert.True(t, found) 78 | } 79 | }) 80 | 81 | // should return apple and orange first. 82 | t.Run("QueryWithQuery", func(t *testing.T) { 83 | k := 2 84 | qb := vectorstore2.QueryRequest{ 85 | Query: "favorite fruit?", 86 | K: k, 87 | } 88 | expectedCreateReq := openai.EmbeddingRequest{ 89 | Input: []string{"favorite fruit?"}, 90 | Model: openai.AdaEmbeddingV2, 91 | } 92 | expectedCreateResp := openai.EmbeddingResponse{ 93 | Data: []openai.Embedding{ 94 | { 95 | Embedding: []float32{float32(0.1), float32(0.1), float32(0.1)}, 96 | }, 97 | }, 98 | } 99 | mockllm.EXPECT().CreateEmbeddings(ctx, expectedCreateReq).Return(expectedCreateResp, nil) 100 | resp, err := mvs.Query(ctx, qb) 101 | assert.NoError(t, err) 102 | assert.Equal(t, k, len(resp)) 103 | assert.Equal(t, "Apple", resp[0].Content) 104 | assert.Equal(t, "Orange", resp[1].Content) 105 | }) 106 | 107 | // This should return basketball because the embedding str should override the query 108 | t.Run("QueryWithEmbedding", func(t *testing.T) { 109 | k := 1 110 | qb := vectorstore2.QueryRequest{ 111 | Query: "What is your favorite fruit", 112 | EmbeddingStrings: []string{"favorite sport?"}, 113 | K: k, 114 | } 115 | expectedCreateReq := openai.EmbeddingRequest{ 116 | Input: []string{"favorite sport?"}, 117 | Model: openai.AdaEmbeddingV2, 118 | } 119 | expectedCreateResp := openai.EmbeddingResponse{ 120 | Data: []openai.Embedding{ 121 | {Embedding: []float32{float32(0.1), float32(2.11), float32(2.11)}}, 122 | }} 123 | mockllm.EXPECT().CreateEmbeddings(ctx, expectedCreateReq).Return(expectedCreateResp, nil) 124 | resp, err := mvs.Query(ctx, qb) 125 | assert.NoError(t, err) 126 | assert.Equal(t, k, len(resp)) 127 | assert.Equal(t, "Basketball", resp[0].Content) 128 | }) 129 | } 130 | 131 | type MockEmbedder struct{} 132 | 133 | func (m MockEmbedder) CreateEmbeddings(ctx context.Context, req openai.EmbeddingRequest) (openai.EmbeddingResponse, error) { 134 | resp := openai.EmbeddingResponse{ 135 | Data: []openai.Embedding{ 136 | {Embedding: getRandomEmbedding(1536)}, 137 | }, 138 | } 139 | return resp, nil 140 | } 141 | 142 | func BenchmarkMemoryVectorStore(b *testing.B) { 143 | llm := mock_gollum.NewMockEmbedder(gomock.NewController(b)) 144 | ctx := context.Background() 145 | 146 | nValues := []int{10, 100, 1_000, 10_000, 100_000, 1_000_000} 147 | kValues := []int{1, 10, 100} 148 | dim := 768 149 | for _, n := range nValues { 150 | b.Run(fmt.Sprintf("BenchmarkInsert-n=%v", n), func(b *testing.B) { 151 | mvs := vectorstore2.NewMemoryVectorStore(llm) 152 | for i := 0; i < b.N; i++ { 153 | for j := 0; j < n; j++ { 154 | mv := gollum.Document{ 155 | ID: fmt.Sprintf("%v", j), 156 | Content: "test", 157 | Embedding: getRandomEmbedding(dim), 158 | } 159 | mvs.Insert(ctx, mv) 160 | } 161 | } 162 | }) 163 | for _, k := range kValues { 164 | if k <= n { 165 | b.Run(fmt.Sprintf("BenchmarkQuery-n=%v-k=%v", n, k), func(b *testing.B) { 166 | mvs := vectorstore2.NewMemoryVectorStore(llm) 167 | for j := 0; j < n; j++ { 168 | mv := gollum.Document{ 169 | ID: fmt.Sprintf("%v", j), 170 | Content: "test", 171 | Embedding: getRandomEmbedding(dim), 172 | } 173 | mvs.Insert(ctx, mv) 174 | } 175 | qb := vectorstore2.QueryRequest{ 176 | EmbeddingFloats: getRandomEmbedding(dim), 177 | K: k, 178 | } 179 | b.ResetTimer() 180 | for i := 0; i < b.N; i++ { 181 | _, err := mvs.Query(ctx, qb) 182 | assert.NoError(b, err) 183 | } 184 | }) 185 | } 186 | } 187 | } 188 | } 189 | 190 | func BenchmarkHeap(b *testing.B) { 191 | // Create a sample Heap. 192 | 193 | ks := []int{1, 10, 100} 194 | 195 | for _, k := range ks { 196 | var h vectorstore2.Heap 197 | h.Init(k) 198 | b.Run(fmt.Sprintf("BenchmarkHeapPush-k=%v", k), func(b *testing.B) { 199 | for i := 0; i < b.N; i++ { 200 | doc := &gollum.Document{} 201 | similarity := rand.Float32() 202 | ns := vectorstore2.NodeSimilarity{Document: doc, Similarity: similarity} 203 | h.Push(ns) 204 | if h.Len() > k { 205 | h.Pop() 206 | } 207 | } 208 | }) 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /packages/llm/providers/vertex/vertex.go: -------------------------------------------------------------------------------- 1 | // Package vertex implements the Vertex AI api 2 | // it is largely similar to the Google "ai studio" provider but uses a different library... 3 | package vertex 4 | 5 | import ( 6 | "cloud.google.com/go/vertexai/genai" 7 | "context" 8 | "fmt" 9 | "github.com/pkg/errors" 10 | "github.com/stillmatic/gollum/packages/llm" 11 | "google.golang.org/api/iterator" 12 | "log" 13 | ) 14 | 15 | type VertexAIProvider struct { 16 | client *genai.Client 17 | } 18 | 19 | func NewVertexAIProvider(ctx context.Context, projectID, location string) (*VertexAIProvider, error) { 20 | client, err := genai.NewClient(ctx, projectID, location) 21 | if err != nil { 22 | return nil, errors.Wrap(err, "failed to create Vertex AI client") 23 | } 24 | 25 | ccIter := client.ListCachedContents(ctx) 26 | for { 27 | cc, err := ccIter.Next() 28 | if err == iterator.Done { 29 | break 30 | } 31 | if err != nil { 32 | return nil, errors.Wrap(err, "failed to list cached contents") 33 | } 34 | log.Printf("Cached content: %v", cc) 35 | } 36 | 37 | return &VertexAIProvider{ 38 | client: client, 39 | }, nil 40 | } 41 | 42 | func (p *VertexAIProvider) getModel(req llm.InferRequest) *genai.GenerativeModel { 43 | // this does NOT validate if the model name is valid, that is done at inference time. 44 | model := p.client.GenerativeModel(req.ModelConfig.ModelName) 45 | model.SetTemperature(req.MessageOptions.Temperature) 46 | model.SetMaxOutputTokens(int32(req.MessageOptions.MaxTokens)) 47 | 48 | return model 49 | } 50 | 51 | func (p *VertexAIProvider) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) { 52 | if len(req.Messages) > 1 { 53 | return p.generateResponseMultiTurn(ctx, req) 54 | } 55 | return p.generateResponseSingleTurn(ctx, req) 56 | } 57 | 58 | func (p *VertexAIProvider) generateResponseSingleTurn(ctx context.Context, req llm.InferRequest) (string, error) { 59 | model := p.getModel(req) 60 | parts := messageToParts(req.Messages[0]) 61 | 62 | resp, err := model.GenerateContent(ctx, parts...) 63 | if err != nil { 64 | return "", errors.Wrap(err, "failed to generate content") 65 | } 66 | 67 | return flattenResponse(resp), nil 68 | } 69 | 70 | func (p *VertexAIProvider) generateResponseMultiTurn(ctx context.Context, req llm.InferRequest) (string, error) { 71 | model := p.getModel(req) 72 | 73 | msgs, sysInstr := multiTurnMessageToParts(req.Messages[:len(req.Messages)-1]) 74 | if sysInstr != nil { 75 | model.SystemInstruction = sysInstr 76 | } 77 | 78 | cs := model.StartChat() 79 | cs.History = msgs 80 | mostRecentMessage := req.Messages[len(req.Messages)-1] 81 | 82 | // Send the last message 83 | resp, err := cs.SendMessage(ctx, genai.Text(mostRecentMessage.Content)) 84 | if err != nil { 85 | return "", errors.Wrap(err, "failed to send message in chat") 86 | } 87 | 88 | return flattenResponse(resp), nil 89 | } 90 | 91 | func (p *VertexAIProvider) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) { 92 | if len(req.Messages) > 1 { 93 | return p.generateResponseAsyncMultiTurn(ctx, req) 94 | } 95 | return p.generateResponseAsyncSingleTurn(ctx, req) 96 | } 97 | 98 | func (p *VertexAIProvider) generateResponseAsyncSingleTurn(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) { 99 | outChan := make(chan llm.StreamDelta) 100 | 101 | go func() { 102 | defer close(outChan) 103 | 104 | model := p.getModel(req) 105 | parts := messageToParts(req.Messages[0]) 106 | 107 | iter := model.GenerateContentStream(ctx, parts...) 108 | 109 | for { 110 | resp, err := iter.Next() 111 | if errors.Is(err, iterator.Done) { 112 | outChan <- llm.StreamDelta{EOF: true} 113 | break 114 | } 115 | if err != nil { 116 | log.Printf("Error from Vertex AI stream: %v", err) 117 | return 118 | } 119 | 120 | content := flattenResponse(resp) 121 | if content != "" { 122 | select { 123 | case <-ctx.Done(): 124 | return 125 | case outChan <- llm.StreamDelta{Text: content}: 126 | } 127 | } 128 | } 129 | }() 130 | 131 | return outChan, nil 132 | } 133 | 134 | func (p *VertexAIProvider) generateResponseAsyncMultiTurn(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) { 135 | outChan := make(chan llm.StreamDelta) 136 | 137 | go func() { 138 | defer close(outChan) 139 | 140 | model := p.getModel(req) 141 | cs := model.StartChat() 142 | 143 | // Add previous messages to chat history 144 | for _, msg := range req.Messages[:len(req.Messages)-1] { 145 | parts := messageToParts(msg) 146 | cs.History = append(cs.History, &genai.Content{ 147 | Parts: parts, 148 | Role: msg.Role, 149 | }) 150 | } 151 | 152 | // Send the last message 153 | lastMsg := req.Messages[len(req.Messages)-1] 154 | iter := cs.SendMessageStream(ctx, messageToParts(lastMsg)...) 155 | 156 | for { 157 | resp, err := iter.Next() 158 | if errors.Is(err, iterator.Done) { 159 | outChan <- llm.StreamDelta{EOF: true} 160 | break 161 | } 162 | if err != nil { 163 | log.Printf("Error from Vertex AI stream: %v", err) 164 | return 165 | } 166 | 167 | content := flattenResponse(resp) 168 | if content != "" { 169 | select { 170 | case <-ctx.Done(): 171 | return 172 | case outChan <- llm.StreamDelta{Text: content}: 173 | } 174 | } 175 | } 176 | }() 177 | 178 | return outChan, nil 179 | } 180 | 181 | func messageToParts(message llm.InferMessage) []genai.Part { 182 | parts := []genai.Part{genai.Text(message.Content)} 183 | if message.Image != nil && len(message.Image) > 0 { 184 | parts = append(parts, genai.ImageData("png", message.Image)) 185 | } 186 | return parts 187 | } 188 | 189 | func multiTurnMessageToParts(messages []llm.InferMessage) ([]*genai.Content, *genai.Content) { 190 | sysInstructionParts := make([]genai.Part, 0) 191 | hist := make([]*genai.Content, 0, len(messages)) 192 | for _, message := range messages { 193 | parts := []genai.Part{genai.Text(message.Content)} 194 | if message.Image != nil && len(message.Image) > 0 { 195 | parts = append(parts, genai.ImageData("png", message.Image)) 196 | } 197 | if message.Role == "system" { 198 | sysInstructionParts = append(sysInstructionParts, parts...) 199 | continue 200 | } 201 | hist = append(hist, &genai.Content{ 202 | Parts: parts, 203 | Role: message.Role, 204 | }) 205 | } 206 | if len(sysInstructionParts) > 0 { 207 | return hist, &genai.Content{ 208 | Parts: sysInstructionParts, 209 | } 210 | } 211 | 212 | return hist, nil 213 | } 214 | 215 | func flattenResponse(resp *genai.GenerateContentResponse) string { 216 | var result string 217 | for _, cand := range resp.Candidates { 218 | for _, part := range cand.Content.Parts { 219 | result += fmt.Sprintf("%v", part) 220 | } 221 | } 222 | return result 223 | } 224 | 225 | var _ llm.Responder = &VertexAIProvider{} 226 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GOLLuM 2 | 3 | Production-grade LLM tooling. At least, in theory -- stuff changes fast so don't expect stability from this library so much as ideas for your own apps. 4 | 5 | ## Features 6 | 7 | - Sane LLM provider abstraction. 8 | - Automated function dispatch 9 | - Parses arbitrary Go structs into JSONSchema for OpenAI - and validates when unmarshaling back to your structs 10 | - Simplified API to generate results from a single prompt or template 11 | - Highly performant vector store solution with exact search 12 | - SIMD acceleration for 10x better perf than naive approach, constant memory usage 13 | - Drop-in integration with OpenAI and other embedding providers 14 | - Carefully mocked, tested, and benchmarked. 15 | - Implementation of HyDE (hypothetical documents embeddings) for enhanced retrieval 16 | - MIT License 17 | 18 | # Examples 19 | 20 | ## Dispatch 21 | 22 | Function dispatch is a highly simplified and easy way to generate filled structs via an LLM. 23 | 24 | ```go 25 | type dinnerParty struct { 26 | Topic string `json:"topic" jsonschema:"required" jsonschema_description:"The topic of the conversation"` 27 | RandomWords []string `json:"random_words" jsonschema:"required" jsonschema_description:"Random words to prime the conversation"` 28 | } 29 | completer := openai.NewClient(os.Getenv("OPENAI_API_KEY")) 30 | d := gollum.NewOpenAIDispatcher[dinnerParty]("dinner_party", "Given a topic, return random words", completer, nil) 31 | output, _ := d.Prompt(context.Background(), "Talk to me about dinosaurs") 32 | ``` 33 | 34 | The result should be a filled `dinnerParty`` struct. 35 | 36 | ```go 37 | expected := dinnerParty{ 38 | Topic: "dinosaurs", 39 | RandomWords: []string{"dinosaur", "fossil", "extinct"}, 40 | } 41 | ``` 42 | 43 | Some similar libraries / ideas: 44 | 45 | - Rust: [grantslatton/ai-functions](https://github.com/grantslatton/ai-functions/blob/main/ai_bin/src/main.rs) 46 | - Python: [jxnl/openai_function_call](https://github.com/jxnl/openai_function_call) 47 | 48 | ## Parsing 49 | 50 | ### Simplest 51 | 52 | Imagine you have a function `GetWeather` -- 53 | 54 | ```go 55 | type getWeatherInput struct { 56 | Location string `json:"location" jsonschema_description:"The city and state, e.g. San Francisco, CA" jsonschema:"required"` 57 | Unit string `json:"unit,omitempty" jsonschema:"enum=celsius,enum=fahrenheit" jsonschema_description:"The unit of temperature"` 58 | } 59 | 60 | type getWeatherOutput struct { 61 | // ... 62 | } 63 | 64 | // GetWeather does something, this dosctring is annoying but theoretically possible to get 65 | func GetWeather(ctx context.Context, inp getWeatherInput) (out getWeatherOutput, err error) { 66 | return out, err 67 | } 68 | ``` 69 | 70 | This is a common pattern for API design, as it is eay to share the `getWeatherInput` struct (well, imagine if it were public). See, for example, the [GRPC service definitions](https://github.com/grpc/grpc-go/blob/master/examples/helloworld/greeter_server/main.go#L43), or the [Connect RPC implementation](https://github.com/bufbuild/connect-go/blob/main/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go#LL155C6-L155C24). This means we can simplify the logic greatly by assuming a single input struct. 71 | 72 | Now, we can construct the responses: 73 | 74 | ```go 75 | type getWeatherInput struct { 76 | Location string `json:"location" jsonschema_description:"The city and state, e.g. San Francisco, CA" jsonschema:"required"` 77 | Unit string `json:"unit,omitempty" jsonschema:"enum=celsius,enum=fahrenheit" jsonschema_description:"The unit of temperature"` 78 | } 79 | 80 | fi := gollum.StructToJsonSchema("weather", "Get the current weather in a given location", getWeatherInput{}) 81 | 82 | chatRequest := openai.ChatCompletionRequest{ 83 | Model: "gpt-3.5-turbo-0613", 84 | Messages: []openai.ChatCompletionMessage{ 85 | { 86 | Role: "user", 87 | Content: "Whats the temperature in Boston?", 88 | }, 89 | }, 90 | MaxTokens: 256, 91 | Temperature: 0.0, 92 | Tools: []openai.Tool{{Type: "function", Function: openai.FunctionDefinition(fi)}}, 93 | ToolChoice: "weather", 94 | } 95 | 96 | ctx := context.Background() 97 | resp, err := api.SendRequest(ctx, chatRequest) 98 | parser := gollum.NewJSONParser[getWeatherInput](false) 99 | input, err := parser.Parse(ctx, resp.Choices[0].Message.ToolCalls[0].Function.Arguments) 100 | ``` 101 | 102 | This example steps through all that, end to end. Some of this is 'sort of' pseudo-code, as the OpenAI clients I use haven't implemented support yet for functions, but it should also hopefully show that minimal modifications are necessary to upstream libraries. 103 | 104 | It is also possible to go from just the function definition to a fully formed OpenAI FunctionCall. Reflection gives name of the function for free, godoc parsing can get the function description too. I think in practice though that it's fairly unlikely that you need to change the name/description of the function that often, and in practice the inputs change more often. Using this pattern and compiling once makes the most sense to me. 105 | 106 | We should be able to chain the call for the single input and for the ctx + single input case and return it easily. 107 | 108 | ### Recursion on arbitrary structs without explicit definitions 109 | 110 | Say you have a struct that has JSON tags defined. 111 | 112 | ```go 113 | fi := gollum.StructToJsonSchema("ChatCompletion", "Call the OpenAI chat completion API", chatCompletionRequest{}) 114 | 115 | chatRequest := chatCompletionRequest{ 116 | ChatCompletionRequest: openai.ChatCompletionRequest{ 117 | Model: "gpt-3.5-turbo-0613", 118 | Messages: []openai.ChatCompletionMessage{ 119 | { 120 | Role: openai.ChatMessageRoleSystem, 121 | Content: "Construct a ChatCompletionRequest to answer the user's question, but using Kirby references. Do not answer the question directly using prior knowledge, you must generate a ChatCompletionRequest that will answer the question.", 122 | }, 123 | { 124 | Role: openai.ChatMessageRoleUser, 125 | Content: "What is the definition of recursion?", 126 | }, 127 | }, 128 | MaxTokens: 256, 129 | Temperature: 0.0, 130 | }, 131 | Tools: []openai.Tool{ 132 | { 133 | Type: "function", 134 | Function: fi, 135 | } 136 | } 137 | } 138 | parser := gollum.NewJSONParser[openai.ChatCompletionRequest](false) 139 | input, err := parser.Parse(ctx, resp.Choices[0].Message.ToolCalls[0].Function.Arguments) 140 | ``` 141 | 142 | On the first try, this yielded the following result: 143 | 144 | ```json 145 | { 146 | "model": "gpt-3.5-turbo", 147 | "messages": [ 148 | {"role": "system", "content": "You are Kirby, a friendly virtual assistant."}, 149 | {"role": "user", "content": "What is the definition of recursion?"} 150 | ] 151 | } 152 | ``` 153 | 154 | That's really sick considering that _no_ effort was put into manually creating a new JSON struct, and the original struct didn't have any JSONSchema tags - just JSON serdes comments. 155 | -------------------------------------------------------------------------------- /bench.md: -------------------------------------------------------------------------------- 1 | 2 | 2023-07-09 3 | 4 | ``` 5 | goos: linux 6 | goarch: amd64 7 | pkg: github.com/stillmatic/gollum 8 | cpu: AMD Ryzen 9 7950X 16-Core Processor 9 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10-32 14752 83971 ns/op 64912 B/op 10 allocs/op 10 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=1-32 810657 1256 ns/op 288 B/op 3 allocs/op 11 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=10-32 663574 1639 ns/op 2880 B/op 6 allocs/op 12 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100-32 1263 1042804 ns/op 646807 B/op 190 allocs/op 13 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=1-32 113851 10399 ns/op 288 B/op 3 allocs/op 14 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=10-32 93428 12625 ns/op 2880 B/op 6 allocs/op 15 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=100-32 69256 17569 ns/op 25664 B/op 9 allocs/op 16 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000-32 147 8100071 ns/op 6505065 B/op 2734 allocs/op 17 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=1-32 10000 104921 ns/op 288 B/op 3 allocs/op 18 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=10-32 9831 123464 ns/op 2880 B/op 6 allocs/op 19 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=100-32 7848 152007 ns/op 25664 B/op 9 allocs/op 20 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10000-32 13 82907557 ns/op 64727761 B/op 29740 allocs/op 21 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=1-32 783 1514925 ns/op 288 B/op 3 allocs/op 22 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=10-32 679 1692251 ns/op 2880 B/op 6 allocs/op 23 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=100-32 650 2095042 ns/op 25664 B/op 9 allocs/op 24 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100000-32 2 814309670 ns/op 648192728 B/op 299774 allocs/op 25 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=1-32 82 16656264 ns/op 288 B/op 3 allocs/op 26 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=10-32 68 16402470 ns/op 2880 B/op 6 allocs/op 27 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=100-32 64 18205266 ns/op 25664 B/op 9 allocs/op 28 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000000-32 1 9578485965 ns/op 6552089784 B/op 2999874 allocs/op 29 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=1-32 7 161260588 ns/op 288 B/op 3 allocs/op 30 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=10-32 7 212760511 ns/op 2880 B/op 6 allocs/op 31 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=100-32 4 290261365 ns/op 25664 B/op 9 allocs/op 32 | PASS 33 | ok github.com/stillmatic/gollum 111.224s 34 | ``` 35 | 36 | post perf improvements, 2023-07-17 37 | changes are most pronounced when k is large, more efficient memory reuse makes time nearly constant with k, about 2x improvement with large k. 38 | 39 | ``` 40 | goos: linux 41 | goarch: amd64 42 | pkg: github.com/stillmatic/gollum 43 | cpu: AMD Ryzen 9 7950X 16-Core Processor 44 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10-32 12562 86781 ns/op 64680 B/op 10 allocs/op 45 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=1-32 1259478 967.6 ns/op 120 B/op 4 allocs/op 46 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=10-32 1201736 993.5 ns/op 304 B/op 3 allocs/op 47 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100-32 1335 847394 ns/op 652949 B/op 190 allocs/op 48 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=1-32 143896 8509 ns/op 120 B/op 4 allocs/op 49 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=10-32 122300 9787 ns/op 624 B/op 4 allocs/op 50 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=100-32 111930 11152 ns/op 2752 B/op 3 allocs/op 51 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000-32 127 8455112 ns/op 6477091 B/op 2734 allocs/op 52 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=1-32 13416 88695 ns/op 120 B/op 4 allocs/op 53 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=10-32 12246 97176 ns/op 624 B/op 4 allocs/op 54 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=100-32 9746 114710 ns/op 5952 B/op 4 allocs/op 55 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10000-32 14 88090856 ns/op 65255787 B/op 29740 allocs/op 56 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=1-32 769 1357818 ns/op 120 B/op 4 allocs/op 57 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=10-32 727 1555869 ns/op 624 B/op 4 allocs/op 58 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=100-32 752 1574506 ns/op 5952 B/op 4 allocs/op 59 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100000-32 2 888843642 ns/op 648192776 B/op 299774 allocs/op 60 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=1-32 69 15276284 ns/op 120 B/op 4 allocs/op 61 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=10-32 79 14270086 ns/op 624 B/op 4 allocs/op 62 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=100-32 68 15162731 ns/op 5952 B/op 4 allocs/op 63 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000000-32 1 8644221139 ns/op 6552072176 B/op 2999841 allocs/op 64 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=1-32 8 141239584 ns/op 120 B/op 4 allocs/op 65 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=10-32 1 1354937045 ns/op 624 B/op 4 allocs/op 66 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=100-32 7 156217518 ns/op 5952 B/op 4 allocs/op 67 | PASS 68 | ok github.com/stillmatic/gollum 105.535s 69 | ``` 70 | 71 | post perf improvement - mac. stabilizes the allocations 72 | 73 | ``` 74 | goos: darwin 75 | goarch: arm64 76 | pkg: github.com/stillmatic/gollum 77 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10-10 5341 220817 ns/op 65223 B/op 10 allocs/op 78 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=1-10 60616 19622 ns/op 120 B/op 4 allocs/op 79 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10-k=10-10 60388 20033 ns/op 304 B/op 3 allocs/op 80 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100-10 536 2202933 ns/op 652278 B/op 190 allocs/op 81 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=1-10 6152 194476 ns/op 120 B/op 4 allocs/op 82 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=10-10 6094 198124 ns/op 624 B/op 4 allocs/op 83 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100-k=100-10 5946 199925 ns/op 2752 B/op 3 allocs/op 84 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000-10 55 22152592 ns/op 6523947 B/op 2735 allocs/op 85 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=1-10 613 1953824 ns/op 120 B/op 4 allocs/op 86 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=10-10 610 1987216 ns/op 624 B/op 4 allocs/op 87 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000-k=100-10 580 2051436 ns/op 5952 B/op 4 allocs/op 88 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=10000-10 5 222244750 ns/op 64782620 B/op 29747 allocs/op 89 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=1-10 61 19383620 ns/op 120 B/op 4 allocs/op 90 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=10-10 60 19823898 ns/op 624 B/op 4 allocs/op 91 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=10000-k=100-10 57 20027584 ns/op 5952 B/op 4 allocs/op 92 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=100000-10 1 2207505500 ns/op 648271208 B/op 299808 allocs/op 93 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=1-10 6 196473680 ns/op 120 B/op 4 allocs/op 94 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=10-10 6 197389812 ns/op 624 B/op 4 allocs/op 95 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=100000-k=100-10 5 200068883 ns/op 5952 B/op 4 allocs/op 96 | BenchmarkMemoryVectorStore/BenchmarkInsert-n=1000000-10 1 22239769458 ns/op 6552038696 B/op 2999849 allocs/op 97 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=1-10 1 1966544833 ns/op 120 B/op 4 allocs/op 98 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=10-10 1 1963972417 ns/op 624 B/op 4 allocs/op 99 | BenchmarkMemoryVectorStore/BenchmarkQuery-n=1000000-k=100-10 1 1988149583 ns/op 5952 B/op 4 allocs/op 100 | PASS 101 | ok github.com/stillmatic/gollum 142.897s 102 | ``` 103 | 104 | mac is expected to be slower. however, post change, what we see is that our memury usage is much more stable - consistently 4 allocs per operation and much less memory usage too. the memory characteristics are proportional to `k`. the desktop chip is faster and has SIMD enhanced distance calculation, so not unexpected. the runtime is also consistently linear with `n` where `n` is the number of values in the db. 105 | 106 | -------------------------------------------------------------------------------- /packages/vectorstore/vectorstore_compressed_test.go: -------------------------------------------------------------------------------- 1 | package vectorstore_test 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "crypto/rand" 8 | "fmt" 9 | vectorstore2 "github.com/stillmatic/gollum/packages/vectorstore" 10 | mathrand "math/rand" 11 | "os" 12 | "testing" 13 | 14 | "github.com/google/uuid" 15 | "github.com/stillmatic/gollum" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | func TestCompressedVectorStore(t *testing.T) { 20 | vs := vectorstore2.NewGzipVectorStore() 21 | t.Run("implements interface", func(t *testing.T) { 22 | var vs2 vectorstore2.VectorStore 23 | vs2 = vs 24 | assert.NotNil(t, vs2) 25 | }) 26 | ctx := context.Background() 27 | testStrings := []string{ 28 | "Japan's Seiko Epson Corp. has developed a 12-gram flying microrobot.", 29 | "The latest tiny flying robot has been unveiled in Japan.", 30 | "Michael Phelps won the gold medal in the 400 individual medley.", 31 | } 32 | t.Run("testInsert", func(t *testing.T) { 33 | for _, str := range testStrings { 34 | vs.Insert(ctx, gollum.Document{ 35 | ID: uuid.NewString(), 36 | Content: str, 37 | }) 38 | } 39 | docs, err := vs.RetrieveAll(ctx) 40 | assert.NoError(t, err) 41 | assert.Equal(t, 5, len(docs)) 42 | }) 43 | t.Run("correctness", func(t *testing.T) { 44 | for _, str := range testStrings { 45 | vs.Insert(ctx, gollum.Document{ 46 | ID: uuid.NewString(), 47 | Content: str, 48 | }) 49 | } 50 | docs, err := vs.Query(ctx, vectorstore2.QueryRequest{ 51 | Query: "Where was the new robot unveiled?", 52 | K: 5, 53 | }) 54 | assert.NoError(t, err) 55 | assert.Equal(t, 3, len(docs)) 56 | assert.Equal(t, "Japan's Seiko Epson Corp. has developed a 12-gram flying microrobot.", docs[0].Content) 57 | assert.Equal(t, "The latest tiny flying robot has been unveiled in Japan.", docs[1].Content) 58 | }) 59 | } 60 | 61 | func BenchmarkCompressedVectorStore(b *testing.B) { 62 | ctx := context.Background() 63 | // Test different sizes 64 | sizes := []int{10, 100, 1000, 10_000, 100_000} 65 | // note that runtime doesn't really depend on K - 66 | ks := []int{1, 10, 100} 67 | // benchmark inserts 68 | stores := map[string]vectorstore2.VectorStore{ 69 | "DummyVectorStore": vectorstore2.NewDummyVectorStore(), 70 | "StdGzipVectorStore": vectorstore2.NewStdGzipVectorStore(), 71 | "ZstdVectorStore": vectorstore2.NewZstdVectorStore(), 72 | "GzipVectorStore": vectorstore2.NewGzipVectorStore(), 73 | } 74 | 75 | for vsName, vs := range stores { 76 | // for _, size := range sizes { 77 | // b.Run(fmt.Sprintf("%s-Insert-%d", vsName, size), func(b *testing.B) { 78 | // // Create vector store using live compression 79 | // docs := make([]gollum.Document, size) 80 | // // Generate synthetic docs 81 | // for i := range docs { 82 | // docs[i] = syntheticDoc() 83 | // } 84 | // b.ReportAllocs() 85 | // b.ResetTimer() 86 | // for n := 0; n < b.N; n++ { 87 | // // Insert docs 88 | // for _, doc := range docs { 89 | // vs.Insert(ctx, doc) 90 | // } 91 | // } 92 | // }) 93 | // } 94 | // // Concurrent writes to a slice are ok 95 | // for _, size := range sizes { 96 | // b.Run(fmt.Sprintf("%s-InsertConcurrent-%d", vsName, size), func(b *testing.B) { 97 | // // Create vector store using live compression 98 | // docs := make([]gollum.Document, size) 99 | // // Generate synthetic docs 100 | // for i := range docs { 101 | // docs[i] = syntheticDoc() 102 | // } 103 | // var wg sync.WaitGroup 104 | // sem := make(chan struct{}, 8) 105 | // b.ReportAllocs() 106 | // b.ResetTimer() 107 | // for n := 0; n < b.N; n++ { 108 | // // Insert docs 109 | // for _, doc := range docs { 110 | // wg.Add(1) 111 | // sem <- struct{}{} 112 | // go func(doc gollum.Document) { 113 | // defer wg.Done() 114 | // defer func() { <-sem }() 115 | // vs.Insert(ctx, doc) 116 | // }(doc) 117 | // } 118 | // wg.Wait() 119 | // } 120 | // }) 121 | // } 122 | // benchmark queries 123 | for _, size := range sizes { 124 | f, err := os.Open("testdata/enwik8") 125 | if err != nil { 126 | panic(err) 127 | } 128 | defer f.Close() 129 | lines := make([]string, 0) 130 | scanner := bufio.NewScanner(f) 131 | for scanner.Scan() { 132 | lines = append(lines, scanner.Text()) 133 | } 134 | for _, k := range ks { 135 | if k <= size { 136 | b.Run(fmt.Sprintf("%s-Query-%d-%d", vsName, size, k), func(b *testing.B) { 137 | // Create vector store and insert docs 138 | for i := 0; i < size; i++ { 139 | vs.Insert(ctx, gollum.NewDocumentFromString(lines[i])) 140 | } 141 | query := vectorstore2.QueryRequest{ 142 | Query: lines[size+1], 143 | } 144 | b.ReportAllocs() 145 | b.ResetTimer() 146 | // Create query 147 | for n := 0; n < b.N; n++ { 148 | vs.Query(ctx, query) 149 | } 150 | }) 151 | } 152 | } 153 | } 154 | } 155 | } 156 | 157 | // Helper functions 158 | func syntheticString() string { 159 | // Random length between 8 and 32 160 | randLength := mathrand.Intn(32-8+1) + 8 161 | 162 | // Generate random bytes 163 | randBytes := make([]byte, randLength) 164 | rand.Read(randBytes) 165 | 166 | // Format as hex string 167 | return fmt.Sprintf("%x", randBytes) 168 | } 169 | 170 | // syntheticQuery return query request with random embedding 171 | func syntheticQuery(k int) vectorstore2.QueryRequest { 172 | return vectorstore2.QueryRequest{ 173 | Query: syntheticString(), 174 | K: k, 175 | } 176 | } 177 | 178 | func BenchmarkStringToBytes(b *testing.B) { 179 | st := syntheticString() 180 | b.ResetTimer() 181 | b.Run("byteSlice", func(b *testing.B) { 182 | for n := 0; n < b.N; n++ { 183 | _ = []byte(st) 184 | } 185 | }) 186 | b.Run("byteSliceCopy", func(b *testing.B) { 187 | for n := 0; n < b.N; n++ { 188 | bts := make([]byte, len(st)) 189 | copy(bts, st) 190 | } 191 | }) 192 | b.Run("byteSliceCopyAppend", func(b *testing.B) { 193 | for n := 0; n < b.N; n++ { 194 | bts := make([]byte, 0) 195 | bts = append(bts, st...) 196 | _ = bts 197 | } 198 | }) 199 | b.Run("bytesBuffer", func(b *testing.B) { 200 | for n := 0; n < b.N; n++ { 201 | bb := bytes.NewBufferString(st) 202 | _ = bb.Bytes() 203 | } 204 | }) 205 | b.Run("bytesBufferEmpty", func(b *testing.B) { 206 | for n := 0; n < b.N; n++ { 207 | var bb bytes.Buffer 208 | bb.WriteString(st) 209 | _ = bb.Bytes() 210 | } 211 | }) 212 | } 213 | 214 | func dummyCompress(src []byte) []byte { 215 | return src 216 | } 217 | 218 | func minMax(val1, val2 float64) (float64, float64) { 219 | if val1 < val2 { 220 | return val1, val2 221 | } 222 | return val2, val1 223 | } 224 | 225 | func BenchmarkE2E(b *testing.B) { 226 | f, err := os.Open("testdata/enwik8") 227 | if err != nil { 228 | panic(err) 229 | } 230 | defer f.Close() 231 | lines := make([]string, 0) 232 | scanner := bufio.NewScanner(f) 233 | for scanner.Scan() { 234 | lines = append(lines, scanner.Text()) 235 | } 236 | // st1 := syntheticString() 237 | // st2 := syntheticString() 238 | st1 := lines[1] 239 | st2 := lines[2] 240 | b.ResetTimer() 241 | b.Run("minMax", func(b *testing.B) { 242 | for n := 0; n < b.N; n++ { 243 | Cx1 := float64(len(st1)) 244 | Cx2 := float64(len(st2)) 245 | min, max := minMax(Cx1, Cx2) 246 | _ = min 247 | _ = max 248 | } 249 | }) 250 | var bb bytes.Buffer 251 | b.Run("resetBytesBufferBytes", func(b *testing.B) { 252 | st1b := []byte(st1) 253 | st2b := []byte(st2) 254 | spb := []byte(" ") 255 | for n := 0; n < b.N; n++ { 256 | Cx1 := float64(len(st1b)) 257 | Cx2 := float64(len(st2b)) 258 | bb.Reset() 259 | bb.Write(st1b) 260 | bb.Write(spb) 261 | bb.Write(st2b) 262 | b_ := bb.Bytes() 263 | x1x2 := dummyCompress(b_) 264 | Cx1x2 := float64(len(x1x2)) 265 | min, max := minMax(Cx1, Cx2) 266 | ncd := (Cx1x2 - min) / (max) 267 | _ = ncd 268 | } 269 | }) 270 | } 271 | 272 | func BenchmarkConcatenateStrings(b *testing.B) { 273 | f, err := os.Open("testdata/enwik8") 274 | if err != nil { 275 | panic(err) 276 | } 277 | defer f.Close() 278 | lines := make([]string, 0) 279 | scanner := bufio.NewScanner(f) 280 | for scanner.Scan() { 281 | lines = append(lines, scanner.Text()) 282 | } 283 | // st1 := syntheticString() 284 | // st2 := syntheticString() 285 | st1 := lines[1] 286 | st2 := lines[2] 287 | b.ResetTimer() 288 | b.Run("minMax", func(b *testing.B) { 289 | for n := 0; n < b.N; n++ { 290 | Cx1 := float64(len(st1)) 291 | Cx2 := float64(len(st2)) 292 | min, max := minMax(Cx1, Cx2) 293 | _ = min 294 | _ = max 295 | } 296 | }) 297 | b.Run("concatenate", func(b *testing.B) { 298 | for n := 0; n < b.N; n++ { 299 | _ = []byte(st1 + " " + st2) 300 | } 301 | }) 302 | b.Run("bytesBuffer", func(b *testing.B) { 303 | for n := 0; n < b.N; n++ { 304 | bb := bytes.NewBufferString(st1) 305 | bb.WriteString(" ") 306 | bb.WriteString(st2) 307 | _ = bb.Bytes() 308 | } 309 | }) 310 | b.Run("bytesBufferEmpty", func(b *testing.B) { 311 | for n := 0; n < b.N; n++ { 312 | var bb bytes.Buffer 313 | bb.WriteString(st1) 314 | bb.WriteString(" ") 315 | bb.WriteString(st2) 316 | _ = bb.Bytes() 317 | } 318 | }) 319 | var bb bytes.Buffer 320 | b.Run("resetBytesBuffer", func(b *testing.B) { 321 | for n := 0; n < b.N; n++ { 322 | bb.Reset() 323 | bb.WriteString(st1) 324 | bb.WriteString(" ") 325 | bb.WriteString(st2) 326 | _ = bb.Bytes() 327 | } 328 | }) 329 | 330 | } 331 | 332 | func BenchmarkCompress(b *testing.B) { 333 | compressors := map[string]vectorstore2.Compressor{ 334 | "DummyCompressor": vectorstore2.NewDummyVectorStore().Compressor, 335 | "StdGzipCompressor": vectorstore2.NewStdGzipVectorStore().Compressor, 336 | "ZstdCompressor": vectorstore2.NewZstdVectorStore().Compressor, 337 | "GzipCompressor": vectorstore2.NewGzipVectorStore().Compressor, 338 | } 339 | str := syntheticString() 340 | b.ResetTimer() 341 | for name, compressor := range compressors { 342 | b.Run(name, func(b *testing.B) { 343 | for n := 0; n < b.N; n++ { 344 | _ = compressor.Compress([]byte(str)) 345 | } 346 | }) 347 | } 348 | } 349 | -------------------------------------------------------------------------------- /packages/llm/configs.go: -------------------------------------------------------------------------------- 1 | package llm 2 | 3 | const ( 4 | ProviderAnthropic ProviderType = "anthropic" 5 | ProviderGoogle ProviderType = "google" 6 | ProviderVertex ProviderType = "vertex" 7 | 8 | ProviderOpenAI ProviderType = "openai" 9 | ProviderGroq ProviderType = "groq" 10 | ProviderTogether ProviderType = "together" 11 | ProviderHyperbolic ProviderType = "hyperbolic" 12 | ProviderDeepseek ProviderType = "deepseek" 13 | 14 | ProviderVoyage ProviderType = "voyage" 15 | ProviderMixedBread ProviderType = "mixedbread" 16 | ) 17 | 18 | // configs are user declared, here's some useful defaults 19 | const ( 20 | // LLM models 21 | 22 | // aka claude 3.6 23 | ConfigClaude3Dot6Sonnet = "claude-3.6-sonnet" 24 | // the traditional 3.5 25 | ConfigClaude3Dot5Sonnet = "claude-3.5-sonnet" 26 | ConfigClaude3Dot7Sonnet = "claude-3.7-sonnet" 27 | 28 | ConfigGPT4Mini = "gpt-4-mini" 29 | ConfigGPT4o = "gpt-4o" 30 | ConfigOpenAIO1 = "oai-o1" 31 | ConfigOpenAIO1Mini = "oai-o1-mini" 32 | ConfigOpenAIO1Preview = "oai-o1-preview" 33 | ConfigOpenAIO3Mini = "oai-o3-mini" 34 | 35 | ConfigGroqLlama70B = "groq-llama-70b" 36 | ConfigGroqLlama8B = "groq-llama-8b" 37 | ConfigGroqGemma9B = "groq-gemma2-9b" 38 | ConfigGroqMixtral = "groq-mixtral" 39 | 40 | ConfigTogetherGemma27B = "together-gemma-27b" 41 | ConfigTogetherDeepseekCoder33B = "together-deepseek-coder-33b" 42 | 43 | ConfigGemini1Dot5Flash8B = "gemini-flash-8b" 44 | ConfigGemini1Dot5Flash = "gemini-flash" 45 | ConfigGemini1Dot5Pro = "gemini-pro" 46 | ConfigGemini2Flash = "gemini-2-flash" 47 | 48 | ConfigHyperbolicLlama405B = "hyperbolic-llama-405b" 49 | ConfigHyperbolicLlama405BBase = "hyperbolic-llama-405b-base" 50 | ConfigHyperbolicLlama70B = "hyperbolic-llama-70b" 51 | ConfigHyperbolicLlama8B = "hyperbolic-llama-8b" 52 | 53 | ConfigDeepseekChat = "deepseek-chat" 54 | ConfigDeepseekCoder = "deepseek-coder" 55 | 56 | // Vertex 57 | ConfigClaude3Dot5SonnetVertex = "claude-3.5-sonnet-vertex" 58 | ConfigLlama405BVertex = "llama-405b-vertex" 59 | 60 | // Embedding models 61 | ConfigOpenAITextEmbedding3Small = "openai-text-embedding-3-small" 62 | ConfigOpenAITextEmbedding3Large = "openai-text-embedding-3-large" 63 | ConfigOpenAITextEmbeddingAda002 = "openai-text-embedding-ada-002" 64 | 65 | ConfigGeminiTextEmbedding4 = "gemini-text-embedding-004" 66 | 67 | ConfigMxbaiEmbedLargeV1 = "mxbai-embed-large-v1" 68 | ConfigVoyageLarge2Instruct = "voyage-large-2-instruct" 69 | ) 70 | 71 | var configs = map[string]ModelConfig{ 72 | ConfigClaude3Dot5Sonnet: { 73 | ProviderType: ProviderAnthropic, 74 | ModelName: "claude-3-5-sonnet-20240620", 75 | ModelType: ModelTypeLLM, 76 | CentiCentsPerMillionInputTokens: 30000, 77 | CentiCentsPerMillionOutputTokens: 150000, 78 | }, 79 | ConfigClaude3Dot5SonnetVertex: { 80 | ProviderType: ProviderVertex, 81 | ModelName: "claude-3-5-sonnet@20240620", 82 | ModelType: ModelTypeLLM, 83 | CentiCentsPerMillionInputTokens: 30000, 84 | CentiCentsPerMillionOutputTokens: 150000, 85 | }, 86 | ConfigLlama405BVertex: { 87 | ProviderType: ProviderVertex, 88 | ModelName: "llama3-405b-instruct-maas", 89 | ModelType: ModelTypeLLM, 90 | // The Llama 3.1 API service is at no cost during public preview, and will be priced as per dollar-per-1M-tokens at GA. 91 | }, 92 | ConfigClaude3Dot6Sonnet: { 93 | ProviderType: ProviderAnthropic, 94 | ModelName: "claude-3-5-sonnet-20241022", 95 | }, 96 | ConfigClaude3Dot7Sonnet: { 97 | ProviderType: ProviderAnthropic, 98 | ModelName: "claude-3-7-sonnet-latest", 99 | }, 100 | ConfigGPT4Mini: { 101 | ProviderType: ProviderOpenAI, 102 | ModelName: "gpt-4o-mini", 103 | ModelType: ModelTypeLLM, 104 | CentiCentsPerMillionInputTokens: 1_500, 105 | CentiCentsPerMillionOutputTokens: 6_000, 106 | }, 107 | ConfigGPT4o: { 108 | ProviderType: ProviderOpenAI, 109 | // NB 2025-01-08: 2024-08-06 remains 'latest', but there is also 'gpt-4o-2024-11-20' 110 | ModelName: "gpt-4o-2024-08-06", 111 | CentiCentsPerMillionInputTokens: 25_000, 112 | CentiCentsPerMillionOutputTokens: 100_000, 113 | }, 114 | ConfigOpenAIO1: { 115 | ProviderType: ProviderOpenAI, 116 | ModelName: "o1", 117 | CentiCentsPerMillionInputTokens: 150_000, 118 | CentiCentsPerMillionOutputTokens: 600_000, 119 | }, 120 | ConfigOpenAIO1Mini: { 121 | ProviderType: ProviderOpenAI, 122 | ModelName: "o1-mini", 123 | CentiCentsPerMillionInputTokens: 30_000, 124 | CentiCentsPerMillionOutputTokens: 120_000, 125 | }, 126 | ConfigOpenAIO3Mini: { 127 | ProviderType: ProviderOpenAI, 128 | ModelName: "o3-mini", 129 | CentiCentsPerMillionInputTokens: 150_000, 130 | CentiCentsPerMillionOutputTokens: 600_000, 131 | }, 132 | ConfigOpenAIO1Preview: { 133 | ProviderType: ProviderOpenAI, 134 | ModelName: "o1-preview", 135 | CentiCentsPerMillionInputTokens: 150_000, 136 | CentiCentsPerMillionOutputTokens: 600_000, 137 | }, 138 | ConfigGroqLlama70B: { 139 | ProviderType: ProviderGroq, 140 | ModelName: "llama-3.3-70b-versatile", 141 | CentiCentsPerMillionInputTokens: 5900, 142 | CentiCentsPerMillionOutputTokens: 7900, 143 | }, 144 | ConfigGroqMixtral: { 145 | ProviderType: ProviderGroq, 146 | ModelName: "mixtral-8x7b-32768", 147 | ModelType: ModelTypeLLM, 148 | CentiCentsPerMillionInputTokens: 2400, 149 | CentiCentsPerMillionOutputTokens: 2400, 150 | }, 151 | ConfigGroqGemma9B: { 152 | ProviderType: ProviderGroq, 153 | ModelName: "gemma2-9b-it", 154 | ModelType: ModelTypeLLM, 155 | CentiCentsPerMillionInputTokens: 2000, 156 | CentiCentsPerMillionOutputTokens: 2000, 157 | }, 158 | ConfigGroqLlama8B: { 159 | ProviderType: ProviderGroq, 160 | ModelType: ModelTypeLLM, 161 | CentiCentsPerMillionInputTokens: 500, 162 | CentiCentsPerMillionOutputTokens: 800, 163 | ModelName: "llama-3.1-8b-instant", 164 | }, 165 | ConfigTogetherGemma27B: { 166 | ProviderType: ProviderTogether, 167 | ModelName: "google/gemma-2-27b-it", 168 | ModelType: ModelTypeLLM, 169 | CentiCentsPerMillionOutputTokens: 8000, 170 | CentiCentsPerMillionInputTokens: 8000, 171 | }, 172 | ConfigTogetherDeepseekCoder33B: { 173 | ProviderType: ProviderTogether, 174 | ModelName: "deepseek-ai/deepseek-coder-33b-instruct", 175 | ModelType: ModelTypeLLM, 176 | CentiCentsPerMillionOutputTokens: 8000, 177 | CentiCentsPerMillionInputTokens: 8000, 178 | }, 179 | ConfigGemini1Dot5Flash: { 180 | ProviderType: ProviderGoogle, 181 | ModelName: "gemini-1.5-flash", 182 | ModelType: ModelTypeLLM, 183 | // assumes < 128k 184 | CentiCentsPerMillionOutputTokens: 3000, 185 | CentiCentsPerMillionInputTokens: 750, 186 | }, 187 | ConfigGemini1Dot5Flash8B: { 188 | ProviderType: ProviderGoogle, 189 | ModelName: "gemini-1.5-flash-8b", 190 | ModelType: ModelTypeLLM, 191 | // assumes < 128k 192 | CentiCentsPerMillionOutputTokens: 1500, 193 | CentiCentsPerMillionInputTokens: 375, 194 | }, 195 | ConfigGemini1Dot5Pro: { 196 | ProviderType: ProviderGoogle, 197 | ModelName: "gemini-1.5-pro", 198 | ModelType: ModelTypeLLM, 199 | // assumes < 128k 200 | CentiCentsPerMillionOutputTokens: 50000, 201 | CentiCentsPerMillionInputTokens: 12500, 202 | }, 203 | ConfigGemini2Flash: { 204 | ProviderType: ProviderGoogle, 205 | ModelName: "gemini-2.0-flash", 206 | }, 207 | ConfigHyperbolicLlama405B: { 208 | ProviderType: ProviderHyperbolic, 209 | ModelName: "meta-llama/Meta-Llama-3.1-405B-Instruct", 210 | ModelType: ModelTypeLLM, 211 | 212 | CentiCentsPerMillionInputTokens: 40000, 213 | CentiCentsPerMillionOutputTokens: 40000, 214 | }, 215 | ConfigHyperbolicLlama405BBase: { 216 | ProviderType: ProviderHyperbolic, 217 | ModelName: "meta-llama/Meta-Llama-3.1-405B", 218 | ModelType: ModelTypeLLM, 219 | 220 | CentiCentsPerMillionInputTokens: 40000, 221 | CentiCentsPerMillionOutputTokens: 40000, 222 | }, 223 | ConfigHyperbolicLlama70B: { 224 | ProviderType: ProviderHyperbolic, 225 | ModelName: "meta-llama/Meta-Llama-3.1-70B-Instruct", 226 | ModelType: ModelTypeLLM, 227 | 228 | CentiCentsPerMillionInputTokens: 4000, 229 | CentiCentsPerMillionOutputTokens: 4000, 230 | }, 231 | ConfigHyperbolicLlama8B: { 232 | ProviderType: ProviderHyperbolic, 233 | ModelName: "meta-llama/Meta-Llama-3.1-8B-Instruct", 234 | ModelType: ModelTypeLLM, 235 | 236 | CentiCentsPerMillionInputTokens: 1000, 237 | CentiCentsPerMillionOutputTokens: 1000, 238 | }, 239 | 240 | ConfigDeepseekChat: { 241 | ProviderType: ProviderDeepseek, 242 | ModelName: "deepseek-chat", 243 | ModelType: ModelTypeLLM, 244 | 245 | // assume cache miss 246 | CentiCentsPerMillionInputTokens: 1400, 247 | CentiCentsPerMillionOutputTokens: 2800, 248 | }, 249 | ConfigDeepseekCoder: { 250 | ProviderType: ProviderDeepseek, 251 | ModelName: "deepseek-coder", 252 | ModelType: ModelTypeLLM, 253 | 254 | // assume cache miss 255 | CentiCentsPerMillionInputTokens: 1400, 256 | CentiCentsPerMillionOutputTokens: 2800, 257 | }, 258 | 259 | ConfigOpenAITextEmbedding3Small: { 260 | ProviderType: ProviderOpenAI, 261 | ModelName: "text-embedding-3-small", 262 | ModelType: ModelTypeEmbedding, 263 | }, 264 | ConfigOpenAITextEmbedding3Large: { 265 | ProviderType: ProviderOpenAI, 266 | ModelName: "text-embedding-3-large", 267 | ModelType: ModelTypeEmbedding, 268 | }, 269 | ConfigOpenAITextEmbeddingAda002: { 270 | ProviderType: ProviderOpenAI, 271 | ModelName: "text-embedding-ada-002", 272 | ModelType: ModelTypeEmbedding, 273 | }, 274 | 275 | ConfigGeminiTextEmbedding4: { 276 | ProviderType: ProviderGoogle, 277 | ModelName: "text-embedding-004", 278 | ModelType: ModelTypeEmbedding, 279 | }, 280 | 281 | ConfigMxbaiEmbedLargeV1: { 282 | ProviderType: ProviderMixedBread, 283 | ModelName: "mxbai-embed-large-v1", 284 | ModelType: ModelTypeEmbedding, 285 | }, 286 | 287 | ConfigVoyageLarge2Instruct: { 288 | ProviderType: ProviderVoyage, 289 | ModelName: "voyage-large-2-instruct", 290 | ModelType: ModelTypeEmbedding, 291 | }, 292 | } 293 | -------------------------------------------------------------------------------- /packages/dispatch/functions_test.go: -------------------------------------------------------------------------------- 1 | package dispatch_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "os" 7 | "testing" 8 | 9 | "github.com/stillmatic/gollum/packages/dispatch" 10 | "github.com/stillmatic/gollum/packages/jsonparser" 11 | 12 | "github.com/joho/godotenv" 13 | "github.com/sashabaranov/go-openai" 14 | "github.com/stillmatic/gollum/internal/testutil" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | type addInput struct { 19 | A int `json:"a" json_schema:"required"` 20 | B int `json:"b" json_schema:"required"` 21 | } 22 | 23 | type getWeatherInput struct { 24 | Location string `json:"location" jsonschema_description:"The city and state." jsonschema:"required,example=San Francisco, CA"` 25 | Unit string `json:"unit,omitempty" jsonschema:"enum=celsius,enum=fahrenheit" jsonschema_description:"The unit of temperature,default=fahrenheit"` 26 | } 27 | 28 | type counter struct { 29 | Count int `json:"count" jsonschema:"required" jsonschema_description:"total number of words in sentence"` 30 | Words []string `json:"words" jsonschema:"required" jsonschema_description:"list of words in sentence"` 31 | } 32 | 33 | type blobNode struct { 34 | Name string `json:"name" jsonschema:"required"` 35 | Children []blobNode `json:"children,omitempty" jsonschema_description:"list of child nodes - only applicable if this is a directory"` 36 | NodeType string `json:"node_type" jsonschema:"required,enum=file,enum=folder" jsonschema_description:"type of node, inferred from name"` 37 | } 38 | 39 | type queryNode struct { 40 | Question string `json:"question" jsonschema:"required" jsonschema_description:"question to ask - questions can use information from children questions"` 41 | // NodeType string `json:"node_type" jsonschema:"required,enum=single_question,enum=merge_responses" jsonschema_description:"type of question. Either a single question or a multi question merge when there are multiple questions."` 42 | Children []queryNode `json:"children,omitempty" jsonschema_description:"list of child questions that need to be answered before this question can be answered. Use a subquery when anything may be unknown, and we need to ask multiple questions to get the answer. Dependences must only be other queries."` 43 | } 44 | 45 | func TestConstructJSONSchema(t *testing.T) { 46 | t.Run("add_", func(t *testing.T) { 47 | res := dispatch.StructToJsonSchema("add_", "adds two numbers", addInput{}) 48 | expectedStr := `{"name":"add_","description":"adds two numbers","parameters":{"properties":{"a":{"type":"integer"},"b":{"type":"integer"}},"type":"object","required":["a","b"]}}` 49 | b, err := json.Marshal(res) 50 | assert.NoError(t, err) 51 | assert.Equal(t, expectedStr, string(b)) 52 | }) 53 | t.Run("getWeather", func(t *testing.T) { 54 | res := dispatch.StructToJsonSchema("getWeather", "Get the current weather in a given location", getWeatherInput{}) 55 | assert.Equal(t, res.Name, "getWeather") 56 | assert.Equal(t, res.Description, "Get the current weather in a given location") 57 | // assert.Equal(t, res.Parameters.Type, "object") 58 | expectedStr := `{"name":"getWeather","description":"Get the current weather in a given location","parameters":{"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"],"description":"The unit of temperature"}},"type":"object","required":["location"]}}` 59 | b, err := json.Marshal(res) 60 | assert.NoError(t, err) 61 | assert.Equal(t, expectedStr, string(b)) 62 | }) 63 | } 64 | 65 | func ptr[T any](v T) *T { 66 | return &v 67 | } 68 | 69 | func TestEndToEnd(t *testing.T) { 70 | godotenv.Load() 71 | baseAPIURL := "https://api.openai.com/v1/chat/completions" 72 | openAIKey := os.Getenv("OPENAI_API_KEY") 73 | assert.NotEmpty(t, openAIKey) 74 | 75 | api := testutil.NewTestAPI(baseAPIURL, openAIKey) 76 | t.Run("weather", func(t *testing.T) { 77 | t.Skip("somewhat flaky - word counter is more reliable") 78 | fi := dispatch.StructToJsonSchema("weather", "Get the current weather in a given location", getWeatherInput{}) 79 | 80 | chatRequest := openai.ChatCompletionRequest{ 81 | Model: openai.GPT3Dot5Turbo1106, 82 | Messages: []openai.ChatCompletionMessage{ 83 | { 84 | Role: "user", 85 | Content: "Whats the temperature in Boston?", 86 | }, 87 | }, 88 | MaxTokens: 256, 89 | Temperature: 0.0, 90 | Tools: []openai.Tool{{Type: "function", Function: ptr(openai.FunctionDefinition(fi))}}, 91 | ToolChoice: "weather", 92 | } 93 | 94 | ctx := context.Background() 95 | resp, err := api.SendRequest(ctx, chatRequest) 96 | assert.NoError(t, err) 97 | 98 | assert.Equal(t, resp.Model, "gpt-3.5-turbo-0613") 99 | assert.NotEmpty(t, resp.Choices) 100 | assert.Empty(t, resp.Choices[0].Message.Content) 101 | assert.NotNil(t, resp.Choices[0].Message.ToolCalls) 102 | assert.Equal(t, resp.Choices[0].Message.ToolCalls[0].Function.Name, "weather") 103 | 104 | // this is somewhat flaky - about 20% of the time it returns 'Boston' 105 | expectedArg := []byte(`{"location": "Boston, MA"}`) 106 | parser := jsonparser.NewJSONParserGeneric[getWeatherInput](false) 107 | expectedStruct, err := parser.Parse(ctx, expectedArg) 108 | assert.NoError(t, err) 109 | input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments)) 110 | assert.NoError(t, err) 111 | assert.Equal(t, expectedStruct, input) 112 | }) 113 | 114 | t.Run("counter", func(t *testing.T) { 115 | fi := dispatch.StructToJsonSchema("split_word", "Break sentences into words", counter{}) 116 | chatRequest := openai.ChatCompletionRequest{ 117 | Model: "gpt-3.5-turbo-0613", 118 | Messages: []openai.ChatCompletionMessage{ 119 | { 120 | Role: "user", 121 | Content: "「What is the weather like in Boston?」Break down the above sentence into words", 122 | }, 123 | }, 124 | MaxTokens: 256, 125 | Temperature: 0.0, 126 | Tools: []openai.Tool{ 127 | {Type: "function", Function: ptr(openai.FunctionDefinition(fi))}, 128 | }, 129 | } 130 | ctx := context.Background() 131 | resp, err := api.SendRequest(ctx, chatRequest) 132 | assert.NoError(t, err) 133 | 134 | assert.Equal(t, resp.Model, "gpt-3.5-turbo-0613") 135 | assert.NotEmpty(t, resp.Choices) 136 | assert.Empty(t, resp.Choices[0].Message.Content) 137 | assert.NotNil(t, resp.Choices[0].Message.ToolCalls) 138 | assert.Equal(t, resp.Choices[0].Message.ToolCalls[0].Function.Name, "split_word") 139 | 140 | expectedStruct := counter{ 141 | Count: 7, 142 | Words: []string{"What", "is", "the", "weather", "like", "in", "Boston?"}, 143 | } 144 | parser := jsonparser.NewJSONParserGeneric[counter](false) 145 | input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments)) 146 | assert.NoError(t, err) 147 | assert.Equal(t, expectedStruct, input) 148 | }) 149 | 150 | t.Run("callOpenAI", func(t *testing.T) { 151 | fi := dispatch.StructToJsonSchema("ChatCompletion", "Call the OpenAI chat completion API", openai.ChatCompletionRequest{}) 152 | 153 | chatRequest := openai.ChatCompletionRequest{ 154 | Model: "gpt-3.5-turbo-0613", 155 | Messages: []openai.ChatCompletionMessage{ 156 | { 157 | Role: openai.ChatMessageRoleSystem, 158 | Content: "Construct a ChatCompletionRequest to answer the user's question, but using Kirby references. Do not answer the question directly using prior knowledge, you must generate a ChatCompletionRequest that will answer the question.", 159 | }, 160 | { 161 | Role: openai.ChatMessageRoleUser, 162 | Content: "What is the definition of recursion?", 163 | }, 164 | }, 165 | MaxTokens: 256, 166 | Temperature: 0.0, 167 | Tools: []openai.Tool{ 168 | {Type: "function", Function: ptr(openai.FunctionDefinition(fi))}, 169 | }, 170 | } 171 | 172 | ctx := context.Background() 173 | resp, err := api.SendRequest(ctx, chatRequest) 174 | assert.NoError(t, err) 175 | assert.Equal(t, resp.Model, "gpt-3.5-turbo-0613") 176 | assert.NotEmpty(t, resp.Choices) 177 | assert.Empty(t, resp.Choices[0].Message.Content) 178 | assert.NotNil(t, resp.Choices[0].Message.ToolCalls) 179 | assert.Equal(t, resp.Choices[0].Message.ToolCalls[0].Function.Name, "ChatCompletion") 180 | 181 | parser := jsonparser.NewJSONParserGeneric[openai.ChatCompletionRequest](false) 182 | input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments)) 183 | assert.NoError(t, err) 184 | assert.NotEmpty(t, input) 185 | 186 | // an example output: 187 | // "{ 188 | // "model": "gpt-3.5-turbo", 189 | // "messages": [ 190 | // {"role": "system", "content": "You are Kirby, a friendly virtual assistant."}, 191 | // {"role": "user", "content": "What is the definition of recursion?"} 192 | // ] 193 | // }" 194 | }) 195 | 196 | t.Run("directory", func(t *testing.T) { 197 | fi := dispatch.StructToJsonSchema("directory", "Get the contents of a directory", blobNode{}) 198 | inp := `root 199 | ├── dir1 200 | │ ├── file1.txt 201 | │ └── file2.txt 202 | └── dir2 203 | ├── file3.txt 204 | └── subfolder 205 | └── file4.txt` 206 | 207 | chatRequest := openai.ChatCompletionRequest{ 208 | Model: "gpt-3.5-turbo-0613", 209 | Messages: []openai.ChatCompletionMessage{ 210 | { 211 | Role: "user", 212 | Content: inp, 213 | }, 214 | }, 215 | MaxTokens: 256, 216 | Temperature: 0.0, 217 | Tools: []openai.Tool{ 218 | {Type: "function", Function: ptr(openai.FunctionDefinition(fi))}, 219 | }, 220 | } 221 | ctx := context.Background() 222 | resp, err := api.SendRequest(ctx, chatRequest) 223 | assert.NoError(t, err) 224 | t.Log(resp) 225 | assert.Equal(t, 0, 1) 226 | 227 | parser := jsonparser.NewJSONParserGeneric[blobNode](false) 228 | input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments)) 229 | assert.NoError(t, err) 230 | assert.NotEmpty(t, input) 231 | assert.Equal(t, input, blobNode{ 232 | Name: "root", 233 | Children: []blobNode{ 234 | { 235 | Name: "dir1", 236 | Children: []blobNode{ 237 | { 238 | Name: "file1.txt", 239 | NodeType: "file", 240 | }, 241 | { 242 | Name: "file2.txt", 243 | NodeType: "file", 244 | }, 245 | }, 246 | NodeType: "folder", 247 | }, 248 | { 249 | Name: "dir2", 250 | Children: []blobNode{ 251 | { 252 | Name: "file3.txt", 253 | NodeType: "file", 254 | }, 255 | { 256 | Name: "subfolder", 257 | Children: []blobNode{ 258 | { 259 | Name: "file4.txt", 260 | NodeType: "file", 261 | }, 262 | }, 263 | NodeType: "folder", 264 | }, 265 | }, 266 | NodeType: "folder", 267 | }, 268 | }, 269 | NodeType: "folder", 270 | }) 271 | }) 272 | t.Run("planner", func(t *testing.T) { 273 | fi := dispatch.StructToJsonSchema("queryPlanner", "Plan a multi-step query", queryNode{}) 274 | // inp := `Jason is from Canada` 275 | 276 | chatRequest := openai.ChatCompletionRequest{ 277 | Model: "gpt-3.5-turbo-0613", 278 | Messages: []openai.ChatCompletionMessage{ 279 | { 280 | Role: "system", 281 | Content: `When a user asks a question, you must use the 'queryPlanner' function to answer the question. If you are at all unsure, break the question into multiple smaller questions 282 | Example: 283 | Input: What is the population of Jason's home country? 284 | 285 | Output: What is the population of Jason's home country? 286 | │ ├── What is Jason's home country? 287 | │ ├── What is the population of that country?`, 288 | }, 289 | { 290 | Role: "user", 291 | Content: "What's on the flag of Jason's home country?", 292 | }, 293 | }, 294 | MaxTokens: 256, 295 | Temperature: 0.0, 296 | Tools: []openai.Tool{ 297 | {Type: "function", Function: ptr(openai.FunctionDefinition(fi))}, 298 | }, 299 | } 300 | ctx := context.Background() 301 | resp, err := api.SendRequest(ctx, chatRequest) 302 | assert.NoError(t, err) 303 | t.Log(resp) 304 | assert.Equal(t, 0, 1) 305 | 306 | }) 307 | } 308 | 309 | func BenchmarkStructToJsonSchem(b *testing.B) { 310 | b.Run("basic", func(b *testing.B) { 311 | for i := 0; i < b.N; i++ { 312 | dispatch.StructToJsonSchema("queryPlanner", "Plan a multi-step query", queryNode{}) 313 | } 314 | }) 315 | b.Run("generic", func(b *testing.B) { 316 | for i := 0; i < b.N; i++ { 317 | dispatch.StructToJsonSchemaGeneric[queryNode]("queryPlanner", "Plan a multi-step query") 318 | } 319 | }) 320 | } 321 | -------------------------------------------------------------------------------- /packages/llm/providers/google/google.go: -------------------------------------------------------------------------------- 1 | package google 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "strings" 7 | "time" 8 | 9 | "github.com/cespare/xxhash/v2" 10 | "github.com/stillmatic/gollum/packages/llm" 11 | "google.golang.org/api/option" 12 | 13 | "github.com/google/generative-ai-go/genai" 14 | "github.com/pkg/errors" 15 | "google.golang.org/api/iterator" 16 | ) 17 | 18 | type Provider struct { 19 | client *genai.Client 20 | cachedFileMap map[string]string 21 | cachedContentMap map[string]struct{} 22 | } 23 | 24 | func NewGoogleProvider(ctx context.Context, apiKey string) (*Provider, error) { 25 | client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) 26 | if err != nil { 27 | return nil, errors.Wrap(err, "google client error") 28 | } 29 | 30 | // load cached content map 31 | p := &Provider{client: client} 32 | err = p.refreshCachedContentMap(ctx) 33 | if err != nil { 34 | return nil, errors.Wrap(err, "google refresh cached content map error") 35 | } 36 | 37 | return p, nil 38 | } 39 | 40 | func (p *Provider) refreshCachedFileMap(ctx context.Context) error { 41 | iter := p.client.ListFiles(ctx) 42 | cachedFileMap := make(map[string]string) 43 | for { 44 | cachedFile, err := iter.Next() 45 | if errors.Is(err, iterator.Done) { 46 | break 47 | } 48 | if err != nil { 49 | return errors.Wrap(err, "google list cached files error") 50 | } 51 | cachedFileMap[cachedFile.Name] = cachedFile.URI 52 | } 53 | p.cachedFileMap = cachedFileMap 54 | return nil 55 | } 56 | 57 | func (p *Provider) refreshCachedContentMap(ctx context.Context) error { 58 | iter := p.client.ListCachedContents(ctx) 59 | cachedContentMap := make(map[string]struct{}) 60 | for { 61 | cachedContent, err := iter.Next() 62 | if errors.Is(err, iterator.Done) { 63 | break 64 | } 65 | if err != nil { 66 | return errors.Wrap(err, "google list cached content error") 67 | } 68 | cachedContentMap[cachedContent.Name] = struct{}{} 69 | } 70 | p.cachedContentMap = cachedContentMap 71 | return nil 72 | } 73 | 74 | func getHash(value string) string { 75 | return string(xxhash.New().Sum([]byte(value))) 76 | } 77 | 78 | func getHashBytes(value []byte) string { 79 | return string(xxhash.New().Sum(value)) 80 | } 81 | 82 | func (p *Provider) uploadFile(ctx context.Context, key string, value string) (*genai.File, error) { 83 | // check if the file is already cached 84 | if _, ok := p.cachedFileMap[key]; ok { 85 | // if so, load the cached file and return it 86 | cachedFile, err := p.client.GetFile(ctx, key) 87 | if err != nil { 88 | return nil, errors.Wrap(err, "google get file error") 89 | } 90 | return cachedFile, nil 91 | } 92 | 93 | r := strings.NewReader(value) 94 | file, err := p.client.UploadFile(ctx, key, r, nil) 95 | if err != nil { 96 | return nil, errors.Wrap(err, "google upload file error") 97 | } 98 | p.cachedFileMap[key] = file.URI 99 | 100 | return file, nil 101 | } 102 | 103 | func (p *Provider) createCachedContent(ctx context.Context, value string, modelName string) (*genai.CachedContent, error) { 104 | key := getHash(value) 105 | 106 | // check if the content is already cached 107 | if _, ok := p.cachedContentMap[key]; ok { 108 | // if so, load the cached content and return it 109 | cachedContent, err := p.client.GetCachedContent(ctx, key) 110 | if err != nil { 111 | return nil, errors.Wrap(err, "google get cached content error") 112 | } 113 | return cachedContent, nil 114 | } 115 | 116 | file, err := p.uploadFile(ctx, key, value) 117 | if err != nil { 118 | return nil, errors.Wrap(err, "error uploading file") 119 | } 120 | fd := genai.FileData{URI: file.URI} 121 | cc := &genai.CachedContent{ 122 | Name: key, 123 | Model: modelName, 124 | Contents: []*genai.Content{genai.NewUserContent(fd)}, 125 | // TODO: make this configurable 126 | // maybe something like an optional field, 'ephemeral' / 'hour'? 127 | // default matches Anthropic's 5 minute TTL 128 | Expiration: genai.ExpireTimeOrTTL{TTL: 5 * time.Minute}, 129 | } 130 | content, err := p.client.CreateCachedContent(ctx, cc) 131 | if err != nil { 132 | return nil, errors.Wrap(err, "error creating cached content") 133 | } 134 | 135 | return content, nil 136 | } 137 | 138 | func (p *Provider) getModel(req llm.InferRequest) *genai.GenerativeModel { 139 | model := p.client.GenerativeModel(req.ModelConfig.ModelName) 140 | model.SetTemperature(req.MessageOptions.Temperature) 141 | model.SetMaxOutputTokens(int32(req.MessageOptions.MaxTokens)) 142 | // lol... 143 | model.SafetySettings = []*genai.SafetySetting{ 144 | {Category: genai.HarmCategoryHarassment, Threshold: genai.HarmBlockNone}, 145 | {Category: genai.HarmCategoryHateSpeech, Threshold: genai.HarmBlockNone}, 146 | {Category: genai.HarmCategorySexuallyExplicit, Threshold: genai.HarmBlockNone}, 147 | {Category: genai.HarmCategoryDangerousContent, Threshold: genai.HarmBlockNone}, 148 | } 149 | model.SetCandidateCount(1) 150 | return model 151 | } 152 | 153 | func (p *Provider) GenerateResponse(ctx context.Context, req llm.InferRequest) (string, error) { 154 | if len(req.Messages) > 1 { 155 | return p.generateResponseChat(ctx, req) 156 | } 157 | 158 | // it is slightly better to build a trie, indexed on hashes of each message 159 | // since we can quickly get based on the prefix (i.e. existing messages) 160 | // but ... your number of messages is probably not THAT high to justify the complexity. 161 | // NB: trie also not useful for images/audio 162 | messagesToCache := make([]llm.InferMessage, 0) 163 | for _, message := range req.Messages { 164 | if message.ShouldCache { 165 | messagesToCache = append(messagesToCache, message) 166 | } 167 | } 168 | model := p.getModel(req) 169 | if len(messagesToCache) > 0 { 170 | // hash the messages and check if the overall object is cached. 171 | // we choose to do this because you may have a later message identical to an earlier message 172 | // if we find exact match for this set of messages, load it. 173 | hashKeys := make([]string, 0, len(messagesToCache)) 174 | for _, message := range messagesToCache { 175 | // it is possible to have collision between user + assistant content being identical 176 | // this feels like a rare case especially given that we are ordering sensitive in the hash. 177 | hashKeys = append(hashKeys, getHash(message.Content)) 178 | if len(message.Image) > 0 { 179 | hashKeys = append(hashKeys, getHashBytes(message.Image)) 180 | } 181 | if len(message.Audio) > 0 { 182 | hashKeys = append(hashKeys, getHashBytes(message.Audio)) 183 | } 184 | } 185 | joinedKey := strings.Join(hashKeys, "/") 186 | var cachedContent *genai.CachedContent 187 | // if the cached object exists, load it 188 | if _, ok := p.cachedContentMap[joinedKey]; ok { 189 | cachedContent, _ = p.client.GetCachedContent(ctx, joinedKey) 190 | model = p.client.GenerativeModelFromCachedContent(cachedContent) 191 | } else { 192 | // otherwise, create a new cached object 193 | cc, err := p.createCachedContent(ctx, joinedKey, req.ModelConfig.ModelName) 194 | if err != nil { 195 | return "", errors.Wrap(err, "google upload file error") 196 | } 197 | model = p.client.GenerativeModelFromCachedContent(cc) 198 | } 199 | } 200 | 201 | parts := singleTurnMessageToParts(req.Messages[0]) 202 | 203 | resp, err := model.GenerateContent(ctx, parts...) 204 | if err != nil { 205 | return "", errors.Wrap(err, "google generate content error") 206 | } 207 | respStr := flattenResponse(resp) 208 | 209 | return respStr, nil 210 | } 211 | 212 | func singleTurnMessageToParts(message llm.InferMessage) []genai.Part { 213 | parts := []genai.Part{genai.Text(message.Content)} 214 | if len(message.Image) > 0 { 215 | // TODO: set the image type based on the actual image type 216 | parts = append(parts, genai.ImageData("png", message.Image)) 217 | } 218 | // if len(message.Audio) > 0 { 219 | // // TODO: set the audio type based on the actual audio type 220 | // parts = append(parts, genai.("wav", message.Audio)) 221 | // } 222 | 223 | return parts 224 | } 225 | 226 | func multiTurnMessageToParts(messages []llm.InferMessage) ([]*genai.Content, *genai.Content) { 227 | sysInstructionParts := make([]genai.Part, 0) 228 | hist := make([]*genai.Content, 0, len(messages)) 229 | for _, message := range messages { 230 | parts := []genai.Part{genai.Text(message.Content)} 231 | if len(message.Image) > 0 { 232 | parts = append(parts, genai.ImageData("png", message.Image)) 233 | } 234 | if message.Role == "system" { 235 | sysInstructionParts = append(sysInstructionParts, parts...) 236 | continue 237 | } 238 | hist = append(hist, &genai.Content{ 239 | Parts: parts, 240 | Role: message.Role, 241 | }) 242 | } 243 | if len(sysInstructionParts) > 0 { 244 | return hist, &genai.Content{ 245 | Parts: sysInstructionParts, 246 | } 247 | } 248 | 249 | return hist, nil 250 | } 251 | 252 | func (p *Provider) generateResponseChat(ctx context.Context, req llm.InferRequest) (string, error) { 253 | model := p.getModel(req) 254 | 255 | // annoyingly, the last message is the one we want to generate a response to, so we need to split it out 256 | msgs, sysInstr := multiTurnMessageToParts(req.Messages[:len(req.Messages)-1]) 257 | if sysInstr != nil { 258 | model.SystemInstruction = sysInstr 259 | } 260 | 261 | cs := model.StartChat() 262 | cs.History = msgs 263 | mostRecentMessage := req.Messages[len(req.Messages)-1] 264 | 265 | // NB chua: this might be a bug but Google doesn't seem to accept multiple parts in the same message 266 | // in the chat API. So can't send text + image if it exists. 267 | //mostRecentMessagePart := []genai.Part{genai.Text(mostRecentMessage.Content)} 268 | //if mostRecentMessage.Image != nil && len(*mostRecentMessage.Image) > 0 { 269 | // mostRecentMessagePart = append(mostRecentMessagePart, genai.ImageData("png", *mostRecentMessage.Image)) 270 | //} 271 | 272 | resp, err := cs.SendMessage(ctx, genai.Text(mostRecentMessage.Content)) 273 | if err != nil { 274 | return "", errors.Wrap(err, "google generate content error") 275 | } 276 | respStr := flattenResponse(resp) 277 | 278 | return respStr, nil 279 | } 280 | 281 | func (p *Provider) GenerateResponseAsync(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) { 282 | if len(req.Messages) > 1 { 283 | return p.generateResponseAsyncChat(ctx, req) 284 | } 285 | return p.generateResponseAsyncSingle(ctx, req) 286 | } 287 | 288 | func (p *Provider) generateResponseAsyncSingle(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) { 289 | outChan := make(chan llm.StreamDelta) 290 | 291 | go func() { 292 | defer close(outChan) 293 | 294 | model := p.getModel(req) 295 | 296 | parts := singleTurnMessageToParts(req.Messages[0]) 297 | iter := model.GenerateContentStream(ctx, parts...) 298 | 299 | for { 300 | resp, err := iter.Next() 301 | if errors.Is(err, iterator.Done) { 302 | outChan <- llm.StreamDelta{EOF: true} 303 | break 304 | } 305 | if err != nil { 306 | slog.Error("error from gemini stream", "err", err, "req", req.Messages[0].Content, "model", req.ModelConfig.ModelName) 307 | return 308 | } 309 | 310 | content := flattenResponse(resp) 311 | if content != "" { 312 | select { 313 | case <-ctx.Done(): 314 | return 315 | case outChan <- llm.StreamDelta{Text: content}: 316 | } 317 | } 318 | } 319 | }() 320 | 321 | return outChan, nil 322 | } 323 | 324 | func (p *Provider) generateResponseAsyncChat(ctx context.Context, req llm.InferRequest) (<-chan llm.StreamDelta, error) { 325 | outChan := make(chan llm.StreamDelta) 326 | 327 | go func() { 328 | defer close(outChan) 329 | 330 | model := p.getModel(req) 331 | msgs, sysInstr := multiTurnMessageToParts(req.Messages[:len(req.Messages)-1]) 332 | if sysInstr != nil { 333 | model.SystemInstruction = sysInstr 334 | } 335 | cs := model.StartChat() 336 | cs.History = msgs 337 | 338 | mostRecentMessage := req.Messages[len(req.Messages)-1] 339 | 340 | iter := cs.SendMessageStream(ctx, genai.Text(mostRecentMessage.Content)) 341 | 342 | for { 343 | resp, err := iter.Next() 344 | if errors.Is(err, iterator.Done) { 345 | outChan <- llm.StreamDelta{EOF: true} 346 | break 347 | } 348 | if err != nil { 349 | slog.Error("error from gemini stream", "err", err, "req", mostRecentMessage.Content, "model", req.ModelConfig.ModelName) 350 | return 351 | } 352 | 353 | content := flattenResponse(resp) 354 | if content != "" { 355 | select { 356 | case <-ctx.Done(): 357 | return 358 | case outChan <- llm.StreamDelta{Text: content}: 359 | } 360 | } 361 | } 362 | }() 363 | 364 | return outChan, nil 365 | } 366 | 367 | // flattenResponse flattens the response from the Gemini API into a single string. 368 | func flattenResponse(resp *genai.GenerateContentResponse) string { 369 | var rtn strings.Builder 370 | for i, part := range resp.Candidates[0].Content.Parts { 371 | switch part := part.(type) { 372 | case genai.Text: 373 | if i > 0 { 374 | rtn.WriteString(" ") 375 | } 376 | rtn.WriteString(string(part)) 377 | } 378 | } 379 | return rtn.String() 380 | } 381 | 382 | // GenerateEmbedding generates embeddings for the given input. 383 | // 384 | // NB chua: This is a confusing method in the docs. 385 | // - There are two separate API methods and it's unclear which you should use. Is batch with 1 the same as single? 386 | // - What's the maximum number of docs to embed at once? 387 | // - TaskType is automatically set..? I don't see how to configure it ...? 388 | // see also https://pkg.go.dev/github.com/google/generative-ai-go/genai#TaskType 389 | func (p *Provider) GenerateEmbedding(ctx context.Context, req llm.EmbedRequest) (*llm.EmbeddingResponse, error) { 390 | em := p.client.EmbeddingModel(req.ModelConfig.ModelName) 391 | 392 | // if there is only one input, use the single API 393 | if len(req.Input) == 1 { 394 | resp, err := em.EmbedContent(ctx, genai.Text(req.Input[0])) 395 | if err != nil { 396 | return nil, errors.Wrap(err, "google embedding error") 397 | } 398 | 399 | return &llm.EmbeddingResponse{Data: []llm.Embedding{{Values: resp.Embedding.Values}}}, nil 400 | } 401 | // otherwise, use the batch API. I'm not sure there's much difference though... 402 | batchReq := em.NewBatch() 403 | for _, input := range req.Input { 404 | batchReq.AddContent(genai.Text(input)) 405 | } 406 | resp, err := em.BatchEmbedContents(ctx, batchReq) 407 | if err != nil { 408 | return nil, errors.Wrap(err, "google batch embedding error") 409 | } 410 | 411 | respVectors := make([]llm.Embedding, len(resp.Embeddings)) 412 | for i, v := range resp.Embeddings { 413 | respVectors[i] = llm.Embedding{ 414 | Values: v.Values, 415 | } 416 | } 417 | 418 | return &llm.EmbeddingResponse{ 419 | Data: respVectors, 420 | }, nil 421 | } 422 | --------------------------------------------------------------------------------