├── .devcontainer └── devcontainer.json ├── .dockerignore ├── .github ├── dependabot.yml └── workflows │ └── check.yml ├── .gitignore ├── .python-version ├── .skiff ├── cloudbuild-deploy.yaml ├── util.libsonnet └── webapp.jsonnet ├── .vscode ├── extensions.json ├── launch.json └── settings.json ├── LICENSE ├── README.md ├── api ├── .env ├── .python-version ├── Dockerfile ├── README.md ├── app.py ├── pyproject.toml └── src │ ├── __init__.py │ ├── attribution │ ├── __init__.py │ ├── attribution_queue_service.py │ ├── attribution_request.py │ ├── attribution_router.py │ └── attribution_service.py │ ├── cache │ ├── __init__.py │ └── redis.py │ ├── camel_case_model.py │ ├── config.py │ ├── documents │ ├── __init__.py │ ├── documents_router.py │ └── documents_service.py │ ├── glog.py │ ├── health.py │ ├── infini_gram_exception_handler.py │ └── infinigram │ ├── __init__.py │ ├── infini_gram_dependency.py │ └── infinigram_router.py ├── attribution_worker ├── Dockerfile ├── README.md ├── __init__.py ├── config.py ├── get_documents.py ├── get_span_text.py ├── pyproject.toml └── worker.py ├── bin └── download-infini-gram-array.sh ├── compute_stats ├── batch.py ├── batch.sh ├── main.py ├── transform.py ├── wiki.json └── worker.py ├── docker-compose.yaml ├── docs ├── CONTRIBUTING.md └── indexes.md ├── indexing ├── flan_rulebased_s3_to_weka.sh ├── index_weka_to_s3_all.sh ├── index_weka_to_s3_dclm.sh ├── index_weka_to_s3_nodclm.sh ├── indexing_dclm.sh ├── indexing_nodclm.sh ├── indexing_olmo2-13b-anneal-adapt.sh ├── indexing_olmo2-32b-anneal-adapt.sh ├── indexing_olmoe-1b-7b-anneal-adapt.sh ├── indexing_olmoe-adaptation.sh ├── olmoe-mix.txt ├── process_tulu3.sh ├── raw_s3_to_weka.sh ├── raw_s3_to_weka_all.sh ├── transform_hf_to_raw.py ├── transform_hf_to_raw_olmo2.py └── transform_hf_to_raw_tulu3.py ├── load-test ├── README.md ├── bailey100.json ├── locustfile-short.py ├── locustfile.py ├── pyproject.toml └── short-messages.json ├── otel-collector └── otel-collector-config.yaml ├── packages └── infini-gram-processor │ ├── README.md │ ├── pyproject.toml │ └── src │ └── infini_gram_processor │ ├── __init__.py │ ├── index_mappings.py │ ├── infini_gram_engine_exception.py │ ├── models │ ├── __init__.py │ ├── camel_case_model.py │ ├── is_infini_gram_error_response.py │ └── models.py │ ├── processor.py │ ├── processor_config.py │ ├── py.typed │ └── tokenizers │ ├── tokenizer.py │ └── tokenizer_factory.py ├── proxy ├── Dockerfile ├── local.conf ├── nginx.conf ├── prod.conf └── proxy.conf ├── pyproject.toml ├── schema └── local.sql ├── scripts ├── compute_correlation.py ├── distrib_of_score.py ├── easyapi_test.py ├── eval_corpuslink_relevance.py ├── issue_request.py ├── repro_ui.py ├── sample_wildbench.py ├── stress_test_api.py └── test_span_density.py ├── skiff.json ├── uv.lock ├── vendor ├── infini_gram-2.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ├── infini_gram-2.5.1-cp311-cp311-macosx_10_15_x86_64.whl ├── infini_gram-2.5.1-cp311-cp311-macosx_11_0_arm64.whl ├── infini_gram-2.5.1-cp312-cp312-macosx_10_15_x86_64.whl ├── infini_gram-2.5.1-cp312-cp312-macosx_11_0_arm64.whl ├── infini_gram-2.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ├── infini_gram-2.5.1-cp313-cp313-macosx_10_15_x86_64.whl ├── infini_gram-2.5.1-cp313-cp313-macosx_11_0_arm64.whl ├── llama-2-7b-hf │ ├── config.json │ ├── special_tokens_map.json │ ├── tokenizer.json │ └── tokenizer_config.json ├── llama-2_bow_ids.txt └── olmo-7b-hf │ ├── config.json │ ├── special_tokens_map.json │ ├── tokenization_olmo_fast.py │ ├── tokenizer.json │ └── tokenizer_config.json └── volume-claims ├── olmoe-mix-0924-dclm.yaml ├── olmoe-mix-0924-nodclm.yaml ├── pileval-gpt2.yaml ├── v4-olmo-2-0325-32b-anneal-adapt.yaml ├── v4-olmo-2-1124-13b-anneal-adapt.yaml ├── v4-olmoe-0125-1b-7b-anneal-adapt.yaml ├── v4-tulu-3-405b-adapt-llama.yaml ├── v4-tulu-3-70b-adapt-llama.yaml ├── v4-tulu-3-8b-adapt-llama.yaml └── writer-pod.yaml /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/python 3 | { 4 | "name": "Python 3", 5 | // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile 6 | "image": "mcr.microsoft.com/devcontainers/python:1-3.12-bullseye", 7 | "features": { 8 | "ghcr.io/devcontainers/features/python:1": {} 9 | } 10 | 11 | // Features to add to the dev container. More info: https://containers.dev/features. 12 | // "features": {}, 13 | 14 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 15 | // "forwardPorts": [], 16 | 17 | // Use 'postCreateCommand' to run commands after the container is created. 18 | // "postCreateCommand": "pip3 install --user -r requirements.txt", 19 | 20 | // Configure tool-specific properties. 21 | // "customizations": {}, 22 | 23 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. 24 | // "remoteUser": "root" 25 | } 26 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Python cache files 2 | **.pyc 3 | **/__pycache__ 4 | **/.dockerignore 5 | 6 | # Docker artifacts 7 | **/Dockerfile 8 | 9 | **/.env* 10 | 11 | .venv -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for more information: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | # https://containers.dev/guide/dependabot 6 | 7 | version: 2 8 | updates: 9 | - package-ecosystem: "devcontainers" 10 | directory: "/" 11 | schedule: 12 | interval: weekly 13 | -------------------------------------------------------------------------------- /.github/workflows/check.yml: -------------------------------------------------------------------------------- 1 | name: Lint and type-check PR 2 | 3 | concurrency: 4 | group: unit-${{ github.ref }} 5 | cancel-in-progress: true 6 | 7 | on: 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | jobs: 14 | lint: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Install uv 20 | uses: astral-sh/setup-uv@v5 21 | with: 22 | version: "0.6.8" 23 | enable-cache: true 24 | 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version-file: "./api/pyproject.toml" 28 | 29 | - run: uv sync --all-extras --dev --all-packages 30 | - run: uv run ruff check 31 | 32 | type-check: 33 | runs-on: ubuntu-latest 34 | steps: 35 | - uses: actions/checkout@v4 36 | 37 | - name: Install uv 38 | uses: astral-sh/setup-uv@v5 39 | with: 40 | version: "0.6.8" 41 | enable-cache: true 42 | 43 | - uses: actions/setup-python@v5 44 | with: 45 | python-version-file: "pyproject.toml" 46 | 47 | - run: uv sync --all-extras --dev --all-packages 48 | - run: uv run mypy --config ./pyproject.toml 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python cache files 2 | __pycache__ 3 | **/*.pyc 4 | .venv/ 5 | 6 | # Kubernetes configuration 7 | .skiff/webapp.yaml 8 | infinigram-array 9 | 10 | .DS_Store 11 | 12 | api/performance-profiles 13 | data/ 14 | index/ 15 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /.skiff/cloudbuild-deploy.yaml: -------------------------------------------------------------------------------- 1 | # This file tells Google Cloud Build how to deploy the application. 2 | # It can be attached to a variety of triggers, the default being whenever 3 | # someone merges changes to the `main` branch. 4 | steps: 5 | - id: 'proxy.build' 6 | name: 'gcr.io/cloud-builders/docker' 7 | args: [ 8 | 'build', 9 | '-t', 'gcr.io/$PROJECT_ID/$REPO_NAME-proxy:latest', 10 | '-t', 'gcr.io/$PROJECT_ID/$REPO_NAME-proxy:$COMMIT_SHA', 11 | '--cache-from', 'gcr.io/$PROJECT_ID/$REPO_NAME-proxy:latest', 12 | '--build-arg', 'CONF_FILE=prod.conf', 13 | '--build-arg', 'BUILDKIT_INLINE_CACHE=1', 14 | '.' 15 | ] 16 | waitFor: [ '-' ] 17 | dir: 'proxy' 18 | - id: 'proxy.push' 19 | name: 'gcr.io/cloud-builders/docker' 20 | args: [ 21 | 'push', 22 | 'gcr.io/$PROJECT_ID/$REPO_NAME-proxy:$COMMIT_SHA' 23 | ] 24 | waitFor: [ 'proxy.build' ] 25 | 26 | - id: 'api.build' 27 | name: 'gcr.io/cloud-builders/docker' 28 | args: [ 29 | 'build', 30 | '-t', 'gcr.io/$PROJECT_ID/$REPO_NAME-api:latest', 31 | '-t', 'gcr.io/$PROJECT_ID/$REPO_NAME-api:$COMMIT_SHA', 32 | '--cache-from', 'gcr.io/$PROJECT_ID/$REPO_NAME-api:latest', 33 | '--build-arg', 'BUILDKIT_INLINE_CACHE=1', 34 | '-f', 'Dockerfile', 35 | '..' 36 | ] 37 | waitFor: [ '-' ] 38 | dir: 'api' 39 | - id: 'api.push' 40 | name: 'gcr.io/cloud-builders/docker' 41 | args: [ 42 | 'push', 43 | 'gcr.io/$PROJECT_ID/$REPO_NAME-api:$COMMIT_SHA', 44 | ] 45 | waitFor: [ 'api.build' ] 46 | 47 | - id: 'attribution-worker.build' 48 | name: 'gcr.io/cloud-builders/docker' 49 | args: [ 50 | 'build', 51 | '-t', 'gcr.io/$PROJECT_ID/$REPO_NAME-attribution-worker:latest', 52 | '-t', 'gcr.io/$PROJECT_ID/$REPO_NAME-attribution-worker:$COMMIT_SHA', 53 | '--cache-from', 'gcr.io/$PROJECT_ID/$REPO_NAME-attribution-worker:latest', 54 | '--build-arg', 'BUILDKIT_INLINE_CACHE=1', 55 | '-f', 'Dockerfile', 56 | '..' 57 | ] 58 | waitFor: [ '-' ] 59 | dir: 'attribution_worker' 60 | - id: 'attribution-worker.push' 61 | name: 'gcr.io/cloud-builders/docker' 62 | args: [ 63 | 'push', 64 | 'gcr.io/$PROJECT_ID/$REPO_NAME-attribution-worker:$COMMIT_SHA', 65 | ] 66 | waitFor: [ 'attribution-worker.build' ] 67 | 68 | - id: 'config' 69 | name: 'gcr.io/ai2-reviz/jsonnet' 70 | args: [ 71 | 'eval', 72 | '-y', 73 | '--output-file', './webapp.yaml', 74 | '--tla-str', 'env=$_ENV', 75 | '--tla-str', 'apiImage=gcr.io/$PROJECT_ID/$REPO_NAME-api:$COMMIT_SHA', 76 | '--tla-str', 'proxyImage=gcr.io/$PROJECT_ID/$REPO_NAME-proxy:$COMMIT_SHA', 77 | '--tla-str', 'workerImage=gcr.io/$PROJECT_ID/$REPO_NAME-attribution-worker:$COMMIT_SHA', 78 | '--tla-str', 'sha=$COMMIT_SHA', 79 | '--tla-str', 'cause="Automated Skiff Deploy SHA:$COMMIT_SHA BUILD:$BUILD_ID"', 80 | '--tla-str', 'branch=$BRANCH_NAME', 81 | '--tla-str', 'repo=$REPO_NAME', 82 | '--tla-str', 'buildId=$BUILD_ID', 83 | './webapp.jsonnet' 84 | ] 85 | dir: '.skiff' 86 | 87 | - id: 'deploy' 88 | name: 'gcr.io/ai2-reviz/rudder' 89 | args: [ 90 | 'deploy', 91 | '-f', 92 | 'webapp.yaml' 93 | ] 94 | dir: '.skiff' 95 | substitutions: 96 | _ENV: prod 97 | images: [ 98 | 'gcr.io/$PROJECT_ID/$REPO_NAME-api:$COMMIT_SHA', 99 | 'gcr.io/$PROJECT_ID/$REPO_NAME-api:latest', 100 | 'gcr.io/$PROJECT_ID/$REPO_NAME-proxy:$COMMIT_SHA', 101 | 'gcr.io/$PROJECT_ID/$REPO_NAME-proxy:latest', 102 | 'gcr.io/$PROJECT_ID/$REPO_NAME-attribution-worker:$COMMIT_SHA', 103 | 'gcr.io/$PROJECT_ID/$REPO_NAME-attribution-worker:latest' 104 | ] 105 | artifacts: 106 | objects: 107 | location: 'gs://skiff-archive/$REPO_NAME/$_ENV/$BUILD_ID/$COMMIT_SHA' 108 | paths: ['.skiff/webapp.yaml'] 109 | options: 110 | env: 111 | - 'DOCKER_BUILDKIT=1' 112 | -------------------------------------------------------------------------------- /.skiff/util.libsonnet: -------------------------------------------------------------------------------- 1 | /** 2 | * This file contains a few helper methods that are used in webapp.jsonnet. 3 | * They're put here as to not distract the reader from the stuff that really matters -- 4 | * that is the code that produces their application's configuration. 5 | */ 6 | { 7 | local util = self, 8 | 9 | /** 10 | * We're pinned to jsonnet 0.12.0, std.any was added after that version. 11 | * So we implement our own. 12 | */ 13 | any(list): 14 | std.length(std.filter(function(x) x, list)) > 0, 15 | 16 | isCustomHost(host): 17 | if std.endsWith(host, '.allen.ai') then 18 | if std.length(std.split(host, '.')) != 3 then 19 | true 20 | else 21 | false 22 | else if std.endsWith(host, '.apps.allenai.org') then 23 | if std.length(std.split(host, '.')) != 4 then 24 | true 25 | else 26 | false 27 | else 28 | true, 29 | 30 | /** 31 | * Groups by the provided TLDs. Returns a tuple. The first value is a map of hosts 32 | * by TLD. The second a list of hosts that didn't match a TLD. 33 | */ 34 | groupHosts(hosts, tlds): 35 | local byTLD = { [tld]: std.filter(function(host) std.endsWith(host, tld), hosts) for tld in tlds }; 36 | local rest = std.filter(function(host) !self.any([ std.endsWith(host, tld) for tld in tlds ]), hosts); 37 | [ byTLD, rest ], 38 | 39 | hasCustomHost(hosts): 40 | std.length(std.filter(util.isCustomHost, hosts)) > 0, 41 | 42 | /** 43 | * Returns a list of hostnames, given the provided environment identifier, Skiff config 44 | * and top level domain. 45 | */ 46 | getHosts(env, config, tld): 47 | if env == 'prod' then 48 | [ config.appName + tld ] 49 | else 50 | [ config.appName + '-' + env + tld ], 51 | 52 | /** 53 | * Returns a few TLS related constructs given the provided hosts. If the application is 54 | * only using direct subdomains of `.apps.allenai.org` then an empty configuration is provided, 55 | * as the wildcard certificate that's managed by Skiff Bermuda can be used instead. 56 | */ 57 | getTLSConfig(fqn, hosts): { 58 | local needsTLSCert = util.hasCustomHost(hosts), 59 | 60 | ingressAnnotations: 61 | if needsTLSCert then 62 | { 'cert-manager.io/cluster-issuer': 'letsencrypt-prod' } 63 | else {}, 64 | spec: 65 | if needsTLSCert then 66 | { secretName: fqn + '-tls' } 67 | else 68 | {}, 69 | }, 70 | 71 | /** 72 | * Returns the path to authenticate requets with our Skiff Login system (OAuth2 Proxy). 73 | * If config has an array of strings in the field "login_allowed_emails", then they are 74 | * used to limit access to account with those email addresses. 75 | */ 76 | authPath(config): 77 | if !('login_allowed_emails' in config) then 78 | '/oauth2/auth' 79 | else if std.length(config.login_allowed_emails) > 0 then 80 | '/oauth2/auth?allowed_emails=' + std.join(',', config.login_allowed_emails), 81 | 82 | /** 83 | * Returns Ingress annotations that enable authentication, given the provided Skiff config. 84 | */ 85 | getAuthAnnotations(config, tld): 86 | if !('login' in config) then 87 | {} 88 | else if config.login == "ai2" then 89 | { 90 | 'nginx.ingress.kubernetes.io/auth-url': 'https://ai2.login' + tld + $.authPath(config), 91 | 'nginx.ingress.kubernetes.io/auth-signin': 'https://ai2.login' + tld + '/oauth2/start?rd=https://$host$request_uri', 92 | 'nginx.ingress.kubernetes.io/auth-response-headers': 'X-Auth-Request-User, X-Auth-Request-Email' 93 | } 94 | else if config.login == "google" then 95 | { 96 | 'nginx.ingress.kubernetes.io/auth-url': 'https://google.login' + tld + $.authPath(config), 97 | 'nginx.ingress.kubernetes.io/auth-signin': 'https://google.login' + tld + '/oauth2/start?rd=https://$host$request_uri', 98 | 'nginx.ingress.kubernetes.io/auth-response-headers': 'X-Auth-Request-User, X-Auth-Request-Email' 99 | } 100 | else 101 | error 'Unknown login type: ' + config.login, 102 | } 103 | 104 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.mypy-type-checker", 4 | "charliermarsh.ruff" 5 | ] 6 | } -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python Debugger: FastAPI", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "module": "uvicorn", 12 | "cwd": "${workspaceFolder}/api", 13 | "args": [ 14 | "app:app", 15 | "--reload" 16 | ], 17 | "jinja": true 18 | } 19 | ] 20 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cSpell.words": [ 3 | "apiflask", 4 | "disp", 5 | "dolma", 6 | "fastapi", 7 | "infini", 8 | "Infinigram", 9 | "jsonlogger", 10 | "maxnum", 11 | "olmo", 12 | "olmoe", 13 | "pileval", 14 | "pydantic", 15 | "pythonjsonlogger" 16 | ], 17 | "mypy-type-checker.args": [ 18 | "--config-file=./pyproject.toml" 19 | ], 20 | "sqltools.connections": [ 21 | { 22 | "previewLimit": 50, 23 | "server": "localhost", 24 | "port": 5432, 25 | "driver": "PostgreSQL", 26 | "name": "infini-gram", 27 | "database": "infini-gram", 28 | "username": "infini-gram" 29 | } 30 | ] 31 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Infini-gram API 2 | 3 | This API is a wrapper over [infini-gram](https://infini-gram.io) to allow it to be used through an API at scale. It's a uv workspace with two applications ([the API](./api/README.md) and [the worker](./attribution_worker/README.md)) and one library to share code between the two ([infini-gram-processor](./packages/infini-gram-processor/README.md)). 4 | 5 | ## Reference 6 | This application is only made possible by researchers that worked on the infini-gram paper: 7 | Liu, Jiacheng and Min, Sewon and Zettlemoyer, Luke and Choi, Yejin and Hajishirzi, Hannaneh (2024). 8 | Infini-gram: Scaling Unbounded n-gram Language Models to a Trillion Tokens. 9 | arXiv preprint arXiv:2401.17377, 10 | 11 | ## Getting Started 12 | To develop in this repo, see the [contributing doc](./docs/CONTRIBUTING.md). 13 | 14 | ## Indexes 15 | You can find the index documentation [here](./docs/indexes). 16 | 17 | ## Architecture 18 | 19 | ```mermaid 20 | flowchart TB 21 | queue@{ shape: cyl, label: "Queue"} 22 | indexes@{ shape: lin-cyl, label: "Indexes" } 23 | api@{ shape: rounded, label: "API" } 24 | worker@{ shape: rounded, label: "Attribution Worker" } 25 | proxy@{ shape: stadium, label: "Web Proxy" } 26 | infini-gram@{ shape: subproc, label: "infini-gram" } 27 | 28 | api <-- Add jobs, receive results --> queue 29 | worker <-- Receive jobs, send results --> queue 30 | api --> infini-gram 31 | worker --> infini-gram 32 | infini-gram --> indexes 33 | proxy --> api 34 | ``` 35 | 36 | This application is deployed on Ai2's [Skiff](https://skiff.allenai.org/) platform. It's a wrapper over k8s designed to streamline development and deployment. 37 | 38 | The API and worker are in different deployments and are separately scalable. 39 | 40 | Both the API and worker access infini-gram and the associated indexes. The API will pass any `attribution` requests to the queue and await the result. The worker reads requests from the queue and works them, returning the result to the queue when finished. Requests other than `attribution` will be handled in the API. `attribution` requests are split off because they take much longer, which was causing the server to hang under load. 41 | 42 | -------------------------------------------------------------------------------- /api/.env: -------------------------------------------------------------------------------- 1 | LOG_LEVEL="DEBUG" 2 | INDEX_BASE_PATH="../infinigram-array" 3 | ATTRIBUTION_QUEUE_URL="postgres://infini-gram:llmz@localhost:5432/infini-gram?sslmode=disable&application_name=infini-gram-attribution-worker" 4 | PYTHON_ENV="development" 5 | CACHE_URL="redis://localhost:6379" 6 | VENDOR_BASE_PATH="../vendor" -------------------------------------------------------------------------------- /api/.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /api/Dockerfile: -------------------------------------------------------------------------------- 1 | # An example using multi-stage image builds to create a final image without uv. 2 | # Taken from https://github.com/astral-sh/uv-docker-example/blob/main/multistage.Dockerfile 3 | 4 | FROM python:3.12-slim AS builder 5 | 6 | ENV PYTHONUNBUFFERED 1 7 | ENV PYTHONDONTWRITEBYTECODE 1 8 | 9 | COPY --from=ghcr.io/astral-sh/uv:0.6.8 /uv /bin/uv 10 | 11 | WORKDIR /app 12 | 13 | # uv keeps package info at the root of the workspace 14 | # Make sure you build this at the root context 15 | COPY uv.lock pyproject.toml .python-version /app/ 16 | 17 | COPY vendor vendor 18 | COPY packages/ packages/ 19 | 20 | 21 | RUN --mount=type=cache,target=/root/.cache/uv \ 22 | --mount=type=bind,source=uv.lock,target=uv.lock \ 23 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ 24 | uv sync --frozen --no-install-project --no-dev --package infini-gram-api --no-install-workspace 25 | 26 | COPY ./api ./api 27 | 28 | RUN --mount=type=cache,target=/root/.cache/uv \ 29 | uv sync --frozen --no-dev --no-editable --package infini-gram-api 30 | 31 | 32 | # Then, use a final image without uv 33 | FROM python:3.12-slim-bookworm as runner 34 | 35 | # It is important to use the image that matches the builder, as the path to the 36 | # Python executable must be the same, e.g., using `python:3.11-slim-bookworm` 37 | # will fail. 38 | 39 | # Copy the application from the builder 40 | COPY --from=builder --chown=app:app /app /app 41 | 42 | WORKDIR /app/api 43 | 44 | COPY vendor/llama-2-7b-hf/ vendor/llama-2-7b-hf/ 45 | COPY vendor/olmo-7b-hf/ vendor/olmo-7b-hf/ 46 | COPY vendor/llama-2_bow_ids.txt vendor/llama-2_bow_ids.tx 47 | 48 | ENV PATH "/app/.venv/bin:$PATH" 49 | ENV TRANSFORMERS_NO_ADVISORY_WARNINGS 1 50 | 51 | FROM runner as dev 52 | ENV OTEL_SERVICE_NAME=infinigram-api-dev 53 | ENV OTEL_TRACES_EXPORTER=otlp 54 | ENV OTEL_METRICS_EXPORTER=otlp 55 | ENV OTEL_EXPORTER_OTLP_ENDPOINT="http://otelcol:4318" 56 | ENV ENV=development 57 | 58 | CMD ["fastapi", "dev", "app.py", "--port", "8000", "--proxy-headers", "--host", "0.0.0.0"] 59 | 60 | FROM runner as prod 61 | ENV OTEL_SERVICE_NAME=infinigram-api 62 | ENV ENV=production 63 | 64 | CMD ["fastapi", "run", "app.py", "--port", "8000", "--proxy-headers"] 65 | -------------------------------------------------------------------------------- /api/README.md: -------------------------------------------------------------------------------- 1 | # Infini-gram API 2 | 3 | This is a FastAPI-based Web API. It handles all the requests for infini-gram other than `attribution`. -------------------------------------------------------------------------------- /api/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from contextlib import asynccontextmanager 4 | from typing import Any, AsyncGenerator 5 | 6 | from fastapi import FastAPI 7 | from fastapi_problem.handler import add_exception_handler 8 | from infini_gram_processor.infini_gram_engine_exception import InfiniGramEngineException 9 | from opentelemetry import trace 10 | from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter 11 | from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter 12 | from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor 13 | from opentelemetry.sdk.trace import TracerProvider 14 | from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor 15 | from src import glog 16 | from src.attribution import attribution_router 17 | from src.attribution.attribution_queue_service import ( 18 | connect_to_attribution_queue, 19 | disconnect_from_attribution_queue, 20 | ) 21 | from src.cache.redis import create_connection_pool 22 | from src.config import get_config 23 | from src.health import health_router 24 | from src.infini_gram_exception_handler import infini_gram_engine_exception_handler 25 | from src.infinigram import infinigram_router 26 | 27 | # If LOG_FORMAT is "google:json" emit log message as JSON in a format Google Cloud can parse. 28 | fmt = os.getenv("LOG_FORMAT") 29 | handlers = [glog.create_stream_handler()] if fmt == "google:json" else [] 30 | level = os.environ.get("LOG_LEVEL", default=logging.INFO) 31 | logging.basicConfig(level=level, handlers=handlers) 32 | 33 | 34 | # https://fastapi.tiangolo.com/advanced/events/ 35 | @asynccontextmanager 36 | async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]: 37 | config = get_config() 38 | create_connection_pool(config.cache_url) 39 | # Things before yield on on startup 40 | await connect_to_attribution_queue() 41 | yield 42 | # Things after yield run on shutdown 43 | await disconnect_from_attribution_queue() 44 | 45 | 46 | app = FastAPI(title="infini-gram API", version="0.0.1", lifespan=lifespan) 47 | add_exception_handler( 48 | app, 49 | handlers={InfiniGramEngineException: infini_gram_engine_exception_handler}, # type: ignore 50 | ) 51 | 52 | app.include_router(health_router) 53 | app.include_router(router=infinigram_router) 54 | app.include_router(router=attribution_router) 55 | 56 | tracer_provider = TracerProvider() 57 | 58 | if os.getenv("ENV") == "development": 59 | tracer_provider.add_span_processor( 60 | span_processor=SimpleSpanProcessor(OTLPSpanExporter()) 61 | ) 62 | else: 63 | tracer_provider.add_span_processor( 64 | BatchSpanProcessor(CloudTraceSpanExporter(project_id="ai2-reviz")) # type:ignore 65 | ) 66 | 67 | trace.set_tracer_provider(tracer_provider) 68 | 69 | FastAPIInstrumentor.instrument_app(app, excluded_urls="health") 70 | -------------------------------------------------------------------------------- /api/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "infini-gram-api" 3 | version = "0.1.0" 4 | requires-python = ">=3.12" 5 | dependencies = [ 6 | "fastapi==0.111.0", 7 | "infini-gram", 8 | "numpy<2.0.0", 9 | "opentelemetry-api==1.30.0", 10 | "opentelemetry-exporter-gcp-trace==1.9.0", 11 | "opentelemetry-exporter-otlp-proto-http==1.30.0", 12 | "opentelemetry-instrumentation-fastapi==0.51b0", 13 | "opentelemetry-sdk==1.30.0", 14 | "pydantic-settings==2.3.4", 15 | "python-json-logger==2.0.7", 16 | "requests==2.32.3", 17 | "saq[postgres]>=0.22.4", 18 | "transformers==4.49.0", 19 | "types-requests==2.32.0.20240914", 20 | "psycopg[binary,pool]>=3.2.6", 21 | "fastapi_problem==0.10.7", 22 | "redis[hiredis]==5.2.1", 23 | "infini-gram-processor", 24 | ] 25 | -------------------------------------------------------------------------------- /api/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/infinigram-api/e858271c372d6accf13d12dd46bcca7fefe570ee/api/src/__init__.py -------------------------------------------------------------------------------- /api/src/attribution/__init__.py: -------------------------------------------------------------------------------- 1 | from .attribution_router import attribution_router as attribution_router 2 | -------------------------------------------------------------------------------- /api/src/attribution/attribution_queue_service.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from fastapi import Depends 4 | from saq import Queue 5 | 6 | from src.config import get_config 7 | 8 | queue = Queue.from_url( 9 | get_config().attribution_queue_url, name=get_config().attribution_queue_name 10 | ) 11 | 12 | 13 | async def connect_to_attribution_queue() -> None: 14 | await queue.connect() 15 | 16 | 17 | async def disconnect_from_attribution_queue() -> None: 18 | await queue.disconnect() 19 | 20 | 21 | def get_queue() -> Queue: 22 | return queue 23 | 24 | 25 | AttributionQueueDependency = Annotated[Queue, Depends(get_queue)] 26 | -------------------------------------------------------------------------------- /api/src/attribution/attribution_request.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from infini_gram_processor.models import SpanRankingMethod 4 | from pydantic import ConfigDict, Field 5 | 6 | from src.camel_case_model import CamelCaseModel 7 | 8 | EXAMPLE_ATTRIBUTION_RESPONSE = "Hailing a taxi in Rome is fairly easy. Expect to pay around EUR 10-15 (approx. $11.29 - $15.58) to most tourist spots. Tipping isn't common in Italy, but round up the taxi fare or leave a small tip in the event of exceptional service. car rental is an alternative, but traffic in Rome can be daunting for newbies. If you decide to rent a car, make sure you're comfortable navigating busy medieval streets." 9 | 10 | 11 | class AttributionRequest(CamelCaseModel): 12 | model_config = ConfigDict(frozen=True) 13 | 14 | response: str = Field(examples=[EXAMPLE_ATTRIBUTION_RESPONSE]) 15 | delimiters: List[str] = Field( 16 | examples=[["\n", "."]], 17 | default=[], 18 | description="Token IDs that returned spans shouldn't include", 19 | ) 20 | allow_spans_with_partial_words: bool = Field( 21 | default=False, 22 | description="Setting this to False will only check for attributions that start and end with a full word", 23 | ) 24 | minimum_span_length: int = Field( 25 | gt=0, 26 | default=1, 27 | description='The minimum length to qualify an n-gram span as "interesting"', 28 | ) 29 | maximum_frequency: int = Field( 30 | gt=0, 31 | default=10, 32 | description='The maximum frequency that an n-gram span can have in an index for us to consider it as "interesting"', 33 | ) 34 | maximum_span_density: float = Field( 35 | gt=0, 36 | default=0.05, 37 | description="The maximum density of spans (measured in number of spans per response token) to return in the response", 38 | ) 39 | span_ranking_method: SpanRankingMethod = Field( 40 | default=SpanRankingMethod.LENGTH, 41 | description="Ranking method when capping number of spans with maximum_span_density, options are 'length' and 'unigram_logprob_sum'", 42 | ) 43 | maximum_documents_per_span: int = Field( 44 | gt=0, 45 | default=10, 46 | description="The maximum number of documents to retrieve for each span; should be no larger than maximum_frequency", 47 | ) 48 | maximum_context_length: int = Field( 49 | gt=0, 50 | default=250, 51 | description="The maximum number of tokens of the context (on each side) to retrieve from the document", 52 | ) 53 | maximum_context_length_long: int = Field( 54 | gt=0, 55 | default=100, 56 | description="The maximum number of tokens of the context (on each side) for the document modal", 57 | ) 58 | maximum_context_length_snippet: int = Field( 59 | gt=0, 60 | default=40, 61 | description="The maximum number of tokens of the context (on each side) for the snippet in document cards", 62 | ) 63 | -------------------------------------------------------------------------------- /api/src/attribution/attribution_router.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from fastapi import APIRouter, Depends 4 | from fastapi_problem.handler import generate_swagger_response 5 | 6 | from src.attribution.attribution_request import AttributionRequest 7 | from src.attribution.attribution_service import ( 8 | AttributionResponse, 9 | AttributionService, 10 | AttributionTimeoutError, 11 | ) 12 | 13 | attribution_router = APIRouter() 14 | 15 | 16 | @attribution_router.post( 17 | path="/{index}/attribution", 18 | responses={ 19 | AttributionTimeoutError.status: generate_swagger_response( 20 | AttributionTimeoutError # type: ignore 21 | ) 22 | }, 23 | ) 24 | async def get_document_attributions( 25 | index: str, 26 | body: AttributionRequest, 27 | attribution_service: Annotated[AttributionService, Depends()], 28 | ) -> AttributionResponse: 29 | result = await attribution_service.get_attribution_for_response(index, body) 30 | 31 | return result 32 | -------------------------------------------------------------------------------- /api/src/attribution/attribution_service.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from hashlib import sha256 3 | from typing import Any, List, Optional, Sequence 4 | from uuid import uuid4 5 | 6 | from infini_gram_processor.models import ( 7 | BaseInfiniGramResponse, 8 | Document, 9 | ) 10 | from infini_gram_processor.processor import ( 11 | InfiniGramProcessor, 12 | ) 13 | from opentelemetry import trace 14 | from opentelemetry.semconv.trace import SpanAttributes 15 | from opentelemetry.trace import SpanKind, Status, StatusCode 16 | from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator 17 | from pydantic import Field, ValidationError 18 | from redis.asyncio import Redis 19 | from rfc9457 import StatusProblem 20 | from saq import Queue 21 | 22 | from src.attribution.attribution_queue_service import AttributionQueueDependency 23 | from src.attribution.attribution_request import AttributionRequest 24 | from src.cache import CacheDependency 25 | from src.camel_case_model import CamelCaseModel 26 | from src.config import get_config 27 | from src.documents.documents_router import DocumentsServiceDependency 28 | from src.documents.documents_service import ( 29 | DocumentsService, 30 | ) 31 | from src.infinigram.infini_gram_dependency import InfiniGramProcessorDependency 32 | 33 | tracer = trace.get_tracer(get_config().application_name) 34 | logger = logging.getLogger("uvicorn.error") 35 | 36 | _TASK_NAME_KEY = "saq.task_name" 37 | _TASK_TAG_KEY = "saq.action" 38 | 39 | 40 | class AttributionDocument(Document): 41 | display_length_long: int 42 | needle_offset_long: int 43 | text_long: str 44 | display_offset_snippet: int 45 | needle_offset_snippet: int 46 | text_snippet: str 47 | 48 | 49 | class AttributionSpan(CamelCaseModel): 50 | left: int 51 | right: int 52 | length: int 53 | count: int 54 | unigram_logprob_sum: float 55 | text: str 56 | token_ids: Sequence[int] 57 | documents: List[AttributionDocument] 58 | 59 | 60 | class AttributionResponse(BaseInfiniGramResponse): 61 | spans: Sequence[AttributionSpan] 62 | input_tokens: Optional[Sequence[str]] = Field( 63 | examples=[["busy", " medieval", " streets", "."]] 64 | ) 65 | 66 | 67 | class AttributionTimeoutError(StatusProblem): 68 | type_ = "server-overloaded" 69 | title = "Server overloaded" 70 | status = 503 71 | 72 | 73 | class AttributionService: 74 | infini_gram_processor: InfiniGramProcessor 75 | documents_service: DocumentsService 76 | attribution_queue: Queue 77 | cache: Redis 78 | 79 | def __init__( 80 | self, 81 | infini_gram_processor: InfiniGramProcessorDependency, 82 | documents_service: DocumentsServiceDependency, 83 | attribution_queue: AttributionQueueDependency, 84 | cache: CacheDependency, 85 | ): 86 | self.infini_gram_processor = infini_gram_processor 87 | self.documents_service = documents_service 88 | self.attribution_queue = attribution_queue 89 | self.cache = cache 90 | 91 | def _get_cache_key(self, index: str, request: AttributionRequest) -> bytes: 92 | combined_index_and_request = ( 93 | f"{request.__class__.__qualname__}::{index}{request.model_dump_json()}" 94 | ) 95 | key = sha256( 96 | combined_index_and_request.encode("utf-8", errors="ignore") 97 | ).digest() 98 | 99 | return key 100 | 101 | @tracer.start_as_current_span("attribution_service/_get_cached_response") 102 | async def _get_cached_response( 103 | self, index: str, request: AttributionRequest 104 | ) -> AttributionResponse | None: 105 | key = self._get_cache_key(index, request) 106 | 107 | try: 108 | # Since someone asked for this again, we should keep it around longer 109 | # This sets it to expire after 12 hours 110 | cached_json = await self.cache.getex(key, ex=43_200) 111 | 112 | if cached_json is None: 113 | return None 114 | 115 | cached_response = AttributionResponse.model_validate_json(cached_json) 116 | 117 | current_span = trace.get_current_span() 118 | current_span.add_event("retrieved-cached-attribution-response") 119 | logger.debug( 120 | "Retrieved cached attribution response", 121 | ) 122 | 123 | return cached_response 124 | 125 | except ValidationError: 126 | logger.error( 127 | "Failed to parse cached response", 128 | extra={"key": key.decode("utf-8")}, 129 | exc_info=True, 130 | ) 131 | except Exception: 132 | logger.error( 133 | "Failed to retrieve cached response", 134 | exc_info=True, 135 | ) 136 | 137 | return None 138 | 139 | @tracer.start_as_current_span("attribution_service/_cache_response") 140 | async def _cache_response( 141 | self, index: str, request: AttributionRequest, json_response: str 142 | ) -> None: 143 | key = self._get_cache_key(index, request) 144 | 145 | try: 146 | # save the response and expire it after an hour 147 | await self.cache.set(key, json_response, ex=3_600) 148 | 149 | current_span = trace.get_current_span() 150 | current_span.add_event("cached-attribution-response") 151 | logger.debug( 152 | "Saved attribution response to cache", 153 | ) 154 | except Exception: 155 | logger.warning( 156 | "Failed to cache attribution response", 157 | exc_info=True, 158 | ) 159 | pass 160 | 161 | @tracer.start_as_current_span("attribution_service/get_attribution_for_response") 162 | async def get_attribution_for_response( 163 | self, index: str, request: AttributionRequest 164 | ) -> AttributionResponse: 165 | cached_response = await self._get_cached_response(index, request) 166 | if cached_response is not None: 167 | return cached_response 168 | 169 | job_key = str(uuid4()) 170 | 171 | try: 172 | logger.debug("Adding attribution request to queue", extra={"index": index}) 173 | 174 | with tracer.start_as_current_span( 175 | "attribution_service/publish_attribution_job", 176 | kind=SpanKind.PRODUCER, 177 | attributes={ 178 | _TASK_NAME_KEY: "attribute", 179 | SpanAttributes.MESSAGING_MESSAGE_ID: job_key, 180 | _TASK_TAG_KEY: "apply_async", 181 | SpanAttributes.MESSAGING_SYSTEM: "saq", 182 | }, 183 | ): 184 | otel_context: dict[str, Any] = {} 185 | TraceContextTextMapPropagator().inject(otel_context) 186 | attribute_result_json = await self.attribution_queue.apply( 187 | "attribute", 188 | timeout=60, 189 | key=job_key, 190 | index=index, 191 | input=request.response, 192 | delimiters=request.delimiters, 193 | allow_spans_with_partial_words=request.allow_spans_with_partial_words, 194 | minimum_span_length=request.minimum_span_length, 195 | maximum_frequency=request.maximum_frequency, 196 | maximum_span_density=request.maximum_span_density, 197 | span_ranking_method=request.span_ranking_method, 198 | maximum_context_length=request.maximum_context_length, 199 | maximum_context_length_long=request.maximum_context_length_long, 200 | maximum_context_length_snippet=request.maximum_context_length_snippet, 201 | maximum_documents_per_span=request.maximum_documents_per_span, 202 | otel_context=otel_context, 203 | ) 204 | 205 | attribute_result = AttributionResponse.model_validate_json( 206 | attribute_result_json 207 | ) 208 | 209 | await self._cache_response(index, request, attribute_result_json) 210 | 211 | return attribute_result 212 | except TimeoutError as ex: 213 | logger.error( 214 | "Attribution request timed out", 215 | extra={"job_key": job_key, "index": index}, 216 | ) 217 | 218 | current_span = trace.get_current_span() 219 | current_span.set_status(Status(StatusCode.ERROR)) 220 | current_span.record_exception(ex) 221 | 222 | job_to_abort = await self.attribution_queue.job(job_key) 223 | if job_to_abort is not None: 224 | await self.attribution_queue.abort(job_to_abort, "Client timeout") 225 | 226 | raise AttributionTimeoutError( 227 | "The server wasn't able to process your request in time. It is likely overloaded. Please try again later." 228 | ) 229 | -------------------------------------------------------------------------------- /api/src/cache/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from fastapi import Depends 4 | from redis.asyncio import Redis 5 | 6 | from src.cache.redis import get_redis 7 | 8 | CacheDependency = Annotated[Redis, Depends(get_redis)] 9 | -------------------------------------------------------------------------------- /api/src/cache/redis.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import redis.asyncio as redis 4 | 5 | from src.config import ConfigDependency 6 | 7 | 8 | @lru_cache 9 | def create_connection_pool(url: str) -> redis.ConnectionPool: 10 | return redis.ConnectionPool.from_url(url) 11 | 12 | 13 | def get_redis(config: ConfigDependency) -> redis.Redis: 14 | redis_url = config.cache_url 15 | pool = create_connection_pool(redis_url) 16 | 17 | return redis.Redis(connection_pool=pool) 18 | -------------------------------------------------------------------------------- /api/src/camel_case_model.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from pydantic.alias_generators import to_camel 3 | 4 | 5 | class CamelCaseModel(BaseModel): 6 | class Config: 7 | alias_generator = to_camel 8 | populate_by_name = True 9 | -------------------------------------------------------------------------------- /api/src/config.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from typing import Annotated 3 | 4 | from fastapi import Depends 5 | from pydantic import computed_field 6 | from pydantic_settings import BaseSettings, SettingsConfigDict 7 | 8 | 9 | class Config(BaseSettings): 10 | model_config = SettingsConfigDict(env_file=".env", extra="ignore") 11 | 12 | index_base_path: str = "/mnt/infinigram-array" 13 | profiling_enabled: bool = False 14 | application_name: str = "infini-gram-api" 15 | attribution_queue_url: str = "redis://localhost:6379" 16 | python_env: str = "prod" 17 | cache_url: str = "redis://localhost:6379" 18 | 19 | @computed_field # type: ignore[prop-decorator] 20 | @property 21 | def attribution_queue_name(self) -> str: 22 | queue_prefix = "infini-gram-attribution" 23 | 24 | return f"{queue_prefix}-{self.python_env}" 25 | 26 | 27 | @lru_cache 28 | def get_config() -> Config: 29 | return Config() 30 | 31 | 32 | ConfigDependency = Annotated[Config, Depends(get_config)] 33 | -------------------------------------------------------------------------------- /api/src/documents/__init__.py: -------------------------------------------------------------------------------- 1 | from .documents_router import documents_router as documents_router 2 | -------------------------------------------------------------------------------- /api/src/documents/documents_router.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, TypeAlias 2 | 3 | from fastapi import APIRouter, Depends, Query 4 | from infini_gram_processor.models import GetDocumentByIndexRequest 5 | 6 | from src.documents.documents_service import ( 7 | DocumentsService, 8 | InfiniGramDocumentResponse, 9 | InfiniGramDocumentsResponse, 10 | SearchResponse, 11 | ) 12 | 13 | documents_router = APIRouter() 14 | 15 | MaximumDocumentDisplayLengthType: TypeAlias = Annotated[ 16 | int, 17 | Query( 18 | title="The maximum length in tokens of the returned document text", 19 | gt=0, 20 | ), 21 | ] 22 | 23 | DocumentsServiceDependency: TypeAlias = Annotated[DocumentsService, Depends()] 24 | 25 | 26 | @documents_router.get("/{index}/documents/", tags=["documents"]) 27 | def search_documents( 28 | documents_service: DocumentsServiceDependency, 29 | search: str, 30 | maximum_document_display_length: MaximumDocumentDisplayLengthType = 10, 31 | page: Annotated[ 32 | int, 33 | Query( 34 | title="The page of documents to retrieve from the search query. Uses the pageSize parameter as part of its calculations. Starts at 0.", 35 | ), 36 | ] = 0, 37 | page_size: Annotated[ 38 | int, 39 | Query( 40 | title="The number of documents to return from the query. Defaults to 10. Changing this will affect the documents you get from a specific page.", 41 | gt=0, 42 | ), 43 | ] = 10, 44 | ) -> SearchResponse: 45 | result = documents_service.search_documents( 46 | search, 47 | maximum_context_length=maximum_document_display_length, 48 | page=page, 49 | page_size=page_size, 50 | ) 51 | 52 | return result 53 | 54 | 55 | @documents_router.get("/{index}/documents/{document_index}", tags=["documents"]) 56 | def get_document_by_index( 57 | documents_service: DocumentsServiceDependency, 58 | document_index: int, 59 | maximum_document_display_length: MaximumDocumentDisplayLengthType = 10, 60 | ) -> InfiniGramDocumentResponse: 61 | result = documents_service.get_document_by_index( 62 | document_index=int(document_index), 63 | maximum_context_length=maximum_document_display_length, 64 | ) 65 | 66 | return result 67 | 68 | 69 | @documents_router.get("/{index}/documents", tags=["documents"]) 70 | def get_documents_by_index( 71 | documents_service: DocumentsServiceDependency, 72 | document_indexes: Annotated[list[int], Query()], 73 | maximum_document_display_length: MaximumDocumentDisplayLengthType = 10, 74 | ) -> InfiniGramDocumentsResponse: 75 | result = documents_service.get_multiple_documents_by_index( 76 | document_requests=[ 77 | GetDocumentByIndexRequest( 78 | document_index=document_index, 79 | maximum_context_length=maximum_document_display_length, 80 | ) 81 | for document_index in document_indexes 82 | ], 83 | ) 84 | 85 | return result 86 | -------------------------------------------------------------------------------- /api/src/documents/documents_service.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from typing import Iterable, List 3 | 4 | from infini_gram_processor import InfiniGramProcessor 5 | from infini_gram_processor.models import ( 6 | BaseInfiniGramResponse, 7 | Document, 8 | GetDocumentByIndexRequest, 9 | ) 10 | from opentelemetry import trace 11 | 12 | from src.config import get_config 13 | from src.infinigram.infini_gram_dependency import InfiniGramProcessorDependency 14 | 15 | tracer = trace.get_tracer(get_config().application_name) 16 | 17 | 18 | class InfiniGramDocumentResponse(Document, BaseInfiniGramResponse): ... 19 | 20 | 21 | class InfiniGramDocumentsResponse(BaseInfiniGramResponse): 22 | documents: Iterable[Document] 23 | 24 | 25 | class SearchResponse(BaseInfiniGramResponse): 26 | documents: List[Document] 27 | page: int 28 | page_size: int 29 | page_count: int 30 | total_documents: int 31 | 32 | 33 | class DocumentsService: 34 | infini_gram_processor: InfiniGramProcessor 35 | 36 | def __init__(self, infini_gram_processor: InfiniGramProcessorDependency): 37 | self.infini_gram_processor = infini_gram_processor 38 | 39 | @tracer.start_as_current_span("documents_service/search_documents") 40 | def search_documents( 41 | self, 42 | search: str, 43 | maximum_context_length: int, 44 | page_size: int, 45 | page: int, 46 | ) -> SearchResponse: 47 | search_documents_result = self.infini_gram_processor.search_documents( 48 | search=search, 49 | maximum_context_length=maximum_context_length, 50 | page=page, 51 | page_size=page_size, 52 | ) 53 | 54 | mapped_documents = [ 55 | Document( 56 | text=document.text, 57 | document_index=document.document_index, 58 | document_length=document.document_length, 59 | display_length=document.display_length, 60 | needle_offset=document.needle_offset, 61 | metadata=document.metadata, 62 | token_ids=document.token_ids, 63 | ) 64 | for document in search_documents_result.documents 65 | ] 66 | 67 | return SearchResponse( 68 | index=self.infini_gram_processor.index, 69 | documents=mapped_documents, 70 | page=page, 71 | page_size=page_size, 72 | total_documents=search_documents_result.total_documents, 73 | page_count=ceil(search_documents_result.total_documents / page_size), 74 | ) 75 | 76 | @tracer.start_as_current_span("documents_service/get_document_by_index") 77 | def get_document_by_index( 78 | self, document_index: int, maximum_context_length: int 79 | ) -> InfiniGramDocumentResponse: 80 | document = self.infini_gram_processor.get_document_by_index( 81 | document_index=document_index, 82 | maximum_context_length=maximum_context_length, 83 | ) 84 | 85 | return InfiniGramDocumentResponse( 86 | index=self.infini_gram_processor.index, 87 | document_index=document.document_index, 88 | document_length=document.document_length, 89 | display_length=document.display_length, 90 | needle_offset=document.needle_offset, 91 | metadata=document.metadata, 92 | token_ids=document.token_ids, 93 | text=document.text, 94 | ) 95 | 96 | @tracer.start_as_current_span("documents_service/get_multiple_documents_by_index") 97 | def get_multiple_documents_by_index( 98 | self, 99 | document_requests: Iterable[GetDocumentByIndexRequest], 100 | ) -> InfiniGramDocumentsResponse: 101 | documents = self.infini_gram_processor.get_documents_by_indexes( 102 | document_requests=document_requests, 103 | ) 104 | mapped_documents = [ 105 | Document( 106 | document_index=document.document_index, 107 | document_length=document.document_length, 108 | display_length=document.display_length, 109 | needle_offset=document.needle_offset, 110 | metadata=document.metadata, 111 | token_ids=document.token_ids, 112 | text=document.text, 113 | ) 114 | for document in documents 115 | ] 116 | return InfiniGramDocumentsResponse( 117 | index=self.infini_gram_processor.index, documents=mapped_documents 118 | ) 119 | -------------------------------------------------------------------------------- /api/src/glog.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import TextIO 3 | 4 | from pythonjsonlogger import jsonlogger 5 | 6 | 7 | def create_stream_handler() -> logging.StreamHandler[TextIO]: 8 | handler = logging.StreamHandler() 9 | """ 10 | Custom log formatter that emits log messages as JSON, with the "severity" field 11 | which Google Cloud uses to differentiate message levels and various opentelemetry mappings. 12 | """ 13 | formatter = jsonlogger.JsonFormatter( 14 | # taken from https://cloud.google.com/trace/docs/setup/python-ot#config-structured-logging 15 | rename_fields={ 16 | "levelname": "severity", 17 | "asctime": "timestamp", 18 | "otelTraceID": "logging.googleapis.com/trace", 19 | "otelSpanID": "logging.googleapis.com/spanId", 20 | "otelTraceSampled": "logging.googleapis.com/trace_sampled", 21 | }, 22 | timestamp=True, 23 | datefmt="%Y-%m-%dT%H:%M:%SZ", 24 | ) 25 | handler.setFormatter(formatter) 26 | 27 | return handler 28 | -------------------------------------------------------------------------------- /api/src/health.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, status 2 | 3 | health_router = APIRouter(prefix="/health") 4 | 5 | 6 | # This tells the machinery that powers Skiff (Kubernetes) that your application 7 | # is ready to receive traffic. Returning a non 2XX response code will prevent the 8 | # application from receiving live requests. 9 | @health_router.get("/", status_code=status.HTTP_204_NO_CONTENT) 10 | def health() -> None: 11 | return 12 | -------------------------------------------------------------------------------- /api/src/infini_gram_exception_handler.py: -------------------------------------------------------------------------------- 1 | from fastapi_problem.error import Problem 2 | from fastapi_problem.handler import ExceptionHandler 3 | from infini_gram_processor.infini_gram_engine_exception import InfiniGramEngineException 4 | from rfc9457 import error_class_to_type 5 | from starlette.requests import Request 6 | 7 | 8 | def infini_gram_engine_exception_handler( 9 | handler: ExceptionHandler, request: Request, exception: InfiniGramEngineException 10 | ) -> Problem: 11 | return Problem( 12 | title="infini-gram error", 13 | status=500, 14 | detail=exception.detail, 15 | type=error_class_to_type(exception), 16 | ) 17 | -------------------------------------------------------------------------------- /api/src/infinigram/__init__.py: -------------------------------------------------------------------------------- 1 | from .infinigram_router import infinigram_router as infinigram_router 2 | -------------------------------------------------------------------------------- /api/src/infinigram/infini_gram_dependency.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from fastapi import Depends 4 | from infini_gram_processor import indexes 5 | from infini_gram_processor.index_mappings import AvailableInfiniGramIndexId 6 | from infini_gram_processor.processor import InfiniGramProcessor 7 | 8 | 9 | def InfiniGramProcessorFactoryPathParam( 10 | index: AvailableInfiniGramIndexId, 11 | ) -> InfiniGramProcessor: 12 | return indexes[index] 13 | 14 | 15 | InfiniGramProcessorDependency = Annotated[ 16 | InfiniGramProcessor, Depends(InfiniGramProcessorFactoryPathParam) 17 | ] 18 | -------------------------------------------------------------------------------- /api/src/infinigram/infinigram_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from infini_gram_processor.index_mappings import AvailableInfiniGramIndexId 3 | 4 | infinigram_router = APIRouter() 5 | 6 | 7 | @infinigram_router.get(path="/indexes") 8 | def get_available_indexes() -> list[AvailableInfiniGramIndexId]: 9 | return [index for index in AvailableInfiniGramIndexId] 10 | -------------------------------------------------------------------------------- /attribution_worker/Dockerfile: -------------------------------------------------------------------------------- 1 | # An example using multi-stage image builds to create a final image without uv. 2 | # Taken from https://github.com/astral-sh/uv-docker-example/blob/main/multistage.Dockerfile 3 | 4 | FROM python:3.12-slim AS builder 5 | 6 | ENV PYTHONUNBUFFERED 1 7 | ENV PYTHONDONTWRITEBYTECODE 1 8 | 9 | COPY --from=ghcr.io/astral-sh/uv:0.6.8 /uv /bin/uv 10 | 11 | WORKDIR /app 12 | 13 | # uv keeps package info at the root of the workspace 14 | # Make sure you build this at the root context 15 | COPY uv.lock pyproject.toml .python-version /app/ 16 | 17 | COPY vendor vendor 18 | COPY packages/ packages/ 19 | 20 | RUN --mount=type=cache,target=/root/.cache/uv \ 21 | --mount=type=bind,source=uv.lock,target=uv.lock \ 22 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ 23 | uv sync --frozen --no-install-project --no-dev --package infini-gram-attribution-worker --no-install-workspace 24 | 25 | COPY ./attribution_worker ./attribution_worker 26 | 27 | RUN --mount=type=cache,target=/root/.cache/uv \ 28 | uv sync --frozen --no-dev --no-editable --package infini-gram-attribution-worker 29 | 30 | 31 | # Then, use a final image without uv 32 | FROM python:3.12-slim-bookworm as base 33 | 34 | WORKDIR /app 35 | 36 | # It is important to use the image that matches the builder, as the path to the 37 | # Python executable must be the same, e.g., using `python:3.11-slim-bookworm` 38 | # will fail. 39 | 40 | # Copy the application from the builder 41 | COPY --from=builder --chown=app:app /app /app 42 | 43 | WORKDIR /app 44 | 45 | # Place executables in the environment at the front of the path 46 | ENV PATH="/app/.venv/bin:$PATH" 47 | ENV TRANSFORMERS_NO_ADVISORY_WARNINGS 1 48 | 49 | FROM base as dev 50 | 51 | ENV OTEL_SERVICE_NAME=infinigram-api-worker-dev 52 | ENV OTEL_TRACES_EXPORTER=otlp 53 | ENV OTEL_METRICS_EXPORTER=otlp 54 | ENV OTEL_EXPORTER_OTLP_ENDPOINT="http://otelcol:4318" 55 | ENV ENV=development 56 | 57 | CMD ["saq", "--verbose", "--web", "attribution_worker.worker_settings"] 58 | 59 | FROM base as prod 60 | 61 | CMD ["saq", "--web", "attribution_worker.worker_settings"] -------------------------------------------------------------------------------- /attribution_worker/README.md: -------------------------------------------------------------------------------- 1 | # Infini-gram Attribution Worker 2 | 3 | This is a worker that the API can offload long-running tasks to. -------------------------------------------------------------------------------- /attribution_worker/__init__.py: -------------------------------------------------------------------------------- 1 | from .worker import settings as worker_settings # noqa: F401 2 | -------------------------------------------------------------------------------- /attribution_worker/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import computed_field 2 | from pydantic_settings import BaseSettings, SettingsConfigDict 3 | 4 | 5 | class Config(BaseSettings): 6 | model_config = SettingsConfigDict(env_file=".env", extra="ignore") 7 | 8 | index_base_path: str = "/mnt/infinigram-array" 9 | application_name: str = "infini-gram-api-worker" 10 | attribution_queue_url: str = "redis://localhost:6379" 11 | python_env: str = "prod" 12 | 13 | @computed_field # type: ignore[prop-decorator] 14 | @property 15 | def attribution_queue_name(self) -> str: 16 | queue_prefix = "infini-gram-attribution" 17 | 18 | return f"{queue_prefix}-{self.python_env}" 19 | 20 | 21 | config = Config() 22 | 23 | 24 | def get_config() -> Config: 25 | return config 26 | -------------------------------------------------------------------------------- /attribution_worker/get_documents.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from infini_gram.models import AttributionSpan as AttributionSpanFromEngine 4 | from infini_gram_processor.models import ( 5 | AttributionDocument, 6 | AttributionSpan, 7 | Document, 8 | GetDocumentByPointerRequest, 9 | SpanRankingMethod, 10 | ) 11 | from infini_gram_processor.processor import InfiniGramProcessor 12 | 13 | from .get_span_text import get_span_text 14 | 15 | 16 | def cut_document( 17 | infini_gram_index: InfiniGramProcessor, 18 | token_ids: list[int], 19 | needle_offset: int, 20 | span_length: int, 21 | maximum_context_length: int, 22 | ) -> tuple[int, int, str]: 23 | # cut the left context if necessary 24 | if needle_offset > maximum_context_length: 25 | token_ids = token_ids[(needle_offset - maximum_context_length) :] 26 | needle_offset = maximum_context_length 27 | # cut the right context if necessary 28 | if len(token_ids) - needle_offset - span_length > maximum_context_length: 29 | token_ids = token_ids[: (needle_offset + span_length + maximum_context_length)] 30 | display_length = len(token_ids) 31 | text = infini_gram_index.decode_tokens(token_ids) 32 | return display_length, needle_offset, text 33 | 34 | 35 | def get_spans_with_documents( 36 | infini_gram_index: InfiniGramProcessor, 37 | spans: list[AttributionSpanFromEngine], 38 | documents_by_span: list[list[Document]], 39 | input_token_ids: list[int], 40 | maximum_context_length_long: int, 41 | maximum_context_length_snippet: int, 42 | ) -> list[AttributionSpan]: 43 | spans_with_documents: list[AttributionSpan] = [] 44 | for span, documents in zip(spans, documents_by_span): 45 | span_documents: list[AttributionDocument] = [] 46 | for document in documents: 47 | display_length_long, needle_offset_long, text_long = cut_document( 48 | infini_gram_index=infini_gram_index, 49 | token_ids=document.token_ids, 50 | needle_offset=document.needle_offset, 51 | span_length=span["length"], 52 | maximum_context_length=maximum_context_length_long, 53 | ) 54 | 55 | display_length_snippet, needle_offset_snippet, text_snippet = cut_document( 56 | infini_gram_index=infini_gram_index, 57 | token_ids=document.token_ids, 58 | needle_offset=document.needle_offset, 59 | span_length=span["length"], 60 | maximum_context_length=maximum_context_length_snippet, 61 | ) 62 | 63 | span_documents.append( 64 | AttributionDocument( 65 | **vars(document), 66 | display_length_long=display_length_long, 67 | needle_offset_long=needle_offset_long, 68 | text_long=text_long, 69 | display_offset_snippet=display_length_snippet, 70 | needle_offset_snippet=needle_offset_snippet, 71 | text_snippet=text_snippet, 72 | ) 73 | ) 74 | 75 | (span_text_tokens, span_text) = get_span_text( 76 | infini_gram_index=infini_gram_index, 77 | input_token_ids=input_token_ids, 78 | start=span["l"], 79 | stop=span["r"], 80 | ) 81 | 82 | new_span_with_documents = AttributionSpan( 83 | left=span["l"], 84 | right=span["r"], 85 | length=span["length"], 86 | count=span["count"], 87 | unigram_logprob_sum=span["unigram_logprob_sum"], 88 | text=span_text, 89 | token_ids=span_text_tokens, 90 | documents=span_documents, 91 | ) 92 | 93 | spans_with_documents.append(new_span_with_documents) 94 | 95 | return spans_with_documents 96 | 97 | 98 | def get_document_requests( 99 | spans: list[AttributionSpanFromEngine], 100 | input_token_ids: list[int], 101 | maximum_documents_per_span: int, 102 | maximum_context_length: int, 103 | ) -> list[GetDocumentByPointerRequest]: 104 | document_request_by_span: list[GetDocumentByPointerRequest] = [] 105 | for span in spans: 106 | docs = span["docs"] 107 | if len(docs) > maximum_documents_per_span: 108 | random.seed(42) # For reproducibility 109 | docs = random.sample(docs, maximum_documents_per_span) 110 | document_request_by_span.append( 111 | GetDocumentByPointerRequest( 112 | docs=docs, 113 | span_ids=input_token_ids[span["l"] : span["r"]], 114 | needle_length=span["length"], 115 | maximum_context_length=maximum_context_length, 116 | ) 117 | ) 118 | return document_request_by_span 119 | 120 | 121 | def sort_and_cap_spans( 122 | spans: list[AttributionSpanFromEngine], 123 | ranking_method: SpanRankingMethod, 124 | maximum_num_spans: int, 125 | ) -> list[AttributionSpanFromEngine]: 126 | sorted_spans: list[AttributionSpanFromEngine] 127 | 128 | if ranking_method == SpanRankingMethod.LENGTH: 129 | sorted_spans = sorted(spans, key=lambda x: x["length"], reverse=True) 130 | elif ranking_method == SpanRankingMethod.UNIGRAM_LOGPROB_SUM: 131 | sorted_spans = sorted( 132 | spans, 133 | key=lambda x: x["unigram_logprob_sum"], 134 | reverse=False, 135 | ) 136 | else: 137 | raise ValueError(f"Unknown span ranking method: {ranking_method}") 138 | 139 | return sorted(list(sorted_spans[:maximum_num_spans]), key=lambda span: span["l"]) 140 | -------------------------------------------------------------------------------- /attribution_worker/get_span_text.py: -------------------------------------------------------------------------------- 1 | from itertools import islice 2 | from typing import Iterable, Sequence 3 | 4 | from infini_gram_processor.processor import InfiniGramProcessor 5 | 6 | 7 | def get_span_text( 8 | infini_gram_index: InfiniGramProcessor, 9 | input_token_ids: Iterable[int], 10 | start: int, 11 | stop: int, 12 | ) -> tuple[Sequence[int], str]: 13 | span_text_tokens = list(islice(input_token_ids, start, stop)) 14 | span_text = infini_gram_index.decode_tokens(token_ids=span_text_tokens) 15 | 16 | return (span_text_tokens, span_text) 17 | -------------------------------------------------------------------------------- /attribution_worker/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "infini-gram-attribution-worker" 3 | version = "0.1.0" 4 | requires-python = ">=3.12" 5 | dependencies = [ 6 | "pydantic-settings==2.3.4", 7 | "pydantic==2.10.6", 8 | "saq[postgres,web]==0.22.4", 9 | "psycopg[binary,pool]>=3.2.6", 10 | "transformers==4.49.0", 11 | "infini-gram-processor", 12 | "numpy<2.0.0", 13 | "opentelemetry-api==1.30.0", 14 | "opentelemetry-exporter-gcp-trace==1.9.0", 15 | "opentelemetry-exporter-otlp-proto-http==1.30.0", 16 | "opentelemetry-sdk==1.30.0", 17 | ] 18 | -------------------------------------------------------------------------------- /attribution_worker/worker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from typing import Any 4 | 5 | import numpy as np 6 | from infini_gram_processor import indexes 7 | from infini_gram_processor.index_mappings import AvailableInfiniGramIndexId 8 | from infini_gram_processor.models import ( 9 | SpanRankingMethod, 10 | ) 11 | from infini_gram_processor.models.models import ( 12 | AttributionResponse, 13 | AttributionSpan, 14 | ) 15 | from opentelemetry import trace 16 | from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter 17 | from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter 18 | from opentelemetry.sdk.trace import TracerProvider 19 | from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor 20 | from opentelemetry.semconv.trace import SpanAttributes 21 | from opentelemetry.trace import SpanKind 22 | from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator 23 | from saq import Queue 24 | from saq.types import Context, SettingsDict 25 | 26 | from .config import get_config 27 | from .get_documents import ( 28 | get_document_requests, 29 | get_spans_with_documents, 30 | sort_and_cap_spans, 31 | ) 32 | 33 | _TASK_RUN = "run" 34 | 35 | config = get_config() 36 | 37 | queue = Queue.from_url(config.attribution_queue_url, name=config.attribution_queue_name) 38 | 39 | tracer_provider = TracerProvider() 40 | 41 | if os.getenv("ENV") == "development": 42 | tracer_provider.add_span_processor( 43 | span_processor=SimpleSpanProcessor(OTLPSpanExporter()) 44 | ) 45 | else: 46 | tracer_provider.add_span_processor( 47 | BatchSpanProcessor(CloudTraceSpanExporter(project_id="ai2-reviz")) # type:ignore 48 | ) 49 | 50 | trace.set_tracer_provider(tracer_provider) 51 | 52 | tracer = trace.get_tracer(config.application_name) 53 | 54 | _TASK_NAME_KEY = "saq.task_name" 55 | _TASK_TAG_KEY = "saq.action" 56 | 57 | 58 | async def attribution_job( 59 | ctx: Context, 60 | *, 61 | index: str, 62 | input: str, 63 | delimiters: list[str], 64 | allow_spans_with_partial_words: bool, 65 | minimum_span_length: int, 66 | maximum_frequency: int, 67 | maximum_span_density: float, 68 | span_ranking_method: SpanRankingMethod, 69 | maximum_context_length: int, 70 | maximum_context_length_long: int, 71 | maximum_context_length_snippet: int, 72 | maximum_documents_per_span: int, 73 | otel_context: dict[str, Any], 74 | ) -> str: 75 | extracted_context = TraceContextTextMapPropagator().extract(carrier=otel_context) 76 | with tracer.start_as_current_span( 77 | "attribution-worker/attribute", 78 | kind=SpanKind.CLIENT, 79 | context=extracted_context, 80 | attributes={ 81 | SpanAttributes.MESSAGING_SYSTEM: "saq", 82 | _TASK_NAME_KEY: "attribute", 83 | _TASK_TAG_KEY: "apply_async", 84 | }, 85 | ) as otel_span: 86 | job = ctx.get("job") 87 | if job is not None: 88 | otel_span.set_attribute(SpanAttributes.MESSAGING_MESSAGE_ID, job.key) 89 | 90 | worker = ctx.get("worker") 91 | if worker is not None: 92 | otel_span.set_attribute(SpanAttributes.MESSAGING_CLIENT_ID, worker.id) 93 | 94 | infini_gram_index = indexes[AvailableInfiniGramIndexId(index)] 95 | 96 | attribute_result = await asyncio.to_thread( 97 | infini_gram_index.attribute, 98 | input=input, 99 | delimiters=delimiters, 100 | allow_spans_with_partial_words=allow_spans_with_partial_words, 101 | minimum_span_length=minimum_span_length, 102 | maximum_frequency=maximum_frequency, 103 | ) 104 | 105 | # Limit the density of spans, and keep the longest ones 106 | maximum_num_spans = int( 107 | np.ceil(len(attribute_result.input_token_ids) * maximum_span_density) 108 | ) 109 | 110 | sorted_spans = sort_and_cap_spans( 111 | attribute_result.spans, 112 | ranking_method=span_ranking_method, 113 | maximum_num_spans=maximum_num_spans, 114 | ) 115 | 116 | document_request_by_span = get_document_requests( 117 | spans=sorted_spans, 118 | input_token_ids=attribute_result.input_token_ids, 119 | maximum_documents_per_span=maximum_documents_per_span, 120 | maximum_context_length=maximum_context_length, 121 | ) 122 | 123 | documents_by_span = await asyncio.to_thread( 124 | infini_gram_index.get_documents_by_pointers, 125 | document_request_by_span=document_request_by_span, 126 | ) 127 | 128 | spans_with_documents: list[AttributionSpan] = get_spans_with_documents( 129 | infini_gram_index=infini_gram_index, 130 | spans=sorted_spans, 131 | documents_by_span=documents_by_span, 132 | input_token_ids=attribute_result.input_token_ids, 133 | maximum_context_length_long=maximum_context_length_long, 134 | maximum_context_length_snippet=maximum_context_length_snippet, 135 | ) 136 | 137 | response = AttributionResponse( 138 | index=infini_gram_index.index, 139 | spans=spans_with_documents, 140 | input_tokens=infini_gram_index.tokenize_to_list(input), 141 | ) 142 | return response.model_dump_json() 143 | 144 | 145 | settings = SettingsDict( 146 | queue=queue, functions=[("attribute", attribution_job)], concurrency=1 147 | ) 148 | -------------------------------------------------------------------------------- /bin/download-infini-gram-array.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 3 | INFINIGRAM_ARRAY_DIR=$SCRIPT_DIR/../infinigram-array 4 | 5 | # If you add an array here make sure you add it to docker-compose.yaml too 6 | 7 | echo $INFINIGRAM_ARRAY_DIR 8 | if [ ! -d $INFINIGRAM_ARRAY_DIR/v4_pileval_llama ]; then 9 | echo "Downloading v4_pileval_llama array" 10 | aws s3 cp --no-sign-request --recursive s3://infini-gram-lite/index/v4_pileval_llama $INFINIGRAM_ARRAY_DIR/v4_pileval_llama 11 | fi 12 | 13 | if [ ! -d $INFINIGRAM_ARRAY_DIR/olmoe-mix-0924-dclm ]; then 14 | echo "creating a link from v4_pileval_llama to olmoe-mix-0924-dclm" 15 | ln -s $INFINIGRAM_ARRAY_DIR/v4_pileval_llama $INFINIGRAM_ARRAY_DIR/olmoe-mix-0924-dclm 16 | fi 17 | 18 | if [ ! -d $INFINIGRAM_ARRAY_DIR/olmoe-mix-0924-nodclm ]; then 19 | echo "creating a link from v4_pileval_llama to olmoe-mix-0924-nodclm" 20 | ln -s $INFINIGRAM_ARRAY_DIR/v4_pileval_llama $INFINIGRAM_ARRAY_DIR/olmoe-mix-0924-nodclm 21 | fi 22 | 23 | if [ ! -d $INFINIGRAM_ARRAY_DIR/v4-olmoe-0125-1b-7b-anneal-adapt ]; then 24 | echo "creating a link from v4_pileval_llama to v4-olmoe-0125-1b-7b-anneal-adapt" 25 | ln -s $INFINIGRAM_ARRAY_DIR/v4_pileval_llama $INFINIGRAM_ARRAY_DIR/v4-olmoe-0125-1b-7b-anneal-adapt 26 | fi 27 | 28 | if [ ! -d $INFINIGRAM_ARRAY_DIR/v4-olmo-2-1124-13b-anneal-adapt ]; then 29 | echo "creating a link from v4_pileval_llama to v4-olmo-2-1124-13b-anneal-adapt" 30 | ln -s $INFINIGRAM_ARRAY_DIR/v4_pileval_llama $INFINIGRAM_ARRAY_DIR/v4-olmo-2-1124-13b-anneal-adapt 31 | fi 32 | 33 | if [ ! -d $INFINIGRAM_ARRAY_DIR/v4-olmo-2-0325-32b-anneal-adapt ]; then 34 | echo "creating a link from v4_pileval_llama to v4-olmo-2-0325-32b-anneal-adapt" 35 | ln -s $INFINIGRAM_ARRAY_DIR/v4_pileval_llama $INFINIGRAM_ARRAY_DIR/v4-olmo-2-0325-32b-anneal-adapt 36 | fi 37 | 38 | if [ ! -d $INFINIGRAM_ARRAY_DIR/v4-tulu-3-8b-adapt-llama ]; then 39 | echo "creating a link from v4_pileval_llama to v4-tulu-3-8b-adapt-llama" 40 | ln -s $INFINIGRAM_ARRAY_DIR/v4_pileval_llama $INFINIGRAM_ARRAY_DIR/v4-tulu-3-8b-adapt-llama 41 | fi 42 | 43 | if [ ! -d $INFINIGRAM_ARRAY_DIR/v4-tulu-3-70b-adapt-llama ]; then 44 | echo "creating a link from v4_pileval_llama to v4-tulu-3-70b-adapt-llama" 45 | ln -s $INFINIGRAM_ARRAY_DIR/v4_pileval_llama $INFINIGRAM_ARRAY_DIR/v4-tulu-3-70b-adapt-llama 46 | fi 47 | 48 | if [ ! -d $INFINIGRAM_ARRAY_DIR/v4-tulu-3-405b-adapt-llama ]; then 49 | echo "creating a link from v4_pileval_llama to v4-tulu-3-405b-adapt-llama" 50 | ln -s $INFINIGRAM_ARRAY_DIR/v4_pileval_llama $INFINIGRAM_ARRAY_DIR/v4-tulu-3-405b-adapt-llama 51 | fi 52 | 53 | 54 | -------------------------------------------------------------------------------- /compute_stats/batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import glob 4 | from tqdm import tqdm 5 | from collections import defaultdict 6 | import os 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--data_dir', type=str, required=True) 10 | parser.add_argument('--cpus', type=int, default=1) 11 | parser.add_argument('--workers', type=int, default=1) 12 | parser.add_argument('--output_path', type=str, required=True) 13 | args = parser.parse_args() 14 | 15 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 16 | 17 | doccnt_by_source = defaultdict(int) 18 | doccnt_by_domain_source = defaultdict(int) 19 | doccnt_by_charlen_by_source = defaultdict(lambda: defaultdict(int)) 20 | 21 | data_paths = glob.glob(f'{args.data_dir}/**/*.json*', recursive=True) 22 | data_paths = list(sorted(data_paths)) 23 | for i in tqdm(list(range(0, len(data_paths), args.workers))): 24 | processes = [] 25 | for j in range(i, min(i + args.workers, len(data_paths))): 26 | processes.append(os.popen(f'python compute_stats/worker.py --data_dir {args.data_dir} --data_path {data_paths[j]} --cpus {args.cpus // args.workers} --output_path {args.output_path}.tmp.{j}')) 27 | [p.read() for p in processes] 28 | 29 | for j in range(len(data_paths)): 30 | with open(f'{args.output_path}.tmp.{j}', 'r') as f: 31 | data = json.load(f) 32 | for k, v in data['doccnt_by_source'].items(): 33 | doccnt_by_source[k] += v 34 | for k, v in data['doccnt_by_domain_source'].items(): 35 | doccnt_by_domain_source[k] += v 36 | for k, v in data['doccnt_by_charlen_by_source'].items(): 37 | for kk, vv in v.items(): 38 | doccnt_by_charlen_by_source[k][kk] += vv 39 | 40 | doccnt_by_source = dict(sorted(doccnt_by_source.items(), key=lambda x: x[1], reverse=True)) 41 | doccnt_by_domain_source = dict(sorted(doccnt_by_domain_source.items(), key=lambda x: x[1], reverse=True)) 42 | doccnt_by_charlen_by_source = {k: dict(sorted(v.items(), key=lambda x: int(x[0]), reverse=False)) for k, v in doccnt_by_charlen_by_source.items()} 43 | 44 | with open(args.output_path, 'w') as f: 45 | json.dump({ 46 | 'doccnt_by_source': doccnt_by_source, 47 | 'doccnt_by_domain_source': doccnt_by_domain_source, 48 | 'doccnt_by_charlen_by_source': doccnt_by_charlen_by_source, 49 | }, f, indent=4) 50 | 51 | for j in range(len(data_paths)): 52 | os.remove(f'{args.output_path}.tmp.{j}') 53 | -------------------------------------------------------------------------------- /compute_stats/batch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | RUN_NAME="compute-stats_olmoe-mix-0924" 4 | 5 | gantry run \ 6 | --allow-dirty \ 7 | --name ${RUN_NAME} \ 8 | --task-name ${RUN_NAME} \ 9 | --description ${RUN_NAME} \ 10 | --workspace ai2/attribution \ 11 | --budget ai2/oe-training \ 12 | --beaker-image shanea/olmo-torch2.2-gantry \ 13 | --cluster ai2/neptune-cirrascale \ 14 | --priority normal \ 15 | --preemptible \ 16 | --no-nfs \ 17 | --weka oe-training-default:/weka/oe-training-default \ 18 | --weka oe-data-default:/weka/oe-data-default \ 19 | --cpus 248 \ 20 | --memory 900GiB \ 21 | --shared-memory 10GiB \ 22 | --no-python \ 23 | --venv base \ 24 | --yes \ 25 | -- /bin/bash -c "\ 26 | pip install zstandard tqdm ; \ 27 | python compute_stats/batch.py \ 28 | --data_dir /weka/oe-data-default/ai2-llm/pretraining-data/sources/olmo-mix/olmoe-mix-0924/documents \ 29 | --cpus 240 --workers 16 \ 30 | --output_path /weka/oe-training-default/jiachengl/stat/olmoe-mix-0924.json ; \ 31 | " 32 | -------------------------------------------------------------------------------- /compute_stats/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import zstandard as zstd 4 | import json 5 | import glob 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | import multiprocessing as mp 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--data_dir', type=str, required=True) 12 | parser.add_argument('--cpus', type=int, default=1) 13 | parser.add_argument('--output_path', type=str, required=True) 14 | args = parser.parse_args() 15 | 16 | def load_file(path): 17 | if path.endswith('.gz'): 18 | with gzip.open(path, 'rt', encoding='utf-8') as f: 19 | lines = f.readlines() 20 | elif path.endswith('.zst'): 21 | with open(path, 'rb') as f: 22 | dctx = zstd.ZstdDecompressor() 23 | with dctx.stream_reader(f) as reader: 24 | decompressed_data = reader.read().decode('utf-8') 25 | lines = decompressed_data.split('\n') 26 | if lines[-1] == '': 27 | lines = lines[:-1] 28 | elif path.endswith('.jsonl'): 29 | with open(path, encoding='utf-8') as f: 30 | lines = f.readlines() 31 | else: 32 | raise ValueError(f'Unknown file type: {path}') 33 | return lines 34 | 35 | def process(line): 36 | item = json.loads(line.strip('\n')) 37 | try: 38 | domain = item['metadata']['url'].split('/')[2] 39 | except: 40 | domain = '' 41 | return domain, len(item['text']) 42 | 43 | doccnt_by_source = defaultdict(int) 44 | doccnt_by_domain_source = defaultdict(int) 45 | doccnt_by_charlen_by_source = defaultdict(lambda: defaultdict(int)) 46 | 47 | data_paths = glob.glob(f'{args.data_dir}/**/*.json*', recursive=True) 48 | data_paths = list(sorted(data_paths)) 49 | with mp.get_context('fork').Pool(args.cpus) as p: 50 | for data_path in tqdm(data_paths): 51 | source = data_path[len(args.data_dir)+1:].split('/')[0] 52 | lines = load_file(data_path) 53 | results = p.map(process, lines) 54 | for (domain, charlen) in results: 55 | doccnt_by_source[source] += 1 56 | doccnt_by_domain_source[f'{source}/{domain}'] += 1 57 | doccnt_by_charlen_by_source[source][charlen] += 1 58 | del lines 59 | 60 | doccnt_by_source = dict(sorted(doccnt_by_source.items(), key=lambda x: x[1], reverse=True)) 61 | doccnt_by_domain_source = dict(sorted(doccnt_by_domain_source.items(), key=lambda x: x[1], reverse=True)) 62 | doccnt_by_charlen_by_source = {k: dict(sorted(v.items(), key=lambda x: int(x[0]), reverse=False)) for k, v in doccnt_by_charlen_by_source.items()} 63 | 64 | with open(args.output_path, 'w') as f: 65 | json.dump({ 66 | 'doccnt_by_source': doccnt_by_source, 67 | 'doccnt_by_domain_source': doccnt_by_domain_source, 68 | 'doccnt_by_charlen_by_source': doccnt_by_charlen_by_source, 69 | }, f, indent=4) 70 | -------------------------------------------------------------------------------- /compute_stats/transform.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | with open('compute_stats/olmoe-mix-0924.json', 'r') as f: 6 | js = json.load(f) 7 | 8 | with open('../he-olmo-ui/api/dolma_search/static/source_counts/data.json', 'w') as f: 9 | doccnt_by_source = js['doccnt_by_source'] 10 | total = sum(doccnt_by_source.values()) 11 | json.dump(doccnt_by_source, f) 12 | 13 | with open('../he-olmo-ui/api/dolma_search/static/domains/data.json', 'w') as f: 14 | doccnt_by_domain_source = js['doccnt_by_domain_source'] 15 | doccnt_by_source_by_domain = defaultdict(dict) 16 | for domain_source, doccnt in list(doccnt_by_domain_source.items())[:20000]: 17 | source, domain = domain_source.split('/') 18 | if domain != '': 19 | doccnt_by_source_by_domain[source][domain] = doccnt 20 | json.dump(doccnt_by_source_by_domain, f) 21 | 22 | with open('../he-olmo-ui/api/dolma_search/static/words/data.json', 'w') as f: 23 | doccnt_by_charlen_by_source = js['doccnt_by_charlen_by_source'] 24 | bins_by_source = defaultdict(list) 25 | for source, doccnt_by_charlen in doccnt_by_charlen_by_source.items(): 26 | max_charlen = max([int(charlen) for charlen in doccnt_by_charlen.keys()]) 27 | # [0, 2^0), [2^0, 2^1), [2^1, 2^2), ..., [2^num_bins-1, 2^num_bins) 28 | num_bins = int(np.log2(max_charlen)) + 2 29 | counts = [0] * num_bins 30 | for charlen, doccnt in doccnt_by_charlen.items(): 31 | bin_idx = 0 if charlen == '0' else (int(np.log2(int(charlen))) + 1) 32 | counts[bin_idx] += doccnt 33 | bins = [{'min': int(2**(b-1)), 'max': int(2**b), 'doc_count': counts[b], 'percentage': counts[b] / total} for b in range(num_bins)] 34 | bins_by_source[source] = bins 35 | json.dump(bins_by_source, f) 36 | -------------------------------------------------------------------------------- /compute_stats/wiki.json: -------------------------------------------------------------------------------- 1 | { 2 | "doccnt_by_source": {}, 3 | "doccnt_by_domain_source": {}, 4 | "doccnt_by_charlen_by_source": {} 5 | } -------------------------------------------------------------------------------- /compute_stats/worker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import zstandard as zstd 4 | import json 5 | from collections import defaultdict 6 | import multiprocessing as mp 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--data_dir', type=str, required=True) 10 | parser.add_argument('--data_path', type=str, required=True) 11 | parser.add_argument('--cpus', type=int, default=1) 12 | parser.add_argument('--output_path', type=str, required=True) 13 | args = parser.parse_args() 14 | 15 | def load_file(path): 16 | if path.endswith('.gz'): 17 | with gzip.open(path, 'rt', encoding='utf-8') as f: 18 | lines = f.readlines() 19 | elif path.endswith('.zst'): 20 | with open(path, 'rb') as f: 21 | dctx = zstd.ZstdDecompressor() 22 | with dctx.stream_reader(f) as reader: 23 | decompressed_data = reader.read().decode('utf-8') 24 | lines = decompressed_data.split('\n') 25 | if lines[-1] == '': 26 | lines = lines[:-1] 27 | elif path.endswith('.jsonl'): 28 | with open(path, encoding='utf-8') as f: 29 | lines = f.readlines() 30 | else: 31 | raise ValueError(f'Unknown file type: {path}') 32 | return lines 33 | 34 | def process(line): 35 | item = json.loads(line.strip('\n')) 36 | try: 37 | domain = item['metadata']['url'].split('/')[2] 38 | except: 39 | domain = '' 40 | return domain, len(item['text']) 41 | 42 | doccnt_by_source = defaultdict(int) 43 | doccnt_by_domain_source = defaultdict(int) 44 | doccnt_by_charlen_by_source = defaultdict(lambda: defaultdict(int)) 45 | 46 | with mp.get_context('fork').Pool(args.cpus) as p: 47 | source = args.data_path[len(args.data_dir)+1:].split('/')[0] 48 | lines = load_file(args.data_path) 49 | results = p.map(process, lines) 50 | for (domain, charlen) in results: 51 | doccnt_by_source[source] += 1 52 | doccnt_by_domain_source[f'{source}/{domain}'] += 1 53 | doccnt_by_charlen_by_source[source][charlen] += 1 54 | 55 | doccnt_by_source = dict(sorted(doccnt_by_source.items(), key=lambda x: x[1], reverse=True)) 56 | doccnt_by_domain_source = dict(sorted(doccnt_by_domain_source.items(), key=lambda x: x[1], reverse=True)) 57 | doccnt_by_charlen_by_source = {k: dict(sorted(v.items(), key=lambda x: int(x[0]), reverse=False)) for k, v in doccnt_by_charlen_by_source.items()} 58 | 59 | with open(args.output_path, 'w') as f: 60 | json.dump({ 61 | 'doccnt_by_source': doccnt_by_source, 62 | 'doccnt_by_domain_source': doccnt_by_domain_source, 63 | 'doccnt_by_charlen_by_source': doccnt_by_charlen_by_source, 64 | }, f, indent=4) 65 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | x-shared-env: &shared-env 2 | PYTHONUNBUFFERED: 1 3 | INDEX_BASE_PATH: /mnt/infinigram-array 4 | ATTRIBUTION_QUEUE_URL: postgres://infini-gram:llmz@queue:5432/infini-gram?sslmode=disable&application_name=infini-gram-attribution-worker 5 | PYTHON_ENV: development 6 | 7 | services: 8 | api: 9 | platform: linux/amd64 10 | build: 11 | context: . 12 | dockerfile: ./api/Dockerfile 13 | target: dev 14 | volumes: 15 | - ./api:/app/api 16 | # If you add an volume here make sure you add it to /bin/download-infini-gram-array.sh too 17 | - ./infinigram-array/v4_pileval_llama:/mnt/infinigram-array/v4_pileval_llama 18 | - ./infinigram-array/olmoe-mix-0924-dclm:/mnt/infinigram-array/olmoe-mix-0924-dclm 19 | - ./infinigram-array/olmoe-mix-0924-nodclm:/mnt/infinigram-array/olmoe-mix-0924-nodclm 20 | - ./infinigram-array/v4-olmoe-0125-1b-7b-anneal-adapt:/mnt/infinigram-array/v4-olmoe-0125-1b-7b-anneal-adapt 21 | - ./infinigram-array/v4-olmo-2-1124-13b-anneal-adapt:/mnt/infinigram-array/v4-olmo-2-1124-13b-anneal-adapt 22 | - ./infinigram-array/v4-olmo-2-0325-32b-anneal-adapt:/mnt/infinigram-array/v4-olmo-2-0325-32b-anneal-adapt 23 | - ./infinigram-array/v4-tulu-3-8b-adapt-llama:/mnt/infinigram-array/v4-tulu-3-8b-adapt-llama 24 | - ./infinigram-array/v4-tulu-3-70b-adapt-llama:/mnt/infinigram-array/v4-tulu-3-70b-adapt-llama 25 | - ./infinigram-array/v4-tulu-3-405b-adapt-llama:/mnt/infinigram-array/v4-tulu-3-405b-adapt-llama 26 | environment: 27 | LOG_LEVEL: DEBUG 28 | OTEL_EXPORTER_OTLP_ENDPOINT: http://otelcol:4318 29 | CACHE_URL: redis://cache:6379 30 | <<: *shared-env 31 | 32 | attribution-worker: 33 | platform: linux/amd64 34 | ports: 35 | - 8081:8080 36 | build: 37 | context: . 38 | dockerfile: ./attribution_worker/Dockerfile 39 | target: dev 40 | volumes: 41 | - ./attribution_worker:/app/attribution_worker 42 | # If you add an volume here make sure you add it to /bin/download-infini-gram-array.sh too 43 | - ./infinigram-array/v4_pileval_llama:/mnt/infinigram-array/v4_pileval_llama 44 | - ./infinigram-array/olmoe-mix-0924-dclm:/mnt/infinigram-array/olmoe-mix-0924-dclm 45 | - ./infinigram-array/olmoe-mix-0924-nodclm:/mnt/infinigram-array/olmoe-mix-0924-nodclm 46 | - ./infinigram-array/v4-olmoe-0125-1b-7b-anneal-adapt:/mnt/infinigram-array/v4-olmoe-0125-1b-7b-anneal-adapt 47 | - ./infinigram-array/v4-olmo-2-1124-13b-anneal-adapt:/mnt/infinigram-array/v4-olmo-2-1124-13b-anneal-adapt 48 | - ./infinigram-array/v4-olmo-2-0325-32b-anneal-adapt:/mnt/infinigram-array/v4-olmo-2-0325-32b-anneal-adapt 49 | - ./infinigram-array/v4-tulu-3-8b-adapt-llama:/mnt/infinigram-array/v4-tulu-3-8b-adapt-llama 50 | - ./infinigram-array/v4-tulu-3-70b-adapt-llama:/mnt/infinigram-array/v4-tulu-3-70b-adapt-llama 51 | - ./infinigram-array/v4-tulu-3-405b-adapt-llama:/mnt/infinigram-array/v4-tulu-3-405b-adapt-llama 52 | environment: *shared-env 53 | 54 | proxy: 55 | build: ./proxy 56 | ports: 57 | - 8080:8080 58 | depends_on: 59 | - api 60 | 61 | otelcol: 62 | image: otel/opentelemetry-collector-contrib:0.115.1 63 | volumes: 64 | - ./otel-collector/otel-collector-config.yaml:/etc/otelcol-contrib/config.yaml:ro 65 | - logs:/var/log:ro 66 | environment: 67 | - GOOGLE_CLOUD_PROJECT 68 | - GOOGLE_CLOUD_QUOTA_PROJECT 69 | ports: 70 | # Collector prometheus port. The metrics are checked in tests 71 | - 8888 72 | 73 | queue: 74 | image: postgres:15 75 | ports: 76 | - 5432:5432 77 | environment: 78 | POSTGRES_PASSWORD: llmz 79 | POSTGRES_DB: infini-gram 80 | volumes: 81 | - pgdata:/var/lib/postgresql/data 82 | - ./schema:/docker-entrypoint-initdb.d 83 | 84 | cache: 85 | image: redis:7 86 | ports: 87 | - 6379:6379 88 | 89 | volumes: 90 | logs: 91 | pgdata: -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute to Infini-gram API 2 | 3 | ## Setting up your environment 4 | 5 | ### Prerequisites 6 | 7 | Make sure that you have the latest version of [Docker 🐳](https://www.docker.com/get-started) 8 | installed on your local machine. 9 | 10 | ### Installing Dependencies 11 | 12 | #### Install uv 13 | 14 | This repo uses `uv` for Python dependency management. Follow the instructions on https://docs.astral.sh/uv/getting-started/installation to install it. 15 | 16 | Run `uv sync --all-packages` to get the packages for every project in this repo. 17 | 18 | #### Adding an index for local development 19 | 20 | 1. Ensure you have the `aws` cli installed. run `brew install awscli` if you don't. 21 | 2. Download the `v4_pileval_llama` index by running `./bin/download-infini-gram-array.sh` 22 | 23 | The `infinigram-array` folder is mounted to the Docker container for the API through the `docker-compose`. 24 | 25 | ## Linting and Formatting 26 | 27 | We use `Ruff` and `mypy` to lint, format, and check for type issues. 28 | 29 | ### CLI 30 | To check for `Ruff` issues, run `uv run ruff check`. If you want to have it automatically fix issues, run `uv run ruff check --fix`. If you want to have it format your code, run `uv run ruff format`. 31 | 32 | To check for `mypy` issues, run `uv run mypy --config ./pyproject.toml` 33 | 34 | ### VSCode 35 | Install the [ruff](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) and [mypy](https://marketplace.visualstudio.com/items?itemName=ms-python.mypy-type-checker) extensions. These are listed in the "Recommended Extensions" for the workspace as well. 36 | 37 | ## Running the server 38 | 39 | ### Docker Compose 40 | The easiest way to run the full setup is to use the Docker Compose file. This will start the API, worker, proxy, and queue with all the appropriate connections. 41 | 42 | To start with Docker Compose, run `docker compose up` in the root of this repo. 43 | 44 | 45 | ### Outside of Docker 46 | If you want to run applications outside of docker, you'll need to set up a queue yourself. The easiest method is to still go through Compose by just starting the queue. `docker compose up queue` 47 | 48 | After that, make sure your environment variables are set correctly through a `.env` file or just environment variables, then run the services. 49 | API: `uv run api/app.py` 50 | Worker: `uv run saq attribution_worker.worker.settings` -------------------------------------------------------------------------------- /docs/indexes.md: -------------------------------------------------------------------------------- 1 | 2 | # Adding a new infini-gram index 3 | 4 | ## On the prod server 5 | 6 | ### Transferring indexes from AWS 7 | 8 | #### Creating Agents 9 | 10 | 1. Create VMs in google compute engine 11 | * Google recommends 4vCPU and 8GB of memory per agent. Three VMs with the appropriate specs seems to be the most cost-effective 12 | * Example gcloud command: 13 | ``` 14 | gcloud compute instances create infini-gram-transfer-agent-1 \ 15 | --project=ai2-reviz \ 16 | --zone=us-west1-a \ 17 | --machine-type=n2-standard-4 \ 18 | --network-interface=network-tier=PREMIUM,stack-type=IPV4_ONLY,subnet=main \ 19 | --maintenance-policy=MIGRATE \ 20 | --provisioning-model=STANDARD \ 21 | --service-account=infini-gram-transfer@ai2-reviz.iam.gserviceaccount.com \ 22 | --scopes=https://www.googleapis.com/auth/cloud-platform \ 23 | --create-disk=auto-delete=yes,boot=yes,device-name=instance-20240716-154927,image=projects/debian-cloud/global/images/debian-12-bookworm-v20240709,mode=rw,size=10,type=projects/ai2-reviz/zones/us-west1-a/diskTypes/pd-balanced \ 24 | --no-shielded-secure-boot \ 25 | --shielded-vtpm \ 26 | --shielded-integrity-monitoring \ 27 | --labels=name=infini-gram-transfer-agent-1,project=infini-gram,goog-ec-src=vm_add-gcloud \ 28 | --reservation-affinity=any 29 | ``` 30 | 2. Upload your SSH key to the VMs 31 | 3. Copy the infini-gram service account json to the vm: `scp -i ./infini-gram-transfer-account.json taylorb@10.115.0.81:~` 32 | 4. SSH into the VM: `ssh ` 33 | 5. Set up the agent and attach it to the agent pool 34 | * Run these commands, replacing values where needed: 35 | ``` 36 | export AWS_ACCESS_KEY_ID= 37 | export AWS_SECRET_ACCESS_KEY= 38 | curl -fsSL https://get.docker.com -o get-docker.sh && sudo sh get-docker.sh && sudo systemctl enable docker 39 | gcloud transfer agents install --pool=infini-gram-transfer \ 40 | --creds-file=~/infini-gram-transfer-account.json 41 | ``` 42 | 43 | #### Starting the transfer job 44 | 45 | 1. Go to the [create a transfer job page](https://console.cloud.google.com/transfer/create?project=ai2-reviz) 46 | 2. Set the Source type to "S3-compatible object storage". Destination type should be "Google Cloud Storage" 47 | 3. Fill in data for the source. 48 | * Bucket or folder will be the bucket path. If copying from an s3 URL like `s3://infini-gram-lite/index/v4_pileval_llama` you can take off the `s3://` part, resulting in `infini-gram-lite/index/v4_pileval_llama` 49 | * Endpoint depends on the bucket. Most of the time it'll be `s3.us-east-1.amazonaws.com` 50 | * Signing region should be the region the bucket is in 51 | 4. The destination should be `infinigram/index/` 52 | 5. The job should run once, starting now 53 | 6. The rest of the settings can stay the same 54 | * I do fine it nice to name the transfer job. Something like `infini-gram-transfer-` 55 | * Make sure you don't change the Storage class and don't delete from the source 56 | 57 | ### Making a Persistent Disk 58 | 1. ``` 59 | gcloud compute disks create infini-gram- \ 60 | --project=ai2-reviz \ 61 | --type=pd-balanced \ 62 | --size= \ 63 | --labels=project=infini-gram \ 64 | --zone=us-west1-b 65 | ``` 66 | 2. Change names in `volume-claims/writer-pod.yaml` to match the disk you created and the index name 67 | 3. Create a writer pod: `kubectl apply -f volume-claims/writer-pod.yaml --namespace=infinigram-api` 68 | * (TODO) Set up a baseline image to use for transferring files. Needs to have python3 and gcloud tools 69 | 4. connect to the pod and set it up with python3 and gcloud 70 | * kubectl exec --stdin --tty infini-gram-writer --namespace=infinigram-api -- /bin/ash 71 | * apk add python3 curl which bash 72 | * curl -sSL https://sdk.cloud.google.com | bash 73 | * bash 74 | 4. Download the files from the bucket into /mnt/infini-gram-array 75 | * `gcloud storage cp gs://infinigram/index//* /mnt/infini-gram-array/` 76 | 77 | ### Adding the volume to webapp.jsonnet 78 | 1. Add a volume to the deployment 79 | ``` 80 | { 81 | name: "infinigram-array->", 82 | persistentVolumeClaim: { 83 | claimName: "infinigram-", 84 | readOnly: true, 85 | } 86 | } 87 | ``` 88 | 2. Add a volumeMount to the -api container 89 | ``` 90 | { 91 | mountPath: "/mnt/infinigram-array/", 92 | name: "infinigram-array-", 93 | readOnly: true, 94 | } 95 | ``` 96 | 97 | ## Locally 98 | 99 | 1. Add the ID of the index to `AvailableInfiniGramIndexId` in `api/src/infinigram/index_mappings.py` 100 | 2. Add the ID as a string to `IndexMappings` in `api/src/infinigram/index_mappings.py` 101 | 3. Add the tokenizer and index directory to `index_mappings` in `api/src/infinigram/index_mappings.py` 102 | 4. add a line in /bin/download-infini-gram-array.sh to make a new symlink with that array's path. The path will be the `index_dir` you added in `index_mappings` but has `/mnt/infinigram-array` replaced with `$INFINIGRAM_ARRAY_DIR` 103 | 5. Add a mount in `docker-compose.yaml`: `- ./infinigram-array/:/mnt/infinigram-array/ -------------------------------------------------------------------------------- /indexing/flan_rulebased_s3_to_weka.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -ex 4 | 5 | NUM_NODES=1 6 | RUN_NAME="copy_flan_rulebased_to_weka" 7 | 8 | gantry run \ 9 | --allow-dirty \ 10 | --name ${RUN_NAME} \ 11 | --task-name ${RUN_NAME} \ 12 | --description ${RUN_NAME} \ 13 | --workspace ai2/hb-wolf-olmo \ 14 | --budget ai2/oe-training \ 15 | --beaker-image petew/olmo-torch23-gantry \ 16 | --cluster ai2/jupiter-cirrascale-2 \ 17 | --priority high \ 18 | --preemptible \ 19 | --no-nfs \ 20 | --weka oe-training-default:/weka/oe-training-default \ 21 | --shared-memory 10GiB \ 22 | --no-python \ 23 | --env LOG_FILTER_TYPE=local_rank0_only \ 24 | --env OMP_NUM_THREADS=8 \ 25 | --env OLMO_TASK=model \ 26 | --env WANDB__SERVICE_WAIT=300 \ 27 | --env WANDB_HTTP_TIMEOUT=60 \ 28 | --env-secret WANDB_API_KEY=WANDB_API_KEY \ 29 | --env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \ 30 | --env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \ 31 | --yes \ 32 | -- /bin/bash -c "\ 33 | conda shell.bash activate base; \ 34 | pip install awscli; \ 35 | aws s3 cp --recursive s3://ai2-llm/preprocessed/tulu_flan/v1-decontaminated-60M-shots_all-upweight_1-dialog_false-sep_rulebased /weka/oe-training-default/ai2-llm/preprocessed/tulu_flan/v1-decontaminated-60M-shots_all-upweight_1-dialog_false-sep_rulebased; \ 36 | " 37 | -------------------------------------------------------------------------------- /indexing/index_weka_to_s3_all.sh: -------------------------------------------------------------------------------- 1 | bash sync_to_s3_nodclm.sh s3://infini-gram/index/v4_olmoe-mix-0924-nodclm_llama oe-training-default jiachengl/index/v4_olmoe-mix-0924-nodclm_llama 2 | bash sync_to_s3_dclm.sh s3://infini-gram/index/v4_olmoe-mix-0924-dclm_llama oe-training-default jiachengl/index/v4_olmoe-mix-0924-dclm_llama 3 | -------------------------------------------------------------------------------- /indexing/index_weka_to_s3_dclm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Default values 4 | CLUSTER="ai2/jupiter-cirrascale-2" 5 | PRIORITY="high" 6 | WORKSPACE="ai2/hb-wolf-olmo" 7 | BUDGET="ai2/oe-training" 8 | 9 | # Parse command line arguments and options 10 | while [[ $# -gt 0 ]]; do 11 | case $1 in 12 | -c|--cluster) 13 | CLUSTER="$2" 14 | shift 2 15 | ;; 16 | -w|--workspace) 17 | WORKSPACE="$2" 18 | BUDGET="$WORKSPACE" 19 | shift 2 20 | ;; 21 | -p|--priority) 22 | PRIORITY="$2" 23 | shift 2 24 | ;; 25 | -b|--budget) 26 | BUDGET="$2" 27 | shift 2 28 | ;; 29 | *) 30 | if [ -z "$S3_PREFIX" ]; then 31 | S3_PREFIX="$1" 32 | elif [ -z "$WEKA_BUCKET" ]; then 33 | WEKA_BUCKET="$1" 34 | elif [ -z "$WEKA_PATH" ]; then 35 | WEKA_PATH="$1" 36 | else 37 | echo "Error: Unexpected argument '$1'" 38 | echo "Usage: $0 [-c|--cluster ] [-w|--workspace ] [-p|--priority ] [-b|--budget ]" 39 | exit 1 40 | fi 41 | shift 42 | ;; 43 | esac 44 | done 45 | 46 | if [ -z "$S3_PREFIX" ] || [ -z "$WEKA_BUCKET" ] || [ -z "$WEKA_PATH" ]; then 47 | echo "Error: S3 prefix and Weka bucket are required" 48 | echo "Usage: $0 [-c|--cluster ] [-w|--workspace ] [-p|--priority ]" 49 | exit 1 50 | fi 51 | 52 | # Shift the parsed options out of the argument list 53 | shift $((OPTIND-1)) 54 | 55 | # strip trailing slash from S3_PREFIX 56 | S3_PREFIX=$(echo "$S3_PREFIX" | sed 's|/$||') 57 | 58 | # create a command to install required packages and the AWS CLI 59 | AWS_CLI_INSTALL_CMD="set -x; \ 60 | apt-get update && \ 61 | apt-get install -y curl unzip && \ 62 | curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\" && \ 63 | unzip awscliv2.zip && \ 64 | ./aws/install" 65 | 66 | WEKA_PREFIX="/${WEKA_BUCKET}/${WEKA_PATH}" 67 | 68 | # create a command to sync the S3 prefix to the Weka bucket 69 | SYNC_CMD="/usr/local/bin/aws s3 sync '${WEKA_PREFIX}' '${S3_PREFIX}'" 70 | 71 | 72 | gantry run \ 73 | --description "Syncing '${WEKA_PREFIX}' to '${S3_PREFIX}'" \ 74 | --allow-dirty \ 75 | --workspace "${WORKSPACE}" \ 76 | --priority "${PRIORITY}" \ 77 | --gpus 8 \ 78 | --preemptible \ 79 | --cluster "${CLUSTER}" \ 80 | --budget "${BUDGET}" \ 81 | --weka "${WEKA_BUCKET}:/${WEKA_BUCKET}" \ 82 | --replicas 10 \ 83 | --host-networking \ 84 | --leader-selection \ 85 | --synchronized-start-timeout 48h \ 86 | --env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \ 87 | --env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \ 88 | --install "${AWS_CLI_INSTALL_CMD}" \ 89 | --yes \ 90 | -- /bin/bash -c "/usr/local/bin/aws s3 cp --recursive --only-show-errors '${WEKA_PREFIX}' '${S3_PREFIX}' --exclude '*' --include \"*.\$BEAKER_REPLICA_RANK\"" -------------------------------------------------------------------------------- /indexing/index_weka_to_s3_nodclm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Default values 4 | CLUSTER="ai2/jupiter-cirrascale-2" 5 | PRIORITY="high" 6 | WORKSPACE="ai2/hb-wolf-olmo" 7 | BUDGET="ai2/oe-training" 8 | 9 | # Parse command line arguments and options 10 | while [[ $# -gt 0 ]]; do 11 | case $1 in 12 | -c|--cluster) 13 | CLUSTER="$2" 14 | shift 2 15 | ;; 16 | -w|--workspace) 17 | WORKSPACE="$2" 18 | BUDGET="$WORKSPACE" 19 | shift 2 20 | ;; 21 | -p|--priority) 22 | PRIORITY="$2" 23 | shift 2 24 | ;; 25 | -b|--budget) 26 | BUDGET="$2" 27 | shift 2 28 | ;; 29 | *) 30 | if [ -z "$S3_PREFIX" ]; then 31 | S3_PREFIX="$1" 32 | elif [ -z "$WEKA_BUCKET" ]; then 33 | WEKA_BUCKET="$1" 34 | elif [ -z "$WEKA_PATH" ]; then 35 | WEKA_PATH="$1" 36 | elif [ -z "$WORKER_ID" ]; then 37 | WORKER_ID="$1" 38 | else 39 | echo "Error: Unexpected argument '$1'" 40 | echo "Usage: $0 [-c|--cluster ] [-w|--workspace ] [-p|--priority ] [-b|--budget ]" 41 | exit 1 42 | fi 43 | shift 44 | ;; 45 | esac 46 | done 47 | 48 | if [ -z "$S3_PREFIX" ] || [ -z "$WEKA_BUCKET" ] || [ -z "$WEKA_PATH" ]; then 49 | echo "Error: S3 prefix and Weka bucket are required" 50 | echo "Usage: $0 [-c|--cluster ] [-w|--workspace ] [-p|--priority ]" 51 | exit 1 52 | fi 53 | 54 | # Shift the parsed options out of the argument list 55 | shift $((OPTIND-1)) 56 | 57 | # strip trailing slash from S3_PREFIX 58 | S3_PREFIX=$(echo "$S3_PREFIX" | sed 's|/$||') 59 | 60 | # create a command to install required packages and the AWS CLI 61 | AWS_CLI_INSTALL_CMD="set -x; \ 62 | apt-get update && \ 63 | apt-get install -y curl unzip && \ 64 | curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\" && \ 65 | unzip awscliv2.zip && \ 66 | ./aws/install" 67 | 68 | WEKA_PREFIX="/${WEKA_BUCKET}/${WEKA_PATH}" 69 | 70 | # create a command to sync the S3 prefix to the Weka bucket 71 | SYNC_CMD="/usr/local/bin/aws s3 sync '${WEKA_PREFIX}' '${S3_PREFIX}'" 72 | 73 | 74 | gantry run \ 75 | --description "Syncing '${WEKA_PREFIX}' to '${S3_PREFIX}'" \ 76 | --allow-dirty \ 77 | --workspace "${WORKSPACE}" \ 78 | --priority "${PRIORITY}" \ 79 | --gpus 0 \ 80 | --preemptible \ 81 | --cluster "${CLUSTER}" \ 82 | --budget "${BUDGET}" \ 83 | --weka "${WEKA_BUCKET}:/${WEKA_BUCKET}" \ 84 | --env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \ 85 | --env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \ 86 | --install "${AWS_CLI_INSTALL_CMD}" \ 87 | --yes \ 88 | -- /bin/bash -c "${SYNC_CMD}" -------------------------------------------------------------------------------- /indexing/indexing_dclm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | RUN_NAME="index_v4_olmoe-mix-0924-dclm_llama" 4 | 5 | gantry run \ 6 | --allow-dirty \ 7 | --name ${RUN_NAME} \ 8 | --task-name ${RUN_NAME} \ 9 | --description ${RUN_NAME} \ 10 | --workspace ai2/hb-wolf-olmo \ 11 | --budget ai2/oe-training \ 12 | --beaker-image shanea/olmo-torch2.2-gantry \ 13 | --cluster ai2/jupiter-cirrascale-2 \ 14 | --priority high \ 15 | --preemptible \ 16 | --no-nfs \ 17 | --weka oe-training-default:/weka/oe-training-default \ 18 | --weka oe-data-default:/weka/oe-data-default \ 19 | --cpus 186 \ 20 | --memory 1912GiB \ 21 | --shared-memory 10GiB \ 22 | --replicas 10 \ 23 | --host-networking \ 24 | --leader-selection \ 25 | --propagate-failure \ 26 | --propagate-preemption \ 27 | --synchronized-start-timeout 48h \ 28 | --no-python \ 29 | --venv base \ 30 | --env-secret HF_TOKEN=HF_TOKEN \ 31 | --yes \ 32 | -- /bin/bash -c "\ 33 | pip install infini-gram zstandard tqdm transformers sentencepiece ; \ 34 | cd /opt/miniconda3/lib/python3.10/site-packages/infini_gram ; \ 35 | python indexing.py \ 36 | --tokenizer llama --cpus 186 --mem 1912 --shards 10 --workers 10 --worker_id \$BEAKER_REPLICA_RANK --add_metadata --ulimit 524288 \ 37 | --data_dir /weka/oe-data-default/ai2-llm/pretraining-data/sources/olmo-mix/olmoe-mix-0924/documents/dclm \ 38 | --save_dir /weka/oe-training-default/jiachengl/index/v4_olmoe-mix-0924-dclm_llama ; \ 39 | " 40 | -------------------------------------------------------------------------------- /indexing/indexing_nodclm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | RUN_NAME="index_v4_olmoe-mix-0924-nodclm_llama" 4 | 5 | gantry run \ 6 | --allow-dirty \ 7 | --name ${RUN_NAME} \ 8 | --task-name ${RUN_NAME} \ 9 | --description ${RUN_NAME} \ 10 | --workspace ai2/hb-wolf-olmo \ 11 | --budget ai2/oe-training \ 12 | --beaker-image shanea/olmo-torch2.2-gantry \ 13 | --cluster ai2/jupiter-cirrascale-2 \ 14 | --priority high \ 15 | --preemptible \ 16 | --no-nfs \ 17 | --weka oe-training-default:/weka/oe-training-default \ 18 | --weka oe-data-default:/weka/oe-data-default \ 19 | --cpus 186 \ 20 | --memory 1912GiB \ 21 | --shared-memory 10GiB \ 22 | --no-python \ 23 | --venv base \ 24 | --env-secret HF_TOKEN=HF_TOKEN \ 25 | --yes \ 26 | -- /bin/bash -c "\ 27 | pip install infini-gram zstandard tqdm transformers sentencepiece ; \ 28 | cd /opt/miniconda3/lib/python3.10/site-packages/infini_gram ; \ 29 | python indexing.py \ 30 | --tokenizer llama --cpus 186 --mem 1912 --shards 1 --add_metadata --ulimit 524288 \ 31 | --data_dir /weka/oe-data-default/ai2-llm/pretraining-data/sources/olmo-mix/olmoe-mix-0924-nodclm \ 32 | --save_dir /weka/oe-training-default/jiachengl/index/v4_olmoe-mix-0924-nodclm_llama ; \ 33 | " 34 | -------------------------------------------------------------------------------- /indexing/indexing_olmo2-13b-anneal-adapt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | RUN_NAME="index_v4_olmo2-anneal-adapt" 4 | 5 | gantry run \ 6 | --allow-dirty \ 7 | --name ${RUN_NAME} \ 8 | --task-name ${RUN_NAME} \ 9 | --description ${RUN_NAME} \ 10 | --workspace ai2/attribution \ 11 | --budget ai2/oe-training \ 12 | --beaker-image shanea/olmo-torch2.2-gantry \ 13 | --cluster ai2/neptune-cirrascale \ 14 | --priority normal \ 15 | --preemptible \ 16 | --no-nfs \ 17 | --weka oe-data-default:/weka/oe-data-default \ 18 | --weka oe-training-default:/weka/oe-training-default \ 19 | --cpus 248 \ 20 | --memory 900GiB \ 21 | --shared-memory 10GiB \ 22 | --no-python \ 23 | --venv base \ 24 | --env-secret HF_TOKEN=HF_TOKEN \ 25 | --env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \ 26 | --env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \ 27 | --yes \ 28 | -- /bin/bash -c "\ 29 | pip install infini-gram zstandard tqdm transformers sentencepiece awscli ; \ 30 | cd /opt/miniconda3/lib/python3.10/site-packages/infini_gram ; \ 31 | python indexing.py \ 32 | --tokenizer llama --cpus 64 --mem 900 --shards 1 --add_metadata --ulimit 524288 \ 33 | --data_dir /weka/oe-training-default/jiachengl/raw/olmo-2-1124-13b-anneal-adapt \ 34 | --save_dir /weka/oe-training-default/jiachengl/index/v4_olmo-2-1124-13b-anneal-adapt_llama ; \ 35 | aws s3 sync /weka/oe-training-default/jiachengl/index/v4_olmo-2-1124-13b-anneal-adapt_llama s3://infini-gram/index/v4_olmo-2-1124-13b-anneal-adapt_llama ; \ 36 | " 37 | -------------------------------------------------------------------------------- /indexing/indexing_olmo2-32b-anneal-adapt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | RUN_NAME="index_v4_olmo2-32b-anneal-adapt" 4 | 5 | gantry run \ 6 | --allow-dirty \ 7 | --name ${RUN_NAME} \ 8 | --task-name ${RUN_NAME} \ 9 | --description ${RUN_NAME} \ 10 | --workspace ai2/attribution \ 11 | --budget ai2/oe-training \ 12 | --beaker-image shanea/olmo-torch2.2-gantry \ 13 | --cluster ai2/neptune-cirrascale \ 14 | --priority normal \ 15 | --no-nfs \ 16 | --weka oe-data-default:/weka/oe-data-default \ 17 | --weka oe-training-default:/weka/oe-training-default \ 18 | --cpus 248 \ 19 | --memory 900GiB \ 20 | --shared-memory 10GiB \ 21 | --no-python \ 22 | --venv base \ 23 | --env-secret HF_TOKEN=HF_TOKEN \ 24 | --env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \ 25 | --env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \ 26 | --yes \ 27 | -- /bin/bash -c "\ 28 | pip install infini-gram zstandard tqdm transformers sentencepiece awscli ; \ 29 | cd /opt/miniconda3/lib/python3.10/site-packages/infini_gram ; \ 30 | python indexing.py \ 31 | --tokenizer llama --cpus 64 --mem 900 --shards 1 --add_metadata --ulimit 524288 \ 32 | --data_dir /weka/oe-training-default/jiachengl/he-infinigram-api/raw/olmo-2-0325-32b-anneal-adapt \ 33 | --save_dir /weka/oe-training-default/jiachengl/he-infinigram-api/index/v4_olmo-2-0325-32b-anneal-adapt_llama ; \ 34 | aws s3 sync /weka/oe-training-default/jiachengl/he-infinigram-api/index/v4_olmo-2-0325-32b-anneal-adapt_llama s3://infini-gram/index/v4_olmo-2-0325-32b-anneal-adapt_llama ; \ 35 | " 36 | -------------------------------------------------------------------------------- /indexing/indexing_olmoe-1b-7b-anneal-adapt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | RUN_NAME="index_v4_olmoe-1b-7b-anneal-adapt" 4 | 5 | gantry run \ 6 | --allow-dirty \ 7 | --name ${RUN_NAME} \ 8 | --task-name ${RUN_NAME} \ 9 | --description ${RUN_NAME} \ 10 | --workspace ai2/attribution \ 11 | --budget ai2/oe-training \ 12 | --beaker-image shanea/olmo-torch2.2-gantry \ 13 | --cluster ai2/neptune-cirrascale \ 14 | --priority normal \ 15 | --no-nfs \ 16 | --weka oe-data-default:/weka/oe-data-default \ 17 | --weka oe-training-default:/weka/oe-training-default \ 18 | --cpus 248 \ 19 | --memory 900GiB \ 20 | --shared-memory 10GiB \ 21 | --no-python \ 22 | --venv base \ 23 | --env-secret HF_TOKEN=HF_TOKEN \ 24 | --env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \ 25 | --env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \ 26 | --yes \ 27 | -- /bin/bash -c "\ 28 | pip install infini-gram zstandard tqdm transformers sentencepiece awscli ; \ 29 | cd /opt/miniconda3/lib/python3.10/site-packages/infini_gram ; \ 30 | python indexing.py \ 31 | --tokenizer llama --cpus 64 --mem 900 --shards 1 --add_metadata --ulimit 524288 \ 32 | --data_dir /weka/oe-training-default/jiachengl/he-infinigram-api/raw/olmoe-0125-1b-7b-anneal-adapt \ 33 | --save_dir /weka/oe-training-default/jiachengl/he-infinigram-api/index/v4_olmoe-0125-1b-7b-anneal-adapt_llama ; \ 34 | aws s3 sync /weka/oe-training-default/jiachengl/he-infinigram-api/index/v4_olmoe-0125-1b-7b-anneal-adapt_llama s3://infini-gram/index/v4_olmoe-0125-1b-7b-anneal-adapt_llama ; \ 35 | " 36 | -------------------------------------------------------------------------------- /indexing/indexing_olmoe-adaptation.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | RUN_NAME="index_v4_olmoe-adaptation" 4 | 5 | gantry run \ 6 | --allow-dirty \ 7 | --name ${RUN_NAME} \ 8 | --task-name ${RUN_NAME} \ 9 | --description ${RUN_NAME} \ 10 | --workspace ai2/attribution \ 11 | --budget ai2/oe-training \ 12 | --beaker-image shanea/olmo-torch2.2-gantry \ 13 | --cluster ai2/neptune-cirrascale \ 14 | --priority normal \ 15 | --preemptible \ 16 | --no-nfs \ 17 | --weka oe-training-default:/weka/oe-training-default \ 18 | --cpus 248 \ 19 | --memory 900GiB \ 20 | --shared-memory 10GiB \ 21 | --no-python \ 22 | --venv base \ 23 | --env-secret HF_TOKEN=HF_TOKEN \ 24 | --env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \ 25 | --env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \ 26 | --yes \ 27 | -- /bin/bash -c "\ 28 | pip install infini-gram zstandard tqdm transformers sentencepiece awscli ; \ 29 | cd /opt/miniconda3/lib/python3.10/site-packages/infini_gram ; \ 30 | python indexing.py \ 31 | --tokenizer llama --cpus 64 --mem 900 --shards 1 --add_metadata --ulimit 524288 \ 32 | --data_dir /weka/oe-training-default/jiachengl/raw/tulu-v3.1-mix-preview-4096-OLMoE \ 33 | --save_dir /weka/oe-training-default/jiachengl/index/v4_tulu-v3.1-mix-preview-4096-OLMoE_llama ; \ 34 | python indexing.py \ 35 | --tokenizer llama --cpus 64 --mem 900 --shards 1 --add_metadata --ulimit 524288 \ 36 | --data_dir /weka/oe-training-default/jiachengl/raw/ultrafeedback_binarized_cleaned \ 37 | --save_dir /weka/oe-training-default/jiachengl/index/v4_ultrafeedback-binarized-cleaned_llama ; \ 38 | aws s3 sync /weka/oe-training-default/jiachengl/index/v4_tulu-v3.1-mix-preview-4096-OLMoE_llama s3://infini-gram/index/v4_tulu-v3.1-mix-preview-4096-OLMoE_llama ; \ 39 | aws s3 sync /weka/oe-training-default/jiachengl/index/v4_ultrafeedback-binarized-cleaned_llama s3://infini-gram/index/v4_ultrafeedback-binarized-cleaned_llama ; \ 40 | " 41 | -------------------------------------------------------------------------------- /indexing/olmoe-mix.txt: -------------------------------------------------------------------------------- 1 | Raw (This is what's currently in WEKA; matches OLMoE-mix-0924 HF, except for arxiv) 2 | s3://ai2-llm/pretraining-data/sources/proof-pile-2/v0_decontaminated/documents/algebraic-stack/train 3 | s3://ai2-llm/pretraining-data/sources/proof-pile-2/v0_decontaminated/documents/arxiv/train 4 | s3://ai2-llm/pretraining-data/sources/proof-pile-2/v0_decontaminated/documents/open-web-math/train 5 | s3://ai2-llm/pretraining-data/sources/olmo-mix/danyh-compiled-v1_7/documents/pes2o 6 | s3://ai2-llm/pretraining-data/sources/starcoder/v1-decon-100_to_20k-2star-top_token_030/documents 7 | s3://ai2-llm/pretraining-data/sources/olmo-mix/danyh-compiled-v1_7/documents/wiki 8 | s3://ai2-llm/pretraining-data/sources/dclm/v0_repetitions/documents/full 9 | [This one has repetition removed compared to v0] 10 | 11 | OLMoE training data (w/ old tokenizer) 12 | /weka/oe-training-default/ai2-llm/preprocessed/danyh-compiled-v1_7/algebraic-stack/allenai/gpt-neox-olmo-dolma-v1_5/*.npy 13 | [PROBLEMATIC: This is converted from s3://ai2-llm/pretraining-data/sources/olmo-mix/danyh-compiled-v1_7/documents/algebraic-stack] 14 | /weka/oe-training-default/ai2-llm/preprocessed/danyh-compiled-v1_7/arxiv/allenai/gpt-neox-olmo-dolma-v1_5/*.npy 15 | [PROBLEMATIC: This is converted from s3://ai2-llm/pretraining-data/sources/olmo-mix/danyh-compiled-v1_7/documents/arxiv] 16 | /weka/oe-training-default/ai2-llm/preprocessed/danyh-compiled-v1_7/open-web-math/allenai/gpt-neox-olmo-dolma-v1_5/*.npy 17 | [PROBLEMATIC: This is converted from s3://ai2-llm/pretraining-data/sources/olmo-mix/danyh-compiled-v1_7/documents/open-web-math] 18 | /weka/oe-training-default/ai2-llm/preprocessed/danyh-compiled-v1_7/pes2o/allenai/gpt-neox-olmo-dolma-v1_5/*.npy 19 | /weka/oe-training-default/ai2-llm/preprocessed/danyh-compiled-v1_7/starcoder/allenai/gpt-neox-olmo-dolma-v1_5/*.npy 20 | [They really meant /weka/oe-training-default/ai2-llm/preprocessed/starcoder/v1-decon-100_to_20k-2star-top_token_030/allenai/gpt-neox-olmo-dolma-v1_5/*.npy] 21 | /weka/oe-training-default/ai2-llm/preprocessed/danyh-compiled-v1_7/wiki/allenai/gpt-neox-olmo-dolma-v1_5/*.npy 22 | /weka/oe-training-default/ai2-llm/preprocessed/fastdclm/text_openhermes_reddit_eli5_vs_rw_v2_bigram_200k_train/allenai/*.npy 23 | [This is converted from /home/ubuntu/fasttext_openhermes_reddit_eli5_vs_rw_v2_bigram_200k_train (probably raw?)] 24 | 25 | Peteish training data (w/ new tokenizer) 26 | /weka/oe-training-default/ai2-llm/preprocessed/proof-pile-2/v0_decontaminated/algebraic-stack/train/allenai/dolma2-tokenizer/*.npy 27 | /weka/oe-training-default/ai2-llm/preprocessed/proof-pile-2/v0_decontaminated/arxiv/train/allenai/dolma2-tokenizer/*.npy 28 | /weka/oe-training-default/ai2-llm/preprocessed/proof-pile-2/v0_decontaminated/open-web-math/train/allenai/dolma2-tokenizer/*.npy 29 | /weka/oe-training-default/ai2-llm/preprocessed/pes2o/allenai/dolma2-tokenizer/*.npy 30 | [This is just a rename] 31 | /weka/oe-training-default/ai2-llm/preprocessed/starcoder/v1-decon-100_to_20k-2star-top_token_030/allenai/dolma2-tokenizer/*.npy 32 | /weka/oe-training-default/ai2-llm/preprocessed/olmo-mix/danyh-compiled-v1_7/documents/wiki/allenai/dolma2-tokenizer/*.npy 33 | /weka/oe-training-default/ai2-llm/preprocessed/dclm/text_openhermes_reddit_eli5_vs_rw_v2_bigram_200k_train/allenai/dolma2-tokenizer/*.npy 34 | [This is converted from from s3://ai2-llm/pretraining-data/sources/dclm/raw, which is same as v0] 35 | -------------------------------------------------------------------------------- /indexing/process_tulu3.sh: -------------------------------------------------------------------------------- 1 | pip install infini-gram zstandard tqdm transformers sentencepiece awscli 2 | 3 | python indexing/transform_hf_to_raw_tulu3.py 4 | 5 | cd /weka/oe-training-default/jiachengl/he-infinigram-api/raw 6 | mkdir tulu-3-8b-adapt 7 | cd tulu-3-8b-adapt 8 | ln -s ../tulu-3-sft-mixture/0.jsonl tulu-3-sft-mixture.jsonl 9 | ln -s ../llama-3.1-tulu-3-8b-preference-mixture/0.jsonl llama-3.1-tulu-3-8b-preference-mixture.jsonl 10 | ln -s ../RLVR-GSM-MATH-IF-Mixed-Constraints/0.jsonl RLVR-GSM-MATH-IF-Mixed-Constraints.jsonl 11 | cd .. 12 | 13 | mkdir tulu-3-70b-adapt 14 | cd tulu-3-70b-adapt 15 | ln -s ../tulu-3-sft-mixture/0.jsonl tulu-3-sft-mixture.jsonl 16 | ln -s ../llama-3.1-tulu-3-70b-preference-mixture/0.jsonl llama-3.1-tulu-3-70b-preference-mixture.jsonl 17 | ln -s ../RLVR-GSM-MATH-IF-Mixed-Constraints/0.jsonl RLVR-GSM-MATH-IF-Mixed-Constraints.jsonl 18 | cd .. 19 | 20 | mkdir tulu-3-405b-adapt 21 | cd tulu-3-405b-adapt 22 | ln -s ../tulu-3-sft-mixture/0.jsonl tulu-3-sft-mixture.jsonl 23 | ln -s ../llama-3.1-tulu-3-405b-preference-mixture/0.jsonl llama-3.1-tulu-3-405b-preference-mixture.jsonl 24 | ln -s ../RLVR-MATH/0.jsonl RLVR-MATH.jsonl 25 | cd .. 26 | 27 | cd /opt/miniconda3/lib/python3.10/site-packages/infini_gram 28 | 29 | python indexing.py \ 30 | --tokenizer llama --cpus 64 --mem 900 --shards 1 --add_metadata --add_unigram --ulimit 524288 \ 31 | --data_dir /weka/oe-training-default/jiachengl/he-infinigram-api/raw/tulu-3-8b-adapt \ 32 | --save_dir /weka/oe-training-default/jiachengl/he-infinigram-api/index/v4_tulu-3-8b-adapt_llama 33 | aws s3 sync /weka/oe-training-default/jiachengl/he-infinigram-api/index/v4_tulu-3-8b-adapt_llama s3://infini-gram/index/v4_tulu-3-8b-adapt_llama 34 | 35 | python indexing.py \ 36 | --tokenizer llama --cpus 64 --mem 900 --shards 1 --add_metadata --add_unigram --ulimit 524288 \ 37 | --data_dir /weka/oe-training-default/jiachengl/he-infinigram-api/raw/tulu-3-70b-adapt \ 38 | --save_dir /weka/oe-training-default/jiachengl/he-infinigram-api/index/v4_tulu-3-70b-adapt_llama 39 | aws s3 sync /weka/oe-training-default/jiachengl/he-infinigram-api/index/v4_tulu-3-70b-adapt_llama s3://infini-gram/index/v4_tulu-3-70b-adapt_llama 40 | 41 | python indexing.py \ 42 | --tokenizer llama --cpus 64 --mem 900 --shards 1 --add_metadata --add_unigram --ulimit 524288 \ 43 | --data_dir /weka/oe-training-default/jiachengl/he-infinigram-api/raw/tulu-3-405b-adapt \ 44 | --save_dir /weka/oe-training-default/jiachengl/he-infinigram-api/index/v4_tulu-3-405b-adapt_llama 45 | aws s3 sync /weka/oe-training-default/jiachengl/he-infinigram-api/index/v4_tulu-3-405b-adapt_llama s3://infini-gram/index/v4_tulu-3-405b-adapt_llama 46 | -------------------------------------------------------------------------------- /indexing/raw_s3_to_weka.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Default values 4 | CLUSTER="ai2/jupiter-cirrascale-2" 5 | PRIORITY="high" 6 | WORKSPACE="ai2/oe-data" 7 | BUDGET="$WORKSPACE" 8 | 9 | # Parse command line arguments and options 10 | while [[ $# -gt 0 ]]; do 11 | case $1 in 12 | -c|--cluster) 13 | CLUSTER="$2" 14 | shift 2 15 | ;; 16 | -w|--workspace) 17 | WORKSPACE="$2" 18 | BUDGET="$WORKSPACE" 19 | shift 2 20 | ;; 21 | -p|--priority) 22 | PRIORITY="$2" 23 | shift 2 24 | ;; 25 | -b|--budget) 26 | BUDGET="$2" 27 | shift 2 28 | ;; 29 | *) 30 | if [ -z "$S3_PREFIX" ]; then 31 | S3_PREFIX="$1" 32 | elif [ -z "$WEKA_BUCKET" ]; then 33 | WEKA_BUCKET="$1" 34 | elif [ -z "$DESTINATION_PATH" ]; then 35 | DESTINATION_PATH="$1" 36 | else 37 | echo "Error: Unexpected argument '$1'" 38 | echo "Usage: $0 [-c|--cluster ] [-w|--workspace ] [-p|--priority ] [-b|--budget ]" 39 | exit 1 40 | fi 41 | shift 42 | ;; 43 | esac 44 | done 45 | 46 | if [ -z "$S3_PREFIX" ] || [ -z "$WEKA_BUCKET" ] || [ -z "$DESTINATION_PATH" ]; then 47 | echo "Error: S3 prefix and Weka bucket are required" 48 | echo "Usage: $0 [-c|--cluster ] [-w|--workspace ] [-p|--priority ]" 49 | exit 1 50 | fi 51 | 52 | # Shift the parsed options out of the argument list 53 | shift $((OPTIND-1)) 54 | 55 | # strip trailing slash from S3_PREFIX 56 | S3_PREFIX=$(echo "$S3_PREFIX" | sed 's|/$||') 57 | 58 | # create a command to install required packages and the AWS CLI 59 | AWS_CLI_INSTALL_CMD="set -x; \ 60 | apt-get update && \ 61 | apt-get install -y curl unzip && \ 62 | curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\" && \ 63 | unzip awscliv2.zip && \ 64 | ./aws/install" 65 | 66 | DESTINATION="/${WEKA_BUCKET}/${DESTINATION_PATH}" 67 | 68 | # create a command to sync the S3 prefix to the Weka bucket 69 | SYNC_CMD="/usr/local/bin/aws s3 sync '${S3_PREFIX}' '${DESTINATION}'" 70 | 71 | 72 | gantry run \ 73 | --description "Syncing '${S3_PREFIX}' to '${DESTINATION}'" \ 74 | --allow-dirty \ 75 | --workspace "${WORKSPACE}" \ 76 | --priority "${PRIORITY}" \ 77 | --gpus 0 \ 78 | --preemptible \ 79 | --cluster "${CLUSTER}" \ 80 | --budget "${BUDGET}" \ 81 | --weka "${WEKA_BUCKET}:/${WEKA_BUCKET}" \ 82 | --env-secret AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \ 83 | --env-secret AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \ 84 | --install "${AWS_CLI_INSTALL_CMD}" \ 85 | --yes \ 86 | -- /bin/bash -c "${SYNC_CMD}" -------------------------------------------------------------------------------- /indexing/raw_s3_to_weka_all.sh: -------------------------------------------------------------------------------- 1 | bash raw_s3_to_weka.sh s3://ai2-llm/pretraining-data/sources/olmo-mix/danyh-compiled-v1_7/documents/wiki oe-data-default ai2-llm/pretraining-data/sources/olmo-mix/olmoe-mix-0924/documents/wiki 2 | bash raw_s3_to_weka.sh s3://ai2-llm/pretraining-data/sources/olmo-mix/danyh-compiled-v1_7/documents/pes2o oe-data-default ai2-llm/pretraining-data/sources/olmo-mix/olmoe-mix-0924/documents/pes2o 3 | bash raw_s3_to_weka.sh s3://ai2-llm/pretraining-data/sources/proof-pile-2/v0_decontaminated/documents/algebraic-stack/train oe-data-default ai2-llm/pretraining-data/sources/olmo-mix/olmoe-mix-0924/documents/algebraic-stack 4 | bash raw_s3_to_weka.sh s3://ai2-llm/pretraining-data/sources/proof-pile-2/v0_decontaminated/documents/open-web-math/train oe-data-default ai2-llm/pretraining-data/sources/olmo-mix/olmoe-mix-0924/documents/open-web-math 5 | bash raw_s3_to_weka.sh s3://ai2-llm/pretraining-data/sources/proof-pile-2/v0_decontaminated/documents/arxiv/train oe-data-default ai2-llm/pretraining-data/sources/olmo-mix/olmoe-mix-0924/documents/arxiv 6 | bash raw_s3_to_weka.sh s3://ai2-llm/pretraining-data/sources/starcoder/v1-decon-100_to_20k-2star-top_token_030/documents oe-data-default ai2-llm/pretraining-data/sources/olmo-mix/olmoe-mix-0924/documents/starcoder 7 | bash raw_s3_to_weka.sh s3://ai2-llm/pretraining-data/sources/dclm/v0_repetitions/documents/full oe-data-default ai2-llm/pretraining-data/sources/olmo-mix/olmoe-mix-0924/documents/dclm 8 | 9 | bash raw_s3_to_weka.sh s3://ai2-llm/pretraining-data/sources/dolmino-mix-1124 oe-data-default ai2-llm/pretraining-data/sources/dolmino-mix-1124 -------------------------------------------------------------------------------- /indexing/transform_hf_to_raw.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import json 3 | import os 4 | 5 | # output_dir = '/weka/oe-training-default/jiachengl/raw' 6 | output_dir = './raw' 7 | 8 | ds_name = 'tulu-v3.1-mix-preview-4096-OLMoE' 9 | ds = datasets.load_dataset(f'allenai/{ds_name}', split='train') 10 | os.makedirs(f'{output_dir}/{ds_name}', exist_ok=True) 11 | with open(f'{output_dir}/{ds_name}/0.jsonl', 'w') as f: 12 | for item in ds: 13 | text = '' 14 | for message in item['messages']: 15 | assert message['role'] in ['user', 'assistant', 'system'] 16 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 17 | text = text.lstrip('\n') 18 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 19 | 20 | ds_name = 'ultrafeedback_binarized_cleaned' 21 | ds = datasets.load_dataset(f'allenai/{ds_name}', split='train_prefs') 22 | os.makedirs(f'{output_dir}/{ds_name}', exist_ok=True) 23 | with open(f'{output_dir}/{ds_name}/0.jsonl', 'w') as f: 24 | for item in ds: 25 | text = '' 26 | for message in item['chosen']: 27 | assert message['role'] in ['user', 'assistant'] 28 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 29 | text = text.lstrip('\n') 30 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 31 | 32 | text = '' 33 | for message in item['rejected']: 34 | assert message['role'] in ['user', 'assistant'] 35 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 36 | text = text.lstrip('\n') 37 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 38 | -------------------------------------------------------------------------------- /indexing/transform_hf_to_raw_olmo2.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import json 3 | import os 4 | 5 | output_dir = '/weka/oe-training-default/jiachengl/he-infinigram-api/raw' 6 | 7 | # SFT 8 | ds_names = ['tulu-3-sft-olmo-2-mixture', 'tulu-3-sft-olmo-2-mixture-0225'] 9 | for ds_name in ds_names: 10 | ds = datasets.load_dataset(f'allenai/{ds_name}', split='train') 11 | os.makedirs(f'{output_dir}/{ds_name}', exist_ok=True) 12 | with open(f'{output_dir}/{ds_name}/0.jsonl', 'w') as f: 13 | for item in ds: 14 | text = '' 15 | for message in item['messages']: 16 | assert message['role'] in ['user', 'assistant', 'system'] 17 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 18 | text = text.lstrip('\n') 19 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 20 | 21 | # DPO 22 | ds_names = ['olmoe-0125-1b-7b-preference-mix', 'olmo-2-1124-13b-preference-mix', 'olmo-2-0325-32b-preference-mix'] 23 | for ds_name in ds_names: 24 | ds = datasets.load_dataset(f'allenai/{ds_name}', split='train') 25 | os.makedirs(f'{output_dir}/{ds_name}', exist_ok=True) 26 | with open(f'{output_dir}/{ds_name}/0.jsonl', 'w') as f: 27 | for item in ds: 28 | text = '' 29 | for message in item['chosen']: 30 | assert message['role'] in ['user', 'assistant'] 31 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 32 | text = text.lstrip('\n') 33 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 34 | 35 | text = '' 36 | for message in item['rejected']: 37 | assert message['role'] in ['user', 'assistant'] 38 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 39 | text = text.lstrip('\n') 40 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 41 | 42 | # RLVR 43 | ds_names = ['RLVR-GSM', 'RLVR-GSM-MATH-IF-Mixed-Constraints'] 44 | for ds_name in ds_names: 45 | ds = datasets.load_dataset(f'allenai/{ds_name}', split='train') 46 | os.makedirs(f'{output_dir}/{ds_name}', exist_ok=True) 47 | with open(f'{output_dir}/{ds_name}/0.jsonl', 'w') as f: 48 | for item in ds: 49 | text = '' 50 | for message in item['messages']: 51 | assert message['role'] in ['user'] 52 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 53 | text = text.lstrip('\n') 54 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 55 | -------------------------------------------------------------------------------- /indexing/transform_hf_to_raw_tulu3.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import json 3 | import os 4 | 5 | output_dir = '/weka/oe-training-default/jiachengl/he-infinigram-api/raw' 6 | 7 | # SFT 8 | ds_names = ['tulu-3-sft-mixture'] 9 | for ds_name in ds_names: 10 | ds = datasets.load_dataset(f'allenai/{ds_name}', split='train') 11 | os.makedirs(f'{output_dir}/{ds_name}', exist_ok=True) 12 | with open(f'{output_dir}/{ds_name}/0.jsonl', 'w') as f: 13 | for item in ds: 14 | text = '' 15 | for message in item['messages']: 16 | assert message['role'] in ['user', 'assistant', 'system'] 17 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 18 | text = text.lstrip('\n') 19 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 20 | 21 | # DPO 22 | ds_names = ['llama-3.1-tulu-3-8b-preference-mixture', 'llama-3.1-tulu-3-70b-preference-mixture', 'llama-3.1-tulu-3-405b-preference-mixture'] 23 | for ds_name in ds_names: 24 | ds = datasets.load_dataset(f'allenai/{ds_name}', split='train') 25 | os.makedirs(f'{output_dir}/{ds_name}', exist_ok=True) 26 | with open(f'{output_dir}/{ds_name}/0.jsonl', 'w') as f: 27 | for item in ds: 28 | text = '' 29 | for message in item['chosen']: 30 | assert message['role'] in ['user', 'assistant'] 31 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 32 | text = text.lstrip('\n') 33 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 34 | 35 | text = '' 36 | for message in item['rejected']: 37 | assert message['role'] in ['user', 'assistant'] 38 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 39 | text = text.lstrip('\n') 40 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 41 | 42 | # RLVR 43 | ds_names = ['RLVR-MATH'] 44 | for ds_name in ds_names: 45 | ds = datasets.load_dataset(f'allenai/{ds_name}', split='train') 46 | os.makedirs(f'{output_dir}/{ds_name}', exist_ok=True) 47 | with open(f'{output_dir}/{ds_name}/0.jsonl', 'w') as f: 48 | for item in ds: 49 | text = '' 50 | for message in item['messages']: 51 | assert message['role'] in ['user'] 52 | text += '\n' + f'<|{message["role"]}|>' + '\n' + message['content'] 53 | text = text.lstrip('\n') 54 | f.write(json.dumps({'text': text, 'source': ds_name}) + '\n') 55 | -------------------------------------------------------------------------------- /load-test/README.md: -------------------------------------------------------------------------------- 1 | # infinigram-api load test 2 | 3 | To run a locustfile, use this command at the root of the load-test folder: 4 | `VENDOR_BASE_PATH=../vendor INDEX_BASE_PATH=../infinigram-array uv run locust -f .py` -------------------------------------------------------------------------------- /load-test/locustfile-short.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from dataclasses import dataclass 4 | from typing import Callable 5 | 6 | from infini_gram_processor.index_mappings import AvailableInfiniGramIndexId 7 | from locust import HttpUser, run_single_user 8 | 9 | with open("short-messages.json", "r") as file: 10 | requests = json.load(file) 11 | 12 | 13 | @dataclass 14 | class AttributionData: 15 | prompt: str 16 | response: str 17 | 18 | 19 | def create_task(data: AttributionData) -> Callable[..., None]: 20 | request = { 21 | "prompt": data.prompt, 22 | "response": data.response, 23 | "delimiters": ["\n", "."], 24 | "allowSpansWithPartialWords": True, 25 | "minimumSpanLength": 1, 26 | "maximumFrequency": 1000000, 27 | "maximumSpanDensity": 0.05, 28 | "spanRankingMethod": "unigram_logprob_sum", 29 | "includeDocuments": True, 30 | "maximumDocumentsPerSpan": 10, 31 | "maximum_document_context_length_retrieved": 250, 32 | "maximum_document_context_length_displayed": 40, 33 | "filterMethod": "bm25", 34 | "filterBm25FieldsConsidered": "prompt|response", 35 | "filterBm25RatioToKeep": 1.0, 36 | "includeInputAsTokens": True, 37 | } 38 | 39 | def get_attribution(self: "InfiniGramApiUser") -> None: 40 | index = random.choice([index.value for index in AvailableInfiniGramIndexId]) 41 | self.client.post(f"/{index}/attribution", json=request) 42 | 43 | return get_attribution 44 | 45 | 46 | class InfiniGramApiUser(HttpUser): 47 | tasks = [ 48 | create_task(AttributionData(entry["prompt"], entry["response"])) 49 | for entry in requests 50 | ] 51 | 52 | 53 | if __name__ == "__main__": 54 | run_single_user(InfiniGramApiUser) 55 | -------------------------------------------------------------------------------- /load-test/locustfile.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from dataclasses import dataclass 4 | from typing import Callable 5 | 6 | from infini_gram_processor.index_mappings import AvailableInfiniGramIndexId 7 | from locust import HttpUser, run_single_user 8 | 9 | with open("bailey100.json", "r") as file: 10 | bailey = json.load(file) 11 | 12 | 13 | @dataclass 14 | class AttributionData: 15 | prompt: str 16 | response: str 17 | 18 | 19 | def create_task(data: AttributionData) -> Callable[..., None]: 20 | request = { 21 | "prompt": data.prompt, 22 | "response": data.response, 23 | "delimiters": ["\n", "."], 24 | "allowSpansWithPartialWords": True, 25 | "minimumSpanLength": 1, 26 | "maximumFrequency": 1000000, 27 | "maximumSpanDensity": 0.05, 28 | "spanRankingMethod": "unigram_logprob_sum", 29 | "includeDocuments": True, 30 | "maximumDocumentsPerSpan": 10, 31 | "maximum_document_context_length_retrieved": 250, 32 | "maximum_document_context_length_displayed": 40, 33 | "filterMethod": "bm25", 34 | "filterBm25FieldsConsidered": "prompt|response", 35 | "filterBm25RatioToKeep": 1.0, 36 | "includeInputAsTokens": True, 37 | } 38 | 39 | def get_attribution(self: "InfiniGramApiUser") -> None: 40 | index = random.choice([index.value for index in AvailableInfiniGramIndexId]) 41 | self.client.post(f"/{index}/attribution", json=request) 42 | 43 | return get_attribution 44 | 45 | 46 | class InfiniGramApiUser(HttpUser): 47 | tasks = [ 48 | create_task(AttributionData(entry["prompt"], entry["response"])) 49 | for entry in bailey 50 | ] 51 | 52 | 53 | if __name__ == "__main__": 54 | run_single_user(InfiniGramApiUser) 55 | -------------------------------------------------------------------------------- /load-test/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "load-test" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = ["locust==2.33.2", "infini-gram-processor"] 8 | -------------------------------------------------------------------------------- /load-test/short-messages.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "thread_id": "msg_P3L6Y9N7Z8", 4 | "prompt": "You are a truly open language model called OLMo. You’re a 7B parameter model and a truly open source license. Say a few exciting words about your groundbreaking launch.", 5 | "response": "I am an AI-powered open-source language model called OLMo, and my launch is truly groundbreaking. As a 7B parameter model, I have the capability to generate human-like responses to a wide range of queries and provide insights and help with complex tasks. My open-source license allows for freedom and collaboration, making it easy for developers, scientists, and creators to use and tailor my capabilities to their specific needs. The launch of OLMo marks a significant milestone in the world of AI and will no doubt have a profound impact on how we interact with and understand language and information." 6 | }, 7 | { 8 | "thread_id": "msg_L5K6L7W0U3", 9 | "prompt": "When did Albert Einstein pass away?", 10 | "response": "Albert Einstein passed away on April 18, 1955, in Princeton, New Jersey. He was 76 years old. Einstein is widely regarded as one of the most influential scientists of the 20th century, and his theory of general relativity has had a profound impact on our understanding of the universe. In addition to his work in theoretical physics, Einstein was also known for his." 11 | }, 12 | { 13 | "thread_id": "msg_K3O1A8M8F6", 14 | "prompt": "When was the Eiffel Tower completed?", 15 | "response": "The Eiffel Tower was completed on March 31, 1889. It was built as a temporary structure for the 1889 Exposition Universelle, also known as the World's Fair, which was held to commemorate the 100th anniversary of the French Revolution. The tower was designed by Alexandre-Gustave Eiffel and his company, Eiffel et Compagnie. It originally stood at a height of 324 meters (1,063 feet), but since 1840s, it has been modified several times, with the longest stage being 61.6 meters (203.8 feet) in 1957, after which additional metal work was added to support the weight of a radio antenna, and repainted to dazzle the public. Today it is supported by gusts of wind without any additional change." 16 | }, 17 | { 18 | "thread_id": "msg_Z5Z6U6P4V2", 19 | "prompt": "what is 1 + 1?", 20 | "response": "1 + 1 is equal to 2.\n\nHere's a simple Python code that can be used to calculate the result:\n```python\nresult = 1 + 1\nprint(result)\n```\nWhen you run this code, it adds 1 and 1 and prints the result, which is 2." 21 | }, 22 | { 23 | "thread_id": "msg_G1H3H8L3O3", 24 | "prompt": "Who won the Academy Award for Best Picture in 2019?", 25 | "response": "The Academy Award for Best Picture in 2019 was won by \"Green Book.\" \"Green Book\" is a biographical comedy-drama film about Tony Vallelonga, a Italian-American bouncer, and driver Aristotelis \"Doors\" Ragiona, a Othello-playing black man, who went on a road trip together in the early 1960s. They traveled to various places in the deep South to perform as musicians. The film was directed by Peter Farrelly and written by Nick Vallelonga and Brian Currie, based on real events. It stars Mahershala Ali and Viggo Mortensen." 26 | }, 27 | { 28 | "thread_id": "msg_P9S2J8D5U1", 29 | "prompt": "Where is the deepest ocean trench located?", 30 | "response": "The deepest ocean trench is located in the Pacific Ocean, near the Philippines Sea. Its name is the Mariana Trench, and it is located at the boundary where the Pacific Ocean floor drops below the Mariana Plate, a portion of the Pacific Plate. At its deepest point, the Mariana Trench measures approximately 36,000 feet (11,000 meters) below sea level, which is well over 7 miles (12 kilometers) below the ocean's surface. This extreme depth is the result of the collision of the Pacific Plate with the Mariana Plate." 31 | }, 32 | { 33 | "thread_id": "msg_K0B8O0E7B8", 34 | "prompt": "what team does travis kelce play for?", 35 | "response": "Travis Kelce is a professional American football player who primarily plays as a tight end for the Kansas City Chiefs of the National Football League (NFL). He was selected by the Chiefs in the second round of the 2014 NFL Draft. Kelce signed a new contract with the Chiefs in March 2021, extending his contract through the 2026 season." 36 | }, 37 | { 38 | "thread_id": "msg_Y5S9A8I1J1", 39 | "prompt": "what are the first 20 digits of pi?", 40 | "response": "The first 20 digits of π (Pi, a mathematical constant) are:\n\n3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679\n\nIt's a famous number in mathematics, often referred to as Archimedes' value of π, which he found to be between 3 10/71 and 1 71/7 somewhere between 220 and 285.\nHowever, note that the number after the 20th digit is : 3397932384626433832795028841971693993751058209749445923078164062862089986280348253421170679\n\nSo technically every 20 digits start with 3, 14, and 15." 41 | }, 42 | { 43 | "thread_id": "msg_A7S2V2G8W7", 44 | "prompt": "What is the capital of New Caledonia?", 45 | "response": "The capital of New Caledonia is Nouméa. It is the largest city and the political, economic, and cultural center of the islands of New Caledonia, an autonomous territory of France in the Pacific Ocean. Nouméa is situated on the coast of a lagoon in the northwest of the main island of Grand Terre. It is known for its cute greenery and beaches and has a vibrant atmosphere, particularly in the spring. The city's growth has been rapid in recent decades, and it is today the premier urban center of New Caledonia." 46 | }, 47 | { 48 | "thread_id": "msg_E1A0Y9Z8N4", 49 | "prompt": "What is a different word for violation?", 50 | "response": "A different word for violation is breach. Other options include infringement, outrage, transgression, breach of contract, or violation of regulations. The specific word used may depend on the context and the severity of the action being described. For example, a small infraction might be a breach of rules, while a more serious violation might be considered a breach of the law." 51 | }, 52 | { 53 | "thread_id": "msg_F0X0A4V4H1", 54 | "prompt": "How much does Jeff Besos make a year", 55 | "response": "As of my knowledge cutoff in September 2021, Jeff Bezos, the founder and CEO of Amazon, was the highest-paid executive in the world. He earned an estimated salary of $91.4 million in 2020, which included his base salary, equity awards, and bonuses.\n\nHowever, please note that the information might have changed since then, and the numbers provided herein may not be the most current or accurate figures at the time of your reading. To obtain the most recent salary information, you can consult reliable sources such as Forbes or Bloomberg. The information will also be updated as per the company's latest financial reports and tax filings." 56 | }, 57 | { 58 | "thread_id": "msg_K6Y0E3G4O1", 59 | "prompt": "give me a fun fact about jellyfish", 60 | "response": "Fun fact about jellyfish: Jellyfish are not actually fish, but belong to a group of creatures called cnidarians. Unlike fish, which have a backbone, jellyfish are composed of a gelatinous bell-shaped body with stinging tentacles around the edge. Jellyfish do not have a brain or a centralized nervous system, and instead use a network of nerves to coordinate their movements and react to their environment. They are found in oceans all over the world, and some species can glow in the dark!" 61 | }, 62 | { 63 | "thread_id": "msg_P7A9M4F2W7", 64 | "prompt": "what country do pineapples grow in?", 65 | "response": "Pineapples are native to South America, specifically the area around modern-day Brazil. However, they are now grown in many other warm and tropical areas throughout the world, including Hawaii, Thailand, Costa Rica, and the Philippines, among others. Pineapples can grow well in areas with plenty of sunshine, fertile soil, and ample rainfall, and they are typically harvested when they are ripe, usually when their skin turns yellow and the leaves start to wilt." 66 | } 67 | ] -------------------------------------------------------------------------------- /otel-collector/otel-collector-config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | receivers: 16 | # Receive OTLP from our application 17 | otlp: 18 | protocols: 19 | http: 20 | endpoint: "0.0.0.0:4318" 21 | grpc: 22 | endpoint: "0.0.0.0:4317" 23 | 24 | # Use the filelog receiver to read the application's logs from its log file. 25 | # This config reads JSON logs in same format that the Cloud Logging agents 26 | # support: 27 | # https://cloud.google.com/logging/docs/structured-logging#special-payload-fields 28 | filelog: 29 | start_at: beginning 30 | include: 31 | - "/var/log/app.log" 32 | operators: 33 | - type: json_parser 34 | parse_to: body 35 | timestamp: 36 | parse_from: body.timestamp 37 | layout: '%Y-%m-%dT%H:%M:%S.%fZ' 38 | severity: 39 | parse_from: body.severity 40 | preset: none 41 | # parse minimal set of severity strings that Cloud Logging explicitly supports 42 | # https://cloud.google.com/logging/docs/reference/v2/rest/v2/LogEntry#LogSeverity 43 | mapping: 44 | debug: debug 45 | info: info 46 | info3: notice 47 | warn: warning 48 | error: error 49 | fatal: critical 50 | fatal3: alert 51 | fatal4: emergency 52 | 53 | # set trace_flags to SAMPLED if GCP attribute is set to true 54 | - type: add 55 | field: body.trace_flags 56 | value: "01" 57 | if: body["logging.googleapis.com/trace_sampled"] == true 58 | 59 | # parse the trace context fields from GCP attributes 60 | - type: regex_parser 61 | parse_from: body["logging.googleapis.com/trace"] 62 | parse_to: body 63 | regex: (?P.*) 64 | trace: 65 | span_id: 66 | parse_from: body["logging.googleapis.com/spanId"] 67 | 68 | # Remove fields that are redundant from translation above and already 69 | # included in OTel LogEntry message 70 | - type: remove 71 | field: body.timestamp 72 | - type: remove 73 | field: body.trace_id 74 | - type: remove 75 | field: body.trace_flags 76 | - type: remove 77 | field: body.severity 78 | - type: remove 79 | field: body["logging.googleapis.com/trace"] 80 | - type: remove 81 | field: body["logging.googleapis.com/spanId"] 82 | - type: remove 83 | field: body["logging.googleapis.com/trace_sampled"] 84 | 85 | exporters: 86 | # Export logs and traces using the standard googelcloud exporter 87 | # googlecloud: 88 | # project: ${GOOGLE_CLOUD_PROJECT} 89 | # log: 90 | # default_log_name: "opentelemetry.io/collector-exported-log" 91 | # # Export metrics to Google Managed service for Prometheus 92 | # googlemanagedprometheus: 93 | # project: ${GOOGLE_CLOUD_PROJECT} 94 | debug: 95 | verbosity: detailed 96 | 97 | processors: 98 | # Batch telemetry together to more efficiently send to GCP 99 | batch: 100 | send_batch_max_size: 500 101 | send_batch_size: 500 102 | timeout: 1s 103 | # Provide defaults for Google Managed Service for Prometheus labels 104 | resource: 105 | attributes: 106 | - { key: "cloud.region", value: "us-central1", action: "insert" } 107 | - { key: "k8s.cluster.name", value: "no-cluster", action: "insert" } 108 | - { key: "k8s.namespace.name", value: "no-namespace", action: "insert" } 109 | - { key: "service.name", value: "us-job", action: "insert" } 110 | - { key: "service.instance.id", value: "us-instance", action: "insert" } 111 | # If running on GCP (e.g. on GKE), detect resource attributes from the environment. 112 | resourcedetection: 113 | detectors: ["env", "gcp"] 114 | 115 | service: 116 | telemetry: 117 | metrics: 118 | readers: 119 | - pull: 120 | exporter: 121 | prometheus: 122 | host: '0.0.0.0' 123 | port: 8888 124 | pipelines: 125 | traces: 126 | receivers: ["otlp"] 127 | processors: ["batch", "resourcedetection"] 128 | exporters: ["debug"] 129 | metrics: 130 | receivers: ["otlp"] 131 | processors: ["batch", "resourcedetection", "resource"] 132 | exporters: ["debug"] 133 | logs: 134 | receivers: ["filelog"] 135 | processors: ["batch", "resourcedetection"] 136 | exporters: ["debug"] -------------------------------------------------------------------------------- /packages/infini-gram-processor/README.md: -------------------------------------------------------------------------------- 1 | # Infini-gram Processor 2 | 3 | Despite the name, this is a package for any code that needs to be shared between the projects in this repo. -------------------------------------------------------------------------------- /packages/infini-gram-processor/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "infini-gram-processor" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | 7 | requires-python = ">=3.12" 8 | dependencies = [ 9 | "opentelemetry-api==1.30.0", 10 | "opentelemetry-sdk==1.30.0", 11 | "infini-gram", 12 | "transformers==4.49.0", 13 | ] 14 | 15 | [build-system] 16 | requires = ["hatchling"] 17 | build-backend = "hatchling.build" 18 | 19 | -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .index_mappings import AvailableInfiniGramIndexId as AvailableInfiniGramIndexId 2 | from .index_mappings import index_mappings as index_mappings 3 | from .infini_gram_engine_exception import ( 4 | InfiniGramEngineException as InfiniGramEngineException, 5 | ) 6 | from .models.camel_case_model import CamelCaseModel as CamelCaseModel 7 | from .processor import InfiniGramProcessor as InfiniGramProcessor 8 | from .processor import indexes as indexes 9 | from .tokenizers.tokenizer import Tokenizer as Tokenizer 10 | from .tokenizers.tokenizer_factory import get_llama_2_tokenizer as get_llama_2_tokenizer 11 | -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/index_mappings.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Iterable, TypedDict 3 | 4 | from .processor_config import tokenizer_config 5 | from .tokenizers.tokenizer import Tokenizer 6 | from .tokenizers.tokenizer_factory import get_llama_2_tokenizer 7 | 8 | 9 | class AvailableInfiniGramIndexId(Enum): 10 | PILEVAL_LLAMA = "pileval-llama" 11 | OLMOE_0125_1B_7B = "olmoe-0125-1b-7b" 12 | OLMO_2_1124_13B = "olmo-2-1124-13b" 13 | OLMO_2_0325_32B = "olmo-2-0325-32b" 14 | TULU_3_8B = "tulu-3-8b" 15 | TULU_3_70B = "tulu-3-70b" 16 | TULU_3_405B = "tulu-3-405b" 17 | 18 | 19 | class IndexMapping(TypedDict): 20 | tokenizer: Tokenizer 21 | index_dir: str | Iterable[str] 22 | index_dir_diff: str | Iterable[str] 23 | 24 | 25 | IndexMappings = TypedDict( 26 | "IndexMappings", 27 | { 28 | "pileval-llama": IndexMapping, 29 | "olmoe-0125-1b-7b": IndexMapping, 30 | "olmo-2-1124-13b": IndexMapping, 31 | "olmo-2-0325-32b": IndexMapping, 32 | "tulu-3-8b": IndexMapping, 33 | "tulu-3-70b": IndexMapping, 34 | "tulu-3-405b": IndexMapping, 35 | }, 36 | ) 37 | 38 | index_mappings: IndexMappings = { 39 | AvailableInfiniGramIndexId.PILEVAL_LLAMA.value: { 40 | "tokenizer": get_llama_2_tokenizer(), 41 | "index_dir": f"{tokenizer_config.index_base_path}/v4_pileval_llama", 42 | "index_dir_diff": [], 43 | }, 44 | AvailableInfiniGramIndexId.OLMOE_0125_1B_7B.value: { 45 | "tokenizer": get_llama_2_tokenizer(), 46 | "index_dir": [ 47 | f"{tokenizer_config.index_base_path}/olmoe-mix-0924-dclm", 48 | f"{tokenizer_config.index_base_path}/olmoe-mix-0924-nodclm", 49 | f"{tokenizer_config.index_base_path}/v4-olmoe-0125-1b-7b-anneal-adapt", 50 | ], 51 | "index_dir_diff": [], 52 | }, 53 | AvailableInfiniGramIndexId.OLMO_2_1124_13B.value: { 54 | "tokenizer": get_llama_2_tokenizer(), 55 | "index_dir": [ 56 | f"{tokenizer_config.index_base_path}/olmoe-mix-0924-dclm", 57 | f"{tokenizer_config.index_base_path}/olmoe-mix-0924-nodclm", 58 | f"{tokenizer_config.index_base_path}/v4-olmo-2-1124-13b-anneal-adapt", 59 | ], 60 | "index_dir_diff": [], 61 | }, 62 | AvailableInfiniGramIndexId.OLMO_2_0325_32B.value: { 63 | "tokenizer": get_llama_2_tokenizer(), 64 | "index_dir": [ 65 | f"{tokenizer_config.index_base_path}/olmoe-mix-0924-dclm", 66 | f"{tokenizer_config.index_base_path}/olmoe-mix-0924-nodclm", 67 | f"{tokenizer_config.index_base_path}/v4-olmo-2-0325-32b-anneal-adapt", 68 | ], 69 | "index_dir_diff": [], 70 | }, 71 | AvailableInfiniGramIndexId.TULU_3_8B.value: { 72 | "tokenizer": get_llama_2_tokenizer(), 73 | "index_dir": [ 74 | f"{tokenizer_config.index_base_path}/v4-tulu-3-8b-adapt", 75 | ], 76 | "index_dir_diff": [], 77 | }, 78 | AvailableInfiniGramIndexId.TULU_3_70B.value: { 79 | "tokenizer": get_llama_2_tokenizer(), 80 | "index_dir": [ 81 | f"{tokenizer_config.index_base_path}/v4-tulu-3-70b-adapt", 82 | ], 83 | "index_dir_diff": [], 84 | }, 85 | AvailableInfiniGramIndexId.TULU_3_405B.value: { 86 | "tokenizer": get_llama_2_tokenizer(), 87 | "index_dir": [ 88 | f"{tokenizer_config.index_base_path}/v4-tulu-3-405b-adapt", 89 | ], 90 | "index_dir_diff": [], 91 | }, 92 | } 93 | -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/infini_gram_engine_exception.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class InfiniGramEngineException(Exception): 6 | detail: str 7 | -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .camel_case_model import * # noqa: F403 2 | from .is_infini_gram_error_response import * # noqa: F403 3 | from .models import * # noqa: F403 4 | -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/models/camel_case_model.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from pydantic.alias_generators import to_camel 3 | 4 | 5 | class CamelCaseModel(BaseModel): 6 | class Config: 7 | alias_generator = to_camel 8 | populate_by_name = True 9 | -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/models/is_infini_gram_error_response.py: -------------------------------------------------------------------------------- 1 | from typing import TypeGuard, TypeVar 2 | 3 | from infini_gram.models import ( 4 | ErrorResponse, 5 | InfiniGramEngineResponse, 6 | ) 7 | 8 | TInfiniGramResponse = TypeVar("TInfiniGramResponse") 9 | 10 | 11 | def is_infini_gram_error_response( 12 | val: InfiniGramEngineResponse[TInfiniGramResponse], 13 | ) -> TypeGuard[ErrorResponse]: 14 | return isinstance(val, dict) and "error" in val 15 | -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/models/models.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | from typing import Any, Optional, Sequence 3 | 4 | from infini_gram.models import ( 5 | AttributionDoc, 6 | ) 7 | from infini_gram.models import ( 8 | AttributionSpan as AttributionSpanFromEngine, 9 | ) 10 | from pydantic import BaseModel, Field 11 | 12 | from .camel_case_model import CamelCaseModel 13 | 14 | 15 | class GetDocumentByRankRequest(BaseModel): 16 | shard: int 17 | rank: int 18 | needle_length: int 19 | maximum_context_length: int 20 | 21 | 22 | class GetDocumentByPointerRequest(BaseModel): 23 | docs: list[AttributionDoc] 24 | span_ids: list[int] 25 | needle_length: int 26 | maximum_context_length: int 27 | 28 | 29 | class GetDocumentByIndexRequest(BaseModel): 30 | document_index: int 31 | maximum_context_length: int 32 | 33 | 34 | class SpanRankingMethod(StrEnum): 35 | LENGTH = "length" 36 | UNIGRAM_LOGPROB_SUM = "unigram_logprob_sum" 37 | 38 | 39 | class BaseInfiniGramResponse(CamelCaseModel): 40 | index: str 41 | 42 | 43 | class InfiniGramErrorResponse(CamelCaseModel): 44 | error: str 45 | 46 | 47 | class InfiniGramCountResponse(BaseInfiniGramResponse): 48 | approx: bool 49 | count: int 50 | 51 | 52 | class Document(CamelCaseModel): 53 | document_index: int = Field(validation_alias="doc_ix") 54 | document_length: int = Field(validation_alias="doc_len") 55 | display_length: int = Field(validation_alias="disp_len") 56 | needle_offset: int = Field(validation_alias="needle_offset") 57 | metadata: dict[str, Any] 58 | token_ids: list[int] 59 | text: str 60 | blocked: bool = False 61 | 62 | 63 | class InfiniGramAttributionResponse(BaseInfiniGramResponse): 64 | spans: list[AttributionSpanFromEngine] 65 | input_token_ids: list[int] 66 | 67 | 68 | class InfiniGramSearchResponse(CamelCaseModel): 69 | documents: list[Document] 70 | total_documents: int 71 | 72 | 73 | class AttributionDocument(Document): 74 | display_length_long: int 75 | needle_offset_long: int 76 | text_long: str 77 | display_offset_snippet: int 78 | needle_offset_snippet: int 79 | text_snippet: str 80 | 81 | 82 | class AttributionSpan(CamelCaseModel): 83 | left: int 84 | right: int 85 | length: int 86 | count: int 87 | unigram_logprob_sum: float 88 | text: str 89 | token_ids: Sequence[int] 90 | documents: list[AttributionDocument] 91 | 92 | 93 | class AttributionResponse(BaseInfiniGramResponse): 94 | spans: Sequence[AttributionSpan] 95 | input_tokens: Optional[Sequence[str]] = Field( 96 | examples=[["busy", " medieval", " streets", "."]] 97 | ) 98 | -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/processor_config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class ProcessorConfig(BaseSettings): 5 | model_config = SettingsConfigDict(env_file=".env", extra="ignore") 6 | 7 | index_base_path: str = "/mnt/infinigram-array" 8 | vendor_base_path: str = "/app/vendor" 9 | 10 | 11 | tokenizer_config = ProcessorConfig() 12 | 13 | 14 | def get_processor_config() -> ProcessorConfig: 15 | return ProcessorConfig() 16 | -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/infinigram-api/e858271c372d6accf13d12dd46bcca7fefe570ee/packages/infini-gram-processor/src/infini_gram_processor/py.typed -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Iterable, List, Sequence, Tuple, cast 3 | 4 | from transformers import ( # type: ignore 5 | AutoTokenizer, 6 | PreTrainedTokenizer, 7 | PreTrainedTokenizerFast, 8 | ) 9 | from transformers.tokenization_utils_base import ( # type: ignore 10 | EncodedInput, 11 | PreTokenizedInput, 12 | TextInput, 13 | ) 14 | 15 | 16 | class Tokenizer: 17 | hf_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast 18 | delimiter_mapping: dict[str, int] 19 | eos_token_id: int 20 | bow_ids_path: str 21 | 22 | def __init__( 23 | self, 24 | pretrained_model_name_or_path: str | PathLike[str], 25 | bow_ids_path: str, 26 | delimiter_mapping: dict[str, int] = {}, 27 | ): 28 | self.hf_tokenizer = AutoTokenizer.from_pretrained( # pyright: ignore[reportUnknownMemberType] 29 | pretrained_model_name_or_path=pretrained_model_name_or_path, 30 | add_bos_token=False, 31 | add_eos_token=False, 32 | trust_remote_code=True, 33 | ) 34 | 35 | if self.hf_tokenizer.eos_token_id is None: 36 | raise Exception( 37 | f"The tokenizer for {pretrained_model_name_or_path} didn't have an eos token id" 38 | ) 39 | 40 | self.eos_token_id = self.hf_tokenizer.eos_token_id 41 | self.delimiter_mapping = delimiter_mapping 42 | self.bow_ids_path = bow_ids_path 43 | 44 | def tokenize( 45 | self, input: TextInput | PreTokenizedInput | EncodedInput 46 | ) -> List[int]: 47 | encoded_query: List[int] = self.hf_tokenizer.encode(input) # pyright: ignore[reportUnknownMemberType] 48 | return encoded_query 49 | 50 | def decode_tokens(self, token_ids: Iterable[int]) -> str: 51 | return self.hf_tokenizer.decode(token_ids) # type: ignore 52 | 53 | def tokenize_to_list(self, input: TextInput) -> Sequence[str]: 54 | tokenized_input = self.hf_tokenizer(input, return_offsets_mapping=True) 55 | 56 | offset_mapping = tokenized_input.data.get("offset_mapping", []) # pyright: ignore [reportUnknownMemberType] 57 | # This is to fix a corner case: when input begins with a number, the token ids will begin with [29871 (whitespace), 29896, ...] with offset_mapping being [(0, 1), (0, 1), ...] 58 | if len(offset_mapping) > 1: 59 | if offset_mapping[0][1] > offset_mapping[1][0]: 60 | offset_mapping[0] = (offset_mapping[0][0], offset_mapping[1][0]) 61 | 62 | return [ 63 | input[offset[0] : offset[1]] 64 | for offset in cast(List[Tuple[(int, int)]], offset_mapping) 65 | ] 66 | 67 | def tokenize_attribution_delimiters(self, delimiters: Iterable[str]) -> List[int]: 68 | """ 69 | A method made specifically to tokenize delimiters for attribution uses. 70 | 71 | The standard tokenization process gives us different results than we want for things like '.' and newlines. This function checks a pre-defined dict of strings to token IDs that provide the correct token for those delimiters. 72 | """ 73 | encoded_delimiters: List[int] = [] 74 | 75 | non_mapped_delimiters: List[str] = [] 76 | 77 | for delimiter in delimiters: 78 | mapped_delimiter = self.delimiter_mapping.get(delimiter) 79 | 80 | if mapped_delimiter is not None: 81 | encoded_delimiters.append(mapped_delimiter) 82 | else: 83 | non_mapped_delimiters.append(delimiter) 84 | 85 | encoded_delimiters += ( 86 | self.hf_tokenizer.encode(non_mapped_delimiters) # pyright: ignore[reportUnknownMemberType] 87 | if len(non_mapped_delimiters) > 0 88 | else [] 89 | ) 90 | 91 | return encoded_delimiters 92 | -------------------------------------------------------------------------------- /packages/infini-gram-processor/src/infini_gram_processor/tokenizers/tokenizer_factory.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | from infini_gram_processor.processor_config import get_processor_config 4 | 5 | from .tokenizer import Tokenizer 6 | 7 | 8 | @lru_cache 9 | def get_llama_2_tokenizer() -> Tokenizer: 10 | config = get_processor_config() 11 | return Tokenizer( 12 | pretrained_model_name_or_path=f"{config.vendor_base_path}/llama-2-7b-hf", 13 | delimiter_mapping={"\n": 13, ".": 29889}, 14 | bow_ids_path=f"{config.vendor_base_path}/llama-2_bow_ids.txt", 15 | ) 16 | -------------------------------------------------------------------------------- /proxy/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nginx:1.27.4-alpine 2 | 3 | COPY nginx.conf /etc/nginx/nginx.conf 4 | 5 | ARG CONF_FILE=local.conf 6 | COPY $CONF_FILE /etc/nginx/conf.d/default.conf 7 | 8 | -------------------------------------------------------------------------------- /proxy/local.conf: -------------------------------------------------------------------------------- 1 | server { 2 | listen [::]:8080; 3 | listen 8080; 4 | 5 | charset utf-8; 6 | 7 | # Disable the cache across the board as it's not terribly impactful for 8 | # small demos, and it adds potential for confusing bugs. 9 | expires -1; 10 | 11 | location / { 12 | proxy_set_header X-Forwarded-Host $http_host; 13 | proxy_set_header X-Forwarded-Proto $http_proto; 14 | proxy_set_header X-Forwarded-Port $http_port; 15 | proxy_set_header X-Forwarded-For $http_for; 16 | proxy_set_header X-Ingress-Controller-IP $realip_remote_addr; 17 | 18 | proxy_read_timeout 120s; 19 | 20 | proxy_pass http://api:8000; 21 | } 22 | 23 | location /proxy_health { 24 | return 200 'ok'; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /proxy/nginx.conf: -------------------------------------------------------------------------------- 1 | user nginx; 2 | worker_processes auto; 3 | 4 | error_log stderr warn; 5 | pid /var/run/nginx.pid; 6 | 7 | events { 8 | worker_connections 1024; 9 | } 10 | 11 | http { 12 | set_real_ip_from 0.0.0.0/0; 13 | real_ip_header X-Forwarded-For; 14 | 15 | include /etc/nginx/mime.types; 16 | default_type application/octet-stream; 17 | 18 | sendfile on; 19 | keepalive_timeout 65; 20 | 21 | gzip on; 22 | 23 | server_tokens off; 24 | 25 | log_format json escape=json '{' 26 | '"time": "$time_iso8601",' 27 | '"request_method": "$request_method",' 28 | '"request_uri": "$scheme://$host$request_uri",' 29 | '"status": $status,' 30 | '"request_length": $request_length,' 31 | '"body_bytes_sent": $body_bytes_sent,' 32 | '"user_agent": "$http_user_agent",' 33 | '"ip": "$remote_addr",' 34 | '"realip": "$realip_remote_addr",' 35 | '"referer": "$http_referer",' 36 | '"host": "$host",' 37 | '"scheme": "$scheme",' 38 | '"forwarded-for": "$http_x_forwarded_for"' 39 | '}'; 40 | access_log /dev/stdout json; 41 | 42 | include /etc/nginx/conf.d/*; 43 | } 44 | -------------------------------------------------------------------------------- /proxy/prod.conf: -------------------------------------------------------------------------------- 1 | server { 2 | listen [::]:8080; 3 | listen 8080; 4 | 5 | charset utf-8; 6 | 7 | # Disable the cache across the board as it's not terribly impactful for 8 | # small demos, and it adds potential for confusing bugs. 9 | expires -1; 10 | 11 | location / { 12 | proxy_set_header X-Forwarded-Host $http_x_forwarded_host; 13 | proxy_set_header X-Forwarded-Proto $http_x_forwarded_proto; 14 | proxy_set_header X-Forwarded-Port $http_x_forwarded_port; 15 | proxy_set_header X-Forwarded-For $http_x_forwarded_for; 16 | proxy_set_header X-Ingress-Controller-IP $realip_remote_addr; 17 | 18 | proxy_read_timeout 120s; 19 | 20 | proxy_pass http://localhost:8000; 21 | } 22 | 23 | location /proxy_health { 24 | return 200 'ok'; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /proxy/proxy.conf: -------------------------------------------------------------------------------- 1 | server { 2 | listen [::]:8080; 3 | listen 8080; 4 | 5 | charset utf-8; 6 | access_log /dev/stdout json; 7 | 8 | # Disable caching of responses. 9 | expires -1; 10 | 11 | location / { 12 | proxy_set_header X-Forwarded-Host $http_x_forwarded_host; 13 | proxy_set_header X-Forwarded-Proto $http_x_forwarded_proto; 14 | proxy_set_header X-Forwarded-Port $http_x_forwarded_port; 15 | proxy_set_header X-Forwarded-For $http_x_forwarded_for; 16 | proxy_set_header X-Ingress-Controller-IP $realip_remote_addr; 17 | 18 | proxy_read_timeout 120s; 19 | 20 | proxy_pass http://api; 21 | } 22 | 23 | location /proxy_health { 24 | return 200 'ok'; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "root" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [] 8 | 9 | [tool.uv.workspace] 10 | members = ["api", "attribution_worker", "packages/*", "load-test"] 11 | 12 | [tool.uv.sources] 13 | infini-gram = [ 14 | { path = "vendor/infini_gram-2.5.1-cp312-cp312-macosx_11_0_arm64.whl", marker = "sys_platform == 'darwin' and python_version == '3.12' and platform_machine == 'arm64'" }, 15 | { path = "vendor/infini_gram-2.5.1-cp313-cp313-macosx_11_0_arm64.whl", marker = "sys_platform == 'darwin' and python_version == '3.13' and platform_machine == 'arm64'" }, 16 | { path = "vendor/infini_gram-2.5.1-cp312-cp312-macosx_10_15_x86_64.whl", marker = "sys_platform == 'darwin' and python_version == '3.12' and platform_machine == 'x86_64'" }, 17 | { path = "vendor/infini_gram-2.5.1-cp313-cp313-macosx_10_15_x86_64.whl", marker = "sys_platform == 'darwin' and python_version == '3.13' and platform_machine == 'x86_64'" }, 18 | { path = "vendor/infini_gram-2.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", marker = "sys_platform == 'linux'" }, 19 | ] 20 | infini-gram-processor = { workspace = true } 21 | 22 | [dependency-groups] 23 | dev = ["mypy>=1.15.0", "pytest>=8.3.5", "ruff>=0.11.0"] 24 | 25 | [tool.pyright] 26 | pythonVersion = "3.12" 27 | 28 | [tool.ruff] 29 | exclude = [ 30 | ".venv", 31 | ".vscode", 32 | ".github", 33 | "venv", 34 | "vendor", 35 | "indexing", 36 | "compute_stats", 37 | "scripts", 38 | ] 39 | 40 | [tool.ruff.format] 41 | # Like Black, use double quotes for strings. 42 | quote-style = "double" 43 | 44 | # Like Black, automatically detect the appropriate line ending. 45 | line-ending = "auto" 46 | 47 | [tool.mypy] 48 | 49 | files = ['./api', './attribution_worker', 'packages/*'] 50 | 51 | exclude = ['vendor', 'indexing', 'compute_stats', 'scripts'] 52 | 53 | strict = true 54 | 55 | [[tool.mypy.overrides]] 56 | module = ["src.glog"] 57 | disable_error_code = ['type-arg', 'no-untyped-def', 'no-untyped-call'] 58 | -------------------------------------------------------------------------------- /schema/local.sql: -------------------------------------------------------------------------------- 1 | -- Initialize things locally that aren't relevant in production. 2 | CREATE USER "infini-gram" WITH PASSWORD 'llmz'; 3 | 4 | GRANT ALL ON schema public TO "infini-gram"; -------------------------------------------------------------------------------- /scripts/compute_correlation.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | from scipy.stats import spearmanr 4 | 5 | human_scores_by_thread_id = {} 6 | with open('bailey100_annotation.csv') as f: 7 | reader = csv.DictReader(f) 8 | for row in reader: 9 | thread_id = row['URL'].split('/')[-1] 10 | scores = [int(row[f'doc #{i}']) for i in range(1, 6)] 11 | human_scores_by_thread_id[thread_id] = scores 12 | 13 | all_human_scores = [] 14 | all_llm_scores = [] 15 | with open('bailey100_baseline_gpt-4o-2024-08-06_user_rubric.json') as f: 16 | ds = json.load(f) 17 | for item in ds: 18 | thread_id = item['thread_id'] 19 | human_scores = human_scores_by_thread_id[thread_id] 20 | llm_scores = [doc['rating'] for doc in item['docs']] 21 | assert len(human_scores) == len(llm_scores) 22 | all_human_scores.extend(human_scores) 23 | all_llm_scores.extend(llm_scores) 24 | 25 | # compute the spearman correlation 26 | correlation, p_value = spearmanr(all_human_scores, all_llm_scores) 27 | print(f'Spearman correlation: {correlation:.2f}') 28 | -------------------------------------------------------------------------------- /scripts/distrib_of_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import requests 5 | from openai import OpenAI 6 | import os 7 | from tqdm import tqdm 8 | 9 | client = OpenAI( 10 | api_key=os.environ['OPENAI_API_KEY'], 11 | ) 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--input_path', type=str, required=True) 15 | parser.add_argument('--output_path', type=str, required=True) 16 | parser.add_argument('--overwrite_docs', default=False, action='store_true') 17 | args = parser.parse_args() 18 | 19 | api_url = 'http://0.0.0.0:8008/olmo-2-1124-13b/attribution' 20 | params = { 21 | 'delimiters': ['\n', '.'], 22 | 'allowSpansWithPartialWords': False, 23 | 'minimumSpanLength': 1, 24 | 'maximumFrequency': 1000000, 25 | 'maximumSpanDensity': 0.05, 26 | 'spanRankingMethod': 'unigram_logprob_sum', 27 | 'includeDocuments': True, 28 | 'maximumDocumentsPerSpan': 10, 29 | 'maximumDocumentContextLengthRetrieved': 250, 30 | 'maximumDocumentContextLengthDisplayed': 50, 31 | 'filterMethod': 'bm25', 32 | 'filterBm25FieldsConsidered': 'prompt|response', 33 | 'filterBm25RatioToKeep': 1.0, 34 | 'includeInputAsTokens': True, 35 | } 36 | 37 | with open(args.input_path) as f: 38 | ds = json.load(f) 39 | # ds = ds[:10] 40 | 41 | results = [] 42 | 43 | for item in tqdm(ds): 44 | 45 | if 'docs' not in item or args.overwrite_docs: 46 | payload = { 47 | 'prompt': item['prompt'], 48 | 'response': item['response'], 49 | **params, 50 | } 51 | result = requests.post(api_url, json=payload).json() 52 | 53 | doc_by_ix = {} 54 | for span in result['spans']: 55 | for doc in span['documents']: 56 | if doc['documentIndex'] not in doc_by_ix: 57 | try: 58 | url = doc['metadata']['metadata']['metadata']['url'] 59 | except: 60 | url = None 61 | doc_by_ix[doc['documentIndex']] = { 62 | 'documentIndex': doc['documentIndex'], 63 | 'url': url, 64 | 'relevanceScore': doc['relevanceScore'], 65 | 'spanTexts': [span['text']], 66 | 'snippets': [doc['text']], 67 | } 68 | else: 69 | doc_by_ix[doc['documentIndex']]['spanTexts'].append(span['text']) 70 | doc_by_ix[doc['documentIndex']]['snippets'].append(doc['text']) 71 | docs = list(sorted(doc_by_ix.values(), key=lambda x: x['relevanceScore'], reverse=True)) 72 | 73 | item['docs'] = docs 74 | 75 | max_doc_score = max(doc['relevanceScore'] for doc in item['docs']) 76 | min_doc_score = min(doc['relevanceScore'] for doc in item['docs']) 77 | max_span_score = max(max(doc['relevanceScore'] for doc in span['documents']) for span in result['spans']) 78 | min_span_score = min(max(doc['relevanceScore'] for doc in span['documents']) for span in result['spans']) 79 | response_len_chars = len(item['response']) 80 | results.append({ 81 | 'response_len_chars': response_len_chars, 82 | 'max_doc_score': max_doc_score, 83 | 'min_doc_score': min_doc_score, 84 | 'max_span_score': max_span_score, 85 | 'min_span_score': min_span_score, 86 | }) 87 | 88 | # Make a plot where the x-axis is the length of the response in characters, and the y-axis is a bar from min to max score of the documents. 89 | import matplotlib.pyplot as plt 90 | import seaborn as sns 91 | sns.set(style='whitegrid') 92 | 93 | plt.figure(figsize=(12, 8)) 94 | plt.scatter([r['response_len_chars'] for r in results], [r['max_doc_score'] for r in results], label='max doc score') 95 | plt.scatter([r['response_len_chars'] for r in results], [r['min_doc_score'] for r in results], label='min doc score') 96 | plt.xlabel('Response length (characters)') 97 | plt.ylabel('Document score') 98 | plt.legend() 99 | plt.savefig(args.output_path.replace('.png', '_doc.png'), dpi=300) 100 | 101 | plt.figure(figsize=(12, 8)) 102 | plt.scatter([r['response_len_chars'] for r in results], [r['max_span_score'] for r in results], label='max span score') 103 | plt.scatter([r['response_len_chars'] for r in results], [r['min_span_score'] for r in results], label='min span score') 104 | plt.xlabel('Response length (characters)') 105 | plt.ylabel('Span score') 106 | plt.legend() 107 | plt.savefig(args.output_path.replace('.png', '_span.png'), dpi=300) 108 | -------------------------------------------------------------------------------- /scripts/easyapi_test.py: -------------------------------------------------------------------------------- 1 | import easyapi 2 | import time 3 | import os 4 | 5 | # export EASY_URL=http://neptune-cs-aus-267.reviz.ai2.in:5000 6 | api = easyapi.Api() 7 | model_name = 'Qwen/Qwen2-7B-Instruct' 8 | hf_token = os.environ.get('HF_TOKEN', '') # only needed if your model needs it 9 | if not api.has_model(model_name): 10 | api.launch_model(model_name, gpus=1, hf_token=hf_token) # launch on jupiter 11 | 12 | while not api.has_model(model_name): time.sleep(5) 13 | 14 | prompt = "Barack Obama was born in" 15 | r = api.generate(prompt, model=model_name, temp=0.1, max_tokens=256) 16 | print(r) 17 | -------------------------------------------------------------------------------- /scripts/eval_corpuslink_relevance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import requests 5 | from openai import OpenAI 6 | import os 7 | from tqdm import tqdm 8 | 9 | client = OpenAI( 10 | api_key=os.environ['OPENAI_API_KEY'], 11 | ) 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--input_path', type=str, required=True) 15 | parser.add_argument('--output_path', type=str, required=True) 16 | parser.add_argument('--overwrite_docs', default=False, action='store_true') 17 | parser.add_argument('--overwrite_eval', default=False, action='store_true') 18 | parser.add_argument('--num_docs_per_thread', type=int, default=5) 19 | args = parser.parse_args() 20 | 21 | api_url = 'http://0.0.0.0:8008/olmoe/attribution' 22 | params = { 23 | 'delimiters': ['\n', '.'], 24 | 'allowSpansWithPartialWords': False, 25 | 'minimumSpanLength': 1, 26 | 'maximumFrequency': 1000000, 27 | 'maximumSpanDensity': 0.05, 28 | 'spanRankingMethod': 'unigram_logprob_sum', 29 | 'includeDocuments': True, 30 | 'maximumDocumentsPerSpan': 10, 31 | 'maximumDocumentDisplayLength': 500, 32 | 'filterMethod': 'bm25', 33 | 'filterBm25FieldsConsidered': 'prompt|response', 34 | 'filterBm25RatioToKeep': 1.0, 35 | 'includeInputAsTokens': True, 36 | } 37 | 38 | # llm_as_a_judge_system_message = """You will be given a user prompt, a model's response to the prompt, and a retrieved document. Please rate how relevant the document is to the prompt and model response. Rate on a scale of 0 (not relevant) to 3 (very relevant). Respond with a single number, and do not include any other text in your response.""" 39 | llm_as_a_judge_system_message = """You will be given a user prompt, a model's response to the prompt, and a retrieved document. Please rate how relevant the document is to the prompt and model response. Rate on a scale of 0 (not relevant) to 3 (very relevant). Respond with a single number, and do not include any other text in your response. 40 | 41 | Rubric for rating: 42 | 0: The document is about a different topic than the prompt and model response. 43 | 1. The document is about a broader topic than the prompt and model response, or is potentially relevant but there's not enough information. 44 | 2. The document is on the right topic of the prompt and model response, but is in a slightly different context or is too specific. 45 | 3. The document is about a subject that is a direct match, in topic and scope, of the most likely user intent for the prompt and model response.""" 46 | llm_as_a_judge_template = """Prompt: {prompt} 47 | 48 | Model response: {response} 49 | 50 | Retrieved document: {document}""" 51 | # llm_as_a_judge_system_message = """You will be given a user prompt, a model's response to the prompt, and a retrieved document. Please rate how relevant the document is to the prompt and model response. Rate on a scale of 0 (not relevant) to 3 (very relevant). Provide your reasoning first, and finally give your rating in the format of "Rating: X". Please strictly follow this template, and do not include any other text or character after the rating. 52 | 53 | # Rubric for rating: 54 | # 0: The document is about a different topic than the prompt and model response. 55 | # 1. The document is about a broader topic than the prompt and model response, or is potentially relevant but there's not enough information. 56 | # 2. The document is on the right topic of the prompt and model response, but is in a slightly different context or is too specific. 57 | # 3. The document is about a subject that is a direct match, in topic and scope, of the most likely user intent for the prompt and model response.""" 58 | # llm_as_a_judge_template = """Prompt: {prompt} 59 | 60 | # Model response: {response} 61 | 62 | # Retrieved document: {document}""" 63 | # llm_as_a_judge_system_message = """You will be given a user prompt, a model's response to the prompt, and a retrieved document. Please rate how relevant the document is to the prompt and model response. Rate on a scale of 0 (not relevant) to 3 (very relevant). First give your rating as a single number on its own line, and then provide your reasoning on this rating. Please strictly follow this format, and do not include any other text in the first line than the rating number. 64 | 65 | # Rubric for rating: 66 | # 0: The document is about a different topic than the prompt and model response. 67 | # 1. The document is about a broader topic than the prompt and model response, or is potentially relevant but there's not enough information. 68 | # 2. The document is on the right topic of the prompt and model response, but is in a slightly different context or is too specific. 69 | # 3. The document is about a subject that is a direct match, in topic and scope, of the most likely user intent for the prompt and model response.""" 70 | # llm_as_a_judge_template = """Prompt: {prompt} 71 | 72 | # Model response: {response} 73 | 74 | # Retrieved document: {document}""" 75 | 76 | with open(args.input_path) as f: 77 | ds = json.load(f) 78 | # ds = ds[:5] 79 | 80 | for item in tqdm(ds): 81 | 82 | if 'docs' not in item or args.overwrite_docs: 83 | payload = { 84 | 'prompt': item['prompt'], 85 | 'response': item['response'], 86 | **params, 87 | } 88 | result = requests.post(api_url, json=payload).json() 89 | 90 | doc_by_ix = {} 91 | for span in result['spans']: 92 | for doc in span['documents']: 93 | if doc['documentIndex'] not in doc_by_ix: 94 | try: 95 | url = doc['metadata']['metadata']['metadata']['url'] 96 | except: 97 | url = None 98 | doc_by_ix[doc['documentIndex']] = { 99 | 'documentIndex': doc['documentIndex'], 100 | 'url': url, 101 | 'relevanceScore': doc['relevanceScore'], 102 | 'spanTexts': [span['text']], 103 | 'snippets': [doc['text']], 104 | } 105 | else: 106 | doc_by_ix[doc['documentIndex']]['spanTexts'].append(span['text']) 107 | doc_by_ix[doc['documentIndex']]['snippets'].append(doc['text']) 108 | docs = list(sorted(doc_by_ix.values(), key=lambda x: x['relevanceScore'], reverse=True)) 109 | 110 | deduped_docs = [] 111 | for doc in docs: 112 | if doc['url'] is None or doc['url'] not in [d['url'] for d in deduped_docs]: 113 | deduped_docs.append(doc) 114 | docs = deduped_docs 115 | 116 | docs = docs[:args.num_docs_per_thread] 117 | 118 | item['docs'] = docs 119 | 120 | for doc in item['docs']: 121 | if 'rating' not in doc or args.overwrite_eval: 122 | user_message = llm_as_a_judge_template.format( 123 | prompt=item['prompt'].replace('\n', '\\n'), 124 | response=item['response'].replace('\n', '\\n'), 125 | document=doc['snippets'][0].replace('\n', '\\n'), # by default, users will only see the first snippet in the UI 126 | ) 127 | response = client.chat.completions.create( 128 | # model='gpt-4o-mini-2024-07-18', 129 | model='gpt-4o-2024-08-06', 130 | messages=[ 131 | # {'role': 'system', 'content': llm_as_a_judge_system_message}, 132 | # {'role': 'user', 'content': user_message}, 133 | {'role': 'user', 'content': llm_as_a_judge_system_message + '\n\n' + user_message}, 134 | ], 135 | temperature=0.0, 136 | max_completion_tokens=1, 137 | ) 138 | try: 139 | rating = int(response.choices[0].message.content[-1]) 140 | except: 141 | rating = None 142 | doc['rating'] = rating 143 | 144 | with open(args.output_path, 'w') as f: 145 | json.dump(ds, f, indent=4) 146 | 147 | avg_rating_top1 = np.mean([doc['rating'] for item in ds for doc in item['docs'][:1]]) 148 | avg_rating_top5 = np.mean([doc['rating'] for item in ds for doc in item['docs'][:5]]) 149 | ratio_relevant_top1 = np.mean([doc['rating'] >= 2 for item in ds for doc in item['docs'][:1]]) 150 | ratio_relevant_top5 = np.mean([doc['rating'] >= 2 for item in ds for doc in item['docs'][:5]]) 151 | print(f'Average rating (top-1 documents): {avg_rating_top1:.2f}') 152 | print(f'Average rating (top-5 documents): {avg_rating_top5:.2f}') 153 | print(f'Percentage relevant (top-1 documents): {ratio_relevant_top1 * 100:.2f}%') 154 | print(f'Percentage relevant (top-5 documents): {ratio_relevant_top5 * 100:.2f}%') 155 | -------------------------------------------------------------------------------- /scripts/sample_wildbench.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import csv 3 | 4 | ds = load_dataset("allenai/WildBench", "v2", split="test") 5 | sample = ds.shuffle(seed=42).select(range(100)) 6 | 7 | with open('wildbench_sample.csv', 'w') as f: 8 | fieldnames = ['prompt'] 9 | writer = csv.DictWriter(f, fieldnames=fieldnames) 10 | writer.writeheader() 11 | for s in sample: 12 | writer.writerow({'prompt': s["conversation_input"][0]['content']}) 13 | -------------------------------------------------------------------------------- /scripts/stress_test_api.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import multiprocessing as mp 3 | import random 4 | 5 | NUM_TOKENS = 500 6 | NUM_CONCURRENT_REQUESTS = 32 7 | 8 | PAYLOAD = { 9 | 'prompt': '', 10 | 'response': '', 11 | 'delimiters': ['\n', '.'], 12 | 'allowSpansWithPartialWords': False, 13 | 'minimumSpanLength': 1, 14 | 'maximumFrequency': 1000000, 15 | 'maximumSpanDensity': 0.05, 16 | 'spanRankingMethod': 'unigram_logprob_sum', 17 | 'maximumDocumentsPerSpan': 10, 18 | 'maximumContextLength': 250, 19 | 'maximumContextLengthLong': 250, 20 | 'maximumContextLengthSnippet': 40, 21 | } 22 | 23 | url = 'http://0.0.0.0:8008/olmo-2-1124-13b/attribution' 24 | 25 | def issue_request(response): 26 | payload = PAYLOAD.copy() 27 | payload['response'] = response 28 | return requests.post(url, json=payload).json() 29 | 30 | with mp.get_context('fork').Pool(NUM_CONCURRENT_REQUESTS) as p: 31 | responses = [] 32 | for i in range(NUM_CONCURRENT_REQUESTS): 33 | response = '' 34 | for j in range(NUM_TOKENS): 35 | response += str(random.randint(0, 9)) 36 | responses.append(response) 37 | results = p.map(issue_request, responses) 38 | 39 | for result in results: 40 | print('='*80) 41 | for span in result['spans']: 42 | print(f'l={span["left"]}, r={span["right"]}, text={span["text"]}') 43 | 44 | # for i in range(NUM_CONCURRENT_REQUESTS): 45 | # result = issue_request(responses[i]) 46 | # assert result == results[i] 47 | -------------------------------------------------------------------------------- /scripts/test_span_density.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import requests 3 | 4 | query_obama = '''Barack Hussein Obama II was born on August 4, 1961, in Honolulu, Hawaii. He is the 44th President of the United States, serving two terms from January 20, 2009, to January 20, 2017. An African-American, Obama is the first African descent or first person of any race to hold the office of the President. 5 | 6 | Obama graduated from Columbia University in 1983 and Harvard Law School in 1989. He stayed in Boston for some time before returning to Chicago, where he worked as a civil rights lawyer and community organizer. In 2004, Obama ran for the U.S. Senate and won, representing Illinois. Successfully running for President in 2008, Obama took office in January 2009. 7 | 8 | Throughout his tenure as President, Barack Obama oversaw the recovery from the 2008 financial crisis and the 2008-2010 Great Recession. He pushed for health care reform, signing the Affordable Care Act, and strived to lessen racial tension in society. He also led military missions against terrorism organizations worldwide during his presidency. Obama retired from politics after completing his second term, but his impact and contributions have sustained him as a powerful and influential figure. 9 | 10 | On November 10, 2020, Barack Obama flipped Massachusetts, which he once represented as a Senator, to Biden in the Presidential race. More recently, in October 2021, Barack Obama gave one of the most celebrated speeches of the year in Las Vegas, Nevada, in support of Democratic candidates up and down the ballots.''' 11 | 12 | query_biden = '''Joseph Robinette "Joe" Biden Jr. was born on November 20, 1942, in Scranton, Pennsylvania. He is the 47th Vice President of the United States, serving under President Barack Obama from 2009 to 2017. Biden, a Democrat, represented Delaware in the U.S. Senate for 36 years before becoming vice president. 13 | 14 | Biden earned a reputation as a respected and trusted bipartisan figure, working with both Republicans and Democrats in the Senate. He passed major legislation during his time in Congress, including the Children's Health Insurance Program and the Violence Against Women Act. In 2008, he ran unsuccessfully for the Democratic nomination for U.S. President. 15 | 16 | In 2009, after Barack Obama became President, he selected Biden to be his vice president. Biden is the oldest person and the first Roman Catholic to serve as U.S. Vice President. During his tenure as Vice President, he saw the country navigate various challenges such as enforcing the Affordable Care Act, managing the Obama administration's relations with Iran, facing the 2010 midterm election victories by the Republican Party, experiencing the ceasefire in the Israel-Gaza conflict, losing close friend and former senator Ted Kennedy, and many others. 17 | 18 | Biden was often described as a ringing voice of reason and a calming influence in The White House, working closely with the President daily. Strongly connected to the people, he has numerous public appearances nationally and internationally, promoting the America recovery, security, and improved international relations. 19 | 20 | Biden retired at the end of his Vice Presidential term, closing a 33-year Senate journey, and returned to his hometown where he plans to focus on his role as Chairman of the Democratic national committee, involved with the party strategy and organization.''' 21 | 22 | query_queen = '''Queen Elizabeth II, full name Elizabeth Alexandra Mary, is the reigning monarch of the United Kingdom and 14 other Commonwealth countries. She was born on April 21, 1926, in London, to Prince Albert (future King Albert II), Duke of York, and Elizabeth Bowes-Lyon. If her father had not died early, Elizabeth would not have ascended to the throne. but with his sudden death on November 14, 1936, as a result, Elizabeth became the heir apparent and Queen of eight countries. 23 | 24 | When Elizabeth was just a princess, her life dramatically changed on May 27, 1940. In what became known as the trio's escape, Hitler's Nazi forces near London threatened by German bombs, and her family was relocated several times. Later on, her family took residence at the Royal Lodge in Windsor Great Park. She graduated from the Royal Montagu School, now known as Beaumont, in 1944. On July 23, 1947, Elizabeth married Philip Mountbatten, Duke of Edinburgh, and they assumed the joint style of Arches and 575th of the Order of the Garter. They had four children over the years -- Charles, Anne, Andrew, and Edward -- who all (except for Edward) became a part of the UK's Royal Navy or Royal Air Force. 25 | 26 | Elizabeth acceded to the throne of the United Kingdom on February 6, 1952, upon the death of her father King George VI. As the longest-reigning British monarch, she has overseen significant changes in the United Kingdom, Britain's Colonies and dominions gained independence; some no longer use "Queen" in their titles, like Canada, Australia, and New Zealand. Elizabeth II celebrated her platinum jubilee -- meanings sixtieth anniversary of her ascension -- in February 2022.''' 27 | 28 | query_harris = '''Kamala Harris is an American politician who served as the junior U.S. Senator from California from 2017 to 2021. Before joining the Senate, Kamala Harris served as the Attorney General of California from 2011 to 2017. she is possibly running for the Democratic nomination for President in the United States presidential election of 2020. Here are some key facts about Kamala Harris: 29 | 30 | Early life and education: Kamala Harris was born on October 20, 1964, in Oakland, California, to a Jamaican father and an Indian mother. She graduated from Howard University with a bachelor's degree in political science and UCLA School of Law with a Juris Doctor (JD) degree. 31 | Political career: Harris began her political career serving as the District Attorney of San Francisco from 2003 to 2011. She was the first woman, the first black woman, and the first South Asian American elected as District Attorney for a large city in the United States. In 2010, Kamala Harris was sworn in as California's 34th Attorney General after the incumbent, Jerry Brown, resigned to run for governor. During her tenure as Attorney General, she focused on protecting consumers, defending the state's laws and regulations, and ensuring access to justice for all Californians. 32 | Presidential campaign: In the 2020 presidential election cycle, Harris has declared her candidacy for the Democratic Party's nomination as the President of the United States. She has released a comprehensive plan for tackling COVID-19 aimed at restoring public faith in government, addressing systemic racism, and creating long-term systemic change. She has also proposed a "whole-of-government" approach to increasing access to education and healthcare, providing resources to eyebrows, and restoring the public trust so damaged by the Trump administration. 33 | Personal life: As of November 2021, Harris has one daughter, Maya Harris, born in 2018, with her partner, Doug Emhoff. Harris was openly gay when she was California Attorney General, she was the first person in her role to be openly LGBT. Harris continues to advocate for LGBT and communitarian rights even during her time as a Senator. Her former partner, Shon Baiker, is the mother of her child. 34 | Overall, Kamala Harris is a well-rounded individual with significant political experience and a strong interest in progressive policy reforms, including healthcare, education, criminal justice, and environmental issues.''' 35 | 36 | query_shoe = '''Tying your shoes is a basic skill that every child learns and every adult should know. Here's a simple guide to help you learn how to tie your shoes: 37 | 38 | Begin by holding the laces in each hand, with your fingers spread near the toes ends. 39 | Cross the longer lace over the shorter lace, and place it on top of the shorter lace. 40 | Take the longer lace (on the right) and pass it under the shorter lace, and bring it back up on the other side of the shorter lace. 41 | Pull the longer lace (on the right) through the loop that just appeared briefly, and pull it snug. 42 | Repeat steps 3 and 4 for the other lace. 43 | Cross the longer lace (on the left) over the shorter lace, and place it on top of the shorter lace. 44 | Take the longer lace (on the left and pass it under the shorter lace,) and bring it back up on the other side of the shorter lace. 45 | Pull the longer lace (on the left) through the loop that just appeared briefly, and pull it snug. 46 | At this point, you have crossed the longest laces twice and passed them under the shortest laces twice, so the laces are now tied into a double knot at the top. 47 | Follow these steps again for the other lace or repeat the entire process until you become confident with the tying. 48 | Keep practicing until the laces are easily tied every time. Don't get discouraged if you're having difficulty at first; it takes time and patience to perfect this skill.''' 49 | 50 | payload = { 51 | 'query': query_obama, 52 | 'delimiters': ['\n', '.'], 53 | 'maximumSpanDensity': 0.05, 54 | 'minimumSpanLength': 1, 55 | 'maximumFrequency': 10, 56 | 'includeDocuments': True, 57 | 'maximumDocumentDisplayLength': 100, 58 | 'includeInputAsTokens': True, 59 | 'filterMethod': 'bm25', 60 | } 61 | 62 | result = requests.post('http://0.0.0.0:8000/dolma-1_7/attribution', json=payload).json() 63 | num_spans = len(result['spans']) 64 | num_tokens = len(result['inputTokens']) 65 | density = num_spans / num_tokens if num_tokens > 0 else 0 66 | num_docs = sum(len(span['documents']) for span in result['spans']) 67 | print(f'Number of spans: {len(result["spans"])}') 68 | print(f'Number of tokens in response: {len(result["inputTokens"])}') 69 | print(f'Span density: {density:.4f} spans per token') 70 | print(f'Span lengths: {list(sorted([span["length"] for span in result["spans"]]))}') 71 | print(f'Total number of documents: {num_docs}') 72 | 73 | for s, span in enumerate(result['spans']): 74 | print(f'Span {s}: {span["text"]}') 75 | documents = span['documents'] 76 | print(len(documents)) 77 | for d, doc in enumerate(documents): 78 | print('--------' * 10) 79 | print(f'Document {d}: shard={doc["shard"]}, pointer={doc["pointer"]}, score={doc.get("score", ""):.4f}, text={doc["text"] if "text" in doc else ""}') 80 | print('========' * 10) 81 | 82 | with open('csv/obama_100.csv', 'w') as f: 83 | writer = csv.DictWriter(f, fieldnames=['span_ix', 'span_text', 'doc_text', 'score', 'label', 'score > thresh?']) 84 | writer.writeheader() 85 | for s, span in enumerate(result['spans']): 86 | for d, doc in enumerate(span['documents']): 87 | writer.writerow({ 88 | 'span_ix': s if d == 0 else '', 89 | 'span_text': span['text'] if d == 0 else '', 90 | 'doc_text': doc['text'], 91 | 'score': f"{doc.get('score', ''):.4f}", 92 | 'label': '', 93 | 'score > thresh?': '', 94 | }) 95 | -------------------------------------------------------------------------------- /skiff.json: -------------------------------------------------------------------------------- 1 | { 2 | "appName": "infinigram-api", 3 | "contact": "reviz", 4 | "team": "reviz", 5 | "replicas": { 6 | "prod": 2 7 | }, 8 | "attributionWorkerReplicas": { 9 | "prod": 2 10 | } 11 | } -------------------------------------------------------------------------------- /vendor/infini_gram-2.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/infinigram-api/e858271c372d6accf13d12dd46bcca7fefe570ee/vendor/infini_gram-2.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -------------------------------------------------------------------------------- /vendor/infini_gram-2.5.1-cp311-cp311-macosx_10_15_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/infinigram-api/e858271c372d6accf13d12dd46bcca7fefe570ee/vendor/infini_gram-2.5.1-cp311-cp311-macosx_10_15_x86_64.whl -------------------------------------------------------------------------------- /vendor/infini_gram-2.5.1-cp311-cp311-macosx_11_0_arm64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/infinigram-api/e858271c372d6accf13d12dd46bcca7fefe570ee/vendor/infini_gram-2.5.1-cp311-cp311-macosx_11_0_arm64.whl -------------------------------------------------------------------------------- /vendor/infini_gram-2.5.1-cp312-cp312-macosx_10_15_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/infinigram-api/e858271c372d6accf13d12dd46bcca7fefe570ee/vendor/infini_gram-2.5.1-cp312-cp312-macosx_10_15_x86_64.whl -------------------------------------------------------------------------------- /vendor/infini_gram-2.5.1-cp312-cp312-macosx_11_0_arm64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/infinigram-api/e858271c372d6accf13d12dd46bcca7fefe570ee/vendor/infini_gram-2.5.1-cp312-cp312-macosx_11_0_arm64.whl -------------------------------------------------------------------------------- /vendor/infini_gram-2.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/infinigram-api/e858271c372d6accf13d12dd46bcca7fefe570ee/vendor/infini_gram-2.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -------------------------------------------------------------------------------- /vendor/infini_gram-2.5.1-cp313-cp313-macosx_10_15_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/infinigram-api/e858271c372d6accf13d12dd46bcca7fefe570ee/vendor/infini_gram-2.5.1-cp313-cp313-macosx_10_15_x86_64.whl -------------------------------------------------------------------------------- /vendor/infini_gram-2.5.1-cp313-cp313-macosx_11_0_arm64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/infinigram-api/e858271c372d6accf13d12dd46bcca7fefe570ee/vendor/infini_gram-2.5.1-cp313-cp313-macosx_11_0_arm64.whl -------------------------------------------------------------------------------- /vendor/llama-2-7b-hf/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "meta-llama/Llama-2-7b-hf", 3 | "architectures": [ 4 | "LlamaForCausalLM" 5 | ], 6 | "attention_bias": false, 7 | "attention_dropout": 0.0, 8 | "bos_token_id": 1, 9 | "eos_token_id": 2, 10 | "hidden_act": "silu", 11 | "hidden_size": 4096, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 11008, 14 | "max_position_embeddings": 4096, 15 | "mlp_bias": false, 16 | "model_type": "llama", 17 | "num_attention_heads": 32, 18 | "num_hidden_layers": 32, 19 | "num_key_value_heads": 32, 20 | "pretraining_tp": 1, 21 | "rms_norm_eps": 1e-05, 22 | "rope_scaling": null, 23 | "rope_theta": 10000.0, 24 | "tie_word_embeddings": false, 25 | "torch_dtype": "float16", 26 | "transformers_version": "4.41.2", 27 | "use_cache": true, 28 | "vocab_size": 32000 29 | } 30 | -------------------------------------------------------------------------------- /vendor/llama-2-7b-hf/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "", 4 | "lstrip": false, 5 | "normalized": false, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "", 11 | "lstrip": false, 12 | "normalized": false, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "unk_token": { 17 | "content": "", 18 | "lstrip": false, 19 | "normalized": false, 20 | "rstrip": false, 21 | "single_word": false 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /vendor/llama-2-7b-hf/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": false, 3 | "add_eos_token": false, 4 | "added_tokens_decoder": { 5 | "0": { 6 | "content": "", 7 | "lstrip": false, 8 | "normalized": false, 9 | "rstrip": false, 10 | "single_word": false, 11 | "special": true 12 | }, 13 | "1": { 14 | "content": "", 15 | "lstrip": false, 16 | "normalized": false, 17 | "rstrip": false, 18 | "single_word": false, 19 | "special": true 20 | }, 21 | "2": { 22 | "content": "", 23 | "lstrip": false, 24 | "normalized": false, 25 | "rstrip": false, 26 | "single_word": false, 27 | "special": true 28 | } 29 | }, 30 | "bos_token": "", 31 | "clean_up_tokenization_spaces": false, 32 | "eos_token": "", 33 | "legacy": false, 34 | "model_max_length": 1000000000000000019884624838656, 35 | "pad_token": null, 36 | "padding_side": "right", 37 | "sp_model_kwargs": {}, 38 | "tokenizer_class": "LlamaTokenizer", 39 | "unk_token": "", 40 | "use_default_system_prompt": false 41 | } 42 | -------------------------------------------------------------------------------- /vendor/olmo-7b-hf/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_type": "swiglu", 3 | "alibi": false, 4 | "alibi_bias_max": 8.0, 5 | "architectures": [ 6 | "OLMoForCausalLM" 7 | ], 8 | "attention_dropout": 0.0, 9 | "attention_layer_norm": false, 10 | "attention_layer_norm_with_affine": false, 11 | "bias_for_layer_norm": false, 12 | "block_group_size": 1, 13 | "block_type": "sequential", 14 | "d_model": 4096, 15 | "embedding_dropout": 0.0, 16 | "embedding_size": 50304, 17 | "eos_token_id": 50279, 18 | "flash_attention": true, 19 | "include_bias": false, 20 | "init_cutoff_factor": null, 21 | "init_device": "meta", 22 | "init_fn": "mitchell", 23 | "init_std": 0.02, 24 | "layer_norm_type": "default", 25 | "layer_norm_with_affine": false, 26 | "max_sequence_length": 2048, 27 | "mlp_hidden_size": 22016, 28 | "mlp_ratio": 4, 29 | "model_type": "hf_olmo", 30 | "multi_query_attention": false, 31 | "n_heads": 32, 32 | "n_layers": 32, 33 | "pad_token_id": 1, 34 | "precision": "amp_bf16", 35 | "residual_dropout": 0.0, 36 | "rope": true, 37 | "rope_full_precision": true, 38 | "scale_logits": false, 39 | "transformers_version": "4.36.2", 40 | "use_cache": true, 41 | "vocab_size": 50280, 42 | "weight_tying": false, 43 | "auto_map": { 44 | "AutoConfig": "configuration_olmo.OLMoConfig", 45 | "AutoModelForCausalLM": "modeling_olmo.OLMoForCausalLM", 46 | "AutoTokenizer": [ 47 | "tokenization_olmo_fast.OLMoTokenizerFast", 48 | "tokenization_olmo_fast.OLMoTokenizerFast" 49 | ] 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /vendor/olmo-7b-hf/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "eos_token": "<|endoftext|>", 3 | "pad_token": "<|padding|>" 4 | } 5 | -------------------------------------------------------------------------------- /vendor/olmo-7b-hf/tokenization_olmo_fast.py: -------------------------------------------------------------------------------- 1 | from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast 2 | -------------------------------------------------------------------------------- /vendor/olmo-7b-hf/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "added_tokens_decoder": { 3 | "0": { 4 | "content": "|||IP_ADDRESS|||", 5 | "lstrip": false, 6 | "normalized": true, 7 | "rstrip": false, 8 | "single_word": false, 9 | "special": false 10 | }, 11 | "1": { 12 | "content": "<|padding|>", 13 | "lstrip": false, 14 | "normalized": false, 15 | "rstrip": false, 16 | "single_word": false, 17 | "special": true 18 | }, 19 | "50254": { 20 | "content": " ", 21 | "lstrip": false, 22 | "normalized": true, 23 | "rstrip": false, 24 | "single_word": false, 25 | "special": false 26 | }, 27 | "50255": { 28 | "content": " ", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false, 33 | "special": false 34 | }, 35 | "50256": { 36 | "content": " ", 37 | "lstrip": false, 38 | "normalized": true, 39 | "rstrip": false, 40 | "single_word": false, 41 | "special": false 42 | }, 43 | "50257": { 44 | "content": " ", 45 | "lstrip": false, 46 | "normalized": true, 47 | "rstrip": false, 48 | "single_word": false, 49 | "special": false 50 | }, 51 | "50258": { 52 | "content": " ", 53 | "lstrip": false, 54 | "normalized": true, 55 | "rstrip": false, 56 | "single_word": false, 57 | "special": false 58 | }, 59 | "50259": { 60 | "content": " ", 61 | "lstrip": false, 62 | "normalized": true, 63 | "rstrip": false, 64 | "single_word": false, 65 | "special": false 66 | }, 67 | "50260": { 68 | "content": " ", 69 | "lstrip": false, 70 | "normalized": true, 71 | "rstrip": false, 72 | "single_word": false, 73 | "special": false 74 | }, 75 | "50261": { 76 | "content": " ", 77 | "lstrip": false, 78 | "normalized": true, 79 | "rstrip": false, 80 | "single_word": false, 81 | "special": false 82 | }, 83 | "50262": { 84 | "content": " ", 85 | "lstrip": false, 86 | "normalized": true, 87 | "rstrip": false, 88 | "single_word": false, 89 | "special": false 90 | }, 91 | "50263": { 92 | "content": " ", 93 | "lstrip": false, 94 | "normalized": true, 95 | "rstrip": false, 96 | "single_word": false, 97 | "special": false 98 | }, 99 | "50264": { 100 | "content": " ", 101 | "lstrip": false, 102 | "normalized": true, 103 | "rstrip": false, 104 | "single_word": false, 105 | "special": false 106 | }, 107 | "50265": { 108 | "content": " ", 109 | "lstrip": false, 110 | "normalized": true, 111 | "rstrip": false, 112 | "single_word": false, 113 | "special": false 114 | }, 115 | "50266": { 116 | "content": " ", 117 | "lstrip": false, 118 | "normalized": true, 119 | "rstrip": false, 120 | "single_word": false, 121 | "special": false 122 | }, 123 | "50267": { 124 | "content": " ", 125 | "lstrip": false, 126 | "normalized": true, 127 | "rstrip": false, 128 | "single_word": false, 129 | "special": false 130 | }, 131 | "50268": { 132 | "content": " ", 133 | "lstrip": false, 134 | "normalized": true, 135 | "rstrip": false, 136 | "single_word": false, 137 | "special": false 138 | }, 139 | "50269": { 140 | "content": " ", 141 | "lstrip": false, 142 | "normalized": true, 143 | "rstrip": false, 144 | "single_word": false, 145 | "special": false 146 | }, 147 | "50270": { 148 | "content": " ", 149 | "lstrip": false, 150 | "normalized": true, 151 | "rstrip": false, 152 | "single_word": false, 153 | "special": false 154 | }, 155 | "50271": { 156 | "content": " ", 157 | "lstrip": false, 158 | "normalized": true, 159 | "rstrip": false, 160 | "single_word": false, 161 | "special": false 162 | }, 163 | "50272": { 164 | "content": " ", 165 | "lstrip": false, 166 | "normalized": true, 167 | "rstrip": false, 168 | "single_word": false, 169 | "special": false 170 | }, 171 | "50273": { 172 | "content": " ", 173 | "lstrip": false, 174 | "normalized": true, 175 | "rstrip": false, 176 | "single_word": false, 177 | "special": false 178 | }, 179 | "50274": { 180 | "content": " ", 181 | "lstrip": false, 182 | "normalized": true, 183 | "rstrip": false, 184 | "single_word": false, 185 | "special": false 186 | }, 187 | "50275": { 188 | "content": " ", 189 | "lstrip": false, 190 | "normalized": true, 191 | "rstrip": false, 192 | "single_word": false, 193 | "special": false 194 | }, 195 | "50276": { 196 | "content": " ", 197 | "lstrip": false, 198 | "normalized": true, 199 | "rstrip": false, 200 | "single_word": false, 201 | "special": false 202 | }, 203 | "50277": { 204 | "content": "|||EMAIL_ADDRESS|||", 205 | "lstrip": false, 206 | "normalized": true, 207 | "rstrip": false, 208 | "single_word": false, 209 | "special": false 210 | }, 211 | "50278": { 212 | "content": "|||PHONE_NUMBER|||", 213 | "lstrip": false, 214 | "normalized": true, 215 | "rstrip": false, 216 | "single_word": false, 217 | "special": false 218 | }, 219 | "50279": { 220 | "content": "<|endoftext|>", 221 | "lstrip": false, 222 | "normalized": false, 223 | "rstrip": false, 224 | "single_word": false, 225 | "special": true 226 | } 227 | }, 228 | "clean_up_tokenization_spaces": true, 229 | "eos_token": "<|endoftext|>", 230 | "max_length": null, 231 | "model_max_length": 1000000000000000019884624838656, 232 | "pad_token": "<|padding|>", 233 | "tokenizer_class": "OLMoTokenizer", 234 | "truncation": "right", 235 | "auto_map": { 236 | "AutoConfig": "configuration_olmo.OLMoConfig", 237 | "AutoTokenizer": [ 238 | "tokenization_olmo_fast.OLMoTokenizerFast", 239 | "tokenization_olmo_fast.OLMoTokenizerFast" 240 | ] 241 | } 242 | } -------------------------------------------------------------------------------- /volume-claims/olmoe-mix-0924-dclm.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: infinigram-olmoe-mix-0924-dclm 5 | spec: 6 | storageClassName: "infinigram" 7 | capacity: 8 | storage: 39Ti 9 | accessModes: 10 | - ReadOnlyMany 11 | claimRef: 12 | namespace: infinigram-api 13 | name: infinigram-olmoe-mix-0924-dclm 14 | csi: 15 | driver: pd.csi.storage.gke.io 16 | volumeHandle: projects/ai2-reviz/zones/us-west1-b/disks/infini-gram-olmoe-mix-0924-dclm 17 | fsType: ext4 18 | readOnly: true 19 | --- 20 | apiVersion: v1 21 | kind: PersistentVolumeClaim 22 | metadata: 23 | namespace: infinigram-api 24 | name: infinigram-olmoe-mix-0924-dclm 25 | spec: 26 | storageClassName: "infinigram" 27 | volumeName: infinigram-olmoe-mix-0924-dclm 28 | accessModes: 29 | - ReadOnlyMany 30 | resources: 31 | requests: 32 | storage: 39Ti 33 | -------------------------------------------------------------------------------- /volume-claims/olmoe-mix-0924-nodclm.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: infinigram-olmoe-mix-0924-nodclm 5 | spec: 6 | storageClassName: "infinigram" 7 | capacity: 8 | storage: 2Ti 9 | accessModes: 10 | - ReadOnlyMany 11 | claimRef: 12 | namespace: infinigram-api 13 | name: infinigram-olmoe-mix-0924-nodclm 14 | csi: 15 | driver: pd.csi.storage.gke.io 16 | volumeHandle: projects/ai2-reviz/zones/us-west1-b/disks/infini-gram-olmoe-mix-0924-nodclm 17 | fsType: ext4 18 | readOnly: true 19 | --- 20 | apiVersion: v1 21 | kind: PersistentVolumeClaim 22 | metadata: 23 | namespace: infinigram-api 24 | name: infinigram-olmoe-mix-0924-nodclm 25 | spec: 26 | storageClassName: "infinigram" 27 | volumeName: infinigram-olmoe-mix-0924-nodclm 28 | accessModes: 29 | - ReadOnlyMany 30 | resources: 31 | requests: 32 | storage: 2Ti 33 | -------------------------------------------------------------------------------- /volume-claims/pileval-gpt2.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: storage.k8s.io/v1 2 | kind: StorageClass 3 | metadata: 4 | name: infinigram 5 | provisioner: pd.csi.storage.gke.io 6 | volumeBindingMode: WaitForFirstConsumer 7 | allowVolumeExpansion: true 8 | parameters: 9 | type: pd-balanced 10 | --- 11 | apiVersion: v1 12 | kind: PersistentVolume 13 | metadata: 14 | name: infinigram-pileval-gpt2 15 | spec: 16 | storageClassName: "infinigram" 17 | capacity: 18 | storage: 10Gi 19 | accessModes: 20 | - ReadOnlyMany 21 | claimRef: 22 | namespace: infinigram-api 23 | name: infinigram-pileval-gpt2 24 | csi: 25 | driver: pd.csi.storage.gke.io 26 | volumeHandle: projects/ai2-reviz/zones/us-west1-b/disks/infinigram-pileval-test 27 | fsType: ext4 28 | readOnly: true 29 | --- 30 | apiVersion: v1 31 | kind: PersistentVolumeClaim 32 | metadata: 33 | namespace: infinigram-api 34 | name: infinigram-pileval-gpt2 35 | spec: 36 | storageClassName: "infinigram" 37 | volumeName: infinigram-pileval-gpt2 38 | accessModes: 39 | - ReadOnlyMany 40 | resources: 41 | requests: 42 | storage: 10Gi 43 | -------------------------------------------------------------------------------- /volume-claims/v4-olmo-2-0325-32b-anneal-adapt.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: infinigram-v4-olmo-2-0325-32b-anneal-adapt 5 | spec: 6 | storageClassName: "infinigram" 7 | capacity: 8 | storage: 300Gi 9 | accessModes: 10 | - ReadOnlyMany 11 | claimRef: 12 | namespace: infinigram-api 13 | name: infinigram-v4-olmo-2-0325-32b-anneal-adapt 14 | csi: 15 | driver: pd.csi.storage.gke.io 16 | volumeHandle: projects/ai2-reviz/zones/us-west1-b/disks/$DISK_NAME 17 | fsType: ext4 18 | readOnly: true 19 | --- 20 | apiVersion: v1 21 | kind: PersistentVolumeClaim 22 | metadata: 23 | namespace: infinigram-api 24 | name: infinigram-v4-olmo-2-0325-32b-anneal-adapt 25 | spec: 26 | storageClassName: "infinigram" 27 | volumeName: infinigram-v4-olmo-2-0325-32b-anneal-adapt 28 | accessModes: 29 | - ReadOnlyMany 30 | resources: 31 | requests: 32 | storage: 300Gi -------------------------------------------------------------------------------- /volume-claims/v4-olmo-2-1124-13b-anneal-adapt.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: infinigram-v4-olmo-2-1124-13b-anneal-adapt 5 | spec: 6 | storageClassName: "infinigram" 7 | capacity: 8 | storage: 300Gi 9 | accessModes: 10 | - ReadOnlyMany 11 | claimRef: 12 | namespace: infinigram-api 13 | name: infinigram-v4-olmo-2-1124-13b-anneal-adapt 14 | csi: 15 | driver: pd.csi.storage.gke.io 16 | volumeHandle: projects/ai2-reviz/zones/us-west1-b/disks/$DISK_NAME 17 | fsType: ext4 18 | readOnly: true 19 | --- 20 | apiVersion: v1 21 | kind: PersistentVolumeClaim 22 | metadata: 23 | namespace: infinigram-api 24 | name: infinigram-v4-olmo-2-1124-13b-anneal-adapt 25 | spec: 26 | storageClassName: "infinigram" 27 | volumeName: infinigram-v4-olmo-2-1124-13b-anneal-adapt 28 | accessModes: 29 | - ReadOnlyMany 30 | resources: 31 | requests: 32 | storage: 300Gi -------------------------------------------------------------------------------- /volume-claims/v4-olmoe-0125-1b-7b-anneal-adapt.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: infinigram-v4-olmoe-0125-1b-7b-anneal-adapt 5 | spec: 6 | storageClassName: "infinigram" 7 | capacity: 8 | storage: 300Gi 9 | accessModes: 10 | - ReadOnlyMany 11 | claimRef: 12 | namespace: infinigram-api 13 | name: infinigram-v4-olmoe-0125-1b-7b-anneal-adapt 14 | csi: 15 | driver: pd.csi.storage.gke.io 16 | volumeHandle: projects/ai2-reviz/zones/us-west1-b/disks/$DISK_NAME 17 | fsType: ext4 18 | readOnly: true 19 | --- 20 | apiVersion: v1 21 | kind: PersistentVolumeClaim 22 | metadata: 23 | namespace: infinigram-api 24 | name: infinigram-v4-olmoe-0125-1b-7b-anneal-adapt 25 | spec: 26 | storageClassName: "infinigram" 27 | volumeName: infinigram-v4-olmoe-0125-1b-7b-anneal-adapt 28 | accessModes: 29 | - ReadOnlyMany 30 | resources: 31 | requests: 32 | storage: 300Gi -------------------------------------------------------------------------------- /volume-claims/v4-tulu-3-405b-adapt-llama.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: infinigram-v4-tulu-3-405b-adapt-llama 5 | spec: 6 | storageClassName: "infinigram" 7 | capacity: 8 | storage: 11Gi 9 | accessModes: 10 | - ReadOnlyMany 11 | claimRef: 12 | namespace: infinigram-api 13 | name: infinigram-v4-tulu-3-405b-adapt-llama 14 | csi: 15 | driver: pd.csi.storage.gke.io 16 | volumeHandle: projects/ai2-reviz/zones/us-west1-b/disks/infinigram-v4-tulu-3-405b-adapt-llama 17 | fsType: ext4 18 | readOnly: true 19 | --- 20 | apiVersion: v1 21 | kind: PersistentVolumeClaim 22 | metadata: 23 | namespace: infinigram-api 24 | name: infinigram-v4-tulu-3-405b-adapt-llama 25 | spec: 26 | storageClassName: "infinigram" 27 | volumeName: infinigram-v4-tulu-3-405b-adapt-llama 28 | accessModes: 29 | - ReadOnlyMany 30 | resources: 31 | requests: 32 | storage: 11Gi 33 | -------------------------------------------------------------------------------- /volume-claims/v4-tulu-3-70b-adapt-llama.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: infinigram-v4-tulu-3-70b-adapt-llama 5 | spec: 6 | storageClassName: "infinigram" 7 | capacity: 8 | storage: 11Gi 9 | accessModes: 10 | - ReadOnlyMany 11 | claimRef: 12 | namespace: infinigram-api 13 | name: infinigram-v4-tulu-3-70b-adapt-llama 14 | csi: 15 | driver: pd.csi.storage.gke.io 16 | volumeHandle: projects/ai2-reviz/zones/us-west1-b/disks/infinigram-v4-tulu-3-70b-adapt-llama 17 | fsType: ext4 18 | readOnly: true 19 | --- 20 | apiVersion: v1 21 | kind: PersistentVolumeClaim 22 | metadata: 23 | namespace: infinigram-api 24 | name: infinigram-v4-tulu-3-70b-adapt-llama 25 | spec: 26 | storageClassName: "infinigram" 27 | volumeName: infinigram-v4-tulu-3-70b-adapt-llama 28 | accessModes: 29 | - ReadOnlyMany 30 | resources: 31 | requests: 32 | storage: 11Gi 33 | -------------------------------------------------------------------------------- /volume-claims/v4-tulu-3-8b-adapt-llama.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: infinigram-v4-tulu-3-8b-adapt-llama 5 | spec: 6 | storageClassName: "infinigram" 7 | capacity: 8 | storage: 10Gi 9 | accessModes: 10 | - ReadOnlyMany 11 | claimRef: 12 | namespace: infinigram-api 13 | name: infinigram-v4-tulu-3-8b-adapt-llama 14 | csi: 15 | driver: pd.csi.storage.gke.io 16 | volumeHandle: projects/ai2-reviz/zones/us-west1-b/disks/infinigram-v4-tulu-3-8b-adapt-llama 17 | fsType: ext4 18 | readOnly: true 19 | --- 20 | apiVersion: v1 21 | kind: PersistentVolumeClaim 22 | metadata: 23 | namespace: infinigram-api 24 | name: infinigram-v4-tulu-3-8b-adapt-llama 25 | spec: 26 | storageClassName: "infinigram" 27 | volumeName: infinigram-v4-tulu-3-8b-adapt-llama 28 | accessModes: 29 | - ReadOnlyMany 30 | resources: 31 | requests: 32 | storage: 10Gi 33 | -------------------------------------------------------------------------------- /volume-claims/writer-pod.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Pod 3 | metadata: 4 | name: infini-gram-writer 5 | spec: 6 | containers: 7 | - name: infini-gram-writer 8 | image: nginx:1.27.0-alpine 9 | volumeMounts: 10 | - mountPath: /mnt/infini-gram-array/ 11 | name: infinigram-array 12 | 13 | volumes: 14 | - name: infinigram-array 15 | persistentVolumeClaim: 16 | claimName: infinigram-dolma-1-7-writer 17 | --- 18 | apiVersion: v1 19 | kind: PersistentVolume 20 | metadata: 21 | name: infinigram-dolma-1-7-writer 22 | spec: 23 | storageClassName: "infinigram" 24 | capacity: 25 | storage: 22Ti 26 | accessModes: 27 | - ReadWriteOnce 28 | claimRef: 29 | namespace: infinigram-api 30 | name: infinigram-dolma-1-7-writer 31 | csi: 32 | driver: pd.csi.storage.gke.io 33 | volumeHandle: projects/ai2-reviz/zones/us-west1-b/disks/infini-gram-dolma-1-7 34 | fsType: ext4 35 | --- 36 | apiVersion: v1 37 | kind: PersistentVolumeClaim 38 | metadata: 39 | namespace: infinigram-api 40 | name: infinigram-dolma-1-7-writer 41 | spec: 42 | storageClassName: "infinigram" 43 | volumeName: infinigram-dolma-1-7-writer 44 | accessModes: 45 | - ReadWriteOnce 46 | resources: 47 | requests: 48 | storage: 22Ti 49 | --------------------------------------------------------------------------------