├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── Taskfile.llama.yml ├── Taskfile.nomic.yml ├── Taskfile.qdrant.yml ├── Taskfile.reranker.yml ├── Taskfile.tts.yml ├── Taskfile.unstructured.yml ├── Taskfile.whisper.yml ├── Taskfile.yml ├── cmd ├── client │ └── main.go ├── ingest │ └── main.go └── server │ └── main.go ├── config ├── config.go ├── config_api.go ├── config_authorizer.go ├── config_chain.go ├── config_completer.go ├── config_embedder.go ├── config_extractor.go ├── config_index.go ├── config_message.go ├── config_model.go ├── config_provider.go ├── config_proxy.go ├── config_renderer.go ├── config_reranker.go ├── config_router.go ├── config_segmenter.go ├── config_summarizer.go ├── config_synthesizer.go ├── config_template.go ├── config_tool.go ├── config_transcriber.go └── config_translator.go ├── docs ├── architecture.png ├── dashboard.json └── icon.png ├── examples ├── anthropic-chat │ ├── README.md │ └── compose.yaml ├── custom-completer │ ├── golang │ │ └── main.go │ └── python │ │ ├── README.md │ │ ├── main.py │ │ ├── provider.proto │ │ ├── provider_pb2.py │ │ ├── provider_pb2.pyi │ │ └── provider_pb2_grpc.py ├── custom-tool │ ├── golang │ │ └── main.go │ └── python │ │ ├── README.md │ │ ├── main.py │ │ ├── tool.proto │ │ ├── tool_pb2.py │ │ ├── tool_pb2.pyi │ │ └── tool_pb2_grpc.py ├── local-chat │ ├── README.md │ └── compose.yaml ├── local-jupyter │ ├── .gitignore │ ├── Dockerfile │ ├── README.md │ ├── compose.yaml │ └── work │ │ ├── chat.ipynb │ │ └── jupyter_ai.ipynb ├── local-openwebui │ ├── .env │ └── compose.yaml ├── local-rag │ ├── README.md │ └── compose.yaml └── openai-chat │ ├── README.md │ └── compose.yaml ├── go.mod ├── go.sum ├── pkg ├── api │ ├── api.go │ └── json │ │ ├── config.go │ │ └── handler.go ├── authorizer │ ├── oidc │ │ └── oidc.go │ ├── provider.go │ └── static │ │ └── static.go ├── chain │ ├── agent │ │ └── chain.go │ ├── assistant │ │ └── chain.go │ ├── chain.go │ └── rag │ │ ├── chain.go │ │ ├── prompt.go │ │ └── prompt.tmpl ├── client │ ├── client.go │ ├── client_completions.go │ ├── client_documents.go │ ├── client_embeddings.go │ ├── client_extractions.go │ ├── client_models.go │ ├── client_renderings.go │ ├── client_reranks.go │ ├── client_segments.go │ ├── client_summaries.go │ ├── client_synthesis.go │ ├── client_transcriptions.go │ └── option.go ├── extractor │ ├── azure │ │ ├── client.go │ │ ├── config.go │ │ └── model.go │ ├── custom │ │ ├── Taskfile.yaml │ │ ├── client.go │ │ ├── config.go │ │ ├── extractor.pb.go │ │ ├── extractor.proto │ │ └── extractor_grpc.pb.go │ ├── exa │ │ ├── client.go │ │ ├── config.go │ │ └── models.go │ ├── extractor.go │ ├── jina │ │ ├── client.go │ │ ├── client_test.go │ │ └── config.go │ ├── multi │ │ └── multi.go │ ├── tavily │ │ ├── client.go │ │ ├── config.go │ │ ├── models.go │ │ └── utils.go │ ├── text │ │ ├── config.go │ │ └── text.go │ ├── tika │ │ ├── client.go │ │ ├── client_test.go │ │ ├── config.go │ │ └── model.go │ └── unstructured │ │ ├── client.go │ │ ├── client_test.go │ │ ├── config.go │ │ └── model.go ├── index │ ├── azure │ │ ├── client.go │ │ ├── client_delete.go │ │ ├── client_index.go │ │ ├── client_list.go │ │ ├── client_query.go │ │ ├── client_test.go │ │ ├── config.go │ │ └── models.go │ ├── bing │ │ ├── client.go │ │ ├── config.go │ │ └── models.go │ ├── chroma │ │ ├── client.go │ │ ├── client_test.go │ │ ├── config.go │ │ └── models.go │ ├── custom │ │ ├── Taskfile.yaml │ │ ├── client.go │ │ ├── index.pb.go │ │ ├── index.proto │ │ └── index_grpc.pb.go │ ├── duckduckgo │ │ ├── client.go │ │ ├── client_test.go │ │ └── config.go │ ├── elasticsearch │ │ ├── client.go │ │ ├── client_test.go │ │ ├── config.go │ │ └── models.go │ ├── exa │ │ ├── client.go │ │ ├── config.go │ │ └── models.go │ ├── index.go │ ├── memory │ │ ├── client.go │ │ ├── client_test.go │ │ └── config.go │ ├── postgrest │ │ ├── client.go │ │ ├── client_delete.go │ │ ├── client_index.go │ │ ├── client_list.go │ │ ├── client_query.go │ │ ├── client_test.go │ │ ├── config.go │ │ └── types.go │ ├── qdrant │ │ ├── client.go │ │ ├── client_test.go │ │ ├── config.go │ │ └── models.go │ ├── searxng │ │ ├── client.go │ │ ├── config.go │ │ └── models.go │ ├── tavily │ │ ├── client.go │ │ ├── config.go │ │ ├── models.go │ │ └── utils.go │ └── weaviate │ │ ├── client.go │ │ ├── client_test.go │ │ ├── config.go │ │ ├── models.go │ │ ├── query.go │ │ └── query.tmpl ├── limiter │ ├── limiter.go │ ├── provider_chain.go │ ├── provider_completer.go │ ├── provider_embedder.go │ ├── provider_extractor.go │ ├── provider_renderer.go │ ├── provider_reranker.go │ ├── provider_segmenter.go │ ├── provider_synthesizer.go │ ├── provider_transcriber.go │ └── provider_translator.go ├── otel │ ├── otel.go │ ├── otel_http.go │ ├── otel_logger.go │ ├── otel_meter.go │ ├── otel_propagator.go │ ├── otel_tracer.go │ ├── otel_utils.go │ ├── provider.go │ ├── provider_chain.go │ ├── provider_completer.go │ ├── provider_embedder.go │ ├── provider_extractor.go │ ├── provider_index.go │ ├── provider_renderer.go │ ├── provider_reranker.go │ ├── provider_segmenter.go │ ├── provider_synthesizer.go │ ├── provider_tool.go │ ├── provider_transcriber.go │ └── provider_translator.go ├── provider │ ├── adapter │ │ └── reranker │ │ │ └── adapter.go │ ├── anthropic │ │ ├── completer.go │ │ ├── config.go │ │ └── util.go │ ├── azure │ │ ├── completer.go │ │ ├── config.go │ │ └── embedder.go │ ├── bedrock │ │ ├── completer.go │ │ └── config.go │ ├── cohere │ │ ├── completer.go │ │ ├── config.go │ │ ├── embedder.go │ │ └── util.go │ ├── completer.go │ ├── custom │ │ ├── Taskfile.yaml │ │ ├── client.go │ │ ├── client_completer.go │ │ ├── client_embedder.go │ │ ├── provider.pb.go │ │ ├── provider.proto │ │ └── provider_grpc.pb.go │ ├── embedder.go │ ├── gemini │ │ ├── completer.go │ │ ├── config.go │ │ ├── embedder.go │ │ └── util.go │ ├── groq │ │ ├── completer.go │ │ ├── completer_test.go │ │ ├── config.go │ │ ├── transcriber.go │ │ └── transcriber_test.go │ ├── huggingface │ │ ├── completer.go │ │ ├── config.go │ │ ├── embedder.go │ │ ├── embedder_test.go │ │ ├── reranker.go │ │ ├── reranker_test.go │ │ └── util.go │ ├── jina │ │ ├── config.go │ │ ├── embedder.go │ │ ├── embedder_test.go │ │ ├── reranker.go │ │ ├── reranker_test.go │ │ └── util.go │ ├── llama │ │ ├── completer.go │ │ ├── completer_test.go │ │ ├── config.go │ │ ├── embedder.go │ │ ├── embedder_test.go │ │ ├── reranker.go │ │ └── reranker_test.go │ ├── mistral │ │ ├── completer.go │ │ ├── config.go │ │ └── util.go │ ├── mistralrs │ │ ├── completer.go │ │ ├── config.go │ │ └── embedder.go │ ├── ollama │ │ ├── completer.go │ │ ├── config.go │ │ └── embedder.go │ ├── openai │ │ ├── completer.go │ │ ├── config.go │ │ ├── embedder.go │ │ ├── renderer.go │ │ ├── synthesizer.go │ │ ├── transcriber.go │ │ └── util.go │ ├── provider.go │ ├── renderer.go │ ├── replicate │ │ ├── client.go │ │ ├── config.go │ │ └── flux │ │ │ └── renderer.go │ ├── reranker.go │ ├── synthesizer.go │ ├── transcriber.go │ ├── whisper │ │ ├── config.go │ │ ├── transcriber.go │ │ └── util.go │ └── xai │ │ ├── completer.go │ │ └── config.go ├── router │ └── roundrobin │ │ └── router.go ├── segmenter │ ├── jina │ │ ├── client.go │ │ ├── client_test.go │ │ ├── config.go │ │ └── model.go │ ├── segmenter.go │ ├── text │ │ ├── text.go │ │ └── text_separators.go │ └── unstructured │ │ ├── client.go │ │ ├── client_test.go │ │ ├── config.go │ │ └── model.go ├── summarizer │ ├── adapter │ │ └── adapter.go │ ├── custom │ │ ├── Taskfile.yaml │ │ ├── client.go │ │ ├── config.go │ │ ├── summarizer.pb.go │ │ ├── summarizer.proto │ │ └── summarizer_grpc.pb.go │ └── summarizer.go ├── template │ ├── template.go │ ├── template_date.go │ ├── template_include.go │ └── template_message.go ├── text │ ├── normalize.go │ └── splitter.go ├── to │ └── to.go ├── tool │ ├── custom │ │ ├── Taskfile.yaml │ │ ├── client.go │ │ ├── config.go │ │ ├── tool.pb.go │ │ ├── tool.proto │ │ └── tool_grpc.pb.go │ ├── extract │ │ ├── client.go │ │ └── config.go │ ├── mcp │ │ ├── client.go │ │ └── config.go │ ├── render │ │ ├── client.go │ │ ├── config.go │ │ └── models.go │ ├── retrieve │ │ ├── client.go │ │ ├── config.go │ │ └── models.go │ ├── search │ │ ├── client.go │ │ ├── config.go │ │ └── models.go │ ├── synthesize │ │ ├── client.go │ │ ├── config.go │ │ └── models.go │ ├── tool.go │ └── translate │ │ ├── client.go │ │ ├── config.go │ │ └── models.go └── translator │ ├── azure │ ├── client.go │ ├── config.go │ └── util.go │ ├── custom │ ├── Taskfile.yaml │ ├── client.go │ ├── config.go │ ├── translator.pb.go │ ├── translator.proto │ └── translator_grpc.pb.go │ ├── deepl │ ├── client.go │ ├── config.go │ └── util.go │ ├── llm │ └── client.go │ └── translator.go ├── server ├── api │ ├── handler.go │ ├── handler_extract.go │ ├── handler_rerank.go │ ├── handler_segment.go │ ├── handler_summarize.go │ ├── handler_transcribe.go │ ├── handler_translate.go │ ├── handler_util.go │ └── models.go ├── index │ ├── handler.go │ ├── handler_delete.go │ ├── handler_index.go │ ├── handler_list.go │ ├── handler_query.go │ └── models.go ├── openai │ ├── handler.go │ ├── handler_audio_speach.go │ ├── handler_audio_transcription.go │ ├── handler_chat_completion.go │ ├── handler_embeddings.go │ ├── handler_image_edit.go.go │ ├── handler_image_generation.go │ ├── handler_models.go │ └── models.go ├── server.go └── unstructured │ ├── handler.go │ ├── handler_partition.go │ └── models.go └── test ├── test.go └── test_index.go /.dockerignore: -------------------------------------------------------------------------------- 1 | /models -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /models 2 | /config.yaml 3 | 4 | .env 5 | .vscode 6 | .DS_Store 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | 3 | FROM golang:1-alpine AS build 4 | 5 | WORKDIR /src 6 | 7 | COPY go.* ./ 8 | RUN go mod download 9 | 10 | COPY . . 11 | RUN CGO_ENABLED=0 go build -o /server /src/cmd/server 12 | RUN CGO_ENABLED=0 go build -o /client /src/cmd/client 13 | RUN CGO_ENABLED=0 go build -o /ingest /src/cmd/ingest 14 | 15 | 16 | FROM alpine 17 | 18 | RUN apk add --no-cache tini ca-certificates mailcap 19 | 20 | COPY --from=build /server /client /ingest / 21 | 22 | EXPOSE 8080 23 | 24 | ENTRYPOINT ["/sbin/tini", "--"] 25 | CMD ["/server"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Adrian Liechti 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Taskfile.llama.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | server: 7 | deps: [ download-model ] 8 | cmds: 9 | - llama-server 10 | --port 9081 11 | --log-disable 12 | --ctx-size 32768 13 | --flash-attn 14 | --model ./models/llama-3.2-3b-instruct.gguf 15 | 16 | download-model: 17 | cmds: 18 | - mkdir -p models 19 | - curl -s -L -o models/llama-3.2-3b-instruct.gguf https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_K_L.gguf?download=true 20 | 21 | status: 22 | - test -f models/llama-3.2-3b-instruct.gguf 23 | 24 | test: 25 | cmds: 26 | - | 27 | curl http://localhost:9081/v1/chat/completions \ 28 | -H "Content-Type: application/json" \ 29 | -d '{ 30 | "model": "llama", 31 | "messages": [ 32 | { 33 | "role": "user", 34 | "content": "Hello!" 35 | } 36 | ] 37 | }' -------------------------------------------------------------------------------- /Taskfile.nomic.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | server: 7 | deps: [ download-model ] 8 | cmds: 9 | - llama-server 10 | --port 9082 11 | --log-disable 12 | --embedding 13 | --ctx-size 8192 14 | --batch-size 8192 15 | --rope-scaling yarn 16 | --rope-freq-scale .75 17 | --model ./models/nomic-embed-text-v1.5.gguf 18 | 19 | download-model: 20 | cmds: 21 | - mkdir -p models 22 | - curl -s -L -o models/nomic-embed-text-v1.5.gguf https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF/resolve/main/nomic-embed-text-v1.5.f16.gguf?download=true 23 | 24 | status: 25 | - test -f models/nomic-embed-text-v1.5.gguf 26 | 27 | test: 28 | cmds: 29 | - | 30 | curl http://localhost:9082/v1/embeddings \ 31 | -H "Content-Type: application/json" \ 32 | -d '{ 33 | "model": "nomic-embed-text", 34 | "input": "Hello!" 35 | }' -------------------------------------------------------------------------------- /Taskfile.qdrant.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | server: 7 | cmds: 8 | - docker run -it --rm -p 6333:6333 -v qdrant-data:/qdrant/storage qdrant/qdrant:v1.14.0 9 | 10 | webui: 11 | cmds: 12 | - open http://localhost:6333/dashboard 13 | -------------------------------------------------------------------------------- /Taskfile.reranker.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | server: 7 | deps: [ download-model ] 8 | cmds: 9 | - llama-server 10 | --port 9082 11 | --log-disable 12 | --reranking 13 | --model ./models/bge-reranker-v2-m3.gguf 14 | 15 | download-model: 16 | cmds: 17 | - mkdir -p models 18 | - curl -s -L -o models/bge-reranker-v2-m3.gguf https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF/resolve/main/bge-reranker-v2-m3-Q4_K_M.gguf?download=true 19 | 20 | status: 21 | - test -f models/bge-reranker-v2-m3.gguf 22 | 23 | test: 24 | cmds: 25 | - | 26 | curl http://localhost:9082/v1/rerank \ 27 | -H "Content-Type: application/json" \ 28 | -d '{ 29 | "model": "bge-reranker-v2-m3", 30 | "query": "What is panda?", 31 | "top_n": 3, 32 | "documents": [ 33 | "hi", 34 | "it is a bear", 35 | "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." 36 | ] 37 | }' -------------------------------------------------------------------------------- /Taskfile.tts.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # https://taskfile.dev 4 | 5 | version: "3" 6 | 7 | tasks: 8 | server: 9 | cmds: 10 | - docker run -it --rm -p 9085:8000 -v openedai-data:/app/voices ghcr.io/matatonic/openedai-speech-min -------------------------------------------------------------------------------- /Taskfile.unstructured.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | server: 7 | cmds: 8 | - docker run -it -p 9085:8000 -v unstructured-cache:/home/notebook-user/.cache quay.io/unstructured-io/unstructured-api:0.0.80 9 | -------------------------------------------------------------------------------- /Taskfile.whisper.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | server: 7 | deps: [ download-model ] 8 | cmds: 9 | - whisper-server 10 | --port 9083 11 | --convert 12 | --model ./models/whisper-large-v3-turbo.bin 13 | 14 | download-model: 15 | cmds: 16 | - mkdir -p models 17 | - curl -s -L -o models/whisper-large-v3-turbo.bin https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo-q5_0.bin?download=true 18 | 19 | status: 20 | - test -f models/whisper-large-v3-turbo.bin 21 | 22 | test: 23 | cmds: 24 | - curl -Lo jfk.wav https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav 25 | - | 26 | curl http://localhost:9083/inference \ 27 | -H "Content-Type: multipart/form-data" \ 28 | -F file="@jfk.wav" \ 29 | -F response_format="json" 30 | - rm jfk.wav -------------------------------------------------------------------------------- /Taskfile.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | vars: 6 | REPOSITORY: ghcr.io/adrianliechti/wingman-platform 7 | #REPOSITORY: ghcr.io/adrianliechti/wingman-platform:nightly 8 | 9 | includes: 10 | llama: 11 | taskfile: ./Taskfile.llama.yml 12 | 13 | nomic: 14 | taskfile: ./Taskfile.nomic.yml 15 | 16 | reranker: 17 | taskfile: ./Taskfile.reranker.yml 18 | 19 | whisper: 20 | taskfile: ./Taskfile.whisper.yml 21 | 22 | tts: 23 | taskfile: ./Taskfile.tts.yml 24 | 25 | unstructured: 26 | taskfile: ./Taskfile.unstructured.yml 27 | 28 | qdrant: 29 | taskfile: ./Taskfile.qdrant.yml 30 | 31 | tasks: 32 | publish: 33 | cmds: 34 | - docker buildx build . --push --platform linux/amd64,linux/arm64 --tag {{.REPOSITORY}} 35 | 36 | server: 37 | dotenv: ['.env' ] 38 | 39 | cmds: 40 | - go run cmd/server/main.go 41 | 42 | client: 43 | cmds: 44 | - go run cmd/client/main.go 45 | 46 | webui: 47 | cmds: 48 | - docker run -it --rm --pull always -p 8000:8000 -e OPENAI_BASE_URL=http://host.docker.internal:8080/v1 ghcr.io/adrianliechti/wingman-chat 49 | 50 | start: 51 | cmds: 52 | - docker run --name wingman -d --rm --pull always -p 8080:8080 --env-file .env -v ./config.yaml:/config.yaml -v ./prompts:/prompts ghcr.io/adrianliechti/wingman 53 | 54 | otel: 55 | cmds: 56 | - docker run --rm -it -p 3000:3000 -p 4317:4317 -p 4318:4318 grafana/otel-lgtm 57 | 58 | stop: 59 | cmds: 60 | - docker rm wingman -f -------------------------------------------------------------------------------- /cmd/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | 7 | "github.com/adrianliechti/wingman/config" 8 | "github.com/adrianliechti/wingman/server" 9 | 10 | "github.com/adrianliechti/wingman/pkg/otel" 11 | ) 12 | 13 | func main() { 14 | portFlag := flag.Int("port", 8080, "server port") 15 | addressFlag := flag.String("address", "", "server address") 16 | configFlag := flag.String("config", "config.yaml", "configuration path") 17 | 18 | flag.Parse() 19 | 20 | cfg, err := config.Parse(*configFlag) 21 | 22 | if err != nil { 23 | panic(err) 24 | } 25 | 26 | cfg.Address = fmt.Sprintf("%s:%d", *addressFlag, *portFlag) 27 | 28 | s, err := server.New(cfg) 29 | 30 | if err != nil { 31 | panic(err) 32 | } 33 | 34 | if err := otel.Setup("llama", "0.0.1"); err != nil { 35 | panic(err) 36 | } 37 | 38 | if err := s.ListenAndServe(); err != nil { 39 | panic(err) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /config/config_authorizer.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/authorizer" 8 | "github.com/adrianliechti/wingman/pkg/authorizer/oidc" 9 | "github.com/adrianliechti/wingman/pkg/authorizer/static" 10 | ) 11 | 12 | type authorizerConfig struct { 13 | Type string `yaml:"type"` 14 | 15 | Token string `yaml:"token"` 16 | 17 | Issuer string `yaml:"issuer"` 18 | Audience string `yaml:"audience"` 19 | } 20 | 21 | func (c *Config) registerAuthorizer(f *configFile) error { 22 | for _, a := range f.Authorizers { 23 | authorizer, err := createAuthorizer(a) 24 | 25 | if err != nil { 26 | return err 27 | } 28 | 29 | c.Authorizers = append(c.Authorizers, authorizer) 30 | } 31 | 32 | return nil 33 | } 34 | 35 | func createAuthorizer(cfg authorizerConfig) (authorizer.Provider, error) { 36 | switch strings.ToLower(cfg.Type) { 37 | case "static": 38 | return staticAuthorizer(cfg) 39 | 40 | case "oidc": 41 | return oidcAuthorizer(cfg) 42 | 43 | default: 44 | return nil, errors.New("invalid authorizer type: " + cfg.Type) 45 | } 46 | } 47 | 48 | func staticAuthorizer(cfg authorizerConfig) (authorizer.Provider, error) { 49 | return static.New(cfg.Token) 50 | } 51 | 52 | func oidcAuthorizer(cfg authorizerConfig) (authorizer.Provider, error) { 53 | return oidc.New(cfg.Issuer, cfg.Audience) 54 | } 55 | -------------------------------------------------------------------------------- /config/config_message.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | ) 9 | 10 | type message struct { 11 | Role string `yaml:"role"` 12 | Content string `yaml:"content"` 13 | } 14 | 15 | func parseMessages(messages []message) ([]provider.Message, error) { 16 | result := make([]provider.Message, 0) 17 | 18 | for _, m := range messages { 19 | message, err := parseMessage(m) 20 | 21 | if err != nil { 22 | return nil, err 23 | 24 | } 25 | 26 | result = append(result, *message) 27 | } 28 | 29 | return result, nil 30 | } 31 | 32 | func parseMessage(message message) (*provider.Message, error) { 33 | var role provider.MessageRole 34 | 35 | if strings.EqualFold(message.Role, string(provider.MessageRoleSystem)) { 36 | role = provider.MessageRoleSystem 37 | } 38 | 39 | if strings.EqualFold(message.Role, string(provider.MessageRoleUser)) { 40 | role = provider.MessageRoleUser 41 | } 42 | 43 | if strings.EqualFold(message.Role, string(provider.MessageRoleAssistant)) { 44 | role = provider.MessageRoleAssistant 45 | } 46 | 47 | if role == "" { 48 | return nil, errors.New("invalid message role: " + message.Role) 49 | } 50 | 51 | return &provider.Message{ 52 | Role: role, 53 | 54 | Content: []provider.Content{ 55 | provider.TextContent(message.Content), 56 | }, 57 | }, nil 58 | } 59 | -------------------------------------------------------------------------------- /config/config_proxy.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | type proxyConfig struct { 4 | URL string `yaml:"url"` 5 | } 6 | -------------------------------------------------------------------------------- /config/config_synthesizer.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | "github.com/adrianliechti/wingman/pkg/provider/openai" 9 | ) 10 | 11 | func (cfg *Config) RegisterSynthesizer(id string, p provider.Synthesizer) { 12 | cfg.RegisterModel(id) 13 | 14 | if cfg.synthesizer == nil { 15 | cfg.synthesizer = make(map[string]provider.Synthesizer) 16 | } 17 | 18 | if _, ok := cfg.synthesizer[""]; !ok { 19 | cfg.synthesizer[""] = p 20 | } 21 | 22 | cfg.synthesizer[id] = p 23 | } 24 | 25 | func (cfg *Config) Synthesizer(id string) (provider.Synthesizer, error) { 26 | if cfg.synthesizer != nil { 27 | if s, ok := cfg.synthesizer[id]; ok { 28 | return s, nil 29 | } 30 | } 31 | 32 | return nil, errors.New("synthesizer not found: " + id) 33 | } 34 | 35 | func createSynthesizer(cfg providerConfig, model modelContext) (provider.Synthesizer, error) { 36 | switch strings.ToLower(cfg.Type) { 37 | case "openai": 38 | return openaiSynthesizer(cfg, model) 39 | 40 | default: 41 | return nil, errors.New("invalid synthesizer type: " + cfg.Type) 42 | } 43 | } 44 | 45 | func openaiSynthesizer(cfg providerConfig, model modelContext) (provider.Synthesizer, error) { 46 | var options []openai.Option 47 | 48 | if cfg.Token != "" { 49 | options = append(options, openai.WithToken(cfg.Token)) 50 | } 51 | 52 | return openai.NewSynthesizer(cfg.URL, model.ID, options...) 53 | } 54 | -------------------------------------------------------------------------------- /config/config_template.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | 7 | "github.com/adrianliechti/wingman/pkg/template" 8 | ) 9 | 10 | func parseTemplate(val string) (*template.Template, error) { 11 | if val == "" { 12 | return nil, errors.New("empty template") 13 | } 14 | 15 | if data, err := os.ReadFile(val); err == nil { 16 | return template.NewTemplate(string(data)) 17 | } 18 | 19 | return template.NewTemplate(val) 20 | } 21 | -------------------------------------------------------------------------------- /docs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrianliechti/wingman/506347beaa1133a05370938a686041bfd1798b30/docs/architecture.png -------------------------------------------------------------------------------- /docs/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrianliechti/wingman/506347beaa1133a05370938a686041bfd1798b30/docs/icon.png -------------------------------------------------------------------------------- /examples/anthropic-chat/README.md: -------------------------------------------------------------------------------- 1 | # Antropic Adapter 2 | 3 | ```shell 4 | export ANTHROPIC_API_KEY=sk-ant-...... 5 | 6 | docker compose up --force-recreate --remove-orphans 7 | ``` 8 | 9 | open [localhost:8000](http://localhost:8000) in your favorite browser 10 | 11 | ## Completion API 12 | 13 | ```shell 14 | curl http://localhost:8080/v1/chat/completions \ 15 | -H "Content-Type: application/json" \ 16 | -d '{ 17 | "model": "claude-sonnet", 18 | "messages": [ 19 | { 20 | "role": "system", 21 | "content": "You are a helpful assistant." 22 | }, 23 | { 24 | "role": "user", 25 | "content": "Hello!" 26 | } 27 | ] 28 | }' 29 | ``` 30 | 31 | ## Vision API 32 | 33 | ```shell 34 | curl http://localhost:8080/v1/chat/completions \ 35 | -H "Content-Type: application/json" \ 36 | -d '{ 37 | "model": "claude-sonnet", 38 | "messages": [ 39 | { 40 | "role": "user", 41 | "content": [ 42 | { 43 | "type": "text", 44 | "text": "What’s in this image?" 45 | }, 46 | { 47 | "type": "image_url", 48 | "image_url": { 49 | "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" 50 | } 51 | } 52 | ] 53 | } 54 | ] 55 | }' 56 | ``` -------------------------------------------------------------------------------- /examples/anthropic-chat/compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | platform: 3 | image: ghcr.io/adrianliechti/wingman 4 | pull_policy: always 5 | build: 6 | context: ../../ 7 | dockerfile: Dockerfile 8 | ports: 9 | - 8080:8080 10 | configs: 11 | - source: platform 12 | target: /config.yaml 13 | 14 | web: 15 | image: ghcr.io/adrianliechti/wingman-chat 16 | pull_policy: always 17 | ports: 18 | - 8000:8000 19 | environment: 20 | - OPENAI_BASE_URL=http://platform:8080/v1 21 | depends_on: 22 | - platform 23 | 24 | configs: 25 | platform: 26 | content: | 27 | providers: 28 | - type: anthropic 29 | token: ${ANTHROPIC_API_KEY} 30 | 31 | # https://docs.anthropic.com/en/docs/models-overview 32 | models: 33 | claude-sonnet: 34 | id: claude-3-5-sonnet-latest 35 | 36 | claude-haiku: 37 | id: claude-3-5-haiku-latest 38 | -------------------------------------------------------------------------------- /examples/custom-completer/golang/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "strings" 7 | "time" 8 | 9 | "github.com/adrianliechti/wingman/pkg/provider/custom" 10 | "github.com/google/uuid" 11 | 12 | "google.golang.org/grpc" 13 | ) 14 | 15 | func main() { 16 | l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", 50051)) 17 | 18 | if err != nil { 19 | panic(err) 20 | } 21 | 22 | s := grpc.NewServer() 23 | custom.RegisterCompleterServer(s, newServer()) 24 | s.Serve(l) 25 | } 26 | 27 | type server struct { 28 | custom.UnsafeCompleterServer 29 | } 30 | 31 | func newServer() *server { 32 | return &server{} 33 | } 34 | 35 | func (s *server) Complete(r *custom.CompleteRequest, stream grpc.ServerStreamingServer[custom.Completion]) error { 36 | text := "Please provide me more information about the topic." 37 | 38 | words := strings.Split(text, " ") 39 | 40 | for _, word := range words { 41 | content := word + " " 42 | 43 | time.Sleep(300 * time.Millisecond) 44 | 45 | stream.Send(&custom.Completion{ 46 | Id: uuid.NewString(), 47 | Model: "test", 48 | 49 | Delta: &custom.Message{ 50 | Role: "assistant", 51 | 52 | Content: []*custom.Content{ 53 | { 54 | Text: &content, 55 | }, 56 | }, 57 | }, 58 | }) 59 | } 60 | 61 | stream.Send(&custom.Completion{ 62 | Id: uuid.NewString(), 63 | Model: "test", 64 | 65 | Message: &custom.Message{ 66 | Role: "assistant", 67 | 68 | Content: []*custom.Content{ 69 | { 70 | Text: &text, 71 | }, 72 | }, 73 | }, 74 | }) 75 | 76 | return nil 77 | } 78 | -------------------------------------------------------------------------------- /examples/custom-completer/python/README.md: -------------------------------------------------------------------------------- 1 | ### Generate gRPC Server & Messages 2 | 3 | ```shell 4 | pip install grpcio-tools grpcio-reflection 5 | ``` 6 | 7 | ```shell 8 | $ curl -Lo provider.proto https://raw.githubusercontent.com/adrianliechti/wingman/refs/heads/main/pkg/provider/custom/provider.proto 9 | $ python -m grpc_tools.protoc -I . --python_out=. --pyi_out=. --grpc_python_out=. provider.proto 10 | ``` 11 | 12 | ### Run this Tool 13 | 14 | ```shell 15 | $ python main.go 16 | > Tool Server started. Listening on port 50051 17 | ``` 18 | 19 | ### Example Configuration 20 | 21 | ```yaml 22 | providers: 23 | - type: custom 24 | url: grpc://localhost:50051 25 | models: 26 | test: 27 | type: completer 28 | ``` -------------------------------------------------------------------------------- /examples/custom-tool/python/tool.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option go_package = "github.com/adrianliechti/wingman/pkg/tool/custom;custom"; 4 | 5 | package tool; 6 | 7 | service Tool { 8 | rpc Tools (ToolsRequest) returns (ToolsResponse) {} 9 | rpc Execute (ExecuteRequest) returns (ResultResponse) {} 10 | } 11 | 12 | message ToolsRequest { 13 | } 14 | 15 | message ToolsResponse { 16 | repeated Definition definitions = 1; 17 | } 18 | 19 | message Definition { 20 | string name = 1; 21 | string description = 2; 22 | 23 | string parameters = 3; 24 | } 25 | 26 | message ExecuteRequest { 27 | string name = 1; 28 | string parameters = 2; 29 | } 30 | 31 | message ResultResponse { 32 | string data = 1; 33 | } -------------------------------------------------------------------------------- /examples/local-chat/README.md: -------------------------------------------------------------------------------- 1 | # Local Chat 2 | 3 | ## Run Example 4 | - [Docker Desktop](https://www.docker.com/products/docker-desktop/) 5 | 6 | Start Example Application 7 | 8 | ```shell 9 | docker compose up --force-recreate --remove-orphans 10 | ``` 11 | 12 | ## Open Web UI 13 | 14 | ```shell 15 | $ open http://localhost:8000 16 | ``` 17 | 18 | ## Completion API 19 | 20 | The Completion API provides compatibility for the OpenAI API standard, allowing easier integrations into existing applications. (Documentation: https://platform.openai.com/docs/api-reference/chat/create) 21 | 22 | ```shell 23 | curl http://localhost:8080/v1/chat/completions \ 24 | -H "Content-Type: application/json" \ 25 | -d '{ 26 | "model": "llama", 27 | "messages": [ 28 | { 29 | "role": "system", 30 | "content": "You are a helpful assistant." 31 | }, 32 | { 33 | "role": "user", 34 | "content": "Hello!" 35 | } 36 | ] 37 | }' 38 | ``` 39 | 40 | ## Embedding API 41 | 42 | ```shell 43 | curl http://localhost:8080/v1/embeddings \ 44 | -H "Content-Type: application/json" \ 45 | -d '{ 46 | "model": "nomic", 47 | "input": "Your text string goes here" 48 | }' 49 | ``` -------------------------------------------------------------------------------- /examples/local-chat/compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | platform: 3 | image: ghcr.io/adrianliechti/wingman 4 | pull_policy: always 5 | build: 6 | context: ../../ 7 | dockerfile: Dockerfile 8 | ports: 9 | - 8080:8080 10 | configs: 11 | - source: platform 12 | target: /config.yaml 13 | depends_on: 14 | - ollama 15 | - ollama-companion 16 | 17 | ollama: 18 | image: ollama/ollama:0.6.7 19 | pull_policy: always 20 | volumes: 21 | - ollama-data:/root/.ollama 22 | 23 | ollama-companion: 24 | image: ghcr.io/adrianliechti/ollama-companion 25 | pull_policy: always 26 | restart: on-failure 27 | environment: 28 | - OLLAMA_HOST=ollama:11434 29 | - OLLAMA_MODELS=llama3.2:1b,nomic-embed-text:v1.5 30 | 31 | web: 32 | image: ghcr.io/adrianliechti/wingman-chat 33 | pull_policy: always 34 | ports: 35 | - 8000:8000 36 | environment: 37 | - OPENAI_BASE_URL=http://platform:8080/v1 38 | depends_on: 39 | - platform 40 | 41 | configs: 42 | platform: 43 | content: | 44 | providers: 45 | - type: ollama 46 | url: http://ollama:11434 47 | 48 | # https://ollama.com/library 49 | models: 50 | llama: 51 | id: llama3.2:1b 52 | 53 | nomic-embed: 54 | id: nomic-embed-text:v1.5 55 | 56 | volumes: 57 | ollama-data: -------------------------------------------------------------------------------- /examples/local-jupyter/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints -------------------------------------------------------------------------------- /examples/local-jupyter/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM quay.io/jupyter/minimal-notebook:2024-12-23 2 | 3 | RUN pip install openai jupyter-ai[all] 4 | 5 | CMD [ "start-notebook.sh", "--NotebookApp.token=''", "-ServerApp.root_dir=/home/jovyan/work" ] -------------------------------------------------------------------------------- /examples/local-jupyter/README.md: -------------------------------------------------------------------------------- 1 | # Local Chat using LangChain 2 | 3 | ## Run Example 4 | 5 | - [Docker Desktop](https://www.docker.com/products/docker-desktop/) 6 | 7 | Start Example Application 8 | 9 | ```shell 10 | docker compose up --force-recreate --remove-orphans 11 | ``` 12 | 13 | ## Open Jupyter UI 14 | 15 | ```shell 16 | $ open http://localhost:8888 17 | ``` -------------------------------------------------------------------------------- /examples/local-jupyter/work/chat.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "45e12094-010e-47f2-9993-8a206bb097ad", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import openai\n", 11 | "\n", 12 | "completion = openai.chat.completions.create(\n", 13 | " messages=[\n", 14 | " {\n", 15 | " \"role\": \"user\",\n", 16 | " \"content\": \"Say this is a test\",\n", 17 | " }\n", 18 | " ],\n", 19 | " model=\"llama\",\n", 20 | ")\n", 21 | "\n", 22 | "print(completion.choices[0].message.content)" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "f519fdeb-92a3-4c64-88a9-485e105ed15a", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [] 32 | } 33 | ], 34 | "metadata": { 35 | "kernelspec": { 36 | "display_name": "Python 3 (ipykernel)", 37 | "language": "python", 38 | "name": "python3" 39 | }, 40 | "language_info": { 41 | "codemirror_mode": { 42 | "name": "ipython", 43 | "version": 3 44 | }, 45 | "file_extension": ".py", 46 | "mimetype": "text/x-python", 47 | "name": "python", 48 | "nbconvert_exporter": "python", 49 | "pygments_lexer": "ipython3", 50 | "version": "3.12.8" 51 | } 52 | }, 53 | "nbformat": 4, 54 | "nbformat_minor": 5 55 | } 56 | -------------------------------------------------------------------------------- /examples/local-jupyter/work/jupyter_ai.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "efd136aa-cac0-4e6e-be32-539b39b41aca", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext jupyter_ai_magics" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "37ad48b9-fd8f-4d8d-b273-37d19805ecc5", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "%%ai chatgpt\n", 21 | "write a program to calculate pi" 22 | ] 23 | } 24 | ], 25 | "metadata": { 26 | "kernelspec": { 27 | "display_name": "Python 3 (ipykernel)", 28 | "language": "python", 29 | "name": "python3" 30 | }, 31 | "language_info": { 32 | "codemirror_mode": { 33 | "name": "ipython", 34 | "version": 3 35 | }, 36 | "file_extension": ".py", 37 | "mimetype": "text/x-python", 38 | "name": "python", 39 | "nbconvert_exporter": "python", 40 | "pygments_lexer": "ipython3", 41 | "version": "3.12.8" 42 | } 43 | }, 44 | "nbformat": 4, 45 | "nbformat_minor": 5 46 | } 47 | -------------------------------------------------------------------------------- /examples/local-openwebui/.env: -------------------------------------------------------------------------------- 1 | WEBUI_URL=http://localhost:3000 2 | WEBUI_NAME=Platform Chat 3 | 4 | CHAT_MODEL=gpt-4o 5 | TASK_MODEL=gpt-4o-turbo 6 | EMBEDDING_MODEL=text-embedding-3-large 7 | STT_MODEL=whisper-1 8 | TTS_MODEL=tts-1 9 | TTS_VOICE=alloy 10 | IMAGE_MODEL=dall-e-3 11 | 12 | OPENAI_API_BASE=http://host.docker.internal:8080/v1 13 | OPENAI_API_KEY=- -------------------------------------------------------------------------------- /examples/local-rag/README.md: -------------------------------------------------------------------------------- 1 | # Local RAG 2 | 3 | ## Run Example 4 | - [Docker Desktop](https://www.docker.com/products/docker-desktop/) 5 | 6 | Start Example Application 7 | 8 | ```shell 9 | docker compose up --force-recreate --remove-orphans 10 | ``` 11 | 12 | ## Index Documents 13 | 14 | ```shell 15 | docker run -it --rm -v ./:/data -w /data --pull=always ghcr.io/adrianliechti/wingman /ingest -url http://host.docker.internal:8080 -token - 16 | ``` 17 | 18 | ## Verify Documents 19 | 20 | ``` 21 | open http://localhost:8080/v1/index/docs 22 | open http://localhost:6333/dashboard#/collections/docs 23 | ``` 24 | 25 | ## Open Web UI 26 | 27 | ```shell 28 | $ open http://localhost:8000 29 | ``` 30 | -------------------------------------------------------------------------------- /examples/openai-chat/compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | platform: 3 | image: ghcr.io/adrianliechti/wingman 4 | pull_policy: always 5 | build: 6 | context: ../../ 7 | dockerfile: Dockerfile 8 | ports: 9 | - 8080:8080 10 | configs: 11 | - source: platform 12 | target: /config.yaml 13 | 14 | web: 15 | image: ghcr.io/adrianliechti/wingman-chat 16 | pull_policy: always 17 | ports: 18 | - 8000:8000 19 | environment: 20 | - OPENAI_BASE_URL=http://platform:8080/v1 21 | depends_on: 22 | - platform 23 | 24 | configs: 25 | platform: 26 | content: | 27 | providers: 28 | - type: openai 29 | token: ${OPENAI_API_KEY} 30 | 31 | # https://platform.openai.com/docs/models 32 | models: 33 | - gpt-4o 34 | - gpt-4o-mini 35 | - text-embedding-3-small 36 | - text-embedding-3-large 37 | - whisper-1 38 | - dall-e-3 39 | - tts-1 40 | - tts-1-hd 41 | -------------------------------------------------------------------------------- /pkg/api/api.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | ) 8 | 9 | type Provider interface { 10 | http.Handler 11 | } 12 | 13 | type Schema = provider.Schema 14 | -------------------------------------------------------------------------------- /pkg/api/json/config.go: -------------------------------------------------------------------------------- 1 | package json 2 | 3 | import ( 4 | "github.com/adrianliechti/wingman/pkg/api" 5 | "github.com/adrianliechti/wingman/pkg/provider" 6 | ) 7 | 8 | type Option func(*Handler) 9 | 10 | func WithCompleter(p provider.Completer) Option { 11 | return func(c *Handler) { 12 | c.completer = p 13 | } 14 | } 15 | 16 | func WithInputSchema(schema api.Schema) Option { 17 | return func(c *Handler) { 18 | c.input = &schema 19 | } 20 | } 21 | 22 | func WithOutputSchema(schema api.Schema) Option { 23 | return func(c *Handler) { 24 | c.output = &schema 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /pkg/authorizer/oidc/oidc.go: -------------------------------------------------------------------------------- 1 | package oidc 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "strings" 8 | 9 | "github.com/coreos/go-oidc/v3/oidc" 10 | ) 11 | 12 | type Provider struct { 13 | provider *oidc.Provider 14 | verifier *oidc.IDTokenVerifier 15 | } 16 | 17 | func New(issuer, audience string) (*Provider, error) { 18 | cfg := &oidc.Config{ 19 | ClientID: audience, 20 | } 21 | 22 | provider, err := oidc.NewProvider(context.Background(), issuer) 23 | 24 | if err != nil { 25 | return nil, err 26 | } 27 | 28 | verifier := provider.Verifier(cfg) 29 | 30 | return &Provider{ 31 | provider: provider, 32 | verifier: verifier, 33 | }, nil 34 | } 35 | 36 | func (p *Provider) Verify(ctx context.Context, r *http.Request) error { 37 | header := r.Header.Get("Authorization") 38 | 39 | if header == "" { 40 | return errors.New("missing authorization header") 41 | } 42 | 43 | if !strings.HasPrefix(header, "Bearer ") { 44 | return errors.New("invalid authorization header") 45 | } 46 | 47 | token := strings.TrimPrefix(header, "Bearer ") 48 | 49 | idtoken, err := p.verifier.Verify(ctx, token) 50 | 51 | if err != nil { 52 | return err 53 | } 54 | 55 | _ = idtoken 56 | return nil 57 | } 58 | -------------------------------------------------------------------------------- /pkg/authorizer/provider.go: -------------------------------------------------------------------------------- 1 | package authorizer 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | type Provider interface { 9 | Verify(ctx context.Context, r *http.Request) error 10 | } 11 | -------------------------------------------------------------------------------- /pkg/authorizer/static/static.go: -------------------------------------------------------------------------------- 1 | package static 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "strings" 8 | ) 9 | 10 | type Provider struct { 11 | token string 12 | } 13 | 14 | func New(token string) (*Provider, error) { 15 | return &Provider{ 16 | token: token, 17 | }, nil 18 | } 19 | 20 | func (p *Provider) Verify(ctx context.Context, r *http.Request) error { 21 | if p.token == "" { 22 | return nil 23 | } 24 | 25 | header := r.Header.Get("Authorization") 26 | 27 | if header == "" { 28 | return errors.New("missing authorization header") 29 | } 30 | 31 | if !strings.HasPrefix(header, "Bearer ") { 32 | return errors.New("invalid authorization header") 33 | } 34 | 35 | token := strings.TrimPrefix(header, "Bearer ") 36 | 37 | if !strings.EqualFold(token, p.token) { 38 | return errors.New("invalid token") 39 | } 40 | 41 | return nil 42 | } 43 | -------------------------------------------------------------------------------- /pkg/chain/chain.go: -------------------------------------------------------------------------------- 1 | package chain 2 | 3 | import ( 4 | "github.com/adrianliechti/wingman/pkg/provider" 5 | ) 6 | 7 | type Provider interface { 8 | provider.Completer 9 | } 10 | -------------------------------------------------------------------------------- /pkg/chain/rag/prompt.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | _ "embed" 5 | ) 6 | 7 | var ( 8 | //go:embed prompt.tmpl 9 | promptTemplate string 10 | ) 11 | 12 | type promptData struct { 13 | Input string 14 | Results []promptResult 15 | } 16 | 17 | type promptResult struct { 18 | Title string 19 | Source string 20 | Content string 21 | 22 | Metadata map[string]string 23 | } 24 | -------------------------------------------------------------------------------- /pkg/chain/rag/prompt.tmpl: -------------------------------------------------------------------------------- 1 | {{- if .Results -}} 2 | Use the provided documents to answer questions: 3 | {{ range .Results }} 4 | --- 5 | {{- if .Title }} 6 | Title: {{ .Title }} 7 | {{- end }} 8 | {{- if .Source }} 9 | Source: {{ .Source }} 10 | {{- end }} 11 | {{ .Content }} 12 | {{ end }} 13 | --- 14 | {{- end -}} 15 | 16 | Question: {{ .Input }} -------------------------------------------------------------------------------- /pkg/client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Client struct { 8 | Models ModelService 9 | 10 | Embeddings EmbeddingService 11 | Completions CompletionService 12 | 13 | Syntheses SynthesisService 14 | Transcriptions TranscriptionService 15 | Renderings RenderingService 16 | 17 | Segments SegmentService 18 | Extractions ExtractionService 19 | 20 | Documents DocumentService 21 | Summaries SummaryService 22 | } 23 | 24 | func New(url string, opts ...RequestOption) *Client { 25 | opts = append(opts, WithURL(url)) 26 | 27 | return &Client{ 28 | Models: NewModelService(opts...), 29 | 30 | Embeddings: NewEmbeddingService(opts...), 31 | Completions: NewCompletionService(opts...), 32 | 33 | Syntheses: NewSynthesisService(opts...), 34 | Transcriptions: NewTranscriptionService(opts...), 35 | Renderings: NewRenderingService(opts...), 36 | 37 | Segments: NewSegmentService(opts...), 38 | Extractions: NewExtractionService(opts...), 39 | 40 | Documents: NewDocumentService(opts...), 41 | Summaries: NewSummaryService(opts...), 42 | } 43 | } 44 | 45 | func newRequestConfig(opts ...RequestOption) *RequestConfig { 46 | c := &RequestConfig{ 47 | Client: http.DefaultClient, 48 | } 49 | 50 | for _, opt := range opts { 51 | opt(c) 52 | } 53 | 54 | return c 55 | } 56 | 57 | func Ptr[T any](v T) *T { 58 | return &v 59 | } 60 | -------------------------------------------------------------------------------- /pkg/client/client_embeddings.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | "github.com/adrianliechti/wingman/pkg/provider/openai" 9 | ) 10 | 11 | type EmbeddingService struct { 12 | Options []RequestOption 13 | } 14 | 15 | func NewEmbeddingService(opts ...RequestOption) EmbeddingService { 16 | return EmbeddingService{ 17 | Options: opts, 18 | } 19 | } 20 | 21 | type Embedding = provider.Embedding 22 | 23 | type EmbeddingsRequest struct { 24 | Model string 25 | 26 | Texts []string 27 | } 28 | 29 | func (r *EmbeddingService) New(ctx context.Context, input EmbeddingsRequest, opts ...RequestOption) (*Embedding, error) { 30 | cfg := newRequestConfig(append(r.Options, opts...)...) 31 | url := strings.TrimRight(cfg.URL, "/") + "/v1/" 32 | 33 | options := []openai.Option{} 34 | 35 | if cfg.Token != "" { 36 | options = append(options, openai.WithToken(cfg.Token)) 37 | } 38 | 39 | if cfg.Client != nil { 40 | options = append(options, openai.WithClient(cfg.Client)) 41 | } 42 | 43 | p, err := openai.NewEmbedder(url, input.Model, options...) 44 | 45 | if err != nil { 46 | return nil, err 47 | } 48 | 49 | return p.Embed(ctx, input.Texts) 50 | } 51 | -------------------------------------------------------------------------------- /pkg/client/client_models.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "net/http" 8 | 9 | "github.com/adrianliechti/wingman/pkg/provider" 10 | "github.com/adrianliechti/wingman/server/openai" 11 | ) 12 | 13 | type ModelService struct { 14 | Options []RequestOption 15 | } 16 | 17 | func NewModelService(opts ...RequestOption) ModelService { 18 | return ModelService{ 19 | Options: opts, 20 | } 21 | } 22 | 23 | type Model = provider.Model 24 | 25 | func (r *ModelService) List(ctx context.Context, opts ...RequestOption) ([]Model, error) { 26 | c := newRequestConfig(append(r.Options, opts...)...) 27 | 28 | req, _ := http.NewRequestWithContext(ctx, "GET", c.URL+"/v1/models", nil) 29 | req.Header.Set("Content-Type", "application/json") 30 | 31 | if c.Token != "" { 32 | req.Header.Set("Authorization", "Bearer "+c.Token) 33 | } 34 | 35 | resp, err := c.Client.Do(req) 36 | 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | defer resp.Body.Close() 42 | 43 | if resp.StatusCode != http.StatusOK { 44 | return nil, errors.New(resp.Status) 45 | } 46 | 47 | var result openai.ModelList 48 | 49 | if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { 50 | return nil, err 51 | } 52 | 53 | var models []provider.Model 54 | 55 | for _, m := range result.Models { 56 | models = append(models, provider.Model{ 57 | ID: m.ID, 58 | }) 59 | } 60 | 61 | return models, nil 62 | } 63 | -------------------------------------------------------------------------------- /pkg/client/client_renderings.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | "github.com/adrianliechti/wingman/pkg/provider/openai" 9 | ) 10 | 11 | type RenderingService struct { 12 | Options []RequestOption 13 | } 14 | 15 | func NewRenderingService(opts ...RequestOption) RenderingService { 16 | return RenderingService{ 17 | Options: opts, 18 | } 19 | } 20 | 21 | type Rendering = provider.Rendering 22 | 23 | type RenderOptions = provider.RenderOptions 24 | 25 | type RenderingRequest struct { 26 | RenderOptions 27 | 28 | Model string 29 | 30 | Input string 31 | } 32 | 33 | func (r *RenderingService) New(ctx context.Context, input RenderingRequest, opts ...RequestOption) (*Rendering, error) { 34 | cfg := newRequestConfig(append(r.Options, opts...)...) 35 | url := strings.TrimRight(cfg.URL, "/") + "/v1/" 36 | 37 | options := []openai.Option{} 38 | 39 | if cfg.Token != "" { 40 | options = append(options, openai.WithToken(cfg.Token)) 41 | } 42 | 43 | if cfg.Client != nil { 44 | options = append(options, openai.WithClient(cfg.Client)) 45 | } 46 | 47 | p, err := openai.NewRenderer(url, input.Model, options...) 48 | 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | return p.Render(ctx, input.Input, &input.RenderOptions) 54 | } 55 | -------------------------------------------------------------------------------- /pkg/client/client_reranks.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "net/http" 9 | 10 | "github.com/adrianliechti/wingman/server/api" 11 | ) 12 | 13 | type RerankService struct { 14 | Options []RequestOption 15 | } 16 | 17 | func NewRerankService(opts ...RequestOption) *RerankService { 18 | return &RerankService{ 19 | Options: opts, 20 | } 21 | } 22 | 23 | type Rerank = api.Result 24 | type RerankRequest = api.RerankRequest 25 | 26 | func (r *RerankService) New(ctx context.Context, input RerankRequest, opts ...RequestOption) ([]Rerank, error) { 27 | c := newRequestConfig(append(r.Options, opts...)...) 28 | 29 | var data bytes.Buffer 30 | 31 | if err := json.NewEncoder(&data).Encode(input); err != nil { 32 | return nil, err 33 | } 34 | 35 | req, _ := http.NewRequestWithContext(ctx, "POST", c.URL+"/v1/rerank", &data) 36 | req.Header.Set("Content-Type", "application/json") 37 | 38 | if c.Token != "" { 39 | req.Header.Set("Authorization", "Bearer "+c.Token) 40 | } 41 | 42 | resp, err := c.Client.Do(req) 43 | 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | defer resp.Body.Close() 49 | 50 | if resp.StatusCode != http.StatusOK { 51 | return nil, errors.New(resp.Status) 52 | } 53 | 54 | var result api.RerankResponse 55 | 56 | if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { 57 | return nil, err 58 | } 59 | 60 | return result.Results, nil 61 | } 62 | -------------------------------------------------------------------------------- /pkg/client/client_synthesis.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | "github.com/adrianliechti/wingman/pkg/provider/openai" 9 | ) 10 | 11 | type SynthesisService struct { 12 | Options []RequestOption 13 | } 14 | 15 | func NewSynthesisService(opts ...RequestOption) SynthesisService { 16 | return SynthesisService{ 17 | Options: opts, 18 | } 19 | } 20 | 21 | type Synthesis = provider.Synthesis 22 | type SynthesizeOptions = provider.SynthesizeOptions 23 | 24 | type SynthesizeRequest struct { 25 | SynthesizeOptions 26 | 27 | Model string 28 | 29 | Input string 30 | } 31 | 32 | func (r *SynthesisService) New(ctx context.Context, input SynthesizeRequest, opts ...RequestOption) (*Synthesis, error) { 33 | cfg := newRequestConfig(append(r.Options, opts...)...) 34 | url := strings.TrimRight(cfg.URL, "/") + "/v1/" 35 | 36 | options := []openai.Option{} 37 | 38 | if cfg.Token != "" { 39 | options = append(options, openai.WithToken(cfg.Token)) 40 | } 41 | 42 | if cfg.Client != nil { 43 | options = append(options, openai.WithClient(cfg.Client)) 44 | } 45 | 46 | p, err := openai.NewSynthesizer(url, input.Model, options...) 47 | 48 | if err != nil { 49 | return nil, err 50 | } 51 | 52 | return p.Synthesize(ctx, input.Input, &input.SynthesizeOptions) 53 | } 54 | -------------------------------------------------------------------------------- /pkg/client/client_transcriptions.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "strings" 7 | 8 | "github.com/adrianliechti/wingman/pkg/provider" 9 | "github.com/adrianliechti/wingman/pkg/provider/openai" 10 | ) 11 | 12 | type TranscriptionService struct { 13 | Options []RequestOption 14 | } 15 | 16 | func NewTranscriptionService(opts ...RequestOption) TranscriptionService { 17 | return TranscriptionService{ 18 | Options: opts, 19 | } 20 | } 21 | 22 | type Transcription = provider.Transcription 23 | type TranscribeOptions = provider.TranscribeOptions 24 | 25 | type TranscribeRequest struct { 26 | TranscribeOptions 27 | 28 | Model string 29 | 30 | Name string 31 | Reader io.Reader 32 | } 33 | 34 | func (r *TranscriptionService) New(ctx context.Context, input TranscribeRequest, opts ...RequestOption) (*Transcription, error) { 35 | cfg := newRequestConfig(append(r.Options, opts...)...) 36 | url := strings.TrimRight(cfg.URL, "/") + "/v1/" 37 | 38 | options := []openai.Option{} 39 | 40 | if cfg.Token != "" { 41 | options = append(options, openai.WithToken(cfg.Token)) 42 | } 43 | 44 | if cfg.Client != nil { 45 | options = append(options, openai.WithClient(cfg.Client)) 46 | } 47 | 48 | p, err := openai.NewTranscriber(url, input.Model, options...) 49 | 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | data, err := io.ReadAll(input.Reader) 55 | 56 | if err != nil { 57 | return nil, err 58 | } 59 | 60 | file := provider.File{ 61 | Name: input.Name, 62 | Content: data, 63 | } 64 | 65 | return p.Transcribe(ctx, file, &input.TranscribeOptions) 66 | } 67 | -------------------------------------------------------------------------------- /pkg/client/option.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type RequestOption = func(*RequestConfig) error 8 | 9 | type RequestConfig struct { 10 | Client *http.Client 11 | 12 | URL string 13 | Token string 14 | } 15 | 16 | func WithClient(client *http.Client) RequestOption { 17 | return func(c *RequestConfig) error { 18 | c.Client = client 19 | return nil 20 | } 21 | } 22 | 23 | func WithURL(url string) RequestOption { 24 | return func(c *RequestConfig) error { 25 | c.URL = url 26 | return nil 27 | } 28 | } 29 | 30 | func WithToken(token string) RequestOption { 31 | return func(c *RequestConfig) error { 32 | c.Token = token 33 | return nil 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /pkg/extractor/azure/config.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | 15 | func WithToken(token string) Option { 16 | return func(c *Client) { 17 | c.token = token 18 | } 19 | } 20 | 21 | // https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/concept-layout?view=doc-intel-4.0.0&tabs=sample-code#input-requirements 22 | var SupportedExtensions = []string{ 23 | ".pdf", 24 | 25 | ".jpeg", ".jpg", 26 | ".png", 27 | ".bmp", 28 | ".tiff", 29 | ".heif", 30 | 31 | ".docx", 32 | ".pptx", 33 | ".xlsx", 34 | } 35 | 36 | // https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/concept-layout?view=doc-intel-4.0.0&tabs=sample-code#input-requirements 37 | var SupportedMimeTypes = []string{ 38 | "application/pdf", 39 | 40 | "image/jpeg", 41 | "image/png", 42 | "image/bmp", 43 | "image/tiff", 44 | "image/heif", 45 | 46 | "application/vnd.openxmlformats-officedocument.wordprocessingml.document", 47 | "application/vnd.openxmlformats-officedocument.presentationml.presentation", 48 | "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", 49 | } 50 | -------------------------------------------------------------------------------- /pkg/extractor/azure/model.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | type OperationStatus string 4 | 5 | const ( 6 | OperationStatusSucceeded OperationStatus = "succeeded" 7 | OperationStatusRunning OperationStatus = "running" 8 | OperationStatusNotStarted OperationStatus = "notStarted" 9 | ) 10 | 11 | type AnalyzeOperation struct { 12 | Status OperationStatus `json:"status"` 13 | 14 | Result AnalyzeResult `json:"analyzeResult"` 15 | } 16 | 17 | type AnalyzeResult struct { 18 | ModelID string `json:"modelId"` 19 | 20 | Content string `json:"content"` 21 | } 22 | -------------------------------------------------------------------------------- /pkg/extractor/custom/Taskfile.yaml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | generate: 7 | cmds: 8 | - protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative extractor.proto -------------------------------------------------------------------------------- /pkg/extractor/custom/config.go: -------------------------------------------------------------------------------- 1 | package custom 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/extractor/custom/extractor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option go_package = "github.com/adrianliechti/wingman/pkg/extractor/custom;custom"; 4 | 5 | package extractor; 6 | 7 | service Extractor { 8 | rpc Extract (ExtractRequest) returns (File) {} 9 | } 10 | 11 | enum Format { 12 | FORMAT_TEXT = 0; 13 | FORMAT_IMAGE = 1; 14 | FORMAT_PDF = 2; 15 | } 16 | 17 | message ExtractRequest { 18 | optional File file = 1; 19 | optional string url = 2; 20 | 21 | optional Format format = 3; 22 | } 23 | 24 | message File { 25 | string name = 1; 26 | 27 | bytes content = 2; 28 | string content_type = 3; 29 | } -------------------------------------------------------------------------------- /pkg/extractor/exa/config.go: -------------------------------------------------------------------------------- 1 | package exa 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /pkg/extractor/exa/models.go: -------------------------------------------------------------------------------- 1 | package exa 2 | 3 | type ContentsRequest struct { 4 | URLs []string `json:"urls"` 5 | 6 | LiveCrawl LiveCrawl `json:"livecrawl,omitempty"` 7 | } 8 | 9 | type LiveCrawl string 10 | 11 | const ( 12 | LiveCrawlAuto LiveCrawl = "auto" 13 | LiveCrawlAlways LiveCrawl = "always" 14 | LiveCrawlNever LiveCrawl = "never" 15 | LiveCrawlFallback LiveCrawl = "fallback" 16 | ) 17 | 18 | type ContentsResponse struct { 19 | Results []ContentsResult `json:"results"` 20 | } 21 | 22 | type ContentsResult struct { 23 | ID string `json:"id"` 24 | 25 | URL string `json:"url"` 26 | Title string `json:"title"` 27 | 28 | Text string `json:"text"` 29 | } 30 | -------------------------------------------------------------------------------- /pkg/extractor/extractor.go: -------------------------------------------------------------------------------- 1 | package extractor 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | ) 9 | 10 | type Provider interface { 11 | Extract(ctx context.Context, input Input, options *ExtractOptions) (*Document, error) 12 | } 13 | 14 | var ( 15 | ErrUnsupported = errors.New("unsupported type") 16 | ) 17 | 18 | type Format string 19 | 20 | const ( 21 | FormatText Format = "text" 22 | FormatImage Format = "image" 23 | FormatPDF Format = "pdf" 24 | ) 25 | 26 | type ExtractOptions struct { 27 | Format *Format 28 | } 29 | 30 | type Input struct { 31 | URL *string 32 | 33 | File *provider.File 34 | } 35 | 36 | type Document struct { 37 | Content []byte 38 | ContentType string 39 | } 40 | -------------------------------------------------------------------------------- /pkg/extractor/jina/client_test.go: -------------------------------------------------------------------------------- 1 | package jina_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/adrianliechti/wingman/pkg/extractor" 8 | "github.com/adrianliechti/wingman/pkg/extractor/jina" 9 | "github.com/adrianliechti/wingman/pkg/to" 10 | 11 | "github.com/stretchr/testify/require" 12 | "github.com/testcontainers/testcontainers-go" 13 | "github.com/testcontainers/testcontainers-go/wait" 14 | ) 15 | 16 | func TestExtract(t *testing.T) { 17 | ctx := context.Background() 18 | 19 | server, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 20 | Started: true, 21 | 22 | ContainerRequest: testcontainers.ContainerRequest{ 23 | Image: "ghcr.io/adrianliechti/wingman-reader", 24 | ExposedPorts: []string{"8080/tcp"}, 25 | WaitingFor: wait.ForExposedPort(), 26 | }, 27 | }) 28 | 29 | require.NoError(t, err) 30 | 31 | url, err := server.Endpoint(ctx, "") 32 | require.NoError(t, err) 33 | 34 | c, err := jina.New("http://" + url) 35 | require.NoError(t, err) 36 | 37 | input := extractor.Input{ 38 | URL: to.Ptr("https://example.org"), 39 | } 40 | 41 | result, err := c.Extract(ctx, input, nil) 42 | require.NoError(t, err) 43 | 44 | require.NotEmpty(t, result.Content) 45 | } 46 | -------------------------------------------------------------------------------- /pkg/extractor/jina/config.go: -------------------------------------------------------------------------------- 1 | package jina 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | 15 | func WithToken(token string) Option { 16 | return func(c *Client) { 17 | c.token = token 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /pkg/extractor/multi/multi.go: -------------------------------------------------------------------------------- 1 | package multi 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/extractor" 7 | ) 8 | 9 | var _ extractor.Provider = &Extractor{} 10 | 11 | type Extractor struct { 12 | providers []extractor.Provider 13 | } 14 | 15 | func New(provider ...extractor.Provider) *Extractor { 16 | return &Extractor{ 17 | providers: provider, 18 | } 19 | } 20 | 21 | func (e *Extractor) Extract(ctx context.Context, input extractor.Input, options *extractor.ExtractOptions) (*extractor.Document, error) { 22 | if options == nil { 23 | options = new(extractor.ExtractOptions) 24 | } 25 | 26 | for _, p := range e.providers { 27 | result, err := p.Extract(ctx, input, options) 28 | 29 | if err != nil { 30 | continue 31 | // if errors.Is(err, extractor.ErrUnsupported) { 32 | // continue 33 | // } 34 | 35 | // return nil, err 36 | } 37 | 38 | return result, nil 39 | } 40 | 41 | return nil, extractor.ErrUnsupported 42 | } 43 | -------------------------------------------------------------------------------- /pkg/extractor/tavily/config.go: -------------------------------------------------------------------------------- 1 | package tavily 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /pkg/extractor/tavily/models.go: -------------------------------------------------------------------------------- 1 | package tavily 2 | 3 | type extractResult struct { 4 | Results []struct { 5 | URL string `json:"url"` 6 | 7 | Content string `json:"raw_content"` 8 | } `json:"results"` 9 | } 10 | -------------------------------------------------------------------------------- /pkg/extractor/tavily/utils.go: -------------------------------------------------------------------------------- 1 | package tavily 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "net/http" 9 | ) 10 | 11 | func jsonReader(v any) io.Reader { 12 | b := new(bytes.Buffer) 13 | 14 | enc := json.NewEncoder(b) 15 | enc.SetEscapeHTML(false) 16 | 17 | enc.Encode(v) 18 | return b 19 | } 20 | 21 | func convertError(resp *http.Response) error { 22 | return errors.New(http.StatusText(resp.StatusCode)) 23 | } 24 | -------------------------------------------------------------------------------- /pkg/extractor/text/config.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | var SupportedExtensions = []string{ 4 | ".txt", 5 | ".csv", 6 | ".tsv", 7 | 8 | ".json", 9 | ".xml", 10 | ".yaml", 11 | ".yml", 12 | 13 | ".ini", 14 | ".log", 15 | ".md", 16 | ".rst", 17 | } 18 | 19 | var SupportedMimeTypes = []string{ 20 | "text/plain", 21 | "text/markdown", 22 | 23 | "text/csv", 24 | "text/tab-separated-values", 25 | 26 | "application/json", 27 | "application/xml", 28 | "application/yaml", 29 | } 30 | -------------------------------------------------------------------------------- /pkg/extractor/tika/client_test.go: -------------------------------------------------------------------------------- 1 | package tika_test 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | "testing" 8 | 9 | "github.com/adrianliechti/wingman/pkg/extractor" 10 | "github.com/adrianliechti/wingman/pkg/extractor/tika" 11 | "github.com/adrianliechti/wingman/pkg/provider" 12 | 13 | "github.com/stretchr/testify/require" 14 | "github.com/testcontainers/testcontainers-go" 15 | "github.com/testcontainers/testcontainers-go/wait" 16 | ) 17 | 18 | func TestExtract(t *testing.T) { 19 | ctx := context.Background() 20 | 21 | server, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 22 | Started: true, 23 | 24 | ContainerRequest: testcontainers.ContainerRequest{ 25 | Image: "apache/tika:3.0.0.0-BETA2-full", 26 | ExposedPorts: []string{"9998/tcp"}, 27 | WaitingFor: wait.ForLog("Started Apache Tika server"), 28 | }, 29 | }) 30 | 31 | require.NoError(t, err) 32 | 33 | url, err := server.Endpoint(ctx, "") 34 | require.NoError(t, err) 35 | 36 | c, err := tika.New("http://" + url) 37 | require.NoError(t, err) 38 | 39 | resp, err := http.Get("https://helpx.adobe.com/pdf/acrobat_reference.pdf") 40 | require.NoError(t, err) 41 | defer resp.Body.Close() 42 | 43 | data, err := io.ReadAll(resp.Body) 44 | require.NoError(t, err) 45 | 46 | input := extractor.Input{ 47 | File: &provider.File{ 48 | Content: data, 49 | ContentType: "application/pdf", 50 | }, 51 | } 52 | 53 | result, err := c.Extract(ctx, input, nil) 54 | require.NoError(t, err) 55 | 56 | require.NotEmpty(t, result.Content) 57 | } 58 | -------------------------------------------------------------------------------- /pkg/extractor/tika/config.go: -------------------------------------------------------------------------------- 1 | package tika 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | var SupportedExtensions = []string{ 8 | ".pdf", 9 | 10 | ".jpg", ".jpeg", 11 | ".png", 12 | 13 | ".doc", ".docx", 14 | ".ppt", ".pptx", 15 | ".xls", ".xlsx", 16 | } 17 | 18 | var SupportedMimeTypes = []string{ 19 | "application/pdf", 20 | 21 | "image/jpeg", 22 | "image/png", 23 | 24 | "application/msword", 25 | "application/vnd.openxmlformats-officedocument.wordprocessingml.document", 26 | 27 | "application/vnd.ms-powerpoint", 28 | "application/vnd.openxmlformats-officedocument.presentationml.presentation", 29 | 30 | "application/vnd.ms-excel", 31 | "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", 32 | } 33 | 34 | type Option func(*Client) 35 | 36 | func WithClient(client *http.Client) Option { 37 | return func(c *Client) { 38 | c.client = client 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /pkg/extractor/tika/model.go: -------------------------------------------------------------------------------- 1 | package tika 2 | 3 | type TikaResponse struct { 4 | Content string `json:"X-TIKA:content"` 5 | 6 | //ContentType string `json:"Content-Type"` 7 | //ContentLength string `json:"Content-Length"` 8 | } 9 | -------------------------------------------------------------------------------- /pkg/extractor/unstructured/model.go: -------------------------------------------------------------------------------- 1 | package unstructured 2 | 3 | type Strategy string 4 | 5 | const ( 6 | StrategyAuto Strategy = "auto" 7 | StrategyFast Strategy = "fast" 8 | StrategyHiRes Strategy = "hi_res" 9 | ) 10 | 11 | type Element struct { 12 | ID string `json:"element_id"` 13 | 14 | Type string `json:"type"` 15 | Text string `json:"text"` 16 | 17 | Metadata ElementMetadata `json:"metadata"` 18 | } 19 | 20 | type ElementMetadata struct { 21 | FileName string `json:"filename"` 22 | FileType string `json:"filetype"` 23 | 24 | Languages []string `json:"languages"` 25 | 26 | // PageName string `json:"page_name"` 27 | // PageNumber int `json:"page_number"` 28 | 29 | // MailSender string `json:"sent_from"` 30 | // MailRecipient string `json:"sent_to"` 31 | // MailSubject string `json:"subject"` 32 | } 33 | -------------------------------------------------------------------------------- /pkg/index/azure/client_delete.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | func (c *Client) Delete(ctx context.Context, ids ...string) error { 9 | if err := c.ensureCollection(ctx, c.namespace); err != nil { 10 | return err 11 | } 12 | 13 | items := []map[string]any{} 14 | 15 | for _, id := range ids { 16 | item := map[string]any{ 17 | "@search.action": "delete", 18 | 19 | "id": id, 20 | } 21 | 22 | items = append(items, item) 23 | } 24 | 25 | body := map[string]any{ 26 | "value": items, 27 | } 28 | 29 | req, _ := http.NewRequestWithContext(ctx, "POST", c.requestURL("/indexes/"+c.namespace+"/docs/index", nil), jsonReader(body)) 30 | req.Header.Set("Content-Type", "application/json") 31 | req.Header.Set("api-key", c.token) 32 | 33 | resp, err := c.client.Do(req) 34 | 35 | if err != nil { 36 | return err 37 | } 38 | 39 | if resp.StatusCode != http.StatusOK { 40 | return convertError(resp) 41 | } 42 | 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /pkg/index/azure/client_index.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/pkg/index" 8 | "github.com/google/uuid" 9 | ) 10 | 11 | func (c *Client) Index(ctx context.Context, documents ...index.Document) error { 12 | if err := c.ensureCollection(ctx, c.namespace); err != nil { 13 | return err 14 | } 15 | 16 | items := []map[string]any{} 17 | 18 | for _, d := range documents { 19 | id := d.ID 20 | 21 | if id == "" { 22 | id = uuid.New().String() 23 | } 24 | 25 | item := map[string]any{ 26 | "@search.action": "upload", 27 | 28 | "id": id, 29 | 30 | "title": d.Title, 31 | "source": d.Source, 32 | "content": d.Content, 33 | } 34 | 35 | if len(d.Metadata) > 0 { 36 | metadata := []map[string]string{} 37 | 38 | for k, v := range d.Metadata { 39 | metadata = append(metadata, map[string]string{ 40 | "key": k, 41 | "value": v, 42 | }) 43 | } 44 | 45 | item["metadata"] = metadata 46 | } 47 | 48 | items = append(items, item) 49 | } 50 | 51 | body := map[string]any{ 52 | "value": items, 53 | } 54 | 55 | req, _ := http.NewRequestWithContext(ctx, "POST", c.requestURL("/indexes/"+c.namespace+"/docs/index", nil), jsonReader(body)) 56 | req.Header.Set("Content-Type", "application/json") 57 | req.Header.Set("api-key", c.token) 58 | 59 | resp, err := c.client.Do(req) 60 | 61 | if err != nil { 62 | return err 63 | } 64 | 65 | if resp.StatusCode != http.StatusOK { 66 | return convertError(resp) 67 | } 68 | 69 | return nil 70 | } 71 | -------------------------------------------------------------------------------- /pkg/index/azure/client_list.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index" 7 | ) 8 | 9 | func (c *Client) List(ctx context.Context, options *index.ListOptions) (*index.Page[index.Document], error) { 10 | results, err := c.Query(ctx, "*", &index.QueryOptions{}) 11 | 12 | if err != nil { 13 | return nil, err 14 | } 15 | 16 | var items []index.Document 17 | 18 | for _, r := range results { 19 | items = append(items, r.Document) 20 | } 21 | 22 | page := index.Page[index.Document]{ 23 | Items: items, 24 | } 25 | 26 | return &page, nil 27 | } 28 | -------------------------------------------------------------------------------- /pkg/index/azure/client_query.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | 9 | "github.com/adrianliechti/wingman/pkg/index" 10 | "github.com/adrianliechti/wingman/pkg/to" 11 | ) 12 | 13 | func (c *Client) Query(ctx context.Context, query string, options *index.QueryOptions) ([]index.Result, error) { 14 | if options == nil { 15 | options = new(index.QueryOptions) 16 | } 17 | 18 | if options.Limit == nil { 19 | options.Limit = to.Ptr(10) 20 | } 21 | 22 | queries := map[string]string{ 23 | "search": query, 24 | } 25 | 26 | if options.Limit != nil { 27 | queries["$top"] = fmt.Sprintf("%d", *options.Limit) 28 | } 29 | 30 | req, _ := http.NewRequestWithContext(ctx, "GET", c.requestURL("/indexes/"+c.namespace+"/docs", queries), nil) 31 | req.Header.Set("api-key", c.token) 32 | 33 | resp, err := c.client.Do(req) 34 | 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | defer resp.Body.Close() 40 | 41 | var result Results 42 | 43 | if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { 44 | return nil, err 45 | } 46 | 47 | var results []index.Result 48 | 49 | for _, r := range result.Value { 50 | result := index.Result{ 51 | Document: index.Document{ 52 | ID: r.ID(), 53 | 54 | Title: r.Title(), 55 | Source: r.Source(), 56 | Content: r.Content(), 57 | 58 | Metadata: r.Metadata(), 59 | }, 60 | } 61 | 62 | results = append(results, result) 63 | } 64 | 65 | return results, nil 66 | } 67 | -------------------------------------------------------------------------------- /pkg/index/azure/client_test.go: -------------------------------------------------------------------------------- 1 | package azure_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/adrianliechti/wingman/pkg/index/azure" 8 | "github.com/adrianliechti/wingman/test" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestAzure(t *testing.T) { 14 | context := test.NewContext() 15 | 16 | url := os.Getenv("AZURE_SEARCH_ENDPOINT") 17 | token := os.Getenv("AZURE_SEARCH_API_KEY") 18 | index := os.Getenv("AZURE_SEARCH_INDEX_NAME") 19 | 20 | require.NotEmpty(t, url) 21 | require.NotEmpty(t, token) 22 | require.NotEmpty(t, index) 23 | 24 | c, err := azure.New(url, index, token) 25 | 26 | if err != nil { 27 | t.Fatal(err) 28 | } 29 | 30 | test.TestIndex(t, context, c) 31 | } 32 | -------------------------------------------------------------------------------- /pkg/index/azure/config.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /pkg/index/bing/config.go: -------------------------------------------------------------------------------- 1 | package bing 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /pkg/index/bing/models.go: -------------------------------------------------------------------------------- 1 | package bing 2 | 3 | type response struct { 4 | WebPages struct { 5 | Value []page `json:"value"` 6 | } `json:"webPages"` 7 | } 8 | 9 | type page struct { 10 | ID string `json:"id"` 11 | URL string `json:"url"` 12 | 13 | Name string `json:"name"` 14 | Snippet string `json:"snippet"` 15 | } 16 | -------------------------------------------------------------------------------- /pkg/index/chroma/client_test.go: -------------------------------------------------------------------------------- 1 | package chroma_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index/chroma" 7 | "github.com/adrianliechti/wingman/test" 8 | 9 | "github.com/stretchr/testify/require" 10 | "github.com/testcontainers/testcontainers-go" 11 | "github.com/testcontainers/testcontainers-go/wait" 12 | ) 13 | 14 | func TestChroma(t *testing.T) { 15 | context := test.NewContext() 16 | 17 | server, err := testcontainers.GenericContainer(context.Context, testcontainers.GenericContainerRequest{ 18 | Started: true, 19 | 20 | ContainerRequest: testcontainers.ContainerRequest{ 21 | Image: "ghcr.io/chroma-core/chroma:0.5.5", 22 | ExposedPorts: []string{"8000/tcp"}, 23 | WaitingFor: wait.ForLog("Application startup complete"), 24 | }, 25 | }) 26 | 27 | require.NoError(t, err) 28 | 29 | url, err := server.Endpoint(context.Context, "") 30 | require.NoError(t, err) 31 | 32 | c, err := chroma.New("http://"+url, "test", chroma.WithEmbedder(context.Embedder)) 33 | require.NoError(t, err) 34 | 35 | test.TestIndex(t, context, c) 36 | } 37 | -------------------------------------------------------------------------------- /pkg/index/chroma/config.go: -------------------------------------------------------------------------------- 1 | package chroma 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index" 7 | ) 8 | 9 | type Option func(*Client) 10 | 11 | func WithClient(client *http.Client) Option { 12 | return func(c *Client) { 13 | c.client = client 14 | } 15 | } 16 | 17 | func WithEmbedder(embedder index.Embedder) Option { 18 | return func(c *Client) { 19 | c.embedder = embedder 20 | } 21 | } 22 | 23 | func WithReranker(reranker index.Reranker) Option { 24 | return func(c *Client) { 25 | c.reranker = reranker 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /pkg/index/chroma/models.go: -------------------------------------------------------------------------------- 1 | package chroma 2 | 3 | type collection struct { 4 | ID string `json:"id,omitempty"` 5 | 6 | Tenant string `json:"tenant,omitempty"` 7 | Database string `json:"database,omitempty"` 8 | 9 | Name string `json:"name,omitempty"` 10 | Metadata map[string]any `json:"metadata,omitempty"` 11 | } 12 | 13 | type embeddings struct { 14 | IDs []string `json:"ids"` 15 | 16 | Embeddings [][]float32 `json:"embeddings"` 17 | 18 | Metadatas []map[string]string `json:"metadatas"` 19 | Documents []string `json:"documents"` 20 | } 21 | 22 | type getResult struct { 23 | IDs []string `json:"ids"` 24 | 25 | Distances []float32 `json:"distances,omitempty"` 26 | 27 | Embeddings [][]float64 `json:"embeddings"` 28 | 29 | Metadatas []map[string]string `json:"metadatas"` 30 | Documents []string `json:"documents"` 31 | } 32 | 33 | type queryResult struct { 34 | IDs [][]string `json:"ids"` 35 | 36 | Distances [][]float32 `json:"distances,omitempty"` 37 | 38 | Embeddings [][][]float64 `json:"embeddings"` 39 | 40 | Metadatas [][]map[string]string `json:"metadatas"` 41 | Documents [][]string `json:"documents"` 42 | } 43 | 44 | type errorDetail struct { 45 | Message string `json:"msg"` 46 | } 47 | -------------------------------------------------------------------------------- /pkg/index/custom/Taskfile.yaml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | generate: 7 | cmds: 8 | - protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative index.proto -------------------------------------------------------------------------------- /pkg/index/custom/index.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "google/protobuf/empty.proto"; 4 | 5 | option go_package = "github.com/adrianliechti/wingman/pkg/index/custom;custom"; 6 | 7 | package index; 8 | 9 | service index { 10 | rpc List (ListRequest) returns (Documents) {} 11 | rpc Delete(DeleteRequest) returns (google.protobuf.Empty) {} 12 | rpc Index(IndexRequest) returns (google.protobuf.Empty) {} 13 | rpc Query (QueryRequest) returns (Results) {} 14 | } 15 | 16 | message ListRequest { 17 | } 18 | 19 | message DeleteRequest { 20 | repeated string ids = 1; 21 | } 22 | 23 | message IndexRequest { 24 | repeated Document documents = 1; 25 | } 26 | 27 | message QueryRequest { 28 | string query = 1; 29 | 30 | optional int32 limit = 2; 31 | 32 | map filters = 3; 33 | } 34 | 35 | message Documents { 36 | repeated Document documents = 1; 37 | } 38 | 39 | message Document { 40 | string id = 1; 41 | 42 | string title = 2; 43 | string source = 3; 44 | string content = 4; 45 | 46 | map metadata = 5; 47 | 48 | repeated float embedding = 6; 49 | } 50 | 51 | message Results { 52 | repeated Result results = 1; 53 | } 54 | 55 | message Result { 56 | Document document = 1; 57 | 58 | float score = 2; 59 | } -------------------------------------------------------------------------------- /pkg/index/duckduckgo/client_test.go: -------------------------------------------------------------------------------- 1 | package duckduckgo_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index/duckduckgo" 7 | "github.com/adrianliechti/wingman/test" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestSearch(t *testing.T) { 12 | context := test.NewContext() 13 | 14 | c, err := duckduckgo.New() 15 | require.NoError(t, err) 16 | 17 | result, err := c.Query(context.Context, "Meta LLAMA", nil) 18 | require.NoError(t, err) 19 | 20 | println(result) 21 | } 22 | -------------------------------------------------------------------------------- /pkg/index/duckduckgo/config.go: -------------------------------------------------------------------------------- 1 | package duckduckgo 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /pkg/index/elasticsearch/client_test.go: -------------------------------------------------------------------------------- 1 | package elasticsearch_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index/elasticsearch" 7 | "github.com/adrianliechti/wingman/test" 8 | 9 | "github.com/stretchr/testify/require" 10 | "github.com/testcontainers/testcontainers-go" 11 | "github.com/testcontainers/testcontainers-go/wait" 12 | ) 13 | 14 | func TestElasticsearch(t *testing.T) { 15 | context := test.NewContext() 16 | 17 | server, err := testcontainers.GenericContainer(context.Context, testcontainers.GenericContainerRequest{ 18 | Started: true, 19 | 20 | ContainerRequest: testcontainers.ContainerRequest{ 21 | Image: "docker.elastic.co/elasticsearch/elasticsearch:8.15.1", 22 | Env: map[string]string{ 23 | "ES_JAVA_OPTS": "-Xms1g -Xmx1g", 24 | "discovery.type": "single-node", 25 | "xpack.security.enabled": "false", 26 | "node.name": "test", 27 | "cluster.name": "test", 28 | }, 29 | ExposedPorts: []string{"9200/tcp"}, 30 | WaitingFor: wait.ForExposedPort(), 31 | }, 32 | }) 33 | 34 | require.NoError(t, err) 35 | 36 | url, err := server.Endpoint(context.Context, "") 37 | require.NoError(t, err) 38 | 39 | c, err := elasticsearch.New("http://"+url, "test") 40 | require.NoError(t, err) 41 | 42 | test.TestIndex(t, context, c) 43 | } 44 | -------------------------------------------------------------------------------- /pkg/index/elasticsearch/config.go: -------------------------------------------------------------------------------- 1 | package elasticsearch 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /pkg/index/elasticsearch/models.go: -------------------------------------------------------------------------------- 1 | package elasticsearch 2 | 3 | type Document struct { 4 | ID string `json:"id"` 5 | 6 | Title string `json:"title"` 7 | Source string `json:"source"` 8 | Content string `json:"content"` 9 | 10 | Metadata map[string]string `json:"metadata"` 11 | } 12 | 13 | type SearchResult struct { 14 | Hits SearchHits `json:"hits"` 15 | } 16 | 17 | type SearchHits struct { 18 | Hits []SearchHit `json:"hits"` 19 | } 20 | 21 | type SearchHit struct { 22 | Score float32 `json:"_score"` 23 | Document Document `json:"_source"` 24 | } 25 | -------------------------------------------------------------------------------- /pkg/index/exa/config.go: -------------------------------------------------------------------------------- 1 | package exa 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /pkg/index/exa/models.go: -------------------------------------------------------------------------------- 1 | package exa 2 | 3 | type SearchRequest struct { 4 | Query string `json:"query"` 5 | 6 | Contents SearchContents `json:"contents,omitempty"` 7 | } 8 | 9 | type SearchContents struct { 10 | Text bool `json:"text,omitempty"` 11 | 12 | LiveCrawl LiveCrawl `json:"livecrawl,omitempty"` 13 | } 14 | 15 | type LiveCrawl string 16 | 17 | const ( 18 | LiveCrawlAuto LiveCrawl = "auto" 19 | LiveCrawlAlways LiveCrawl = "always" 20 | LiveCrawlNever LiveCrawl = "never" 21 | LiveCrawlFallback LiveCrawl = "fallback" 22 | ) 23 | 24 | type SearchResponse struct { 25 | Results []SearchResult `json:"results"` 26 | } 27 | 28 | type SearchResult struct { 29 | ID string `json:"id"` 30 | 31 | URL string `json:"url"` 32 | Title string `json:"title"` 33 | 34 | Text string `json:"text"` 35 | } 36 | -------------------------------------------------------------------------------- /pkg/index/index.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | ) 8 | 9 | type Embedder = provider.Embedder 10 | type Reranker = provider.Reranker 11 | 12 | type Provider interface { 13 | List(ctx context.Context, options *ListOptions) (*Page[Document], error) 14 | 15 | Index(ctx context.Context, documents ...Document) error 16 | Delete(ctx context.Context, ids ...string) error 17 | 18 | Query(ctx context.Context, query string, options *QueryOptions) ([]Result, error) 19 | } 20 | 21 | type ListOptions struct { 22 | Limit *int 23 | Cursor string 24 | } 25 | 26 | type QueryOptions struct { 27 | Limit *int 28 | 29 | Filters map[string]string 30 | } 31 | 32 | type Page[T Document] struct { 33 | Items []T 34 | 35 | Cursor string 36 | } 37 | 38 | type Document struct { 39 | ID string 40 | 41 | Title string 42 | Source string 43 | Content string 44 | 45 | Metadata map[string]string 46 | 47 | Embedding []float32 48 | } 49 | 50 | type Result struct { 51 | Document 52 | Score float32 53 | } 54 | -------------------------------------------------------------------------------- /pkg/index/memory/client_test.go: -------------------------------------------------------------------------------- 1 | package memory_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index/memory" 7 | "github.com/adrianliechti/wingman/test" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestMemory(t *testing.T) { 13 | context := test.NewContext() 14 | 15 | c, err := memory.New(memory.WithEmbedder(context.Embedder)) 16 | require.NoError(t, err) 17 | 18 | test.TestIndex(t, context, c) 19 | } 20 | -------------------------------------------------------------------------------- /pkg/index/memory/config.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "github.com/adrianliechti/wingman/pkg/index" 5 | ) 6 | 7 | type Option func(*Provider) 8 | 9 | func WithEmbedder(embedder index.Embedder) Option { 10 | return func(p *Provider) { 11 | p.embedder = embedder 12 | } 13 | } 14 | 15 | func WithReranker(reranker index.Reranker) Option { 16 | return func(p *Provider) { 17 | p.reranker = reranker 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /pkg/index/postgrest/client.go: -------------------------------------------------------------------------------- 1 | package postgrest 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "net/http" 9 | 10 | "github.com/adrianliechti/wingman/pkg/index" 11 | ) 12 | 13 | var _ index.Provider = &Client{} 14 | 15 | type Client struct { 16 | client *http.Client 17 | 18 | url string 19 | 20 | namespace string 21 | 22 | embedder index.Embedder 23 | reranker index.Reranker 24 | } 25 | 26 | func New(url string, namespace string, options ...Option) (*Client, error) { 27 | c := &Client{ 28 | client: http.DefaultClient, 29 | 30 | url: url, 31 | 32 | namespace: namespace, 33 | } 34 | 35 | for _, option := range options { 36 | option(c) 37 | } 38 | 39 | if c.embedder == nil { 40 | return nil, errors.New("embedder is required") 41 | } 42 | 43 | if c.namespace == "" { 44 | return nil, errors.New("namespace is required") 45 | } 46 | 47 | return c, nil 48 | } 49 | 50 | func jsonReader(v any) io.Reader { 51 | b := new(bytes.Buffer) 52 | 53 | enc := json.NewEncoder(b) 54 | enc.SetEscapeHTML(false) 55 | 56 | enc.Encode(v) 57 | return b 58 | } 59 | 60 | func convertError(resp *http.Response) error { 61 | data, _ := io.ReadAll(resp.Body) 62 | 63 | if len(data) == 0 { 64 | return errors.New(http.StatusText(resp.StatusCode)) 65 | } 66 | 67 | return errors.New(string(data)) 68 | } 69 | -------------------------------------------------------------------------------- /pkg/index/postgrest/client_delete.go: -------------------------------------------------------------------------------- 1 | package postgrest 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "net/url" 7 | "strings" 8 | ) 9 | 10 | func (c *Client) Delete(ctx context.Context, ids ...string) error { 11 | url, _ := url.JoinPath(c.url, "/docs") 12 | url += "?id=in.(" + strings.Join(ids, ",") + ")" 13 | 14 | req, _ := http.NewRequestWithContext(ctx, "DELETE", url, nil) 15 | 16 | resp, err := c.client.Do(req) 17 | 18 | if err != nil { 19 | return err 20 | } 21 | 22 | if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { 23 | return convertError(resp) 24 | } 25 | 26 | return nil 27 | } 28 | -------------------------------------------------------------------------------- /pkg/index/postgrest/client_index.go: -------------------------------------------------------------------------------- 1 | package postgrest 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "net/url" 7 | 8 | "github.com/adrianliechti/wingman/pkg/index" 9 | ) 10 | 11 | func (c *Client) Index(ctx context.Context, documents ...index.Document) error { 12 | body := []Document{} 13 | 14 | for _, d := range documents { 15 | if len(d.Embedding) == 0 && c.embedder != nil { 16 | embedding, err := c.embedder.Embed(ctx, []string{d.Content}) 17 | 18 | if err != nil { 19 | return err 20 | } 21 | 22 | d.Embedding = embedding.Embeddings[0] 23 | } 24 | 25 | item := Document{ 26 | ID: d.ID, 27 | 28 | Title: d.Title, 29 | Source: d.Source, 30 | Content: d.Content, 31 | 32 | Embedding: d.Embedding, 33 | } 34 | 35 | body = append(body, item) 36 | } 37 | 38 | url, _ := url.JoinPath(c.url, "docs") 39 | 40 | req, _ := http.NewRequestWithContext(ctx, "POST", url, jsonReader(body)) 41 | req.Header.Set("Content-Type", "application/json") 42 | req.Header.Set("Prefer", "resolution=merge-duplicates") 43 | 44 | resp, err := c.client.Do(req) 45 | 46 | if err != nil { 47 | return err 48 | } 49 | 50 | if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { 51 | return convertError(resp) 52 | } 53 | 54 | return nil 55 | } 56 | -------------------------------------------------------------------------------- /pkg/index/postgrest/client_list.go: -------------------------------------------------------------------------------- 1 | package postgrest 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "net/http" 7 | "net/url" 8 | 9 | "github.com/adrianliechti/wingman/pkg/index" 10 | ) 11 | 12 | func (c *Client) List(ctx context.Context, options *index.ListOptions) (*index.Page[index.Document], error) { 13 | url, _ := url.JoinPath(c.url, "/docs") 14 | 15 | req, _ := http.NewRequestWithContext(ctx, "GET", url, nil) 16 | 17 | resp, err := c.client.Do(req) 18 | 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | defer resp.Body.Close() 24 | 25 | if resp.StatusCode != http.StatusOK { 26 | return nil, convertError(resp) 27 | } 28 | 29 | var documents []Document 30 | 31 | if err := json.NewDecoder(resp.Body).Decode(&documents); err != nil { 32 | return nil, err 33 | } 34 | 35 | var items []index.Document 36 | 37 | for _, doc := range documents { 38 | items = append(items, index.Document{ 39 | ID: doc.ID, 40 | 41 | Title: doc.Title, 42 | Source: doc.Source, 43 | Content: doc.Content, 44 | }) 45 | } 46 | 47 | page := index.Page[index.Document]{ 48 | Items: items, 49 | } 50 | 51 | return &page, nil 52 | } 53 | -------------------------------------------------------------------------------- /pkg/index/postgrest/client_test.go: -------------------------------------------------------------------------------- 1 | package postgrest_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index/postgrest" 7 | "github.com/adrianliechti/wingman/test" 8 | ) 9 | 10 | func TestQdrant(t *testing.T) { 11 | context := test.NewContext() 12 | 13 | url := "localhost:3000" 14 | 15 | c, err := postgrest.New("http://"+url, "docs", postgrest.WithEmbedder(context.Embedder)) 16 | 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | 21 | test.TestIndex(t, context, c) 22 | } 23 | -------------------------------------------------------------------------------- /pkg/index/postgrest/config.go: -------------------------------------------------------------------------------- 1 | package postgrest 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index" 7 | ) 8 | 9 | type Option func(*Client) 10 | 11 | func WithClient(client *http.Client) Option { 12 | return func(c *Client) { 13 | c.client = client 14 | } 15 | } 16 | 17 | func WithEmbedder(embedder index.Embedder) Option { 18 | return func(c *Client) { 19 | c.embedder = embedder 20 | } 21 | } 22 | 23 | func WithReranker(reranker index.Reranker) Option { 24 | return func(c *Client) { 25 | c.reranker = reranker 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /pkg/index/postgrest/types.go: -------------------------------------------------------------------------------- 1 | package postgrest 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | type Document struct { 10 | ID string `json:"id"` 11 | 12 | Title string `json:"title"` 13 | Source string `json:"source"` 14 | Content string `json:"content"` 15 | 16 | Embedding []float32 `json:"embedding"` 17 | } 18 | 19 | func (d *Document) UnmarshalJSON(data []byte) error { 20 | var alias struct { 21 | ID string `json:"id"` 22 | 23 | Title string `json:"title"` 24 | Source string `json:"source"` 25 | Content string `json:"content"` 26 | 27 | Embedding string `json:"embedding"` 28 | } 29 | 30 | if err := json.Unmarshal(data, &alias); err != nil { 31 | return err 32 | } 33 | 34 | d.ID = alias.ID 35 | d.Title = alias.Title 36 | d.Source = alias.Source 37 | 38 | d.Content = alias.Content 39 | 40 | slices := strings.Split(strings.Trim(alias.Embedding, "[]"), ",") 41 | 42 | for _, slice := range slices { 43 | var value float32 44 | 45 | if _, err := fmt.Sscanf(slice, "%f", &value); err != nil { 46 | return err 47 | } 48 | 49 | d.Embedding = append(d.Embedding, value) 50 | } 51 | 52 | return nil 53 | } 54 | -------------------------------------------------------------------------------- /pkg/index/qdrant/client_test.go: -------------------------------------------------------------------------------- 1 | package qdrant_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index/qdrant" 7 | "github.com/adrianliechti/wingman/test" 8 | 9 | "github.com/stretchr/testify/require" 10 | "github.com/testcontainers/testcontainers-go" 11 | "github.com/testcontainers/testcontainers-go/wait" 12 | ) 13 | 14 | func TestQdrant(t *testing.T) { 15 | context := test.NewContext() 16 | 17 | server, err := testcontainers.GenericContainer(context.Context, testcontainers.GenericContainerRequest{ 18 | Started: true, 19 | 20 | ContainerRequest: testcontainers.ContainerRequest{ 21 | Image: "qdrant/qdrant:v1.13.3", 22 | ExposedPorts: []string{"6333/tcp"}, 23 | WaitingFor: wait.ForLog("Qdrant HTTP listening on 6333"), 24 | }, 25 | }) 26 | 27 | require.NoError(t, err) 28 | 29 | url, err := server.Endpoint(context.Context, "") 30 | require.NoError(t, err) 31 | 32 | c, err := qdrant.New("http://"+url, "test", qdrant.WithEmbedder(context.Embedder)) 33 | 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | 38 | test.TestIndex(t, context, c) 39 | } 40 | -------------------------------------------------------------------------------- /pkg/index/qdrant/config.go: -------------------------------------------------------------------------------- 1 | package qdrant 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index" 7 | ) 8 | 9 | type Option func(*Client) 10 | 11 | func WithClient(client *http.Client) Option { 12 | return func(c *Client) { 13 | c.client = client 14 | } 15 | } 16 | 17 | func WithEmbedder(embedder index.Embedder) Option { 18 | return func(c *Client) { 19 | c.embedder = embedder 20 | } 21 | } 22 | 23 | func WithReranker(reranker index.Reranker) Option { 24 | return func(c *Client) { 25 | c.reranker = reranker 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /pkg/index/qdrant/models.go: -------------------------------------------------------------------------------- 1 | package qdrant 2 | 3 | type payload struct { 4 | Title string `json:"title,omitempty"` 5 | Source string `json:"source,omitempty"` 6 | Content string `json:"content,omitempty"` 7 | 8 | Metadata map[string]string `json:"metadata,omitempty"` 9 | } 10 | 11 | type point struct { 12 | ID string `json:"id"` 13 | 14 | Vector []float32 `json:"vector"` 15 | 16 | Payload payload `json:"payload"` 17 | } 18 | 19 | type result struct { 20 | ID string `json:"id"` 21 | 22 | Version int `json:"version"` 23 | Score float32 `json:"score"` 24 | 25 | Vector []float32 `json:"vector"` 26 | 27 | Payload payload `json:"payload"` 28 | } 29 | 30 | type queryResult struct { 31 | Result []result `json:"result"` 32 | } 33 | 34 | type scrollResult struct { 35 | Result struct { 36 | Points []point `json:"points"` 37 | 38 | NextPageOffset string `json:"next_page_offset"` 39 | } `json:"result"` 40 | } 41 | -------------------------------------------------------------------------------- /pkg/index/searxng/config.go: -------------------------------------------------------------------------------- 1 | package searxng 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /pkg/index/searxng/models.go: -------------------------------------------------------------------------------- 1 | package searxng 2 | 3 | type searchResponse struct { 4 | Results []result `json:"results"` 5 | } 6 | 7 | type result struct { 8 | URL string `json:"url"` 9 | 10 | Engine string `json:"engine"` 11 | 12 | Title string `json:"title"` 13 | Content string `json:"content"` 14 | 15 | Score float32 `json:"score"` 16 | } 17 | -------------------------------------------------------------------------------- /pkg/index/tavily/config.go: -------------------------------------------------------------------------------- 1 | package tavily 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /pkg/index/tavily/models.go: -------------------------------------------------------------------------------- 1 | package tavily 2 | 3 | type searchResult struct { 4 | Query string `json:"query"` 5 | 6 | Answer string `json:"answer"` 7 | 8 | Results []struct { 9 | URL string `json:"url"` 10 | 11 | Title string `json:"title"` 12 | Content string `json:"content"` 13 | 14 | Score float64 `json:"score"` 15 | } `json:"results"` 16 | } 17 | -------------------------------------------------------------------------------- /pkg/index/tavily/utils.go: -------------------------------------------------------------------------------- 1 | package tavily 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "net/http" 9 | ) 10 | 11 | func jsonReader(v any) io.Reader { 12 | b := new(bytes.Buffer) 13 | 14 | enc := json.NewEncoder(b) 15 | enc.SetEscapeHTML(false) 16 | 17 | enc.Encode(v) 18 | return b 19 | } 20 | 21 | func convertError(resp *http.Response) error { 22 | return errors.New(http.StatusText(resp.StatusCode)) 23 | } 24 | -------------------------------------------------------------------------------- /pkg/index/weaviate/client_test.go: -------------------------------------------------------------------------------- 1 | package weaviate_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index/weaviate" 7 | "github.com/adrianliechti/wingman/test" 8 | 9 | "github.com/stretchr/testify/require" 10 | "github.com/testcontainers/testcontainers-go" 11 | "github.com/testcontainers/testcontainers-go/wait" 12 | ) 13 | 14 | func TestWeaviate(t *testing.T) { 15 | context := test.NewContext() 16 | 17 | server, err := testcontainers.GenericContainer(context.Context, testcontainers.GenericContainerRequest{ 18 | Started: true, 19 | 20 | ContainerRequest: testcontainers.ContainerRequest{ 21 | Image: "cr.weaviate.io/semitechnologies/weaviate:1.26.4", 22 | Env: map[string]string{ 23 | "CLUSTER_HOSTNAME": "node1", 24 | "DEFAULT_VECTORIZER_MODULE": "none", 25 | }, 26 | ExposedPorts: []string{"8080/tcp"}, 27 | WaitingFor: wait.ForLog("node reporting ready"), 28 | }, 29 | }) 30 | 31 | require.NoError(t, err) 32 | 33 | url, err := server.Endpoint(context.Context, "") 34 | require.NoError(t, err) 35 | 36 | c, err := weaviate.New("http://"+url, "Test", weaviate.WithEmbedder(context.Embedder)) 37 | require.NoError(t, err) 38 | 39 | test.TestIndex(t, context, c) 40 | } 41 | -------------------------------------------------------------------------------- /pkg/index/weaviate/config.go: -------------------------------------------------------------------------------- 1 | package weaviate 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index" 7 | ) 8 | 9 | type Option func(*Client) 10 | 11 | func WithClient(client *http.Client) Option { 12 | return func(c *Client) { 13 | c.client = client 14 | } 15 | } 16 | 17 | func WithEmbedder(embedder index.Embedder) Option { 18 | return func(c *Client) { 19 | c.embedder = embedder 20 | } 21 | } 22 | 23 | func WithReranker(reranker index.Reranker) Option { 24 | return func(c *Client) { 25 | c.reranker = reranker 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /pkg/index/weaviate/models.go: -------------------------------------------------------------------------------- 1 | package weaviate 2 | 3 | type Object struct { 4 | ID string `json:"id"` 5 | 6 | Created int64 `json:"creationTimeUnix"` 7 | Updated int64 `json:"lastUpdateTimeUnix"` 8 | 9 | Properties map[string]string `json:"properties"` 10 | } 11 | -------------------------------------------------------------------------------- /pkg/index/weaviate/query.go: -------------------------------------------------------------------------------- 1 | package weaviate 2 | 3 | import ( 4 | "bytes" 5 | _ "embed" 6 | "text/template" 7 | ) 8 | 9 | var ( 10 | //go:embed query.tmpl 11 | queryTemplateText string 12 | queryTemplate = template.Must(template.New("query").Parse(queryTemplateText)) 13 | ) 14 | 15 | type queryData struct { 16 | Class string 17 | 18 | Query string 19 | Vector []float32 20 | 21 | Limit *int 22 | Where map[string]string 23 | } 24 | 25 | func executeQueryTemplate(data queryData) string { 26 | var buffer bytes.Buffer 27 | queryTemplate.Execute(&buffer, data) 28 | 29 | return buffer.String() 30 | } 31 | -------------------------------------------------------------------------------- /pkg/index/weaviate/query.tmpl: -------------------------------------------------------------------------------- 1 | { 2 | Get { 3 | {{ .Class }} ( 4 | {{- if .Limit }} 5 | limit: {{ .Limit }} 6 | {{ end }} 7 | 8 | {{- if .Where }} 9 | where: { 10 | operator: And, 11 | operands: [ 12 | {{- $sep := "" }} 13 | {{- range $key, $value := .Where }} 14 | { 15 | path: ["{{ $key }}"], 16 | operator: Equal, 17 | valueText: "{{ $value }}", 18 | } 19 | {{- $sep = "," }} 20 | {{- end }} 21 | ] 22 | } 23 | {{- end }} 24 | 25 | hybrid: { 26 | query: "{{ .Query }}" 27 | vector: {{ .Vector }} 28 | } 29 | ) { 30 | key 31 | title 32 | source 33 | content 34 | _additional { 35 | id 36 | distance 37 | certainty 38 | } 39 | } 40 | } 41 | } -------------------------------------------------------------------------------- /pkg/limiter/limiter.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | type Limiter interface { 4 | limiterSetup() 5 | } 6 | -------------------------------------------------------------------------------- /pkg/limiter/provider_chain.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/chain" 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | 9 | "golang.org/x/time/rate" 10 | ) 11 | 12 | type Chain interface { 13 | Limiter 14 | chain.Provider 15 | } 16 | 17 | type limitedChain struct { 18 | limiter *rate.Limiter 19 | provider chain.Provider 20 | } 21 | 22 | func NewChain(l *rate.Limiter, p chain.Provider) Chain { 23 | return &limitedChain{ 24 | limiter: l, 25 | provider: p, 26 | } 27 | } 28 | 29 | func (p *limitedChain) limiterSetup() { 30 | } 31 | 32 | func (p *limitedChain) Complete(ctx context.Context, messages []provider.Message, options *provider.CompleteOptions) (*provider.Completion, error) { 33 | if p.limiter != nil { 34 | p.limiter.Wait(ctx) 35 | } 36 | 37 | return p.provider.Complete(ctx, messages, options) 38 | } 39 | -------------------------------------------------------------------------------- /pkg/limiter/provider_completer.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | type Completer interface { 12 | Limiter 13 | provider.Completer 14 | } 15 | 16 | type limitedCompleter struct { 17 | limiter *rate.Limiter 18 | provider provider.Completer 19 | } 20 | 21 | func NewCompleter(l *rate.Limiter, p provider.Completer) Completer { 22 | return &limitedCompleter{ 23 | limiter: l, 24 | provider: p, 25 | } 26 | } 27 | 28 | func (p *limitedCompleter) limiterSetup() { 29 | } 30 | 31 | func (p *limitedCompleter) Complete(ctx context.Context, messages []provider.Message, options *provider.CompleteOptions) (*provider.Completion, error) { 32 | if p.limiter != nil { 33 | p.limiter.Wait(ctx) 34 | } 35 | 36 | return p.provider.Complete(ctx, messages, options) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/limiter/provider_embedder.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | type Embedder interface { 12 | Limiter 13 | provider.Embedder 14 | } 15 | 16 | type limitedEmbedder struct { 17 | limiter *rate.Limiter 18 | provider provider.Embedder 19 | } 20 | 21 | func NewEmbedder(l *rate.Limiter, p provider.Embedder) Embedder { 22 | return &limitedEmbedder{ 23 | limiter: l, 24 | provider: p, 25 | } 26 | } 27 | 28 | func (p *limitedEmbedder) limiterSetup() { 29 | } 30 | 31 | func (p *limitedEmbedder) Embed(ctx context.Context, texts []string) (*provider.Embedding, error) { 32 | if p.limiter != nil { 33 | p.limiter.Wait(ctx) 34 | } 35 | 36 | return p.provider.Embed(ctx, texts) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/limiter/provider_extractor.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/extractor" 7 | 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | type Extractor interface { 12 | Limiter 13 | extractor.Provider 14 | } 15 | 16 | type limitedExtractor struct { 17 | limiter *rate.Limiter 18 | provider extractor.Provider 19 | } 20 | 21 | func NewExtractor(l *rate.Limiter, p extractor.Provider) Extractor { 22 | return &limitedExtractor{ 23 | limiter: l, 24 | provider: p, 25 | } 26 | } 27 | 28 | func (p *limitedExtractor) limiterSetup() { 29 | } 30 | 31 | func (p *limitedExtractor) Extract(ctx context.Context, input extractor.Input, options *extractor.ExtractOptions) (*extractor.Document, error) { 32 | if p.limiter != nil { 33 | p.limiter.Wait(ctx) 34 | } 35 | 36 | return p.provider.Extract(ctx, input, options) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/limiter/provider_renderer.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | type Renderer interface { 12 | Limiter 13 | provider.Renderer 14 | } 15 | 16 | type limitedRenderer struct { 17 | limiter *rate.Limiter 18 | provider provider.Renderer 19 | } 20 | 21 | func NewRenderer(l *rate.Limiter, p provider.Renderer) Renderer { 22 | return &limitedRenderer{ 23 | limiter: l, 24 | provider: p, 25 | } 26 | } 27 | 28 | func (p *limitedRenderer) limiterSetup() { 29 | } 30 | 31 | func (p *limitedRenderer) Render(ctx context.Context, input string, options *provider.RenderOptions) (*provider.Rendering, error) { 32 | if p.limiter != nil { 33 | p.limiter.Wait(ctx) 34 | } 35 | 36 | return p.provider.Render(ctx, input, options) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/limiter/provider_reranker.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | type Reranker interface { 12 | Limiter 13 | provider.Reranker 14 | } 15 | 16 | type limitedReranker struct { 17 | limiter *rate.Limiter 18 | provider provider.Reranker 19 | } 20 | 21 | func NewReranker(l *rate.Limiter, p provider.Reranker) Reranker { 22 | return &limitedReranker{ 23 | limiter: l, 24 | provider: p, 25 | } 26 | } 27 | 28 | func (p *limitedReranker) limiterSetup() { 29 | } 30 | 31 | func (p *limitedReranker) Rerank(ctx context.Context, query string, inputs []string, options *provider.RerankOptions) ([]provider.Ranking, error) { 32 | if p.limiter != nil { 33 | p.limiter.Wait(ctx) 34 | } 35 | 36 | return p.provider.Rerank(ctx, query, inputs, options) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/limiter/provider_segmenter.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/segmenter" 7 | 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | type Segmenter interface { 12 | Limiter 13 | segmenter.Provider 14 | } 15 | 16 | type limitedSegmenter struct { 17 | limiter *rate.Limiter 18 | provider segmenter.Provider 19 | } 20 | 21 | func NewSegmenter(l *rate.Limiter, p segmenter.Provider) Segmenter { 22 | return &limitedSegmenter{ 23 | limiter: l, 24 | provider: p, 25 | } 26 | } 27 | 28 | func (p *limitedSegmenter) limiterSetup() { 29 | } 30 | 31 | func (p *limitedSegmenter) Segment(ctx context.Context, input string, options *segmenter.SegmentOptions) ([]segmenter.Segment, error) { 32 | if p.limiter != nil { 33 | p.limiter.Wait(ctx) 34 | } 35 | 36 | return p.provider.Segment(ctx, input, options) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/limiter/provider_synthesizer.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | type Synthesizer interface { 12 | Limiter 13 | provider.Synthesizer 14 | } 15 | 16 | type limitedSynthesizer struct { 17 | limiter *rate.Limiter 18 | provider provider.Synthesizer 19 | } 20 | 21 | func NewSynthesizer(l *rate.Limiter, p provider.Synthesizer) Synthesizer { 22 | return &limitedSynthesizer{ 23 | limiter: l, 24 | provider: p, 25 | } 26 | } 27 | 28 | func (p *limitedSynthesizer) limiterSetup() { 29 | } 30 | 31 | func (p *limitedSynthesizer) Synthesize(ctx context.Context, content string, options *provider.SynthesizeOptions) (*provider.Synthesis, error) { 32 | if p.limiter != nil { 33 | p.limiter.Wait(ctx) 34 | } 35 | 36 | return p.provider.Synthesize(ctx, content, options) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/limiter/provider_transcriber.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | type Transcriber interface { 12 | Limiter 13 | provider.Transcriber 14 | } 15 | 16 | type limitedTranscriber struct { 17 | limiter *rate.Limiter 18 | provider provider.Transcriber 19 | } 20 | 21 | func NewTranscriber(l *rate.Limiter, p provider.Transcriber) Transcriber { 22 | return &limitedTranscriber{ 23 | limiter: l, 24 | provider: p, 25 | } 26 | } 27 | 28 | func (p *limitedTranscriber) limiterSetup() { 29 | } 30 | 31 | func (p *limitedTranscriber) Transcribe(ctx context.Context, input provider.File, options *provider.TranscribeOptions) (*provider.Transcription, error) { 32 | if p.limiter != nil { 33 | p.limiter.Wait(ctx) 34 | } 35 | 36 | return p.provider.Transcribe(ctx, input, options) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/limiter/provider_translator.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/translator" 7 | 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | type Translator interface { 12 | Limiter 13 | translator.Provider 14 | } 15 | 16 | type limitedTranslator struct { 17 | limiter *rate.Limiter 18 | provider translator.Provider 19 | } 20 | 21 | func NewTranslator(l *rate.Limiter, p translator.Provider) Translator { 22 | return &limitedTranslator{ 23 | limiter: l, 24 | provider: p, 25 | } 26 | } 27 | 28 | func (p *limitedTranslator) limiterSetup() { 29 | } 30 | 31 | func (p *limitedTranslator) Translate(ctx context.Context, content string, options *translator.TranslateOptions) (*translator.Translation, error) { 32 | if p.limiter != nil { 33 | p.limiter.Wait(ctx) 34 | } 35 | 36 | return p.provider.Translate(ctx, content, options) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/otel/otel.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | 6 | "go.opentelemetry.io/otel/sdk/resource" 7 | semconv "go.opentelemetry.io/otel/semconv/v1.26.0" 8 | ) 9 | 10 | func Setup(serviceName, serviceVersion string) error { 11 | if !EnableTelemetry { 12 | return nil 13 | } 14 | 15 | ctx := context.Background() 16 | 17 | resource, err := resource.Merge( 18 | resource.Default(), 19 | resource.NewWithAttributes( 20 | semconv.SchemaURL, 21 | semconv.ServiceName(serviceName), 22 | semconv.ServiceVersion(serviceVersion), 23 | ), 24 | ) 25 | 26 | if err != nil { 27 | return err 28 | } 29 | 30 | if err := newPropagator(resource); err != nil { 31 | return err 32 | } 33 | 34 | if err := setupTracer(ctx, resource); err != nil { 35 | return err 36 | } 37 | 38 | if err := setupMeter(ctx, resource); err != nil { 39 | return err 40 | } 41 | 42 | if err := setupLogger(ctx, resource); err != nil { 43 | return err 44 | } 45 | 46 | if err := setupHTTP(ctx, resource); err != nil { 47 | return err 48 | } 49 | 50 | return nil 51 | } 52 | -------------------------------------------------------------------------------- /pkg/otel/otel_http.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | sdkresource "go.opentelemetry.io/otel/sdk/resource" 8 | 9 | "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" 10 | ) 11 | 12 | func setupHTTP(_ context.Context, _ *sdkresource.Resource) error { 13 | rt := Transport(http.DefaultTransport) 14 | 15 | http.DefaultTransport = rt 16 | 17 | return nil 18 | } 19 | 20 | func Transport(rt http.RoundTripper) http.RoundTripper { 21 | if rt == nil { 22 | rt = http.DefaultTransport 23 | } 24 | 25 | return otelhttp.NewTransport(rt) 26 | } 27 | -------------------------------------------------------------------------------- /pkg/otel/otel_logger.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | 7 | "go.opentelemetry.io/otel/log/global" 8 | 9 | sdklog "go.opentelemetry.io/otel/sdk/log" 10 | sdkresource "go.opentelemetry.io/otel/sdk/resource" 11 | 12 | "go.opentelemetry.io/contrib/bridges/otelslog" 13 | "go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp" 14 | ) 15 | 16 | func setupLogger(ctx context.Context, resource *sdkresource.Resource) error { 17 | exporter, err := otlploghttp.New(ctx) 18 | 19 | if err != nil { 20 | return err 21 | } 22 | 23 | provider := sdklog.NewLoggerProvider( 24 | sdklog.WithProcessor(sdklog.NewBatchProcessor(exporter)), 25 | sdklog.WithResource(resource), 26 | ) 27 | 28 | global.SetLoggerProvider(provider) 29 | 30 | logger := otelslog.NewLogger("", otelslog.WithLoggerProvider(provider)) 31 | slog.SetDefault(logger) 32 | 33 | return nil 34 | } 35 | -------------------------------------------------------------------------------- /pkg/otel/otel_meter.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "go.opentelemetry.io/otel" 8 | 9 | sdkmetric "go.opentelemetry.io/otel/sdk/metric" 10 | sdkresource "go.opentelemetry.io/otel/sdk/resource" 11 | 12 | "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp" 13 | ) 14 | 15 | func setupMeter(ctx context.Context, resource *sdkresource.Resource) error { 16 | exporter, err := otlpmetrichttp.New(ctx) 17 | 18 | if err != nil { 19 | return err 20 | } 21 | 22 | provider := sdkmetric.NewMeterProvider( 23 | sdkmetric.WithReader(sdkmetric.NewPeriodicReader(exporter, sdkmetric.WithInterval(3*time.Second))), 24 | sdkmetric.WithResource(resource), 25 | ) 26 | 27 | otel.SetMeterProvider(provider) 28 | 29 | return nil 30 | } 31 | -------------------------------------------------------------------------------- /pkg/otel/otel_propagator.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "go.opentelemetry.io/otel" 5 | "go.opentelemetry.io/otel/propagation" 6 | 7 | sdkresource "go.opentelemetry.io/otel/sdk/resource" 8 | ) 9 | 10 | func newPropagator(_ *sdkresource.Resource) error { 11 | propagator := propagation.NewCompositeTextMapPropagator( 12 | propagation.TraceContext{}, 13 | propagation.Baggage{}, 14 | ) 15 | 16 | otel.SetTextMapPropagator(propagator) 17 | 18 | return nil 19 | } 20 | -------------------------------------------------------------------------------- /pkg/otel/otel_tracer.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "go.opentelemetry.io/otel" 8 | 9 | sdkresource "go.opentelemetry.io/otel/sdk/resource" 10 | sdktrace "go.opentelemetry.io/otel/sdk/trace" 11 | 12 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" 13 | ) 14 | 15 | func setupTracer(ctx context.Context, resource *sdkresource.Resource) error { 16 | exporter, err := otlptracehttp.New(ctx) 17 | 18 | if err != nil { 19 | return err 20 | } 21 | 22 | provider := sdktrace.NewTracerProvider( 23 | sdktrace.WithSampler(sdktrace.AlwaysSample()), 24 | sdktrace.WithBatcher(exporter, sdktrace.WithBatchTimeout(time.Second)), 25 | sdktrace.WithResource(resource), 26 | ) 27 | 28 | otel.SetTracerProvider(provider) 29 | 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /pkg/otel/otel_utils.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "go.opentelemetry.io/otel/attribute" 5 | ) 6 | 7 | type KeyValue = attribute.KeyValue 8 | 9 | func String(key string, val string) KeyValue { 10 | return attribute.String(key, val) 11 | } 12 | 13 | func Strings(key string, val []string) KeyValue { 14 | return attribute.StringSlice(key, val) 15 | } 16 | -------------------------------------------------------------------------------- /pkg/otel/provider.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "strings" 7 | 8 | "go.opentelemetry.io/otel" 9 | "go.opentelemetry.io/otel/attribute" 10 | "go.opentelemetry.io/otel/metric" 11 | ) 12 | 13 | var ( 14 | EnableDebug = false 15 | EnableTelemetry = false 16 | ) 17 | 18 | func init() { 19 | EnableDebug = os.Getenv("DEBUG") != "" 20 | EnableTelemetry = os.Getenv("TELEMETRY") != "" 21 | } 22 | 23 | type Observable interface { 24 | otelSetup() 25 | } 26 | 27 | func meterRequest(ctx context.Context, library, provider, operation, model string) { 28 | meter, _ := otel.Meter(library).Int64Counter("llm_requests_total") 29 | 30 | meter.Add(ctx, 1, metric.WithAttributes( 31 | attribute.String("provider", strings.ToLower(provider)), 32 | attribute.String("operation", strings.ToLower(operation)), 33 | attribute.String("model", strings.ToLower(model)), 34 | )) 35 | } 36 | 37 | func meterTokens(ctx context.Context, library, provider, operation, model string, val int64) { 38 | meter, _ := otel.Meter(library).Int64Counter("llm_tokens_total") 39 | 40 | meter.Add(ctx, val, metric.WithAttributes( 41 | attribute.String("provider", strings.ToLower(provider)), 42 | attribute.String("operation", strings.ToLower(operation)), 43 | attribute.String("model", strings.ToLower(model)), 44 | )) 45 | } 46 | -------------------------------------------------------------------------------- /pkg/otel/provider_embedder.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | ) 12 | 13 | type Embedder interface { 14 | Observable 15 | provider.Embedder 16 | } 17 | 18 | type observableEmbedder struct { 19 | name string 20 | library string 21 | 22 | model string 23 | provider string 24 | 25 | embedder provider.Embedder 26 | } 27 | 28 | func NewEmbedder(provider, model string, p provider.Embedder) Embedder { 29 | library := strings.ToLower(provider) 30 | 31 | return &observableEmbedder{ 32 | embedder: p, 33 | 34 | name: strings.TrimSuffix(strings.ToLower(provider), "-embedder") + "-embedder", 35 | library: library, 36 | 37 | model: model, 38 | provider: provider, 39 | } 40 | } 41 | 42 | func (p *observableEmbedder) otelSetup() { 43 | } 44 | 45 | func (p *observableEmbedder) Embed(ctx context.Context, texts []string) (*provider.Embedding, error) { 46 | ctx, span := otel.Tracer(p.library).Start(ctx, p.name) 47 | defer span.End() 48 | 49 | result, err := p.embedder.Embed(ctx, texts) 50 | 51 | meterRequest(ctx, p.library, p.provider, "embed", p.model) 52 | 53 | if EnableDebug { 54 | span.SetAttributes(attribute.StringSlice("texts", texts)) 55 | } 56 | 57 | if result != nil { 58 | if result.Usage != nil { 59 | tokens := int64(result.Usage.InputTokens) + int64(result.Usage.OutputTokens) 60 | meterTokens(ctx, p.library, p.provider, "embed", p.model, tokens) 61 | } 62 | } 63 | 64 | return result, err 65 | } 66 | -------------------------------------------------------------------------------- /pkg/otel/provider_extractor.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/extractor" 8 | "go.opentelemetry.io/otel" 9 | ) 10 | 11 | type Extractor interface { 12 | Observable 13 | extractor.Provider 14 | } 15 | 16 | type observableExtractor struct { 17 | name string 18 | library string 19 | 20 | provider string 21 | 22 | extractor extractor.Provider 23 | } 24 | 25 | func NewExtractor(provider string, p extractor.Provider) Extractor { 26 | library := strings.ToLower(provider) 27 | 28 | return &observableExtractor{ 29 | extractor: p, 30 | 31 | name: strings.TrimSuffix(strings.ToLower(provider), "-extractor") + "-extractor", 32 | library: library, 33 | 34 | provider: provider, 35 | } 36 | } 37 | 38 | func (p *observableExtractor) otelSetup() { 39 | } 40 | 41 | func (p *observableExtractor) Extract(ctx context.Context, input extractor.Input, options *extractor.ExtractOptions) (*extractor.Document, error) { 42 | ctx, span := otel.Tracer(p.library).Start(ctx, p.name) 43 | defer span.End() 44 | 45 | result, err := p.extractor.Extract(ctx, input, options) 46 | 47 | return result, err 48 | } 49 | -------------------------------------------------------------------------------- /pkg/otel/provider_renderer.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | ) 12 | 13 | type Renderer interface { 14 | Observable 15 | provider.Renderer 16 | } 17 | 18 | type observableRenderer struct { 19 | name string 20 | library string 21 | 22 | model string 23 | provider string 24 | 25 | renderer provider.Renderer 26 | } 27 | 28 | func NewRenderer(provider, model string, p provider.Renderer) Renderer { 29 | library := strings.ToLower(provider) 30 | 31 | return &observableRenderer{ 32 | renderer: p, 33 | 34 | name: strings.TrimSuffix(strings.ToLower(provider), "-renderer") + "-renderer", 35 | library: library, 36 | 37 | model: model, 38 | provider: provider, 39 | } 40 | } 41 | 42 | func (p *observableRenderer) otelSetup() { 43 | } 44 | 45 | func (p *observableRenderer) Render(ctx context.Context, input string, options *provider.RenderOptions) (*provider.Rendering, error) { 46 | ctx, span := otel.Tracer(p.library).Start(ctx, p.name) 47 | defer span.End() 48 | 49 | result, err := p.renderer.Render(ctx, input, options) 50 | 51 | meterRequest(ctx, p.library, p.provider, "render", p.model) 52 | 53 | if EnableDebug { 54 | span.SetAttributes(attribute.String("input", input)) 55 | } 56 | 57 | return result, err 58 | } 59 | -------------------------------------------------------------------------------- /pkg/otel/provider_reranker.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | ) 12 | 13 | type Reranker interface { 14 | Observable 15 | provider.Reranker 16 | } 17 | 18 | type observableReranker struct { 19 | name string 20 | library string 21 | 22 | model string 23 | provider string 24 | 25 | reranker provider.Reranker 26 | } 27 | 28 | func NewReranker(provider, model string, p provider.Reranker) Reranker { 29 | library := strings.ToLower(provider) 30 | 31 | return &observableReranker{ 32 | reranker: p, 33 | 34 | name: strings.TrimSuffix(strings.ToLower(provider), "-reranker") + "-reranker", 35 | library: library, 36 | 37 | model: model, 38 | provider: provider, 39 | } 40 | } 41 | 42 | func (p *observableReranker) otelSetup() { 43 | } 44 | 45 | func (p *observableReranker) Rerank(ctx context.Context, query string, inputs []string, options *provider.RerankOptions) ([]provider.Ranking, error) { 46 | ctx, span := otel.Tracer(p.library).Start(ctx, p.name) 47 | defer span.End() 48 | 49 | result, err := p.reranker.Rerank(ctx, query, inputs, options) 50 | 51 | meterRequest(ctx, p.library, p.provider, "rerank", p.model) 52 | 53 | if EnableDebug { 54 | span.SetAttributes(attribute.String("query", query)) 55 | span.SetAttributes(attribute.StringSlice("inputs", inputs)) 56 | } 57 | 58 | return result, err 59 | } 60 | -------------------------------------------------------------------------------- /pkg/otel/provider_segmenter.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/segmenter" 8 | "go.opentelemetry.io/otel" 9 | ) 10 | 11 | type Segmenter interface { 12 | Observable 13 | segmenter.Provider 14 | } 15 | 16 | type observableSegmenter struct { 17 | name string 18 | library string 19 | 20 | provider string 21 | 22 | segmenter segmenter.Provider 23 | } 24 | 25 | func NewSegmenter(provider string, p segmenter.Provider) Segmenter { 26 | library := strings.ToLower(provider) 27 | 28 | return &observableSegmenter{ 29 | segmenter: p, 30 | 31 | name: strings.TrimSuffix(strings.ToLower(provider), "-segmenter") + "-segmenter", 32 | library: library, 33 | 34 | provider: provider, 35 | } 36 | } 37 | 38 | func (p *observableSegmenter) otelSetup() { 39 | } 40 | 41 | func (p *observableSegmenter) Segment(ctx context.Context, input string, options *segmenter.SegmentOptions) ([]segmenter.Segment, error) { 42 | ctx, span := otel.Tracer(p.library).Start(ctx, p.name) 43 | defer span.End() 44 | 45 | result, err := p.segmenter.Segment(ctx, input, options) 46 | 47 | return result, err 48 | } 49 | -------------------------------------------------------------------------------- /pkg/otel/provider_synthesizer.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | ) 12 | 13 | type Synthesizer interface { 14 | Observable 15 | provider.Synthesizer 16 | } 17 | 18 | type observableSynthesizer struct { 19 | name string 20 | library string 21 | 22 | model string 23 | provider string 24 | 25 | synthesizer provider.Synthesizer 26 | } 27 | 28 | func NewSynthesizer(provider, model string, p provider.Synthesizer) Synthesizer { 29 | library := strings.ToLower(provider) 30 | 31 | return &observableSynthesizer{ 32 | synthesizer: p, 33 | 34 | name: strings.TrimSuffix(strings.ToLower(provider), "-synthesizer") + "-synthesizer", 35 | library: library, 36 | 37 | model: model, 38 | provider: provider, 39 | } 40 | } 41 | 42 | func (p *observableSynthesizer) otelSetup() { 43 | } 44 | 45 | func (p *observableSynthesizer) Synthesize(ctx context.Context, content string, options *provider.SynthesizeOptions) (*provider.Synthesis, error) { 46 | ctx, span := otel.Tracer(p.library).Start(ctx, p.name) 47 | defer span.End() 48 | 49 | result, err := p.synthesizer.Synthesize(ctx, content, options) 50 | 51 | meterRequest(ctx, p.library, p.provider, "synthesize", p.model) 52 | 53 | if EnableDebug { 54 | span.SetAttributes(attribute.String("input", content)) 55 | } 56 | 57 | return result, err 58 | } 59 | -------------------------------------------------------------------------------- /pkg/otel/provider_tool.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/tool" 8 | 9 | "go.opentelemetry.io/otel" 10 | ) 11 | 12 | type Tool interface { 13 | Observable 14 | tool.Provider 15 | } 16 | 17 | type observableTool struct { 18 | name string 19 | library string 20 | 21 | provider string 22 | 23 | tool tool.Provider 24 | } 25 | 26 | func NewTool(provider string, p tool.Provider) Tool { 27 | library := strings.ToLower(provider) 28 | 29 | return &observableTool{ 30 | tool: p, 31 | 32 | name: strings.TrimSuffix(strings.ToLower(provider), "-tool") + "-tool", 33 | library: library, 34 | 35 | provider: provider, 36 | } 37 | } 38 | 39 | func (p *observableTool) otelSetup() { 40 | } 41 | 42 | func (p *observableTool) Tools(ctx context.Context) ([]tool.Tool, error) { 43 | ctx, span := otel.Tracer(p.library).Start(ctx, p.name) 44 | defer span.End() 45 | 46 | tools, err := p.tool.Tools(ctx) 47 | 48 | return tools, err 49 | } 50 | 51 | func (p *observableTool) Execute(ctx context.Context, tool string, parameters map[string]any) (any, error) { 52 | ctx, span := otel.Tracer(p.library).Start(ctx, p.name) 53 | defer span.End() 54 | 55 | result, err := p.tool.Execute(ctx, tool, parameters) 56 | 57 | return result, err 58 | } 59 | -------------------------------------------------------------------------------- /pkg/otel/provider_transcriber.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | 9 | "go.opentelemetry.io/otel" 10 | ) 11 | 12 | type Transcriber interface { 13 | Observable 14 | provider.Transcriber 15 | } 16 | 17 | type observableTranscriber struct { 18 | name string 19 | library string 20 | 21 | model string 22 | provider string 23 | 24 | transcriber provider.Transcriber 25 | } 26 | 27 | func NewTranscriber(provider, model string, p provider.Transcriber) Transcriber { 28 | library := strings.ToLower(provider) 29 | 30 | return &observableTranscriber{ 31 | transcriber: p, 32 | 33 | name: strings.TrimSuffix(strings.ToLower(provider), "-transcriber") + "-transcriber", 34 | library: library, 35 | 36 | model: model, 37 | provider: provider, 38 | } 39 | } 40 | 41 | func (p *observableTranscriber) otelSetup() { 42 | } 43 | 44 | func (p *observableTranscriber) Transcribe(ctx context.Context, input provider.File, options *provider.TranscribeOptions) (*provider.Transcription, error) { 45 | ctx, span := otel.Tracer(p.library).Start(ctx, p.name) 46 | defer span.End() 47 | 48 | result, err := p.transcriber.Transcribe(ctx, input, options) 49 | 50 | meterRequest(ctx, p.library, p.provider, "transcribe", p.model) 51 | 52 | return result, err 53 | } 54 | -------------------------------------------------------------------------------- /pkg/otel/provider_translator.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/translator" 8 | 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | ) 12 | 13 | type Translator interface { 14 | Observable 15 | translator.Provider 16 | } 17 | 18 | type observableTranslator struct { 19 | name string 20 | library string 21 | 22 | model string 23 | provider string 24 | 25 | translator translator.Provider 26 | } 27 | 28 | func NewTranslator(provider, model string, p translator.Provider) Translator { 29 | library := strings.ToLower(provider) 30 | 31 | return &observableTranslator{ 32 | translator: p, 33 | 34 | name: strings.TrimSuffix(strings.ToLower(provider), "-translator") + "-translator", 35 | library: library, 36 | 37 | model: model, 38 | provider: provider, 39 | } 40 | } 41 | 42 | func (p *observableTranslator) otelSetup() { 43 | } 44 | 45 | func (p *observableTranslator) Translate(ctx context.Context, content string, options *translator.TranslateOptions) (*translator.Translation, error) { 46 | ctx, span := otel.Tracer(p.library).Start(ctx, p.name) 47 | defer span.End() 48 | 49 | result, err := p.translator.Translate(ctx, content, options) 50 | 51 | meterRequest(ctx, p.library, p.provider, "translate", p.model) 52 | 53 | if EnableDebug { 54 | span.SetAttributes(attribute.String("input", content)) 55 | 56 | if result != nil { 57 | if result.Text != "" { 58 | span.SetAttributes(attribute.String("output", result.Text)) 59 | } 60 | } 61 | } 62 | 63 | return result, err 64 | } 65 | -------------------------------------------------------------------------------- /pkg/provider/anthropic/config.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | 7 | "github.com/anthropics/anthropic-sdk-go/option" 8 | ) 9 | 10 | type Config struct { 11 | url string 12 | 13 | token string 14 | model string 15 | 16 | client *http.Client 17 | } 18 | 19 | type Option func(*Config) 20 | 21 | func WithClient(client *http.Client) Option { 22 | return func(c *Config) { 23 | c.client = client 24 | } 25 | } 26 | 27 | func WithToken(token string) Option { 28 | return func(c *Config) { 29 | c.token = token 30 | } 31 | } 32 | 33 | func (c *Config) Options() []option.RequestOption { 34 | if c.url == "" { 35 | c.url = "https://api.anthropic.com/" 36 | } 37 | 38 | c.url = strings.TrimRight(c.url, "/") + "/" 39 | 40 | options := []option.RequestOption{ 41 | option.WithBaseURL(c.url), 42 | } 43 | 44 | if c.client != nil { 45 | options = append(options, option.WithHTTPClient(c.client)) 46 | } 47 | 48 | if c.token != "" { 49 | options = append(options, option.WithAPIKey(c.token)) 50 | } 51 | 52 | return options 53 | } 54 | -------------------------------------------------------------------------------- /pkg/provider/anthropic/util.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/anthropics/anthropic-sdk-go" 7 | ) 8 | 9 | func convertError(err error) error { 10 | var apierr *anthropic.Error 11 | 12 | if errors.As(err, &apierr) { 13 | //println(string(apierr.DumpRequest(true))) // Prints the serialized HTTP request 14 | //println(string(apierr.DumpResponse(true))) // Prints the serialized HTTP response 15 | } 16 | 17 | return err 18 | } 19 | -------------------------------------------------------------------------------- /pkg/provider/azure/completer.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "github.com/adrianliechti/wingman/pkg/provider/openai" 5 | ) 6 | 7 | type Completer = openai.Completer 8 | 9 | func NewCompleter(url, model string, options ...Option) (*Completer, error) { 10 | if url == "" { 11 | url = "https://models.inference.ai.azure.com" 12 | } 13 | 14 | cfg := &Config{} 15 | 16 | for _, option := range options { 17 | option(cfg) 18 | } 19 | 20 | return openai.NewCompleter(url, model, cfg.options...) 21 | } 22 | -------------------------------------------------------------------------------- /pkg/provider/azure/config.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider/openai" 7 | ) 8 | 9 | type Config struct { 10 | options []openai.Option 11 | } 12 | 13 | type Option func(*Config) 14 | 15 | func WithClient(client *http.Client) Option { 16 | return func(c *Config) { 17 | c.options = append(c.options, openai.WithClient(client)) 18 | } 19 | } 20 | 21 | func WithToken(token string) Option { 22 | return func(c *Config) { 23 | c.options = append(c.options, openai.WithToken(token)) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /pkg/provider/azure/embedder.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "github.com/adrianliechti/wingman/pkg/provider/openai" 5 | ) 6 | 7 | type Embedder = openai.Embedder 8 | 9 | func NewEmbedder(url, model string, options ...Option) (*Embedder, error) { 10 | if url == "" { 11 | url = "https://models.inference.ai.azure.com" 12 | } 13 | 14 | cfg := &Config{} 15 | 16 | for _, option := range options { 17 | option(cfg) 18 | } 19 | 20 | return openai.NewEmbedder(url, model, cfg.options...) 21 | } 22 | -------------------------------------------------------------------------------- /pkg/provider/bedrock/config.go: -------------------------------------------------------------------------------- 1 | package bedrock 2 | 3 | type Config struct { 4 | model string 5 | } 6 | 7 | type Option func(*Config) 8 | -------------------------------------------------------------------------------- /pkg/provider/cohere/config.go: -------------------------------------------------------------------------------- 1 | package cohere 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/cohere-ai/cohere-go/v2/option" 7 | ) 8 | 9 | type Config struct { 10 | token string 11 | model string 12 | 13 | client *http.Client 14 | } 15 | 16 | type Option func(*Config) 17 | 18 | func WithClient(client *http.Client) Option { 19 | return func(c *Config) { 20 | c.client = client 21 | } 22 | } 23 | 24 | func WithToken(token string) Option { 25 | return func(c *Config) { 26 | c.token = token 27 | } 28 | } 29 | 30 | func (c *Config) Options() []option.RequestOption { 31 | options := []option.RequestOption{} 32 | 33 | if c.client != nil { 34 | options = append(options, option.WithHTTPClient(c.client)) 35 | } 36 | 37 | if c.token != "" { 38 | options = append(options, option.WithToken(c.token)) 39 | } 40 | 41 | return options 42 | } 43 | -------------------------------------------------------------------------------- /pkg/provider/cohere/embedder.go: -------------------------------------------------------------------------------- 1 | package cohere 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | 8 | v2 "github.com/cohere-ai/cohere-go/v2" 9 | client "github.com/cohere-ai/cohere-go/v2/v2" 10 | ) 11 | 12 | var _ provider.Embedder = (*Embedder)(nil) 13 | 14 | type Embedder struct { 15 | *Config 16 | client *client.Client 17 | } 18 | 19 | func NewEmbedder(model string, options ...Option) (*Embedder, error) { 20 | cfg := &Config{ 21 | model: model, 22 | } 23 | 24 | for _, option := range options { 25 | option(cfg) 26 | } 27 | 28 | return &Embedder{ 29 | Config: cfg, 30 | client: client.NewClient(cfg.Options()...), 31 | }, nil 32 | } 33 | 34 | func (e *Embedder) Embed(ctx context.Context, texts []string) (*provider.Embedding, error) { 35 | req := &v2.V2EmbedRequest{ 36 | Model: e.model, 37 | 38 | Texts: texts, 39 | 40 | InputType: v2.EmbedInputTypeSearchDocument, 41 | 42 | EmbeddingTypes: []v2.EmbeddingType{ 43 | v2.EmbeddingTypeFloat, 44 | }, 45 | } 46 | 47 | resp, err := e.client.Embed(ctx, req) 48 | 49 | if err != nil { 50 | return nil, convertError(err) 51 | } 52 | 53 | result := &provider.Embedding{ 54 | Model: e.model, 55 | } 56 | 57 | for _, e := range resp.Embeddings.Float { 58 | result.Embeddings = append(result.Embeddings, toFloat32(e)) 59 | } 60 | 61 | return result, nil 62 | } 63 | -------------------------------------------------------------------------------- /pkg/provider/cohere/util.go: -------------------------------------------------------------------------------- 1 | package cohere 2 | 3 | func convertError(err error) error { 4 | return err 5 | } 6 | 7 | func toFloat32(input []float64) []float32 { 8 | result := make([]float32, len(input)) 9 | 10 | for i, v := range input { 11 | result[i] = float32(v) 12 | } 13 | 14 | return result 15 | } 16 | -------------------------------------------------------------------------------- /pkg/provider/custom/Taskfile.yaml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | generate: 7 | cmds: 8 | - protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative provider.proto -------------------------------------------------------------------------------- /pkg/provider/custom/client.go: -------------------------------------------------------------------------------- 1 | package custom 2 | 3 | type Config struct { 4 | } 5 | 6 | type Option func(*Config) 7 | -------------------------------------------------------------------------------- /pkg/provider/embedder.go: -------------------------------------------------------------------------------- 1 | package provider 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type Embedder interface { 8 | Embed(ctx context.Context, texts []string) (*Embedding, error) 9 | } 10 | 11 | type Embedding struct { 12 | Model string 13 | 14 | Embeddings [][]float32 15 | 16 | Usage *Usage 17 | } 18 | -------------------------------------------------------------------------------- /pkg/provider/gemini/config.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "net/http" 5 | 6 | "google.golang.org/api/option" 7 | ) 8 | 9 | type Config struct { 10 | token string 11 | model string 12 | 13 | client *http.Client 14 | } 15 | 16 | type Option func(*Config) 17 | 18 | func WithClient(client *http.Client) Option { 19 | return func(c *Config) { 20 | c.client = client 21 | } 22 | } 23 | 24 | func WithToken(token string) Option { 25 | return func(c *Config) { 26 | c.token = token 27 | } 28 | } 29 | 30 | func (c *Config) Options() []option.ClientOption { 31 | options := []option.ClientOption{} 32 | 33 | if c.client != nil { 34 | options = append(options, option.WithHTTPClient(c.client)) 35 | } 36 | 37 | if c.token != "" { 38 | options = append(options, option.WithAPIKey(c.token)) 39 | } 40 | 41 | return options 42 | } 43 | -------------------------------------------------------------------------------- /pkg/provider/gemini/embedder.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | "github.com/google/generative-ai-go/genai" 8 | ) 9 | 10 | var _ provider.Embedder = (*Embedder)(nil) 11 | 12 | type Embedder struct { 13 | *Config 14 | } 15 | 16 | func NewEmbedder(model string, options ...Option) (*Embedder, error) { 17 | cfg := &Config{ 18 | model: model, 19 | } 20 | 21 | for _, option := range options { 22 | option(cfg) 23 | } 24 | 25 | return &Embedder{ 26 | Config: cfg, 27 | }, nil 28 | } 29 | 30 | func (e *Embedder) Embed(ctx context.Context, texts []string) (*provider.Embedding, error) { 31 | client, err := genai.NewClient(ctx, e.Options()...) 32 | 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | defer client.Close() 38 | 39 | model := client.EmbeddingModel(e.model) 40 | 41 | batch := model.NewBatch() 42 | 43 | for _, text := range texts { 44 | batch.AddContent(genai.Text(text)) 45 | } 46 | 47 | resp, err := model.BatchEmbedContents(ctx, batch) 48 | 49 | if err != nil { 50 | return nil, convertError(err) 51 | } 52 | 53 | result := &provider.Embedding{ 54 | Model: e.model, 55 | } 56 | 57 | for _, e := range resp.Embeddings { 58 | result.Embeddings = append(result.Embeddings, e.Values) 59 | } 60 | 61 | return result, nil 62 | } 63 | -------------------------------------------------------------------------------- /pkg/provider/gemini/util.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "errors" 5 | 6 | "google.golang.org/api/googleapi" 7 | ) 8 | 9 | func convertError(err error) error { 10 | var apierr *googleapi.Error 11 | 12 | if errors.As(err, &apierr) { 13 | return errors.New(apierr.Body) 14 | } 15 | 16 | return err 17 | } 18 | -------------------------------------------------------------------------------- /pkg/provider/groq/completer.go: -------------------------------------------------------------------------------- 1 | package groq 2 | 3 | import ( 4 | "github.com/adrianliechti/wingman/pkg/provider/openai" 5 | ) 6 | 7 | type Completer = openai.Completer 8 | 9 | func NewCompleter(model string, options ...Option) (*Completer, error) { 10 | url := "https://api.groq.com/openai/v1" 11 | 12 | cfg := &Config{} 13 | 14 | for _, option := range options { 15 | option(cfg) 16 | } 17 | 18 | return openai.NewCompleter(url, model, cfg.options...) 19 | } 20 | -------------------------------------------------------------------------------- /pkg/provider/groq/completer_test.go: -------------------------------------------------------------------------------- 1 | package groq_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/adrianliechti/wingman/pkg/provider" 9 | "github.com/adrianliechti/wingman/pkg/provider/groq" 10 | 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestCompleter(t *testing.T) { 15 | ctx := context.Background() 16 | token := os.Getenv("GROQ_API_TOKEN") 17 | model := "llama-3.1-8b-instant" 18 | 19 | if token == "" { 20 | t.Skip("GROQ_API_TOKEN required for this test") 21 | } 22 | 23 | c, err := groq.NewCompleter(model, groq.WithToken(token)) 24 | require.NoError(t, err) 25 | 26 | result, err := c.Complete(ctx, []provider.Message{ 27 | provider.UserMessage("Hello!"), 28 | }, nil) 29 | 30 | require.NoError(t, err) 31 | require.NotEmpty(t, result.Message.Content) 32 | 33 | t.Log(result.Message.Content) 34 | } 35 | -------------------------------------------------------------------------------- /pkg/provider/groq/config.go: -------------------------------------------------------------------------------- 1 | package groq 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider/openai" 7 | ) 8 | 9 | type Config struct { 10 | options []openai.Option 11 | } 12 | 13 | type Option func(*Config) 14 | 15 | func WithClient(client *http.Client) Option { 16 | return func(c *Config) { 17 | c.options = append(c.options, openai.WithClient(client)) 18 | } 19 | } 20 | 21 | func WithToken(token string) Option { 22 | return func(c *Config) { 23 | c.options = append(c.options, openai.WithToken(token)) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /pkg/provider/groq/transcriber.go: -------------------------------------------------------------------------------- 1 | package groq 2 | 3 | import ( 4 | "github.com/adrianliechti/wingman/pkg/provider/openai" 5 | ) 6 | 7 | type Transcriber = openai.Transcriber 8 | 9 | func NewTranscriber(url, model string, options ...Option) (*Transcriber, error) { 10 | if url == "" { 11 | url = "https://api.groq.com/openai/v1" 12 | } 13 | 14 | cfg := &Config{} 15 | 16 | for _, option := range options { 17 | option(cfg) 18 | } 19 | 20 | return openai.NewTranscriber(url, model, cfg.options...) 21 | } 22 | -------------------------------------------------------------------------------- /pkg/provider/groq/transcriber_test.go: -------------------------------------------------------------------------------- 1 | package groq_test 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | "os" 8 | "testing" 9 | 10 | "github.com/adrianliechti/wingman/pkg/provider" 11 | "github.com/adrianliechti/wingman/pkg/provider/groq" 12 | 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestTranscriber(t *testing.T) { 17 | ctx := context.Background() 18 | token := os.Getenv("GROQ_API_TOKEN") 19 | model := "whisper-large-v3" 20 | 21 | if token == "" { 22 | t.Skip("GROQ_API_TOKEN required for this test") 23 | } 24 | 25 | p, err := groq.NewTranscriber("", model, groq.WithToken(token)) 26 | require.NoError(t, err) 27 | 28 | resp, err := http.Get("https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav") 29 | require.NoError(t, err) 30 | defer resp.Body.Close() 31 | 32 | data, err := io.ReadAll(resp.Body) 33 | require.NoError(t, err) 34 | 35 | result, err := p.Transcribe(ctx, provider.File{ 36 | Name: "jfk.wav", 37 | 38 | Content: data, 39 | ContentType: "audio/wav", 40 | }, nil) 41 | 42 | require.NoError(t, err) 43 | require.NotEmpty(t, result.Text) 44 | 45 | t.Log(result.Text) 46 | } 47 | -------------------------------------------------------------------------------- /pkg/provider/huggingface/completer.go: -------------------------------------------------------------------------------- 1 | package huggingface 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider/openai" 8 | ) 9 | 10 | type Completer = openai.Completer 11 | 12 | func NewCompleter(url, model string, options ...Option) (*Completer, error) { 13 | if url == "" { 14 | url = "https://api-inference.huggingface.co/models/" + model 15 | } 16 | 17 | url = strings.TrimRight(url, "/") 18 | url = strings.TrimRight(url, "/v1") 19 | 20 | cfg := &Config{ 21 | client: http.DefaultClient, 22 | 23 | url: url, 24 | token: "-", 25 | 26 | model: "tgi", 27 | } 28 | 29 | for _, option := range options { 30 | option(cfg) 31 | } 32 | 33 | ops := []openai.Option{} 34 | 35 | if cfg.client != nil { 36 | ops = append(ops, openai.WithClient(cfg.client)) 37 | } 38 | 39 | if cfg.token != "" { 40 | ops = append(ops, openai.WithToken(cfg.token)) 41 | } 42 | 43 | return openai.NewCompleter(url+"/v1", model, ops...) 44 | } 45 | -------------------------------------------------------------------------------- /pkg/provider/huggingface/config.go: -------------------------------------------------------------------------------- 1 | package huggingface 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Config struct { 8 | url string 9 | 10 | token string 11 | model string 12 | 13 | client *http.Client 14 | } 15 | 16 | type Option func(*Config) 17 | 18 | func WithClient(client *http.Client) Option { 19 | return func(c *Config) { 20 | c.client = client 21 | } 22 | } 23 | 24 | func WithToken(token string) Option { 25 | return func(c *Config) { 26 | c.token = token 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /pkg/provider/huggingface/embedder_test.go: -------------------------------------------------------------------------------- 1 | package huggingface_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider/huggingface" 8 | 9 | "github.com/stretchr/testify/require" 10 | "github.com/testcontainers/testcontainers-go" 11 | "github.com/testcontainers/testcontainers-go/wait" 12 | ) 13 | 14 | func TestEmbedder(t *testing.T) { 15 | ctx := context.Background() 16 | 17 | server, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 18 | Started: true, 19 | 20 | ContainerRequest: testcontainers.ContainerRequest{ 21 | Image: "ghcr.io/huggingface/text-embeddings-inference:cpu-1.6", 22 | ImagePlatform: "linux/amd64", 23 | 24 | Cmd: []string{"--model-id", "BAAI/bge-large-en-v1.5"}, 25 | 26 | Mounts: testcontainers.Mounts( 27 | testcontainers.ContainerMount{ 28 | Target: "/data", 29 | Source: testcontainers.DockerVolumeMountSource{ 30 | Name: "huggingface", 31 | }, 32 | }, 33 | ), 34 | 35 | ExposedPorts: []string{"80/tcp"}, 36 | 37 | WaitingFor: wait.ForLog("Ready"), 38 | }, 39 | }) 40 | 41 | require.NoError(t, err) 42 | 43 | url, err := server.Endpoint(ctx, "") 44 | require.NoError(t, err) 45 | 46 | e, err := huggingface.NewEmbedder("http://"+url, "") 47 | require.NoError(t, err) 48 | 49 | result, err := e.Embed(ctx, []string{"Hello, World!", "Hello Welt!"}) 50 | require.NoError(t, err) 51 | 52 | require.Len(t, result.Embeddings, 2) 53 | 54 | require.NotEmpty(t, result.Embeddings[0]) 55 | require.NotEmpty(t, result.Embeddings[1]) 56 | } 57 | -------------------------------------------------------------------------------- /pkg/provider/huggingface/reranker_test.go: -------------------------------------------------------------------------------- 1 | package huggingface_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider/huggingface" 8 | 9 | "github.com/stretchr/testify/require" 10 | "github.com/testcontainers/testcontainers-go" 11 | "github.com/testcontainers/testcontainers-go/wait" 12 | ) 13 | 14 | func TestReranker(t *testing.T) { 15 | ctx := context.Background() 16 | 17 | server, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 18 | Started: true, 19 | 20 | ContainerRequest: testcontainers.ContainerRequest{ 21 | Image: "ghcr.io/huggingface/text-embeddings-inference:cpu-1.5", 22 | ImagePlatform: "linux/amd64", 23 | 24 | Cmd: []string{"--model-id", "BAAI/bge-reranker-base"}, 25 | 26 | Mounts: testcontainers.Mounts( 27 | testcontainers.ContainerMount{ 28 | Target: "/data", 29 | Source: testcontainers.DockerVolumeMountSource{ 30 | Name: "huggingface", 31 | }, 32 | }, 33 | ), 34 | 35 | ExposedPorts: []string{"80/tcp"}, 36 | 37 | WaitingFor: wait.ForLog("Ready"), 38 | }, 39 | }) 40 | 41 | require.NoError(t, err) 42 | 43 | url, err := server.Endpoint(ctx, "") 44 | require.NoError(t, err) 45 | 46 | e, err := huggingface.NewReranker("http://"+url, "") 47 | require.NoError(t, err) 48 | 49 | result, err := e.Rerank(ctx, "What is Deep Learning", []string{"Deep learning is a type of machine learning that uses artificial neural networks to learn from data.", "Deep Learning is Not All You Need"}, nil) 50 | require.NoError(t, err) 51 | 52 | require.NotEmpty(t, result) 53 | } 54 | -------------------------------------------------------------------------------- /pkg/provider/huggingface/util.go: -------------------------------------------------------------------------------- 1 | package huggingface 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "net/http" 9 | ) 10 | 11 | func convertError(resp *http.Response) error { 12 | data, _ := io.ReadAll(resp.Body) 13 | 14 | if len(data) == 0 { 15 | return errors.New(http.StatusText(resp.StatusCode)) 16 | } 17 | 18 | return errors.New(string(data)) 19 | } 20 | 21 | func jsonReader(v any) io.Reader { 22 | b := new(bytes.Buffer) 23 | 24 | enc := json.NewEncoder(b) 25 | enc.SetEscapeHTML(false) 26 | 27 | enc.Encode(v) 28 | return b 29 | } 30 | -------------------------------------------------------------------------------- /pkg/provider/jina/config.go: -------------------------------------------------------------------------------- 1 | package jina 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Config struct { 8 | url string 9 | 10 | token string 11 | model string 12 | 13 | client *http.Client 14 | } 15 | 16 | type Option func(*Config) 17 | 18 | func WithClient(client *http.Client) Option { 19 | return func(c *Config) { 20 | c.client = client 21 | } 22 | } 23 | 24 | func WithToken(token string) Option { 25 | return func(c *Config) { 26 | c.token = token 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /pkg/provider/jina/embedder_test.go: -------------------------------------------------------------------------------- 1 | package jina_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider/jina" 8 | 9 | "github.com/stretchr/testify/require" 10 | "github.com/testcontainers/testcontainers-go" 11 | "github.com/testcontainers/testcontainers-go/wait" 12 | ) 13 | 14 | func TestEmbedder(t *testing.T) { 15 | ctx := context.Background() 16 | 17 | server, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 18 | Started: true, 19 | 20 | ContainerRequest: testcontainers.ContainerRequest{ 21 | Image: "ghcr.io/adrianliechti/wingman-embeddings", 22 | 23 | Mounts: testcontainers.Mounts( 24 | testcontainers.ContainerMount{ 25 | Target: "/app/.cache/huggingface", 26 | Source: testcontainers.DockerVolumeMountSource{ 27 | Name: "huggingface", 28 | }, 29 | }, 30 | ), 31 | 32 | ExposedPorts: []string{"8000/tcp"}, 33 | 34 | WaitingFor: wait.ForLog("Application startup complete"), 35 | }, 36 | }) 37 | 38 | require.NoError(t, err) 39 | 40 | url, err := server.Endpoint(ctx, "") 41 | require.NoError(t, err) 42 | 43 | e, err := jina.NewEmbedder("http://"+url, "") 44 | require.NoError(t, err) 45 | 46 | result, err := e.Embed(ctx, []string{"Hello, World!", "Hello Welt!"}) 47 | require.NoError(t, err) 48 | 49 | require.Len(t, result.Embeddings, 2) 50 | 51 | require.NotEmpty(t, result.Embeddings[0]) 52 | require.NotEmpty(t, result.Embeddings[1]) 53 | } 54 | -------------------------------------------------------------------------------- /pkg/provider/jina/reranker_test.go: -------------------------------------------------------------------------------- 1 | package jina_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider/jina" 8 | 9 | "github.com/stretchr/testify/require" 10 | "github.com/testcontainers/testcontainers-go" 11 | "github.com/testcontainers/testcontainers-go/wait" 12 | ) 13 | 14 | func TestReranker(t *testing.T) { 15 | ctx := context.Background() 16 | 17 | server, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 18 | Started: true, 19 | 20 | ContainerRequest: testcontainers.ContainerRequest{ 21 | Image: "ghcr.io/adrianliechti/wingman-reranker", 22 | 23 | Mounts: testcontainers.Mounts( 24 | testcontainers.ContainerMount{ 25 | Target: "/app/.cache/huggingface", 26 | Source: testcontainers.DockerVolumeMountSource{ 27 | Name: "huggingface", 28 | }, 29 | }, 30 | ), 31 | 32 | ExposedPorts: []string{"8000/tcp"}, 33 | 34 | WaitingFor: wait.ForLog("Application startup complete"), 35 | }, 36 | }) 37 | 38 | require.NoError(t, err) 39 | 40 | url, err := server.Endpoint(ctx, "") 41 | require.NoError(t, err) 42 | 43 | r, err := jina.NewReranker("http://"+url, "") 44 | require.NoError(t, err) 45 | 46 | result, err := r.Rerank(ctx, "Hello, World!", []string{"World", "Sun", "Moon"}, nil) 47 | require.NoError(t, err) 48 | 49 | require.NotEmpty(t, result) 50 | } 51 | -------------------------------------------------------------------------------- /pkg/provider/jina/util.go: -------------------------------------------------------------------------------- 1 | package jina 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "net/http" 9 | ) 10 | 11 | func convertError(resp *http.Response) error { 12 | data, _ := io.ReadAll(resp.Body) 13 | 14 | if len(data) == 0 { 15 | return errors.New(http.StatusText(resp.StatusCode)) 16 | } 17 | 18 | return errors.New(string(data)) 19 | } 20 | 21 | func jsonReader(v any) io.Reader { 22 | b := new(bytes.Buffer) 23 | 24 | enc := json.NewEncoder(b) 25 | enc.SetEscapeHTML(false) 26 | 27 | enc.Encode(v) 28 | return b 29 | } 30 | -------------------------------------------------------------------------------- /pkg/provider/llama/completer.go: -------------------------------------------------------------------------------- 1 | package llama 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider/openai" 8 | ) 9 | 10 | type Completer = openai.Completer 11 | 12 | func NewCompleter(url, model string, options ...Option) (*Completer, error) { 13 | if url == "" { 14 | return nil, errors.New("url is required") 15 | } 16 | 17 | url = strings.TrimRight(url, "/") 18 | url = strings.TrimSuffix(url, "/v1") 19 | 20 | cfg := &Config{} 21 | 22 | for _, option := range options { 23 | option(cfg) 24 | } 25 | 26 | opts := []openai.Option{} 27 | 28 | if cfg.client != nil { 29 | opts = append(opts, openai.WithClient(cfg.client)) 30 | } 31 | 32 | return openai.NewCompleter(url+"/v1", model, opts...) 33 | } 34 | -------------------------------------------------------------------------------- /pkg/provider/llama/config.go: -------------------------------------------------------------------------------- 1 | package llama 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Config struct { 8 | client *http.Client 9 | } 10 | 11 | type Option func(*Config) 12 | 13 | func WithClient(client *http.Client) Option { 14 | return func(c *Config) { 15 | c.client = client 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /pkg/provider/llama/embedder.go: -------------------------------------------------------------------------------- 1 | package llama 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider/openai" 8 | ) 9 | 10 | type Embedder = openai.Embedder 11 | 12 | func NewEmbedder(url, model string, options ...Option) (*Embedder, error) { 13 | if url == "" { 14 | return nil, errors.New("url is required") 15 | } 16 | 17 | url = strings.TrimRight(url, "/") 18 | url = strings.TrimSuffix(url, "/v1") 19 | 20 | cfg := &Config{} 21 | 22 | for _, option := range options { 23 | option(cfg) 24 | } 25 | 26 | opts := []openai.Option{} 27 | 28 | if cfg.client != nil { 29 | opts = append(opts, openai.WithClient(cfg.client)) 30 | } 31 | 32 | return openai.NewEmbedder(url+"/v1", model, opts...) 33 | } 34 | -------------------------------------------------------------------------------- /pkg/provider/llama/reranker.go: -------------------------------------------------------------------------------- 1 | package llama 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider/jina" 8 | ) 9 | 10 | type Reranker = jina.Reranker 11 | 12 | func NewReranker(url, model string, options ...Option) (*Reranker, error) { 13 | if url == "" { 14 | return nil, errors.New("url is required") 15 | } 16 | 17 | url = strings.TrimRight(url, "/") 18 | url = strings.TrimSuffix(url, "/v1") 19 | 20 | cfg := &Config{} 21 | 22 | for _, option := range options { 23 | option(cfg) 24 | } 25 | 26 | opts := []jina.Option{} 27 | 28 | if cfg.client != nil { 29 | opts = append(opts, jina.WithClient(cfg.client)) 30 | } 31 | 32 | return jina.NewReranker(url, model, opts...) 33 | } 34 | -------------------------------------------------------------------------------- /pkg/provider/mistral/completer.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "github.com/adrianliechti/wingman/pkg/provider/openai" 5 | ) 6 | 7 | type Completer = openai.Completer 8 | 9 | func NewCompleter(model string, options ...Option) (*Completer, error) { 10 | url := "https://api.mistral.ai/v1/" 11 | 12 | cfg := &Config{} 13 | 14 | for _, option := range options { 15 | option(cfg) 16 | } 17 | 18 | return openai.NewCompleter(url, model, cfg.options...) 19 | } 20 | -------------------------------------------------------------------------------- /pkg/provider/mistral/config.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider/openai" 7 | ) 8 | 9 | type Config struct { 10 | options []openai.Option 11 | } 12 | 13 | type Option func(*Config) 14 | 15 | func WithClient(client *http.Client) Option { 16 | return func(c *Config) { 17 | c.options = append(c.options, openai.WithClient(client)) 18 | } 19 | } 20 | 21 | func WithToken(token string) Option { 22 | return func(c *Config) { 23 | c.options = append(c.options, openai.WithToken(token)) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /pkg/provider/mistral/util.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "net/http" 9 | ) 10 | 11 | func convertError(resp *http.Response) error { 12 | data, _ := io.ReadAll(resp.Body) 13 | 14 | if len(data) == 0 { 15 | return errors.New(http.StatusText(resp.StatusCode)) 16 | } 17 | 18 | return errors.New(string(data)) 19 | } 20 | 21 | func jsonReader(v any) io.Reader { 22 | b := new(bytes.Buffer) 23 | 24 | enc := json.NewEncoder(b) 25 | enc.SetEscapeHTML(false) 26 | 27 | enc.Encode(v) 28 | return b 29 | } 30 | -------------------------------------------------------------------------------- /pkg/provider/mistralrs/completer.go: -------------------------------------------------------------------------------- 1 | package mistralrs 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider/openai" 8 | ) 9 | 10 | type Completer = openai.Completer 11 | 12 | func NewCompleter(url, model string, options ...Option) (*Completer, error) { 13 | if url == "" { 14 | return nil, errors.New("url is required") 15 | } 16 | 17 | url = strings.TrimRight(url, "/") 18 | url = strings.TrimSuffix(url, "/v1") 19 | 20 | cfg := &Config{} 21 | 22 | for _, option := range options { 23 | option(cfg) 24 | } 25 | 26 | return openai.NewCompleter(url+"/v1", model, cfg.options...) 27 | } 28 | -------------------------------------------------------------------------------- /pkg/provider/mistralrs/config.go: -------------------------------------------------------------------------------- 1 | package mistralrs 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider/openai" 7 | ) 8 | 9 | type Config struct { 10 | options []openai.Option 11 | } 12 | 13 | type Option func(*Config) 14 | 15 | func WithClient(client *http.Client) Option { 16 | return func(c *Config) { 17 | c.options = append(c.options, openai.WithClient(client)) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /pkg/provider/mistralrs/embedder.go: -------------------------------------------------------------------------------- 1 | package mistralrs 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider/openai" 8 | ) 9 | 10 | type Embedder = openai.Embedder 11 | 12 | func NewEmbedder(url, model string, options ...Option) (*Embedder, error) { 13 | if url == "" { 14 | return nil, errors.New("url is required") 15 | } 16 | 17 | url = strings.TrimRight(url, "/") 18 | url = strings.TrimSuffix(url, "/v1") 19 | 20 | cfg := &Config{} 21 | 22 | for _, option := range options { 23 | option(cfg) 24 | } 25 | 26 | return openai.NewEmbedder(url+"/v1", model, cfg.options...) 27 | } 28 | -------------------------------------------------------------------------------- /pkg/provider/ollama/completer.go: -------------------------------------------------------------------------------- 1 | package ollama 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider/openai" 7 | ) 8 | 9 | type Completer = openai.Completer 10 | 11 | func NewCompleter(url, model string, options ...Option) (*Completer, error) { 12 | if url == "" { 13 | url = "http://localhost:11434" 14 | } 15 | 16 | url = strings.TrimRight(url, "/") 17 | url = strings.TrimSuffix(url, "/v1") 18 | 19 | cfg := &Config{} 20 | 21 | for _, option := range options { 22 | option(cfg) 23 | } 24 | 25 | return openai.NewCompleter(url+"/v1", model, cfg.options...) 26 | } 27 | -------------------------------------------------------------------------------- /pkg/provider/ollama/config.go: -------------------------------------------------------------------------------- 1 | package ollama 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider/openai" 7 | ) 8 | 9 | type Config struct { 10 | options []openai.Option 11 | } 12 | 13 | type Option func(*Config) 14 | 15 | func WithClient(client *http.Client) Option { 16 | return func(c *Config) { 17 | c.options = append(c.options, openai.WithClient(client)) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /pkg/provider/ollama/embedder.go: -------------------------------------------------------------------------------- 1 | package ollama 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider/openai" 7 | ) 8 | 9 | type Embedder = openai.Embedder 10 | 11 | func NewEmbedder(url, model string, options ...Option) (*Embedder, error) { 12 | if url == "" { 13 | url = "http://localhost:11434" 14 | } 15 | 16 | url = strings.TrimRight(url, "/") 17 | url = strings.TrimSuffix(url, "/v1") 18 | 19 | cfg := &Config{} 20 | 21 | for _, option := range options { 22 | option(cfg) 23 | } 24 | 25 | return openai.NewEmbedder(url+"/v1", model, cfg.options...) 26 | } 27 | -------------------------------------------------------------------------------- /pkg/provider/openai/config.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | 7 | "github.com/openai/openai-go/azure" 8 | "github.com/openai/openai-go/option" 9 | ) 10 | 11 | type Config struct { 12 | url string 13 | 14 | token string 15 | model string 16 | 17 | client *http.Client 18 | } 19 | 20 | type Option func(*Config) 21 | 22 | func WithClient(client *http.Client) Option { 23 | return func(c *Config) { 24 | c.client = client 25 | } 26 | } 27 | 28 | func WithToken(token string) Option { 29 | return func(c *Config) { 30 | c.token = token 31 | } 32 | } 33 | 34 | func (c *Config) Options() []option.RequestOption { 35 | if c.url == "" { 36 | c.url = "https://api.openai.com/v1/" 37 | } 38 | 39 | if c.client == nil { 40 | c.client = http.DefaultClient 41 | } 42 | 43 | c.url = strings.TrimRight(c.url, "/") + "/" 44 | 45 | if strings.Contains(c.url, "openai.azure.com") || strings.Contains(c.url, "cognitiveservices.azure.com") { 46 | options := make([]option.RequestOption, 0) 47 | 48 | options = append(options, 49 | option.WithHTTPClient(c.client), 50 | azure.WithEndpoint(c.url, "2025-04-01-preview"), 51 | ) 52 | 53 | if c.token != "" { 54 | options = append(options, azure.WithAPIKey(c.token)) 55 | } 56 | 57 | return options 58 | } 59 | 60 | options := []option.RequestOption{ 61 | option.WithBaseURL(c.url), 62 | option.WithHTTPClient(c.client), 63 | } 64 | 65 | if c.token != "" { 66 | options = append(options, option.WithAPIKey(c.token)) 67 | } 68 | 69 | return options 70 | } 71 | -------------------------------------------------------------------------------- /pkg/provider/openai/synthesizer.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "io" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | 9 | "github.com/google/uuid" 10 | "github.com/openai/openai-go" 11 | ) 12 | 13 | var _ provider.Synthesizer = (*Synthesizer)(nil) 14 | 15 | type Synthesizer struct { 16 | *Config 17 | speech openai.AudioSpeechService 18 | } 19 | 20 | func NewSynthesizer(url, model string, options ...Option) (*Synthesizer, error) { 21 | cfg := &Config{ 22 | url: url, 23 | model: model, 24 | } 25 | 26 | for _, option := range options { 27 | option(cfg) 28 | } 29 | 30 | return &Synthesizer{ 31 | Config: cfg, 32 | speech: openai.NewAudioSpeechService(cfg.Options()...), 33 | }, nil 34 | } 35 | 36 | func (s *Synthesizer) Synthesize(ctx context.Context, content string, options *provider.SynthesizeOptions) (*provider.Synthesis, error) { 37 | if options == nil { 38 | options = new(provider.SynthesizeOptions) 39 | } 40 | 41 | result, err := s.speech.New(ctx, openai.AudioSpeechNewParams{ 42 | Model: s.model, 43 | Input: content, 44 | 45 | Voice: openai.AudioSpeechNewParamsVoiceAlloy, 46 | 47 | ResponseFormat: openai.AudioSpeechNewParamsResponseFormatMP3, 48 | }) 49 | 50 | if err != nil { 51 | return nil, convertError(err) 52 | } 53 | 54 | data, err := io.ReadAll(result.Body) 55 | 56 | if err != nil { 57 | return nil, err 58 | } 59 | 60 | return &provider.Synthesis{ 61 | ID: uuid.NewString(), 62 | Model: s.model, 63 | 64 | Content: data, 65 | ContentType: "audio/mpeg", 66 | }, nil 67 | } 68 | -------------------------------------------------------------------------------- /pkg/provider/openai/util.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/openai/openai-go" 7 | ) 8 | 9 | func convertError(err error) error { 10 | var apierr *openai.Error 11 | 12 | if errors.As(err, &apierr) { 13 | //println(string(apierr.DumpRequest(true))) // Prints the serialized HTTP request 14 | //println(string(apierr.DumpResponse(true))) // Prints the serialized HTTP response 15 | } 16 | 17 | return err 18 | } 19 | -------------------------------------------------------------------------------- /pkg/provider/provider.go: -------------------------------------------------------------------------------- 1 | package provider 2 | 3 | type Provider = any 4 | 5 | type Model struct { 6 | ID string 7 | } 8 | 9 | type File struct { 10 | Name string 11 | 12 | Content []byte 13 | ContentType string 14 | } 15 | 16 | type Tool struct { 17 | Name string 18 | Description string 19 | 20 | Strict *bool 21 | 22 | Parameters map[string]any 23 | } 24 | 25 | type ToolResult struct { 26 | ID string 27 | 28 | Data string 29 | } 30 | 31 | type Schema struct { 32 | Name string 33 | Description string 34 | 35 | Strict *bool 36 | 37 | Schema map[string]any // TODO: Rename to Properties 38 | } 39 | 40 | type Usage struct { 41 | InputTokens int 42 | OutputTokens int 43 | } 44 | -------------------------------------------------------------------------------- /pkg/provider/renderer.go: -------------------------------------------------------------------------------- 1 | package provider 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type Renderer interface { 8 | Render(ctx context.Context, input string, options *RenderOptions) (*Rendering, error) 9 | } 10 | 11 | type RenderOptions struct { 12 | Images []File 13 | } 14 | 15 | type Rendering struct { 16 | ID string 17 | Model string 18 | 19 | Content []byte 20 | ContentType string 21 | } 22 | -------------------------------------------------------------------------------- /pkg/provider/replicate/client.go: -------------------------------------------------------------------------------- 1 | package replicate 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/replicate/replicate-go" 7 | ) 8 | 9 | type Client struct { 10 | *Config 11 | client *replicate.Client 12 | } 13 | 14 | type PredictionInput = replicate.PredictionInput 15 | type PredictionOutput = replicate.PredictionOutput 16 | 17 | type FileOutput = replicate.FileOutput 18 | 19 | func New(model string, options ...Option) (*Client, error) { 20 | cfg := &Config{ 21 | model: model, 22 | } 23 | 24 | for _, option := range options { 25 | option(cfg) 26 | } 27 | 28 | client, err := replicate.NewClient(cfg.Options()...) 29 | 30 | if err != nil { 31 | return nil, err 32 | } 33 | 34 | return &Client{ 35 | Config: cfg, 36 | client: client, 37 | }, nil 38 | } 39 | 40 | func (c *Client) Run(ctx context.Context, input PredictionInput) (PredictionOutput, error) { 41 | return c.client.RunWithOptions(ctx, c.model, input, nil, replicate.WithBlockUntilDone(), replicate.WithFileOutput()) 42 | } 43 | -------------------------------------------------------------------------------- /pkg/provider/replicate/config.go: -------------------------------------------------------------------------------- 1 | package replicate 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/replicate/replicate-go" 7 | ) 8 | 9 | type Config struct { 10 | token string 11 | model string 12 | 13 | client *http.Client 14 | } 15 | 16 | type Option func(*Config) 17 | 18 | func WithClient(client *http.Client) Option { 19 | return func(c *Config) { 20 | c.client = client 21 | } 22 | } 23 | 24 | func WithToken(token string) Option { 25 | return func(c *Config) { 26 | c.token = token 27 | } 28 | } 29 | 30 | func WithModel(model string) Option { 31 | return func(c *Config) { 32 | c.model = model 33 | } 34 | } 35 | 36 | func (c *Config) Options() []replicate.ClientOption { 37 | options := []replicate.ClientOption{} 38 | 39 | if c.token != "" { 40 | options = append(options, replicate.WithToken(c.token)) 41 | } 42 | 43 | return options 44 | } 45 | -------------------------------------------------------------------------------- /pkg/provider/reranker.go: -------------------------------------------------------------------------------- 1 | package provider 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type Reranker interface { 8 | Rerank(ctx context.Context, query string, texts []string, options *RerankOptions) ([]Ranking, error) 9 | } 10 | 11 | type RerankOptions struct { 12 | Limit *int 13 | } 14 | 15 | type Ranking struct { 16 | Text string 17 | Score float64 18 | } 19 | -------------------------------------------------------------------------------- /pkg/provider/synthesizer.go: -------------------------------------------------------------------------------- 1 | package provider 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type Synthesizer interface { 8 | Synthesize(ctx context.Context, input string, options *SynthesizeOptions) (*Synthesis, error) 9 | } 10 | 11 | type SynthesizeOptions struct { 12 | Voice string 13 | } 14 | 15 | type Synthesis struct { 16 | ID string 17 | Model string 18 | 19 | Content []byte 20 | ContentType string 21 | } 22 | -------------------------------------------------------------------------------- /pkg/provider/transcriber.go: -------------------------------------------------------------------------------- 1 | package provider 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type Transcriber interface { 8 | Transcribe(ctx context.Context, input File, options *TranscribeOptions) (*Transcription, error) 9 | } 10 | 11 | type TranscribeOptions struct { 12 | Language string 13 | Temperature *float32 14 | } 15 | 16 | type Transcription struct { 17 | ID string 18 | Model string 19 | 20 | Text string 21 | 22 | // Language string 23 | // Duration float64 24 | } 25 | -------------------------------------------------------------------------------- /pkg/provider/whisper/config.go: -------------------------------------------------------------------------------- 1 | package whisper 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Config struct { 8 | url string 9 | 10 | client *http.Client 11 | } 12 | 13 | type Option func(*Config) 14 | 15 | func WithClient(client *http.Client) Option { 16 | return func(c *Config) { 17 | c.client = client 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /pkg/provider/whisper/util.go: -------------------------------------------------------------------------------- 1 | package whisper 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | func convertError(resp *http.Response) error { 10 | data, _ := io.ReadAll(resp.Body) 11 | 12 | if len(data) == 0 { 13 | return errors.New(http.StatusText(resp.StatusCode)) 14 | } 15 | 16 | return errors.New(string(data)) 17 | } 18 | -------------------------------------------------------------------------------- /pkg/provider/xai/completer.go: -------------------------------------------------------------------------------- 1 | package xai 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider/openai" 7 | ) 8 | 9 | type Completer = openai.Completer 10 | 11 | func NewCompleter(url, model string, options ...Option) (*Completer, error) { 12 | if url == "" { 13 | url = "https://api.x.ai" 14 | } 15 | 16 | url = strings.TrimRight(url, "/") 17 | url = strings.TrimSuffix(url, "/v1") 18 | 19 | cfg := &Config{} 20 | 21 | for _, option := range options { 22 | option(cfg) 23 | } 24 | 25 | return openai.NewCompleter(url+"/v1", model, cfg.options...) 26 | } 27 | -------------------------------------------------------------------------------- /pkg/provider/xai/config.go: -------------------------------------------------------------------------------- 1 | package xai 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider/openai" 7 | ) 8 | 9 | type Config struct { 10 | options []openai.Option 11 | } 12 | 13 | type Option func(*Config) 14 | 15 | func WithClient(client *http.Client) Option { 16 | return func(c *Config) { 17 | c.options = append(c.options, openai.WithClient(client)) 18 | } 19 | } 20 | 21 | func WithToken(token string) Option { 22 | return func(c *Config) { 23 | c.options = append(c.options, openai.WithToken(token)) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /pkg/router/roundrobin/router.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import ( 4 | "context" 5 | "math/rand" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | ) 9 | 10 | type Completer struct { 11 | completers []provider.Completer 12 | } 13 | 14 | func NewCompleter(completer ...provider.Completer) (provider.Completer, error) { 15 | c := &Completer{ 16 | completers: completer, 17 | } 18 | 19 | return c, nil 20 | } 21 | 22 | func (c *Completer) Complete(ctx context.Context, messages []provider.Message, options *provider.CompleteOptions) (*provider.Completion, error) { 23 | index := rand.Intn(len(c.completers)) 24 | provider := c.completers[index] 25 | 26 | return provider.Complete(ctx, messages, options) 27 | } 28 | -------------------------------------------------------------------------------- /pkg/segmenter/jina/client_test.go: -------------------------------------------------------------------------------- 1 | package jina_test 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/adrianliechti/wingman/pkg/segmenter/jina" 9 | 10 | "github.com/stretchr/testify/require" 11 | "github.com/testcontainers/testcontainers-go" 12 | "github.com/testcontainers/testcontainers-go/wait" 13 | ) 14 | 15 | func TestExtract(t *testing.T) { 16 | ctx := context.Background() 17 | 18 | server, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 19 | Started: true, 20 | 21 | ContainerRequest: testcontainers.ContainerRequest{ 22 | Image: "ghcr.io/adrianliechti/wingman-segmenter", 23 | ExposedPorts: []string{"8000/tcp"}, 24 | 25 | WaitingFor: wait.ForLog("Application startup complete"), 26 | }, 27 | }) 28 | 29 | require.NoError(t, err) 30 | 31 | url, err := server.Endpoint(ctx, "") 32 | require.NoError(t, err) 33 | 34 | s, err := jina.New("http://" + url) 35 | require.NoError(t, err) 36 | 37 | input := strings.Repeat("Hello, World! ", 2000) 38 | 39 | segments, err := s.Segment(ctx, input, nil) 40 | require.NoError(t, err) 41 | 42 | require.NotEmpty(t, segments) 43 | } 44 | -------------------------------------------------------------------------------- /pkg/segmenter/jina/config.go: -------------------------------------------------------------------------------- 1 | package jina 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | 15 | func WithToken(token string) Option { 16 | return func(c *Client) { 17 | c.token = token 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /pkg/segmenter/jina/model.go: -------------------------------------------------------------------------------- 1 | package jina 2 | 3 | type SegmentRequest struct { 4 | Content string `json:"content"` 5 | 6 | ReturnChunks bool `json:"return_chunks,omitempty"` 7 | 8 | MaxChunkLength int `json:"max_chunk_length,omitempty"` 9 | } 10 | 11 | type SegmentResponse struct { 12 | Chunks []string `json:"chunks"` 13 | } 14 | -------------------------------------------------------------------------------- /pkg/segmenter/segmenter.go: -------------------------------------------------------------------------------- 1 | package segmenter 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | ) 7 | 8 | type Provider interface { 9 | Segment(ctx context.Context, text string, options *SegmentOptions) ([]Segment, error) 10 | } 11 | 12 | var ( 13 | ErrUnsupported = errors.New("unsupported type") 14 | ) 15 | 16 | type SegmentOptions struct { 17 | FileName string 18 | 19 | SegmentLength *int 20 | SegmentOverlap *int 21 | } 22 | 23 | type Segment struct { 24 | Text string 25 | } 26 | -------------------------------------------------------------------------------- /pkg/segmenter/unstructured/client_test.go: -------------------------------------------------------------------------------- 1 | package unstructured_test 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/adrianliechti/wingman/pkg/segmenter/unstructured" 9 | 10 | "github.com/stretchr/testify/require" 11 | "github.com/testcontainers/testcontainers-go" 12 | "github.com/testcontainers/testcontainers-go/wait" 13 | ) 14 | 15 | func TestExtract(t *testing.T) { 16 | ctx := context.Background() 17 | 18 | server, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ 19 | Started: true, 20 | 21 | ContainerRequest: testcontainers.ContainerRequest{ 22 | Image: "quay.io/unstructured-io/unstructured-api:0.0.80", 23 | ExposedPorts: []string{"8000/tcp"}, 24 | WaitingFor: wait.ForLog("Application startup complete"), 25 | }, 26 | }) 27 | 28 | require.NoError(t, err) 29 | 30 | url, err := server.Endpoint(ctx, "") 31 | require.NoError(t, err) 32 | 33 | s, err := unstructured.New("http://" + url) 34 | require.NoError(t, err) 35 | 36 | input := strings.Repeat("Hello, World! ", 2000) 37 | 38 | segments, err := s.Segment(ctx, input, nil) 39 | require.NoError(t, err) 40 | 41 | require.NotEmpty(t, segments) 42 | } 43 | -------------------------------------------------------------------------------- /pkg/segmenter/unstructured/config.go: -------------------------------------------------------------------------------- 1 | package unstructured 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | 15 | func WithToken(token string) Option { 16 | return func(c *Client) { 17 | c.token = token 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /pkg/segmenter/unstructured/model.go: -------------------------------------------------------------------------------- 1 | package unstructured 2 | 3 | type Element struct { 4 | ID string `json:"element_id"` 5 | 6 | Type string `json:"type"` 7 | Text string `json:"text"` 8 | 9 | Metadata ElementMetadata `json:"metadata"` 10 | } 11 | 12 | type ElementMetadata struct { 13 | FileName string `json:"filename"` 14 | FileType string `json:"filetype"` 15 | 16 | Languages []string `json:"languages"` 17 | 18 | // PageName string `json:"page_name"` 19 | // PageNumber int `json:"page_number"` 20 | 21 | // MailSender string `json:"sent_from"` 22 | // MailRecipient string `json:"sent_to"` 23 | // MailSubject string `json:"subject"` 24 | } 25 | -------------------------------------------------------------------------------- /pkg/summarizer/adapter/adapter.go: -------------------------------------------------------------------------------- 1 | package adapter 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | "github.com/adrianliechti/wingman/pkg/summarizer" 9 | "github.com/adrianliechti/wingman/pkg/text" 10 | ) 11 | 12 | var _ summarizer.Provider = (*Adapter)(nil) 13 | 14 | type Adapter struct { 15 | completer provider.Completer 16 | } 17 | 18 | func FromCompleter(completer provider.Completer) *Adapter { 19 | return &Adapter{ 20 | completer: completer, 21 | } 22 | } 23 | 24 | func (a *Adapter) Summarize(ctx context.Context, content string, options *summarizer.SummarizerOptions) (*summarizer.Summary, error) { 25 | splitter := text.NewSplitter() 26 | splitter.ChunkSize = 16000 27 | splitter.ChunkOverlap = 0 28 | 29 | var segments []string 30 | 31 | for _, part := range splitter.Split(content) { 32 | completion, err := a.completer.Complete(ctx, []provider.Message{ 33 | provider.UserMessage("Write a concise summary of the following: \n" + part), 34 | }, nil) 35 | 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | segments = append(segments, completion.Message.Text()) 41 | } 42 | 43 | completion, err := a.completer.Complete(ctx, []provider.Message{ 44 | provider.UserMessage("Distill the following parts into a consolidated summary: \n" + strings.Join(segments, "\n\n")), 45 | }, nil) 46 | 47 | if err != nil { 48 | return nil, err 49 | } 50 | 51 | result := &summarizer.Summary{ 52 | Text: completion.Message.Text(), 53 | } 54 | 55 | return result, nil 56 | } 57 | -------------------------------------------------------------------------------- /pkg/summarizer/custom/Taskfile.yaml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | generate: 7 | cmds: 8 | - protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative summarizer.proto -------------------------------------------------------------------------------- /pkg/summarizer/custom/client.go: -------------------------------------------------------------------------------- 1 | package custom 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "strings" 7 | 8 | "github.com/adrianliechti/wingman/pkg/summarizer" 9 | 10 | "google.golang.org/grpc" 11 | "google.golang.org/grpc/credentials/insecure" 12 | ) 13 | 14 | var ( 15 | _ summarizer.Provider = (*Client)(nil) 16 | ) 17 | 18 | type Client struct { 19 | url string 20 | client SummarizerClient 21 | } 22 | 23 | func New(url string, options ...Option) (*Client, error) { 24 | if url == "" || !strings.HasPrefix(url, "grpc://") { 25 | return nil, errors.New("invalid url") 26 | } 27 | 28 | c := &Client{ 29 | url: url, 30 | } 31 | 32 | for _, option := range options { 33 | option(c) 34 | } 35 | 36 | client, err := grpc.NewClient(strings.TrimPrefix(c.url, "grpc://"), 37 | grpc.WithTransportCredentials(insecure.NewCredentials()), 38 | grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(100*1024*1024)), // 100MB max receive message size 39 | ) 40 | 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | c.client = NewSummarizerClient(client) 46 | 47 | return c, nil 48 | } 49 | 50 | func (c *Client) Summarize(ctx context.Context, text string, options *summarizer.SummarizerOptions) (*summarizer.Summary, error) { 51 | if options == nil { 52 | options = new(summarizer.SummarizerOptions) 53 | } 54 | 55 | resp, err := c.client.Summarize(ctx, &SummarizeRequest{ 56 | Text: text, 57 | }) 58 | 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | return &summarizer.Summary{ 64 | Text: resp.Text, 65 | }, nil 66 | } 67 | -------------------------------------------------------------------------------- /pkg/summarizer/custom/config.go: -------------------------------------------------------------------------------- 1 | package custom 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/summarizer/custom/summarizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option go_package = "github.com/adrianliechti/wingman/pkg/summarizer/custom;custom"; 4 | 5 | package summarizer; 6 | 7 | service Summarizer { 8 | rpc Summarize (SummarizeRequest) returns (Summary) {} 9 | } 10 | 11 | message SummarizeRequest { 12 | string text = 1; 13 | } 14 | 15 | message Summary { 16 | string text = 1; 17 | } -------------------------------------------------------------------------------- /pkg/summarizer/summarizer.go: -------------------------------------------------------------------------------- 1 | package summarizer 2 | 3 | import "context" 4 | 5 | type Provider interface { 6 | Summarize(ctx context.Context, text string, options *SummarizerOptions) (*Summary, error) 7 | } 8 | 9 | type SummarizerOptions struct { 10 | } 11 | 12 | type Summary struct { 13 | Text string 14 | } 15 | -------------------------------------------------------------------------------- /pkg/template/template.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import ( 4 | "bytes" 5 | "text/template" 6 | ) 7 | 8 | type Template struct { 9 | tmpl *template.Template 10 | } 11 | 12 | func MustTemplate(text string) *Template { 13 | prompt, err := NewTemplate(text) 14 | 15 | if err != nil { 16 | panic(err) 17 | } 18 | 19 | return prompt 20 | } 21 | 22 | func NewTemplate(text string) (*Template, error) { 23 | tmpl, err := template. 24 | New("prompt"). 25 | Funcs(map[string]any{ 26 | "now": now, 27 | "date": date, 28 | "dateInZone": dateInZone, 29 | "include": include, 30 | }). 31 | Parse(text) 32 | 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | return &Template{ 38 | tmpl: tmpl, 39 | }, nil 40 | } 41 | 42 | func (t *Template) Execute(data any) (string, error) { 43 | if data == nil { 44 | data = map[string]any{} 45 | } 46 | 47 | var buffer bytes.Buffer 48 | 49 | if err := t.tmpl.Execute(&buffer, data); err != nil { 50 | return "", err 51 | } 52 | 53 | return buffer.String(), nil 54 | } 55 | -------------------------------------------------------------------------------- /pkg/template/template_date.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // https://github.com/Masterminds/sprig/blob/master/date.go 8 | 9 | func now() time.Time { 10 | return time.Now() 11 | } 12 | 13 | func date(fmt string, date interface{}) string { 14 | return dateInZone(fmt, date, "Local") 15 | } 16 | 17 | func dateInZone(fmt string, date interface{}, zone string) string { 18 | var t time.Time 19 | switch date := date.(type) { 20 | default: 21 | t = time.Now() 22 | case time.Time: 23 | t = date 24 | case *time.Time: 25 | t = *date 26 | case int64: 27 | t = time.Unix(date, 0) 28 | case int: 29 | t = time.Unix(int64(date), 0) 30 | case int32: 31 | t = time.Unix(int64(date), 0) 32 | } 33 | 34 | loc, err := time.LoadLocation(zone) 35 | if err != nil { 36 | loc, _ = time.LoadLocation("UTC") 37 | } 38 | 39 | return t.In(loc).Format(fmt) 40 | } 41 | -------------------------------------------------------------------------------- /pkg/template/template_include.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import ( 4 | "os" 5 | ) 6 | 7 | func include(path string) string { 8 | data, err := os.ReadFile(path) 9 | 10 | if err != nil { 11 | panic(err) 12 | } 13 | 14 | return string(data) 15 | } 16 | -------------------------------------------------------------------------------- /pkg/template/template_message.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import ( 4 | "slices" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | ) 8 | 9 | func Message(message provider.Message, data any) (provider.Message, error) { 10 | t, err := NewTemplate(message.Text()) 11 | 12 | if err != nil { 13 | return message, err 14 | } 15 | 16 | content, err := t.Execute(data) 17 | 18 | if err != nil { 19 | return message, err 20 | } 21 | 22 | message.Content = []provider.Content{ 23 | { 24 | Text: content, 25 | }, 26 | } 27 | 28 | return message, nil 29 | } 30 | 31 | func Messages(messages []provider.Message, data any) ([]provider.Message, error) { 32 | result := slices.Clone(messages) 33 | 34 | for i, m := range result { 35 | message, err := Message(m, data) 36 | 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | result[i] = message 42 | } 43 | 44 | return result, nil 45 | } 46 | -------------------------------------------------------------------------------- /pkg/text/normalize.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "regexp" 5 | "strings" 6 | ) 7 | 8 | func Normalize(text string) string { 9 | text = strings.ReplaceAll(text, "\r\n", "\n") 10 | text = regexp.MustCompile(`\n\s*\n\s*`).ReplaceAllString(text, "\a\a") 11 | text = regexp.MustCompile(`\n\s*`).ReplaceAllString(text, "\a") 12 | text = strings.Join(strings.Fields(text), " ") 13 | text = strings.ReplaceAll(text, "\a", "\n") 14 | 15 | return text 16 | } 17 | -------------------------------------------------------------------------------- /pkg/to/to.go: -------------------------------------------------------------------------------- 1 | package to 2 | 3 | func Ptr[T any](v T) *T { 4 | return &v 5 | } 6 | -------------------------------------------------------------------------------- /pkg/tool/custom/Taskfile.yaml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | generate: 7 | cmds: 8 | - protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative tool.proto -------------------------------------------------------------------------------- /pkg/tool/custom/config.go: -------------------------------------------------------------------------------- 1 | package custom 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/tool/custom/tool.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option go_package = "github.com/adrianliechti/wingman/pkg/tool/custom;custom"; 4 | 5 | package tool; 6 | 7 | service Tool { 8 | rpc Tools (ToolsRequest) returns (ToolsResponse) {} 9 | rpc Execute (ExecuteRequest) returns (ResultResponse) {} 10 | } 11 | 12 | message ToolsRequest { 13 | } 14 | 15 | message ToolsResponse { 16 | repeated Definition definitions = 1; 17 | } 18 | 19 | message Definition { 20 | string name = 1; 21 | string description = 2; 22 | 23 | string parameters = 3; 24 | } 25 | 26 | message ExecuteRequest { 27 | string name = 1; 28 | string parameters = 2; 29 | } 30 | 31 | message ResultResponse { 32 | string data = 1; 33 | } -------------------------------------------------------------------------------- /pkg/tool/extract/config.go: -------------------------------------------------------------------------------- 1 | package extract 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/tool/mcp/config.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/tool/render/config.go: -------------------------------------------------------------------------------- 1 | package render 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/tool/render/models.go: -------------------------------------------------------------------------------- 1 | package render 2 | 3 | type Result struct { 4 | URL string `json:"url"` 5 | 6 | // Style string `json:"style"` 7 | // Prompt string `json:"prompt"` 8 | } 9 | -------------------------------------------------------------------------------- /pkg/tool/retrieve/config.go: -------------------------------------------------------------------------------- 1 | package retrieve 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/tool/retrieve/models.go: -------------------------------------------------------------------------------- 1 | package retrieve 2 | 3 | type Result struct { 4 | Title string `json:"title,omitempty"` 5 | Source string `json:"source,omitempty"` 6 | Content string `json:"content,omitempty"` 7 | } 8 | -------------------------------------------------------------------------------- /pkg/tool/search/config.go: -------------------------------------------------------------------------------- 1 | package search 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/tool/search/models.go: -------------------------------------------------------------------------------- 1 | package search 2 | 3 | type Result struct { 4 | Title string `json:"title,omitempty"` 5 | Source string `json:"source,omitempty"` 6 | Content string `json:"content,omitempty"` 7 | } 8 | -------------------------------------------------------------------------------- /pkg/tool/synthesize/config.go: -------------------------------------------------------------------------------- 1 | package synthesize 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/tool/synthesize/models.go: -------------------------------------------------------------------------------- 1 | package synthesize 2 | 3 | type Result struct { 4 | URL string `json:"url"` 5 | } 6 | -------------------------------------------------------------------------------- /pkg/tool/tool.go: -------------------------------------------------------------------------------- 1 | package tool 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | ) 9 | 10 | type Tool = provider.Tool 11 | 12 | var ( 13 | ErrInvalidTool = errors.New("invalid tool") 14 | ) 15 | 16 | type Provider interface { 17 | Tools(ctx context.Context) ([]Tool, error) 18 | Execute(ctx context.Context, name string, parameters map[string]any) (any, error) 19 | } 20 | 21 | var ( 22 | KeyToolFiles = "tool_files" 23 | ) 24 | 25 | func WithFiles(ctx context.Context, files []provider.File) context.Context { 26 | return context.WithValue(ctx, KeyToolFiles, files) 27 | } 28 | 29 | func FilesFromContext(ctx context.Context) ([]provider.File, bool) { 30 | val := ctx.Value(KeyToolFiles) 31 | 32 | if val == nil { 33 | return nil, false 34 | } 35 | 36 | files, ok := val.([]provider.File) 37 | return files, ok 38 | } 39 | -------------------------------------------------------------------------------- /pkg/tool/translate/config.go: -------------------------------------------------------------------------------- 1 | package translate 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/tool/translate/models.go: -------------------------------------------------------------------------------- 1 | package translate 2 | 3 | type Result struct { 4 | Text string `json:"text,omitempty"` 5 | Language string `json:"language,omitempty"` 6 | } 7 | -------------------------------------------------------------------------------- /pkg/translator/azure/config.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | 15 | func WithToken(token string) Option { 16 | return func(c *Client) { 17 | c.token = token 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /pkg/translator/azure/util.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "net/http" 9 | ) 10 | 11 | func convertError(resp *http.Response) error { 12 | data, _ := io.ReadAll(resp.Body) 13 | 14 | if len(data) == 0 { 15 | return errors.New(http.StatusText(resp.StatusCode)) 16 | } 17 | 18 | return errors.New(string(data)) 19 | } 20 | 21 | func jsonReader(v any) io.Reader { 22 | b := new(bytes.Buffer) 23 | 24 | enc := json.NewEncoder(b) 25 | enc.SetEscapeHTML(false) 26 | 27 | enc.Encode(v) 28 | return b 29 | } 30 | -------------------------------------------------------------------------------- /pkg/translator/custom/Taskfile.yaml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: "3" 4 | 5 | tasks: 6 | generate: 7 | cmds: 8 | - protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative translator.proto -------------------------------------------------------------------------------- /pkg/translator/custom/client.go: -------------------------------------------------------------------------------- 1 | package custom 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "strings" 7 | 8 | "github.com/adrianliechti/wingman/pkg/translator" 9 | 10 | "google.golang.org/grpc" 11 | "google.golang.org/grpc/credentials/insecure" 12 | ) 13 | 14 | var ( 15 | _ translator.Provider = (*Client)(nil) 16 | ) 17 | 18 | type Client struct { 19 | url string 20 | client TranslatorClient 21 | } 22 | 23 | func New(url string, options ...Option) (*Client, error) { 24 | if url == "" || !strings.HasPrefix(url, "grpc://") { 25 | return nil, errors.New("invalid url") 26 | } 27 | 28 | c := &Client{ 29 | url: url, 30 | } 31 | 32 | for _, option := range options { 33 | option(c) 34 | } 35 | 36 | client, err := grpc.NewClient(strings.TrimPrefix(c.url, "grpc://"), 37 | grpc.WithTransportCredentials(insecure.NewCredentials()), 38 | grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(100*1024*1024)), // 100MB max receive message size 39 | ) 40 | 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | c.client = NewTranslatorClient(client) 46 | 47 | return c, nil 48 | } 49 | 50 | func (c *Client) Translate(ctx context.Context, text string, options *translator.TranslateOptions) (*translator.Translation, error) { 51 | if options == nil { 52 | options = new(translator.TranslateOptions) 53 | } 54 | 55 | resp, err := c.client.Translate(ctx, &TranslateRequest{ 56 | Text: text, 57 | 58 | Language: options.Language, 59 | }) 60 | 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | return &translator.Translation{ 66 | Text: resp.Text, 67 | }, nil 68 | } 69 | -------------------------------------------------------------------------------- /pkg/translator/custom/config.go: -------------------------------------------------------------------------------- 1 | package custom 2 | 3 | type Option func(*Client) 4 | -------------------------------------------------------------------------------- /pkg/translator/custom/translator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option go_package = "github.com/adrianliechti/wingman/pkg/translator/custom;custom"; 4 | 5 | package translator; 6 | 7 | service Translator { 8 | rpc Translate (TranslateRequest) returns (Translation) {} 9 | } 10 | 11 | message TranslateRequest { 12 | string text = 1; 13 | string language = 2; 14 | } 15 | 16 | message Translation { 17 | string text = 1; 18 | } -------------------------------------------------------------------------------- /pkg/translator/deepl/config.go: -------------------------------------------------------------------------------- 1 | package deepl 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type Option func(*Client) 8 | 9 | func WithClient(client *http.Client) Option { 10 | return func(c *Client) { 11 | c.client = client 12 | } 13 | } 14 | 15 | func WithToken(token string) Option { 16 | return func(c *Client) { 17 | c.token = token 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /pkg/translator/deepl/util.go: -------------------------------------------------------------------------------- 1 | package deepl 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "net/http" 9 | ) 10 | 11 | func convertError(resp *http.Response) error { 12 | data, _ := io.ReadAll(resp.Body) 13 | 14 | if len(data) == 0 { 15 | return errors.New(http.StatusText(resp.StatusCode)) 16 | } 17 | 18 | return errors.New(string(data)) 19 | } 20 | 21 | func jsonReader(v any) io.Reader { 22 | b := new(bytes.Buffer) 23 | 24 | enc := json.NewEncoder(b) 25 | enc.SetEscapeHTML(false) 26 | 27 | enc.Encode(v) 28 | return b 29 | } 30 | -------------------------------------------------------------------------------- /pkg/translator/llm/client.go: -------------------------------------------------------------------------------- 1 | package llm 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | "github.com/adrianliechti/wingman/pkg/translator" 8 | ) 9 | 10 | type Client struct { 11 | completer provider.Completer 12 | } 13 | 14 | func New(completer provider.Completer) (*Client, error) { 15 | c := &Client{ 16 | completer: completer, 17 | } 18 | 19 | return c, nil 20 | } 21 | 22 | func (c *Client) Translate(ctx context.Context, content string, options *translator.TranslateOptions) (*translator.Translation, error) { 23 | if options == nil { 24 | options = new(translator.TranslateOptions) 25 | } 26 | 27 | messages := []provider.Message{ 28 | provider.SystemMessage("Act as a translator. Translate the following text to `" + options.Language + "`. Only return the translation, no other text."), 29 | provider.UserMessage(content), 30 | } 31 | 32 | completion, err := c.completer.Complete(ctx, messages, nil) 33 | 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | result := &translator.Translation{ 39 | Text: completion.Message.Text(), 40 | } 41 | 42 | return result, nil 43 | } 44 | -------------------------------------------------------------------------------- /pkg/translator/translator.go: -------------------------------------------------------------------------------- 1 | package translator 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type Provider interface { 8 | Translate(ctx context.Context, text string, options *TranslateOptions) (*Translation, error) 9 | } 10 | 11 | type TranslateOptions struct { 12 | Language string 13 | } 14 | 15 | type Translation struct { 16 | Text string 17 | } 18 | -------------------------------------------------------------------------------- /server/api/handler.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/config" 8 | 9 | "github.com/go-chi/chi/v5" 10 | ) 11 | 12 | type Handler struct { 13 | *config.Config 14 | http.Handler 15 | } 16 | 17 | func New(cfg *config.Config) (*Handler, error) { 18 | mux := chi.NewMux() 19 | 20 | h := &Handler{ 21 | Config: cfg, 22 | Handler: mux, 23 | } 24 | 25 | h.Attach(mux) 26 | return h, nil 27 | } 28 | 29 | func (h *Handler) Attach(r chi.Router) { 30 | r.Post("/extract", h.handleExtract) 31 | r.Post("/rerank", h.handleRerank) 32 | r.Post("/segment", h.handleSegment) 33 | r.Post("/summarize", h.handleSummarize) 34 | r.Post("/translate", h.handleTranslate) 35 | r.Post("/transcribe", h.handleTranscribe) 36 | } 37 | 38 | func writeJson(w http.ResponseWriter, v any) { 39 | w.Header().Set("Content-Type", "application/json") 40 | 41 | enc := json.NewEncoder(w) 42 | enc.SetEscapeHTML(false) 43 | 44 | enc.Encode(v) 45 | } 46 | 47 | func writeError(w http.ResponseWriter, code int, err error) { 48 | w.WriteHeader(code) 49 | w.Write([]byte(err.Error())) 50 | } 51 | -------------------------------------------------------------------------------- /server/api/handler_segment.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | 7 | "github.com/adrianliechti/wingman/pkg/segmenter" 8 | ) 9 | 10 | func (h *Handler) handleSegment(w http.ResponseWriter, r *http.Request) { 11 | model := valueModel(r) 12 | 13 | p, err := h.Segmenter(model) 14 | 15 | if err != nil { 16 | writeError(w, http.StatusBadRequest, err) 17 | return 18 | } 19 | 20 | text, err := h.readText(r) 21 | 22 | if err != nil { 23 | writeError(w, http.StatusBadRequest, err) 24 | return 25 | } 26 | 27 | options := &segmenter.SegmentOptions{ 28 | SegmentLength: valueSegmentLength(r), 29 | SegmentOverlap: valueSegmentOverlap(r), 30 | } 31 | 32 | segments, err := p.Segment(r.Context(), text, options) 33 | 34 | if err != nil { 35 | writeError(w, http.StatusBadRequest, err) 36 | return 37 | } 38 | 39 | result := make([]Segment, 0) 40 | 41 | for _, s := range segments { 42 | segment := Segment{ 43 | Text: s.Text, 44 | } 45 | 46 | result = append(result, segment) 47 | } 48 | 49 | writeJson(w, result) 50 | } 51 | 52 | func valueSegmentLength(r *http.Request) *int { 53 | if val := r.FormValue("segment_length"); val != "" { 54 | if val, err := strconv.Atoi(val); err == nil { 55 | return &val 56 | } 57 | } 58 | 59 | return nil 60 | } 61 | 62 | func valueSegmentOverlap(r *http.Request) *int { 63 | if val := r.FormValue("segment_overlap"); val != "" { 64 | if val, err := strconv.Atoi(val); err == nil { 65 | return &val 66 | } 67 | } 68 | 69 | return nil 70 | } 71 | -------------------------------------------------------------------------------- /server/api/handler_summarize.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/pkg/summarizer" 8 | ) 9 | 10 | func (h *Handler) handleSummarize(w http.ResponseWriter, r *http.Request) { 11 | model := valueModel(r) 12 | 13 | p, err := h.Summarizer(model) 14 | 15 | if err != nil { 16 | writeError(w, http.StatusBadRequest, err) 17 | return 18 | } 19 | 20 | text, err := h.readText(r) 21 | 22 | if err != nil { 23 | writeError(w, http.StatusBadRequest, err) 24 | return 25 | } 26 | 27 | options := &summarizer.SummarizerOptions{} 28 | 29 | summary, err := p.Summarize(r.Context(), text, options) 30 | 31 | if err != nil { 32 | writeError(w, http.StatusBadRequest, err) 33 | return 34 | } 35 | 36 | w.Header().Set("Content-Type", "text/plain") 37 | 38 | w.WriteHeader(http.StatusOK) 39 | io.WriteString(w, summary.Text) 40 | } 41 | -------------------------------------------------------------------------------- /server/api/handler_transcribe.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | ) 9 | 10 | func (h *Handler) handleTranscribe(w http.ResponseWriter, r *http.Request) { 11 | model := valueModel(r) 12 | language := valueLanguage(r) 13 | 14 | p, err := h.Transcriber(model) 15 | 16 | if err != nil { 17 | writeError(w, http.StatusBadRequest, err) 18 | return 19 | } 20 | 21 | file, header, err := r.FormFile("file") 22 | 23 | if err != nil { 24 | writeError(w, http.StatusBadRequest, err) 25 | return 26 | } 27 | 28 | defer file.Close() 29 | 30 | data, err := io.ReadAll(file) 31 | 32 | if err != nil { 33 | writeError(w, http.StatusBadRequest, err) 34 | return 35 | } 36 | 37 | input := provider.File{ 38 | Name: header.Filename, 39 | 40 | Content: data, 41 | ContentType: header.Header.Get("Content-Type"), 42 | } 43 | 44 | options := &provider.TranscribeOptions{ 45 | Language: language, 46 | } 47 | 48 | transcription, err := p.Transcribe(r.Context(), input, options) 49 | 50 | if err != nil { 51 | writeError(w, http.StatusBadRequest, err) 52 | return 53 | } 54 | 55 | w.Header().Set("Content-Type", "text/plain") 56 | 57 | w.WriteHeader(http.StatusOK) 58 | io.WriteString(w, transcription.Text) 59 | } 60 | -------------------------------------------------------------------------------- /server/api/handler_translate.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/pkg/translator" 8 | ) 9 | 10 | func (h *Handler) handleTranslate(w http.ResponseWriter, r *http.Request) { 11 | model := valueModel(r) 12 | language := valueLanguage(r) 13 | 14 | p, err := h.Translator(model) 15 | 16 | if err != nil { 17 | writeError(w, http.StatusBadRequest, err) 18 | return 19 | } 20 | 21 | text, err := h.readText(r) 22 | 23 | if err != nil { 24 | writeError(w, http.StatusBadRequest, err) 25 | return 26 | } 27 | 28 | options := &translator.TranslateOptions{ 29 | Language: language, 30 | } 31 | 32 | translation, err := p.Translate(r.Context(), text, options) 33 | 34 | if err != nil { 35 | writeError(w, http.StatusBadRequest, err) 36 | return 37 | } 38 | 39 | w.Header().Set("Content-Type", "text/plain") 40 | 41 | w.WriteHeader(http.StatusOK) 42 | io.WriteString(w, translation.Text) 43 | } 44 | -------------------------------------------------------------------------------- /server/api/models.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | type Result struct { 4 | Index int `json:"index,omitempty"` 5 | Score float64 `json:"score,omitempty"` 6 | Document `json:",inline"` 7 | } 8 | 9 | type Segment struct { 10 | Text string `json:"text"` 11 | } 12 | 13 | type Document struct { 14 | Text string `json:"text,omitempty"` 15 | 16 | Segments []Segment `json:"segments,omitempty"` 17 | } 18 | -------------------------------------------------------------------------------- /server/index/handler.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/config" 8 | 9 | "github.com/go-chi/chi/v5" 10 | ) 11 | 12 | type Handler struct { 13 | *config.Config 14 | http.Handler 15 | } 16 | 17 | func New(cfg *config.Config) (*Handler, error) { 18 | mux := chi.NewMux() 19 | 20 | h := &Handler{ 21 | Config: cfg, 22 | Handler: mux, 23 | } 24 | 25 | h.Attach(mux) 26 | return h, nil 27 | } 28 | 29 | func (h *Handler) Attach(r chi.Router) { 30 | r.Get("/{index}", h.handleList) 31 | r.Delete("/{index}", h.handleDeletion) 32 | 33 | r.Post("/{index}", h.handleIndex) 34 | r.Post("/{index}/query", h.handleQuery) 35 | } 36 | 37 | func writeJson(w http.ResponseWriter, v any) { 38 | w.Header().Set("Content-Type", "application/json") 39 | 40 | enc := json.NewEncoder(w) 41 | enc.SetEscapeHTML(false) 42 | 43 | enc.Encode(v) 44 | } 45 | 46 | func writeError(w http.ResponseWriter, code int, err error) { 47 | w.WriteHeader(code) 48 | w.Write([]byte(err.Error())) 49 | } 50 | -------------------------------------------------------------------------------- /server/index/handler_delete.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | ) 7 | 8 | func (s *Handler) handleDeletion(w http.ResponseWriter, r *http.Request) { 9 | i, err := s.Index(r.PathValue("index")) 10 | 11 | if err != nil { 12 | http.Error(w, err.Error(), http.StatusBadRequest) 13 | return 14 | } 15 | 16 | var ids []string 17 | 18 | if err := json.NewDecoder(r.Body).Decode(&ids); err != nil { 19 | http.Error(w, err.Error(), http.StatusBadRequest) 20 | return 21 | } 22 | 23 | if err := i.Delete(r.Context(), ids...); err != nil { 24 | http.Error(w, err.Error(), http.StatusBadRequest) 25 | return 26 | } 27 | 28 | w.WriteHeader(http.StatusNoContent) 29 | } 30 | -------------------------------------------------------------------------------- /server/index/handler_index.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/pkg/index" 8 | ) 9 | 10 | func (s *Handler) handleIndex(w http.ResponseWriter, r *http.Request) { 11 | i, err := s.Index(r.PathValue("index")) 12 | 13 | if err != nil { 14 | http.Error(w, err.Error(), http.StatusBadRequest) 15 | return 16 | } 17 | 18 | var request []Document 19 | 20 | if err := json.NewDecoder(r.Body).Decode(&request); err != nil { 21 | http.Error(w, err.Error(), http.StatusBadRequest) 22 | return 23 | } 24 | 25 | var documents []index.Document 26 | 27 | for _, d := range request { 28 | document := index.Document{ 29 | ID: d.ID, 30 | 31 | Title: d.Title, 32 | Source: d.Source, 33 | Content: d.Content, 34 | 35 | Metadata: d.Metadata, 36 | 37 | Embedding: d.Embedding, 38 | } 39 | 40 | documents = append(documents, document) 41 | } 42 | 43 | if err := i.Index(r.Context(), documents...); err != nil { 44 | http.Error(w, err.Error(), http.StatusBadRequest) 45 | return 46 | } 47 | 48 | w.WriteHeader(http.StatusNoContent) 49 | } 50 | -------------------------------------------------------------------------------- /server/index/handler_list.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/adrianliechti/wingman/pkg/index" 7 | ) 8 | 9 | func (s *Handler) handleList(w http.ResponseWriter, r *http.Request) { 10 | i, err := s.Index(r.PathValue("index")) 11 | 12 | if err != nil { 13 | http.Error(w, err.Error(), http.StatusBadRequest) 14 | return 15 | } 16 | 17 | opts := &index.ListOptions{} 18 | 19 | if val := r.URL.Query().Get("cursor"); val != "" { 20 | opts.Cursor = val 21 | } 22 | 23 | page, err := i.List(r.Context(), opts) 24 | 25 | if err != nil { 26 | http.Error(w, err.Error(), http.StatusBadRequest) 27 | return 28 | } 29 | 30 | items := make([]Document, 0) 31 | 32 | for _, d := range page.Items { 33 | items = append(items, Document{ 34 | ID: d.ID, 35 | 36 | Title: d.Title, 37 | Source: d.Source, 38 | Content: d.Content, 39 | 40 | Metadata: d.Metadata, 41 | 42 | Embedding: d.Embedding, 43 | }) 44 | } 45 | 46 | result := Page[Document]{ 47 | Items: items, 48 | Cursor: page.Cursor, 49 | } 50 | 51 | writeJson(w, result) 52 | } 53 | -------------------------------------------------------------------------------- /server/index/handler_query.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/pkg/index" 8 | "github.com/adrianliechti/wingman/pkg/to" 9 | ) 10 | 11 | func (s *Handler) handleQuery(w http.ResponseWriter, r *http.Request) { 12 | i, err := s.Index(r.PathValue("index")) 13 | 14 | if err != nil { 15 | http.Error(w, err.Error(), http.StatusBadRequest) 16 | return 17 | } 18 | 19 | var query Query 20 | 21 | if err := json.NewDecoder(r.Body).Decode(&query); err != nil { 22 | http.Error(w, err.Error(), http.StatusBadRequest) 23 | return 24 | } 25 | 26 | if len(query.Text) == 0 { 27 | writeError(w, http.StatusBadRequest, nil) 28 | return 29 | } 30 | 31 | options := &index.QueryOptions{ 32 | Limit: query.Limit, 33 | } 34 | 35 | result, err := i.Query(r.Context(), query.Text, options) 36 | 37 | if err != nil { 38 | http.Error(w, err.Error(), http.StatusBadRequest) 39 | return 40 | } 41 | 42 | results := make([]Result, 0) 43 | 44 | for _, r := range result { 45 | results = append(results, Result{ 46 | Score: to.Ptr(float64(r.Score)), 47 | 48 | Document: Document{ 49 | ID: r.ID, 50 | 51 | Title: r.Title, 52 | Source: r.Source, 53 | Content: r.Content, 54 | 55 | Metadata: r.Metadata, 56 | 57 | //Embedding: r.Embedding, 58 | }, 59 | }) 60 | } 61 | 62 | writeJson(w, results) 63 | } 64 | -------------------------------------------------------------------------------- /server/index/models.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | type Page[T any] struct { 4 | Items []T `json:"items,omitempty"` 5 | Cursor string `json:"cursor,omitempty"` 6 | } 7 | 8 | type Document struct { 9 | ID string `json:"id,omitempty"` 10 | 11 | Title string `json:"title,omitempty"` 12 | Source string `json:"source,omitempty"` 13 | Content string `json:"content,omitempty"` 14 | 15 | Metadata map[string]string `json:"metadata,omitempty"` 16 | 17 | Embedding []float32 `json:"embedding,omitempty"` 18 | } 19 | 20 | type Result struct { 21 | Score *float64 `json:"score,omitempty"` 22 | Document `json:",inline"` 23 | } 24 | 25 | type Query struct { 26 | Text string `json:"text,omitempty"` 27 | 28 | Limit *int `json:"limit,omitempty"` 29 | } 30 | -------------------------------------------------------------------------------- /server/openai/handler_audio_speach.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | ) 9 | 10 | func (h *Handler) handleAudioSpeech(w http.ResponseWriter, r *http.Request) { 11 | var req SpeechRequest 12 | 13 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 14 | writeError(w, http.StatusBadRequest, err) 15 | return 16 | } 17 | 18 | synthesizer, err := h.Synthesizer(req.Model) 19 | 20 | if err != nil { 21 | writeError(w, http.StatusBadRequest, err) 22 | return 23 | } 24 | 25 | options := &provider.SynthesizeOptions{ 26 | Voice: req.Voice, 27 | } 28 | 29 | synthesis, err := synthesizer.Synthesize(r.Context(), req.Input, options) 30 | 31 | if err != nil { 32 | writeError(w, http.StatusBadRequest, err) 33 | return 34 | } 35 | 36 | w.Header().Set("Content-Type", synthesis.ContentType) 37 | w.Write(synthesis.Content) 38 | } 39 | -------------------------------------------------------------------------------- /server/openai/handler_audio_transcription.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/pkg/provider" 8 | ) 9 | 10 | func (h *Handler) handleAudioTranscription(w http.ResponseWriter, r *http.Request) { 11 | if err := r.ParseMultipartForm(32 << 20); err != nil { 12 | writeError(w, http.StatusBadRequest, err) 13 | return 14 | } 15 | 16 | model := r.FormValue("model") 17 | 18 | transcriber, err := h.Transcriber(model) 19 | 20 | if err != nil { 21 | writeError(w, http.StatusBadRequest, err) 22 | return 23 | } 24 | 25 | prompt := r.FormValue("prompt") 26 | language := r.FormValue("language") 27 | 28 | _ = prompt 29 | _ = language 30 | 31 | file, header, err := r.FormFile("file") 32 | 33 | if err != nil { 34 | writeError(w, http.StatusBadRequest, err) 35 | return 36 | } 37 | 38 | defer file.Close() 39 | 40 | data, err := io.ReadAll(file) 41 | 42 | if err != nil { 43 | writeError(w, http.StatusBadRequest, err) 44 | return 45 | } 46 | 47 | input := provider.File{ 48 | Name: header.Filename, 49 | 50 | Content: data, 51 | ContentType: header.Header.Get("Content-Type"), 52 | } 53 | 54 | options := &provider.TranscribeOptions{} 55 | 56 | transcription, err := transcriber.Transcribe(r.Context(), input, options) 57 | 58 | if err != nil { 59 | writeError(w, http.StatusBadRequest, err) 60 | return 61 | } 62 | 63 | result := Transcription{ 64 | Task: "transcribe", 65 | 66 | // Language: transcription.Language, 67 | // Duration: transcription.Duration, 68 | 69 | Text: transcription.Text, 70 | } 71 | 72 | writeJson(w, result) 73 | } 74 | -------------------------------------------------------------------------------- /server/openai/handler_embeddings.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "net/http" 7 | ) 8 | 9 | func (h *Handler) handleEmbeddings(w http.ResponseWriter, r *http.Request) { 10 | var req EmbeddingsRequest 11 | 12 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 13 | writeError(w, http.StatusBadRequest, err) 14 | return 15 | } 16 | 17 | embedder, err := h.Embedder(req.Model) 18 | 19 | if err != nil { 20 | writeError(w, http.StatusBadRequest, err) 21 | return 22 | } 23 | 24 | var inputs []string 25 | 26 | switch v := req.Input.(type) { 27 | case string: 28 | inputs = []string{v} 29 | case []string: 30 | inputs = v 31 | } 32 | 33 | if len(inputs) == 0 { 34 | writeError(w, http.StatusBadRequest, errors.New("no input provided")) 35 | return 36 | } 37 | 38 | embedding, err := embedder.Embed(r.Context(), inputs) 39 | 40 | if err != nil { 41 | writeError(w, http.StatusBadRequest, err) 42 | return 43 | } 44 | 45 | result := &EmbeddingList{ 46 | Object: "list", 47 | 48 | Model: embedding.Model, 49 | } 50 | 51 | if result.Model == "" { 52 | result.Model = req.Model 53 | } 54 | 55 | for i, e := range embedding.Embeddings { 56 | result.Data = append(result.Data, Embedding{ 57 | Object: "embedding", 58 | 59 | Index: i, 60 | Embedding: e, 61 | }) 62 | } 63 | 64 | if embedding.Usage != nil { 65 | result.Usage = &Usage{ 66 | PromptTokens: embedding.Usage.InputTokens, 67 | TotalTokens: embedding.Usage.InputTokens + embedding.Usage.OutputTokens, 68 | } 69 | } 70 | 71 | writeJson(w, result) 72 | } 73 | -------------------------------------------------------------------------------- /server/openai/handler_image_generation.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "net/http" 7 | 8 | "github.com/adrianliechti/wingman/pkg/provider" 9 | ) 10 | 11 | func (h *Handler) handleImageGeneration(w http.ResponseWriter, r *http.Request) { 12 | var req ImageCreateRequest 13 | 14 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 15 | writeError(w, http.StatusBadRequest, err) 16 | return 17 | } 18 | 19 | renderer, err := h.Renderer(req.Model) 20 | 21 | if err != nil { 22 | writeError(w, http.StatusBadRequest, err) 23 | return 24 | } 25 | 26 | options := &provider.RenderOptions{} 27 | 28 | image, err := renderer.Render(r.Context(), req.Prompt, options) 29 | 30 | if err != nil { 31 | writeError(w, http.StatusBadRequest, err) 32 | return 33 | } 34 | 35 | result := ImageList{} 36 | 37 | if req.ResponseFormat == "url" { 38 | result.Images = []Image{ 39 | { 40 | URL: "data:" + image.ContentType + ";base64," + base64.StdEncoding.EncodeToString(image.Content), 41 | }, 42 | } 43 | } else { 44 | result.Images = []Image{ 45 | { 46 | B64JSON: base64.StdEncoding.EncodeToString(image.Content), 47 | }, 48 | } 49 | 50 | } 51 | 52 | writeJson(w, result) 53 | } 54 | -------------------------------------------------------------------------------- /server/openai/handler_models.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | ) 7 | 8 | func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) { 9 | result := &ModelList{ 10 | Object: "list", 11 | } 12 | 13 | for _, m := range h.Models() { 14 | result.Models = append(result.Models, Model{ 15 | Object: "model", 16 | 17 | ID: m.ID, 18 | Created: time.Now().Unix(), 19 | OwnedBy: "openai", 20 | }) 21 | } 22 | 23 | writeJson(w, result) 24 | } 25 | 26 | func (h *Handler) handleModel(w http.ResponseWriter, r *http.Request) { 27 | model, err := h.Model(r.PathValue("id")) 28 | 29 | if err != nil { 30 | writeError(w, http.StatusNotFound, err) 31 | return 32 | } 33 | 34 | result := &Model{ 35 | Object: "model", 36 | 37 | ID: model.ID, 38 | Created: time.Now().Unix(), 39 | OwnedBy: "openai", 40 | } 41 | 42 | writeJson(w, result) 43 | } 44 | -------------------------------------------------------------------------------- /server/unstructured/handler.go: -------------------------------------------------------------------------------- 1 | package unstructured 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | 7 | "github.com/adrianliechti/wingman/config" 8 | 9 | "github.com/go-chi/chi/v5" 10 | ) 11 | 12 | type Handler struct { 13 | *config.Config 14 | http.Handler 15 | } 16 | 17 | func New(cfg *config.Config) (*Handler, error) { 18 | mux := chi.NewMux() 19 | 20 | h := &Handler{ 21 | Config: cfg, 22 | Handler: mux, 23 | } 24 | 25 | h.Attach(mux) 26 | return h, nil 27 | } 28 | 29 | func (h *Handler) Attach(r chi.Router) { 30 | r.Post("/partition", h.handlePartition) 31 | } 32 | 33 | func writeJson(w http.ResponseWriter, v any) { 34 | w.Header().Set("Content-Type", "application/json") 35 | 36 | enc := json.NewEncoder(w) 37 | enc.SetEscapeHTML(false) 38 | 39 | enc.Encode(v) 40 | } 41 | 42 | func writeError(w http.ResponseWriter, code int, err error) { 43 | w.WriteHeader(code) 44 | w.Write([]byte(err.Error())) 45 | } 46 | -------------------------------------------------------------------------------- /server/unstructured/models.go: -------------------------------------------------------------------------------- 1 | package unstructured 2 | 3 | type Partition struct { 4 | ID string `json:"element_id,omitempty"` 5 | 6 | Type string `json:"type,omitempty"` 7 | Text string `json:"text,omitempty"` 8 | 9 | //Metadata PartitionMetadata `json:"metadata,omitempty"` 10 | } 11 | 12 | // type PartitionMetadata struct { 13 | // FileName string `json:"filename,omitempty"` 14 | // FileType string `json:"filetype,omitempty"` 15 | 16 | // Languages []string `json:"languages,omitempty"` 17 | // } 18 | 19 | type ChunkingStrategy string 20 | 21 | const ( 22 | ChunkingStrategyUnknown ChunkingStrategy = "" 23 | ChunkingStrategyNone ChunkingStrategy = "none" 24 | ) 25 | -------------------------------------------------------------------------------- /test/test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/adrianliechti/wingman/pkg/provider" 7 | "github.com/adrianliechti/wingman/pkg/provider/ollama" 8 | ) 9 | 10 | type TestContext struct { 11 | Context context.Context 12 | 13 | Embedder provider.Embedder 14 | Completer provider.Completer 15 | } 16 | 17 | func NewContext() *TestContext { 18 | url := "http://localhost:11434" 19 | 20 | completer, _ := ollama.NewCompleter(url, "llama3.1:latest") 21 | embedder, _ := ollama.NewEmbedder(url, "nomic-embed-text:latest") 22 | 23 | return &TestContext{ 24 | Context: context.Background(), 25 | 26 | Embedder: embedder, 27 | Completer: completer, 28 | } 29 | } 30 | --------------------------------------------------------------------------------