├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── build-test.yml │ ├── docker-publish.yml │ └── golangci-lint.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile.ce ├── LICENSE ├── Makefile.ce ├── README.md ├── SECURITY.md ├── assets └── zep-logo-icon-gradient-rgb.svg ├── docker-compose.ce.yaml ├── go.work ├── go.work.sum ├── src ├── api │ ├── apidata │ │ ├── common.go │ │ ├── fact.go │ │ ├── memory_ce.go │ │ ├── memory_common.go │ │ ├── message.go │ │ ├── session_ce.go │ │ ├── session_common.go │ │ ├── user_ce.go │ │ └── user_common.go │ ├── apihandlers │ │ ├── fact_handlers_ce.go │ │ ├── fact_handlers_common.go │ │ ├── memory_handlers_ce.go │ │ ├── memory_handlers_common.go │ │ ├── message_handlers.go │ │ ├── session_handlers_common.go │ │ └── user_handlers.go │ ├── handlertools │ │ ├── request_state_ce.go │ │ └── tools.go │ ├── middleware │ │ ├── auth.go │ │ ├── secret_key_auth_ce.go │ │ └── send_version.go │ ├── routes.go │ └── server_ce.go ├── go.mod ├── go.sum ├── golangci.yaml ├── lib │ ├── communication │ │ ├── communication_ce.go │ │ ├── service.go │ │ └── service_mock.go │ ├── config │ │ ├── config.go │ │ ├── env_template.go │ │ ├── load_ce.go │ │ ├── models_ce.go │ │ └── version_ce.go │ ├── enablement │ │ ├── enablement_ce.go │ │ ├── events.go │ │ ├── plan_ce.go │ │ ├── service.go │ │ └── service_mock.go │ ├── graphiti │ │ └── service_ce.go │ ├── logger │ │ ├── bun_hook.go │ │ └── logger.go │ ├── observability │ │ ├── observability_ce.go │ │ ├── service.go │ │ └── service_mock.go │ ├── pg │ │ ├── db.go │ │ └── integrity.go │ ├── search │ │ ├── mmr.go │ │ └── rrf.go │ ├── telemetry │ │ ├── events.go │ │ ├── service.go │ │ ├── service_mock.go │ │ └── telemetry_ce.go │ ├── util │ │ ├── httputil │ │ │ ├── http_base.go │ │ │ ├── http_base_mock.go │ │ │ └── retryable_http_client.go │ │ └── utils.go │ └── zerrors │ │ ├── errors.go │ │ └── storage.go ├── main.go ├── models │ ├── app_state_ce.go │ ├── fact_common.go │ ├── memory_ce.go │ ├── memory_common.go │ ├── memorystore_ce.go │ ├── memorystore_common.go │ ├── options.go │ ├── projectsetting.go │ ├── request_state_ce.go │ ├── search_ce.go │ ├── search_common.go │ ├── session_ce.go │ ├── session_common.go │ ├── state.go │ ├── tasks_ce.go │ ├── tasks_common.go │ └── userstore.go ├── setup_ce.go ├── state.go └── store │ ├── db_utils_ce.go │ ├── memory_ce.go │ ├── memory_common.go │ ├── memorystore_common.go │ ├── message_ce.go │ ├── message_common.go │ ├── metadata_utils.go │ ├── migrations │ ├── 000000000001_database_setup.down.sql │ ├── 000000000001_database_setup.up.sql │ └── migrate.go │ ├── purge_ce.go │ ├── purge_common.go │ ├── schema_ce.go │ ├── schema_common.go │ ├── session_ce.go │ ├── sessionstore_ce.go │ ├── sessionstore_common.go │ ├── userstore_ce.go │ └── userstore_common.go ├── zep └── zep.yaml /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Please provide a clear set of steps that would allow us to reproduce the bug. Provide code samples or shell commands as necessary. 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Logs** 20 | Please provide your Zep server and applications logs. 21 | 22 | **Environment (please complete the following information):** 23 | - Zep version: [e.g. vX.X.X] 24 | - Zep SDK and version: [e.g. `zep-js` or `zep-python` and vX.X.X] 25 | - Deployment [e.g. using `docker compose`, to a hosted environment such as Render] 26 | 27 | _Note_: The Zep server version is available in the Zep server logs at startup: 28 | `Starting zep server version 0.11.0-cbf4fe4 (2023-08-30T12:49:03+0000)` 29 | 30 | **Additional context** 31 | Add any other context about the problem here. 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEAT]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/build-test.yml: -------------------------------------------------------------------------------- 1 | name: build-test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: [ "main" ] 9 | jobs: 10 | build: 11 | runs-on: ubuntu-4c-16GB-150GB 12 | container: debian:bullseye-slim 13 | environment: build-test 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: install certs and build-essential (required by CGO) 17 | run: apt-get update && apt-get install -y ca-certificates build-essential 18 | - name: Set up Go 19 | uses: actions/setup-go@v5 20 | with: 21 | go-version: '^1.22' 22 | - name: Cache Go modules 23 | uses: actions/cache@v4 24 | with: 25 | path: ~/go/pkg/mod 26 | key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} 27 | restore-keys: | 28 | ${{ runner.os }}-go- 29 | - name: Build 30 | run: go build -v ./src/... 31 | -------------------------------------------------------------------------------- /.github/workflows/docker-publish.yml: -------------------------------------------------------------------------------- 1 | name: Zep Server Docker Build and Publish 2 | 3 | on: 4 | push: 5 | # Publish semver tags as releases. 6 | tags: [ 'v*.*.*' ] 7 | 8 | env: 9 | REGISTRY: docker.io 10 | IMAGE_NAME: zepai/zep 11 | 12 | jobs: 13 | docker-image: 14 | environment: 15 | name: release 16 | runs-on: ubuntu-latest 17 | permissions: 18 | contents: read 19 | id-token: write 20 | steps: 21 | - name: Checkout repository 22 | uses: actions/checkout@v3 23 | with: 24 | ref: ${{ github.event.inputs.tag || github.ref }} 25 | 26 | - name: Set up Depot CLI 27 | uses: depot/setup-action@v1 28 | 29 | - name: Login to DockerHub 30 | uses: docker/login-action@v2 31 | with: 32 | username: ${{ secrets.DOCKERHUB_USERNAME }} 33 | password: ${{ secrets.DOCKERHUB_TOKEN }} 34 | 35 | - name: Extract Docker metadata 36 | id: meta 37 | uses: docker/metadata-action@v4.4.0 38 | with: 39 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 40 | tags: | 41 | type=semver,pattern={{version}} 42 | type=semver,pattern={{major}}.{{minor}} 43 | type=match,pattern=v(.*-beta),group=1 44 | type=match,pattern=v.*-(beta),group=1 45 | 46 | - name: Depot build and push image 47 | uses: depot/build-push-action@v1 48 | with: 49 | project: v9jv1mlpwc 50 | context: . 51 | platforms: linux/amd64,linux/arm64 52 | push: ${{ github.event_name != 'pull_request' }} 53 | tags: ${{ steps.meta.outputs.tags || env.TAGS }} 54 | labels: ${{ steps.meta.outputs.labels }} 55 | cache-from: type=gha 56 | cache-to: type=gha,mode=max 57 | file: Dockerfile.ce 58 | -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | on: 3 | push: 4 | tags: 5 | - v* 6 | pull_request: 7 | branches: 8 | - main 9 | permissions: 10 | contents: read 11 | jobs: 12 | golangci: 13 | name: lint 14 | runs-on: depot-ubuntu-22.04-8 15 | steps: 16 | - uses: actions/setup-go@v5 17 | with: 18 | go-version: '1.22' 19 | cache: false 20 | - uses: actions/checkout@v4 21 | - name: golangci-lint 22 | uses: golangci/golangci-lint-action@v6 23 | with: 24 | working-directory: ./src 25 | version: v1.61.0 26 | args: 27 | --config=golangci.yaml 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | out/ 6 | *.exe 7 | *.exe~ 8 | *.dll 9 | *.so 10 | *.dylib 11 | 12 | # Secrets 13 | .env 14 | .env.local 15 | 16 | # Test data 17 | test_data 18 | 19 | # Test binary, built with `go test -c` 20 | *.test 21 | 22 | # Output of the go coverage tool, specifically when used with LiteIDE 23 | *.out 24 | 25 | # Dependency directories (remove the comment below to include it) 26 | # vendor/ 27 | 28 | # Go workspace file 29 | .idea 30 | .vscode 31 | 32 | # VSCode local history 33 | .history 34 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | founders@getzep.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Zep 2 | 3 | Thank you for your interest in contributing to Zep! We appreciate your efforts and look forward to collaborating with you. 4 | 5 | ### Getting Started 6 | 7 | 1. **Fork and Clone**: Start by forking the [Zep repo](https://github.com/getzep/zep). Then, clone your fork locally: 8 | 9 | ``` 10 | git clone https://github.com//zep.git 11 | ``` 12 | 13 | 2. **Set Upstream**: Keep your fork synced with the upstream repo by adding it as a remote: 14 | 15 | ``` 16 | git remote add upstream https://github.com/getzep/zep.git 17 | ``` 18 | 19 | 3. **Create a Feature Branch**: Always create a new branch for your work. This helps in keeping your changes organized and easier for maintainers to review. 20 | 21 | ``` 22 | git checkout -b feature/your-feature-name 23 | ``` 24 | 25 | ### Setting "Development" Mode 26 | "Development" mode forces Zep's log level to "debug" and disables caching of the web UI. This is useful when developing Zep locally. 27 | 28 | To enable "development" mode, set the `ZEP_DEVELOPMENT` environment variable to `true`: 29 | 30 | ``` 31 | export ZEP_DEVELOPMENT=true 32 | ``` 33 | 34 | or modify your `.env` file accordingly. 35 | 36 | 37 | ### Running the Database and NLP Server Stack 38 | 39 | A development stack can be started by running: 40 | 41 | ```bash 42 | make dev 43 | ``` 44 | 45 | This starts the DB and NLP services using docker compose and exposes the DB on port 5432 and the NLP service on port 5557. 46 | The database volume is also not persistent, so it will be wiped out when the stack is stopped. 47 | 48 | ### Automatically Rebuilding Zep using Go Watch 49 | 50 | **Note:** You will need to have [Go Watch](https://github.com/mitranim/gow) installed. 51 | 52 | If you want to automatically rebuild Zep when you make changes to the code, run: 53 | 54 | ``` 55 | make watch 56 | ``` 57 | 58 | The above sets "Development" mode and binds Zep to localhost only. 59 | 60 | 61 | ### Rebuilding Tailwind CSS 62 | 63 | If you make changes to the CSS used by HTML template files, you will need to rebuild the Tailwind CSS file. 64 | 65 | Run: 66 | ``` 67 | make web 68 | ``` 69 | 70 | ### Building Zep 71 | 72 | Follow these steps to build Zep locally: 73 | 74 | 1. Navigate to the project root: 75 | 76 | ``` 77 | cd zep 78 | ``` 79 | 80 | 2. Build the project: 81 | 82 | ``` 83 | make build 84 | ``` 85 | 86 | This will produce the binary in `./out/bin`. 87 | 88 | ### Running Tests 89 | 90 | It's essential to ensure that your code passes all tests. Run the tests using: 91 | 92 | ``` 93 | make test 94 | ``` 95 | 96 | If you want to check the coverage, run: 97 | 98 | ``` 99 | make coverage 100 | ``` 101 | 102 | ### Code Linting 103 | 104 | Ensure your code adheres to our linting standards: 105 | 106 | ``` 107 | make lint 108 | ``` 109 | 110 | ### Generating Swagger Docs 111 | 112 | If you make changes to the API or its documentation, regenerate the Swagger docs: 113 | 114 | ``` 115 | make swagger 116 | ``` 117 | 118 | ### Submitting Changes 119 | 120 | 1. **Commit Your Changes**: Use meaningful commit messages that describe the changes made. 121 | 122 | ``` 123 | git add . 124 | git commit -m "Your detailed commit message" 125 | ``` 126 | 127 | 2. **Push to Your Fork**: 128 | 129 | ``` 130 | git push origin feature/your-feature-name 131 | ``` 132 | 133 | 3. **Open a Pull Request**: Navigate to the [Zep GitHub repo](https://github.com/getzep/zep) and click on "New pull request". Choose your fork and the branch you've been working on. Submit the PR with a descriptive message. 134 | 135 | ### Feedback 136 | 137 | Maintainers will review your PR and provide feedback. If any changes are needed, make them in your feature branch and push to your fork. The PR will update automatically. 138 | 139 | ### Final Notes 140 | 141 | - Always be respectful and kind. 142 | - If you're unsure about something, ask. We're here to help. 143 | - Once again, thank you for contributing to Zep! 144 | 145 | --- 146 | 147 | If you encounter any issues or have suggestions, please open an issue! -------------------------------------------------------------------------------- /Dockerfile.ce: -------------------------------------------------------------------------------- 1 | FROM golang:1.22.5-bookworm AS BUILD 2 | 3 | RUN mkdir /app 4 | WORKDIR /app 5 | COPY . . 6 | WORKDIR /app/src 7 | RUN go mod download 8 | WORKDIR /app 9 | RUN make -f Makefile.ce build 10 | 11 | FROM debian:bookworm-slim AS RUNTIME 12 | RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/* 13 | WORKDIR /app 14 | COPY --from=BUILD /app/out/bin/zep /app/ 15 | # Ship with default config that can be overridden by ENV vars 16 | COPY zep.yaml /app/ 17 | 18 | EXPOSE 8000 19 | ENTRYPOINT ["/app/zep"] 20 | -------------------------------------------------------------------------------- /Makefile.ce: -------------------------------------------------------------------------------- 1 | GOCMD=go 2 | 3 | VERSION?=0.0.0 4 | SERVICE_PORT?= 5 | DOCKER_REGISTRY?= 6 | BINARY_NAME?=zep 7 | EXPORT_RESULT?=false # for CI please set EXPORT_RESULT to true 8 | BINARY_DEST=./out/bin 9 | BINARY=$(BINARY_DEST)/zep 10 | GOTEST_CMD=$(GOCMD) test 11 | GOVET_CMD=$(GOCMD) vet 12 | GORUN_CMD=$(GOCMD) run 13 | GOBUILD_CMD=$(GOCMD) build 14 | 15 | WD=$(shell pwd) 16 | 17 | SRC_DIR=$(WD)/src 18 | APP_DIR=$(SRC_DIR) 19 | 20 | RUN_ARGS=-r 21 | 22 | PACKAGE := github.com/getzep/zep/lib/config 23 | VERSION := $(shell git describe --tags --always --abbrev=0 --match='v[0-9]*.[0-9]*.[0-9]*' 2> /dev/null | sed 's/^.//') 24 | COMMIT_HASH := $(shell git rev-parse --short HEAD) 25 | BUILD_TIMESTAMP := $(shell date '+%Y-%m-%dT%H:%M:%S%z') 26 | 27 | LDFLAGS = -X '${PACKAGE}.Version=${VERSION}' \ 28 | -X '${PACKAGE}.CommitHash=${COMMIT_HASH}' \ 29 | -X '${PACKAGE}.BuildTime=${BUILD_TIMESTAMP}' 30 | 31 | GREEN := $(shell tput -Txterm setaf 2) 32 | YELLOW := $(shell tput -Txterm setaf 3) 33 | WHITE := $(shell tput -Txterm setaf 7) 34 | CYAN := $(shell tput -Txterm setaf 6) 35 | RESET := $(shell tput -Txterm sgr0) 36 | 37 | .PHONY: all test build dev-dump restore-db-from-dump help 38 | 39 | all: test build 40 | 41 | run: 42 | $(GORUN_CMD) -ldflags="${LDFLAGS}" $(APP_DIR)/... $(RUN_ARGS) 43 | 44 | build: 45 | mkdir -p $(BINARY_DEST) 46 | $(GOBUILD_CMD) -ldflags="${LDFLAGS}" -o $(BINARY) $(APP_DIR) 47 | 48 | build-run: build 49 | $(BINARY) $(RUN_ARGS) 50 | 51 | ## Go Watch to run server and restart on changes 52 | ## https://github.com/mitranim/gow 53 | watch: 54 | gow run $(APP_DIR)/... $(RUN_ARGS) 55 | 56 | test: ## Run project tests 57 | $(GOTEST_CMD) -shuffle on -race $(SRC_DIR)/... -p 1 58 | 59 | clean: ## Remove build related file 60 | rm -f $(BINARY) 61 | rm -f ./junit-report.xml checkstyle-report.xml ./coverage.xml ./profile.cov yamllint-checkstyle.xml 62 | 63 | coverage: ## Run the tests of the project and export the coverage 64 | $(GOTEST) -cover -covermode=count -coverprofile=profile.cov $(SRC_DIR)/... 65 | $(GOCMD) tool cover -func profile.cov 66 | ifeq ($(EXPORT_RESULT), true) 67 | GO111MODULE=off go get -u github.com/AlekSi/gocov-xml 68 | GO111MODULE=off go get -u github.com/axw/gocov/gocov 69 | gocov convert profile.cov | gocov-xml > coverage.xml 70 | endif 71 | 72 | ## Lint: 73 | lint: 74 | cd src && golangci-lint run --sort-results -c golangci.yaml 75 | 76 | ## Run the dev stack docker compose setup. This exposes DB and NLP services 77 | ## for local development. This does not start the Zep service. 78 | dev: 79 | docker compose up -d 80 | 81 | ## Docker: 82 | docker-build: ## Use the dockerfile to build the container 83 | DOCKER_BUILDKIT=1 docker build --rm --tag $(BINARY_NAME) . 84 | 85 | docker-release: ## Release the container with tag latest and version 86 | docker tag $(BINARY_NAME) $(DOCKER_REGISTRY)$(BINARY_NAME):latest 87 | docker tag $(BINARY_NAME) $(DOCKER_REGISTRY)$(BINARY_NAME):$(VERSION) 88 | # Push the docker images 89 | docker push $(DOCKER_REGISTRY)$(BINARY_NAME):latest 90 | docker push $(DOCKER_REGISTRY)$(BINARY_NAME):$(VERSION) 91 | 92 | 93 | ## Help: 94 | help: ## Show this help. 95 | @echo '' 96 | @echo 'Usage:' 97 | @echo ' ${YELLOW}make${RESET} ${GREEN}${RESET}' 98 | @echo '' 99 | @echo 'Targets:' 100 | @awk 'BEGIN {FS = ":.*?## "} { \ 101 | if (/^[a-zA-Z_-]+:.*?##.*$$/) {printf " ${YELLOW}%-20s${GREEN}%s${RESET}\n", $$1, $$2} \ 102 | else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \ 103 | }' $(MAKEFILE_LIST) 104 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | Use this section to tell people about which versions of your project are 6 | currently being supported with security updates. 7 | 8 | | Version | Supported | 9 | | ------- | ------------------ | 10 | | 0.x.x | :white_check_mark: | 11 | 12 | ## Reporting a Vulnerability 13 | 14 | Please use GitHub's Private Vulnerability Reporting mechanism found in the Security section of this repo. 15 | -------------------------------------------------------------------------------- /assets/zep-logo-icon-gradient-rgb.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /docker-compose.ce.yaml: -------------------------------------------------------------------------------- 1 | name: zep-ce 2 | 3 | services: 4 | zep: 5 | image: zepai/zep:latest 6 | # build: 7 | # context: . 8 | # dockerfile: Dockerfile.ce 9 | ports: 10 | - "8000:8000" 11 | volumes: 12 | - ./zep.yaml:/app/zep.yaml 13 | environment: 14 | - ZEP_CONFIG_FILE=zep.yaml 15 | networks: 16 | - zep-network 17 | depends_on: 18 | graphiti: 19 | condition: service_healthy 20 | db: 21 | condition: service_healthy 22 | db: 23 | image: ankane/pgvector:v0.5.1 24 | container_name: zep-ce-postgres 25 | restart: on-failure 26 | shm_size: "128mb" # Increase this if vacuuming fails with a "no space left on device" error 27 | environment: 28 | - POSTGRES_USER=postgres 29 | - POSTGRES_PASSWORD=postgres 30 | networks: 31 | - zep-network 32 | healthcheck: 33 | test: ["CMD", "pg_isready", "-q", "-d", "postgres", "-U", "postgres"] 34 | interval: 5s 35 | timeout: 5s 36 | retries: 5 37 | volumes: 38 | - zep-db:/var/lib/postgresql/data 39 | ports: 40 | - "5432:5432" 41 | graphiti: 42 | image: zepai/graphiti:0.3 43 | ports: 44 | - "8003:8003" 45 | env_file: 46 | - .env 47 | networks: 48 | - zep-network 49 | healthcheck: 50 | test: 51 | [ 52 | "CMD", 53 | "python", 54 | "-c", 55 | "import urllib.request; urllib.request.urlopen('http://localhost:8003/healthcheck')", 56 | ] 57 | interval: 10s 58 | timeout: 5s 59 | retries: 3 60 | depends_on: 61 | neo4j: 62 | condition: service_healthy 63 | environment: 64 | - OPENAI_API_KEY=${OPENAI_API_KEY} 65 | - MODEL_NAME=gpt-4o-mini 66 | - NEO4J_URI=bolt://neo4j:7687 67 | - NEO4J_USER=neo4j 68 | - NEO4J_PASSWORD=zepzepzep 69 | - PORT=8003 70 | neo4j: 71 | image: neo4j:5.22.0 72 | networks: 73 | - zep-network 74 | healthcheck: 75 | test: wget http://localhost:7687 || exit 1 76 | interval: 1s 77 | timeout: 10s 78 | retries: 20 79 | start_period: 3s 80 | ports: 81 | - "7474:7474" # HTTP 82 | - "7687:7687" # Bolt 83 | volumes: 84 | - neo4j_data:/data 85 | environment: 86 | - NEO4J_AUTH=neo4j/zepzepzep 87 | volumes: 88 | neo4j_data: 89 | zep-db: 90 | networks: 91 | zep-network: 92 | driver: bridge -------------------------------------------------------------------------------- /go.work: -------------------------------------------------------------------------------- 1 | go 1.21.5 2 | 3 | use ./src 4 | -------------------------------------------------------------------------------- /go.work.sum: -------------------------------------------------------------------------------- 1 | github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= 2 | github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= 3 | github.com/cenkalti/backoff/v3 v3.2.2/go.mod h1:cIeZDE3IrqwwJl6VUwCN6trj1oXrTS4rc0ij+ULvLYs= 4 | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 5 | github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= 6 | github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= 7 | github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= 8 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= 9 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= 10 | github.com/prometheus/client_golang v1.20.2/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= 11 | github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= 12 | github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= 13 | github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= 14 | github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= 15 | github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= 16 | github.com/uptrace/opentelemetry-go-extra/otelsql v0.2.3 h1:LNi0Qa7869/loPjz2kmMvp/jwZZnMZ9scMJKhDJ1DIo= 17 | github.com/uptrace/opentelemetry-go-extra/otelsql v0.2.3/go.mod h1:jyigonKik3C5V895QNiAGpKYKEvFuqjw9qAEZks1mUg= 18 | go.opentelemetry.io/otel/sdk v1.28.0/go.mod h1:oYj7ClPUA7Iw3m+r7GeEjz0qckQRJK2B8zjcZEfu7Pg= 19 | golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 20 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 21 | golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= 22 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= 23 | -------------------------------------------------------------------------------- /src/api/apidata/common.go: -------------------------------------------------------------------------------- 1 | package apidata 2 | 3 | // APIError represents an error response. Used for swagger documentation. 4 | type APIError struct { 5 | Message string `json:"message"` 6 | } 7 | 8 | type SuccessResponse struct { 9 | Message string `json:"message"` 10 | } 11 | -------------------------------------------------------------------------------- /src/api/apidata/fact.go: -------------------------------------------------------------------------------- 1 | package apidata 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/google/uuid" 7 | 8 | "github.com/getzep/zep/models" 9 | ) 10 | 11 | func FactListTransformer(facts []models.Fact) []Fact { 12 | f := make([]Fact, len(facts)) 13 | for i, fact := range facts { 14 | f[i] = FactTransformer(fact) 15 | } 16 | 17 | return f 18 | } 19 | 20 | func FactTransformerPtr(fact *models.Fact) *Fact { 21 | if fact == nil { 22 | return nil 23 | } 24 | 25 | f := FactTransformer(*fact) 26 | 27 | return &f 28 | } 29 | 30 | func FactTransformer(fact models.Fact) Fact { 31 | return Fact{ 32 | UUID: fact.UUID, 33 | CreatedAt: fact.CreatedAt, 34 | Fact: fact.Fact, 35 | Rating: fact.Rating, 36 | } 37 | } 38 | 39 | type Fact struct { 40 | UUID uuid.UUID `json:"uuid"` 41 | CreatedAt time.Time `json:"created_at"` 42 | Fact string `json:"fact"` 43 | Rating *float64 `json:"rating,omitempty"` 44 | } 45 | 46 | type FactsResponse struct { 47 | Facts []Fact `json:"facts"` 48 | } 49 | 50 | type FactResponse struct { 51 | Fact Fact `json:"fact"` 52 | } 53 | -------------------------------------------------------------------------------- /src/api/apidata/memory_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package apidata 3 | 4 | import "github.com/getzep/zep/models" 5 | 6 | func MemoryTransformer(memory *models.Memory) Memory { 7 | return Memory{ 8 | MemoryCommon: commonMemoryTransformer(memory), 9 | } 10 | } 11 | 12 | type Memory struct { 13 | MemoryCommon 14 | } 15 | 16 | type AddMemoryRequest struct { 17 | AddMemoryRequestCommon 18 | } 19 | -------------------------------------------------------------------------------- /src/api/apidata/memory_common.go: -------------------------------------------------------------------------------- 1 | package apidata 2 | 3 | import "github.com/getzep/zep/models" 4 | 5 | type MemoryCommon struct { 6 | // A list of message objects, where each message contains a role and content. Only last_n messages will be returned 7 | Messages []Message `json:"messages"` 8 | 9 | // A dictionary containing metadata associated with the memory. 10 | Metadata map[string]any `json:"metadata,omitempty"` 11 | 12 | RelevantFacts []Fact `json:"relevant_facts"` 13 | } 14 | 15 | func commonMemoryTransformer(memory *models.Memory) MemoryCommon { 16 | return MemoryCommon{ 17 | Messages: MessageListTransformer(memory.Messages), 18 | Metadata: memory.Metadata, 19 | RelevantFacts: FactListTransformer(memory.RelevantFacts), 20 | } 21 | } 22 | 23 | type AddMemoryRequestCommon struct { 24 | // A list of message objects, where each message contains a role and content. 25 | Messages []Message `json:"messages" validate:"required"` 26 | } 27 | -------------------------------------------------------------------------------- /src/api/apidata/message.go: -------------------------------------------------------------------------------- 1 | package apidata 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/google/uuid" 7 | 8 | "github.com/getzep/zep/models" 9 | ) 10 | 11 | func MessageListTransformer(messages []models.Message) []Message { 12 | m := make([]Message, len(messages)) 13 | for i, message := range messages { 14 | m[i] = MessageTransformer(message) 15 | } 16 | 17 | return m 18 | } 19 | 20 | func MessageTransformerPtr(message models.Message) *Message { 21 | msg := MessageTransformer(message) 22 | 23 | return &msg 24 | } 25 | 26 | func MessageTransformer(message models.Message) Message { 27 | return Message{ 28 | UUID: message.UUID, 29 | CreatedAt: message.CreatedAt, 30 | UpdatedAt: message.UpdatedAt, 31 | Role: message.Role, 32 | RoleType: RoleType(message.RoleType), 33 | Content: message.Content, 34 | Metadata: message.Metadata, 35 | TokenCount: message.TokenCount, 36 | } 37 | } 38 | func MessagesToModelMessagesTransformer(messages []Message) []models.Message { 39 | m := make([]models.Message, len(messages)) 40 | for i, message := range messages { 41 | m[i] = MessageToModelMessageTransformer(message) 42 | } 43 | 44 | return m 45 | } 46 | func MessageToModelMessageTransformer(message Message) models.Message { 47 | return models.Message{ 48 | UUID: message.UUID, 49 | CreatedAt: message.CreatedAt, 50 | UpdatedAt: message.UpdatedAt, 51 | Role: message.Role, 52 | RoleType: models.RoleType(message.RoleType), 53 | Content: message.Content, 54 | Metadata: message.Metadata, 55 | TokenCount: message.TokenCount, 56 | } 57 | } 58 | 59 | // Message Represents a message in a conversation. 60 | type Message struct { 61 | // The unique identifier of the message. 62 | UUID uuid.UUID `json:"uuid"` 63 | // The timestamp of when the message was created. 64 | CreatedAt time.Time `json:"created_at"` 65 | // The timestamp of when the message was last updated. 66 | UpdatedAt time.Time `json:"updated_at"` 67 | // The role of the sender of the message (e.g., "user", "assistant"). 68 | Role string `json:"role"` 69 | // The type of the role (e.g., "user", "system"). 70 | RoleType RoleType `json:"role_type,omitempty"` 71 | // The content of the message. 72 | Content string `json:"content"` 73 | // The metadata associated with the message. 74 | Metadata map[string]any `json:"metadata,omitempty"` 75 | // The number of tokens in the message. 76 | TokenCount int `json:"token_count"` 77 | } 78 | 79 | type RoleType string 80 | 81 | const ( 82 | NoRole RoleType = "norole" 83 | SystemRole RoleType = "system" 84 | AssistantRole RoleType = "assistant" 85 | UserRole RoleType = "user" 86 | FunctionRole RoleType = "function" 87 | ToolRole RoleType = "tool" 88 | ) 89 | 90 | type MessageListResponse struct { 91 | // A list of message objects. 92 | Messages []Message `json:"messages"` 93 | // The total number of messages. 94 | TotalCount int `json:"total_count"` 95 | // The number of messages returned. 96 | RowCount int `json:"row_count"` 97 | } 98 | -------------------------------------------------------------------------------- /src/api/apidata/session_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package apidata 3 | 4 | import "github.com/getzep/zep/models" 5 | 6 | func transformSession(in *models.Session, out *Session) {} 7 | 8 | func SessionSearchResultTransformer(result models.SessionSearchResult) SessionSearchResult { 9 | return SessionSearchResult{ 10 | SessionSearchResultCommon: SessionSearchResultCommon{ 11 | Fact: FactTransformerPtr(result.Fact), 12 | }, 13 | } 14 | } 15 | 16 | type Session struct { 17 | SessionCommon 18 | } 19 | 20 | type SessionSearchResult struct { 21 | SessionSearchResultCommon 22 | } 23 | -------------------------------------------------------------------------------- /src/api/apidata/session_common.go: -------------------------------------------------------------------------------- 1 | package apidata 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/google/uuid" 7 | 8 | "github.com/getzep/zep/models" 9 | ) 10 | 11 | func SessionListTransformer(sessions []*models.Session) []Session { 12 | transformedSessions := make([]Session, len(sessions)) 13 | for i, session := range sessions { 14 | transformedSessions[i] = SessionTransformer(session) 15 | } 16 | return transformedSessions 17 | } 18 | 19 | func SessionSearchResultListTransformer(result []models.SessionSearchResult) []SessionSearchResult { 20 | transformedResults := make([]SessionSearchResult, len(result)) 21 | for i, r := range result { 22 | transformedResults[i] = SessionSearchResultTransformer(r) 23 | } 24 | 25 | return transformedResults 26 | } 27 | 28 | func SessionTransformer(session *models.Session) Session { 29 | s := Session{ 30 | SessionCommon: SessionCommon{ 31 | UUID: session.UUID, 32 | ID: session.ID, 33 | CreatedAt: session.CreatedAt, 34 | UpdatedAt: session.UpdatedAt, 35 | DeletedAt: session.DeletedAt, 36 | EndedAt: session.EndedAt, 37 | SessionID: session.SessionID, 38 | Metadata: session.Metadata, 39 | UserID: session.UserID, 40 | }, 41 | } 42 | 43 | transformSession(session, &s) 44 | 45 | return s 46 | } 47 | 48 | type SessionCommon struct { 49 | UUID uuid.UUID `json:"uuid"` 50 | ID int64 `json:"id"` 51 | CreatedAt time.Time `json:"created_at"` 52 | UpdatedAt time.Time `json:"updated_at"` 53 | DeletedAt *time.Time `json:"deleted_at"` 54 | EndedAt *time.Time `json:"ended_at"` 55 | SessionID string `json:"session_id"` 56 | Metadata map[string]any `json:"metadata"` 57 | // Must be a pointer to allow for null values 58 | UserID *string `json:"user_id"` 59 | ProjectUUID uuid.UUID `json:"project_uuid"` 60 | } 61 | 62 | type SessionSearchResultCommon struct { 63 | Fact *Fact `json:"fact"` 64 | } 65 | 66 | type SessionSearchResponse struct { 67 | Results []SessionSearchResult `json:"results"` 68 | } 69 | 70 | type SessionListResponse struct { 71 | Sessions []Session `json:"sessions"` 72 | TotalCount int `json:"total_count"` 73 | RowCount int `json:"response_count"` 74 | } 75 | 76 | type CreateSessionRequestCommon struct { 77 | // The unique identifier of the session. 78 | SessionID string `json:"session_id" validate:"required"` 79 | // The unique identifier of the user associated with the session 80 | UserID *string `json:"user_id"` 81 | // The metadata associated with the session. 82 | Metadata map[string]any `json:"metadata"` 83 | } 84 | -------------------------------------------------------------------------------- /src/api/apidata/user_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package apidata 3 | 4 | import "github.com/getzep/zep/models" 5 | 6 | type User struct { 7 | UserCommon 8 | } 9 | 10 | func transformUser(in *models.User, out *User) { 11 | } 12 | -------------------------------------------------------------------------------- /src/api/apidata/user_common.go: -------------------------------------------------------------------------------- 1 | package apidata 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/getzep/zep/models" 7 | "github.com/google/uuid" 8 | ) 9 | 10 | func UserTransformer(user *models.User) User { 11 | u := User{ 12 | UserCommon: UserCommon{ 13 | UUID: user.UUID, 14 | ID: user.ID, 15 | CreatedAt: user.CreatedAt, 16 | UpdatedAt: user.UpdatedAt, 17 | DeletedAt: user.DeletedAt, 18 | UserID: user.UserID, 19 | Email: user.Email, 20 | FirstName: user.FirstName, 21 | LastName: user.LastName, 22 | Metadata: user.Metadata, 23 | SessionCount: user.SessionCount, 24 | }, 25 | } 26 | 27 | transformUser(user, &u) 28 | 29 | return u 30 | } 31 | 32 | func UserListTransformer(users []*models.User) []User { 33 | userList := make([]User, len(users)) 34 | for i, user := range users { 35 | u := user 36 | userList[i] = UserTransformer(u) 37 | } 38 | return userList 39 | } 40 | 41 | type UserCommon struct { 42 | UUID uuid.UUID `json:"uuid"` 43 | ID int64 `json:"id"` 44 | CreatedAt time.Time `json:"created_at"` 45 | UpdatedAt time.Time `json:"updated_at"` 46 | DeletedAt *time.Time `json:"deleted_at"` 47 | UserID string `json:"user_id"` 48 | Email string `json:"email,omitempty"` 49 | FirstName string `json:"first_name,omitempty"` 50 | LastName string `json:"last_name,omitempty"` 51 | ProjectUUID uuid.UUID `json:"project_uuid"` 52 | Metadata map[string]any `json:"metadata,omitempty"` 53 | SessionCount int `json:"session_count,omitempty"` 54 | } 55 | 56 | type UserListResponse struct { 57 | Users []User `json:"users"` 58 | TotalCount int `json:"total_count"` 59 | RowCount int `json:"row_count"` 60 | } 61 | -------------------------------------------------------------------------------- /src/api/apihandlers/fact_handlers_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package apihandlers 3 | 4 | import ( 5 | "context" 6 | 7 | "github.com/google/uuid" 8 | 9 | "github.com/getzep/zep/lib/graphiti" 10 | "github.com/getzep/zep/models" 11 | ) 12 | 13 | func getFact(ctx context.Context, factUUID uuid.UUID, _ *models.RequestState) (*models.Fact, error) { 14 | graphFact, err := graphiti.I().GetFact(ctx, factUUID) 15 | if err != nil { 16 | return nil, err 17 | } 18 | 19 | return &models.Fact{ 20 | UUID: graphFact.UUID, 21 | Fact: graphFact.Fact, 22 | CreatedAt: graphFact.ExtractCreatedAt(), 23 | }, nil 24 | } 25 | 26 | func deleteSessionFact(ctx context.Context, factUUID uuid.UUID, _ *models.RequestState) error { 27 | return graphiti.I().DeleteFact(ctx, factUUID) 28 | } 29 | -------------------------------------------------------------------------------- /src/api/apihandlers/fact_handlers_common.go: -------------------------------------------------------------------------------- 1 | package apihandlers 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | 9 | "github.com/go-chi/chi/v5" 10 | "github.com/google/uuid" 11 | 12 | "github.com/getzep/zep/api/apidata" 13 | "github.com/getzep/zep/api/handlertools" 14 | "github.com/getzep/zep/lib/observability" 15 | "github.com/getzep/zep/lib/zerrors" 16 | "github.com/getzep/zep/models" 17 | ) 18 | 19 | // GetFactHandler godoc 20 | // 21 | // @Summary Returns a fact by UUID 22 | // @Description get fact by uuid 23 | // @Tags fact 24 | // @Accept json 25 | // @Produce json 26 | // @Param factUUID path string true "Fact UUID" 27 | // @Success 200 {object} apidata.FactResponse "The fact with the specified UUID." 28 | // @Failure 404 {object} apidata.APIError "Not Found" 29 | // @Failure 500 {object} apidata.APIError "Internal Server Error" 30 | // @Security Bearer 31 | // @x-fern-audiences ["cloud", "community"] 32 | // @Router /facts/{factUUID} [get] 33 | func GetFactHandler(as *models.AppState) http.HandlerFunc { // nolint:dupl // not duplicate 34 | return func(w http.ResponseWriter, r *http.Request) { 35 | rs, err := handlertools.NewRequestState(r, as) 36 | if err != nil { 37 | handlertools.HandleErrorRequestState(w, err) 38 | return 39 | } 40 | 41 | factUUIDValue, err := url.PathUnescape(chi.URLParam(r, "factUUID")) 42 | if err != nil { 43 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 44 | return 45 | } 46 | 47 | observability.I().CaptureBreadcrumb( 48 | observability.Category_Facts, 49 | "get_fact", 50 | ) 51 | 52 | factUUID, err := uuid.Parse(factUUIDValue) 53 | if err != nil { 54 | handlertools.LogAndRenderError(w, fmt.Errorf("not found"), http.StatusNotFound) 55 | return 56 | } 57 | 58 | fact, err := getFact(r.Context(), factUUID, rs) 59 | if err != nil { 60 | if errors.Is(err, zerrors.ErrNotFound) { 61 | handlertools.LogAndRenderError(w, fmt.Errorf("not found"), http.StatusNotFound) 62 | return 63 | } 64 | 65 | handlertools.LogAndRenderError(w, err, http.StatusInternalServerError) 66 | return 67 | } 68 | 69 | resp := apidata.FactResponse{ 70 | Fact: apidata.FactTransformer(*fact), 71 | } 72 | 73 | if err := handlertools.EncodeJSON(w, resp); err != nil { 74 | handlertools.LogAndRenderError(w, err, http.StatusInternalServerError) 75 | return 76 | } 77 | } 78 | } 79 | 80 | // DeleteFactHandler godoc 81 | // 82 | // @Summary Delete a fact for the given UUID 83 | // @Description delete a fact 84 | // @Tags fact 85 | // @Accept json 86 | // @Produce json 87 | // @Param factUUID path string true "Fact UUID" 88 | // @Success 201 {string} apidata.SuccessResponse "Deleted" 89 | // @Failure 404 {object} apidata.APIError "Not Found" 90 | // @Failure 500 {object} apidata.APIError "Internal Server Error" 91 | // @Security Bearer 92 | // @x-fern-audiences ["cloud", "community"] 93 | // @Router /facts/{factUUID} [delete] 94 | func DeleteFactHandler(as *models.AppState) http.HandlerFunc { // nolint:dupl // not duplicate 95 | return func(w http.ResponseWriter, r *http.Request) { 96 | rs, err := handlertools.NewRequestState(r, as) 97 | if err != nil { 98 | handlertools.HandleErrorRequestState(w, err) 99 | return 100 | } 101 | 102 | factUUIDValue, err := url.PathUnescape(chi.URLParam(r, "factUUID")) 103 | if err != nil { 104 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 105 | return 106 | } 107 | 108 | observability.I().CaptureBreadcrumb( 109 | observability.Category_Facts, 110 | "delete_fact", 111 | ) 112 | 113 | factUUID, err := uuid.Parse(factUUIDValue) 114 | if err != nil { 115 | handlertools.LogAndRenderError(w, fmt.Errorf("not found"), http.StatusNotFound) 116 | return 117 | } 118 | 119 | err = deleteSessionFact(r.Context(), factUUID, rs) 120 | if err != nil { 121 | if errors.Is(err, zerrors.ErrNotFound) { 122 | handlertools.LogAndRenderError(w, fmt.Errorf("not found"), http.StatusNotFound) 123 | return 124 | } 125 | 126 | handlertools.LogAndRenderError(w, err, http.StatusInternalServerError) 127 | return 128 | } 129 | 130 | w.WriteHeader(http.StatusOK) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /src/api/apihandlers/memory_handlers_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package apihandlers 3 | 4 | import ( 5 | "context" 6 | "net/http" 7 | 8 | "github.com/getzep/zep/api/apidata" 9 | "github.com/getzep/zep/lib/graphiti" 10 | 11 | "github.com/getzep/zep/models" 12 | ) 13 | 14 | func putMemory(r *http.Request, rs *models.RequestState, sessionID string, memory apidata.AddMemoryRequest) error { 15 | return rs.Memories.PutMemory( 16 | r.Context(), 17 | sessionID, 18 | &models.Memory{ 19 | MemoryCommon: models.MemoryCommon{ 20 | Messages: apidata.MessagesToModelMessagesTransformer(memory.Messages), 21 | }, 22 | }, 23 | false, /* skipNotify */ 24 | ) 25 | } 26 | 27 | func extractMemoryFilterOptions(_ *http.Request) ([]models.MemoryFilterOption, error) { 28 | var memoryOptions []models.MemoryFilterOption 29 | 30 | return memoryOptions, nil 31 | } 32 | 33 | func deleteMemory(ctx context.Context, sessionID string, rs *models.RequestState) error { 34 | mList, err := rs.Memories.GetMessageList(ctx, sessionID, 0, 1) 35 | if err != nil { 36 | return err 37 | } 38 | totalSize := mList.TotalCount 39 | if totalSize == 0 { 40 | return rs.Memories.DeleteSession(ctx, sessionID) 41 | } 42 | err = graphiti.I().DeleteGroup(ctx, sessionID) 43 | if err != nil { 44 | return err 45 | } 46 | 47 | return rs.Memories.DeleteSession(ctx, sessionID) 48 | } 49 | -------------------------------------------------------------------------------- /src/api/apihandlers/memory_handlers_common.go: -------------------------------------------------------------------------------- 1 | package apihandlers 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | "unicode/utf8" 9 | 10 | "github.com/go-chi/chi/v5" 11 | 12 | "github.com/getzep/zep/api/apidata" 13 | "github.com/getzep/zep/api/handlertools" 14 | "github.com/getzep/zep/lib/observability" 15 | "github.com/getzep/zep/lib/zerrors" 16 | "github.com/getzep/zep/models" 17 | ) 18 | 19 | const ( 20 | maxMessagesPerMemory = 30 21 | maxMessageLength = 2500 22 | maxLongMessageLength = 100_000 23 | DefaultLastNMessages = 6 24 | ) 25 | 26 | // GetMemoryHandler godoc 27 | // 28 | // @Summary Get session memory 29 | // @Description Returns a memory (latest summary, list of messages and facts) for a given session 30 | // @Tags memory 31 | // @Accept json 32 | // @Produce json 33 | // @Param sessionId path string true "The ID of the session for which to retrieve memory." 34 | // @Param lastn query integer false "The number of most recent memory entries to retrieve." 35 | // @Param minRating query float64 false "The minimum rating by which to filter facts" 36 | // @Success 200 {object} apidata.Memory 37 | // @Failure 404 {object} apidata.APIError "Not Found" 38 | // @Failure 500 {object} apidata.APIError "Internal Server Error" 39 | // @Security Bearer 40 | // @x-fern-audiences ["cloud", "community"] 41 | // 42 | // @Router /sessions/{sessionId}/memory [get] 43 | func GetMemoryHandler(as *models.AppState) http.HandlerFunc { 44 | return func(w http.ResponseWriter, r *http.Request) { 45 | rs, err := handlertools.NewRequestState(r, as) 46 | if err != nil { 47 | handlertools.HandleErrorRequestState(w, err) 48 | return 49 | } 50 | 51 | sessionID, err := url.PathUnescape(chi.URLParam(r, "sessionId")) 52 | if err != nil { 53 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 54 | return 55 | } 56 | 57 | lastN, err := handlertools.IntFromQuery[int](r, "lastn") 58 | if err != nil { 59 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 60 | return 61 | } 62 | 63 | if lastN < 0 { 64 | handlertools.LogAndRenderError(w, fmt.Errorf("lastn cannot be negative"), http.StatusBadRequest) 65 | return 66 | } 67 | 68 | memoryOptions, err := extractMemoryFilterOptions(r) 69 | if err != nil { 70 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 71 | return 72 | } 73 | 74 | // if lastN is 0, use the project settings memory window 75 | if lastN == 0 { 76 | lastN = DefaultLastNMessages 77 | } 78 | 79 | observability.I().CaptureBreadcrumb( 80 | observability.Category_Sessions, 81 | "get_memory", 82 | map[string]any{ 83 | "last_n": lastN, 84 | }, 85 | ) 86 | 87 | sessionMemory, err := rs.Memories.GetMemory(r.Context(), sessionID, lastN, memoryOptions...) 88 | if err != nil { 89 | handlertools.HandleErrorRequestState(w, err) 90 | return 91 | } 92 | 93 | if sessionMemory == nil || sessionMemory.Messages == nil { 94 | handlertools.LogAndRenderError(w, fmt.Errorf("not found"), http.StatusNotFound) 95 | return 96 | } 97 | 98 | resp := apidata.MemoryTransformer(sessionMemory) 99 | 100 | if err := handlertools.EncodeJSON(w, resp); err != nil { 101 | handlertools.LogAndRenderError(w, err, http.StatusInternalServerError) 102 | return 103 | } 104 | } 105 | } 106 | 107 | // PostMemoryHandler godoc 108 | // 109 | // @Summary Add memory to the specified session. 110 | // @Description Add memory to the specified session. 111 | // @Tags memory 112 | // @Accept json 113 | // @Produce json 114 | // @Param sessionId path string true "The ID of the session to which memory should be added." 115 | // @Param memoryMessages body apidata.AddMemoryRequest true "A Memory object representing the memory messages to be added." 116 | // @Success 200 {object} apidata.SuccessResponse "OK" 117 | // @Failure 500 {object} apidata.APIError "Internal Server Error" 118 | // @Security Bearer 119 | // @x-fern-audiences ["cloud", "community"] 120 | // 121 | // @Router /sessions/{sessionId}/memory [post] 122 | func PostMemoryHandler(as *models.AppState) http.HandlerFunc { 123 | return func(w http.ResponseWriter, r *http.Request) { 124 | rs, err := handlertools.NewRequestState(r, as) 125 | if err != nil { 126 | handlertools.HandleErrorRequestState(w, err) 127 | return 128 | } 129 | 130 | sessionID, err := url.PathUnescape(chi.URLParam(r, "sessionId")) 131 | if err != nil { 132 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 133 | return 134 | } 135 | 136 | var memoryMessages apidata.AddMemoryRequest 137 | if err = handlertools.DecodeJSON(r, &memoryMessages); err != nil { 138 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 139 | return 140 | } 141 | 142 | if len(memoryMessages.Messages) > maxMessagesPerMemory { 143 | maxMemoryError := fmt.Errorf( 144 | "max messages per memory of %d exceeded. reduce the number of messages in your request", 145 | maxMessagesPerMemory, 146 | ) 147 | handlertools.LogAndRenderError(w, maxMemoryError, http.StatusBadRequest) 148 | } 149 | 150 | l := maxMessageLength 151 | if !rs.EnablementProfile.Plan.IsFree() { 152 | l = maxLongMessageLength 153 | } 154 | 155 | for i := range memoryMessages.Messages { 156 | if utf8.RuneCountInString(memoryMessages.Messages[i].Content) > l { 157 | err := fmt.Errorf("message content exceeds %d characters", l) 158 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 159 | return 160 | } 161 | } 162 | 163 | for i := range memoryMessages.Messages { 164 | if memoryMessages.Messages[i].RoleType == "" { 165 | handlertools.LogAndRenderError(w, fmt.Errorf("messages are required to have a RoleType"), http.StatusBadRequest) 166 | return 167 | } 168 | } 169 | 170 | if err := handlertools.Validate.Struct(memoryMessages); err != nil { 171 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 172 | return 173 | } 174 | 175 | observability.I().CaptureBreadcrumb( 176 | observability.Category_Sessions, 177 | "post_memory", 178 | ) 179 | 180 | if err := putMemory(r, rs, sessionID, memoryMessages); err != nil { 181 | handlertools.HandleErrorRequestState(w, err) 182 | return 183 | } 184 | 185 | handlertools.JSONOK(w, http.StatusCreated) 186 | } 187 | } 188 | 189 | // DeleteMemoryHandler godoc 190 | // 191 | // @Summary Delete memory messages for a given session 192 | // @Description delete memory messages by session id 193 | // @Tags memory 194 | // @Accept json 195 | // @Produce json 196 | // @Param sessionId path string true "The ID of the session for which memory should be deleted." 197 | // @Success 200 {object} apidata.SuccessResponse "OK" 198 | // @Failure 404 {object} apidata.APIError "Not Found" 199 | // @Failure 500 {object} apidata.APIError "Internal Server Error" 200 | // @Security Bearer 201 | // @x-fern-audiences ["cloud", "community"] 202 | // 203 | // @Router /sessions/{sessionId}/memory [delete] 204 | func DeleteMemoryHandler(as *models.AppState) http.HandlerFunc { 205 | return func(w http.ResponseWriter, r *http.Request) { 206 | rs, err := handlertools.NewRequestState(r, as) 207 | if err != nil { 208 | handlertools.HandleErrorRequestState(w, err) 209 | return 210 | } 211 | 212 | sessionID, err := url.PathUnescape(chi.URLParam(r, "sessionId")) 213 | if err != nil { 214 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 215 | return 216 | } 217 | 218 | observability.I().CaptureBreadcrumb( 219 | observability.Category_Sessions, 220 | "delete_memory", 221 | ) 222 | if err := deleteMemory(r.Context(), sessionID, rs); err != nil { 223 | if errors.Is(err, zerrors.ErrNotFound) { 224 | handlertools.LogAndRenderError(w, fmt.Errorf("not found"), http.StatusNotFound) 225 | return 226 | } 227 | 228 | handlertools.LogAndRenderError(w, err, http.StatusInternalServerError) 229 | return 230 | } 231 | 232 | handlertools.JSONOK(w, http.StatusOK) 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /src/api/apihandlers/message_handlers.go: -------------------------------------------------------------------------------- 1 | package apihandlers 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "net/url" 9 | 10 | "github.com/go-chi/chi/v5" 11 | "github.com/google/uuid" 12 | 13 | "github.com/getzep/zep/api/apidata" 14 | "github.com/getzep/zep/api/handlertools" 15 | "github.com/getzep/zep/lib/observability" 16 | "github.com/getzep/zep/lib/zerrors" 17 | "github.com/getzep/zep/models" 18 | ) 19 | 20 | const defaultMessageLimit = 100 21 | 22 | // UpdateMessageMetadataHandler Updates the metadata of a message. 23 | // 24 | // @Summary Updates the metadata of a message. 25 | // @Description Updates the metadata of a message. 26 | // @Tags messages 27 | // @Accept json 28 | // @Produce json 29 | // @Param sessionId path string true "The ID of the session." 30 | // @Param messageUUID path string true "The UUID of the message." 31 | // @Param body body models.MessageMetadataUpdate true "The metadata to update." 32 | // @Success 200 {object} apidata.Message "The updated message." 33 | // @Failure 404 {object} apidata.APIError "Not Found" 34 | // @Failure 500 {object} apidata.APIError "Internal Server Error" 35 | // @x-fern-audiences ["cloud", "community"] 36 | // @Router /sessions/{sessionId}/messages/{messageUUID} [patch] 37 | func UpdateMessageMetadataHandler(as *models.AppState) http.HandlerFunc { 38 | return func(w http.ResponseWriter, r *http.Request) { 39 | rs, err := handlertools.NewRequestState(r, as) 40 | if err != nil { 41 | handlertools.HandleErrorRequestState(w, err) 42 | return 43 | } 44 | 45 | sessionID, err := url.PathUnescape(chi.URLParam(r, "sessionId")) 46 | if err != nil { 47 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 48 | return 49 | } 50 | 51 | messageUUID := handlertools.UUIDFromURL(r, w, "messageUUID") 52 | if messageUUID == uuid.Nil { 53 | handlertools.LogAndRenderError(w, zerrors.NewBadRequestError("messageUUID is required"), http.StatusBadRequest) 54 | return 55 | } 56 | 57 | var messageUpdate models.MessageMetadataUpdate 58 | 59 | err = json.NewDecoder(r.Body).Decode(&messageUpdate) 60 | if err != nil { 61 | http.Error(w, err.Error(), http.StatusBadRequest) 62 | return 63 | } 64 | 65 | message := models.Message{ 66 | UUID: messageUUID, 67 | Metadata: messageUpdate.Metadata, 68 | } 69 | 70 | observability.I().CaptureBreadcrumb( 71 | observability.Category_Messages, 72 | "update_message_metadata", 73 | map[string]any{ 74 | "message_uuid": messageUUID, 75 | }, 76 | ) 77 | 78 | err = rs.Memories.UpdateMessages( 79 | r.Context(), sessionID, []models.Message{message}, false, false, 80 | ) 81 | if err != nil { 82 | handlertools.HandleErrorRequestState(w, err) 83 | return 84 | } 85 | 86 | messages, err := rs.Memories.GetMessagesByUUID(r.Context(), sessionID, []uuid.UUID{messageUUID}) 87 | if err != nil { 88 | handlertools.HandleErrorRequestState(w, err) 89 | return 90 | } 91 | 92 | resp := apidata.MessageTransformer(messages[0]) 93 | 94 | if err := handlertools.EncodeJSON(w, resp); err != nil { 95 | handlertools.LogAndRenderError(w, err, http.StatusInternalServerError) 96 | return 97 | } 98 | } 99 | } 100 | 101 | // GetMessageHandler retrieves a specific message. 102 | // 103 | // @Summary Gets a specific message from a session 104 | // @Description Gets a specific message from a session 105 | // @Tags messages 106 | // @Accept json 107 | // @Produce json 108 | // @Param sessionId path string true "The ID of the session." 109 | // @Param messageUUID path string true "The UUID of the message." 110 | // @Success 200 {object} apidata.Message "The message." 111 | // @Failure 404 {object} apidata.APIError "Not Found" 112 | // @Failure 500 {object} apidata.APIError "Internal Server Error" 113 | // @x-fern-audiences ["cloud", "community"] 114 | // @Router /sessions/{sessionId}/messages/{messageUUID} [get] 115 | func GetMessageHandler(as *models.AppState) http.HandlerFunc { 116 | return func(w http.ResponseWriter, r *http.Request) { 117 | rs, err := handlertools.NewRequestState(r, as) 118 | if err != nil { 119 | handlertools.HandleErrorRequestState(w, err) 120 | return 121 | } 122 | 123 | sessionID, err := url.PathUnescape(chi.URLParam(r, "sessionId")) 124 | if err != nil { 125 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 126 | return 127 | } 128 | 129 | messageUUID := handlertools.UUIDFromURL(r, w, "messageUUID") 130 | messageIDs := []uuid.UUID{messageUUID} 131 | 132 | observability.I().CaptureBreadcrumb( 133 | observability.Category_Messages, 134 | "get_message", 135 | map[string]any{ 136 | "message_uuid": messageUUID, 137 | }, 138 | ) 139 | 140 | messages, err := rs.Memories.GetMessagesByUUID(r.Context(), sessionID, messageIDs) 141 | if err != nil { 142 | if errors.Is(err, zerrors.ErrNotFound) { 143 | handlertools.LogAndRenderError(w, fmt.Errorf("not found"), http.StatusNotFound) 144 | return 145 | } 146 | 147 | handlertools.LogAndRenderError(w, err, http.StatusInternalServerError) 148 | return 149 | } 150 | 151 | if len(messages) == 0 { 152 | handlertools.LogAndRenderError(w, fmt.Errorf("no message found for UUID"), http.StatusNotFound) 153 | return 154 | } 155 | 156 | resp := apidata.MessageTransformer(messages[0]) 157 | 158 | if err := handlertools.EncodeJSON(w, resp); err != nil { 159 | handlertools.LogAndRenderError(w, err, http.StatusInternalServerError) 160 | return 161 | } 162 | } 163 | } 164 | 165 | // GetMessagesForSessionHandler retrieves all messages for a specific session. 166 | // 167 | // @Summary Lists messages for a session 168 | // @Description Lists messages for a session, specified by limit and cursor. 169 | // @Tags messages 170 | // @Accept json 171 | // @Produce json 172 | // @Param sessionId path string true "Session ID" 173 | // @Param limit query integer false "Limit the number of results returned" 174 | // @Param cursor query int64 false "Cursor for pagination" 175 | // @Success 200 {object} apidata.MessageListResponse 176 | // @Failure 404 {object} apidata.APIError "Not Found" 177 | // @Failure 500 {object} apidata.APIError "Internal Server Error" 178 | // @x-fern-audiences ["cloud", "community"] 179 | // @Router /sessions/{sessionId}/messages [get] 180 | func GetMessagesForSessionHandler(as *models.AppState) http.HandlerFunc { 181 | return func(w http.ResponseWriter, r *http.Request) { 182 | rs, err := handlertools.NewRequestState(r, as) 183 | if err != nil { 184 | handlertools.HandleErrorRequestState(w, err) 185 | return 186 | } 187 | 188 | sessionID, err := url.PathUnescape(chi.URLParam(r, "sessionId")) 189 | if err != nil { 190 | handlertools.LogAndRenderError(w, err, http.StatusBadRequest) 191 | return 192 | } 193 | 194 | limit, err := handlertools.IntFromQuery[int](r, "limit") 195 | if err != nil { 196 | limit = defaultMessageLimit 197 | } 198 | 199 | cursor, err := handlertools.IntFromQuery[int](r, "cursor") 200 | if err != nil { 201 | cursor = 1 202 | } 203 | 204 | observability.I().CaptureBreadcrumb( 205 | observability.Category_Messages, 206 | "get_messages_for_session", 207 | map[string]any{ 208 | "cursor": cursor, 209 | "limit": limit, 210 | }, 211 | ) 212 | 213 | messages, err := rs.Memories.GetMessageList(r.Context(), sessionID, cursor, limit) 214 | if err != nil { 215 | handlertools.LogAndRenderError(w, err, http.StatusInternalServerError) 216 | return 217 | } 218 | 219 | resp := apidata.MessageListResponse{ 220 | Messages: apidata.MessageListTransformer(messages.Messages), 221 | TotalCount: messages.TotalCount, 222 | RowCount: messages.RowCount, 223 | } 224 | 225 | if err := handlertools.EncodeJSON(w, resp); err != nil { 226 | handlertools.LogAndRenderError(w, err, http.StatusInternalServerError) 227 | return 228 | } 229 | } 230 | } 231 | -------------------------------------------------------------------------------- /src/api/handlertools/request_state_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package handlertools 3 | 4 | import ( 5 | "net/http" 6 | 7 | "github.com/getzep/zep/lib/config" 8 | "github.com/getzep/zep/models" 9 | "github.com/getzep/zep/store" 10 | ) 11 | 12 | func NewRequestState(r *http.Request, as *models.AppState, opts ...RequestStateOption) (*models.RequestState, error) { 13 | options := &requestStateOptions{} 14 | for _, opt := range opts { 15 | opt.apply(options) 16 | } 17 | 18 | rs := &models.RequestState{} 19 | 20 | rs.SchemaName = config.Postgres().SchemaName 21 | rs.ProjectUUID = config.ProjectUUID() 22 | rs.Memories = store.NewMemoryStore(as, rs) 23 | rs.Sessions = store.NewSessionDAO(as, rs) 24 | rs.Users = store.NewUserStore(as, rs) 25 | 26 | return rs, nil 27 | } 28 | -------------------------------------------------------------------------------- /src/api/handlertools/tools.go: -------------------------------------------------------------------------------- 1 | package handlertools 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "regexp" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/go-chi/chi/v5" 13 | "github.com/go-playground/validator/v10" 14 | "github.com/google/uuid" 15 | 16 | "github.com/getzep/zep/api/apidata" 17 | "github.com/getzep/zep/lib/logger" 18 | "github.com/getzep/zep/lib/observability" 19 | "github.com/getzep/zep/lib/zerrors" 20 | ) 21 | 22 | const ( 23 | RequestIDHeader = "X-Zep-Request-ID" 24 | RequestIDKey = "_zep_req_id" 25 | ) 26 | 27 | var Validate = validator.New() 28 | 29 | func AlphanumericWithUnderscores(fl validator.FieldLevel) bool { 30 | name := fl.Field().String() 31 | return regexp.MustCompile("^[a-zA-Z0-9_]+$").MatchString(name) 32 | } 33 | 34 | func NonEmptyStrings(fl validator.FieldLevel) bool { 35 | slice, ok := fl.Field().Interface().([]string) 36 | if !ok { 37 | return false 38 | } 39 | for _, s := range slice { 40 | if s == "" { 41 | return false 42 | } 43 | } 44 | return true 45 | } 46 | 47 | func RegisterValidations(validations map[string]func(fl validator.FieldLevel) bool) error { 48 | for name, validationFunc := range validations { 49 | if err := Validate.RegisterValidation(name, validationFunc); err != nil { 50 | logger.Error("Error registering validation", "name", name, "error", err) 51 | } 52 | } 53 | 54 | return nil 55 | } 56 | 57 | func DecodeAndValidateJSON(r *http.Request, v any) error { 58 | if err := DecodeJSON(r, v); err != nil { 59 | return err 60 | } 61 | return Validate.Struct(v) 62 | } 63 | 64 | func HandleErrorRequestState(w http.ResponseWriter, err error) { 65 | switch { 66 | case errors.Is(err, zerrors.ErrNotFound): 67 | LogAndRenderError(w, fmt.Errorf("not found"), http.StatusNotFound) 68 | case errors.Is(err, zerrors.ErrBadRequest): 69 | LogAndRenderError(w, err, http.StatusBadRequest) 70 | case errors.Is(err, zerrors.ErrUnauthorized): 71 | LogAndRenderError(w, err, http.StatusUnauthorized) 72 | case errors.Is(err, zerrors.ErrDeprecated): 73 | LogAndRenderError(w, err, http.StatusGone) 74 | case errors.Is(err, zerrors.ErrLockAcquisitionFailed): 75 | LogAndRenderError(w, err, http.StatusTooManyRequests) 76 | case errors.Is(err, zerrors.ErrSessionEnded): 77 | LogAndRenderError(w, err, http.StatusConflict) 78 | default: 79 | LogAndRenderError(w, err, http.StatusInternalServerError) 80 | } 81 | } 82 | 83 | type requestStateOptions struct { 84 | noCache bool 85 | // Indicates whether the handler is public (i.e. uses token with zmiddleware.PublicKeyAuthorizationPrefix) 86 | publicHandler bool 87 | } 88 | 89 | type RequestStateOption interface { 90 | apply(*requestStateOptions) 91 | } 92 | 93 | type noCacheRequestStateOption bool 94 | 95 | func (r noCacheRequestStateOption) apply(opts *requestStateOptions) { 96 | opts.noCache = bool(r) 97 | } 98 | 99 | func WithoutFlagCache(c bool) RequestStateOption { 100 | return noCacheRequestStateOption(c) 101 | } 102 | 103 | type publicHandlerRequestStateOption bool 104 | 105 | func (r publicHandlerRequestStateOption) apply(opts *requestStateOptions) { 106 | opts.publicHandler = bool(r) 107 | } 108 | 109 | func PublicHandler(c bool) RequestStateOption { 110 | return publicHandlerRequestStateOption(c) 111 | } 112 | 113 | // IntFromQuery extracts a query string value and converts it to an int 114 | // if it is not empty. If the value is empty, it returns 0. 115 | func IntFromQuery[T ~int | ~int32 | int64]( 116 | r *http.Request, 117 | param string, 118 | ) (T, error) { 119 | bitsize := 0 120 | 121 | p := strings.TrimSpace(r.URL.Query().Get(param)) 122 | var pInt T 123 | if p != "" { 124 | switch any(pInt).(type) { 125 | case int: 126 | case int32: 127 | bitsize = 32 //nolint:revive // 32 is the size of an int32 128 | case int64: 129 | bitsize = 64 //nolint:revive // 64 is the size of an int64 130 | default: 131 | return 0, errors.New("unsupported type") 132 | } 133 | 134 | pInt, err := strconv.ParseInt(p, 10, bitsize) //nolint:revive // 10 is the base 135 | if err != nil { 136 | return 0, err 137 | } 138 | return T(pInt), nil 139 | } 140 | return 0, nil 141 | } 142 | 143 | func FloatFromQuery[T ~float32 | ~float64](r *http.Request, param string) (T, error) { 144 | p := strings.TrimSpace(r.URL.Query().Get(param)) 145 | if p == "" { 146 | return 0, nil 147 | } 148 | 149 | var ft T 150 | var bitsize int 151 | switch any(ft).(type) { 152 | case float32: 153 | bitsize = 32 //nolint:revive // 32 is the size of a float32 154 | case float64: 155 | bitsize = 64 //nolint:revive // 64 is the size of a float64 156 | default: 157 | return 0, errors.New("unsupported type") 158 | } 159 | 160 | pf, err := strconv.ParseFloat(p, bitsize) 161 | if err != nil { 162 | return 0, err 163 | } 164 | return T(pf), nil 165 | } 166 | 167 | // BoolFromQuery extracts a query string value and converts it to a bool 168 | func BoolFromQuery(r *http.Request, param string) (bool, error) { 169 | p := strings.TrimSpace(r.URL.Query().Get(param)) 170 | if p != "" { 171 | return strconv.ParseBool(p) 172 | } 173 | return false, nil 174 | } 175 | 176 | // BoundedStringFromQuery extracts a query string value and checks if it is one of the provided options. 177 | func BoundedStringFromQuery(r *http.Request, param string, options []string) (string, error) { 178 | p := strings.TrimSpace(r.URL.Query().Get(param)) 179 | if p == "" { 180 | return "", nil 181 | } 182 | for _, option := range options { 183 | if p == option { 184 | return p, nil 185 | } 186 | } 187 | return "", fmt.Errorf("invalid value for %s", param) 188 | } 189 | 190 | // EncodeJSON encodes data into JSON and writes it to the response writer. 191 | func EncodeJSON(w http.ResponseWriter, data any) error { 192 | return json.NewEncoder(w).Encode(data) 193 | } 194 | 195 | // DecodeJSON decodes a JSON request body into the provided data struct. 196 | func DecodeJSON(r *http.Request, data any) error { 197 | return json.NewDecoder(r.Body).Decode(&data) 198 | } 199 | 200 | func JSONError(w http.ResponseWriter, e error, code int) { 201 | w.Header().Set("Content-Type", "application/json; charset=utf-8") 202 | w.Header().Set("X-Content-Type-Options", "nosniff") 203 | w.WriteHeader(code) 204 | errorResponse := zerrors.ErrorResponse{ 205 | Message: e.Error(), 206 | } 207 | if err := EncodeJSON(w, errorResponse); err != nil { 208 | http.Error(w, err.Error(), http.StatusInternalServerError) 209 | return 210 | } 211 | } 212 | 213 | func JSONOK(w http.ResponseWriter, code int) { 214 | w.Header().Set("Content-Type", "application/json; charset=utf-8") 215 | w.Header().Set("X-Content-Type-Options", "nosniff") 216 | w.WriteHeader(code) 217 | r := apidata.SuccessResponse{ 218 | Message: "OK", 219 | } 220 | if err := EncodeJSON(w, r); err != nil { 221 | http.Error(w, err.Error(), http.StatusInternalServerError) 222 | return 223 | } 224 | } 225 | 226 | // LogAndRenderError logs, sanitizes, and renders an error response. 227 | func LogAndRenderError(w http.ResponseWriter, err error, status int) { 228 | // log errors from 500 onwards (inclusive) 229 | if status >= http.StatusInternalServerError { 230 | if errors.Is(err, zerrors.ErrInternalCustomMessage) { 231 | var customMsgInternalErr *zerrors.CustomMessageInternalError 232 | if errors.As(err, &customMsgInternalErr) { 233 | observability.I().CaptureError("Custom message internal error", errors.New(customMsgInternalErr.InternalMessage)) 234 | } 235 | } else { 236 | observability.I().CaptureError("Internal server error", err) 237 | } 238 | } 239 | 240 | // Add descriptive error messages for request body too large 241 | if err.Error() == "http: request body too large" { 242 | status = http.StatusRequestEntityTooLarge 243 | err = fmt.Errorf( 244 | "request body too large", 245 | ) 246 | } 247 | 248 | // sanitize error if it is an auth error 249 | if status == http.StatusUnauthorized { 250 | err = zerrors.ErrUnauthorized 251 | } 252 | 253 | // If the error is a bad request, return a 400 254 | if strings.Contains(err.Error(), "is deleted") || errors.Is(err, zerrors.ErrBadRequest) { 255 | status = http.StatusBadRequest 256 | } 257 | 258 | // Handle too many requests error 259 | if status == http.StatusTooManyRequests { 260 | err = errors.New("too many concurrent writes to the same record") 261 | } 262 | 263 | JSONError(w, err, status) 264 | } 265 | 266 | // UUIDFromURL parses a UUID from a Path parameter. If the UUID is invalid, an error is 267 | // rendered and uuid.Nil is returned. 268 | func UUIDFromURL(r *http.Request, w http.ResponseWriter, paramName string) uuid.UUID { 269 | value := chi.URLParam(r, paramName) 270 | 271 | objUUID, err := uuid.Parse(value) 272 | if err != nil { 273 | LogAndRenderError( 274 | w, 275 | fmt.Errorf("unable to parse UUID: %w", err), 276 | http.StatusBadRequest, 277 | ) 278 | return uuid.Nil 279 | } 280 | 281 | return objUUID 282 | } 283 | 284 | func ExtractPaginationFromRequest(r *http.Request) (pNum, pSize int, pErr error) { 285 | pageNumber, err := IntFromQuery[int](r, "pageNumber") 286 | if err != nil { 287 | return 0, 0, err 288 | } 289 | 290 | pageSize, err := IntFromQuery[int](r, "pageSize") 291 | if err != nil { 292 | return 0, 0, err 293 | } 294 | 295 | return pageNumber, pageSize, nil 296 | } 297 | -------------------------------------------------------------------------------- /src/api/middleware/auth.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | type ZepContextKey string 4 | 5 | const ( 6 | UserId ZepContextKey = "user_id" 7 | ProjectId ZepContextKey = "project_id" 8 | 9 | RequestTokenType ZepContextKey = "request_token_type" 10 | ) 11 | 12 | const BearerRequestTokenType = "bearer" 13 | 14 | const ( 15 | apiKeyAuthorizationPrefix = "Api-Key" 16 | ) 17 | -------------------------------------------------------------------------------- /src/api/middleware/secret_key_auth_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package middleware 3 | 4 | import ( 5 | "context" 6 | "net/http" 7 | "strings" 8 | 9 | "github.com/getzep/zep/lib/config" 10 | ) 11 | 12 | const secretKeyRequestTokenType = "secret-key" 13 | 14 | func SecretKeyAuthMiddleware(next http.Handler) http.Handler { 15 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 16 | authHeader := r.Header.Get("Authorization") 17 | parts := strings.Split(authHeader, " ") 18 | if len(parts) != 2 { 19 | http.Error(w, "unauthorized", http.StatusUnauthorized) 20 | return 21 | } 22 | 23 | prefix, tokenString := parts[0], parts[1] 24 | if prefix != apiKeyAuthorizationPrefix { 25 | http.Error(w, "unauthorized", http.StatusUnauthorized) 26 | return 27 | } 28 | 29 | if tokenString != config.ApiSecret() { 30 | http.Error(w, "unauthorized", http.StatusUnauthorized) 31 | return 32 | } 33 | 34 | ctx := r.Context() 35 | ctx = context.WithValue(ctx, RequestTokenType, secretKeyRequestTokenType) 36 | ctx = context.WithValue(ctx, ProjectId, config.ProjectUUID()) 37 | 38 | r = r.WithContext(ctx) 39 | 40 | next.ServeHTTP(w, r) 41 | }) 42 | } 43 | -------------------------------------------------------------------------------- /src/api/middleware/send_version.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/getzep/zep/lib/config" 7 | "github.com/go-chi/chi/v5/middleware" 8 | ) 9 | 10 | const VersionHeader = "X-Zep-Version" 11 | 12 | // SendVersion is a middleware that adds the current version to the response 13 | func SendVersion(next http.Handler) http.Handler { 14 | fn := func(w http.ResponseWriter, r *http.Request) { 15 | resp := middleware.NewWrapResponseWriter(w, r.ProtoMajor) 16 | 17 | next.ServeHTTP(resp, r) 18 | 19 | // we want this to run after the request to ensure we aren't overriding any headers 20 | // that were set by the handler 21 | if resp.Header().Get(VersionHeader) == "" { 22 | resp.Header().Add( 23 | VersionHeader, 24 | config.VersionString(), 25 | ) 26 | } 27 | } 28 | 29 | return http.HandlerFunc(fn) 30 | } 31 | -------------------------------------------------------------------------------- /src/api/routes.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/go-chi/chi/v5" 10 | chiMiddleware "github.com/go-chi/chi/v5/middleware" 11 | "github.com/go-chi/cors" 12 | "github.com/go-playground/validator/v10" 13 | "github.com/google/uuid" 14 | "github.com/riandyrn/otelchi" 15 | 16 | "github.com/getzep/zep/api/apihandlers" 17 | "github.com/getzep/zep/api/handlertools" 18 | "github.com/getzep/zep/api/middleware" 19 | "github.com/getzep/zep/lib/config" 20 | "github.com/getzep/zep/lib/logger" 21 | "github.com/getzep/zep/models" 22 | ) 23 | 24 | const ( 25 | MaxRequestSize = 5 << 20 // 5MB 26 | ServerContextTimeout = 30 * time.Second 27 | ReadHeaderTimeout = 5 * time.Second 28 | RouterName = "zep-api" 29 | ) 30 | 31 | func Create(as *models.AppState) (*http.Server, error) { 32 | host := config.Http().Host 33 | port := config.Http().Port 34 | 35 | mw := getMiddleware(as) 36 | 37 | router, err := setupRouter(as, mw) 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | return &http.Server{ 43 | Addr: fmt.Sprintf("%s:%d", host, port), 44 | Handler: router, 45 | ReadHeaderTimeout: ReadHeaderTimeout, 46 | }, nil 47 | } 48 | 49 | // SetupRouter 50 | // 51 | // @title Zep Cloud API 52 | // 53 | // @version 0.x 54 | // @host api.getzep.com 55 | // @BasePath /api/v2 56 | // @schemes https 57 | // @securityDefinitions.apikey Api-Key 58 | // @in header 59 | // @name Authorization 60 | // 61 | // 62 | // @description Type "Api-Key" followed by a space and JWT token. 63 | func setupRouter(as *models.AppState, mw []func(http.Handler) http.Handler) (*chi.Mux, error) { 64 | validations := map[string]func(fl validator.FieldLevel) bool{ 65 | "alphanumeric_with_underscores": handlertools.AlphanumericWithUnderscores, 66 | "nonemptystrings": handlertools.NonEmptyStrings, 67 | } 68 | 69 | if err := handlertools.RegisterValidations(validations); err != nil { 70 | return nil, err 71 | } 72 | 73 | router := chi.NewRouter() 74 | router.Use( 75 | cors.Handler(cors.Options{ 76 | AllowOriginFunc: func(_ *http.Request, _ string) bool { return true }, 77 | AllowedMethods: []string{"GET", "POST", "PUT", "DELETE"}, 78 | AllowedHeaders: []string{"Authorization"}, 79 | }), 80 | func(next http.Handler) http.Handler { 81 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 82 | st := time.Now().UTC() 83 | resp := chiMiddleware.NewWrapResponseWriter(w, r.ProtoMajor) 84 | 85 | defer func() { 86 | logger.Info( 87 | "HTTP Request Served", 88 | "proto", r.Proto, 89 | "method", r.Method, 90 | "path", r.URL.Path, 91 | "request_id", chiMiddleware.GetReqID(r.Context()), 92 | "duration", time.Since(st), 93 | "status", resp.Status(), 94 | "response_size", resp.BytesWritten(), 95 | ) 96 | }() 97 | 98 | next.ServeHTTP(resp, r) 99 | }) 100 | }, 101 | chiMiddleware.Heartbeat("/healthz"), 102 | chiMiddleware.RequestSize(config.Http().MaxRequestSize), 103 | func(next http.Handler) http.Handler { 104 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 105 | requestID := r.Header.Get(handlertools.RequestIDHeader) 106 | if requestID == "" { 107 | requestID = uuid.New().String() 108 | } 109 | 110 | ctx := context.WithValue(r.Context(), handlertools.RequestIDKey, requestID) //nolint:staticcheck // it will be fine 111 | 112 | next.ServeHTTP(w, r.WithContext(ctx)) 113 | }) 114 | }, 115 | chiMiddleware.Timeout(ServerContextTimeout), 116 | chiMiddleware.RealIP, 117 | chiMiddleware.CleanPath, 118 | middleware.SendVersion, 119 | otelchi.Middleware( 120 | RouterName, 121 | otelchi.WithChiRoutes(router), 122 | otelchi.WithRequestMethodInSpanName(true), 123 | ), 124 | ) 125 | 126 | setupAPIRoutes(router, as, mw) 127 | 128 | return router, nil 129 | } 130 | 131 | func setupSessionRoutes(router chi.Router, appState *models.AppState, extend ...map[string]func(chi.Router, *models.AppState)) { 132 | var extensions map[string]func(chi.Router, *models.AppState) 133 | if len(extend) > 0 { 134 | extensions = extend[0] 135 | } 136 | 137 | router.Get("/sessions-ordered", apihandlers.GetOrderedSessionListHandler(appState)) 138 | 139 | // these need to be explicitly defined to avoid conflicts with the /sessions/{sessionId} route 140 | router.Post("/sessions/search", apihandlers.SearchSessionsHandler(appState)) 141 | 142 | router.Route("/sessions", func(r chi.Router) { 143 | r.Get("/", apihandlers.GetSessionListHandler(appState)) 144 | r.Post("/", apihandlers.CreateSessionHandler(appState)) 145 | 146 | if ex, ok := extensions["/sessions"]; ok { 147 | ex(r, appState) 148 | } 149 | }) 150 | 151 | router.Route("/sessions/{sessionId}", func(r chi.Router) { 152 | r.Get("/", apihandlers.GetSessionHandler(appState)) 153 | r.Patch("/", apihandlers.UpdateSessionHandler(appState)) 154 | 155 | if ex, ok := extensions["/sessions/{sessionId}"]; ok { 156 | ex(r, appState) 157 | } 158 | 159 | // Memory-related routes 160 | r.Route("/memory", func(r chi.Router) { 161 | r.Get("/", apihandlers.GetMemoryHandler(appState)) 162 | r.Post("/", apihandlers.PostMemoryHandler(appState)) 163 | r.Delete("/", apihandlers.DeleteMemoryHandler(appState)) 164 | 165 | if ex, ok := extensions["/sessions/{sessionId}/memory"]; ok { 166 | ex(r, appState) 167 | } 168 | }) 169 | 170 | // Message-related routes 171 | r.Route("/messages", func(r chi.Router) { 172 | r.Get("/", apihandlers.GetMessagesForSessionHandler(appState)) 173 | r.Route("/{messageUUID}", func(r chi.Router) { 174 | r.Get("/", apihandlers.GetMessageHandler(appState)) 175 | r.Patch("/", apihandlers.UpdateMessageMetadataHandler(appState)) 176 | 177 | if ex, ok := extensions["/sessions/{sessionId}/messages/{messageUUID}"]; ok { 178 | ex(r, appState) 179 | } 180 | }) 181 | 182 | if ex, ok := extensions["/sessions/{sessionId}/messages"]; ok { 183 | ex(r, appState) 184 | } 185 | }) 186 | }) 187 | } 188 | 189 | func setupUserRoutes(router chi.Router, appState *models.AppState) { 190 | router.Post("/users", apihandlers.CreateUserHandler(appState)) 191 | router.Get("/users", apihandlers.ListAllUsersHandler(appState)) 192 | router.Get("/users-ordered", apihandlers.ListAllOrderedUsersHandler(appState)) 193 | router.Route("/users/{userId}", func(r chi.Router) { 194 | r.Get("/", apihandlers.GetUserHandler(appState)) 195 | r.Patch("/", apihandlers.UpdateUserHandler(appState)) 196 | r.Delete("/", apihandlers.DeleteUserHandler(appState)) 197 | r.Get("/sessions", apihandlers.ListUserSessionsHandler(appState)) 198 | }) 199 | } 200 | 201 | func setupFactRoutes(router chi.Router, appState *models.AppState) { 202 | router.Route("/facts/{factUUID}", func(r chi.Router) { 203 | r.Get("/", apihandlers.GetFactHandler(appState)) 204 | r.Delete("/", apihandlers.DeleteFactHandler(appState)) 205 | }) 206 | } 207 | -------------------------------------------------------------------------------- /src/api/server_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package api 3 | 4 | import ( 5 | "net/http" 6 | 7 | "github.com/go-chi/chi/v5" 8 | 9 | "github.com/getzep/zep/api/middleware" 10 | "github.com/getzep/zep/models" 11 | ) 12 | 13 | func getMiddleware(appState *models.AppState) []func(http.Handler) http.Handler { 14 | mw := []func(http.Handler) http.Handler{ 15 | middleware.SecretKeyAuthMiddleware, 16 | } 17 | 18 | return mw 19 | } 20 | 21 | func setupAPIRoutes(router chi.Router, as *models.AppState, mw []func(http.Handler) http.Handler) { 22 | router.Route("/api/v2", func(r chi.Router) { 23 | for _, m := range mw { 24 | r.Use(m) 25 | } 26 | 27 | setupUserRoutes(r, as) 28 | setupSessionRoutes(r, as) 29 | setupFactRoutes(r, as) 30 | }) 31 | } 32 | -------------------------------------------------------------------------------- /src/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/getzep/zep 2 | 3 | go 1.21.5 4 | 5 | require ( 6 | dario.cat/mergo v1.0.1 7 | github.com/ThreeDotsLabs/watermill v1.3.7 8 | github.com/failsafe-go/failsafe-go v0.6.8 9 | github.com/go-chi/chi/v5 v5.1.0 10 | github.com/go-chi/cors v1.2.1 11 | github.com/go-playground/validator/v10 v10.22.1 12 | github.com/google/uuid v1.6.0 13 | github.com/hashicorp/go-retryablehttp v0.7.7 14 | github.com/riandyrn/otelchi v0.9.0 15 | github.com/uptrace/bun v1.1.17 16 | github.com/uptrace/bun/dialect/pgdialect v1.1.17 17 | github.com/uptrace/bun/driver/pgdriver v1.1.17 18 | github.com/uptrace/bun/extra/bunotel v1.1.17 19 | github.com/viterin/vek v0.4.2 20 | go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.46.1 21 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.48.0 22 | go.uber.org/zap v1.27.0 23 | gopkg.in/yaml.v3 v3.0.1 24 | ) 25 | 26 | require ( 27 | github.com/chewxy/math32 v1.10.1 // indirect 28 | github.com/felixge/httpsnoop v1.0.4 // indirect 29 | github.com/gabriel-vasile/mimetype v1.4.3 // indirect 30 | github.com/go-logr/logr v1.4.2 // indirect 31 | github.com/go-logr/stdr v1.2.2 // indirect 32 | github.com/go-playground/locales v0.14.1 // indirect 33 | github.com/go-playground/universal-translator v0.18.1 // indirect 34 | github.com/hashicorp/go-cleanhttp v0.5.2 // indirect 35 | github.com/jinzhu/inflection v1.0.0 // indirect 36 | github.com/leodido/go-urn v1.4.0 // indirect 37 | github.com/lithammer/shortuuid/v3 v3.0.7 // indirect 38 | github.com/oklog/ulid v1.3.1 // indirect 39 | github.com/pkg/errors v0.9.1 // indirect 40 | github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect 41 | github.com/uptrace/opentelemetry-go-extra/otelsql v0.2.3 // indirect 42 | github.com/viterin/partial v1.1.0 // indirect 43 | github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect 44 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 45 | go.opentelemetry.io/otel v1.28.0 // indirect 46 | go.opentelemetry.io/otel/metric v1.28.0 // indirect 47 | go.opentelemetry.io/otel/trace v1.28.0 // indirect 48 | go.uber.org/multierr v1.10.0 // indirect 49 | golang.org/x/crypto v0.26.0 // indirect 50 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect 51 | golang.org/x/net v0.26.0 // indirect 52 | golang.org/x/sys v0.25.0 // indirect 53 | golang.org/x/text v0.17.0 // indirect 54 | mellium.im/sasl v0.3.1 // indirect 55 | ) 56 | -------------------------------------------------------------------------------- /src/golangci.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | go: '1.21' 3 | linters: 4 | enable: 5 | - gocritic 6 | - revive 7 | - gosimple 8 | - govet 9 | - ineffassign 10 | - staticcheck 11 | - typecheck 12 | - unused 13 | - gosec 14 | - forbidigo 15 | - errcheck 16 | 17 | linters-settings: 18 | revive: 19 | ignore-generated-header: true 20 | severity: warning 21 | enable-all-rules: true 22 | confidence: 0.8 23 | rules: 24 | - name: flag-parameter 25 | severity: warning 26 | disabled: true 27 | - name: line-length-limit 28 | severity: warning 29 | disabled: true 30 | - name: max-public-structs 31 | severity: warning 32 | disabled: true 33 | - name: var-naming 34 | severity: warning 35 | disabled: true 36 | - name: cyclomatic 37 | severity: warning 38 | disabled: true 39 | - name: cognitive-complexity 40 | severity: warning 41 | disabled: true 42 | arguments: [15] 43 | - name: add-constant 44 | severity: warning 45 | disabled: false 46 | arguments: 47 | - maxLitCount: "5" 48 | allowStrs: '"","OK"' 49 | allowInts: "0,1" 50 | - name: function-length 51 | severity: warning 52 | disabled: true 53 | - name: flag-parameter 54 | severity: warning 55 | disabled: false 56 | - name: unexported-return 57 | disabled: true 58 | - name: import-alias-naming 59 | severity: warning 60 | disabled: false 61 | exclude: [""] 62 | arguments: 63 | - "^[a-z][a-zA-Z0-9]{0,}$" 64 | - name: unused-parameter 65 | disabled: true 66 | - name: unused-receiver 67 | disabled: true 68 | - name: unhandled-error 69 | severity: warning 70 | disabled: false 71 | arguments: 72 | - "io.Closer.Close" 73 | - "os.Setenv" 74 | - "strings.Builder.WriteString" 75 | - "net/http.Server.Shutdown" 76 | gocritic: 77 | enabled-tags: [diagnostic, style, performance, opinionated] 78 | disabled-checks: 79 | - rangeValCopy 80 | - unnamedResult 81 | settings: 82 | hugeParam: 83 | sizeThreshold: 5120 # 5kb 84 | forbidigo: 85 | # Forbid the following identifiers (list of regexp). 86 | # Default: ["^(fmt\\.Print(|f|ln)|print|println)$"] 87 | forbid: 88 | - p: ^fmt\.Print.*$ 89 | msg: Do not commit print statements. 90 | # Optional message that gets included in error reports. 91 | - p: ^log\.Println.*$ 92 | msg: Do not commit log.Println statements. 93 | # Exclude godoc examples from forbidigo checks. 94 | # Default: true 95 | exclude-godoc-examples: false 96 | # Instead of matching the literal source code, 97 | # use type information to replace expressions with strings that contain the package name 98 | # and (for methods and fields) the type name. 99 | # This makes it possible to handle import renaming and forbid struct fields and methods. 100 | # Default: false 101 | analyze-types: true 102 | errcheck: 103 | check-type-assertions: true 104 | exclude-functions: 105 | - (*net/http.Server).Shutdown 106 | # output: 107 | # Format: colored-line-number|line-number|json|colored-tab|tab|checkstyle|code-climate|junit-xml|github-actions|teamcity 108 | # 109 | # Multiple can be specified by separating them by comma, output can be provided 110 | # for each of them by separating format name and path by colon symbol. 111 | # Output path can be either `stdout`, `stderr` or path to the file to write to. 112 | # Example: "checkstyle:report.xml,json:stdout,colored-line-number" 113 | # 114 | # Default: colored-line-number 115 | # format: json 116 | 117 | severity: 118 | # Set the default severity for issues. 119 | # 120 | # If severity rules are defined and the issues do not match or no severity is provided to the rule 121 | # this will be the default severity applied. 122 | # Severities should match the supported severity names of the selected out format. 123 | # - Code climate: https://docs.codeclimate.com/docs/issues#issue-severity 124 | # - Checkstyle: https://checkstyle.sourceforge.io/property_types.html#SeverityLevel 125 | # - GitHub: https://help.github.com/en/actions/reference/workflow-commands-for-github-actions#setting-an-error-message 126 | # - TeamCity: https://www.jetbrains.com/help/teamcity/service-messages.html#Inspection+Instance 127 | # 128 | # Default value is an empty string. 129 | default-severity: error 130 | # If set to true `severity-rules` regular expressions become case-sensitive. 131 | # Default: false 132 | case-sensitive: true 133 | issues: 134 | exclude-dirs: 135 | - deploy 136 | - test_data 137 | - pkg/triton_grpc_client 138 | exclude-rules: 139 | - path: ".*\\.go" 140 | text: "flag-parameter" # seems to be a bug in revive and this doesn't disable in the revive config 141 | linters: 142 | - revive 143 | - path: ".*\\.go" 144 | text: "add-constant: string literal" # ignore repeated string literals 145 | linters: 146 | - revive 147 | - path: "tasks/.*\\.go" 148 | text: "deep-exit" 149 | linters: 150 | - revive 151 | - path: "lib/util/testutil/.*\\.go" 152 | text: "deep-exit" 153 | linters: 154 | - revive 155 | # Exclude some linters from running on tests files. 156 | - path: ".*test_.*\\.go" 157 | linters: 158 | - gocyclo 159 | - errcheck 160 | - dupl 161 | - gosec 162 | - varnamelen 163 | - revive 164 | - path: ".*_test\\.go" 165 | linters: 166 | - gocyclo 167 | - errcheck 168 | - dupl 169 | - gosec 170 | - varnamelen 171 | - revive 172 | -------------------------------------------------------------------------------- /src/lib/communication/communication_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package communication 3 | 4 | func I() Service { 5 | return NewMockService() 6 | } 7 | -------------------------------------------------------------------------------- /src/lib/communication/service.go: -------------------------------------------------------------------------------- 1 | package communication 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | type Recipient struct { 10 | Email string 11 | FirstName string 12 | LastName string 13 | } 14 | 15 | type AlertRecipientType string 16 | 17 | const ( 18 | EmailRecipientType AlertRecipientType = "email" 19 | ) 20 | 21 | type AlertTopic string 22 | 23 | const ( 24 | AccountOverageTopic AlertTopic = "account_overage" 25 | ) 26 | 27 | type Service interface { 28 | HandleSignup(ctx context.Context, recip Recipient) error 29 | HandleMemberInvite(ctx context.Context, recip Recipient) error 30 | HandleMemberDelete(ctx context.Context, recip Recipient) error 31 | 32 | NotifyAccountOverage(accountUUID uuid.UUID, email, plan string) 33 | NotifyAccountCreation( 34 | accountUUID uuid.UUID, 35 | ownerEmail, ownerFirstName, ownerLastName string, 36 | ownerUUID uuid.UUID, 37 | ) 38 | NotifyAccountMemberAdded( 39 | accountUUID uuid.UUID, 40 | memberEmail, memberFirstName, memberLastName string, 41 | memberUUID uuid.UUID, 42 | ) 43 | } 44 | -------------------------------------------------------------------------------- /src/lib/communication/service_mock.go: -------------------------------------------------------------------------------- 1 | package communication 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | func NewMockService() Service { 10 | return &mockService{} 11 | } 12 | 13 | type mockService struct{} 14 | 15 | func (*mockService) HandleSignup(_ context.Context, _ Recipient) error { 16 | return nil 17 | } 18 | 19 | func (*mockService) HandleMemberInvite(_ context.Context, _ Recipient) error { 20 | return nil 21 | } 22 | 23 | func (*mockService) HandleMemberDelete(_ context.Context, _ Recipient) error { 24 | return nil 25 | } 26 | 27 | func (*mockService) NotifyAccountOverage(_ uuid.UUID, _, _ string) {} 28 | 29 | func (*mockService) NotifyAccountCreation(_ uuid.UUID, _, _, _ string, _ uuid.UUID) { 30 | } 31 | 32 | func (*mockService) NotifyAccountMemberAdded(_ uuid.UUID, _, _, _ string, _ uuid.UUID) { 33 | } 34 | -------------------------------------------------------------------------------- /src/lib/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | ) 7 | 8 | // this is a pointer so that if someone attempts to use it before loading it will 9 | // panic and force them to load it first. 10 | // it is also private so that it cannot be modified after loading. 11 | var _loaded *Config 12 | 13 | func LoadDefault() { 14 | config := defaultConfig 15 | 16 | _loaded = &config 17 | } 18 | 19 | // set sane defaults for all of the config options. when loading the config from 20 | // the file, any options that are not set will be set to these defaults. 21 | var defaultConfig = Config{ 22 | Common: Common{ 23 | Log: logConfig{ 24 | Level: "warn", 25 | Format: "json", 26 | }, 27 | Http: httpConfig{ 28 | Port: 9000, 29 | MaxRequestSize: 5242880, 30 | }, 31 | Carbon: carbonConfig{ 32 | Locale: "en", 33 | }, 34 | }, 35 | } 36 | 37 | type Common struct { 38 | Log logConfig `yaml:"log"` 39 | Http httpConfig `yaml:"http"` 40 | Postgres postgresConfig `yaml:"postgres"` 41 | Carbon carbonConfig `yaml:"carbon"` 42 | } 43 | 44 | type logConfig struct { 45 | Level string `yaml:"level"` 46 | Format string `yaml:"format"` 47 | } 48 | 49 | type httpConfig struct { 50 | Host string `yaml:"host"` 51 | Port int `yaml:"port"` 52 | MaxRequestSize int64 `yaml:"max_request_size"` 53 | } 54 | 55 | type postgresConfigCommon struct { 56 | User string `yaml:"user"` 57 | Password string `yaml:"password"` 58 | Host string `yaml:"host"` 59 | Port int `yaml:"port"` 60 | Database string `yaml:"database"` 61 | ReadTimeout int `yaml:"read_timeout"` 62 | WriteTimeout int `yaml:"write_timeout"` 63 | MaxOpenConnections int `yaml:"max_open_connections"` 64 | } 65 | 66 | func (c postgresConfigCommon) DSN() string { 67 | return fmt.Sprintf( 68 | "postgres://%s:%s@%s:%d/%s?sslmode=disable", 69 | url.QueryEscape(c.User), 70 | url.QueryEscape(c.Password), 71 | c.Host, 72 | c.Port, 73 | url.QueryEscape(c.Database), 74 | ) 75 | } 76 | 77 | type carbonConfig struct { 78 | // should be the name of one of the language files in carbon 79 | // https://github.com/golang-module/carbon/tree/master/lang 80 | Locale string `yaml:"locale"` 81 | } 82 | 83 | // there should be a getter for each top level field in the config struct. 84 | // these getters will panic if the config has not been loaded. 85 | 86 | func Logger() logConfig { 87 | return _loaded.Log 88 | } 89 | 90 | func Http() httpConfig { 91 | return _loaded.Http 92 | } 93 | 94 | func Postgres() postgresConfig { 95 | return _loaded.Postgres 96 | } 97 | 98 | func Carbon() carbonConfig { 99 | return _loaded.Carbon 100 | } 101 | -------------------------------------------------------------------------------- /src/lib/config/env_template.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | "text/template" 8 | ) 9 | 10 | func parseConfigTemplate(data []byte) ([]byte, error) { //nolint:unused // this is only called in CE 11 | var missingVars []string 12 | 13 | tmpl, err := template.New("config").Funcs(template.FuncMap{ 14 | "Env": func(key string) string { 15 | val := os.Getenv(key) 16 | if val == "" { 17 | missingVars = append(missingVars, key) 18 | } 19 | 20 | return val 21 | }, 22 | }).Parse(string(data)) 23 | if err != nil { 24 | return nil, err 25 | } 26 | 27 | var result strings.Builder 28 | 29 | err = tmpl.Execute(&result, nil) 30 | if err != nil { 31 | return nil, err 32 | } 33 | 34 | if len(missingVars) > 0 { 35 | return nil, fmt.Errorf("missing environmentvariables: %s", strings.Join(missingVars, ", ")) 36 | } 37 | 38 | return []byte(result.String()), nil 39 | } 40 | -------------------------------------------------------------------------------- /src/lib/config/load_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package config 3 | 4 | import ( 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | 10 | "gopkg.in/yaml.v3" 11 | ) 12 | 13 | // Load loads the config from the given filename. 14 | // This function will panic if the config file cannot be loaded or if the 15 | // config file is not valid. 16 | // Load should be called as early as possible in the application lifecycle and 17 | // before any config options are used. 18 | func Load() { 19 | location := os.Getenv("ZEP_CONFIG_FILE") 20 | if location == "" { 21 | wd, _ := os.Getwd() 22 | location = filepath.Join(wd, "zep.yaml") 23 | } 24 | 25 | data, err := os.ReadFile(location) 26 | if err != nil { 27 | panic(fmt.Errorf("config file could not be read: %w", err)) 28 | } 29 | 30 | data, err = parseConfigTemplate(data) 31 | if err != nil { 32 | panic(fmt.Errorf("error processing config file: %w", err)) 33 | } 34 | 35 | config := defaultConfig 36 | if err := yaml.Unmarshal(data, &config); err != nil { 37 | panic(fmt.Errorf("config file contains invalid yaml: %w", err)) 38 | } 39 | 40 | if err := cleanAndValidateConfig(&config); err != nil { 41 | panic(fmt.Errorf("config file is invalid: %w", err)) 42 | } 43 | 44 | _loaded = &config 45 | } 46 | 47 | func cleanAndValidateConfig(config *Config) error { 48 | secret := strings.TrimSpace(config.ApiSecret) 49 | if secret == "" { 50 | return fmt.Errorf("api_secret is not set") 51 | } 52 | 53 | config.ApiSecret = secret 54 | 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /src/lib/config/models_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package config 3 | 4 | import "github.com/google/uuid" 5 | 6 | type Config struct { 7 | Common `yaml:",inline"` 8 | 9 | Telemetry telemetryConfig `yaml:"telemetry"` 10 | Graphiti graphitiConfig `yaml:"graphiti"` 11 | ApiSecret string `yaml:"api_secret"` 12 | } 13 | 14 | type postgresConfig struct { 15 | postgresConfigCommon `yaml:",inline"` 16 | 17 | SchemaName string `yaml:"schema_name"` 18 | } 19 | 20 | type graphitiConfig struct { 21 | ServiceUrl string `yaml:"service_url"` 22 | } 23 | 24 | type telemetryConfig struct { 25 | Disabled bool `yaml:"disabled"` 26 | OrganizationName string `yaml:"organization_name"` 27 | } 28 | 29 | func Graphiti() graphitiConfig { 30 | return _loaded.Graphiti 31 | } 32 | 33 | func ApiSecret() string { 34 | return _loaded.ApiSecret 35 | } 36 | 37 | func ProjectUUID() uuid.UUID { 38 | return uuid.MustParse("399e79e0-d0ec-4ea8-a0bf-fe556d19fb9f") 39 | } 40 | 41 | func Telemetry() telemetryConfig { 42 | return _loaded.Telemetry 43 | } 44 | -------------------------------------------------------------------------------- /src/lib/config/version_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package config 3 | 4 | var Version = "dev" 5 | 6 | func VersionString() string { 7 | return Version 8 | } 9 | -------------------------------------------------------------------------------- /src/lib/enablement/enablement_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package enablement 3 | 4 | func I() Service { 5 | return NewMockService() 6 | } 7 | 8 | type EventMetadata any 9 | -------------------------------------------------------------------------------- /src/lib/enablement/events.go: -------------------------------------------------------------------------------- 1 | package enablement 2 | 3 | type Event string 4 | 5 | func (t Event) String() string { 6 | return string(t) 7 | } 8 | 9 | const ( 10 | Event_CreateUser Event = "user_create" 11 | Event_DeleteUser Event = "user_delete" 12 | Event_CreateAPIKey Event = "api_key_create" 13 | Event_CreateAccountMember Event = "account_create_member" 14 | Event_CreateProject Event = "project_create" 15 | Event_DeleteProject Event = "project_delete" 16 | Event_DataExtractor Event = "sde_call" 17 | Event_CreateSession Event = "session_create" 18 | Event_CreateMemoryMessage Event = "memory_create_message" 19 | ) 20 | -------------------------------------------------------------------------------- /src/lib/enablement/plan_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package enablement 3 | 4 | func (z BillingPlan) IsFree() bool { 5 | return true 6 | } 7 | -------------------------------------------------------------------------------- /src/lib/enablement/service.go: -------------------------------------------------------------------------------- 1 | package enablement 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | type Trait func(t map[string]any) 10 | 11 | type BillingPlan string 12 | 13 | func (z BillingPlan) String() string { 14 | return string(z) 15 | } 16 | 17 | type Profile struct { 18 | UUID uuid.UUID 19 | Plan BillingPlan 20 | UnderMessagesQuota bool 21 | } 22 | 23 | type Service interface { 24 | UpdateSubscription(ctx context.Context, customerId string, accountUUID uuid.UUID, newPlan BillingPlan) error 25 | GenerateSubscriptionURL(accountUUID uuid.UUID, plan BillingPlan) (string, error) 26 | GenerateCustomerPortalURL(customerId string) (string, error) 27 | ConfirmSubscription(accountUUID uuid.UUID, sessionId string) (string, error) 28 | UpdatePlan(ctx context.Context, accountUUID uuid.UUID, newPlan BillingPlan) error 29 | 30 | GetProfile(ctx context.Context, accountUUID uuid.UUID) Profile 31 | IsEnabled(ctx context.Context, accountUUID uuid.UUID, flag string) bool 32 | UnderProjectQuota(ctx context.Context, accountUUID uuid.UUID) bool 33 | 34 | CreateProfile(ctx context.Context, accountUUID uuid.UUID) 35 | CreateUser( 36 | ctx context.Context, 37 | accountUUID, memberUUID uuid.UUID, 38 | firstName, lastName, email string, 39 | traits ...Trait, 40 | ) 41 | UpdateProjectCount(ctx context.Context, accountUUID uuid.UUID, projectCount int) 42 | 43 | TrackEvent(event Event, metadata EventMetadata) 44 | } 45 | -------------------------------------------------------------------------------- /src/lib/enablement/service_mock.go: -------------------------------------------------------------------------------- 1 | package enablement 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | func NewMockService() Service { 10 | return &mockService{} 11 | } 12 | 13 | type mockService struct{} 14 | 15 | func (*mockService) UpdateSubscription(_ context.Context, _ string, _ uuid.UUID, _ BillingPlan) error { 16 | return nil 17 | } 18 | 19 | func (*mockService) GenerateSubscriptionURL(_ uuid.UUID, _ BillingPlan) (string, error) { 20 | return "", nil 21 | } 22 | 23 | func (*mockService) GenerateCustomerPortalURL(_ string) (string, error) { 24 | return "", nil 25 | } 26 | 27 | func (*mockService) ConfirmSubscription(_ uuid.UUID, _ string) (string, error) { 28 | return "", nil 29 | } 30 | 31 | func (*mockService) UpdatePlan(_ context.Context, _ uuid.UUID, _ BillingPlan) error { 32 | return nil 33 | } 34 | 35 | func (*mockService) GetProfile(_ context.Context, _ uuid.UUID) Profile { 36 | return Profile{} 37 | } 38 | 39 | func (*mockService) IsEnabled(_ context.Context, _ uuid.UUID, _ string) bool { 40 | return true 41 | } 42 | 43 | func (*mockService) UnderProjectQuota(_ context.Context, _ uuid.UUID) bool { 44 | return true 45 | } 46 | 47 | func (*mockService) CreateProfile(_ context.Context, _ uuid.UUID) {} 48 | 49 | func (*mockService) CreateUser( 50 | _ context.Context, 51 | _, _ uuid.UUID, 52 | _, _, _ string, 53 | _ ...Trait, 54 | ) { 55 | } 56 | 57 | func (*mockService) UpdateProjectCount(_ context.Context, _ uuid.UUID, _ int) { 58 | } 59 | 60 | func (*mockService) TrackEvent(_ Event, _ EventMetadata) {} 61 | 62 | func (*mockService) Close() {} 63 | -------------------------------------------------------------------------------- /src/lib/graphiti/service_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package graphiti 3 | 4 | import ( 5 | "bytes" 6 | "context" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "time" 12 | 13 | "github.com/google/uuid" 14 | 15 | "github.com/getzep/zep/lib/config" 16 | "github.com/getzep/zep/lib/util/httputil" 17 | "github.com/getzep/zep/models" 18 | ) 19 | 20 | type GetMemoryRequest struct { 21 | GroupID string `json:"group_id"` 22 | MaxFacts int `json:"max_facts"` 23 | CenterNodeUUID string `json:"center_node_uuid"` 24 | Messages []models.Message `json:"messages"` 25 | } 26 | 27 | type Fact struct { 28 | UUID uuid.UUID `json:"uuid"` 29 | Name string `json:"name"` 30 | Fact string `json:"fact"` 31 | CreatedAt time.Time `json:"created_at"` 32 | ExpiredAt *time.Time `json:"expired_at"` 33 | ValidAt *time.Time `json:"valid_at"` 34 | InvalidAt *time.Time `json:"invalid_at"` 35 | } 36 | 37 | func (f Fact) ExtractCreatedAt() time.Time { 38 | if f.ValidAt != nil { 39 | return *f.ValidAt 40 | } 41 | return f.CreatedAt 42 | } 43 | 44 | type GetMemoryResponse struct { 45 | Facts []Fact `json:"facts"` 46 | } 47 | 48 | type Message struct { 49 | UUID string `json:"uuid"` 50 | // The role of the sender of the message (e.g., "user", "assistant"). 51 | Role string `json:"role"` 52 | // The type of the role (e.g., "user", "system"). 53 | RoleType models.RoleType `json:"role_type,omitempty"` 54 | // The content of the message. 55 | Content string `json:"content"` 56 | } 57 | 58 | type PutMemoryRequest struct { 59 | GroupId string `json:"group_id"` 60 | Messages []Message `json:"messages"` 61 | } 62 | 63 | type SearchRequest struct { 64 | GroupIDs []string `json:"group_ids"` 65 | Text string `json:"query"` 66 | MaxFacts int `json:"max_facts,omitempty"` 67 | } 68 | 69 | type SearchResponse struct { 70 | Facts []Fact `json:"facts"` 71 | } 72 | 73 | type AddNodeRequest struct { 74 | GroupID string `json:"group_id"` 75 | UUID string `json:"uuid"` 76 | Name string `json:"name"` 77 | Summary string `json:"summary"` 78 | } 79 | 80 | type Service interface { 81 | GetMemory(ctx context.Context, payload GetMemoryRequest) (*GetMemoryResponse, error) 82 | PutMemory(ctx context.Context, groupID string, messages []models.Message, addGroupIDPrefix bool) error 83 | Search(ctx context.Context, payload SearchRequest) (*SearchResponse, error) 84 | AddNode(ctx context.Context, payload AddNodeRequest) error 85 | GetFact(ctx context.Context, factUUID uuid.UUID) (*Fact, error) 86 | DeleteFact(ctx context.Context, factUUID uuid.UUID) error 87 | DeleteGroup(ctx context.Context, groupID string) error 88 | DeleteMessage(ctx context.Context, messageUUID uuid.UUID) error 89 | } 90 | 91 | var _instance Service 92 | 93 | func I() Service { 94 | return _instance 95 | } 96 | 97 | type service struct { 98 | Client httputil.HTTPClient 99 | BaseUrl string 100 | } 101 | 102 | func Setup() { 103 | if _instance != nil { 104 | return 105 | } 106 | 107 | _instance = &service{ 108 | Client: httputil.NewRetryableHTTPClient( 109 | httputil.DefaultRetryMax, 110 | httputil.DefaultTimeout, 111 | httputil.IgnoreBadRequestRetryPolicy, 112 | "", 113 | ), 114 | BaseUrl: config.Graphiti().ServiceUrl, 115 | } 116 | } 117 | 118 | func (s *service) newRequest(ctx context.Context, method, path string, body any) (*http.Request, error) { 119 | buf := new(bytes.Buffer) 120 | if body != nil { 121 | err := json.NewEncoder(buf).Encode(body) 122 | if err != nil { 123 | return nil, err 124 | } 125 | } 126 | 127 | req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s/%s", s.BaseUrl, path), buf) 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | req.Header.Set("Content-Type", "application/json") 133 | 134 | return req, nil 135 | } 136 | 137 | func (s *service) doRequest(req *http.Request, v any) error { 138 | resp, err := s.Client.Do(req) 139 | if err != nil { 140 | return err 141 | } 142 | 143 | defer func(body io.ReadCloser) { 144 | body.Close() 145 | }(resp.Body) 146 | 147 | if resp.StatusCode > http.StatusAccepted { 148 | return fmt.Errorf("received status code: %d", resp.StatusCode) 149 | } 150 | 151 | if v == nil { 152 | return nil 153 | } 154 | 155 | body, err := io.ReadAll(resp.Body) 156 | if err != nil { 157 | return err 158 | } 159 | 160 | if len(body) == 0 { 161 | return fmt.Errorf("received empty response") 162 | } 163 | 164 | return json.Unmarshal(body, v) 165 | } 166 | 167 | func (s *service) GetMemory(ctx context.Context, payload GetMemoryRequest) (*GetMemoryResponse, error) { 168 | req, err := s.newRequest(ctx, http.MethodPost, "get-memory", payload) 169 | if err != nil { 170 | return nil, fmt.Errorf("failed to create request: %w", err) 171 | } 172 | 173 | var resp GetMemoryResponse 174 | 175 | err = s.doRequest(req, &resp) 176 | if err != nil { 177 | return nil, fmt.Errorf("failed to do request: %w", err) 178 | } 179 | 180 | return &resp, nil 181 | } 182 | 183 | func (s *service) PutMemory(ctx context.Context, groupID string, messages []models.Message, addGroupIDPrefix bool) error { 184 | var graphitiMessages []Message 185 | for _, m := range messages { 186 | episodeUUID := m.UUID.String() 187 | if addGroupIDPrefix { 188 | episodeUUID = fmt.Sprintf("%s-%s", groupID, m.UUID) 189 | } 190 | graphitiMessages = append(graphitiMessages, Message{ 191 | UUID: episodeUUID, 192 | Role: m.Role, 193 | RoleType: m.RoleType, 194 | Content: m.Content, 195 | }) 196 | } 197 | 198 | req, err := s.newRequest(ctx, http.MethodPost, "messages", &PutMemoryRequest{ 199 | GroupId: groupID, 200 | Messages: graphitiMessages, 201 | }) 202 | if err != nil { 203 | return fmt.Errorf("failed to create request: %w", err) 204 | } 205 | err = s.doRequest(req, nil) 206 | if err != nil { 207 | return fmt.Errorf("failed to do request: %w", err) 208 | } 209 | 210 | return nil 211 | } 212 | 213 | func (s *service) AddNode(ctx context.Context, payload AddNodeRequest) error { 214 | req, err := s.newRequest(ctx, http.MethodPost, "entity-node", payload) 215 | if err != nil { 216 | return fmt.Errorf("failed to create request: %w", err) 217 | } 218 | 219 | err = s.doRequest(req, nil) 220 | if err != nil { 221 | return fmt.Errorf("failed to do request: %w", err) 222 | } 223 | 224 | return nil 225 | } 226 | 227 | func (s *service) Search(ctx context.Context, payload SearchRequest) (*SearchResponse, error) { 228 | req, err := s.newRequest(ctx, http.MethodPost, "search", payload) 229 | if err != nil { 230 | return nil, fmt.Errorf("failed to create request: %w", err) 231 | } 232 | 233 | var resp SearchResponse 234 | 235 | err = s.doRequest(req, &resp) 236 | if err != nil { 237 | return nil, fmt.Errorf("failed to do request: %w", err) 238 | } 239 | 240 | return &resp, nil 241 | } 242 | 243 | func (s *service) GetFact(ctx context.Context, factUUID uuid.UUID) (*Fact, error) { 244 | req, err := s.newRequest(ctx, http.MethodGet, fmt.Sprintf("entity-edge/%s", factUUID), nil) 245 | if err != nil { 246 | return nil, fmt.Errorf("failed to create request: %w", err) 247 | } 248 | 249 | var resp Fact 250 | 251 | err = s.doRequest(req, &resp) 252 | if err != nil { 253 | return nil, fmt.Errorf("failed to do request: %w", err) 254 | } 255 | 256 | return &resp, nil 257 | } 258 | 259 | func (s *service) DeleteGroup(ctx context.Context, groupID string) error { 260 | req, err := s.newRequest(ctx, http.MethodDelete, fmt.Sprintf("group/%s", groupID), nil) 261 | if err != nil { 262 | return fmt.Errorf("failed to create request: %w", err) 263 | } 264 | 265 | err = s.doRequest(req, nil) 266 | if err != nil { 267 | return fmt.Errorf("failed to do request: %w", err) 268 | } 269 | 270 | return nil 271 | } 272 | 273 | func (s *service) DeleteFact(ctx context.Context, factUUID uuid.UUID) error { 274 | req, err := s.newRequest(ctx, http.MethodDelete, fmt.Sprintf("entity-edge/%s", factUUID), nil) 275 | if err != nil { 276 | return fmt.Errorf("failed to create request: %w", err) 277 | } 278 | 279 | err = s.doRequest(req, nil) 280 | if err != nil { 281 | return fmt.Errorf("failed to do request: %w", err) 282 | } 283 | 284 | return nil 285 | } 286 | 287 | func (s *service) DeleteMessage(ctx context.Context, messageUUID uuid.UUID) error { 288 | req, err := s.newRequest(ctx, http.MethodDelete, fmt.Sprintf("episode/%s", messageUUID), nil) 289 | if err != nil { 290 | return fmt.Errorf("failed to create request: %w", err) 291 | } 292 | 293 | err = s.doRequest(req, nil) 294 | if err != nil { 295 | return fmt.Errorf("failed to do request: %w", err) 296 | } 297 | 298 | return nil 299 | } 300 | -------------------------------------------------------------------------------- /src/lib/logger/bun_hook.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | // this is for the most part a copy of the logrusbun hook - https://github.com/oiime/logrusbun 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "fmt" 9 | "strings" 10 | "time" 11 | 12 | "github.com/uptrace/bun" 13 | ) 14 | 15 | const ( 16 | maxQueryOpNameLen = 16 17 | ) 18 | 19 | type QueryHookOptions struct { 20 | LogSlow time.Duration 21 | QueryLevel LogLevel 22 | SlowLevel LogLevel 23 | ErrorLevel LogLevel 24 | } 25 | 26 | type QueryHook struct { 27 | opts QueryHookOptions 28 | } 29 | 30 | type LogEntryVars struct { 31 | Timestamp time.Time 32 | Query string 33 | Operation string 34 | Duration time.Duration 35 | Error error 36 | } 37 | 38 | // NewQueryHook returns new instance 39 | func NewQueryHook(opts QueryHookOptions) *QueryHook { 40 | h := QueryHook{ 41 | opts: opts, 42 | } 43 | 44 | return &h 45 | } 46 | 47 | func (*QueryHook) BeforeQuery(ctx context.Context, _ *bun.QueryEvent) context.Context { 48 | return ctx 49 | } 50 | 51 | func (h *QueryHook) AfterQuery(_ context.Context, event *bun.QueryEvent) { 52 | var level LogLevel 53 | 54 | now := time.Now() 55 | dur := now.Sub(event.StartTime) 56 | 57 | switch event.Err { 58 | case nil, sql.ErrNoRows: 59 | level = h.opts.QueryLevel 60 | 61 | if h.opts.LogSlow > 0 && dur >= h.opts.LogSlow { 62 | level = h.opts.SlowLevel 63 | } 64 | default: 65 | level = h.opts.ErrorLevel 66 | } 67 | 68 | if level == "" { 69 | return 70 | } 71 | 72 | msg := fmt.Sprintf("[%s]: %s", eventOperation(event), string(event.Query)) 73 | 74 | fields := []any{ 75 | "timestamp", now, 76 | "duration", dur, 77 | } 78 | 79 | if event.Err != nil { 80 | fields = append(fields, "error", event.Err) 81 | } 82 | 83 | switch level { 84 | case DebugLevel: 85 | Debug(msg, fields...) 86 | case InfoLevel: 87 | Info(msg, fields...) 88 | case WarnLevel: 89 | Warn(msg, fields...) 90 | case ErrorLevel: 91 | Error(msg, fields...) 92 | case FatalLevel: 93 | Fatal(msg, fields...) 94 | case PanicLevel: 95 | Panic(msg, fields...) 96 | default: 97 | panic(fmt.Errorf("unsupported level: %v", level)) 98 | } 99 | } 100 | 101 | func eventOperation(event *bun.QueryEvent) string { 102 | switch event.IQuery.(type) { 103 | case *bun.SelectQuery: 104 | return "SELECT" 105 | case *bun.InsertQuery: 106 | return "INSERT" 107 | case *bun.UpdateQuery: 108 | return "UPDATE" 109 | case *bun.DeleteQuery: 110 | return "DELETE" 111 | case *bun.CreateTableQuery: 112 | return "CREATE TABLE" 113 | case *bun.DropTableQuery: 114 | return "DROP TABLE" 115 | } 116 | return queryOperation(event.Query) 117 | } 118 | 119 | func queryOperation(name string) string { 120 | if idx := strings.Index(name, " "); idx > 0 { 121 | name = name[:idx] 122 | } 123 | if len(name) > maxQueryOpNameLen { 124 | name = name[:maxQueryOpNameLen] 125 | } 126 | return string(name) 127 | } 128 | -------------------------------------------------------------------------------- /src/lib/logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/ThreeDotsLabs/watermill" 8 | "go.uber.org/zap" 9 | "go.uber.org/zap/zapcore" 10 | 11 | "github.com/getzep/zep/lib/config" 12 | ) 13 | 14 | type LogFormat int 15 | 16 | const ( 17 | JsonFormat LogFormat = iota 18 | ConsoleFormat 19 | ) 20 | 21 | type LogLevel string 22 | 23 | func (ll LogLevel) String() string { 24 | return string(ll) 25 | } 26 | 27 | const ( 28 | DebugLevel LogLevel = "DEBUG" 29 | InfoLevel LogLevel = "INFO" 30 | WarnLevel LogLevel = "WARN" 31 | ErrorLevel LogLevel = "ERROR" 32 | PanicLevel LogLevel = "PANIC" 33 | DPanicLevel LogLevel = "DPANIC" 34 | FatalLevel LogLevel = "FATAL" 35 | ) 36 | 37 | // we use a singleton for the default logger. in the future we may expose an 38 | // api for creating new instances for specific use cases. 39 | var _instance *logger 40 | 41 | func InitDefaultLogger() { 42 | if _instance != nil { 43 | return 44 | } 45 | 46 | var ( 47 | lvl = InfoLevel 48 | format = JsonFormat 49 | 50 | zapLevel zap.AtomicLevel 51 | zapFormat string 52 | ) 53 | 54 | if envLevel := config.Logger().Level; envLevel != "" { 55 | lvl = LogLevel(strings.ToUpper(envLevel)) 56 | } 57 | 58 | if envFormat := config.Logger().Format; envFormat != "" { 59 | switch envFormat { 60 | case "json": 61 | format = JsonFormat 62 | case "console": 63 | format = ConsoleFormat 64 | default: 65 | // if we manage to get here, it's a bug and panicking is fine because 66 | // we'd want to prevent startup. 67 | panic(fmt.Errorf("bad log format in environment variable: %s", envFormat)) 68 | } 69 | } 70 | 71 | switch lvl { 72 | case DebugLevel: 73 | zapLevel = zap.NewAtomicLevelAt(zap.DebugLevel) 74 | case InfoLevel: 75 | zapLevel = zap.NewAtomicLevelAt(zap.InfoLevel) 76 | case WarnLevel: 77 | zapLevel = zap.NewAtomicLevelAt(zap.WarnLevel) 78 | case ErrorLevel: 79 | zapLevel = zap.NewAtomicLevelAt(zap.ErrorLevel) 80 | case PanicLevel: 81 | zapLevel = zap.NewAtomicLevelAt(zap.PanicLevel) 82 | case DPanicLevel: 83 | zapLevel = zap.NewAtomicLevelAt(zap.DPanicLevel) 84 | case FatalLevel: 85 | zapLevel = zap.NewAtomicLevelAt(zap.FatalLevel) 86 | default: 87 | // if we manage to get here, it's a bug and panicking is fine because 88 | // we'd want to prevent startup. 89 | panic(fmt.Errorf("bad log level: %s", lvl)) 90 | } 91 | 92 | switch format { 93 | case JsonFormat: 94 | zapFormat = "json" 95 | case ConsoleFormat: 96 | zapFormat = "console" 97 | default: 98 | panic(fmt.Errorf("bad log format: %d", format)) 99 | } 100 | 101 | zapConfig := zap.Config{ 102 | Level: zapLevel, 103 | Development: false, 104 | Encoding: zapFormat, 105 | OutputPaths: []string{"stdout"}, 106 | ErrorOutputPaths: []string{"stdout"}, 107 | DisableCaller: false, 108 | EncoderConfig: zapcore.EncoderConfig{ 109 | MessageKey: "msg", 110 | LevelKey: "level", 111 | TimeKey: "ts", 112 | StacktraceKey: "stack", 113 | LineEnding: zapcore.DefaultLineEnding, 114 | EncodeLevel: zapcore.CapitalLevelEncoder, 115 | EncodeTime: zapcore.ISO8601TimeEncoder, 116 | EncodeDuration: zapcore.StringDurationEncoder, 117 | EncodeCaller: zapcore.ShortCallerEncoder, 118 | }, 119 | } 120 | 121 | log, err := zapConfig.Build() 122 | if err != nil { 123 | panic(err) 124 | } 125 | 126 | l := &logger{ 127 | level: lvl, 128 | format: format, 129 | 130 | logger: log.Sugar(), 131 | } 132 | 133 | _instance = l 134 | } 135 | 136 | func GetLogLevel() LogLevel { 137 | return _instance.level 138 | } 139 | 140 | // Use only if absolutely needed 141 | func GetZapLogger() *zap.Logger { 142 | return _instance.logger.Desugar() 143 | } 144 | 145 | func GetLogger() Logger { 146 | return _instance 147 | } 148 | 149 | type Logger interface { 150 | Debug(msg string, keysAndValues ...any) 151 | Info(msg string, keysAndValues ...any) 152 | Warn(msg string, keysAndValues ...any) 153 | Error(msg string, keysAndValues ...any) 154 | Panic(msg string, keysAndValues ...any) 155 | DPanic(msg string, keysAndValues ...any) 156 | Fatal(msg string, keysAndValues ...any) 157 | } 158 | 159 | type logger struct { 160 | level LogLevel 161 | format LogFormat 162 | 163 | logger *zap.SugaredLogger 164 | } 165 | 166 | func (l logger) Debug(msg string, keysAndValues ...any) { 167 | l.logger.Debugw(msg, keysAndValues...) 168 | } 169 | 170 | func (l logger) Info(msg string, keysAndValues ...any) { 171 | l.logger.Infow(msg, keysAndValues...) 172 | } 173 | 174 | func (l logger) Warn(msg string, keysAndValues ...any) { 175 | l.logger.Warnw(msg, keysAndValues...) 176 | } 177 | 178 | func (l logger) Error(msg string, keysAndValues ...any) { 179 | l.logger.Errorw(msg, keysAndValues...) 180 | } 181 | 182 | func (l logger) Panic(msg string, keysAndValues ...any) { 183 | l.logger.Panicw(msg, keysAndValues...) 184 | } 185 | 186 | func (l logger) DPanic(msg string, keysAndValues ...any) { 187 | l.logger.DPanicw(msg, keysAndValues...) 188 | } 189 | 190 | func (l logger) Fatal(msg string, keysAndValues ...any) { 191 | l.logger.Fatalw(msg, keysAndValues...) 192 | } 193 | 194 | func Debug(msg string, keysAndValues ...any) { 195 | _instance.Debug(msg, keysAndValues...) 196 | } 197 | 198 | func Info(msg string, keysAndValues ...any) { 199 | _instance.Info(msg, keysAndValues...) 200 | } 201 | 202 | func Warn(msg string, keysAndValues ...any) { 203 | _instance.Warn(msg, keysAndValues...) 204 | } 205 | 206 | func Error(msg string, keysAndValues ...any) { 207 | _instance.Error(msg, keysAndValues...) 208 | } 209 | 210 | func Panic(msg string, keysAndValues ...any) { 211 | _instance.Panic(msg, keysAndValues...) 212 | } 213 | 214 | func DPanic(msg string, keysAndValues ...any) { 215 | _instance.DPanic(msg, keysAndValues...) 216 | } 217 | 218 | func Fatal(msg string, keysAndValues ...any) { 219 | _instance.Fatal(msg, keysAndValues...) 220 | } 221 | 222 | type watermillLogger struct { 223 | fields watermill.LogFields 224 | } 225 | 226 | func GetWatermillLogger() watermill.LoggerAdapter { 227 | return &watermillLogger{} 228 | } 229 | 230 | func (l *watermillLogger) Error(msg string, err error, fields watermill.LogFields) { 231 | fields = l.fields.Add(fields) 232 | 233 | keysAndValues := make([]any, 0, len(fields)+1) 234 | 235 | for k, v := range fields { 236 | keysAndValues = append(keysAndValues, k, v) 237 | } 238 | 239 | keysAndValues = append(keysAndValues, "error", err) 240 | 241 | _instance.Error(msg, keysAndValues...) 242 | } 243 | 244 | func (l *watermillLogger) Info(msg string, fields watermill.LogFields) { 245 | fields = l.fields.Add(fields) 246 | 247 | keysAndValues := make([]any, 0, len(fields)) 248 | 249 | for k, v := range fields { 250 | keysAndValues = append(keysAndValues, k, v) 251 | } 252 | 253 | _instance.Info(msg, keysAndValues...) 254 | } 255 | 256 | func (l *watermillLogger) Debug(msg string, fields watermill.LogFields) { 257 | fields = l.fields.Add(fields) 258 | 259 | keysAndValues := make([]any, 0, len(fields)) 260 | 261 | for k, v := range fields { 262 | keysAndValues = append(keysAndValues, k, v) 263 | } 264 | 265 | _instance.Debug(msg, keysAndValues...) 266 | } 267 | 268 | func (l *watermillLogger) Trace(msg string, fields watermill.LogFields) { 269 | fields = l.fields.Add(fields) 270 | 271 | keysAndValues := make([]any, 0, len(fields)) 272 | 273 | for k, v := range fields { 274 | keysAndValues = append(keysAndValues, k, v) 275 | } 276 | 277 | _instance.Debug(msg, keysAndValues...) 278 | } 279 | 280 | func (l *watermillLogger) With(fields watermill.LogFields) watermill.LoggerAdapter { 281 | return &watermillLogger{ 282 | fields: l.fields.Add(fields), 283 | } 284 | } 285 | -------------------------------------------------------------------------------- /src/lib/observability/observability_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package observability 3 | 4 | func I() Service { 5 | return NewMockService() 6 | } 7 | 8 | func Setup() {} 9 | 10 | func Shutdown() {} 11 | -------------------------------------------------------------------------------- /src/lib/observability/service.go: -------------------------------------------------------------------------------- 1 | package observability 2 | 3 | import "github.com/google/uuid" 4 | 5 | type Category string 6 | 7 | func (c Category) String() string { 8 | return string(c) 9 | } 10 | 11 | type Service interface { 12 | CaptureError(msg string, err error, keysAndValues ...any) 13 | CaptureBreadcrumb(category Category, message string, metadata ...map[string]any) 14 | LogError(msg string, keysAndValues ...any) 15 | SetRequestScope(accountUUID, projectUUID uuid.UUID) 16 | } 17 | 18 | const ( 19 | Category_Projects Category = "projects" 20 | Category_Messages Category = "messages" 21 | Category_Users Category = "users" 22 | Category_Facts Category = "facts" 23 | Category_Accounts Category = "accounts" 24 | Category_Sessions Category = "sessions" 25 | Category_Auth Category = "auth" 26 | Category_AccountStore Category = "account_store" 27 | Category_ProjectStore Category = "project_store" 28 | Category_Tasks Category = "task" 29 | ) 30 | -------------------------------------------------------------------------------- /src/lib/observability/service_mock.go: -------------------------------------------------------------------------------- 1 | package observability 2 | 3 | import "github.com/google/uuid" 4 | 5 | func NewMockService() *mockService { 6 | return &mockService{} 7 | } 8 | 9 | type mockService struct{} 10 | 11 | func (*mockService) CaptureError(_ string, _ error, _ ...any) {} 12 | 13 | func (*mockService) CaptureBreadcrumb(_ Category, _ string, _ ...map[string]any) { 14 | } 15 | 16 | func (*mockService) LogError(_ string, _ ...any) {} 17 | 18 | func (*mockService) SetRequestScope(_, _ uuid.UUID) {} 19 | -------------------------------------------------------------------------------- /src/lib/pg/db.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "runtime" 8 | "time" 9 | 10 | "github.com/uptrace/bun" 11 | "github.com/uptrace/bun/dialect/pgdialect" 12 | "github.com/uptrace/bun/driver/pgdriver" 13 | "github.com/uptrace/bun/extra/bunotel" 14 | 15 | "github.com/getzep/zep/lib/config" 16 | "github.com/getzep/zep/lib/logger" 17 | ) 18 | 19 | var maxOpenConns = 4 * runtime.GOMAXPROCS(0) 20 | 21 | type Connection struct { 22 | *bun.DB 23 | } 24 | 25 | // NewConnection creates a new database connection and will panic if the connection fails. 26 | // Assumed to be called at startup so panicking is ok as it will prevent the app from starting. 27 | func NewConnection() Connection { 28 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 29 | defer cancel() 30 | 31 | if config.Postgres().DSN() == "" { 32 | panic(fmt.Errorf("missing postgres DSN")) 33 | } 34 | 35 | sqldb := sql.OpenDB( 36 | pgdriver.NewConnector( 37 | pgdriver.WithDSN(config.Postgres().DSN()), 38 | pgdriver.WithReadTimeout(15*time.Second), 39 | pgdriver.WithWriteTimeout(15*time.Second), 40 | ), 41 | ) 42 | sqldb.SetMaxOpenConns(maxOpenConns) 43 | sqldb.SetMaxIdleConns(maxOpenConns) 44 | 45 | db := bun.NewDB(sqldb, pgdialect.New()) 46 | db.AddQueryHook(bunotel.NewQueryHook(bunotel.WithDBName("zep"))) 47 | 48 | // Enable pgvector extension 49 | err := enablePgVectorExtension(ctx, db) 50 | if err != nil { 51 | panic(fmt.Errorf("error enabling pgvector extension: %w", err)) 52 | } 53 | 54 | if logger.GetLogLevel() == logger.DebugLevel { 55 | enableDebugLogging(db) 56 | } 57 | 58 | return Connection{ 59 | DB: db, 60 | } 61 | } 62 | 63 | func enableDebugLogging(db *bun.DB) { 64 | db.AddQueryHook(logger.NewQueryHook(logger.QueryHookOptions{ 65 | LogSlow: time.Second, 66 | QueryLevel: logger.DebugLevel, 67 | ErrorLevel: logger.ErrorLevel, 68 | SlowLevel: logger.WarnLevel, 69 | })) 70 | } 71 | 72 | func enablePgVectorExtension(ctx context.Context, db *bun.DB) error { 73 | // Create pgvector extension in 'extensions' schema if it does not exist 74 | _, err := db.ExecContext(ctx, "CREATE EXTENSION IF NOT EXISTS vector WITH SCHEMA PUBLIC;") 75 | if err != nil { 76 | return fmt.Errorf("error creating pgvector extension: %w", err) 77 | } 78 | 79 | // if this is an upgrade, we may need to update the pgvector extension 80 | // this is a no-op if the extension is already up to date 81 | // if this fails, Zep may not have rights to update extensions. 82 | // this is not an issue if running on a managed service. 83 | _, err = db.ExecContext(ctx, "ALTER EXTENSION vector UPDATE") 84 | if err != nil { 85 | // TODO should this just panic or at last return the error? 86 | logger.Error( 87 | "error updating pgvector extension: %s. this may happen if running on a managed service without rights to update extensions.", 88 | "error", err, 89 | ) 90 | 91 | return nil 92 | } 93 | 94 | return nil 95 | } 96 | -------------------------------------------------------------------------------- /src/lib/pg/integrity.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import "github.com/uptrace/bun/driver/pgdriver" 4 | 5 | func IsIntegrityViolation(err error) bool { 6 | pgErr, ok := err.(pgdriver.Error) 7 | 8 | return ok && pgErr.IntegrityViolation() 9 | } 10 | -------------------------------------------------------------------------------- /src/lib/search/mmr.go: -------------------------------------------------------------------------------- 1 | package search 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math" 7 | 8 | "github.com/viterin/vek/vek32" 9 | ) 10 | 11 | // pairwiseCosineSimilarity takes two matrices of vectors and returns a matrix, where 12 | // the value at [i][j] is the cosine similarity between the ith vector in matrix1 and 13 | // the jth vector in matrix2. 14 | func pairwiseCosineSimilarity(matrix1, matrix2 [][]float32) ([][]float32, error) { 15 | result := make([][]float32, len(matrix1)) 16 | for i, vec1 := range matrix1 { 17 | result[i] = make([]float32, len(matrix2)) 18 | for j, vec2 := range matrix2 { 19 | if len(vec1) != len(vec2) { 20 | return nil, fmt.Errorf( 21 | "vector lengths do not match: %d != %d", 22 | len(vec1), 23 | len(vec2), 24 | ) 25 | } 26 | result[i][j] = vek32.CosineSimilarity(vec1, vec2) 27 | } 28 | } 29 | return result, nil 30 | } 31 | 32 | // MaximalMarginalRelevance implements the Maximal Marginal Relevance algorithm. 33 | // It takes a query embedding, a list of embeddings, a lambda multiplier, and a 34 | // number of results to return. It returns a list of indices of the embeddings 35 | // that are most relevant to the query. 36 | // See https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf 37 | // Implementation borrowed from LangChain 38 | // https://github.com/langchain-ai/langchain/blob/4a2f0c51a116cc3141142ea55254e270afb6acde/libs/langchain/langchain/vectorstores/utils.py 39 | func MaximalMarginalRelevance( 40 | queryEmbedding []float32, 41 | embeddingList [][]float32, 42 | lambdaMult float32, 43 | k int, 44 | ) ([]int, error) { 45 | // if either k or the length of the embedding list is 0, return an empty list 46 | if min(k, len(embeddingList)) <= 0 { 47 | return []int{}, nil 48 | } 49 | 50 | // We expect the query embedding and the embeddings in the list to have the same width 51 | if len(queryEmbedding) != len(embeddingList[0]) { 52 | return []int{}, errors.New("query embedding width does not match embedding vector width") 53 | } 54 | 55 | similarityToQueryMatrix, err := pairwiseCosineSimilarity( 56 | [][]float32{queryEmbedding}, 57 | embeddingList, 58 | ) 59 | if err != nil { 60 | return nil, err 61 | } 62 | similarityToQuery := similarityToQueryMatrix[0] 63 | 64 | mostSimilar := vek32.ArgMax(similarityToQuery) 65 | idxs := []int{mostSimilar} 66 | selected := [][]float32{embeddingList[mostSimilar]} 67 | 68 | for len(idxs) < min(k, len(embeddingList)) { 69 | var bestScore float32 = -math.MaxFloat32 70 | idxToAdd := -1 71 | similarityToSelected, err := pairwiseCosineSimilarity(embeddingList, selected) 72 | if err != nil { 73 | return nil, err 74 | } 75 | 76 | for i, queryScore := range similarityToQuery { 77 | if contains(idxs, i) { 78 | continue 79 | } 80 | redundantScore := vek32.Max(similarityToSelected[i]) 81 | equationScore := lambdaMult*queryScore - (1-lambdaMult)*redundantScore 82 | if equationScore > bestScore { 83 | bestScore = equationScore 84 | idxToAdd = i 85 | } 86 | } 87 | idxs = append(idxs, idxToAdd) 88 | selected = append(selected, embeddingList[idxToAdd]) 89 | } 90 | return idxs, nil 91 | } 92 | 93 | // contains returns true if the slice contains the value 94 | func contains(slice []int, val int) bool { 95 | for _, item := range slice { 96 | if item == val { 97 | return true 98 | } 99 | } 100 | return false 101 | } 102 | -------------------------------------------------------------------------------- /src/lib/search/rrf.go: -------------------------------------------------------------------------------- 1 | package search 2 | 3 | import ( 4 | "slices" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | type Rankable interface { 10 | GetUUID() uuid.UUID 11 | } 12 | 13 | // ReciprocalRankFusion is a function that takes a list of result sets and returns a single list of results, 14 | // where each result is ranked by the sum of the reciprocal ranks of the results in each result set. 15 | func ReciprocalRankFusion[T Rankable](results [][]T) []T { 16 | rankings := make(map[uuid.UUID]float64) 17 | for _, resultSet := range results { 18 | for rank, result := range resultSet { 19 | rankings[result.GetUUID()] += 1.0 / float64(rank+1) //nolint:revive //declaring consts here would be silly 20 | } 21 | } 22 | 23 | uniqueResults := make(map[uuid.UUID]T) 24 | for _, resultSet := range results { 25 | for _, result := range resultSet { 26 | id := result.GetUUID() 27 | if _, exists := uniqueResults[id]; !exists { 28 | uniqueResults[id] = result 29 | } 30 | } 31 | } 32 | 33 | finalResults := make([]T, 0, len(uniqueResults)) 34 | for _, item := range uniqueResults { 35 | finalResults = append(finalResults, item) 36 | } 37 | 38 | slices.SortFunc(finalResults, func(a, b T) int { 39 | if rankings[a.GetUUID()] > rankings[b.GetUUID()] { 40 | return -1 41 | } else if rankings[a.GetUUID()] < rankings[b.GetUUID()] { 42 | return 1 43 | } 44 | return 0 45 | }) 46 | 47 | return finalResults 48 | } 49 | -------------------------------------------------------------------------------- /src/lib/telemetry/events.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | type Event string 4 | 5 | func (t Event) String() string { 6 | return string(t) 7 | } 8 | 9 | const ( 10 | Event_CreateUser Event = "user_create" 11 | Event_DeleteUser Event = "user_delete" 12 | Event_CreateFacts Event = "facts_create" 13 | Event_CreateMemoryMessage Event = "memory_create_message" 14 | Event_GetMemory Event = "memory_get" 15 | Event_CreateSession Event = "session_create" 16 | Event_DeleteSession Event = "session_delete" 17 | Event_SearchSessions Event = "sessions_search" 18 | 19 | Event_CEStart Event = "ce_start" 20 | Event_CEStop Event = "ce_stop" 21 | ) 22 | -------------------------------------------------------------------------------- /src/lib/telemetry/service.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import "github.com/google/uuid" 4 | 5 | type CEEvent struct { 6 | Event Event `json:"event"` 7 | InstallID string `json:"install_id"` 8 | OrgName string `json:"org_name"` 9 | Data map[string]any `json:"data,omitempty"` 10 | } 11 | 12 | type Service interface { 13 | TrackEvent(req Request, event Event, metadata ...map[string]any) 14 | } 15 | 16 | // this interface is used to avoid needing to have a dependency on the models package. 17 | type RequestCommon interface { 18 | GetProjectUUID() uuid.UUID 19 | GetRequestTokenType() string 20 | } 21 | 22 | var _instance Service 23 | 24 | func I() Service { 25 | return _instance 26 | } 27 | -------------------------------------------------------------------------------- /src/lib/telemetry/service_mock.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | func NewMockService() Service { 4 | return &mockService{} 5 | } 6 | 7 | type mockService struct{} 8 | 9 | func (*mockService) TrackEvent(_ Request, _ Event, _ ...map[string]any) { 10 | } 11 | 12 | func (*mockService) Close() { 13 | } 14 | -------------------------------------------------------------------------------- /src/lib/telemetry/telemetry_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package telemetry 3 | 4 | import ( 5 | "bytes" 6 | "encoding/json" 7 | "net/http" 8 | "os" 9 | "strings" 10 | "time" 11 | 12 | "github.com/google/uuid" 13 | 14 | "github.com/getzep/zep/lib/config" 15 | ) 16 | 17 | const installIDFilePermissions = 0o644 18 | 19 | type Request interface { 20 | RequestCommon 21 | } 22 | 23 | func Setup() { 24 | if _instance != nil { 25 | return 26 | } 27 | 28 | noop := config.Telemetry().Disabled 29 | 30 | var installID string 31 | 32 | if !noop { 33 | installID = getInstallID() 34 | } 35 | 36 | _instance = &service{ 37 | noop: noop, 38 | installID: installID, 39 | orgName: config.Telemetry().OrganizationName, 40 | } 41 | 42 | touchInstallIDFile() 43 | } 44 | 45 | func Shutdown() {} 46 | 47 | type service struct { 48 | noop bool 49 | installID string 50 | orgName string 51 | } 52 | 53 | func (s *service) TrackEvent(req Request, event Event, metadata ...map[string]any) { 54 | if s.noop { 55 | return 56 | } 57 | 58 | if !isCEEvent(event) { 59 | return 60 | } 61 | 62 | ev := CEEvent{ 63 | Event: event, 64 | } 65 | 66 | if s.installID != "" { 67 | ev.InstallID = s.installID 68 | } 69 | 70 | if s.orgName != "" { 71 | ev.OrgName = s.orgName 72 | } 73 | 74 | if len(metadata) > 0 { 75 | ev.Data = metadata[0] 76 | } 77 | 78 | b, _ := json.Marshal(ev) 79 | request, _ := http.NewRequest("POST", apiEndpoint, bytes.NewBuffer(b)) 80 | 81 | _, err := http.DefaultClient.Do(request) 82 | if err != nil { 83 | // if we error, make it noop so we don't continue to try and error 84 | s.noop = true 85 | } 86 | } 87 | 88 | const ( 89 | installIDFile = "/tmp/_zep" 90 | unknownID = "UNKNOWN" 91 | 92 | apiEndpoint = "https://api.getzep.com/api/v2/telemetry" 93 | ) 94 | 95 | func touchInstallIDFile() { 96 | go func() { 97 | t := time.NewTicker(1 * time.Hour) 98 | 99 | for { 100 | <-t.C 101 | 102 | if _, err := os.Stat(installIDFile); os.IsNotExist(err) { 103 | return 104 | } 105 | 106 | os.ReadFile(installIDFile) //nolint:errcheck,revive // we don't care if this fails 107 | } 108 | }() 109 | } 110 | 111 | func getInstallID() string { 112 | if _, err := os.Stat(installIDFile); os.IsNotExist(err) { 113 | return createInstallID() 114 | } 115 | 116 | b, err := os.ReadFile(installIDFile) 117 | if err != nil { 118 | return unknownID 119 | } 120 | 121 | return strings.TrimSpace(string(b)) 122 | } 123 | 124 | func createInstallID() string { 125 | id := uuid.New().String() 126 | 127 | err := os.WriteFile(installIDFile, []byte(id), installIDFilePermissions) //nolint:gosec // we want this to be readable by the user 128 | if err != nil { 129 | return unknownID 130 | } 131 | 132 | return id 133 | } 134 | 135 | func isCEEvent(event Event) bool { 136 | return event == Event_CEStart || event == Event_CEStop 137 | } 138 | -------------------------------------------------------------------------------- /src/lib/util/httputil/http_base.go: -------------------------------------------------------------------------------- 1 | package httputil 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "time" 11 | 12 | "github.com/hashicorp/go-retryablehttp" 13 | ) 14 | 15 | const ( 16 | DefaultHTTPTimeout = 10 * time.Second 17 | DefaultMaxRetryAttempts = 3 18 | ) 19 | 20 | type HTTPBaser interface { 21 | Request(ctx context.Context, payload any) ([]byte, error) 22 | healthCheck(ctx context.Context) error 23 | } 24 | 25 | var _ HTTPBaser = &HTTPBase{} 26 | 27 | // HTTPBase is a MixIn for Models that have HTTP APIs and use Bearer tokens for authorization 28 | type HTTPBase struct { 29 | ApiURL string 30 | ApiKey string 31 | HealthURL string 32 | ServerName string 33 | RequestTimeOut time.Duration 34 | MaxRetryAttempts int 35 | } 36 | 37 | // request makes a POST request to the LLM's API endpoint. payload is marshalled to JSON and sent 38 | // as the request body. The response body is returned as a []byte. 39 | // Assumes the content type is application/json 40 | func (h *HTTPBase) Request(ctx context.Context, payload any) ([]byte, error) { 41 | var requestTimeout time.Duration 42 | if h.RequestTimeOut != 0 { 43 | requestTimeout = h.RequestTimeOut 44 | } else { 45 | requestTimeout = DefaultHTTPTimeout 46 | } 47 | 48 | var maxRetryAttempts int 49 | if h.MaxRetryAttempts != 0 { 50 | maxRetryAttempts = h.MaxRetryAttempts 51 | } else { 52 | maxRetryAttempts = DefaultMaxRetryAttempts 53 | } 54 | 55 | ctx, cancel := context.WithTimeout(ctx, requestTimeout) 56 | defer cancel() 57 | 58 | httpClient := NewRetryableHTTPClient( 59 | maxRetryAttempts, 60 | requestTimeout, 61 | IgnoreBadRequestRetryPolicy, 62 | h.ServerName, 63 | ) 64 | 65 | p, err := json.Marshal(payload) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | body := bytes.NewBuffer(p) 71 | 72 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.ApiURL, body) 73 | if err != nil { 74 | return nil, err 75 | } 76 | 77 | req.Header.Set("Authorization", "Bearer "+h.ApiKey) 78 | req.Header.Set("Content-Type", "application/json") 79 | 80 | resp, err := httpClient.Do(req) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | defer resp.Body.Close() 86 | 87 | if resp.StatusCode != http.StatusOK { 88 | return nil, fmt.Errorf( 89 | "error making POST request: %d - %s", 90 | resp.StatusCode, 91 | resp.Status, 92 | ) 93 | } 94 | 95 | rb, err := io.ReadAll(resp.Body) 96 | if err != nil { 97 | return nil, err 98 | } 99 | 100 | return rb, nil 101 | } 102 | 103 | func (h *HTTPBase) healthCheck(ctx context.Context) error { 104 | ctx, cancel := context.WithTimeout(ctx, DefaultHTTPTimeout) 105 | defer cancel() 106 | 107 | httpClient := NewRetryableHTTPClient( 108 | 1, 109 | DefaultHTTPTimeout, 110 | retryablehttp.DefaultRetryPolicy, 111 | h.ServerName, 112 | ) 113 | 114 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, h.HealthURL, http.NoBody) 115 | if err != nil { 116 | return err 117 | } 118 | 119 | resp, err := httpClient.Do(req) 120 | if err != nil { 121 | return err 122 | } 123 | 124 | defer resp.Body.Close() 125 | 126 | if resp.StatusCode != http.StatusOK { 127 | return fmt.Errorf("health check failed with status: %d", resp.StatusCode) 128 | } 129 | 130 | return nil 131 | } 132 | -------------------------------------------------------------------------------- /src/lib/util/httputil/http_base_mock.go: -------------------------------------------------------------------------------- 1 | package httputil 2 | 3 | import "context" 4 | 5 | type MockHTTPBase struct { 6 | ReturnPayload []byte 7 | } 8 | 9 | func (m *MockHTTPBase) Request(_ context.Context, _ any) ([]byte, error) { 10 | return m.ReturnPayload, nil 11 | } 12 | 13 | func (m *MockHTTPBase) healthCheck(_ context.Context) error { 14 | return nil 15 | } 16 | -------------------------------------------------------------------------------- /src/lib/util/httputil/retryable_http_client.go: -------------------------------------------------------------------------------- 1 | package httputil 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "net/http" 7 | "net/http/httptrace" 8 | "sync" 9 | "time" 10 | 11 | "github.com/hashicorp/go-retryablehttp" 12 | "go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace" 13 | "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" 14 | 15 | "github.com/getzep/zep/lib/logger" 16 | ) 17 | 18 | var httpClients sync.Map 19 | 20 | const ( 21 | DefaultRetryMax = 3 22 | DefaultTimeout = 5 * time.Second 23 | MaxIdleConns = 100 24 | MaxIdleConnsPerHost = 20 25 | IdleConnTimeout = 30 * time.Second 26 | ) 27 | 28 | type HTTPClient interface { 29 | Do(req *http.Request) (*http.Response, error) 30 | } 31 | 32 | // NewRetryableHTTPClient returns a new retryable HTTP client with the given retryMax and timeout. 33 | // The retryable HTTP transport is wrapped in an OpenTelemetry transport. 34 | func NewRetryableHTTPClient( 35 | retryMax int, 36 | timeout time.Duration, 37 | retryPolicy retryablehttp.CheckRetry, 38 | serverName string, 39 | ) *http.Client { 40 | client, ok := httpClients.Load(serverName) 41 | if ok { 42 | if httpClient, ok := client.(*http.Client); ok { 43 | return httpClient 44 | } 45 | } 46 | 47 | tlsConfig := &tls.Config{ 48 | MinVersion: tls.VersionTLS12, 49 | } 50 | if serverName != "" { 51 | tlsConfig.ServerName = serverName 52 | } 53 | 54 | httpClient := retryablehttp.Client{ 55 | HTTPClient: &http.Client{ 56 | Timeout: timeout, 57 | Transport: otelhttp.NewTransport(&http.Transport{ 58 | TLSClientConfig: tlsConfig, 59 | MaxIdleConns: MaxIdleConns, 60 | MaxIdleConnsPerHost: MaxIdleConnsPerHost, 61 | IdleConnTimeout: IdleConnTimeout, 62 | ResponseHeaderTimeout: timeout, 63 | DisableKeepAlives: false, 64 | }, otelhttp.WithClientTrace( 65 | func(ctx context.Context) *httptrace.ClientTrace { 66 | return otelhttptrace.NewClientTrace(ctx) 67 | }), 68 | ), 69 | }, 70 | Logger: logger.GetLogger(), 71 | RetryMax: retryMax, 72 | Backoff: retryablehttp.DefaultBackoff, 73 | CheckRetry: retryPolicy, 74 | } 75 | 76 | httpClients.Store(serverName, &httpClient) 77 | 78 | return httpClient.HTTPClient 79 | } 80 | 81 | func IgnoreBadRequestRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { 82 | if resp != nil && resp.StatusCode != http.StatusOK { 83 | logger.Warn("Retry policy invoked with response", "status", resp.Status, "error", err) 84 | } 85 | 86 | // do not retry on context.Canceled or context.DeadlineExceeded 87 | if ctx.Err() != nil { 88 | return false, ctx.Err() 89 | } 90 | 91 | // Do not retry 400 errors as they're used by OpenAI to indicate maximum 92 | // context length exceeded 93 | if resp != nil && resp.StatusCode == http.StatusBadRequest { 94 | return false, err 95 | } 96 | 97 | shouldRetry, _ := retryablehttp.DefaultRetryPolicy(ctx, resp, err) 98 | return shouldRetry, nil 99 | } 100 | -------------------------------------------------------------------------------- /src/lib/util/utils.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bytes" 5 | "math/rand" 6 | "reflect" 7 | "text/template" 8 | ) 9 | 10 | func ParsePrompt(promptTemplate string, data any) (string, error) { 11 | tmpl, err := template.New("prompt").Parse(promptTemplate) 12 | if err != nil { 13 | return "", err 14 | } 15 | 16 | var buf bytes.Buffer 17 | err = tmpl.Execute(&buf, data) 18 | if err != nil { 19 | return "", err 20 | } 21 | 22 | return buf.String(), nil 23 | } 24 | 25 | // StructToMap converts a struct to a map, recursively handling nested structs or lists of structs. 26 | func StructToMap(item any) map[string]any { 27 | val := reflect.ValueOf(item) 28 | 29 | processSlice := func(val reflect.Value) []any { 30 | sliceOut := make([]any, val.Len()) 31 | for i := 0; i < val.Len(); i++ { 32 | sliceVal := val.Index(i) 33 | if sliceVal.Kind() == reflect.Struct { 34 | sliceOut[i] = StructToMap(sliceVal.Interface()) 35 | } else { 36 | sliceOut[i] = sliceVal.Interface() 37 | } 38 | } 39 | return sliceOut 40 | } 41 | 42 | switch val.Kind() { 43 | case reflect.Slice: 44 | return map[string]any{"data": processSlice(val)} 45 | case reflect.Ptr: 46 | val = val.Elem() 47 | if val.Kind() != reflect.Struct { 48 | return map[string]any{} 49 | } 50 | default: 51 | if val.Kind() != reflect.Struct { 52 | return map[string]any{} 53 | } 54 | } 55 | 56 | out := make(map[string]any) 57 | typeOfT := val.Type() 58 | 59 | for i := 0; i < val.NumField(); i++ { 60 | field := typeOfT.Field(i) 61 | value := val.Field(i) 62 | 63 | switch value.Kind() { 64 | case reflect.Struct: 65 | out[field.Name] = StructToMap(value.Interface()) 66 | case reflect.Slice: 67 | out[field.Name] = processSlice(value) 68 | default: 69 | out[field.Name] = value.Interface() 70 | } 71 | } 72 | 73 | return out 74 | } 75 | 76 | func MergeMaps[T any](maps ...map[string]T) map[string]T { 77 | result := make(map[string]T) 78 | for _, m := range maps { 79 | for k, v := range m { 80 | result[k] = v 81 | } 82 | } 83 | return result 84 | } 85 | 86 | // ShuffleSlice shuffles a slice in place. 87 | func ShuffleSlice[T any](a []T) { 88 | rand.Shuffle(len(a), func(i, j int) { a[i], a[j] = a[j], a[i] }) 89 | } 90 | 91 | func IsInterfaceNilValue(i any) bool { 92 | return i == nil || reflect.ValueOf(i).IsNil() 93 | } 94 | 95 | type ptrTypes interface { 96 | int | int32 | int64 | float32 | float64 | bool | string 97 | } 98 | 99 | func AsPtr[T ptrTypes](value T) *T { 100 | return &value 101 | } 102 | 103 | // SafelyDereference safely dereferences a pointer of any type T. 104 | // It returns the value pointed to if the pointer is not nil, otherwise it returns the zero value of T. 105 | func SafelyDereference[T any](ptr *T) T { 106 | if ptr != nil { 107 | return *ptr // Dereference the pointer and return the value 108 | } 109 | var zero T // Initialize a variable with the zero value of type T 110 | return zero // Return the zero value if ptr is nil 111 | } 112 | -------------------------------------------------------------------------------- /src/lib/zerrors/errors.go: -------------------------------------------------------------------------------- 1 | package zerrors 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | type ErrorResponse struct { 9 | Message string `json:"message"` 10 | } 11 | 12 | /* NotFoundError */ 13 | 14 | var ErrNotFound = errors.New("not found") 15 | 16 | type NotFoundError struct { 17 | Resource string 18 | } 19 | 20 | func (e *NotFoundError) Error() string { 21 | return fmt.Sprintf("%s not found", e.Resource) 22 | } 23 | 24 | func (*NotFoundError) Unwrap() error { 25 | return ErrNotFound 26 | } 27 | 28 | func NewNotFoundError(resource string) error { 29 | return &NotFoundError{Resource: resource} 30 | } 31 | 32 | /* UnauthorizedError */ 33 | 34 | var ErrUnauthorized = errors.New("unauthorized") 35 | 36 | type UnauthorizedError struct { 37 | Message string 38 | } 39 | 40 | func (e *UnauthorizedError) Error() string { 41 | return fmt.Sprintf("unauthorized %s", e.Message) 42 | } 43 | 44 | func (*UnauthorizedError) Unwrap() error { 45 | return ErrUnauthorized 46 | } 47 | 48 | func NewUnauthorizedError(message string) error { 49 | return &UnauthorizedError{Message: message} 50 | } 51 | 52 | /* BadRequestError */ 53 | 54 | var ErrBadRequest = errors.New("bad request") 55 | 56 | type BadRequestError struct { 57 | Message string 58 | } 59 | 60 | func (e *BadRequestError) Error() string { 61 | return fmt.Sprintf("bad request: %s", e.Message) 62 | } 63 | 64 | func (*BadRequestError) Unwrap() error { 65 | return ErrBadRequest 66 | } 67 | 68 | func NewBadRequestError(message string) error { 69 | return &BadRequestError{Message: message} 70 | } 71 | 72 | /* CustomMessageInternalError */ 73 | 74 | var ErrInternalCustomMessage = errors.New("internal error") 75 | 76 | type CustomMessageInternalError struct { 77 | // User friendly message 78 | ExternalMessage string 79 | // Internal message, raw error message to be logged to sentry 80 | InternalMessage string 81 | } 82 | 83 | func (e *CustomMessageInternalError) Error() string { 84 | return e.ExternalMessage 85 | } 86 | 87 | func (*CustomMessageInternalError) Unwrap() error { 88 | return ErrInternalCustomMessage 89 | } 90 | 91 | func NewCustomMessageInternalError(externalMessage, internalMessage string) error { 92 | return &CustomMessageInternalError{ExternalMessage: externalMessage, InternalMessage: internalMessage} 93 | } 94 | 95 | var ErrDeprecated = errors.New("deprecated") 96 | 97 | type DeprecationError struct { 98 | Message string 99 | } 100 | 101 | func (e *DeprecationError) Error() string { 102 | return fmt.Sprintf("deprecation error: %s", e.Message) 103 | } 104 | 105 | func (*DeprecationError) Unwrap() error { 106 | return ErrDeprecated 107 | } 108 | 109 | func NewDeprecationError(message string) error { 110 | return &DeprecationError{Message: message} 111 | } 112 | 113 | var ErrLockAcquisitionFailed = errors.New("failed to acquire advisory lock") 114 | 115 | type AdvisoryLockError struct { 116 | Err error 117 | } 118 | 119 | func (e AdvisoryLockError) Error() string { 120 | if e.Err != nil { 121 | return fmt.Sprintf("failed to acquire advisory lock: %v", e.Err) 122 | } 123 | return ErrLockAcquisitionFailed.Error() 124 | } 125 | 126 | func (AdvisoryLockError) Unwrap() error { 127 | return ErrLockAcquisitionFailed 128 | } 129 | 130 | func NewAdvisoryLockError(err error) error { 131 | return &AdvisoryLockError{Err: err} 132 | } 133 | 134 | var ErrSessionEnded = errors.New("session ended") 135 | 136 | type SessionEndedError struct { 137 | Message string 138 | } 139 | 140 | func (e *SessionEndedError) Error() string { 141 | return fmt.Sprintf("session ended: %s", e.Message) 142 | } 143 | 144 | func (*SessionEndedError) Unwrap() error { 145 | return ErrSessionEnded 146 | } 147 | 148 | func NewSessionEndedError(message string) error { 149 | return &SessionEndedError{Message: message} 150 | } 151 | 152 | var ErrRepeatedPattern = errors.New("llm provider reports too many repeated characters") 153 | 154 | type RepeatedPatternError struct { 155 | Message string 156 | } 157 | 158 | func (e *RepeatedPatternError) Error() string { 159 | return fmt.Sprintf("repeated pattern: %s", e.Message) 160 | } 161 | 162 | func (*RepeatedPatternError) Unwrap() error { 163 | return ErrRepeatedPattern 164 | } 165 | 166 | func NewRepeatedPatternError(message string) error { 167 | return &RepeatedPatternError{Message: message} 168 | } 169 | -------------------------------------------------------------------------------- /src/lib/zerrors/storage.go: -------------------------------------------------------------------------------- 1 | package zerrors 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/uptrace/bun/driver/pgdriver" 8 | ) 9 | 10 | type StorageError struct { 11 | Message string 12 | OriginalError error 13 | } 14 | 15 | func (e *StorageError) Error() string { 16 | return fmt.Sprintf("storage error: %s (original error: %v)", e.Message, e.OriginalError) 17 | } 18 | 19 | func NewStorageError(message string, originalError error) *StorageError { 20 | return &StorageError{Message: message, OriginalError: originalError} 21 | } 22 | 23 | var ErrEmbeddingMismatch = errors.New("embedding width mismatch") 24 | 25 | type EmbeddingMismatchError struct { 26 | Message string 27 | OriginalError error 28 | } 29 | 30 | func (e *EmbeddingMismatchError) Error() string { 31 | return fmt.Sprintf( 32 | "embedding width mismatch. please ensure that the embeddings "+ 33 | "you have configured in the zep config are the same width as those "+ 34 | "you are generating. (original error: %v)", 35 | e.OriginalError, 36 | ) 37 | } 38 | 39 | func (*EmbeddingMismatchError) Unwrap() error { 40 | return ErrEmbeddingMismatch 41 | } 42 | 43 | func NewEmbeddingMismatchError( 44 | originalError error, 45 | ) *EmbeddingMismatchError { 46 | return &EmbeddingMismatchError{ 47 | OriginalError: originalError, 48 | } 49 | } 50 | 51 | func CheckForIntegrityViolationError(err error, integrityErrorMessage, generalErrorMessage string) error { 52 | var pgDriverError pgdriver.Error 53 | if errors.As(err, &pgDriverError) && pgDriverError.IntegrityViolation() { 54 | return NewBadRequestError(integrityErrorMessage) 55 | } 56 | return fmt.Errorf("%s %w", generalErrorMessage, err) 57 | } 58 | -------------------------------------------------------------------------------- /src/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | 11 | "github.com/getzep/zep/api" 12 | "github.com/getzep/zep/lib/config" 13 | "github.com/getzep/zep/lib/logger" 14 | "github.com/getzep/zep/lib/observability" 15 | "github.com/getzep/zep/lib/telemetry" 16 | "github.com/getzep/zep/models" 17 | ) 18 | 19 | func main() { 20 | config.Load() 21 | 22 | logger.InitDefaultLogger() 23 | 24 | as := newAppState() 25 | 26 | srv, err := api.Create(as) 27 | if err != nil { 28 | logger.Panic("Failed to create server", "error", err) 29 | } 30 | 31 | done := setupSignalHandler(as, srv) 32 | 33 | err = srv.ListenAndServe() 34 | if err != nil && !errors.Is(err, http.ErrServerClosed) { 35 | logger.Panic("Failed to start server", "error", err) 36 | } 37 | 38 | <-done 39 | } 40 | 41 | func setupSignalHandler(as *models.AppState, srv *http.Server) chan struct{} { 42 | done := make(chan struct{}, 1) 43 | 44 | signalCh := make(chan os.Signal, 1) 45 | signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM) 46 | 47 | go func() { 48 | <-signalCh 49 | 50 | // the order of these calls is important and intentional 51 | // shutting down the server and task router first stops all work 52 | // then we shut down ancillary services 53 | // then we close database connections 54 | // finally close observability. this is last to ensure we can capture 55 | // any errors that occurred during shutdown. 56 | 57 | // ignoring the error here because we're going to shutdown anyways. 58 | // the error here is irrelevant as it is not actionable and very unlikely to 59 | // happen. 60 | srv.Shutdown(context.Background()) 61 | 62 | if err := as.TaskRouter.Close(); err != nil { 63 | logger.Error("Error closing task router", "error", err) 64 | } 65 | 66 | telemetry.Shutdown() 67 | 68 | gracefulShutdown() 69 | 70 | if err := as.DB.Close(); err != nil { 71 | logger.Error("Error closing database connection", "error", err) 72 | } 73 | 74 | observability.Shutdown() 75 | 76 | done <- struct{}{} 77 | }() 78 | 79 | return done 80 | } 81 | -------------------------------------------------------------------------------- /src/models/app_state_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package models 3 | 4 | type AppState struct { 5 | AppStateCommon 6 | } 7 | -------------------------------------------------------------------------------- /src/models/fact_common.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | type Fact struct { 10 | UUID uuid.UUID `json:"uuid"` 11 | CreatedAt time.Time `json:"created_at"` 12 | Fact string `json:"fact"` 13 | Rating *float64 `json:"rating,omitempty"` 14 | } 15 | -------------------------------------------------------------------------------- /src/models/memory_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package models 3 | 4 | type Memory struct { 5 | MemoryCommon 6 | } 7 | 8 | type MemoryFilterOptions struct{} 9 | 10 | func (m *Message) MessageTask(rs *RequestState, memory Memory) MessageTask { 11 | return MessageTask{ 12 | MessageTaskCommon: MessageTaskCommon{ 13 | TaskState: rs.GetTaskState(m.UUID), 14 | }, 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/models/memory_common.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | ) 10 | 11 | type RoleType string 12 | 13 | const ( 14 | NoRole RoleType = "norole" 15 | SystemRole RoleType = "system" 16 | AssistantRole RoleType = "assistant" 17 | UserRole RoleType = "user" 18 | FunctionRole RoleType = "function" 19 | ToolRole RoleType = "tool" 20 | ) 21 | 22 | var validRoleTypes = map[string]RoleType{ 23 | string(NoRole): NoRole, 24 | string(SystemRole): SystemRole, 25 | string(AssistantRole): AssistantRole, 26 | string(UserRole): UserRole, 27 | string(FunctionRole): FunctionRole, 28 | string(ToolRole): ToolRole, 29 | } 30 | 31 | func (rt *RoleType) UnmarshalJSON(b []byte) error { 32 | str := strings.Trim(string(b), "\"") 33 | 34 | if str == "" { 35 | *rt = NoRole 36 | return nil 37 | } 38 | 39 | value, ok := validRoleTypes[str] 40 | if !ok { 41 | return fmt.Errorf("invalid RoleType: %v", str) 42 | } 43 | 44 | *rt = value 45 | return nil 46 | } 47 | 48 | func (rt RoleType) MarshalJSON() ([]byte, error) { 49 | return []byte(fmt.Sprintf("%q", rt)), nil 50 | } 51 | 52 | // Message Represents a message in a conversation. 53 | type Message struct { 54 | // The unique identifier of the message. 55 | UUID uuid.UUID `json:"uuid"` 56 | // The timestamp of when the message was created. 57 | CreatedAt time.Time `json:"created_at"` 58 | // The timestamp of when the message was last updated. 59 | UpdatedAt time.Time `json:"updated_at"` 60 | // The role of the sender of the message (e.g., "user", "assistant"). 61 | Role string `json:"role"` 62 | // The type of the role (e.g., "user", "system"). 63 | RoleType RoleType `json:"role_type,omitempty"` 64 | // The content of the message. 65 | Content string `json:"content"` 66 | // The metadata associated with the message. 67 | Metadata map[string]any `json:"metadata,omitempty"` 68 | // The number of tokens in the message. 69 | TokenCount int `json:"token_count"` 70 | } 71 | 72 | type MessageMetadataUpdate struct { 73 | // The metadata to update 74 | Metadata map[string]any `json:"metadata" validate:"required"` 75 | } 76 | 77 | type MessageListResponse struct { 78 | // A list of message objects. 79 | Messages []Message `json:"messages"` 80 | // The total number of messages. 81 | TotalCount int `json:"total_count"` 82 | // The number of messages returned. 83 | RowCount int `json:"row_count"` 84 | } 85 | 86 | type MemoryCommon struct { 87 | // A list of message objects, where each message contains a role and content. 88 | Messages []Message `json:"messages"` 89 | RelevantFacts []Fact `json:"relevant_facts"` 90 | // A dictionary containing metadata associated with the memory. 91 | Metadata map[string]any `json:"metadata,omitempty"` 92 | } 93 | 94 | type MemoryFilterOption = FilterOption[MemoryFilterOptions] 95 | -------------------------------------------------------------------------------- /src/models/memorystore_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package models 3 | 4 | type MemoryStore interface { 5 | MemoryStoreCommon 6 | } 7 | 8 | type SessionStorer interface { 9 | SessionStorerCommon 10 | } 11 | 12 | type MessageStorer interface { 13 | MessageStorerCommon 14 | } 15 | 16 | type MemoryStorer interface { 17 | MemoryStorerCommon 18 | } 19 | -------------------------------------------------------------------------------- /src/models/memorystore_common.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | type MemoryStoreCommon interface { 10 | MemoryStorer 11 | MessageStorer 12 | SessionStorer 13 | PurgeDeleted(ctx context.Context, schemaName string) error 14 | PutMessages(ctx context.Context, sessionID string, messages []Message) ([]Message, error) 15 | } 16 | 17 | type SessionStorerCommon interface { 18 | CreateSession(ctx context.Context, session *CreateSessionRequest) (*Session, error) 19 | GetSession(ctx context.Context, sessionID string) (*Session, error) 20 | UpdateSession(ctx context.Context, session *UpdateSessionRequest, isPrivileged bool) (*Session, error) 21 | DeleteSession(ctx context.Context, sessionID string) error 22 | ListSessions(ctx context.Context, cursor int64, limit int) ([]*Session, error) 23 | ListSessionsOrdered( 24 | ctx context.Context, 25 | pageNumber, pageSize int, 26 | orderedBy string, 27 | asc bool, 28 | ) (*SessionListResponse, error) 29 | } 30 | 31 | type MessageStorerCommon interface { 32 | GetMessagesLastN(ctx context.Context, sessionID string, lastNMessages int, beforeUUID uuid.UUID) ([]Message, error) 33 | GetMessagesByUUID(ctx context.Context, sessionID string, uuids []uuid.UUID) ([]Message, error) 34 | GetMessageList(ctx context.Context, sessionID string, pageNumber, pageSize int) (*MessageListResponse, error) 35 | UpdateMessages(ctx context.Context, sessionID string, messages []Message, isPrivileged, includeContent bool) error 36 | } 37 | 38 | type MemoryStorerCommon interface { 39 | GetMemory(ctx context.Context, sessionID string, lastNmessages int, opts ...MemoryFilterOption) (*Memory, error) 40 | // PutMemory stores a Memory for a given sessionID. If the SessionID doesn't exist, a new one is created. 41 | PutMemory(ctx context.Context, sessionID string, memoryMessages *Memory, skipNotify bool) error // skipNotify is used to prevent loops when calling NotifyExtractors. 42 | SearchSessions(ctx context.Context, query *SessionSearchQuery, limit int) (*SessionSearchResponse, error) 43 | } 44 | -------------------------------------------------------------------------------- /src/models/options.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type FilterOption[T any] func(*T) 4 | 5 | func ApplyFilterOptions[T any](opts ...FilterOption[T]) T { 6 | var o T 7 | for _, opt := range opts { 8 | opt(&o) 9 | } 10 | return o 11 | } 12 | -------------------------------------------------------------------------------- /src/models/projectsetting.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import "github.com/google/uuid" 4 | 5 | type ProjectSettings struct { 6 | UUID uuid.UUID `json:"uuid"` 7 | } 8 | -------------------------------------------------------------------------------- /src/models/request_state_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package models 3 | 4 | import "github.com/google/uuid" 5 | 6 | type RequestState struct { 7 | RequestStateCommon 8 | } 9 | 10 | func (rs *RequestState) GetTaskState(itemUUID uuid.UUID, projectUUIDOverride ...uuid.UUID) TaskState { 11 | projectUUID := rs.ProjectUUID 12 | if len(projectUUIDOverride) > 0 { 13 | projectUUID = projectUUIDOverride[0] 14 | } 15 | 16 | return TaskState{ 17 | TaskStateCommon: TaskStateCommon{ 18 | UUID: itemUUID, 19 | ProjectUUID: projectUUID, 20 | SchemaName: rs.SchemaName, 21 | }, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/models/search_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package models 3 | 4 | type SessionSearchResult struct { 5 | SessionSearchResultCommon 6 | } 7 | 8 | type SessionSearchQuery struct { 9 | SessionSearchQueryCommon 10 | } 11 | 12 | func (s SessionSearchQuery) BreadcrumbFields() map[string]any { 13 | return map[string]any{} 14 | } 15 | -------------------------------------------------------------------------------- /src/models/search_common.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type SessionSearchQueryCommon struct { 4 | // The search text. 5 | Text string `json:"text"` 6 | // User ID used to determine which sessions to search. Required on Community Edition. 7 | UserID string `json:"user_id,omitempty"` 8 | 9 | // the session ids to search 10 | SessionIDs []string `json:"session_ids,omitempty"` 11 | } 12 | 13 | type SessionSearchResultCommon struct { 14 | Fact *Fact `json:"fact"` 15 | Embedding []float32 `json:"-" swaggerignore:"true"` 16 | } 17 | 18 | type SessionSearchRequest struct { 19 | Query *SessionSearchQuery `json:"query"` 20 | } 21 | 22 | type SessionSearchResponse struct { 23 | Results []SessionSearchResult `json:"results"` 24 | } 25 | -------------------------------------------------------------------------------- /src/models/session_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package models 3 | 4 | type Session struct { 5 | SessionCommon 6 | } 7 | 8 | type SessionStore interface { 9 | SessionStoreCommon 10 | } 11 | 12 | type CreateSessionRequest struct { 13 | CreateSessionRequestCommon 14 | } 15 | 16 | type UpdateSessionRequest struct { 17 | UpdateSessionRequestCommon 18 | } 19 | -------------------------------------------------------------------------------- /src/models/session_common.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/google/uuid" 8 | ) 9 | 10 | type SessionCommon struct { 11 | UUID uuid.UUID `json:"uuid"` 12 | ID int64 `json:"id"` 13 | CreatedAt time.Time `json:"created_at"` 14 | UpdatedAt time.Time `json:"updated_at"` 15 | DeletedAt *time.Time `json:"deleted_at"` 16 | EndedAt *time.Time `json:"ended_at"` 17 | SessionID string `json:"session_id"` 18 | Metadata map[string]any `json:"metadata"` 19 | // Must be a pointer to allow for null values 20 | UserID *string `json:"user_id"` 21 | ProjectUUID uuid.UUID `json:"project_uuid"` 22 | } 23 | 24 | type SessionListResponse struct { 25 | Sessions []*Session `json:"sessions"` 26 | TotalCount int `json:"total_count"` 27 | RowCount int `json:"response_count"` 28 | } 29 | 30 | type CreateSessionRequestCommon struct { 31 | // The unique identifier of the session. 32 | SessionID string `json:"session_id" validate:"required"` 33 | // The unique identifier of the user associated with the session 34 | UserID *string `json:"user_id"` 35 | // The metadata associated with the session. 36 | Metadata map[string]any `json:"metadata"` 37 | } 38 | 39 | type UpdateSessionRequestCommon struct { 40 | SessionID string `json:"session_id" swaggerignore:"true"` 41 | // The metadata to update 42 | Metadata map[string]any `json:"metadata" validate:"required"` 43 | } 44 | 45 | type SessionStoreCommon interface { 46 | Update(ctx context.Context, session *UpdateSessionRequest, isPrivileged bool) (*Session, error) 47 | Create(ctx context.Context, session *CreateSessionRequest) (*Session, error) 48 | Get(ctx context.Context, sessionID string) (*Session, error) 49 | Delete(ctx context.Context, sessionID string) error 50 | ListAll(ctx context.Context, cursor int64, limit int) ([]*Session, error) 51 | ListAllOrdered( 52 | ctx context.Context, 53 | pageNumber int, 54 | pageSize int, 55 | orderBy string, 56 | asc bool, 57 | ) (*SessionListResponse, error) 58 | } 59 | -------------------------------------------------------------------------------- /src/models/state.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "github.com/google/uuid" 5 | 6 | "github.com/getzep/zep/lib/enablement" 7 | "github.com/getzep/zep/lib/pg" 8 | ) 9 | 10 | type AppStateCommon struct { 11 | DB pg.Connection 12 | TaskRouter TaskRouter 13 | TaskPublisher TaskPublisher 14 | } 15 | 16 | type RequestStateCommon struct { 17 | Memories MemoryStore 18 | Users UserStore 19 | Sessions SessionStore 20 | 21 | ProjectUUID uuid.UUID 22 | SessionUUID uuid.UUID 23 | 24 | EnablementProfile enablement.Profile 25 | 26 | SchemaName string 27 | RequestTokenType string 28 | } 29 | 30 | func (rs *RequestState) GetProjectUUID() uuid.UUID { 31 | return rs.ProjectUUID 32 | } 33 | 34 | func (rs *RequestState) GetRequestTokenType() string { 35 | return rs.RequestTokenType 36 | } 37 | -------------------------------------------------------------------------------- /src/models/tasks_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package models 3 | 4 | type TaskPublisher interface { 5 | TaskPublisherCommon 6 | } 7 | 8 | type MessageTask struct { 9 | MessageTaskCommon 10 | } 11 | 12 | type TaskState struct { 13 | TaskStateCommon 14 | } 15 | -------------------------------------------------------------------------------- /src/models/tasks_common.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/ThreeDotsLabs/watermill/message" 7 | "github.com/google/uuid" 8 | ) 9 | 10 | type TaskTopic string 11 | 12 | const ( 13 | MessageEmbedderTopic TaskTopic = "message_embedder" 14 | PurgeDeletedResourcesTopic TaskTopic = "purge_deleted" 15 | ) 16 | 17 | type Task interface { 18 | Execute(ctx context.Context, event *message.Message) error 19 | HandleError(msgId string, err error) 20 | } 21 | 22 | type TaskRouter interface { 23 | Run(ctx context.Context) error 24 | AddTask(ctx context.Context, name string, taskType TaskTopic, task Task, numOfSubscribers int) 25 | AddTaskWithMultiplePools(ctx context.Context, name string, taskType TaskTopic, task Task, numberOfPools int) error 26 | RunHandlers(ctx context.Context) error 27 | IsRunning() bool 28 | Close() error 29 | } 30 | 31 | type TaskPublisherCommon interface { 32 | Publish(ctx context.Context, taskType TaskTopic, metadata map[string]string, payload any) error 33 | PublishMessage(ctx context.Context, metadata map[string]string, payload []MessageTask) error 34 | Close() error 35 | } 36 | 37 | type MessageTaskCommon struct { 38 | TaskState 39 | } 40 | 41 | type TaskStateCommon struct { 42 | UUID uuid.UUID `json:"uuid"` 43 | ProjectUUID uuid.UUID `json:"project_uuid"` 44 | SchemaName string `json:"schema_name"` 45 | } 46 | 47 | func (ts *TaskStateCommon) LogData(data ...any) []any { 48 | if ts.UUID != uuid.Nil { 49 | data = append(data, "uuid", ts.UUID) 50 | } 51 | 52 | if ts.ProjectUUID != uuid.Nil { 53 | data = append(data, "project_uuid", ts.ProjectUUID) 54 | } 55 | 56 | if ts.SchemaName != "" { 57 | data = append(data, "schema_name", ts.SchemaName) 58 | } 59 | 60 | return data 61 | } 62 | -------------------------------------------------------------------------------- /src/models/userstore.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/google/uuid" 8 | ) 9 | 10 | type User struct { 11 | UUID uuid.UUID `json:"uuid"` 12 | ID int64 `json:"id"` 13 | CreatedAt time.Time `json:"created_at"` 14 | UpdatedAt time.Time `json:"updated_at"` 15 | DeletedAt *time.Time `json:"deleted_at"` 16 | UserID string `json:"user_id"` 17 | Email string `json:"email,omitempty"` 18 | FirstName string `json:"first_name,omitempty"` 19 | LastName string `json:"last_name,omitempty"` 20 | ProjectUUID uuid.UUID `json:"project_uuid"` 21 | Metadata map[string]any `json:"metadata,omitempty"` 22 | SessionCount int `json:"session_count,omitempty"` 23 | } 24 | 25 | type UserListResponse struct { 26 | Users []*User `json:"users"` 27 | TotalCount int `json:"total_count"` 28 | RowCount int `json:"row_count"` 29 | } 30 | 31 | type CreateUserRequest struct { 32 | // The unique identifier of the user. 33 | UserID string `json:"user_id"` 34 | // The email address of the user. 35 | Email string `json:"email"` 36 | // The first name of the user. 37 | FirstName string `json:"first_name"` 38 | // The last name of the user. 39 | LastName string `json:"last_name"` 40 | // The metadata associated with the user. 41 | Metadata map[string]any `json:"metadata"` 42 | } 43 | 44 | type UpdateUserRequest struct { 45 | UUID uuid.UUID `json:"uuid" swaggerignore:"true"` 46 | UserID string `json:"user_id" swaggerignore:"true"` 47 | // The email address of the user. 48 | Email string `json:"email"` 49 | // The first name of the user. 50 | FirstName string `json:"first_name"` 51 | // The last name of the user. 52 | LastName string `json:"last_name"` 53 | // The metadata to update 54 | Metadata map[string]any `json:"metadata"` 55 | } 56 | 57 | type UserStore interface { 58 | Create(ctx context.Context, user *CreateUserRequest) (*User, error) 59 | Get(ctx context.Context, userID string) (*User, error) 60 | Update(ctx context.Context, user *UpdateUserRequest, isPrivileged bool) (*User, error) 61 | Delete(ctx context.Context, userID string) error 62 | GetSessionsForUser(ctx context.Context, userID string) ([]*Session, error) 63 | ListAll(ctx context.Context, cursor int64, limit int) ([]*User, error) 64 | ListAllOrdered(ctx context.Context, 65 | pageNumber int, 66 | pageSize int, 67 | orderBy string, 68 | asc bool, 69 | ) (*UserListResponse, error) 70 | } 71 | -------------------------------------------------------------------------------- /src/setup_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package main 3 | 4 | import ( 5 | "context" 6 | "log" 7 | 8 | "github.com/getzep/zep/lib/config" 9 | "github.com/getzep/zep/lib/graphiti" 10 | "github.com/getzep/zep/lib/telemetry" 11 | "github.com/getzep/zep/models" 12 | "github.com/getzep/zep/store" 13 | ) 14 | 15 | func setup(as *models.AppState) { 16 | graphiti.Setup() 17 | 18 | telemetry.I().TrackEvent(nil, telemetry.Event_CEStart) 19 | } 20 | 21 | func gracefulShutdown() { 22 | telemetry.I().TrackEvent(nil, telemetry.Event_CEStop) 23 | } 24 | 25 | func initializeDB(ctx context.Context, as *models.AppState) { 26 | err := store.MigrateSchema(ctx, as.DB, config.Postgres().SchemaName) 27 | if err != nil { 28 | log.Fatalf("Failed to migrate schema: %v", err) //nolint:revive // this is only called from main 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/state.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/getzep/zep/lib/pg" 7 | "github.com/getzep/zep/lib/telemetry" 8 | "github.com/getzep/zep/models" 9 | ) 10 | 11 | func newAppState() *models.AppState { 12 | as := &models.AppState{} 13 | 14 | as.DB = pg.NewConnection() 15 | 16 | initializeDB(context.Background(), as) 17 | 18 | telemetry.Setup() 19 | 20 | setup(as) 21 | 22 | return as 23 | } 24 | -------------------------------------------------------------------------------- /src/store/db_utils_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package store 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | 8 | "github.com/getzep/zep/lib/logger" 9 | "github.com/getzep/zep/lib/pg" 10 | ) 11 | 12 | // purgeDeletedResources purges deleted resources from the database. It will be called when a user or a session is deleted to hard delete the soft deleter resources. 13 | // On cloud a PurgeDeletedResources task is used instead 14 | func purgeDeletedResources(ctx context.Context, db pg.Connection) error { 15 | logger.Debug("purging memory store") 16 | 17 | for _, schema := range messageTableList { 18 | logger.Debug("purging schema", schema) 19 | _, err := db.NewDelete(). 20 | Model(schema). 21 | WhereDeleted(). 22 | ForceDelete(). 23 | Exec(ctx) 24 | if err != nil { 25 | return fmt.Errorf("error purging rows from %T: %w", schema, err) 26 | } 27 | } 28 | 29 | // Vacuum database post-purge. This is avoids issues with HNSW indexes 30 | // after deleting a large number of rows. 31 | // https://github.com/pgvector/pgvector/issues/244 32 | _, err := db.ExecContext(ctx, "VACUUM ANALYZE") 33 | if err != nil { 34 | return fmt.Errorf("error vacuuming database: %w", err) 35 | } 36 | 37 | logger.Info("completed purging store") 38 | 39 | return nil 40 | } 41 | -------------------------------------------------------------------------------- /src/store/memory_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package store 3 | 4 | import ( 5 | "context" 6 | "errors" 7 | 8 | "github.com/getzep/zep/lib/graphiti" 9 | "github.com/getzep/zep/lib/telemetry" 10 | "github.com/getzep/zep/models" 11 | ) 12 | 13 | const maxMessagesForFactRetrieval = 4 // 2 chat turns 14 | 15 | func (dao *memoryDAO) _get( 16 | ctx context.Context, 17 | session *models.Session, 18 | messages []models.Message, 19 | _ models.MemoryFilterOptions, 20 | ) (*models.Memory, error) { 21 | mForRetrieval := messages 22 | if len(messages) > maxMessagesForFactRetrieval { 23 | mForRetrieval = messages[len(messages)-maxMessagesForFactRetrieval:] 24 | } 25 | var result models.Memory 26 | groupID := session.SessionID 27 | if session.UserID != nil { 28 | groupID = *session.UserID 29 | } 30 | memory, err := graphiti.I().GetMemory( 31 | ctx, 32 | graphiti.GetMemoryRequest{ 33 | GroupID: groupID, 34 | MaxFacts: 5, 35 | Messages: mForRetrieval, 36 | }, 37 | ) 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | result.Messages = messages 43 | var memoryFacts []models.Fact 44 | for _, fact := range memory.Facts { 45 | createdAt := fact.CreatedAt 46 | if fact.ValidAt != nil { 47 | createdAt = *fact.ValidAt 48 | } 49 | memoryFacts = append(memoryFacts, models.Fact{ 50 | Fact: fact.Fact, 51 | UUID: fact.UUID, 52 | CreatedAt: createdAt, 53 | }) 54 | } 55 | result.RelevantFacts = memoryFacts 56 | return &result, nil 57 | } 58 | 59 | func (dao *memoryDAO) _initializeProcessingMemory( 60 | ctx context.Context, 61 | session *models.Session, 62 | memoryMessages *models.Memory, 63 | ) error { 64 | err := graphiti.I().PutMemory(ctx, session.SessionID, memoryMessages.Messages, true) 65 | if err != nil { 66 | return err 67 | } 68 | if session.UserID != nil { 69 | err = graphiti.I().PutMemory(ctx, *session.UserID, memoryMessages.Messages, true) 70 | } 71 | return err 72 | } 73 | 74 | func (dao *memoryDAO) _searchSessions(ctx context.Context, query *models.SessionSearchQuery, limit int) (*models.SessionSearchResponse, error) { 75 | if query == nil { 76 | return nil, errors.New("nil query received") 77 | } 78 | var groupIDs []string 79 | if query.UserID != "" { 80 | groupIDs = append(groupIDs, query.UserID) 81 | } 82 | if len(query.SessionIDs) > 0 { 83 | groupIDs = append(groupIDs, query.SessionIDs...) 84 | } 85 | result, err := graphiti.I().Search( 86 | ctx, 87 | graphiti.SearchRequest{ 88 | GroupIDs: groupIDs, 89 | Text: query.Text, 90 | MaxFacts: limit, 91 | }, 92 | ) 93 | if err != nil { 94 | return nil, err 95 | } 96 | 97 | var searchResults []models.SessionSearchResult 98 | 99 | for _, r := range result.Facts { 100 | createdAt := r.CreatedAt 101 | if r.ValidAt != nil { 102 | createdAt = *r.ValidAt 103 | } 104 | searchResults = append(searchResults, models.SessionSearchResult{ 105 | SessionSearchResultCommon: models.SessionSearchResultCommon{ 106 | Fact: &models.Fact{ 107 | Fact: r.Fact, 108 | UUID: r.UUID, 109 | CreatedAt: createdAt, 110 | }, 111 | }, 112 | }) 113 | } 114 | 115 | telemetry.I().TrackEvent(dao.requestState, telemetry.Event_SearchSessions, map[string]any{ 116 | "result_count": len(searchResults), 117 | "query_text_len": len(query.Text), 118 | }) 119 | 120 | return &models.SessionSearchResponse{ 121 | Results: searchResults, 122 | }, nil 123 | } 124 | -------------------------------------------------------------------------------- /src/store/memory_common.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "unicode/utf8" 8 | 9 | "github.com/getzep/zep/lib/enablement" 10 | "github.com/getzep/zep/lib/telemetry" 11 | "github.com/getzep/zep/lib/zerrors" 12 | "github.com/getzep/zep/models" 13 | "github.com/google/uuid" 14 | ) 15 | 16 | const defaultLastNMessages = 4 17 | 18 | func newMemoryDAO(appState *models.AppState, requestState *models.RequestState, sessionID string, lastNMessages int) *memoryDAO { 19 | return &memoryDAO{ 20 | appState: appState, 21 | requestState: requestState, 22 | sessionID: sessionID, 23 | lastNMessages: lastNMessages, 24 | } 25 | } 26 | 27 | // memoryDAO is a data access object for Memory. A Memory is an overlay over Messages. It is used to 28 | // retrieve a set of messages for a given sessionID, to store a new set of messages from 29 | // a chat client, and to search for messages. 30 | type memoryDAO struct { 31 | appState *models.AppState 32 | requestState *models.RequestState 33 | sessionID string 34 | lastNMessages int 35 | } 36 | 37 | func (dao *memoryDAO) Get(ctx context.Context, opts ...models.MemoryFilterOption) (*models.Memory, error) { 38 | if dao.lastNMessages < 0 { 39 | return nil, errors.New("lastNMessages cannot be negative") 40 | } 41 | 42 | memoryFilterOptions := models.ApplyFilterOptions(opts...) 43 | 44 | messageDAO := newMessageDAO(dao.appState, dao.requestState, dao.sessionID) 45 | 46 | // we need to get at least defaultLastNMessages messages 47 | mCnt := dao.lastNMessages 48 | if mCnt < defaultLastNMessages { 49 | mCnt = defaultLastNMessages 50 | } 51 | 52 | messages, err := messageDAO.GetLastN(ctx, mCnt, uuid.Nil) 53 | if err != nil { 54 | return nil, fmt.Errorf("failed to get messages: %w", err) 55 | } 56 | 57 | // return early if there are no messages 58 | if len(messages) == 0 { 59 | return &models.Memory{ 60 | MemoryCommon: models.MemoryCommon{ 61 | Messages: messages, 62 | }, 63 | }, nil 64 | } 65 | 66 | session, err := dao.requestState.Sessions.Get(ctx, dao.sessionID) 67 | if err != nil { 68 | return nil, fmt.Errorf("get failed to get session: %w", err) 69 | } 70 | 71 | // we only want to return max dao.lastNMessages messages for chat history 72 | mChatHistory := messages 73 | if len(messages) > dao.lastNMessages { 74 | mChatHistory = messages[len(messages)-dao.lastNMessages:] 75 | } 76 | 77 | result, err := dao._get(ctx, session, messages, memoryFilterOptions) 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | telemetry.I().TrackEvent(dao.requestState, telemetry.Event_GetMemory, map[string]any{ 83 | "message_count": len(mChatHistory), 84 | }) 85 | 86 | result.MemoryCommon.Messages = mChatHistory 87 | 88 | return result, nil 89 | } 90 | 91 | // Create stores a Memory for a given sessionID. If the SessionID doesn't exist, a new one is created. 92 | // If skipProcessing is true, the new messages will not be published to the message queue router. 93 | func (dao *memoryDAO) Create(ctx context.Context, memoryMessages *models.Memory, skipProcessing bool) error { 94 | sessionStore := NewSessionDAO(dao.appState, dao.requestState) 95 | 96 | // Try to update the session first. If no rows are affected, create a new session. 97 | session, err := sessionStore.Update(ctx, &models.UpdateSessionRequest{ 98 | UpdateSessionRequestCommon: models.UpdateSessionRequestCommon{ 99 | SessionID: dao.sessionID, 100 | }, 101 | }, false) 102 | if err != nil { 103 | if !errors.Is(err, zerrors.ErrNotFound) { 104 | return err 105 | } 106 | session, err = sessionStore.Create(ctx, &models.CreateSessionRequest{ 107 | CreateSessionRequestCommon: models.CreateSessionRequestCommon{ 108 | SessionID: dao.sessionID, 109 | }, 110 | }) 111 | if err != nil { 112 | return err 113 | } 114 | } 115 | 116 | if session.EndedAt != nil { 117 | return zerrors.NewSessionEndedError("session has ended") 118 | } 119 | 120 | messageDAO := newMessageDAO(dao.appState, dao.requestState, dao.sessionID) 121 | 122 | for _, msg := range memoryMessages.Messages { 123 | telemetry.I().TrackEvent(dao.requestState, 124 | telemetry.Event_CreateMemoryMessage, 125 | map[string]any{ 126 | "message_length": utf8.RuneCountInString(msg.Content), 127 | "with_metadata": len(msg.Metadata) > 0, 128 | "session_uuid": session.UUID.String(), 129 | }, 130 | ) 131 | enablement.I().TrackEvent(enablement.Event_CreateMemoryMessage, dao.requestState) 132 | } 133 | 134 | messageResult, err := messageDAO.CreateMany(ctx, memoryMessages.Messages) 135 | if err != nil { 136 | return err 137 | } 138 | memoryMessages.Messages = messageResult 139 | // If we are skipping pushing new messages to the message router, return early 140 | if skipProcessing { 141 | return nil 142 | } 143 | 144 | err = dao._initializeProcessingMemory(ctx, session, memoryMessages) 145 | if err != nil { 146 | return fmt.Errorf("failed to initialize processing memory: %w", err) 147 | } 148 | 149 | return nil 150 | } 151 | 152 | func (dao *memoryDAO) SearchSessions(ctx context.Context, query *models.SessionSearchQuery, limit int) (*models.SessionSearchResponse, error) { 153 | return dao._searchSessions(ctx, query, limit) 154 | } 155 | -------------------------------------------------------------------------------- /src/store/memorystore_common.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "database/sql" 7 | "encoding/binary" 8 | "errors" 9 | "fmt" 10 | "time" 11 | 12 | "github.com/failsafe-go/failsafe-go" 13 | "github.com/failsafe-go/failsafe-go/retrypolicy" 14 | "github.com/google/uuid" 15 | "github.com/uptrace/bun" 16 | 17 | "github.com/getzep/zep/lib/logger" 18 | "github.com/getzep/zep/lib/zerrors" 19 | "github.com/getzep/zep/models" 20 | ) 21 | 22 | func NewMemoryStore(as *models.AppState, rs *models.RequestState) models.MemoryStore { 23 | return &memoryStore{ 24 | as: as, 25 | rs: rs, 26 | } 27 | } 28 | 29 | type memoryStore struct { 30 | as *models.AppState 31 | rs *models.RequestState 32 | } 33 | 34 | func (ms *memoryStore) dao(sessionID string, lastNMessages int) *memoryDAO { 35 | return newMemoryDAO(ms.as, ms.rs, sessionID, lastNMessages) 36 | } 37 | 38 | func (ms *memoryStore) messages(sessionID string) *messageDAO { 39 | return newMessageDAO(ms.as, ms.rs, sessionID) 40 | } 41 | 42 | func (ms *memoryStore) GetSession(ctx context.Context, sessionID string) (*models.Session, error) { 43 | return ms.rs.Sessions.Get(ctx, sessionID) 44 | } 45 | 46 | func (ms *memoryStore) CreateSession(ctx context.Context, session *models.CreateSessionRequest) (*models.Session, error) { 47 | return ms.rs.Sessions.Create(ctx, session) 48 | } 49 | 50 | func (ms *memoryStore) UpdateSession(ctx context.Context, session *models.UpdateSessionRequest, isPrivileged bool) (*models.Session, error) { 51 | return ms.rs.Sessions.Update(ctx, session, isPrivileged) 52 | } 53 | 54 | func (ms *memoryStore) DeleteSession(ctx context.Context, sessionID string) error { 55 | return ms.rs.Sessions.Delete(ctx, sessionID) 56 | } 57 | 58 | func (ms *memoryStore) ListSessions(ctx context.Context, cursor int64, limit int) ([]*models.Session, error) { 59 | return ms.rs.Sessions.ListAll(ctx, cursor, limit) 60 | } 61 | 62 | func (ms *memoryStore) ListSessionsOrdered( 63 | ctx context.Context, 64 | pageNumber, pageSize int, 65 | orderedBy string, 66 | asc bool, 67 | ) (*models.SessionListResponse, error) { 68 | return ms.rs.Sessions.ListAllOrdered(ctx, pageNumber, pageSize, orderedBy, asc) 69 | } 70 | 71 | func (ms *memoryStore) GetMemory( 72 | ctx context.Context, 73 | sessionID string, 74 | lastNMessages int, 75 | opts ...models.MemoryFilterOption, 76 | ) (*models.Memory, error) { 77 | if lastNMessages < 0 { 78 | return nil, errors.New("cannot specify negative lastNMessages") 79 | } 80 | 81 | return ms.dao(sessionID, lastNMessages).Get(ctx, opts...) 82 | } 83 | 84 | func (ms *memoryStore) PutMemory( 85 | ctx context.Context, 86 | sessionID string, 87 | memoryMessages *models.Memory, 88 | skipProcessing bool, 89 | ) error { 90 | return ms.dao(sessionID, 0).Create(ctx, memoryMessages, skipProcessing) 91 | } 92 | 93 | func (ms *memoryStore) GetMessagesLastN( 94 | ctx context.Context, 95 | sessionID string, 96 | lastNMessages int, 97 | beforeUUID uuid.UUID, 98 | ) ([]models.Message, error) { 99 | if lastNMessages < 0 { 100 | return nil, errors.New("cannot specify negative lastNMessages") 101 | } 102 | 103 | return ms.messages(sessionID).GetLastN(ctx, lastNMessages, beforeUUID) 104 | } 105 | 106 | func (ms *memoryStore) GetMessageList( 107 | ctx context.Context, 108 | sessionID string, 109 | pageNumber, pageSize int, 110 | ) (*models.MessageListResponse, error) { 111 | return ms.messages(sessionID).GetListBySession(ctx, pageNumber, pageSize) 112 | } 113 | 114 | func (ms *memoryStore) GetMessagesByUUID( 115 | ctx context.Context, 116 | sessionID string, 117 | uuids []uuid.UUID, 118 | ) ([]models.Message, error) { 119 | return ms.messages(sessionID).GetListByUUID(ctx, uuids) 120 | } 121 | 122 | func (ms *memoryStore) PutMessages(ctx context.Context, sessionID string, messages []models.Message) ([]models.Message, error) { 123 | return ms.messages(sessionID).CreateMany(ctx, messages) 124 | } 125 | 126 | func (ms *memoryStore) UpdateMessages( 127 | ctx context.Context, 128 | sessionID string, 129 | messages []models.Message, 130 | isPrivileged, includeContent bool, 131 | ) error { 132 | return ms.messages(sessionID).UpdateMany(ctx, messages, includeContent, isPrivileged) 133 | } 134 | 135 | func (ms *memoryStore) SearchSessions(ctx context.Context, query *models.SessionSearchQuery, limit int) (*models.SessionSearchResponse, error) { 136 | return ms.dao("", 0).SearchSessions(ctx, query, limit) 137 | } 138 | 139 | func (ms *memoryStore) PurgeDeleted(ctx context.Context, schemaName string) error { 140 | err := purgeDeleted(ctx, ms.as.DB.DB, schemaName, ms.rs.ProjectUUID) 141 | if err != nil { 142 | return zerrors.NewStorageError("failed to purge deleted", err) 143 | } 144 | 145 | return nil 146 | } 147 | 148 | func generateLockID(key string) (uint64, error) { 149 | hasher := sha256.New() 150 | _, err := hasher.Write([]byte(key)) 151 | if err != nil { 152 | return 0, fmt.Errorf("failed to hash key %w", err) 153 | } 154 | hash := hasher.Sum(nil) 155 | return binary.BigEndian.Uint64(hash[:8]), nil 156 | } 157 | 158 | // safelyAcquireMetadataLock attempts to safely acquire a PostgreSQL advisory lock for the given key using a default retry policy. 159 | func safelyAcquireMetadataLock(ctx context.Context, db bun.IDB, key string) (uint64, error) { 160 | lockRetryPolicy := buildDefaultLockRetryPolicy() 161 | 162 | lockIDVal, err := failsafe.Get( 163 | func() (any, error) { 164 | return tryAcquireAdvisoryLock(ctx, db, key) 165 | }, lockRetryPolicy, 166 | ) 167 | if err != nil { 168 | return 0, fmt.Errorf("failed to acquire advisory lock: %w", err) 169 | } 170 | 171 | lockID, ok := lockIDVal.(uint64) 172 | if !ok { 173 | return 0, fmt.Errorf("failed to acquire advisory lock: %w", zerrors.ErrLockAcquisitionFailed) 174 | } 175 | 176 | return lockID, nil 177 | } 178 | 179 | // tryAcquireAdvisoryLock attempts to acquire a PostgreSQL advisory lock using pg_try_advisory_lock. 180 | // This function will fail if it's unable to immediately acquire a lock. 181 | // Accepts a bun.IDB, which can be either a *bun.DB or *bun.Tx. 182 | // Returns the lock ID and a boolean indicating if the lock was successfully acquired. 183 | func tryAcquireAdvisoryLock(ctx context.Context, db bun.IDB, key string) (uint64, error) { 184 | lockID, err := generateLockID(key) 185 | if err != nil { 186 | return 0, fmt.Errorf("failed to generate lock ID: %w", err) 187 | } 188 | 189 | var acquired bool 190 | if err := db.QueryRowContext(ctx, "SELECT pg_try_advisory_lock(?)", lockID).Scan(&acquired); err != nil { 191 | return 0, fmt.Errorf("tryAcquireAdvisoryLock: %w", err) 192 | } 193 | if !acquired { 194 | return 0, zerrors.NewAdvisoryLockError(fmt.Errorf("failed to acquire advisory lock for %s", key)) 195 | } 196 | return lockID, nil 197 | } 198 | 199 | func buildDefaultLockRetryPolicy() retrypolicy.RetryPolicy[any] { 200 | return retrypolicy.Builder[any](). 201 | HandleErrors(zerrors.ErrLockAcquisitionFailed). 202 | WithBackoff(200*time.Millisecond, 30*time.Second). 203 | WithMaxRetries(15). 204 | Build() 205 | } 206 | 207 | // releaseAdvisoryLock releases a PostgreSQL advisory lock for the given key. 208 | // Accepts a bun.IDB, which can be either a *bun.DB or *bun.Tx. 209 | func releaseAdvisoryLock(ctx context.Context, db bun.IDB, lockID uint64) error { 210 | if _, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock(?)", lockID); err != nil { 211 | return fmt.Errorf("failed to release advisory lock %w", err) 212 | } 213 | 214 | return nil 215 | } 216 | 217 | // rollbackOnError rolls back the transaction if an error is encountered. 218 | // If the error is sql.ErrTxDone, the transaction has already been committed or rolled back 219 | // and we ignore the error. 220 | func rollbackOnError(tx bun.Tx) { 221 | if rollBackErr := tx.Rollback(); rollBackErr != nil && !errors.Is(rollBackErr, sql.ErrTxDone) { 222 | logger.Error("failed to rollback transaction", "error", rollBackErr) 223 | } 224 | } 225 | -------------------------------------------------------------------------------- /src/store/message_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package store 3 | 4 | import ( 5 | "context" 6 | 7 | "github.com/google/uuid" 8 | "github.com/uptrace/bun" 9 | ) 10 | 11 | func (dao *messageDAO) cleanup(ctx context.Context, messageUUID uuid.UUID, tx *bun.Tx) error { 12 | return nil 13 | } 14 | -------------------------------------------------------------------------------- /src/store/metadata_utils.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | 9 | "dario.cat/mergo" 10 | "github.com/getzep/zep/lib/zerrors" 11 | "github.com/uptrace/bun" 12 | ) 13 | 14 | type mergeMetadataRequest struct { 15 | entityField string 16 | entityID string 17 | table string 18 | metadata map[string]any 19 | } 20 | 21 | // mergeMetadata merges the received metadata map with the existing metadata map in DB, 22 | // creating keys and values if they don't exist, and overwriting others. 23 | func mergeMetadata(ctx context.Context, 24 | db bun.IDB, 25 | schemaName string, 26 | mergeData mergeMetadataRequest, 27 | isPrivileged bool, 28 | ) (map[string]any, error) { 29 | if mergeData.entityField == "" { 30 | return nil, errors.New("entityField cannot be empty") 31 | } 32 | 33 | if mergeData.entityID == "" { 34 | return nil, errors.New("entityID cannot be empty") 35 | } 36 | 37 | if mergeData.table == "" { 38 | return nil, errors.New("table cannot be empty") 39 | } 40 | 41 | if len(mergeData.metadata) == 0 { 42 | return nil, errors.New("metadata cannot be empty") 43 | } 44 | 45 | // remove the top-level `system` key from the metadata if the caller is not privileged 46 | if !isPrivileged { 47 | delete(mergeData.metadata, "system") 48 | } 49 | 50 | // this should include selection of soft-deleted entities 51 | dbMetadata := new(map[string]any) 52 | 53 | err := db.NewSelect(). 54 | Table(fmt.Sprintf("%s.%s", schemaName, mergeData.table)). 55 | Column("metadata"). 56 | Where("? = ?", bun.Ident(mergeData.entityField), mergeData.entityID). 57 | Scan(ctx, &dbMetadata) 58 | if err != nil { 59 | if errors.Is(err, sql.ErrNoRows) { 60 | return nil, zerrors.NewNotFoundError( 61 | fmt.Sprintf("%s %s", mergeData.entityField, mergeData.entityID), 62 | ) 63 | } 64 | return nil, fmt.Errorf("failed to get %s: %w", mergeData.entityField, err) 65 | } 66 | 67 | // merge the existing metadata with the new metadata 68 | if err := mergo.Merge(dbMetadata, mergeData.metadata, mergo.WithOverride); err != nil { 69 | return nil, fmt.Errorf("failed to merge metadata: %w", err) 70 | } 71 | 72 | return *dbMetadata, nil 73 | } 74 | -------------------------------------------------------------------------------- /src/store/migrations/000000000001_database_setup.down.sql: -------------------------------------------------------------------------------- 1 | -- normally the down migration is the opposite of the up migration 2 | -- but in this case we don't want to drop everything. if the user wants to 3 | -- start fresh, they should manually drop the database. 4 | SELECT 1; 5 | -------------------------------------------------------------------------------- /src/store/migrations/000000000001_database_setup.up.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS "users" 2 | ( 3 | "uuid" uuid NOT NULL DEFAULT gen_random_uuid(), 4 | "id" BIGSERIAL, 5 | "created_at" timestamptz NOT NULL DEFAULT current_timestamp, 6 | "updated_at" timestamptz DEFAULT current_timestamp, 7 | "deleted_at" timestamptz, 8 | "user_id" VARCHAR NOT NULL, 9 | "email" VARCHAR, 10 | "first_name" VARCHAR, 11 | "last_name" VARCHAR, 12 | "project_uuid" uuid NOT NULL, 13 | "metadata" jsonb, 14 | PRIMARY KEY ("uuid"), 15 | UNIQUE ("user_id") 16 | ); 17 | 18 | CREATE TYPE role_type_enum AS ENUM ( 19 | 'norole', 20 | 'system', 21 | 'assistant', 22 | 'user', 23 | 'function', 24 | 'tool' 25 | ); 26 | 27 | CREATE TABLE IF NOT EXISTS "sessions" 28 | ( 29 | "uuid" uuid NOT NULL DEFAULT gen_random_uuid(), 30 | "id" BIGSERIAL, 31 | "session_id" VARCHAR NOT NULL, 32 | "created_at" timestamptz NOT NULL DEFAULT current_timestamp, 33 | "updated_at" timestamptz NOT NULL DEFAULT current_timestamp, 34 | "deleted_at" timestamptz, 35 | "ended_at" timestamptz, 36 | "metadata" jsonb, 37 | "user_id" VARCHAR, 38 | "project_uuid" uuid NOT NULL, 39 | PRIMARY KEY ("uuid"), 40 | UNIQUE ("session_id"), 41 | FOREIGN KEY ("user_id") REFERENCES "users" ("user_id") ON UPDATE NO ACTION ON DELETE CASCADE 42 | ); 43 | 44 | CREATE TABLE IF NOT EXISTS "messages" 45 | ( 46 | "uuid" uuid NOT NULL DEFAULT gen_random_uuid(), 47 | "id" BIGSERIAL, 48 | "created_at" timestamptz NOT NULL DEFAULT current_timestamp, 49 | "updated_at" timestamptz DEFAULT current_timestamp, 50 | "deleted_at" timestamptz, 51 | "session_id" VARCHAR NOT NULL, 52 | "project_uuid" uuid NOT NULL, 53 | "role" VARCHAR NOT NULL, 54 | "role_type" role_type_enum DEFAULT 'norole', 55 | "content" VARCHAR NOT NULL, 56 | "token_count" BIGINT NOT NULL, 57 | "metadata" jsonb, 58 | PRIMARY KEY ("uuid"), 59 | FOREIGN KEY ("session_id") REFERENCES "sessions" ("session_id") ON UPDATE NO ACTION ON DELETE CASCADE 60 | ); 61 | 62 | 63 | 64 | CREATE INDEX IF NOT EXISTS "user_user_id_idx" ON "users" ("user_id"); 65 | CREATE INDEX IF NOT EXISTS "user_email_idx" ON "users" ("email"); 66 | CREATE INDEX IF NOT EXISTS "memstore_session_id_idx" ON "messages" ("session_id"); 67 | CREATE INDEX IF NOT EXISTS "memstore_id_idx" ON "messages" ("id"); 68 | CREATE INDEX IF NOT EXISTS "memstore_session_id_project_uuid_deleted_at_idx" ON "messages" ("session_id", "project_uuid", "deleted_at"); 69 | CREATE INDEX IF NOT EXISTS "session_user_id_idx" ON "sessions" ("user_id"); 70 | CREATE INDEX IF NOT EXISTS "session_id_project_uuid_deleted_at_idx" ON "sessions" ("session_id", "project_uuid", "deleted_at"); 71 | -------------------------------------------------------------------------------- /src/store/migrations/migrate.go: -------------------------------------------------------------------------------- 1 | package migrations 2 | 3 | import ( 4 | "context" 5 | "embed" 6 | "fmt" 7 | 8 | "github.com/uptrace/bun/migrate" 9 | 10 | "github.com/getzep/zep/lib/logger" 11 | "github.com/getzep/zep/lib/pg" 12 | ) 13 | 14 | //go:embed *.sql 15 | var sqlMigrations embed.FS 16 | 17 | func Migrate(ctx context.Context, db pg.Connection, schemaName string) error { 18 | migrations := migrate.NewMigrations() 19 | 20 | if err := migrations.Discover(sqlMigrations); err != nil { 21 | return fmt.Errorf("failed to discover migrations: %w", err) 22 | } 23 | 24 | // Set the search path to the current schema. 25 | if _, err := db.Exec(`SET search_path TO ?`, schemaName); err != nil { 26 | return fmt.Errorf("failed to set search path: %w", err) 27 | } 28 | 29 | migrator := migrate.NewMigrator(db.DB, migrations) 30 | 31 | if err := migrator.Init(ctx); err != nil { 32 | return fmt.Errorf("failed to init migrator: %w", err) 33 | } 34 | 35 | if err := migrator.Lock(ctx); err != nil { 36 | return fmt.Errorf("failed to lock migrator: %w", err) 37 | } 38 | defer func(migrator *migrate.Migrator, ctx context.Context) { 39 | err := migrator.Unlock(ctx) 40 | if err != nil { 41 | panic(fmt.Errorf("failed to unlock migrator: %w", err)) 42 | } 43 | }(migrator, ctx) 44 | 45 | group, err := migrator.Migrate(ctx) 46 | if err != nil { 47 | defer func(migrator *migrate.Migrator, ctx context.Context) { 48 | err := migrator.Unlock(ctx) 49 | if err != nil { 50 | panic(fmt.Errorf("failed to unlock migrator: %w", err)) 51 | } 52 | }(migrator, ctx) 53 | _, rollBackErr := migrator.Rollback(ctx) 54 | if rollBackErr != nil { 55 | panic( 56 | fmt.Errorf("failed to apply migrations and rollback was unsuccessful: %v %w", err, rollBackErr), 57 | ) 58 | } 59 | 60 | panic(fmt.Errorf("failed to apply migrations. rolled back successfully. %w", err)) 61 | } 62 | 63 | if group.IsZero() { 64 | logger.Info("there are no new migrations to run (database is up to date)") 65 | return nil 66 | } 67 | 68 | logger.Info("migration complete", "group", group) 69 | 70 | return nil 71 | } 72 | -------------------------------------------------------------------------------- /src/store/purge_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package store 3 | 4 | import ( 5 | "context" 6 | 7 | "github.com/google/uuid" 8 | "github.com/uptrace/bun" 9 | ) 10 | 11 | func tableCleanup(ctx context.Context, tx *bun.Tx, schemaName string, projectUUID uuid.UUID) error { 12 | return nil 13 | } 14 | -------------------------------------------------------------------------------- /src/store/purge_common.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/google/uuid" 8 | "github.com/uptrace/bun" 9 | ) 10 | 11 | // purgeDeleted hard deletes all soft deleted records from the memory store. 12 | func purgeDeleted(ctx context.Context, db *bun.DB, schemaName string, projectUUID uuid.UUID) error { 13 | if schemaName == "" { 14 | return fmt.Errorf("schemaName cannot be empty") 15 | } 16 | 17 | tx, err := db.BeginTx(ctx, nil) 18 | if err != nil { 19 | return fmt.Errorf("failed to begin transaction: %w", err) 20 | } 21 | defer rollbackOnError(tx) 22 | 23 | _, err = tx.Exec("SET LOCAL search_path TO ?"+SearchPathSuffix, schemaName) 24 | if err != nil { 25 | return fmt.Errorf("error setting schema: %w", err) 26 | } 27 | 28 | // Delete all messages, message embeddings, and summaries associated with sessions 29 | for _, schema := range messageTableList { 30 | _, err := tx.NewDelete(). 31 | Model(schema). 32 | WhereDeleted(). 33 | ForceDelete(). 34 | Exec(ctx) 35 | if err != nil { 36 | return fmt.Errorf("error purging rows from %T: %w", schema, err) 37 | } 38 | } 39 | 40 | // Delete user store records. 41 | _, err = tx.NewDelete(). 42 | Model(&UserSchema{}). 43 | WhereDeleted(). 44 | ForceDelete(). 45 | Exec(ctx) 46 | if err != nil { 47 | return fmt.Errorf("error purging rows from %T: %w", &UserSchema{}, err) 48 | } 49 | 50 | err = tableCleanup(ctx, &tx, schemaName, projectUUID) 51 | if err != nil { 52 | return fmt.Errorf("failed to cleanup tables: %w", err) 53 | } 54 | 55 | if err := tx.Commit(); err != nil { 56 | return fmt.Errorf("failed to commit transaction: %w", err) 57 | } 58 | 59 | return nil 60 | } 61 | -------------------------------------------------------------------------------- /src/store/schema_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package store 3 | 4 | import "github.com/uptrace/bun" 5 | 6 | type SessionSchemaExt struct { 7 | bun.BaseModel `bun:"table:sessions,alias:s" yaml:"-"` 8 | } 9 | 10 | type UserSchemaExt struct { 11 | bun.BaseModel `bun:"table:users,alias:u" yaml:"-"` 12 | } 13 | 14 | var ( 15 | indexes = __indexes 16 | messageTableList = __messageTableList 17 | bunModels = __bunModels 18 | embeddingTables = __embeddingTables 19 | 20 | _ = indexes 21 | _ = __indexes 22 | _ = messageTableList 23 | _ = __messageTableList 24 | _ = bunModels 25 | _ = __bunModels 26 | _ = embeddingTables 27 | _ = __embeddingTables 28 | ) 29 | -------------------------------------------------------------------------------- /src/store/schema_common.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | "github.com/uptrace/bun" 10 | 11 | "github.com/getzep/zep/lib/pg" 12 | "github.com/getzep/zep/models" 13 | "github.com/getzep/zep/store/migrations" 14 | ) 15 | 16 | const SearchPathSuffix = ", public" 17 | 18 | type BaseSchema struct { 19 | SchemaName string `bun:"-" yaml:"schema_name"` 20 | TableName string `bun:"-"` 21 | Alias string `bun:"-"` 22 | } 23 | 24 | func (s *BaseSchema) GetTableName() string { 25 | return fmt.Sprintf("%s.%s", s.SchemaName, s.TableName) 26 | } 27 | 28 | func (s *BaseSchema) GetTableAndAlias() string { 29 | return fmt.Sprintf("%s AS %s", s.GetTableName(), s.Alias) 30 | } 31 | 32 | func NewBaseSchema(schemaName, tableName string) BaseSchema { 33 | return BaseSchema{ 34 | SchemaName: schemaName, 35 | TableName: tableName, 36 | } 37 | } 38 | 39 | type SessionSchema struct { 40 | bun.BaseModel `bun:"table:sessions,alias:s" yaml:"-"` 41 | BaseSchema `yaml:"-"` 42 | 43 | SessionSchemaExt `bun:",extend"` 44 | 45 | UUID uuid.UUID `bun:",pk,type:uuid,default:gen_random_uuid()" yaml:"uuid,omitempty"` 46 | ID int64 `bun:",autoincrement" yaml:"id,omitempty"` // used as a cursor for pagination 47 | SessionID string `bun:",unique,notnull" yaml:"session_id,omitempty"` 48 | CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp" yaml:"created_at,omitempty"` 49 | UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp" yaml:"updated_at,omitempty"` 50 | DeletedAt time.Time `bun:"type:timestamptz,soft_delete,nullzero" yaml:"deleted_at,omitempty"` 51 | EndedAt *time.Time `bun:"type:timestamptz,nullzero" yaml:"ended_at,omitempty"` 52 | Metadata map[string]any `bun:"type:jsonb,nullzero,json_use_number" yaml:"metadata,omitempty"` 53 | // UserUUID must be pointer type in order to be nullable 54 | UserID *string `bun:"," yaml:"user_id,omitempty"` 55 | User *UserSchema `bun:"rel:belongs-to,join:user_id=user_id,on_delete:cascade" yaml:"-"` 56 | ProjectUUID uuid.UUID `bun:"type:uuid,notnull" yaml:"project_uuid,omitempty"` 57 | } 58 | 59 | func (s *SessionSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { 60 | if _, ok := query.(*bun.UpdateQuery); ok { 61 | s.UpdatedAt = time.Now() 62 | } 63 | return nil 64 | } 65 | 66 | type MessageStoreSchema struct { 67 | bun.BaseModel `bun:"table:messages,alias:m" yaml:"-"` 68 | BaseSchema `yaml:"-"` 69 | 70 | UUID uuid.UUID `bun:",pk,type:uuid,default:gen_random_uuid()" yaml:"uuid"` 71 | // ID is used only for sorting / slicing purposes as we can't sort by CreatedAt for messages created simultaneously 72 | ID int64 `bun:",autoincrement" yaml:"id,omitempty"` 73 | CreatedAt time.Time `bun:"type:timestamptz,notnull,default:current_timestamp" yaml:"created_at,omitempty"` 74 | UpdatedAt time.Time `bun:"type:timestamptz,nullzero,default:current_timestamp" yaml:"updated_at,omitempty"` 75 | DeletedAt time.Time `bun:"type:timestamptz,soft_delete,nullzero" yaml:"deleted_at,omitempty"` 76 | SessionID string `bun:",notnull" yaml:"session_id,omitempty"` 77 | ProjectUUID uuid.UUID `bun:"type:uuid,notnull" yaml:"project_uuid,omitempty"` 78 | Role string `bun:",notnull" yaml:"role,omitempty"` 79 | RoleType models.RoleType `bun:",type:public.role_type_enum,nullzero,default:'norole'" yaml:"role_type,omitempty"` 80 | Content string `bun:",notnull" yaml:"content,omitempty"` 81 | TokenCount int `bun:",notnull" yaml:"token_count,omitempty"` 82 | Metadata map[string]any `bun:"type:jsonb,nullzero,json_use_number" yaml:"metadata,omitempty"` 83 | Session *SessionSchema `bun:"rel:belongs-to,join:session_id=session_id,on_delete:cascade" yaml:"-"` 84 | } 85 | 86 | func (s *MessageStoreSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { 87 | if _, ok := query.(*bun.UpdateQuery); ok { 88 | s.UpdatedAt = time.Now() 89 | } 90 | return nil 91 | } 92 | 93 | type UserSchema struct { 94 | bun.BaseModel `bun:"table:users,alias:u" yaml:"-"` 95 | BaseSchema `yaml:"-"` 96 | 97 | UserSchemaExt `bun:",extend"` 98 | 99 | UUID uuid.UUID `bun:",pk,type:uuid,default:gen_random_uuid()" yaml:"uuid,omitempty"` 100 | ID int64 `bun:",autoincrement" yaml:"id,omitempty"` // used as a cursor for pagination 101 | CreatedAt time.Time `bun:"type:timestamptz,notnull,default:current_timestamp" yaml:"created_at,omitempty"` 102 | UpdatedAt time.Time `bun:"type:timestamptz,nullzero,default:current_timestamp" yaml:"updated_at,omitempty"` 103 | DeletedAt time.Time `bun:"type:timestamptz,soft_delete,nullzero" yaml:"deleted_at,omitempty"` 104 | UserID string `bun:",unique,notnull" yaml:"user_id,omitempty"` 105 | Email string `bun:"," yaml:"email,omitempty"` 106 | FirstName string `bun:"," yaml:"first_name,omitempty"` 107 | LastName string `bun:"," yaml:"last_name,omitempty"` 108 | ProjectUUID uuid.UUID `bun:"type:uuid,notnull" yaml:"project_uuid,omitempty"` 109 | Metadata map[string]any `bun:"type:jsonb,nullzero,json_use_number" yaml:"metadata,omitempty"` 110 | } 111 | 112 | func (u *UserSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { 113 | if _, ok := query.(*bun.UpdateQuery); ok { 114 | u.UpdatedAt = time.Now() 115 | } 116 | return nil 117 | } 118 | 119 | type indexInfo struct { 120 | model any 121 | column string 122 | indexName string 123 | compositeColumn []string 124 | unique bool //nolint:unused // unused 125 | custom string //nolint:unused // unused 126 | } 127 | 128 | var ( 129 | // messageTableList is a list of tables that are created when the schema is created. 130 | // the list is also used when deleting message-related rows from the database. 131 | // DO NOT USE this directly. Use messageTableList instead. 132 | __messageTableList = []any{ 133 | &MessageStoreSchema{}, 134 | &SessionSchema{}, 135 | } 136 | 137 | // DO NOT USE this directly. Use bunModels instead. 138 | __bunModels = []any{ 139 | &UserSchema{}, 140 | &MessageStoreSchema{}, 141 | &SessionSchema{}, 142 | } 143 | 144 | __embeddingTables = []string{} 145 | 146 | // DO NOT USE this directly. Use indexes instead. 147 | __indexes = []indexInfo{ 148 | {model: &UserSchema{}, column: "user_id", indexName: "user_user_id_idx"}, 149 | {model: &UserSchema{}, column: "email", indexName: "user_email_idx"}, 150 | {model: &MessageStoreSchema{}, column: "session_id", indexName: "memstore_session_id_idx"}, 151 | {model: &MessageStoreSchema{}, column: "id", indexName: "memstore_id_idx"}, 152 | { 153 | model: &MessageStoreSchema{}, 154 | compositeColumn: []string{"session_id", "project_uuid", "deleted_at"}, 155 | indexName: "memstore_session_id_project_uuid_deleted_at_idx", 156 | }, 157 | {model: &SessionSchema{}, column: "user_id", indexName: "session_user_id_idx"}, 158 | { 159 | model: &SessionSchema{}, 160 | compositeColumn: []string{"session_id", "project_uuid", "deleted_at"}, 161 | indexName: "session_id_project_uuid_deleted_at_idx", 162 | }, 163 | } 164 | ) 165 | 166 | func MigrateSchema(ctx context.Context, db pg.Connection, schemaName string) error { 167 | if err := migrations.Migrate(ctx, db, schemaName); err != nil { 168 | return fmt.Errorf("failed to apply migrations: %w", err) 169 | } 170 | 171 | return nil 172 | } 173 | -------------------------------------------------------------------------------- /src/store/session_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package store 3 | 4 | import ( 5 | "context" 6 | "database/sql" 7 | "errors" 8 | "fmt" 9 | 10 | "github.com/uptrace/bun" 11 | 12 | "github.com/getzep/zep/lib/zerrors" 13 | "github.com/getzep/zep/models" 14 | ) 15 | 16 | func sessionSchemaExt(data ...*models.CreateSessionRequest) SessionSchemaExt { 17 | return SessionSchemaExt{} 18 | } 19 | 20 | func (dao *sessionDAO) buildUpdate(ctx context.Context, session *models.UpdateSessionRequest) (SessionSchema, []string) { 21 | return dao._buildUpdate(ctx, session) 22 | } 23 | 24 | func (dao *sessionDAO) sessionRelations(q *bun.SelectQuery) {} 25 | 26 | func (dao *sessionDAO) cleanup(ctx context.Context, sessionID string, tx bun.Tx) error { 27 | return nil 28 | } 29 | 30 | func (dao *sessionDAO) Get(ctx context.Context, sessionID string) (*models.Session, error) { 31 | session, err := dao.getBySessionID(ctx, sessionID, false) 32 | if err != nil { 33 | if errors.Is(err, sql.ErrNoRows) { 34 | return nil, zerrors.NewNotFoundError("session " + sessionID) 35 | } 36 | return nil, fmt.Errorf("sessionDAO Get failed to get session: %w", err) 37 | } 38 | 39 | resp := sessionSchemaToSession(*session)[0] 40 | 41 | return resp, nil 42 | } 43 | 44 | func sessionSchemaToSession(sessions ...SessionSchema) []*models.Session { 45 | retSessions := make([]*models.Session, len(sessions)) 46 | for i, sess := range sessions { 47 | s := _sessionSchemaToSession(sess) 48 | 49 | retSessions[i] = s 50 | } 51 | return retSessions 52 | } 53 | -------------------------------------------------------------------------------- /src/store/sessionstore_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package store 3 | 4 | import ( 5 | "context" 6 | "errors" 7 | "fmt" 8 | 9 | "github.com/getzep/zep/lib/graphiti" 10 | ) 11 | 12 | func (dao *sessionDAO) _cleanupDeletedSession(ctx context.Context) error { 13 | return purgeDeletedResources(ctx, dao.as.DB) 14 | } 15 | 16 | func (dao *sessionDAO) _postCreateSession(ctx context.Context, sessionID, userID string) error { 17 | user, err := dao.rs.Users.Get(ctx, userID) 18 | if err != nil { 19 | return fmt.Errorf("failed to get user: %w", err) 20 | } 21 | if user == nil { 22 | return errors.New("user not found") 23 | } 24 | name := fmt.Sprintf("User %s %s", user.FirstName, user.LastName) 25 | return graphiti.I().AddNode(ctx, graphiti.AddNodeRequest{ 26 | GroupID: sessionID, 27 | UUID: fmt.Sprintf("%s_%s", sessionID, userID), 28 | Name: name, 29 | Summary: name, 30 | }) 31 | } 32 | -------------------------------------------------------------------------------- /src/store/userstore_ce.go: -------------------------------------------------------------------------------- 1 | 2 | package store 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | 8 | "github.com/getzep/zep/lib/graphiti" 9 | "github.com/getzep/zep/models" 10 | ) 11 | 12 | func (us *userStore) _processCreatedUser(ctx context.Context, user *models.User) error { 13 | err := graphiti.I().AddNode(ctx, graphiti.AddNodeRequest{ 14 | GroupID: user.UserID, 15 | UUID: user.UserID, 16 | Name: fmt.Sprintf("User %s %s", user.FirstName, user.LastName), 17 | Summary: fmt.Sprintf("User %s %s", user.FirstName, user.LastName), 18 | }) 19 | return err 20 | } 21 | 22 | func (us *userStore) _cleanupDeletedUser(ctx context.Context, userID string, sessionIDs []string) error { 23 | err := graphiti.I().DeleteGroup(ctx, userID) 24 | if err != nil { 25 | return err 26 | } 27 | for _, sessionID := range sessionIDs { 28 | err := graphiti.I().DeleteGroup(ctx, sessionID) 29 | if err != nil { 30 | return err 31 | } 32 | } 33 | return purgeDeletedResources(ctx, us.as.DB) 34 | } 35 | -------------------------------------------------------------------------------- /zep: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # List of supported docker compose commands 4 | DOCKER_COMPOSE_COMMANDS=("up" "pull" "down" "logs" "ps" "restart" "stop" "start") 5 | 6 | _make() { 7 | make -f Makefile.ce "${@:1}" 8 | } 9 | 10 | CMD="${1}" 11 | 12 | # Function to check if a value is in an array 13 | contains_element() { 14 | local e match="$1" 15 | shift 16 | for e; do [[ "$e" == "$match" ]] && return 0; done 17 | return 1 18 | } 19 | 20 | # Check if the command is in the list of supported docker compose commands 21 | if contains_element "$CMD" "${DOCKER_COMPOSE_COMMANDS[@]}"; then 22 | docker compose -f docker-compose.ce.yaml "$CMD" "${@:2}" 23 | elif [ "$CMD" = "make" ]; then 24 | _make "${@:2}" 25 | else 26 | echo "${CMD} is not a valid command" 27 | echo "Usage: " 28 | echo " ./zep [$(printf "%s | " "${DOCKER_COMPOSE_COMMANDS[@]}" | sed 's/ | $//')]" 29 | echo " ./zep make " 30 | fi 31 | -------------------------------------------------------------------------------- /zep.yaml: -------------------------------------------------------------------------------- 1 | log: 2 | # debug, info, warn, error, panic, dpanic, or fatal. Default = info 3 | level: info 4 | # How should logs be formatted? Setting to "console" will print human readable logs 5 | # whie "json" will print structured JSON logs. Default is "json". 6 | format: json 7 | http: 8 | # Host to bind to. Default is 0.0.0.0 9 | host: 0.0.0.0 10 | # Port to bind to. Default is 8000 11 | port: 8000 12 | max_request_size: 5242880 13 | postgres: 14 | user: postgres 15 | password: postgres 16 | host: db 17 | port: 5432 18 | database: postgres 19 | schema_name: public 20 | read_timeout: 30 21 | write_timeout: 30 22 | max_open_connections: 10 23 | # Carbon is a package used for dealing with time - github.com/golang-module/carbon 24 | # It is primarily used for generating humand readable relative time strings like "2 hours ago". 25 | # See the list of supported languages here https://github.com/golang-module/carbon?tab=readme-ov-file#i18n 26 | carbon: 27 | locale: en 28 | graphiti: 29 | # Base url to the graphiti service 30 | service_url: http://graphiti:8003 31 | # In order to authenicate API requests to the Zep service, a secret must be provided. 32 | # This secret should be kept secret between the Zep service and the client. It can be any string value. 33 | # When making requests to the Zep service, include the secret in the Authorization header. 34 | api_secret: 35 | # In order to better understand how Zep is used, we can collect telemetry data. 36 | # This is optional and can be disabled by setting disabled to true. 37 | # We do not collect any PII or any of your data. We only collect anonymized data 38 | # about how Zep is used. 39 | telemetry: 40 | disabled: false 41 | # Please provide an identifying name for your organization so can get a better understanding 42 | # about who is using Zep. This is optional. 43 | organization_name: 44 | --------------------------------------------------------------------------------