├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENCE ├── Makefile ├── README.md ├── compose ├── dev │ ├── .env │ ├── docker-compose.nnsight.yml │ ├── docker-compose.yml │ ├── ray_config.yml │ └── service_config.yml └── prod │ ├── .env │ ├── api-start.sh │ ├── docker-compose-worker.yml │ ├── docker-compose.yml │ ├── ray_config.yml │ └── service_config.yml ├── docker ├── dockerfile.base ├── dockerfile.conda ├── dockerfile.service └── helpers │ ├── check_and_update_env.sh │ └── nns_inst.sh ├── scripts ├── redeploy.py └── test.py ├── src ├── common │ ├── logging │ │ ├── __init__.py │ │ └── logger.py │ ├── metrics │ │ ├── __init__.py │ │ ├── gpu_mem.py │ │ ├── metric.py │ │ ├── network_data.py │ │ ├── request_execution_time.py │ │ ├── request_response_size.py │ │ ├── request_status.py │ │ └── request_transport_latency.py │ └── schema │ │ ├── __init__.py │ │ ├── mixins.py │ │ ├── request.py │ │ ├── response.py │ │ └── result.py └── services │ ├── api │ ├── environment.yml │ ├── src │ │ ├── __init__.py │ │ ├── api_key.py │ │ ├── app.py │ │ ├── gunicorn.conf.py │ │ ├── logging │ │ ├── metrics │ │ ├── schema │ │ └── util.py │ └── start.sh │ ├── base │ └── environment.yml │ └── ray │ ├── environment.yml │ ├── src │ ├── __init__.py │ ├── logging │ ├── metrics │ ├── ray │ │ ├── __init__.py │ │ ├── config │ │ │ ├── ray_config.yml │ │ │ └── service_config.yml │ │ ├── deployments │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── controller.py │ │ │ ├── distributed_model.py │ │ │ ├── model.py │ │ │ ├── model │ │ │ │ ├── __init__.py │ │ │ │ └── base.py │ │ │ ├── protocols.py │ │ │ └── request.py │ │ ├── distributed │ │ │ ├── __init__.py │ │ │ ├── parallel_dims.py │ │ │ ├── tensor_parallelism │ │ │ │ ├── __init__.py │ │ │ │ ├── plans │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── llama.py │ │ │ │ └── test.py │ │ │ └── util.py │ │ ├── raystate.py │ │ ├── resources.py │ │ └── util.py │ └── schema │ ├── start-worker.sh │ └── start.sh └── telemetry ├── grafana ├── dashboards │ └── telemetry.json └── provisioning │ ├── dashboards │ └── telemetry.yml │ └── datasources │ ├── influxdb.yml │ └── prometheus.yml └── prometheus └── prometheus.yml /.gitignore: -------------------------------------------------------------------------------- 1 | .config 2 | __pycache__/ 3 | creds.json -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributor Covenant Code of Conduct 3 | 4 | ## Our Pledge 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, caste, color, religion, or sexual 11 | identity and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the overall 27 | community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or advances of 32 | any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email address, 36 | without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official email address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported to the community leaders responsible for enforcement at j.bell@northeastern.edu . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series of 86 | actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or permanent 93 | ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within the 113 | community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.1, available at 119 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 126 | [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 130 | [Mozilla CoC]: https://github.com/mozilla/diversity 131 | [FAQ]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-2024 Northeastern University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | -include .config 2 | 3 | IP_ADDR := $(shell hostname -I | awk '{print $$1}') 4 | N_DEVICES := $(shell command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L | wc -l || echo 0) 5 | 6 | # Treat "up", "down" and "ta" as targets (not files) 7 | .PHONY: up down ta 8 | 9 | # Define valid environments 10 | VALID_ENVS := dev prod delta 11 | 12 | # Default environment 13 | DEFAULT_ENV ?= dev 14 | 15 | # Configs for local nnsight installation 16 | DEV_NNS ?= False 17 | NNS_PATH ?= ~/nnsight 18 | TAG ?= latest 19 | 20 | # Function to check if the environment is valid 21 | check_env = $(if $(filter $(1),$(VALID_ENVS)),,$(error Invalid environment '$(1)'. Use one of: $(VALID_ENVS))) 22 | 23 | # Function to set environment and print message if no environment was specified 24 | set_env = $(eval ENV := $(if $(filter $(words $(MAKECMDGOALS)),1),$(DEFAULT_ENV),$(word 2,$(MAKECMDGOALS)))) \ 25 | $(if $(filter $(words $(MAKECMDGOALS)),1),$(info Using default environment: $(DEFAULT_ENV)),) 26 | 27 | build_base: 28 | docker build --no-cache -t ndif_base:$(TAG) -f docker/dockerfile.base . 29 | 30 | build_conda: 31 | docker build --no-cache --build-arg NAME=$(NAME) --build-arg TAG=$(TAG) -t $(NAME)_conda:$(TAG) -f docker/dockerfile.conda . 32 | 33 | build_service: 34 | cp docker/helpers/check_and_update_env.sh ./ 35 | tar -hczvf src.tar.gz --directory=src/services/$(NAME) src 36 | docker build --no-cache --build-arg NAME=$(NAME) --build-arg TAG=$(TAG) -t $(NAME):$(TAG) -f docker/dockerfile.service . 37 | rm src.tar.gz 38 | rm check_and_update_env.sh 39 | 40 | build_all_base: 41 | $(call set_env) 42 | $(call check_env,$(ENV)) 43 | make build_base 44 | 45 | build_all_conda: 46 | $(call set_env) 47 | $(call check_env,$(ENV)) 48 | make build_conda NAME=api 49 | make build_conda NAME=ray 50 | 51 | 52 | build_all_service: 53 | $(call set_env) 54 | $(call check_env,$(ENV)) 55 | make build_service NAME=api 56 | make build_service NAME=ray 57 | 58 | 59 | build: 60 | $(call set_env) 61 | $(call check_env,$(ENV)) 62 | make build_all_base 63 | make build_all_conda 64 | make build_all_service 65 | make up $(ENV) 66 | 67 | up: 68 | $(call set_env) 69 | $(call check_env,$(ENV)) 70 | @if [ "$(ENV)" = "dev" ] && [ "$(DEV_NNS)" = "True" ]; then \ 71 | export HOST_IP=$(IP_ADDR) N_DEVICES=$(N_DEVICES) NNS_PATH=$(NNS_PATH) && \ 72 | docker compose -f compose/dev/docker-compose.yml -f compose/dev/docker-compose.nnsight.yml up --detach; \ 73 | else \ 74 | export HOST_IP=$(IP_ADDR) N_DEVICES=$(N_DEVICES) NNS_PATH=$(NNS_PATH) && \ 75 | docker compose -f compose/$(ENV)/docker-compose.yml up --detach; \ 76 | fi 77 | 78 | down: 79 | $(call set_env) 80 | $(call check_env,$(ENV)) 81 | export HOST_IP=${IP_ADDR} N_DEVICES=${N_DEVICES} && docker compose -f compose/$(ENV)/docker-compose.yml down 82 | 83 | ta: 84 | $(call set_env) 85 | $(call check_env,$(ENV)) 86 | make down $(ENV) 87 | make build_all_service 88 | make up $(ENV) 89 | 90 | save-vars: 91 | @echo "DEFAULT_ENV=$(DEFAULT_ENV)" > .config 92 | @echo "DEV_NNS=$(DEV_NNS)" >> .config 93 | @echo "NNS_PATH=$(NNS_PATH)" >> .config 94 | 95 | reset-vars: 96 | @rm ./.config 97 | 98 | # Consumes the second argument (e.g. 'dev', 'prod') so it doesn't cause an error. 99 | %: 100 | @: 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NDIF Development Guide 2 | 3 | This guide explains how to set up a development environment, install dependencies, and get started with contributing to the `NDIF` project. 4 | 5 | ## Prerequisites 6 | 7 | - Python 3.10 8 | - Docker 9 | - Docker Compose 10 | 11 | 12 | ## Setup 13 | 14 | ## 1. Install Conda 15 | If you don’t have Conda installed, download and install Anaconda or Miniconda from the [official Conda website](https://docs.conda.io/en/latest/miniconda.html). 16 | 17 | ## 2. Create Conda Environment 18 | 19 | Fork the `NDIF` repository (or clone it directly) to your local machine. Then create a new Conda virtual environment: 20 | 21 | ```sh 22 | conda create -n ndif-dev python=3.10 23 | conda activate ndif-dev 24 | ``` 25 | 26 | ## 3. Install NNsight 27 | 28 | Choose one of the following methods: 29 | 30 | a. Via pip (simple) 31 | 32 | ``` 33 | pip install nnsight 34 | ``` 35 | 36 | b. From repository (recommended for specific branches) 37 | 38 | ```sh 39 | git clone https://github.com/nnsight/nnsight.git 40 | cd nnsight 41 | git checkout # e.g., 0.3 42 | pip install -e . 43 | ``` 44 | 45 | ## Building and Running `NDIF` 46 | 47 | ### 1. Build and start the development environment 48 | 49 | For first-time setup, use: 50 | 51 | ```sh 52 | make build 53 | ``` 54 | 55 | If you’ve made changes to the codebase but did not modify the `environment.yml` files, you can quickly rebuild the services using: 56 | 57 | ```sh 58 | make ta 59 | ``` 60 | 61 | This method is faster than running `make build` again. 62 | 63 | ### 2. Verify server status 64 | 65 | After building the `NDIF` containers, you can check the docker logs to verify the services are running correctly. 66 | ```sh 67 | docker logs dev-api-1 68 | ``` 69 | You should expect to see a message like `Application startup complete.` in the api service log. 70 | 71 | ### 3. Run tests 72 | 73 | ```sh 74 | python scripts/test.py 75 | ``` 76 | 77 | This will send a test NNsight request to the API service running in the local container. 78 | 79 | ## Additional Commands 80 | 81 | - To start the deployment environment without rebuilding: 82 | 83 | ```sh 84 | make up 85 | ``` 86 | 87 | - To stop the development environment: 88 | 89 | ```sh 90 | make down 91 | ``` 92 | 93 | - To rebuild services and restart the environment (useful during development): 94 | 95 | ```sh 96 | make ta 97 | ``` 98 | 99 | _Note: Modifying any of the `environment.yml` files will require you to rebuild from scratch._ 100 | 101 | # Environment Configuration 102 | 103 | The project uses separate `.env` files for development and production environments: 104 | 105 | - Development: `compose/dev/.env` 106 | - Production: `compose/prod/.env` 107 | 108 | For most users, only the development environment is necessary. The production environment is configured separately and is not required for local development. 109 | 110 | ### Note 111 | 112 | The Makefile includes configurations for both development and production environments. As an end user or developer, you'll primarily interact with the development environment. The production environment settings are managed separately and are not typically needed for local development work. -------------------------------------------------------------------------------- /compose/dev/.env: -------------------------------------------------------------------------------- 1 | # Broker Ports 2 | DEV_BROKER_PORT=6379 3 | BROKER_INTERNAL_PORT=6379 4 | BROKER_PROTOCOL=redis:// 5 | 6 | # MinIO Ports 7 | DEV_MINIO_PORT=27018 8 | MINIO_INTERNAL_PORT=9000 9 | 10 | # Ray Ports 11 | DEV_RAY_HEAD_PORT=6380 12 | RAY_HEAD_INTERNAL_PORT=6379 13 | 14 | DEV_RAY_CLIENT_PORT=9998 15 | RAY_CLIENT_INTERNAL_PORT=10001 16 | 17 | DEV_RAY_DASHBOARD_PORT=8266 18 | RAY_DASHBOARD_INTERNAL_PORT=8265 19 | 20 | DEV_RAY_SERVE_PORT=8267 21 | RAY_SERVE_INTERNAL_PORT=8267 22 | 23 | RAY_DASHBOARD_GRPC_PORT=8268 24 | OBJECT_MANAGER_PORT=8076 #Raylet port for object manager 25 | 26 | # API Ports 27 | DEV_API_PORT=5001 28 | API_INTERNAL_PORT=8001 29 | 30 | # Prometheus Ports 31 | DEV_PROMETHEUS_PORT=9090 32 | PROMETHEUS_INTERNAL_PORT=9090 33 | 34 | # Grafana Ports 35 | DEV_GRAFANA_PORT=3000 36 | GRAFANA_INTERNAL_PORT=3000 37 | 38 | # influxDB Ports 39 | DEV_INFLUXDB_PORT=8086 40 | INFLUXDB_INTERNAL_PORT=8086 41 | 42 | # Loki Ports 43 | DEV_LOKI_PORT=3100 44 | LOKI_INTERNAL_PORT=3100 45 | 46 | # Device Configuration 47 | N_DEVICES=$N_DEVICES 48 | 49 | # Credentials and Other Configs 50 | GRAFANA_ADMIN_USER=admin 51 | GRAFANA_ADMIN_PASSWORD=admin 52 | HOST_IP=$HOST_IP 53 | RAY_METRICS_GAUGE_EXPORT_INTERVAL_MS=1000 # Gauge export interval in ms 54 | RAY_DASHBOARD_HOST=0.0.0.0 55 | 56 | RAY_SERVE_QUEUE_LENGTH_RESPONSE_DEADLINE_S=10 57 | 58 | #InfluxDB Configs 59 | INFLUXDB_ORG=NDIF 60 | INFLUXDB_BUCKET=data 61 | SECRET_INFLUXDB_ADMIN_USERNAME=admin 62 | SECRET_INFLUXDB_ADMIN_PASSWORD=adminadmin 63 | SECRET_INFLUXDB_ADMIN_TOKEN=njkldhsfbdsfkl2o32== 64 | 65 | POSTGRES_HOST='' 66 | POSTGRES_PORT='5432' 67 | POSTGRES_DB='accounts' 68 | POSTGRES_USER='postgres' 69 | POSTGRES_PASSWORD='postgres' 70 | -------------------------------------------------------------------------------- /compose/dev/docker-compose.nnsight.yml: -------------------------------------------------------------------------------- 1 | services: 2 | api: 3 | command: ["bash", "./src/nns_inst.sh"] 4 | volumes: 5 | - ../../docker/helpers/nns_inst.sh:/src/nns_inst.sh 6 | - ${NNS_PATH}:/nnsight 7 | 8 | ray-head: 9 | command: ["bash", "./src/nns_inst.sh"] 10 | volumes: 11 | - ../../docker/helpers/nns_inst.sh:/src/nns_inst.sh 12 | - ${NNS_PATH}:/nnsight 13 | -------------------------------------------------------------------------------- /compose/dev/docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | message_broker: 3 | image: redis:latest 4 | ports: 5 | - ${DEV_BROKER_PORT}:${BROKER_INTERNAL_PORT} 6 | 7 | minio: 8 | image: minio/minio:latest 9 | command: server /data 10 | ports: 11 | - ${DEV_MINIO_PORT}:${MINIO_INTERNAL_PORT} 12 | 13 | ray: 14 | image: ray:latest 15 | shm_size: '15gb' 16 | volumes: 17 | - ~/.cache/huggingface/hub/:/root/.cache/huggingface/hub 18 | - ./service_config.yml:/src/ray/config/service_config.yml 19 | - ./ray_config.yml:/src/ray/config/ray_config.yml 20 | - ../../src/services/ray/start.sh:/start.sh 21 | - ray-data:/tmp/ray/ 22 | ports: 23 | - ${DEV_RAY_HEAD_PORT}:${RAY_HEAD_INTERNAL_PORT} 24 | - ${DEV_RAY_CLIENT_PORT}:${RAY_CLIENT_INTERNAL_PORT} 25 | - ${DEV_RAY_DASHBOARD_PORT}:${RAY_DASHBOARD_INTERNAL_PORT} 26 | - ${DEV_RAY_SERVE_PORT}:${RAY_SERVE_INTERNAL_PORT} 27 | deploy: 28 | replicas: 1 29 | resources: 30 | reservations: 31 | devices: 32 | - driver: nvidia 33 | count: ${N_DEVICES} 34 | capabilities: [ gpu ] 35 | environment: 36 | - LOKI_URL=http://${HOST_IP}:${DEV_LOKI_PORT}/loki/api/v1/push 37 | - OBJECT_STORE_URL=${HOST_IP}:${DEV_MINIO_PORT} 38 | - API_URL=http://${HOST_IP}:${DEV_API_PORT} 39 | - INFLUXDB_ADDRESS=http://${HOST_IP}:${DEV_INFLUXDB_PORT} 40 | - INFLUXDB_ADMIN_TOKEN=${SECRET_INFLUXDB_ADMIN_TOKEN} 41 | - INFLUXDB_ORG=${INFLUXDB_ORG} 42 | - INFLUXDB_BUCKET=${INFLUXDB_BUCKET} 43 | env_file: 44 | - .env 45 | 46 | api: 47 | image: api:latest 48 | ports: 49 | - ${DEV_API_PORT}:${API_INTERNAL_PORT} 50 | environment: 51 | OBJECT_STORE_URL: ${HOST_IP}:${DEV_MINIO_PORT} 52 | BROKER_URL: ${BROKER_PROTOCOL}@${HOST_IP}:${DEV_BROKER_PORT}/ 53 | WORKERS: 1 54 | RAY_ADDRESS: ray://${HOST_IP}:${DEV_RAY_CLIENT_PORT} 55 | LOKI_URL: http://${HOST_IP}:${DEV_LOKI_PORT}/loki/api/v1/push 56 | RAY_SERVE_QUEUE_LENGTH_RESPONSE_DEADLINE_S: 10 57 | INFLUXDB_ADDRESS: http://${HOST_IP}:${DEV_INFLUXDB_PORT} 58 | INFLUXDB_ADMIN_TOKEN: ${SECRET_INFLUXDB_ADMIN_TOKEN} 59 | INFLUXDB_ORG: ${INFLUXDB_ORG} 60 | INFLUXDB_BUCKET: ${INFLUXDB_BUCKET} 61 | API_INTERNAL_PORT: ${API_INTERNAL_PORT} 62 | 63 | prometheus: 64 | image: prom/prometheus:latest 65 | network_mode: "host" 66 | command: 67 | - '--config.file=/etc/prometheus/prometheus.yml' 68 | - '--web.enable-remote-write-receiver' 69 | volumes: 70 | - prometheus-data:/prometheus 71 | - ../../telemetry/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml 72 | - ray-data:/tmp/ray 73 | depends_on: 74 | - api 75 | - ray 76 | 77 | influxdb: 78 | image: influxdb:2 79 | network_mode: "host" 80 | environment: 81 | DOCKER_INFLUXDB_INIT_MODE: setup 82 | DOCKER_INFLUXDB_INIT_USERNAME: ${SECRET_INFLUXDB_ADMIN_USERNAME} 83 | DOCKER_INFLUXDB_INIT_PASSWORD: ${SECRET_INFLUXDB_ADMIN_PASSWORD} 84 | DOCKER_INFLUXDB_INIT_ADMIN_TOKEN: ${SECRET_INFLUXDB_ADMIN_TOKEN} 85 | DOCKER_INFLUXDB_INIT_ORG: ${INFLUXDB_ORG} 86 | DOCKER_INFLUXDB_INIT_BUCKET: ${INFLUXDB_BUCKET} 87 | volumes: 88 | - influxdb2-data:/var/lib/influxdb2 89 | - influxdb2-config:/etc/influxdb2 90 | 91 | grafana: 92 | image: grafana/grafana:latest 93 | network_mode: "host" 94 | environment: 95 | - GF_SECURITY_ADMIN_USER=${GRAFANA_ADMIN_USER} 96 | - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD} 97 | - INFLUXDB_ADMIN_TOKEN=${SECRET_INFLUXDB_ADMIN_TOKEN} 98 | volumes: 99 | - grafana-storage:/var/lib/grafana 100 | - ../../telemetry/grafana/provisioning:/etc/grafana/provisioning 101 | - ../../telemetry/grafana/dashboards:/var/lib/grafana/dashboards 102 | depends_on: 103 | - prometheus 104 | - influxdb 105 | 106 | loki: 107 | image: grafana/loki:2.8.1 108 | network_mode: "host" 109 | volumes: 110 | - loki-data:/loki 111 | 112 | volumes: 113 | grafana-storage: 114 | loki-data: 115 | prometheus-data: 116 | ray-data: 117 | influxdb2-data: 118 | influxdb2-config: 119 | -------------------------------------------------------------------------------- /compose/dev/ray_config.yml: -------------------------------------------------------------------------------- 1 | proxy_location: Disabled 2 | http_options: 3 | host: 0.0.0.0 4 | port: 5005 5 | grpc_options: 6 | port: 9000 7 | grpc_servicer_functions: [] 8 | logging_config: 9 | encoding: TEXT 10 | log_level: INFO 11 | logs_dir: null 12 | enable_access_log: true 13 | applications: 14 | - name: Controller 15 | import_path: src.ray.deployments.controller:app 16 | args: 17 | ray_config_path: /src/ray/config/ray_config.yml 18 | service_config_path: /src/ray/config/service_config.yml -------------------------------------------------------------------------------- /compose/dev/service_config.yml: -------------------------------------------------------------------------------- 1 | default_model_import_path: src.ray.deployments.model:app 2 | request_import_path: src.ray.deployments.request:app 3 | request_num_replicas: 1 4 | models: 5 | 6 | - model_key: 'nnsight.modeling.language.LanguageModel:{"repo_id": "openai-community/gpt2"}' 7 | ray_actor_options: 8 | num_gpus: 1 9 | num_replicas: 1 10 | 11 | -------------------------------------------------------------------------------- /compose/prod/.env: -------------------------------------------------------------------------------- 1 | # RabbitMQ Ports 2 | PROD_RABBITMQ_PORT=5672 3 | RABBITMQ_INTERNAL_PORT=5672 4 | 5 | # MinIO Ports 6 | PROD_MINIO_PORT=27017 7 | MINIO_INTERNAL_PORT=9000 8 | 9 | # Ray Ports 10 | PROD_RAY_HEAD_PORT=10001 11 | RAY_HEAD_INTERNAL_PORT=6379 12 | 13 | RAY_CLIENT_INTERNAL_PORT=10001 14 | 15 | RAY_DASHBOARD_INTERNAL_PORT=8265 16 | 17 | RAY_SERVE_INTERNAL_PORT=8267 18 | 19 | RAY_DASHBOARD_GRPC_PORT=8268 20 | OBJECT_MANAGER_PORT=8076 #Raylet port for object manager 21 | 22 | # API Ports 23 | PROD_API_PORT=5000 24 | API_INTERNAL_PORT=80 25 | 26 | # Prometheus Ports 27 | PROD_PROMETHEUS_PORT=9090 28 | PROMETHEUS_INTERNAL_PORT=9090 29 | 30 | # Grafana Ports 31 | PROD_GRAFANA_PORT=3000 32 | GRAFANA_INTERNAL_PORT=3000 33 | 34 | # influxDB Ports 35 | DEV_INFLUXDB_PORT=8086 36 | INFLUXDB_INTERNAL_PORT=8086 37 | 38 | # Loki Ports 39 | PROD_LOKI_PORT=3100 40 | LOKI_INTERNAL_PORT=3100 41 | 42 | # Device Configuration 43 | N_DEVICES=$N_DEVICES 44 | 45 | # Credentials and Other Configs 46 | RABBITMQ_DEFAULT_USER=guest 47 | RABBITMQ_DEFAULT_PASS=guest 48 | GRAFANA_ADMIN_USER=admin 49 | GRAFANA_ADMIN_PASSWORD=admin 50 | RAY_METRICS_GAUGE_EXPORT_INTERVAL_MS=1000 # Gauge export interval in ms 51 | RAY_DASHBOARD_HOST=0.0.0.0 52 | PROD_HOST_IP=nagoya.research.khoury.northeastern.edu 53 | FIREBASE_CREDS_PATH=/src/creds.json 54 | 55 | #InfluxDB Configs 56 | INFLUXDB_ORG=NDIF 57 | INFLUXDB_BUCKET=data 58 | SECRET_INFLUXDB_ADMIN_USERNAME=admin 59 | SECRET_INFLUXDB_ADMIN_PASSWORD=adminadmin 60 | SECRET_INFLUXDB_ADMIN_TOKEN=njkldhsfbdsfkl2o32== -------------------------------------------------------------------------------- /compose/prod/api-start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pip install python-logging-loki 3 | 4 | gunicorn src.app:app --bind 0.0.0.0:80 --workers $WORKERS --worker-class uvicorn.workers.UvicornWorker --timeout 120 5 | 6 | -------------------------------------------------------------------------------- /compose/prod/docker-compose-worker.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | ray-worker: 4 | image: ray_worker:latest 5 | volumes: 6 | - /drive7/hf-cache/hub/:/root/.cache/huggingface/hub 7 | network_mode: "host" 8 | environment: 9 | RAY_ADDRESS: 127.0.0.1:6379 10 | INFLUXDB_ADDRESS: http://${HOST_IP}:${DEV_INFLUXDB_PORT} 11 | INFLUXDB_ADMIN_TOKEN: ${SECRET_INFLUXDB_ADMIN_TOKEN} 12 | INFLUXDB_ORG: ${INFLUXDB_ORG} 13 | INFLUXDB_BUCKET: ${INFLUXDB_BUCKET} 14 | deploy: 15 | replicas: 1 16 | resources: 17 | reservations: 18 | devices: 19 | - driver: nvidia 20 | count: 8 21 | capabilities: [ gpu ] 22 | -------------------------------------------------------------------------------- /compose/prod/docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | 3 | rabbitmq: 4 | image: rabbitmq:3.11.28 5 | environment: 6 | RABBITMQ_DEFAULT_USER: ${RABBITMQ_DEFAULT_USER} 7 | RABBITMQ_DEFAULT_PASS: ${RABBITMQ_DEFAULT_PASS} 8 | ports: 9 | - ${PROD_RABBITMQ_PORT}:${RABBITMQ_INTERNAL_PORT} 10 | 11 | minio: 12 | image: minio/minio:latest 13 | command: server /data 14 | # environment: 15 | # - MINIO_ACCESS_KEY=minioadmin 16 | # MINIO_SECRET_KEY=minioadmin 17 | ports: 18 | - ${PROD_MINIO_PORT}:${MINIO_INTERNAL_PORT} 19 | 20 | ray-head: 21 | image: ray_head:latest 22 | network_mode: "host" 23 | shm_size: '15gb' 24 | volumes: 25 | - /disk/u/jfiottok/.cache/huggingface/hub/:/root/.cache/huggingface/hub 26 | - ./service_config.yml:/src/ray/config/service_config.yml 27 | - ./ray_config.yml:/src/ray/config/ray_config.yml 28 | - ../../services/ray_head/start.sh:/start.sh 29 | - ray-data:/tmp/ray/ 30 | environment: 31 | NCCL_DEBUG: INFO 32 | LOKI_URL: http://${PROD_HOST_IP}:${PROD_LOKI_PORT}/loki/api/v1/push 33 | OBJECT_STORE_URL: ${PROD_HOST_IP}:${PROD_MINIO_PORT} 34 | API_URL: https://ndif.dev 35 | INFLUXDB_ADDRESS: http://${HOST_IP}:${DEV_INFLUXDB_PORT} 36 | INFLUXDB_ADMIN_TOKEN: ${SECRET_INFLUXDB_ADMIN_TOKEN} 37 | INFLUXDB_ORG: ${INFLUXDB_ORG} 38 | INFLUXDB_BUCKET: ${INFLUXDB_BUCKET} 39 | env_file: 40 | - .env 41 | deploy: 42 | replicas: 1 43 | resources: 44 | reservations: 45 | devices: 46 | - driver: nvidia 47 | count: ${N_DEVICES} 48 | capabilities: [ gpu ] 49 | 50 | api: 51 | depends_on: 52 | - rabbitmq 53 | - ray-head 54 | - minio 55 | image: api:latest 56 | ports: 57 | - ${PROD_API_PORT}:${API_INTERNAL_PORT} 58 | volumes: 59 | - ./api-start.sh:/start.sh 60 | - ../../services/api/src/creds.json:/src/creds.json 61 | environment: 62 | OBJECT_STORE_URL: ${PROD_HOST_IP}:${PROD_MINIO_PORT} 63 | RMQ_URL: amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@${PROD_HOST_IP}:${PROD_RABBITMQ_PORT}/ 64 | WORKERS: 12 65 | RAY_ADDRESS: ray://${PROD_HOST_IP}:${PROD_RAY_HEAD_PORT} 66 | LOKI_URL: http://${PROD_HOST_IP}:${PROD_LOKI_PORT}/loki/api/v1/push 67 | FIREBASE_CREDS_PATH: /src/creds.json 68 | INFLUXDB_ADDRESS: http://${HOST_IP}:${DEV_INFLUXDB_PORT} 69 | INFLUXDB_ADMIN_TOKEN: ${SECRET_INFLUXDB_ADMIN_TOKEN} 70 | INFLUXDB_ORG: ${INFLUXDB_ORG} 71 | INFLUXDB_BUCKET: ${INFLUXDB_BUCKET} 72 | 73 | prometheus: 74 | image: prom/prometheus:latest 75 | network_mode: "host" 76 | command: 77 | - '--config.file=/etc/prometheus/prometheus.yml' 78 | - '--web.enable-remote-write-receiver' 79 | volumes: 80 | - prometheus-data:/prometheus 81 | - ../../telemetry/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml 82 | - ray-data:/tmp/ray 83 | depends_on: 84 | - api 85 | - ray-head 86 | 87 | influxdb: 88 | image: influxdb:2 89 | network_mode: "host" 90 | environment: 91 | DOCKER_INFLUXDB_INIT_MODE: setup 92 | DOCKER_INFLUXDB_INIT_USERNAME: ${SECRET_INFLUXDB_ADMIN_USERNAME} 93 | DOCKER_INFLUXDB_INIT_PASSWORD: ${SECRET_INFLUXDB_ADMIN_PASSWORD} 94 | DOCKER_INFLUXDB_INIT_ADMIN_TOKEN: ${SECRET_INFLUXDB_ADMIN_TOKEN} 95 | DOCKER_INFLUXDB_INIT_ORG: ${INFLUXDB_ORG} 96 | DOCKER_INFLUXDB_INIT_BUCKET: ${INFLUXDB_BUCKET} 97 | volumes: 98 | - influxdb2-data:/var/lib/influxdb2 99 | - influxdb2-config:/etc/influxdb2 100 | 101 | grafana: 102 | image: grafana/grafana:latest 103 | network_mode: "host" 104 | environment: 105 | - GF_SECURITY_ADMIN_USER=${GRAFANA_ADMIN_USER} 106 | - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD} 107 | - INFLUXDB_ADMIN_TOKEN=${SECRET_INFLUXDB_ADMIN_TOKEN} 108 | volumes: 109 | - grafana-storage:/var/lib/grafana 110 | - ../../telemetry/grafana/provisioning:/etc/grafana/provisioning 111 | - ../../telemetry/grafana/dashboards:/var/lib/grafana/dashboards 112 | depends_on: 113 | - prometheus 114 | - influxdb 115 | 116 | loki: 117 | image: grafana/loki:2.8.1 118 | network_mode: "host" 119 | volumes: 120 | - loki-data:/loki 121 | 122 | volumes: 123 | grafana-storage: 124 | loki-data: 125 | prometheus-data: 126 | ray-data: 127 | influxdb2-data: 128 | influxdb2-config: 129 | -------------------------------------------------------------------------------- /compose/prod/ray_config.yml: -------------------------------------------------------------------------------- 1 | proxy_location: Disabled 2 | http_options: 3 | host: 0.0.0.0 4 | port: 5005 5 | grpc_options: 6 | port: 9000 7 | grpc_servicer_functions: [] 8 | logging_config: 9 | encoding: TEXT 10 | log_level: INFO 11 | logs_dir: null 12 | enable_access_log: true 13 | applications: 14 | - name: Controller 15 | import_path: src.ray.deployments.controller:app 16 | args: 17 | ray_config_path: /src/ray/config/ray_config.yml 18 | service_config_path: /src/ray/config/service_config.yml 19 | -------------------------------------------------------------------------------- /compose/prod/service_config.yml: -------------------------------------------------------------------------------- 1 | default_model_import_path: src.ray.deployments.model:app 2 | request_import_path: src.ray.deployments.request:app 3 | request_num_replicas: 1 4 | models: 5 | 6 | - model_key: 'nnsight.models.LanguageModel.LanguageModel:{"repo_id": "meta-llama/Meta-Llama-3.1-8B"}' 7 | ray_actor_options: 8 | num_gpus: 1 9 | resources: 10 | fukuyama: 1 11 | num_replicas: 1 12 | 13 | - model_key: 'nnsight.models.LanguageModel.LanguageModel:{"repo_id": "EleutherAI/gpt-j-6b"}' 14 | ray_actor_options: 15 | num_gpus: 1 16 | resources: 17 | hamada: 1 18 | num_replicas: 1 19 | 20 | 21 | - model_key: 'nnsight.models.LanguageModel.LanguageModel:{"repo_id": "meta-llama/Meta-Llama-3.1-405B"}' 22 | num_replicas: 1 23 | model_import_path: src.ray.deployments.distributed_model:app 24 | args: 25 | torch_distributed_port: 5003 26 | torch_distributed_world_size: 16 27 | torch_distributed_world_timeout_seconds: 40 28 | tensor_parallelism_size: 16 29 | ray_actor_options: 30 | num_gpus: 1 31 | 32 | 33 | - model_key: 'nnsight.models.LanguageModel.LanguageModel:{"repo_id": "meta-llama/Meta-Llama-3.1-70B"}' 34 | ray_actor_options: 35 | num_gpus: 3 36 | resources: 37 | boa: 1 38 | num_replicas: 1 39 | -------------------------------------------------------------------------------- /docker/dockerfile.base: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | # Install base utilities 4 | RUN apt-get update \ 5 | && apt-get install -y build-essential wget python3-distutils \ 6 | && apt-get clean \ 7 | && rm -rf /var/lib/apt/lists/* 8 | 9 | # Install miniconda 10 | ENV CONDA_DIR=/opt/conda 11 | ENV PATH=$CONDA_DIR/bin:$PATH 12 | 13 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh \ 14 | && /bin/bash ~/miniconda.sh -b -p /opt/conda \ 15 | && rm ~/miniconda.sh 16 | 17 | # Create base environment with minimal dependencies 18 | COPY services/base/environment.yml . 19 | RUN conda env create --name service -f environment.yml -------------------------------------------------------------------------------- /docker/dockerfile.conda: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=ndif_base 2 | ARG TAG=latest 3 | FROM ${BASE_IMAGE}:${TAG} 4 | 5 | # Update environment with service-specific dependencies 6 | ARG NAME 7 | COPY src/services/${NAME}/environment.yml . 8 | 9 | RUN conda env update --name service -f environment.yml -------------------------------------------------------------------------------- /docker/dockerfile.service: -------------------------------------------------------------------------------- 1 | ARG NAME 2 | ARG TAG=latest 3 | FROM ${NAME}_conda:${TAG} 4 | 5 | # New build stage, so need to redeclare NAME 6 | ARG NAME 7 | COPY ./src.tar.gz ./src.tar.gz 8 | COPY src/services/${NAME}/start.sh ./start.sh 9 | 10 | RUN tar -xvf ./src.tar.gz \ 11 | && rm ./src.tar.gz 12 | 13 | SHELL ["/bin/bash", "-c"] 14 | 15 | CMD source activate service && bash ./start.sh 16 | -------------------------------------------------------------------------------- /docker/helpers/check_and_update_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | check_and_update_environment() { 6 | echo "Checking for environment updates..." 7 | source activate service 8 | 9 | # Compare current environment with environment.yml 10 | if conda env export --name service | grep -v "^prefix: " > /tmp/current_env.yml && \ 11 | diff -q /tmp/current_env.yml environment.yml > /dev/null 2>&1; then 12 | echo "Environment is up-to-date." 13 | else 14 | echo "Differences detected in environment. Updating..." 15 | conda env update --file environment.yml --prune 16 | echo "Environment updated successfully." 17 | fi 18 | 19 | rm -f /tmp/current_env.yml 20 | } 21 | 22 | # Check and update the environment 23 | check_and_update_environment 24 | -------------------------------------------------------------------------------- /docker/helpers/nns_inst.sh: -------------------------------------------------------------------------------- 1 | source activate service 2 | 3 | pip uninstall nnsight -y 4 | 5 | pip install -e ./nnsight 6 | 7 | bash ../start.sh -------------------------------------------------------------------------------- /scripts/redeploy.py: -------------------------------------------------------------------------------- 1 | import ray 2 | from ray import serve 3 | 4 | ray.init() 5 | 6 | serve.get_app_handle("Controller").redeploy.remote() 7 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | import nnsight 2 | 3 | nnsight.CONFIG.set_default_api_key("api key") 4 | nnsight.CONFIG.API.HOST = "localhost:5001" 5 | nnsight.CONFIG.API.SSL = False 6 | 7 | model = nnsight.LanguageModel("openai-community/gpt2") 8 | 9 | with model.trace("The Eiffel Tower is located in ", remote=True): 10 | 11 | output = model.output.save() 12 | 13 | print(output) 14 | -------------------------------------------------------------------------------- /src/common/logging/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import load_logger -------------------------------------------------------------------------------- /src/common/logging/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging_loki 3 | import os 4 | import socket 5 | import sys 6 | import traceback 7 | from typing import Dict, Optional, Any 8 | import time 9 | from functools import wraps 10 | 11 | # Environment variables for Loki configuration 12 | LOKI_URL = os.environ.get('LOKI_URL') 13 | LOKI_RETRY_COUNT = int(os.environ.get('LOKI_RETRY_COUNT', '3')) # Number of retry attempts for failed log sends 14 | 15 | # Global logger instance 16 | LOGGER: Optional[logging.Logger] = None 17 | 18 | class CustomJSONFormatter(logging.Formatter): 19 | """ 20 | Custom JSON formatter for structured logging. 21 | 22 | Extends the standard logging.Formatter to add additional fields 23 | like service name, hostname, and metrics data to log records. 24 | """ 25 | def __init__(self, service_name, fmt=None, datefmt=None, style='%', *args, **kwargs): 26 | """ 27 | Initialize the formatter with service information. 28 | 29 | Args: 30 | service_name: Name of the service generating logs 31 | fmt: Log format string 32 | datefmt: Date format string 33 | style: Style of the format string (%, {, or $) 34 | """ 35 | super().__init__(fmt=fmt, datefmt=datefmt, style=style, *args, **kwargs) 36 | self.service_name = service_name 37 | self.hostname = socket.gethostname() 38 | 39 | def format(self, record): 40 | """ 41 | Format the log record by adding custom fields. 42 | 43 | Args: 44 | record: The log record to format 45 | 46 | Returns: 47 | Formatted log record as a string 48 | """ 49 | # Add custom fields to the log record 50 | record.service_name = self.service_name 51 | record.hostname = self.hostname 52 | record.process_id = os.getpid() 53 | record.thread_name = record.threadName 54 | # Add code location 55 | record.code_file = record.pathname 56 | record.code_line = record.lineno 57 | 58 | # Format the log record using the standard logging format 59 | return super().format(record) 60 | 61 | class RetryingLokiHandler(logging_loki.LokiHandler): 62 | """ 63 | Extended Loki handler with retry capability for handling network issues. 64 | 65 | Attempts to resend logs to Loki if initial attempts fail, using 66 | exponential backoff between retries. 67 | """ 68 | def __init__(self, retry_count=LOKI_RETRY_COUNT, *args, **kwargs): 69 | """ 70 | Initialize the handler with retry configuration. 71 | 72 | Args: 73 | retry_count: Number of times to retry sending logs 74 | *args, **kwargs: Arguments passed to LokiHandler 75 | """ 76 | self.retry_count = retry_count 77 | super().__init__(*args, **kwargs) 78 | 79 | def emit(self, record): 80 | """ 81 | Send the log record to Loki with retry logic. 82 | 83 | Args: 84 | record: The log record to send 85 | """ 86 | for attempt in range(self.retry_count): 87 | try: 88 | super().emit(record) 89 | return 90 | except Exception as e: 91 | if attempt == self.retry_count - 1: 92 | sys.stderr.write(f"Failed to send log to Loki after {self.retry_count} attempts: {e}\n") 93 | else: 94 | time.sleep(0.5 * (attempt + 1)) # Exponential backoff 95 | 96 | def load_logger(service_name: str="", logger_name: str="") -> logging.Logger: 97 | """ 98 | Configure and return a logger with console and optional Loki handlers. 99 | 100 | Sets up a logger with structured JSON formatting for Loki and simpler 101 | formatting for console output. Uses a singleton pattern to avoid 102 | creating multiple loggers. 103 | 104 | Args: 105 | service_name: Name of the service using the logger 106 | logger_name: Name for the logger instance 107 | 108 | Returns: 109 | Configured logging.Logger instance 110 | """ 111 | global LOGGER 112 | 113 | if LOGGER is not None: 114 | return LOGGER 115 | 116 | logger = logging.getLogger(logger_name) 117 | logger.setLevel(logging.DEBUG) 118 | logger.handlers.clear() 119 | 120 | # JSON format for structured logging (used by Loki handler) 121 | json_format = '''{ 122 | "timestamp": "%(asctime)s", 123 | "service": { 124 | "name": "%(service_name)s", 125 | "hostname": "%(hostname)s", 126 | "process_id": %(process_id)d 127 | }, 128 | "log": { 129 | "level": "%(levelname)s", 130 | "logger": "%(name)s", 131 | "function": "%(funcName)s", 132 | "thread": "%(thread_name)s" 133 | }, 134 | "code": { 135 | "file": "%(code_file)s", 136 | "line": %(code_line)d 137 | }, 138 | "message": "%(message)s", 139 | }''' 140 | 141 | # Simpler format for console output 142 | console_format = '%(asctime)s [%(levelname)s] %(name)s - %(message)s' 143 | 144 | # Create formatters for different outputs 145 | json_formatter = CustomJSONFormatter( 146 | fmt=json_format, 147 | service_name=service_name, 148 | datefmt="%Y-%m-%d %H:%M:%S.%f%z" 149 | ) 150 | 151 | console_formatter = logging.Formatter( 152 | fmt=console_format, 153 | datefmt="%Y-%m-%d %H:%M:%S" 154 | ) 155 | 156 | # Set up console handler for local debugging 157 | console_handler = logging.StreamHandler() 158 | console_handler.setFormatter(console_formatter) 159 | console_handler.setLevel(logging.DEBUG) 160 | logger.addHandler(console_handler) 161 | 162 | # Set up Loki handler if URL is configured 163 | if LOKI_URL is not None: 164 | # Loki handler configuration with batching and retries 165 | loki_handler = RetryingLokiHandler( 166 | url=LOKI_URL, 167 | tags={ 168 | "application": service_name, 169 | "hostname": socket.gethostname(), 170 | }, 171 | auth=None, 172 | version="1", 173 | 174 | ) 175 | loki_handler.setFormatter(json_formatter) 176 | loki_handler.setLevel(logging.INFO) # Only send INFO and above to Loki 177 | logger.addHandler(loki_handler) 178 | 179 | 180 | LOGGER = logger 181 | return logger -------------------------------------------------------------------------------- /src/common/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import Metric 2 | 3 | from .gpu_mem import GPUMemMetric 4 | from .network_data import NetworkStatusMetric 5 | from .request_status import RequestStatusMetric 6 | from .request_transport_latency import TransportLatencyMetric 7 | from .request_execution_time import ExecutionTimeMetric 8 | from .request_response_size import RequestResponseSizeMetric -------------------------------------------------------------------------------- /src/common/metrics/gpu_mem.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | from . import Metric 3 | 4 | if TYPE_CHECKING: 5 | 6 | from ..schema import BackendRequestModel 7 | 8 | 9 | class GPUMemMetric(Metric): 10 | 11 | name:str = "request_gpu_mem" 12 | 13 | @classmethod 14 | def update(cls, request: "BackendRequestModel", gpu_mem:float): 15 | 16 | super().update( 17 | gpu_mem, 18 | request_id=request.id, 19 | api_key=request.api_key, 20 | model_key=request.model_key, 21 | ) -------------------------------------------------------------------------------- /src/common/metrics/metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import TYPE_CHECKING, Any, Optional, Union 3 | 4 | from influxdb_client import InfluxDBClient, WriteApi 5 | from influxdb_client.client.write_api import SYNCHRONOUS 6 | from influxdb_client import Point 7 | from ..logging.logger import load_logger 8 | 9 | logger = load_logger() 10 | 11 | class Metric: 12 | 13 | name: str 14 | client: Optional[WriteApi] = None 15 | 16 | @classmethod 17 | def update(cls, measurement: Union[Any, Point], **tags): 18 | 19 | try: 20 | 21 | if Metric.client is None: 22 | 23 | Metric.client = InfluxDBClient( 24 | url=os.getenv("INFLUXDB_ADDRESS"), 25 | token=os.getenv("INFLUXDB_ADMIN_TOKEN"), 26 | ).write_api(write_options=SYNCHRONOUS) 27 | 28 | # If youre providing a Point directly, use it as is 29 | if isinstance(measurement, Point): 30 | point = measurement 31 | 32 | #Otherwise build it from the value (measurement) and its tags. 33 | else: 34 | 35 | point: Point = Point(cls.name).field(cls.name, measurement) 36 | 37 | for key, value in tags.items(): 38 | 39 | point = point.tag(key, value) 40 | 41 | Metric.client.write( 42 | bucket=os.getenv("INFLUXDB_BUCKET"), 43 | org=os.getenv("INFLUXDB_ORG"), 44 | record=point, 45 | ) 46 | 47 | except Exception as e: 48 | logger.exception(f"Error updating metric {cls.name}: {e}") 49 | -------------------------------------------------------------------------------- /src/common/metrics/network_data.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any 2 | from . import Metric 3 | 4 | if TYPE_CHECKING: 5 | 6 | from ..schema import BackendRequestModel 7 | 8 | else: 9 | BackendRequestModel = Any 10 | 11 | class NetworkStatusMetric(Metric): 12 | 13 | name: str = "network_data" 14 | 15 | @classmethod 16 | def update( 17 | cls, 18 | request: BackendRequestModel, 19 | ip_address: str, 20 | user_agent: str, 21 | content_length: int, 22 | ) -> None: 23 | 24 | super().update( 25 | content_length, 26 | request_id=request.id, 27 | model_key=request.model_key, 28 | api_key=request.api_key, 29 | ip_address=ip_address, 30 | user_agent=user_agent, 31 | ) 32 | -------------------------------------------------------------------------------- /src/common/metrics/request_execution_time.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | from . import Metric 3 | 4 | if TYPE_CHECKING: 5 | 6 | from ..schema import BackendRequestModel 7 | 8 | 9 | class ExecutionTimeMetric(Metric): 10 | 11 | name:str = "request_execution_time" 12 | 13 | @classmethod 14 | def update(cls, request: "BackendRequestModel", time_s:float): 15 | 16 | super().update( 17 | time_s, 18 | request_id=request.id, 19 | api_key=request.api_key, 20 | model_key=request.model_key, 21 | ) -------------------------------------------------------------------------------- /src/common/metrics/request_response_size.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | from . import Metric 3 | 4 | if TYPE_CHECKING: 5 | 6 | from ..schema import BackendRequestModel 7 | 8 | 9 | class RequestResponseSizeMetric(Metric): 10 | 11 | name:str = "request_response_size" 12 | 13 | @classmethod 14 | def update(cls, request: "BackendRequestModel", size: int): 15 | 16 | super().update( 17 | size, 18 | request_id=request.id, 19 | api_key=request.api_key, 20 | model_key=request.model_key, 21 | ) -------------------------------------------------------------------------------- /src/common/metrics/request_status.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import TYPE_CHECKING 3 | 4 | from . import Metric 5 | 6 | if TYPE_CHECKING: 7 | 8 | from ..schema import BackendRequestModel, BackendResponseModel 9 | 10 | 11 | class RequestStatusMetric(Metric): 12 | """ 13 | This class abstracts the usage of metrics for tracking the status of requests across different services. 14 | Specifically, it handles the complexity introduced by Ray's distributed system when using Prometheus. 15 | 16 | Considerations: 17 | - Ray's distributed nature complicates direct use of Prometheus client objects, requiring dynamic HTTP servers or a Pushgateway, which adds complexity and potential performance issues. 18 | - To avoid this, Ray's built-in metrics API (Gauge) is used, which handles the distributed aspect automatically. 19 | - However, Ray's API differs slightly from the Prometheus client, leading to a messier interface in this class. 20 | - Additionally, Ray prepends "ray_" to metric names, which needs to be handled separately in Grafana. 21 | 22 | This class supports both Ray's Gauge API and Prometheus' Gauge API, switching between them based on the service type. 23 | """ 24 | 25 | class NumericJobStatus(Enum): 26 | RECEIVED = 1 27 | APPROVED = 2 28 | RUNNING = 3 29 | COMPLETED = 4 30 | LOG = 5 31 | ERROR = 6 32 | STREAM = 7 33 | NNSIGHT_ERROR = 8 34 | 35 | name: str = "request_status" 36 | 37 | @classmethod 38 | def update( 39 | cls, 40 | request: "BackendRequestModel", 41 | response: "BackendResponseModel", 42 | ) -> None: 43 | """ 44 | Update the values of the gauge to reflect the current status of a request. 45 | Handles both Ray and Prometheus Gauge APIs. 46 | 47 | Args: 48 | - request (RequestModel): request object. 49 | - status (ResponseModel.JobStatus): user request job status. 50 | - api_key (str): user api key. 51 | - user_id (str): 52 | - msg (str): description of the current job status of the request. 53 | 54 | Returns: 55 | """ 56 | numeric_status = int(cls.NumericJobStatus[response.status.value].value) 57 | 58 | super().update( 59 | numeric_status, 60 | request_id=str(request.id), 61 | api_key=str(request.api_key), 62 | model_key=str(request.model_key), 63 | msg=response.description, 64 | ) 65 | -------------------------------------------------------------------------------- /src/common/metrics/request_transport_latency.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import TYPE_CHECKING 3 | from . import Metric 4 | 5 | if TYPE_CHECKING: 6 | 7 | from ..schema import BackendRequestModel 8 | 9 | 10 | class TransportLatencyMetric(Metric): 11 | 12 | name:str = "request_transport_latency" 13 | 14 | @classmethod 15 | def update(cls, request: "BackendRequestModel"): 16 | 17 | if request.sent is not None: 18 | 19 | super().update( 20 | time.time() - request.sent, 21 | request_id=request.id, 22 | api_key=request.api_key, 23 | model_key=request.model_key, 24 | ) -------------------------------------------------------------------------------- /src/common/schema/__init__.py: -------------------------------------------------------------------------------- 1 | from .request import BackendRequestModel 2 | from .response import BackendResponseModel 3 | from .result import BackendResultModel 4 | from .mixins import ObjectStorageMixin, TelemetryMixin 5 | from nnsight.schema.result import RESULT -------------------------------------------------------------------------------- /src/common/schema/mixins.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from io import BytesIO 5 | from typing import ClassVar, TYPE_CHECKING, Union, Any 6 | 7 | import torch 8 | import boto3 9 | from botocore.response import StreamingBody 10 | from pydantic import BaseModel 11 | from typing_extensions import Self 12 | 13 | if TYPE_CHECKING: 14 | from nnsight.schema.response import ResponseModel 15 | from nnsight.schema.request import RequestModel 16 | 17 | 18 | class ObjectStorageMixin(BaseModel): 19 | """ 20 | Mixin to provide object storage functionality for models using S3. 21 | 22 | This mixin allows models to save and load themselves from an S3 object store 23 | by serializing their data and interacting with the S3 API. 24 | 25 | Attributes: 26 | id (str): Unique identifier for the object to be stored. 27 | _bucket_name (ClassVar[str]): The default bucket name for storing objects. 28 | _file_extension (ClassVar[str]): The file extension used for stored objects. 29 | 30 | Methods: 31 | object_name(id: str) -> str: 32 | Returns the object name based on the provided ID and file extension. 33 | 34 | save(client: boto3.client) -> Self: 35 | Serializes and saves the object to S3 storage. 36 | 37 | load(client: boto3.client, id: str, stream: bool = False) -> StreamingBody | Self: 38 | Loads and deserializes the object from S3 storage. 39 | 40 | delete(client: boto3.client, id: str) -> None: 41 | Deletes the object from S3 storage. 42 | """ 43 | id: str 44 | size: int = None 45 | 46 | _bucket_name: ClassVar[str] = "default" 47 | _file_extension: ClassVar[str] = "json" 48 | 49 | 50 | @classmethod 51 | def object_name(cls, id: str): 52 | return f"{id}.{cls._file_extension}" 53 | 54 | def _save(self, client: boto3.client, data: BytesIO, content_type: str, bucket_name: str = None) -> None: 55 | bucket_name = self._bucket_name if bucket_name is None else bucket_name 56 | object_name = self.object_name(self.id) 57 | 58 | data.seek(0) 59 | 60 | # Check if bucket exists, create if it doesn't 61 | try: 62 | client.head_bucket(Bucket=bucket_name) 63 | except client.exceptions.ClientError: 64 | client.create_bucket(Bucket=bucket_name) 65 | 66 | # Upload object to S3 67 | client.upload_fileobj( 68 | Fileobj=data, 69 | Bucket=bucket_name, 70 | Key=object_name, 71 | ExtraArgs={'ContentType': content_type} 72 | ) 73 | 74 | @classmethod 75 | def _load( 76 | cls, client: boto3.client, id: str, stream: bool = False 77 | ) -> Union[StreamingBody, bytes]: 78 | bucket_name = cls._bucket_name 79 | object_name = cls.object_name(id) 80 | 81 | response = client.get_object(Bucket=bucket_name, Key=object_name) 82 | 83 | if stream: 84 | return response['Body'], response['ContentLength'] 85 | 86 | data = response['Body'].read() 87 | response['Body'].close() 88 | 89 | return data 90 | 91 | def save(self, client: boto3.client) -> Self: 92 | if self._file_extension == "json": 93 | data = BytesIO(self.model_dump_json().encode("utf-8")) 94 | content_type = "application/json" 95 | elif self._file_extension == "pt": 96 | data = BytesIO() 97 | torch.save(self.model_dump(), data) 98 | content_type = "application/octet-stream" 99 | 100 | self.size = data.getbuffer().nbytes 101 | 102 | self._save(client, data, content_type) 103 | 104 | return self 105 | 106 | @classmethod 107 | def load(cls, client: boto3.client, id: str, stream: bool = False) -> Union[StreamingBody, Self]: 108 | object_data = cls._load(client, id, stream=stream) 109 | 110 | if stream: 111 | return object_data 112 | 113 | if cls._file_extension == "json": 114 | return cls.model_validate_json(object_data.decode("utf-8")) 115 | elif cls._file_extension == "pt": 116 | return torch.load(BytesIO(object_data), map_location="cpu", weights_only=False) 117 | 118 | @classmethod 119 | def delete(cls, client: boto3.client, id: str) -> None: 120 | bucket_name = cls._bucket_name 121 | object_name = cls.object_name(id) 122 | 123 | try: 124 | client.delete_object(Bucket=bucket_name, Key=object_name) 125 | except: 126 | pass 127 | 128 | 129 | class TelemetryMixin: 130 | """ 131 | Mixin to provide telemetry functionality for models, including logging and gauge updates. 132 | 133 | This mixin enables models to log their status and update Prometheus or Ray metrics (gauges) 134 | to track their state in the system. It abstracts the underlying telemetry mechanisms and 135 | allows easy integration of logging and metric updates. 136 | 137 | Methods: 138 | backend_log(logger: logging.Logger, message: str, level: str = 'info') -> Self: 139 | Logs a message with the specified logging level (info, error, exception). 140 | 141 | update_gauge(gauge: NDIFGauge) -> Self: 142 | Updates the telemetry gauge to track the status of a request or response. 143 | """ 144 | def backend_log(self, logger: logging.Logger, message: str, level: str = 'info'): 145 | if level == 'info': 146 | logger.info(message) 147 | elif level == 'error': 148 | logger.error(message) 149 | elif level == 'exception': 150 | logger.exception(message) 151 | return self 152 | 153 | -------------------------------------------------------------------------------- /src/common/schema/request.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import uuid 5 | import time 6 | from typing import TYPE_CHECKING, ClassVar, Optional, Union, Coroutine 7 | 8 | import ray 9 | from fastapi import Request 10 | from pydantic import ConfigDict 11 | from typing_extensions import Self 12 | 13 | from nnsight import NNsight 14 | from nnsight.schema.request import RequestModel 15 | from nnsight.schema.response import ResponseModel 16 | from nnsight.tracing.graph import Graph 17 | 18 | from .mixins import ObjectStorageMixin 19 | from .response import BackendResponseModel 20 | 21 | 22 | class BackendRequestModel(ObjectStorageMixin): 23 | """ 24 | 25 | Attributes: 26 | - model_config: model configuration. 27 | - graph (Union[bytes, ray.ObjectRef]): intervention graph object, could be in multiple forms. 28 | - model_key (str): model key name. 29 | - session_id (Optional[str]): connection session id. 30 | - format (str): format of the request body. 31 | - zlib (bool): is the request body compressed. 32 | - id (str): request id. 33 | - received (datetime.datetime): time of the request being received. 34 | - api_key (str): api key associated with this request. 35 | - _bucket_name (str): request result bucket storage name. 36 | - _file_extension (str): file extension. 37 | """ 38 | 39 | model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) 40 | 41 | _bucket_name: ClassVar[str] = "serialized-requests" 42 | _file_extension: ClassVar[str] = "json" 43 | 44 | graph: Optional[Union[Coroutine, bytes, ray.ObjectRef]] = None 45 | 46 | model_key: str 47 | session_id: Optional[str] = None 48 | format: str 49 | zlib: bool 50 | 51 | id: str 52 | 53 | sent: Optional[float] = None 54 | 55 | api_key: Optional[str] = '' 56 | 57 | def deserialize(self, model: NNsight) -> Graph: 58 | 59 | graph = self.graph 60 | 61 | if isinstance(self.graph, ray.ObjectRef): 62 | 63 | graph = ray.get(graph) 64 | 65 | return RequestModel.deserialize(model, graph, "json", self.zlib) 66 | 67 | @classmethod 68 | def from_request( 69 | cls, request: Request, api_key: str 70 | ) -> Self: 71 | 72 | headers = request.headers 73 | 74 | return BackendRequestModel( 75 | graph=request.body(), 76 | model_key=headers["model_key"], 77 | session_id=headers.get("session_id", None), 78 | format=headers["format"], 79 | zlib=headers["zlib"], 80 | id=str(uuid.uuid4()), 81 | sent=float(headers.get("sent-timestamp", None)), 82 | api_key=api_key, 83 | ) 84 | 85 | def create_response( 86 | self, 87 | status: ResponseModel.JobStatus, 88 | logger: logging.Logger, 89 | description: str = "", 90 | data: bytes = None, 91 | ) -> BackendResponseModel: 92 | """Generates a BackendResponseModel given a change in status to an ongoing request.""" 93 | 94 | log_msg = f"{self.id} - {status.name}: {description}" 95 | 96 | logging_level = "info" 97 | 98 | if status == ResponseModel.JobStatus.ERROR: 99 | logging_level = "exception" 100 | elif status == ResponseModel.JobStatus.NNSIGHT_ERROR: 101 | logging_level = "exception" 102 | 103 | 104 | response = ( 105 | BackendResponseModel( 106 | id=self.id, 107 | session_id=self.session_id, 108 | status=status, 109 | description=description, 110 | data=data, 111 | ) 112 | .backend_log( 113 | logger=logger, 114 | message=log_msg, 115 | level=logging_level, 116 | ) 117 | .update_metric( 118 | self, 119 | ) 120 | ) 121 | 122 | return response 123 | -------------------------------------------------------------------------------- /src/common/schema/response.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any, ClassVar, Optional 4 | 5 | import requests 6 | import socketio 7 | import boto3 8 | from pydantic import field_serializer 9 | from typing_extensions import Self 10 | 11 | from nnsight.schema.response import ResponseModel 12 | 13 | from ..metrics import RequestStatusMetric 14 | from .mixins import ObjectStorageMixin, TelemetryMixin 15 | 16 | if TYPE_CHECKING: 17 | from . import BackendRequestModel 18 | 19 | 20 | class BackendResponseModel(ResponseModel, ObjectStorageMixin, TelemetryMixin): 21 | 22 | _bucket_name: ClassVar[str] = "responses" 23 | 24 | def __str__(self) -> str: 25 | return f"{self.id} - {self.status.name}: {self.description}" 26 | 27 | @property 28 | def blocking(self) -> bool: 29 | return self.session_id is not None 30 | 31 | def respond(self, sio: socketio.SimpleClient, object_store: boto3.client) -> ResponseModel: 32 | if self.blocking: 33 | 34 | fn = sio.client.emit 35 | 36 | if ( 37 | self.status == ResponseModel.JobStatus.COMPLETED 38 | or self.status == ResponseModel.JobStatus.ERROR 39 | or self.status == ResponseModel.JobStatus.NNSIGHT_ERROR 40 | ): 41 | 42 | fn = sio.client.call 43 | 44 | fn("blocking_response", data=(self.session_id, self.pickle())) 45 | else: 46 | self.save(object_store) 47 | 48 | return self 49 | 50 | @field_serializer("status") 51 | def sstatus(self, value, _info): 52 | return value.value 53 | def update_metric( 54 | self, 55 | request: "BackendRequestModel", 56 | ) -> Self: 57 | """Updates the telemetry gauge to track the status of a request or response. 58 | 59 | Args: 60 | 61 | - gauge (NDIFGauge): Telemetry Gauge. 62 | - request (RequestModel): user request. 63 | - status (ResponseModel.JobStatus): status of the user request. 64 | - kwargs: key word arguments to NDIFGauge.update(). 65 | 66 | Returns: 67 | Self. 68 | """ 69 | 70 | RequestStatusMetric.update(request, self) 71 | 72 | return self 73 | -------------------------------------------------------------------------------- /src/common/schema/result.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar 2 | 3 | from nnsight.schema.result import ResultModel 4 | from .mixins import ObjectStorageMixin 5 | 6 | class BackendResultModel(ResultModel, ObjectStorageMixin): 7 | 8 | _bucket_name: ClassVar[str] = "dev-ndif-results" 9 | _file_extension: ClassVar[str] = "pt" 10 | -------------------------------------------------------------------------------- /src/services/api/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | dependencies: 4 | - python=3.10 5 | - pip 6 | - git 7 | - pip: 8 | - pip 9 | - setuptools 10 | # API 11 | - fastapi-cache2 12 | - fastapi==0.108.0 13 | - fastapi-socketio 14 | - python-socketio 15 | - redis 16 | #- firebase-admin 17 | - psycopg2-binary 18 | # Telemetry 19 | - asgiref 20 | - opentelemetry-api 21 | - opentelemetry-sdk 22 | - opentelemetry-exporter-otlp 23 | - opentelemetry-instrumentation-fastapi 24 | - prometheus_client 25 | - prometheus-fastapi-instrumentator 26 | - python-logging-loki 27 | # Http server 28 | - uvicorn[standard]==0.24.0 29 | - gunicorn 30 | - eventlet 31 | # Database 32 | - boto3 33 | # Tasks 34 | - ray[serve]==2.45.0 35 | - requests 36 | - torch 37 | - python-slugify 38 | - influxdb-client 39 | - nnsight -------------------------------------------------------------------------------- /src/services/api/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ndif-team/ndif/546316aac1f625cbeb2d93a0d2be47e856623452/src/services/api/src/__init__.py -------------------------------------------------------------------------------- /src/services/api/src/api_key.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import TYPE_CHECKING 3 | import psycopg2 4 | from typing import Optional 5 | 6 | from fastapi import HTTPException 7 | from fastapi.responses import JSONResponse 8 | from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_400_BAD_REQUEST 9 | 10 | from .logging import load_logger 11 | from .metrics import NetworkStatusMetric 12 | from .schema import BackendRequestModel 13 | import json 14 | #from .util import check_valid_email 15 | 16 | if TYPE_CHECKING: 17 | from fastapi import Request 18 | 19 | from .schema import BackendRequestModel 20 | 21 | logger = load_logger(service_name="api", logger_name="gunicorn.error") 22 | 23 | # TODO: Make this be derived from a base class 24 | 25 | class AccountsDB: 26 | """Database class for accounts""" 27 | def __init__(self, host, port, database, user, password): 28 | self.conn = psycopg2.connect( 29 | host=host, 30 | port=port, 31 | database=database, 32 | user=user, 33 | password=password, 34 | connect_timeout=10 35 | ) 36 | self.cur = self.conn.cursor() 37 | 38 | def __del__(self): 39 | self.cur.close() 40 | self.conn.close() 41 | 42 | def api_key_exists(self, key_id: str) -> bool: 43 | """Check if a key exists""" 44 | try: 45 | with self.conn.cursor() as cur: 46 | cur.execute("SELECT EXISTS(SELECT 1 FROM keys WHERE key_id = %s)", (key_id,)) 47 | result = cur.fetchone() 48 | return result[0] if result else False 49 | except Exception as e: 50 | logger.error(f"Error checking if key exists: {e}") 51 | self.conn.rollback() 52 | return False 53 | 54 | def model_id_from_key(self, key_id: str) -> Optional[str]: 55 | """Get the model ID from a key ID""" 56 | try: 57 | with self.conn.cursor() as cur: 58 | cur.execute("SELECT model_id FROM models WHERE model_key = %s", (key_id,)) 59 | result = cur.fetchone() 60 | return result[0] if result else None 61 | except Exception as e: 62 | logger.error(f"Error getting model ID from key ID: {e}") 63 | self.conn.rollback() 64 | return None 65 | 66 | def key_has_access_to_model(self, key_id: str, model_id: str) -> bool: 67 | """Check if a key has access to a model""" 68 | try: 69 | with self.conn.cursor() as cur: 70 | cur.execute(""" 71 | SELECT EXISTS( 72 | SELECT 1 FROM key_tier_assignments 73 | JOIN model_tier_assignments ON key_tier_assignments.tier_id = model_tier_assignments.tier_id 74 | WHERE key_tier_assignments.key_id = %s AND model_tier_assignments.model_id = %s 75 | ) 76 | """, (key_id, model_id)) 77 | result = cur.fetchone() 78 | return result[0] if result else False 79 | except Exception as e: 80 | logger.error(f"Error checking if key has access to model: {e}") 81 | self.conn.rollback() 82 | return False 83 | 84 | 85 | host = os.environ.get("POSTGRES_HOST") 86 | port = os.environ.get("POSTGRES_PORT") 87 | database = os.environ.get("POSTGRES_DB") 88 | user = os.environ.get("POSTGRES_USER") 89 | password = os.environ.get("POSTGRES_PASSWORD") 90 | 91 | 92 | api_key_store = None 93 | if host is not None: 94 | api_key_store = AccountsDB(host, port, database, user, password) 95 | 96 | def extract_request_metadata(raw_request: "Request") -> dict: 97 | """ 98 | Extracts relevant metadata from the incoming raw request, such as IP address, 99 | user agent, and content length, and returns them as a dictionary. 100 | """ 101 | metadata = { 102 | "ip_address": raw_request.client.host, 103 | "user_agent": raw_request.headers.get("user-agent"), 104 | "content_length": int(raw_request.headers.get("content-length", 0)), 105 | } 106 | return metadata 107 | 108 | 109 | def api_key_auth( 110 | raw_request: "Request", 111 | request: "BackendRequestModel", 112 | ) -> None: 113 | """ 114 | Authenticates the API request by extracting metadata and initializing the BackendRequestModel 115 | with relevant information, including API key, client details, and headers. 116 | 117 | Args: 118 | - raw_request (Request): user request. 119 | - request (BackendRequestModel): user request object. 120 | 121 | Returns: 122 | """ 123 | 124 | metadata = extract_request_metadata(raw_request) 125 | 126 | ip_address, user_agent, content_length = metadata.values() 127 | NetworkStatusMetric.update(request, ip_address, user_agent, content_length) 128 | 129 | # For local development, we don't want to check the API key 130 | if host is None: 131 | return 132 | 133 | # TODO: There should be some form of caching here 134 | # TODO: I should reintroduce the user email check here (unless we choose not to migrate keys which are missing an email) 135 | 136 | # Check if the API key exists and is valid 137 | if not api_key_store.api_key_exists(request.api_key): 138 | raise HTTPException( 139 | status_code=HTTP_401_UNAUTHORIZED, 140 | detail="Missing or invalid API key. Please visit https://login.ndif.us/ to create a new one.", 141 | ) 142 | model_key = request.model_key.lower() 143 | # Get the model ID from the API key 144 | model_id = api_key_store.model_id_from_key(model_key) 145 | if not model_id: 146 | # Let them have access by default (to support future usecase of dynamic model loading) 147 | return 148 | 149 | # Check if the model has access to the API key 150 | if not api_key_store.key_has_access_to_model(request.api_key, model_id): 151 | raise HTTPException( 152 | status_code=HTTP_401_UNAUTHORIZED, 153 | detail=f"API key does not have authorization to access the requested model: {model_key}.", 154 | ) -------------------------------------------------------------------------------- /src/services/api/src/app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import threading 4 | import time 5 | import traceback 6 | from contextlib import asynccontextmanager 7 | from datetime import datetime 8 | from typing import Any, Dict 9 | import uuid 10 | 11 | import ray 12 | import socketio 13 | import uvicorn 14 | import boto3 15 | from fastapi import FastAPI, Request, Security 16 | from fastapi.middleware.cors import CORSMiddleware 17 | from fastapi.responses import StreamingResponse 18 | from fastapi.security.api_key import APIKeyHeader 19 | from fastapi_cache import FastAPICache 20 | from fastapi_cache.backends.inmemory import InMemoryBackend 21 | from fastapi_cache.decorator import cache 22 | from fastapi_socketio import SocketManager 23 | from influxdb_client import Point 24 | from prometheus_fastapi_instrumentator import Instrumentator 25 | from ray import serve 26 | 27 | from nnsight.schema.response import ResponseModel 28 | 29 | from .logging import load_logger 30 | 31 | logger = load_logger(service_name="API", logger_name="API") 32 | 33 | 34 | from .api_key import api_key_auth 35 | from .metrics import TransportLatencyMetric 36 | from .schema import BackendRequestModel, BackendResponseModel, BackendResultModel 37 | 38 | 39 | 40 | @asynccontextmanager 41 | async def lifespan(app: FastAPI): 42 | FastAPICache.init(InMemoryBackend()) 43 | yield 44 | 45 | 46 | # Init FastAPI app 47 | app = FastAPI(lifespan=lifespan) 48 | # Add middleware for CORS 49 | app.add_middleware( 50 | CORSMiddleware, 51 | allow_origins=["*"], 52 | allow_credentials=False, 53 | allow_methods=["*"], 54 | allow_headers=["*"], 55 | ) 56 | 57 | # Init async manager for communication between socketio servers 58 | socketio_manager = socketio.AsyncRedisManager(url=os.environ.get("BROKER_URL")) 59 | # Init socketio manager app 60 | sm = SocketManager( 61 | app=app, 62 | mount_location="/ws", 63 | client_manager=socketio_manager, 64 | max_http_buffer_size=1000000000000000, 65 | ping_timeout=60, 66 | always_connect=True, 67 | ) 68 | 69 | # Init object_store connection 70 | object_store = boto3.client( 71 | 's3', 72 | endpoint_url=f"http://{os.environ.get('OBJECT_STORE_URL')}", 73 | aws_access_key_id=os.environ.get("OBJECT_STORE_ACCESS_KEY", "minioadmin"), 74 | aws_secret_access_key=os.environ.get("OBJECT_STORE_SECRET_KEY", "minioadmin"), 75 | region_name='us-east-1', 76 | # Skip verification for local or custom S3 implementations 77 | verify=False, 78 | # Set to path style for compatibility with non-AWS S3 implementations 79 | config=boto3.session.Config(signature_version='s3v4', s3={'addressing_style': 'path'}) 80 | ) 81 | 82 | # Init Ray connection 83 | RAY_RETRY_INTERVAL_S = os.environ.get("RAY_RETRY_INTERVAL_S", 5) 84 | 85 | def connect_to_ray(): 86 | while True: 87 | try: 88 | if not ray.is_initialized(): 89 | ray.shutdown() 90 | serve.context._set_global_client(None) 91 | ray.init(logging_level="error") 92 | logger.info("Connected to Ray cluster.") 93 | except Exception as e: 94 | logger.error(f"Failed to connect to Ray cluster: {e}") 95 | 96 | time.sleep(RAY_RETRY_INTERVAL_S) 97 | 98 | 99 | # Start the background thread 100 | ray_watchdog = threading.Thread(target=connect_to_ray, daemon=True) 101 | ray_watchdog.start() 102 | 103 | # Prometheus instrumentation (for metrics) 104 | Instrumentator().instrument(app).expose(app) 105 | 106 | api_key_header = APIKeyHeader(name="ndif-api-key", auto_error=False) 107 | 108 | 109 | @app.post("/request") 110 | async def request( 111 | raw_request: Request, api_key: str = Security(api_key_header) 112 | ) -> BackendResponseModel: 113 | """Endpoint to submit request. 114 | 115 | Header: 116 | - api_key: user api key. 117 | 118 | Request Body: 119 | raw_request (Request): user request containing the intervention graph. 120 | 121 | Returns: 122 | BackendResponseModel: reponse to the user request. 123 | """ 124 | 125 | # extract the request data 126 | 127 | request: BackendRequestModel = BackendRequestModel.from_request( 128 | raw_request, api_key 129 | ) 130 | 131 | # process the request 132 | try: 133 | 134 | TransportLatencyMetric.update(request) 135 | 136 | response = request.create_response( 137 | status=ResponseModel.JobStatus.RECEIVED, 138 | description="Your job has been received and is waiting approval.", 139 | logger=logger, 140 | ) 141 | 142 | # authenticate api key 143 | api_key_auth(raw_request, request) 144 | 145 | request.graph = await request.graph 146 | request.graph = ray.put(request.graph) 147 | 148 | # Send to request workers waiting to process requests on the "request" queue. 149 | # Forget as we don't care about the response. 150 | serve.get_app_handle("Request").remote(request) 151 | 152 | # Back up request object by default (to be deleted on successful completion) 153 | # request = request.model_copy() 154 | # request.object = object 155 | # request.save(object_store) 156 | except Exception as exception: 157 | 158 | if 'ray ' in str(exception).lower(): 159 | description = "Issue with Ray. NDIF compute backend must be down :(" 160 | else: 161 | description = f"{traceback.format_exc()}\n{str(exception)}" 162 | 163 | # Create exception response object. 164 | response = request.create_response( 165 | status=ResponseModel.JobStatus.ERROR, 166 | description=description, 167 | logger=logger, 168 | ) 169 | 170 | if not response.blocking: 171 | 172 | response.save(object_store) 173 | 174 | # Return response. 175 | return response 176 | 177 | 178 | @sm.on("connect") 179 | async def connect(session_id: str, environ: Dict): 180 | params = environ.get("QUERY_STRING") 181 | params = dict(x.split("=") for x in params.split("&")) 182 | 183 | if "job_id" in params: 184 | 185 | await sm.enter_room(session_id, params["job_id"]) 186 | 187 | 188 | @sm.on("blocking_response") 189 | async def blocking_response(session_id: str, client_session_id: str, data: Any): 190 | 191 | await sm.emit("blocking_response", data=data, to=client_session_id) 192 | 193 | 194 | @sm.on("stream_upload") 195 | async def stream_upload(session_id: str, data: bytes, job_id: str): 196 | 197 | await sm.emit("stream_upload", data=data, room=job_id) 198 | 199 | 200 | @app.get("/response/{id}") 201 | async def response(id: str) -> BackendResponseModel: 202 | """Endpoint to get latest response for id. 203 | 204 | Args: 205 | id (str): ID of request/response. 206 | 207 | Returns: 208 | BackendResponseModel: Response. 209 | """ 210 | 211 | # Load response from client given id. 212 | return BackendResponseModel.load(object_store, id) 213 | 214 | 215 | @app.get("/result/{id}") 216 | async def result(id: str) -> BackendResultModel: 217 | """Endpoint to retrieve result for id. 218 | 219 | Args: 220 | id (str): ID of request/response. 221 | 222 | Returns: 223 | BackendResultModel: Result. 224 | 225 | Yields: 226 | Iterator[BackendResultModel]: _description_ 227 | """ 228 | 229 | # Get cursor to bytes stored in data backend. 230 | object, content_length = BackendResultModel.load(object_store, id, stream=True) 231 | 232 | # Inform client the total size of result in bytes. 233 | headers = { 234 | "Content-length": str(content_length), 235 | } 236 | 237 | def stream(): 238 | try: 239 | while True: 240 | data = object.read(8192) 241 | if not data: 242 | break 243 | yield data 244 | finally: 245 | object.close() 246 | 247 | BackendResultModel.delete(object_store, id) 248 | BackendResponseModel.delete(object_store, id) 249 | BackendRequestModel.delete(object_store, id) 250 | 251 | return StreamingResponse( 252 | content=stream(), 253 | media_type="application/octet-stream", 254 | headers=headers, 255 | ) 256 | 257 | 258 | @app.get("/ping", status_code=200) 259 | async def ping(): 260 | """Endpoint to check if the server is online. 261 | 262 | Returns: 263 | _type_: _description_ 264 | """ 265 | return "pong" 266 | 267 | 268 | @app.get("/stats", status_code=200) 269 | @cache(expire=600) 270 | async def status(): 271 | 272 | response = {} 273 | 274 | status = serve.status() 275 | 276 | model_configurations = await serve.get_app_handle( 277 | "Controller" 278 | ).get_model_configurations.remote() 279 | 280 | for application_name, application in status.applications.items(): 281 | 282 | if application_name.startswith("Model"): 283 | 284 | deployment = application.deployments["ModelDeployment"] 285 | 286 | num_running_replicas = 0 287 | 288 | for replica_status in deployment.replica_states: 289 | 290 | if replica_status == "RUNNING": 291 | 292 | num_running_replicas += 1 293 | 294 | if num_running_replicas > 0: 295 | 296 | config = model_configurations[application_name] 297 | 298 | response[application_name] = { 299 | "num_running_replicas": num_running_replicas, 300 | **config, 301 | } 302 | 303 | return response 304 | 305 | 306 | if __name__ == "__main__": 307 | uvicorn.run(app, host="0.0.0.0", port=8001, workers=1) 308 | -------------------------------------------------------------------------------- /src/services/api/src/gunicorn.conf.py: -------------------------------------------------------------------------------- 1 | # from opentelemetry import trace 2 | # from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter 3 | # from opentelemetry.sdk.resources import Resource 4 | # from opentelemetry.sdk.trace import TracerProvider 5 | # from opentelemetry.sdk.trace.export import BatchSpanProcessor 6 | # import os 7 | 8 | # def post_fork(server, worker): 9 | # server.log.info("Worker spawned (pid: %s)", worker.pid) 10 | 11 | # resource = Resource.create(attributes={ 12 | # "service.name": "api-service" 13 | # }) 14 | 15 | # trace.set_tracer_provider(TracerProvider(resource=resource)) 16 | # span_processor = BatchSpanProcessor( 17 | # OTLPSpanExporter(endpoint=os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")) 18 | # ) 19 | # trace.get_tracer_provider().add_span_processor(span_processor) -------------------------------------------------------------------------------- /src/services/api/src/logging: -------------------------------------------------------------------------------- 1 | ../../../common/logging/ -------------------------------------------------------------------------------- /src/services/api/src/metrics: -------------------------------------------------------------------------------- 1 | ../../../common/metrics/ -------------------------------------------------------------------------------- /src/services/api/src/schema: -------------------------------------------------------------------------------- 1 | ../../../common/schema/ -------------------------------------------------------------------------------- /src/services/api/src/util.py: -------------------------------------------------------------------------------- 1 | 2 | def check_valid_email(user_id : str) -> bool: 3 | '''Helper function which verifies that the `user_id` field contains a "valid" email.''' 4 | if user_id != '' and user_id is not None and '@' in user_id: 5 | return True 6 | return False -------------------------------------------------------------------------------- /src/services/api/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Use API_INTERNAL_PORT if defined, otherwise default to 80 4 | PORT="${API_INTERNAL_PORT:-80}" 5 | 6 | gunicorn src.app:app --bind 0.0.0.0:$PORT --workers $WORKERS --worker-class uvicorn.workers.UvicornWorker --timeout 120 7 | -------------------------------------------------------------------------------- /src/services/base/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | dependencies: 4 | - python=3.10 5 | - pip 6 | - git 7 | - pip: 8 | - pip 9 | - setuptools 10 | - git+https://github.com/ndif-team/nnsight@dev 11 | # Telemetry 12 | - prometheus_client 13 | - python-logging-loki 14 | -------------------------------------------------------------------------------- /src/services/ray/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | dependencies: 4 | - python=3.10 5 | - pip 6 | - git 7 | - pip: 8 | - pip 9 | - setuptools 10 | - nnsight 11 | - ray[serve]==2.45.0 12 | - python-slugify 13 | # Telemetry 14 | - prometheus_client 15 | - python-logging-loki 16 | # Database 17 | - boto3 18 | - influxdb-client 19 | -------------------------------------------------------------------------------- /src/services/ray/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ndif-team/ndif/546316aac1f625cbeb2d93a0d2be47e856623452/src/services/ray/src/__init__.py -------------------------------------------------------------------------------- /src/services/ray/src/logging: -------------------------------------------------------------------------------- 1 | ../../../common/logging/ -------------------------------------------------------------------------------- /src/services/ray/src/metrics: -------------------------------------------------------------------------------- 1 | ../../../common/metrics/ -------------------------------------------------------------------------------- /src/services/ray/src/ray/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ndif-team/ndif/546316aac1f625cbeb2d93a0d2be47e856623452/src/services/ray/src/ray/__init__.py -------------------------------------------------------------------------------- /src/services/ray/src/ray/config/ray_config.yml: -------------------------------------------------------------------------------- 1 | proxy_location: Disabled 2 | http_options: 3 | host: 0.0.0.0 4 | port: 5005 5 | grpc_options: 6 | port: 9000 7 | grpc_servicer_functions: [] 8 | logging_config: 9 | encoding: TEXT 10 | log_level: INFO 11 | logs_dir: null 12 | enable_access_log: true 13 | applications: 14 | - name: Controller 15 | import_path: src.ray.deployments.controller:app 16 | args: 17 | ray_config_path: src/ray/config/ray_config.yml 18 | service_config_path: src/ray/config/service_config.yml 19 | ray_dashboard_url: http://localhost:8265 20 | database_url: mongodb://user:pass@localhost:27017 21 | api_url: http://localhost:80 -------------------------------------------------------------------------------- /src/services/ray/src/ray/config/service_config.yml: -------------------------------------------------------------------------------- 1 | default_model_import_path: src.ray.deployments.model:app 2 | request_import_path: src.ray.deployments.request:app 3 | request_num_replicas: 1 4 | models: 5 | # - model_key: 'nnsight.models.LanguageModel.LanguageModel:{"repo_id": "openai-community/gpt2"}' 6 | # num_replicas: 1 7 | # ray_actor_options: 8 | # resources: 9 | # cuda_memory_MB: 15000 10 | 11 | 12 | - model_key: 'nnsight.models.LanguageModel.LanguageModel:{"repo_id": "meta-llama/Meta-Llama-3-8B"}' 13 | num_replicas: 1 14 | model_import_path: src.ray.deployments.distributed_model:app 15 | args: 16 | torch_distributed_port: 5003 17 | torch_distributed_world_size: 2 18 | torch_distributed_world_timeout_seconds: 40 19 | tensor_parallelism_size: 2 20 | ray_actor_options: 21 | num_gpus: 1 22 | 23 | 24 | # - model_key: 'nnsight.models.LanguageModel.LanguageModel:{"repo_id": "meta-llama/Meta-Llama-3-8B"}' 25 | # num_replicas: 1 26 | 27 | # ray_actor_options: 28 | # num_gpus: 1 -------------------------------------------------------------------------------- /src/services/ray/src/ray/deployments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ndif-team/ndif/546316aac1f625cbeb2d93a0d2be47e856623452/src/services/ray/src/ray/deployments/__init__.py -------------------------------------------------------------------------------- /src/services/ray/src/ray/deployments/base.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import os 4 | import sys 5 | import time 6 | import traceback 7 | import weakref 8 | from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError 9 | from functools import wraps 10 | from typing import Any, Dict 11 | 12 | import ray 13 | import socketio 14 | import torch 15 | import boto3 16 | from pydantic import BaseModel, ConfigDict 17 | from ray import serve 18 | from torch.amp import autocast 19 | from torch.cuda import max_memory_allocated, memory_allocated, reset_peak_memory_stats 20 | 21 | from nnsight.intervention.contexts import RemoteContext 22 | from nnsight.modeling.mixins import RemoteableMixin 23 | from nnsight.schema.request import StreamValueModel 24 | from nnsight.tracing.backends import Backend 25 | from nnsight.tracing.graph import Graph 26 | from nnsight.tracing.protocols import StopProtocol 27 | from nnsight.util import NNsightError 28 | 29 | from ...logging import load_logger 30 | from ...metrics import GPUMemMetric, ExecutionTimeMetric, RequestResponseSizeMetric 31 | from ...schema import ( 32 | RESULT, 33 | BackendRequestModel, 34 | BackendResponseModel, 35 | BackendResultModel, 36 | ) 37 | from ..util import set_cuda_env_var 38 | from . import protocols 39 | 40 | 41 | class ExtractionBackend(Backend): 42 | 43 | def __call__(self, graph: Graph) -> RESULT: 44 | 45 | try: 46 | 47 | graph.nodes[-1].execute() 48 | 49 | result = BackendResultModel.from_graph(graph) 50 | 51 | except StopProtocol.StopException: 52 | 53 | result = BackendResultModel.from_graph(graph) 54 | 55 | finally: 56 | 57 | graph.nodes.clear() 58 | graph.stack.clear() 59 | 60 | return result 61 | 62 | 63 | class BaseDeployment: 64 | 65 | def __init__( 66 | self, 67 | api_url: str, 68 | object_store_url: str, 69 | object_store_access_key: str, 70 | object_store_secret_key: str, 71 | ) -> None: 72 | 73 | super().__init__() 74 | 75 | self.api_url = api_url 76 | self.object_store_url = object_store_url 77 | self.object_store_access_key = object_store_access_key 78 | self.object_store_secret_key = object_store_secret_key 79 | 80 | # Initialize S3 client (either AWS S3 or compatible service like MinIO) 81 | self.object_store = boto3.client( 82 | 's3', 83 | endpoint_url=f"http://{self.object_store_url}", 84 | aws_access_key_id=self.object_store_access_key, 85 | aws_secret_access_key=self.object_store_secret_key, 86 | # Skip verification for local or custom S3 implementations 87 | verify=False, 88 | # Set to path style for compatibility with non-AWS S3 implementations 89 | config=boto3.session.Config(signature_version='s3v4', s3={'addressing_style': 'path'}) 90 | ) 91 | 92 | self.sio = socketio.SimpleClient(reconnection_attempts=10) 93 | 94 | self.logger = load_logger( 95 | service_name=str(self.__class__), logger_name="ray.serve" 96 | ) 97 | 98 | try: 99 | self.replica_context = serve.get_replica_context() 100 | except: 101 | self.replica_context = None 102 | 103 | 104 | class BaseDeploymentArgs(BaseModel): 105 | 106 | model_config = ConfigDict(arbitrary_types_allowed=True) 107 | 108 | api_url: str 109 | object_store_url: str 110 | object_store_access_key: str 111 | object_store_secret_key: str 112 | 113 | 114 | def threaded(method, size: int = 1): 115 | 116 | group = ThreadPoolExecutor(size) 117 | 118 | @wraps(method) 119 | def inner(*args, **kwargs): 120 | 121 | return group.submit(method, *args, **kwargs) 122 | 123 | return inner 124 | 125 | 126 | class BaseModelDeployment(BaseDeployment): 127 | 128 | def __init__( 129 | self, 130 | model_key: str, 131 | execution_timeout: float | None, 132 | device_map: str | None, 133 | dispatch: bool, 134 | dtype: str | torch.dtype, 135 | *args, 136 | extra_kwargs: Dict[str, Any] = {}, 137 | **kwargs, 138 | ) -> None: 139 | 140 | super().__init__(*args, **kwargs) 141 | 142 | if os.environ.get("CUDA_VISIBLE_DEVICES", "") == "": 143 | set_cuda_env_var() 144 | 145 | self.model_key = model_key 146 | self.execution_timeout = execution_timeout 147 | 148 | if isinstance(dtype, str): 149 | 150 | dtype = getattr(torch, dtype) 151 | 152 | torch.set_default_dtype(torch.bfloat16) 153 | 154 | self.model = RemoteableMixin.from_model_key( 155 | self.model_key, 156 | device_map=device_map, 157 | dispatch=dispatch, 158 | torch_dtype=dtype, 159 | **extra_kwargs, 160 | ) 161 | 162 | if dispatch: 163 | self.model._model.requires_grad_(False) 164 | 165 | torch.cuda.empty_cache() 166 | 167 | self.request: BackendRequestModel 168 | 169 | protocols.LogProtocol.set(lambda *args: self.log(*args)) 170 | 171 | RemoteContext.set(self.stream_send, self.stream_receive) 172 | 173 | def __call__(self, request: BackendRequestModel) -> None: 174 | """Executes the model service pipeline: 175 | 176 | 1.) Pre-processing 177 | 2.) Execution 178 | 3.) Post-processing 179 | 4.) Cleanup 180 | 181 | Args: 182 | request (BackendRequestModel): Request. 183 | """ 184 | 185 | self.request = weakref.proxy(request) 186 | 187 | try: 188 | 189 | result = None 190 | 191 | inputs = self.pre() 192 | 193 | with autocast(device_type="cuda", dtype=torch.get_default_dtype()): 194 | 195 | result = self.execute(inputs) 196 | 197 | if isinstance(result, Future): 198 | result = result.result(timeout=self.execution_timeout) 199 | 200 | self.post(result) 201 | 202 | except TimeoutError as e: 203 | 204 | exception = Exception( 205 | f"Job took longer than timeout: {self.execution_timeout} seconds" 206 | ) 207 | 208 | self.exception(exception) 209 | 210 | except Exception as e: 211 | 212 | self.exception(e) 213 | 214 | finally: 215 | 216 | del request 217 | del result 218 | 219 | self.cleanup() 220 | 221 | # Ray checks this method and restarts replica if it raises an exception 222 | def check_health(self): 223 | pass 224 | 225 | ### ABSTRACT METHODS ################################# 226 | 227 | def pre(self) -> Graph: 228 | """Logic to execute before execution.""" 229 | graph = self.request.deserialize(self.model) 230 | 231 | self.respond( 232 | status=BackendResponseModel.JobStatus.RUNNING, 233 | description="Your job has started running.", 234 | ) 235 | 236 | return graph 237 | 238 | def execute(self, graph: Graph) -> Any: 239 | """Execute request. 240 | 241 | Args: 242 | request (BackendRequestModel): Request. 243 | 244 | Returns: 245 | Any: Result. 246 | """ 247 | 248 | # For tracking peak GPU usage 249 | if torch.cuda.is_available(): 250 | reset_peak_memory_stats() 251 | model_memory = memory_allocated() 252 | 253 | execution_time = time.time() 254 | 255 | # Execute object. 256 | result = ExtractionBackend()(graph) 257 | 258 | execution_time = time.time() - execution_time 259 | 260 | # Compute GPU memory usage 261 | if torch.cuda.is_available(): 262 | gpu_mem = max_memory_allocated() - model_memory 263 | else: 264 | gpu_mem = 0 265 | 266 | return result, gpu_mem, execution_time 267 | 268 | def post(self, result: Any) -> None: 269 | """Logic to execute after execution with result from `.execute`. 270 | 271 | Args: 272 | request (BackendRequestModel): Request. 273 | result (Any): Result. 274 | """ 275 | 276 | saves = result[0] 277 | gpu_mem: int = result[1] 278 | execution_time_s: float = result[2] 279 | 280 | result = BackendResultModel( 281 | id=self.request.id, 282 | result=saves, 283 | ).save(self.object_store) 284 | 285 | self.respond( 286 | status=BackendResponseModel.JobStatus.COMPLETED, 287 | description="Your job has been completed.", 288 | ) 289 | 290 | RequestResponseSizeMetric.update(self.request, result.size) 291 | GPUMemMetric.update(self.request, gpu_mem) 292 | ExecutionTimeMetric.update(self.request, execution_time_s) 293 | 294 | def exception(self, exception: Exception) -> None: 295 | """Handles exceptions that occur during model execution. 296 | 297 | This method processes different types of exceptions and sends appropriate error responses 298 | back to the client. For NNsight-specific errors, it includes detailed traceback information. 299 | For other errors, it includes the full exception traceback and message. 300 | 301 | Args: 302 | exception (Exception): The exception that was raised during __call__. 303 | """ 304 | if isinstance(exception, NNsightError): 305 | # Remove traceback limit to get full stack trace 306 | sys.tracebacklimit = None 307 | self.respond( 308 | status=BackendResponseModel.JobStatus.NNSIGHT_ERROR, 309 | description=f"An error has occured during the execution of the intervention graph.\n{exception.traceback_content}", 310 | data={ 311 | "err_message": exception.message, 312 | "node_id": exception.node_id, 313 | "traceback": exception.traceback_content, 314 | }, 315 | ) 316 | else: 317 | # For non-NNsight errors, include full traceback 318 | description = traceback.format_exc() 319 | self.respond( 320 | status=BackendResponseModel.JobStatus.ERROR, 321 | description=f"{description}\n{str(exception)}", 322 | ) 323 | 324 | # Special handling for CUDA device-side assertion errors 325 | if "device-side assert triggered" in str(exception): 326 | self.restart() 327 | 328 | def restart(self): 329 | """Restarts the Ray serve deployment in response to critical errors. 330 | 331 | This is typically called when encountering CUDA device-side assertion errors 332 | or other critical failures that require a fresh replica state. 333 | """ 334 | app_name = serve.get_replica_context().app_name 335 | serve.get_app_handle("Controller").restart.remote(app_name) 336 | 337 | def cleanup(self): 338 | """Performs cleanup operations after request processing. 339 | 340 | This method: 341 | 1. Disconnects from socketio if connected 342 | 2. Zeros out model gradients 343 | 3. Forces garbage collection 344 | 4. Clears CUDA cache 345 | 346 | This cleanup is important for preventing memory leaks and ensuring 347 | the replica is ready for the next request. 348 | """ 349 | if self.sio.connected: 350 | self.sio.disconnect() 351 | 352 | self.model._model.zero_grad() 353 | gc.collect() 354 | torch.cuda.empty_cache() 355 | 356 | def log(self, *data): 357 | """Logs data during model execution. 358 | 359 | This method is used to send log messages back to the client through 360 | the websocket connection. It joins all provided data into a single string 361 | and sends it as a LOG status response. 362 | 363 | Args: 364 | *data: Variable number of arguments to be converted to strings and logged. 365 | """ 366 | description = "".join([str(_data) for _data in data]) 367 | self.respond(status=BackendResponseModel.JobStatus.LOG, description=description) 368 | 369 | def stream_send(self, data: Any): 370 | """Sends streaming data back to the client. 371 | 372 | This method is used to send intermediate results or progress updates 373 | during model execution. It wraps the data in a STREAM status response. 374 | 375 | Args: 376 | data (Any): The data to stream back to the client. 377 | """ 378 | self.respond(status=BackendResponseModel.JobStatus.STREAM, data=data) 379 | 380 | def stream_receive(self, *args): 381 | """Receives streaming data from the client. 382 | 383 | This method establishes a websocket connection if needed and waits 384 | for data from the client. It has a 5-second timeout for receiving data. 385 | 386 | Returns: 387 | The deserialized data received from the client. 388 | """ 389 | self.stream_connect() 390 | return StreamValueModel.deserialize(self.sio.receive(5)[1], "json", True) 391 | 392 | def stream_connect(self): 393 | """Establishes a websocket connection if one doesn't exist. 394 | 395 | This method ensures that there is an active websocket connection 396 | before attempting to send or receive data. It: 397 | 1. Checks if a connection exists 398 | 2. If not, creates a new connection with appropriate parameters 399 | 3. Adds a small delay to ensure the connection is fully established 400 | 401 | The connection is established with: 402 | - WebSocket transport only (no polling fallback) 403 | - 10-second timeout for connection establishment 404 | - Job ID included in the connection URL for proper routing of receiving stream data from the user. 405 | """ 406 | if self.sio.client is None or not self.sio.connected: 407 | self.sio.connected = False 408 | self.sio.connect( 409 | f"{self.api_url}?job_id={self.request.id}", 410 | socketio_path="/ws/socket.io", 411 | transports=["websocket"], 412 | wait_timeout=10, 413 | ) 414 | # Wait for connection to be fully established 415 | time.sleep(0.1) # Small delay to ensure connection is ready 416 | 417 | def respond(self, **kwargs) -> None: 418 | """Sends a response back to the client. 419 | 420 | This method handles sending responses through either websocket 421 | or object store, depending on whether a session_id exists. 422 | 423 | If session_id exists: 424 | 1. Establishes websocket connection if needed 425 | 2. Sends response through websocket 426 | 427 | If no session_id: 428 | 1. Saves response to object store 429 | 430 | Args: 431 | **kwargs: Arguments to be passed to create_response, including: 432 | - status: The job status 433 | - description: Human-readable status description 434 | - data: Optional additional data 435 | """ 436 | if self.request.session_id is not None: 437 | self.stream_connect() 438 | 439 | self.request.create_response(**kwargs, logger=self.logger).respond( 440 | self.sio, self.object_store 441 | ) 442 | 443 | 444 | class BaseModelDeploymentArgs(BaseDeploymentArgs): 445 | 446 | model_key: str 447 | execution_timeout: float | None = None 448 | device_map: str | None = "auto" 449 | dispatch: bool = True 450 | dtype: str | torch.dtype = torch.bfloat16 451 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/deployments/controller.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import Dict 4 | 5 | from pydantic import BaseModel 6 | from ray import serve 7 | from ray.serve import Application 8 | 9 | from ..raystate import RayState 10 | 11 | 12 | @serve.deployment(ray_actor_options={"num_cpus": 1, "resources": {"head": 1}}) 13 | class ControllerDeployment: 14 | def __init__( 15 | self, 16 | ray_config_path: str, 17 | service_config_path: str, 18 | object_store_url: str, 19 | object_store_access_key: str, 20 | object_store_secret_key: str, 21 | api_url: str, 22 | ): 23 | self.ray_config_path = ray_config_path 24 | self.service_config_path = service_config_path 25 | self.object_store_url = object_store_url 26 | self.object_store_access_key = object_store_access_key 27 | self.object_store_secret_key = object_store_secret_key 28 | 29 | self.api_url = api_url 30 | 31 | self.state = RayState( 32 | self.ray_config_path, 33 | self.service_config_path, 34 | self.object_store_url, 35 | self.object_store_access_key, 36 | self.object_store_secret_key, 37 | self.api_url, 38 | ) 39 | 40 | self.state.redeploy() 41 | 42 | self.model_configurations = {} 43 | 44 | async def redeploy(self): 45 | """Redeploy serve configuration using service_config.yml""" 46 | 47 | self.state.redeploy() 48 | 49 | async def restart(self, name: str): 50 | 51 | self.state.name_to_application[name].runtime_env["env_vars"]["restart_hash"] = ( 52 | str(uuid.uuid4()) 53 | ) 54 | 55 | self.state.apply() 56 | 57 | async def set_model_configuration(self, name:str, configuration: Dict): 58 | 59 | self.model_configurations[name] = configuration 60 | 61 | async def get_model_configurations(self): 62 | 63 | return self.model_configurations 64 | 65 | 66 | class ControllerDeploymentArgs(BaseModel): 67 | 68 | ray_config_path: str = os.environ.get("RAY_CONFIG_PATH", None) 69 | service_config_path: str = os.environ.get("SERVICE_CONFIG_PATH", None) 70 | object_store_url: str = os.environ.get("OBJECT_STORE_URL", None) 71 | object_store_access_key: str = os.environ.get( 72 | "OBJECT_STORE_ACCESS_KEY", "minioadmin" 73 | ) 74 | object_store_secret_key: str = os.environ.get( 75 | "OBJECT_STORE_SECRET_KEY", "minioadmin" 76 | ) 77 | api_url: str = os.environ.get("API_URL", None) 78 | 79 | 80 | def app(args: ControllerDeploymentArgs) -> Application: 81 | return ControllerDeployment.bind(**args.model_dump()) 82 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/deployments/distributed_model.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Any, Dict 3 | 4 | import ray 5 | import torch 6 | import torch.distributed 7 | from ray import serve 8 | from ray.serve import Application 9 | 10 | from nnsight.tracing.graph import Graph 11 | 12 | from ...schema import ( 13 | RESULT, 14 | BackendRequestModel, 15 | BackendResponseModel, 16 | BackendResultModel, 17 | ) 18 | from ..distributed.parallel_dims import ParallelDims 19 | from ..distributed.tensor_parallelism import parallelize_model 20 | from ..distributed.util import load_hf_model_from_cache, patch_intervention_protocol 21 | from ..util import NNsightTimer 22 | from .base import BaseModelDeployment, BaseModelDeploymentArgs 23 | 24 | 25 | class _ModelDeployment(BaseModelDeployment): 26 | def __init__( 27 | self, 28 | torch_distributed_address: str, 29 | torch_distributed_port: int, 30 | torch_distributed_world_size: int, 31 | torch_distributed_world_rank: int, 32 | torch_distributed_world_timeout_seconds: int, 33 | data_parallelism_size: int, 34 | tensor_parallelism_size: int, 35 | pipeline_parallelism_size: int, 36 | *args, 37 | **kwargs, 38 | ): 39 | 40 | super().__init__( 41 | *args, 42 | extra_kwargs={"meta_buffers": False, "patch_llama_scan": False}, 43 | **kwargs, 44 | ) 45 | 46 | self.torch_distributed_address = torch_distributed_address 47 | self.torch_distributed_port = torch_distributed_port 48 | self.torch_distributed_world_size = torch_distributed_world_size 49 | self.torch_distributed_world_rank = torch_distributed_world_rank 50 | self.torch_distributed_world_timeout_seconds = ( 51 | torch_distributed_world_timeout_seconds 52 | ) 53 | self.data_parallelism_size = data_parallelism_size 54 | self.tensor_parallelism_size = tensor_parallelism_size 55 | self.pipeline_parallelism_size = pipeline_parallelism_size 56 | 57 | # Patches nnsight intervention protocol to handle DTensors. 58 | patch_intervention_protocol() 59 | 60 | self.head = torch_distributed_world_rank == 0 61 | 62 | if self.head: 63 | 64 | print("Initializing distributed head...") 65 | 66 | if self.torch_distributed_address is None: 67 | 68 | ip_address = ray.get_runtime_context().worker.node_ip_address 69 | self.torch_distributed_address = ( 70 | f"tcp://{ip_address}:{self.torch_distributed_port}" 71 | ) 72 | 73 | print(f"=> Torch distributed address: {self.torch_distributed_address}") 74 | 75 | self.worker_actors = [] 76 | 77 | for worker_world_rank in range(1, self.torch_distributed_world_size): 78 | 79 | name = f"Shard-{worker_world_rank}:{self.replica_context.app_name}" 80 | 81 | distributed_model_deployment_args = DistributedModelDeploymentArgs( 82 | model_key=self.model_key, 83 | api_url=self.api_url, 84 | object_store_url=self.object_store_url, 85 | object_store_access_key=self.object_store_access_key, 86 | object_store_secret_key=self.object_store_secret_key, 87 | torch_distributed_address=self.torch_distributed_address, 88 | torch_distributed_world_size=self.torch_distributed_world_size, 89 | torch_distributed_world_rank=worker_world_rank, 90 | torch_distributed_world_timeout_seconds=self.torch_distributed_world_timeout_seconds, 91 | tensor_parallelism_size=self.tensor_parallelism_size, 92 | data_parallelism_size=self.data_parallelism_size, 93 | pipeline_parallelism_size=self.pipeline_parallelism_size, 94 | execution_timeout=self.execution_timeout, 95 | ) 96 | 97 | print(f"=> Creating distributed worker: {worker_world_rank}...") 98 | 99 | worker_actor = ModelDeploymentShard.options(name=name).remote( 100 | **distributed_model_deployment_args.model_dump() 101 | ) 102 | 103 | self.worker_actors.append(worker_actor) 104 | 105 | print(f"=> Created distributed worker: {worker_world_rank}.") 106 | 107 | print(f"Initialized distributed head.") 108 | 109 | self.init_distributed() 110 | 111 | self.timer = NNsightTimer(self.execution_timeout) 112 | 113 | def init_process_group(self): 114 | 115 | print("Initializing torch.distributed process group...") 116 | 117 | torch.distributed.init_process_group( 118 | "nccl", 119 | init_method=self.torch_distributed_address, 120 | timeout=timedelta(seconds=self.torch_distributed_world_timeout_seconds), 121 | world_size=self.torch_distributed_world_size, 122 | rank=self.torch_distributed_world_rank, 123 | device_id=self.device, 124 | ) 125 | 126 | print("Initialized torch.distributed process group.") 127 | 128 | def init_distributed(self): 129 | 130 | print( 131 | f"Initializing distributed worker: {self.torch_distributed_world_rank}. Ray address: {self.torch_distributed_address}..." 132 | ) 133 | 134 | self.device = torch.device("cuda:0") 135 | 136 | self.init_process_group() 137 | 138 | print(f"Initialized distributed worker: {self.torch_distributed_world_rank}.") 139 | 140 | parallel_dims = ParallelDims( 141 | dp=self.data_parallelism_size, 142 | tp=self.tensor_parallelism_size, 143 | pp=self.pipeline_parallelism_size, 144 | world_size=self.torch_distributed_world_size, 145 | enable_loss_parallel=False, 146 | ) 147 | 148 | world_mesh = parallel_dims.build_mesh(device_type=f"cuda") 149 | 150 | torch.set_default_device(self.device) 151 | 152 | print( 153 | f"Parallelizing distributed worker: {self.torch_distributed_world_rank}..." 154 | ) 155 | 156 | parallelize_model( 157 | self.model._model, 158 | self.model._model.config._name_or_path, 159 | world_mesh["tp"], 160 | ) 161 | 162 | print(f"Parallelized distributed worker: {self.torch_distributed_world_rank}.") 163 | print( 164 | f"Loading model for distributed worker: {self.torch_distributed_world_rank}..." 165 | ) 166 | 167 | load_hf_model_from_cache( 168 | self.model._model, self.model._model.config._name_or_path 169 | ) 170 | 171 | # Handle buffers 172 | self.model._model = self.model._model.to(self.device) 173 | 174 | self.model._model.requires_grad_(False) 175 | 176 | print( 177 | f"Loaded model for distributed worker: {self.torch_distributed_world_rank}." 178 | ) 179 | 180 | self.model.dispatched = True 181 | 182 | torch.cuda.empty_cache() 183 | 184 | if self.head: 185 | 186 | config = { 187 | "config_json_string": self.model._model.config.to_json_string(), 188 | "repo_id": self.model._model.config._name_or_path, 189 | } 190 | 191 | serve.get_app_handle("Controller").set_model_configuration.remote( 192 | self.replica_context.app_name, config 193 | ) 194 | 195 | def execute(self, graph: Graph): 196 | 197 | with self.timer: 198 | 199 | return super().execute(graph) 200 | 201 | def pre(self) -> Graph: 202 | 203 | graph = self.request.deserialize(self.model) 204 | 205 | if self.head: 206 | 207 | self.respond( 208 | status=BackendResponseModel.JobStatus.RUNNING, 209 | description="Your job has started running.", 210 | ) 211 | 212 | for worker_deployment in self.worker_actors: 213 | 214 | worker_deployment.__call__.remote(self.request) 215 | 216 | torch.distributed.barrier() 217 | 218 | return graph 219 | 220 | def post(self, *args, **kwargs): 221 | 222 | if self.head: 223 | super().post(*args, **kwargs) 224 | 225 | def exception(self, *args, **kwargs): 226 | 227 | if self.head: 228 | super().exception(*args, **kwargs) 229 | 230 | def log(self, *args, **kwargs): 231 | 232 | if self.head: 233 | super().log(*args, **kwargs) 234 | 235 | def stream_send(self, *args, **kwargs): 236 | if self.head: 237 | super().stream_send(*args, **kwargs) 238 | 239 | def cleanup(self): 240 | 241 | torch.distributed.barrier() 242 | 243 | super().cleanup() 244 | 245 | 246 | @serve.deployment( 247 | ray_actor_options={"num_gpus": 1, "num_cpus": 2}, 248 | health_check_period_s=10000000000000000000000000000000, 249 | health_check_timeout_s=12000000000000000000000000000000, 250 | max_ongoing_requests=200, 251 | max_queued_requests=200, 252 | ) 253 | class ModelDeployment(_ModelDeployment): 254 | pass 255 | 256 | 257 | @ray.remote(num_cpus=2, num_gpus=1) 258 | class ModelDeploymentShard(_ModelDeployment): 259 | pass 260 | 261 | 262 | class DistributedModelDeploymentArgs(BaseModelDeploymentArgs): 263 | 264 | device_map: str | None = None 265 | dispatch: bool = False 266 | 267 | torch_distributed_address: str = None 268 | torch_distributed_port: int = None 269 | torch_distributed_world_rank: int = 0 270 | 271 | torch_distributed_world_size: int 272 | torch_distributed_world_timeout_seconds: int 273 | 274 | data_parallelism_size: int = 1 275 | tensor_parallelism_size: int = 1 276 | pipeline_parallelism_size: int = 1 277 | 278 | 279 | def app(args: DistributedModelDeploymentArgs) -> Application: 280 | return ModelDeployment.bind(**args.model_dump()) 281 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/deployments/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from ray import serve 4 | 5 | from nnsight.tracing.graph import Graph 6 | from ..util import set_cuda_env_var 7 | from .base import BaseModelDeployment, BaseModelDeploymentArgs, threaded 8 | 9 | 10 | class ThreadedModelDeployment(BaseModelDeployment): 11 | 12 | @threaded 13 | def execute(self, graph: Graph): 14 | return super().execute(graph) 15 | 16 | 17 | @serve.deployment( 18 | ray_actor_options={ 19 | "num_cpus": 2, 20 | }, 21 | max_ongoing_requests=200, max_queued_requests=200, 22 | health_check_period_s=10000000000000000000000000000000, 23 | health_check_timeout_s=12000000000000000000000000000000, 24 | ) 25 | class ModelDeployment(ThreadedModelDeployment): 26 | 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | 30 | config = { 31 | "config_json_string": self.model._model.config.to_json_string(), 32 | "repo_id": self.model._model.config._name_or_path, 33 | } 34 | 35 | serve.get_app_handle("Controller").set_model_configuration.remote(self.replica_context.app_name, config) 36 | 37 | 38 | def app(args: BaseModelDeploymentArgs) -> serve.Application: 39 | 40 | return ModelDeployment.bind(**args.model_dump()) 41 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/deployments/model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/deployments/model/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def cpu(): 5 | 6 | tensor = torch.randn((1000,1000,10), device="cpu") 7 | 8 | torch.save(tensor, "tensor.pt") 9 | 10 | def gpu(): 11 | tensor = torch.randn((1000,1000,10), device="cuda") 12 | 13 | torch.save(tensor, "tensor.pt") 14 | 15 | 16 | 17 | import time 18 | 19 | cpu() 20 | 21 | start = time.time() 22 | 23 | for i in range(100): 24 | cpu() 25 | 26 | end = time.time() 27 | 28 | print((end - start) / 100) 29 | 30 | start = time.time() 31 | 32 | for i in range(100): 33 | gpu() 34 | 35 | end = time.time() 36 | 37 | print((end - start) / 100) -------------------------------------------------------------------------------- /src/services/ray/src/ray/deployments/protocols.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from nnsight.schema.format.functions import update_function 4 | 5 | from nnsight.tracing.protocols import Protocol 6 | 7 | class LogProtocol(Protocol): 8 | 9 | @classmethod 10 | def set(cls, fn: Callable): 11 | 12 | update_function(print, fn) 13 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/deployments/request.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from functools import wraps 3 | import traceback 4 | 5 | import ray 6 | from ray import serve 7 | from ray.serve import Application 8 | from ray.serve.handle import DeploymentHandle 9 | 10 | try: 11 | from slugify import slugify 12 | except: 13 | pass 14 | 15 | from nnsight.schema.response import ResponseModel 16 | 17 | from ...schema import BackendRequestModel 18 | from .base import BaseDeployment, BaseDeploymentArgs 19 | 20 | 21 | @serve.deployment(max_ongoing_requests=200, max_queued_requests=200) 22 | class RequestDeployment(BaseDeployment): 23 | def __init__(self, *args, **kwargs) -> None: 24 | super().__init__(*args, **kwargs) 25 | 26 | def __call__(self, request: BackendRequestModel): 27 | 28 | if not self.sio.connected: 29 | self.sio.connect( 30 | self.api_url, 31 | socketio_path="/ws/socket.io", 32 | transports=["websocket"], 33 | wait_timeout=100000, 34 | ) 35 | 36 | try: 37 | 38 | model_key = f"Model:{slugify(request.model_key)}" 39 | 40 | app_handle = self.get_ray_app_handle(model_key) 41 | 42 | request.create_response( 43 | status=ResponseModel.JobStatus.APPROVED, 44 | description="Your job was approved and is waiting to be run.", 45 | logger=self.logger, 46 | ).respond(self.sio, self.object_store) 47 | 48 | app_handle.remote(request) 49 | 50 | except Exception as exception: 51 | 52 | description = traceback.format_exc() 53 | 54 | request.create_response( 55 | status=ResponseModel.JobStatus.ERROR, 56 | description=f"{description}\n{str(exception)}", 57 | logger=self.logger, 58 | ).respond(self.sio, self.object_store) 59 | 60 | def get_ray_app_handle(self, name: str) -> DeploymentHandle: 61 | 62 | return serve.get_app_handle(name) 63 | 64 | 65 | def app(args: BaseDeploymentArgs) -> Application: 66 | return RequestDeployment.bind(**args.model_dump()) 67 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ndif-team/ndif/546316aac1f625cbeb2d93a0d2be47e856623452/src/services/ray/src/ray/distributed/__init__.py -------------------------------------------------------------------------------- /src/services/ray/src/ray/distributed/parallel_dims.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import cached_property 3 | 4 | from torch.distributed.device_mesh import init_device_mesh 5 | 6 | 7 | @dataclass 8 | class ParallelDims: 9 | dp: int 10 | tp: int 11 | pp: int 12 | world_size: int 13 | enable_loss_parallel: bool 14 | 15 | def __post_init__(self): 16 | self._validate() 17 | 18 | def _validate(self): 19 | dp, tp, pp = self.dp, self.tp, self.pp 20 | if dp == -1: 21 | self.dp = dp = self.world_size // (tp * pp) 22 | assert dp >= 1, dp 23 | assert tp >= 1, tp 24 | assert pp >= 1, pp 25 | assert ( 26 | dp * tp * pp == self.world_size 27 | ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" 28 | 29 | def build_mesh(self, device_type): 30 | dims = [] 31 | names = [] 32 | for d, name in zip( 33 | [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True 34 | ): 35 | if d > 1: 36 | dims.append(d) 37 | names.append(name) 38 | names = tuple(names) 39 | return init_device_mesh(device_type, dims, mesh_dim_names=names) 40 | 41 | @property 42 | def dp_enabled(self): 43 | return self.dp > 1 44 | 45 | @property 46 | def tp_enabled(self): 47 | return self.tp > 1 48 | 49 | @property 50 | def pp_enabled(self): 51 | return self.pp > 1 52 | 53 | @property 54 | def loss_parallel_enabled(self): 55 | return self.tp > 1 and self.enable_loss_parallel 56 | 57 | @cached_property 58 | def model_parallel_size(self): 59 | return self.tp * self.pp 60 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/distributed/tensor_parallelism/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch.distributed._tensor.api import DTensor 5 | from torch.distributed.tensor.parallel import ParallelStyle, parallelize_module as _parallelize_module 6 | 7 | from .plans import model_id_to_plans 8 | 9 | 10 | def ready_to_be_parallelized(module: torch.nn.Module): 11 | 12 | param = next(module.parameters()) 13 | 14 | return not isinstance(param, DTensor) and param.device.type != 'meta' 15 | 16 | 17 | def parallelize_on_state_dict_load(plan, module: torch.nn.Module, tp_mesh): 18 | 19 | def parallelize_hook(plan, module: torch.nn.Module, keys): 20 | 21 | if ready_to_be_parallelized(module): 22 | 23 | _parallelize_module(module, tp_mesh, plan) 24 | 25 | return module.register_load_state_dict_post_hook(partial(parallelize_hook, plan)) 26 | 27 | 28 | def parallelize_module(module: torch.nn.Module, module_path: str, plan, tp_mesh): 29 | 30 | module_path_components = module_path.split(".*", 1) 31 | 32 | module = module.get_submodule(module_path_components[0]) 33 | 34 | if len(module_path_components) == 1: 35 | 36 | if isinstance(plan, ParallelStyle): 37 | 38 | parallelize_on_state_dict_load(plan, module, tp_mesh) 39 | else: 40 | plan(module, tp_mesh) 41 | else: 42 | for _module in module: 43 | parallelize_module(_module, module_path_components[1], plan, tp_mesh) 44 | 45 | 46 | def parallelize_model( 47 | model: torch.nn.Module, model_name: str, tp_mesh 48 | ) -> torch.nn.Module: 49 | 50 | model_plans = model_id_to_plans[model_name] 51 | 52 | for module_path, plan in model_plans.items(): 53 | 54 | parallelize_module(model, module_path, plan, tp_mesh) 55 | 56 | return model 57 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/distributed/tensor_parallelism/plans/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama import model_plans 2 | 3 | model_id_to_plans = { 4 | "meta-llama/Meta-Llama-3-8B": model_plans, 5 | "meta-llama/Meta-Llama-3-70B": model_plans, 6 | "meta-llama/Meta-Llama-3.1-70B": model_plans, 7 | "meta-llama/Meta-Llama-3.1-8B": model_plans, 8 | "meta-llama/Meta-Llama-3.1-405B": model_plans, 9 | "meta-llama/Meta-Llama-3.1-405B-Instruct": model_plans, 10 | "meta-llama/Meta-Llama-3.1-70B-Instruct": model_plans, 11 | "meta-llama/Llama-3-8B": model_plans, 12 | "meta-llama/Llama-3-70B": model_plans, 13 | "meta-llama/Llama-3.1-70B": model_plans, 14 | "meta-llama/Llama-3.1-8B": model_plans, 15 | "meta-llama/Llama-3.1-405B": model_plans, 16 | "meta-llama/Llama-3.1-405B-Instruct": model_plans, 17 | "meta-llama/Llama-3.1-70B-Instruct": model_plans, 18 | } 19 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/distributed/tensor_parallelism/plans/llama.py: -------------------------------------------------------------------------------- 1 | from torch.distributed._tensor import Replicate, Shard 2 | from torch.distributed.tensor.parallel import ( 3 | ColwiseParallel, 4 | PrepareModuleInput, 5 | RowwiseParallel, 6 | SequenceParallel, 7 | ) 8 | 9 | 10 | def update_attention(module, mesh): 11 | 12 | module.num_heads = module.num_heads // mesh.size() 13 | module.num_key_value_heads = module.num_key_value_heads // mesh.size() 14 | 15 | 16 | model_plans = { 17 | # "model.norm": SequenceParallel(), 18 | # "model.embed_tokens": RowwiseParallel( 19 | # input_layouts=Replicate(), 20 | # output_layouts=Shard(1), 21 | # ), 22 | "lm_head": ColwiseParallel( 23 | output_layouts=Replicate(), 24 | use_local_output=True, 25 | ), 26 | # "model.layers.*input_layernorm": SequenceParallel(), 27 | # "model.layers.*post_attention_layernorm": SequenceParallel(), 28 | # "model.layers.*self_attn": PrepareModuleInput( 29 | # input_layouts=(Shard(1), None), 30 | # desired_input_layouts=(Replicate(), None), 31 | # ), 32 | # "model.layers.*mlp": PrepareModuleInput( 33 | # input_layouts=(Shard(1),), 34 | # desired_input_layouts=(Replicate(),), 35 | # ), 36 | "model.layers.*self_attn.q_proj": ColwiseParallel( 37 | output_layouts=Replicate(), 38 | use_local_output=True, 39 | ), 40 | "model.layers.*self_attn.k_proj": ColwiseParallel( 41 | output_layouts=Replicate(), 42 | use_local_output=True, 43 | ), 44 | "model.layers.*self_attn.v_proj": ColwiseParallel( 45 | output_layouts=Replicate(), 46 | use_local_output=True, 47 | ), 48 | "model.layers.*self_attn.o_proj": ColwiseParallel( 49 | output_layouts=Replicate(), 50 | use_local_output=True, 51 | ), 52 | "model.layers.*mlp.gate_proj": ColwiseParallel(use_local_output=False), 53 | "model.layers.*mlp.up_proj": ColwiseParallel(use_local_output=False), 54 | "model.layers.*mlp.down_proj": RowwiseParallel(), 55 | # "model.layers.*self_attn": update_attention, 56 | } 57 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/distributed/tensor_parallelism/test.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from datetime import timedelta 3 | from timeit import default_timer as timer 4 | from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union 5 | 6 | import torch 7 | import torch.distributed 8 | import torch.distributed.launch 9 | 10 | import nnsight 11 | 12 | from ..parallel_dims import ParallelDims 13 | from ..util import load_hf_model_from_cache 14 | from . import parallelize_model 15 | 16 | 17 | def main(local_rank: int, world_rank: int, world_size: int, model_id: str): 18 | 19 | device = torch.device(f"cuda:{local_rank}") 20 | 21 | nnsight_model = nnsight.LanguageModel(model_id) 22 | 23 | torch.distributed.init_process_group( 24 | "nccl", 25 | init_method="tcp://10.201.22.179:5003", 26 | timeout=timedelta(seconds=10), 27 | world_size=world_size, 28 | rank=world_rank, 29 | 30 | ) 31 | 32 | parallel_dims = ParallelDims( 33 | dp=1, 34 | tp=world_size, 35 | pp=1, 36 | world_size=world_size, 37 | enable_loss_parallel=False, 38 | ) 39 | world_mesh = parallel_dims.build_mesh(device_type=f"cuda") 40 | 41 | model = nnsight_model._model 42 | 43 | torch.set_default_device(device) 44 | 45 | model = parallelize_model(model, model_id, world_mesh["tp"]) 46 | 47 | load_hf_model_from_cache(model, model_id) 48 | 49 | nnsight_model._dispatched = True 50 | 51 | with nnsight_model.trace("hello", scan=False, validate=False): 52 | 53 | output = nnsight_model.model.layers[0].self_attn.q_proj.output.save() 54 | 55 | breakpoint() 56 | 57 | 58 | if __name__ == "__main__": 59 | 60 | import argparse 61 | 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("local_rank", type=int) 64 | parser.add_argument("world_rank", type=int) 65 | parser.add_argument("world_size", type=int) 66 | parser.add_argument("--model_id", default="meta-llama/Meta-Llama-3-8B") 67 | 68 | main(**vars(parser.parse_args())) 69 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/distributed/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import wraps 3 | from typing import Any, NamedTuple 4 | 5 | import torch 6 | import os 7 | 8 | from safetensors.torch import load_file 9 | from torch.distributed._tensor import DTensor, Replicate 10 | from tqdm import tqdm 11 | from transformers.utils.hub import cached_file 12 | 13 | from nnsight import util 14 | from nnsight.intervention.protocols import InterventionProtocol 15 | 16 | from accelerate import load_checkpoint_and_dispatch 17 | from accelerate.utils import modeling 18 | from accelerate.utils.imports import ( 19 | is_mlu_available, 20 | is_mps_available, 21 | is_musa_available, 22 | is_npu_available, 23 | is_peft_available, 24 | is_torch_xla_available, 25 | is_xpu_available, 26 | ) 27 | from accelerate.utils.modeling import check_device_same, clear_device_cache 28 | from accelerate.utils import modeling 29 | from safetensors.torch import load_file 30 | 31 | 32 | def load_hf_model_from_cache(model: torch.nn.Module, repo_id: str): 33 | 34 | model_index_filename = "model.safetensors.index.json" 35 | index_path = cached_file(repo_id, model_index_filename) 36 | 37 | with open(index_path, "r") as f: 38 | index = json.load(f) 39 | 40 | shard_paths = sorted(set(index["weight_map"].values())) 41 | 42 | pbar = tqdm(shard_paths, desc="Loading shards") 43 | 44 | for shard_file in pbar: 45 | # Get path to shard 46 | shard_path = cached_file(repo_id, shard_file) 47 | pbar.set_postfix({"Current shard": shard_file}) 48 | 49 | # Get path to shard 50 | state_dict = load_file(shard_path, device="cpu") 51 | 52 | torch.distributed.barrier() 53 | 54 | if os.environ.get('DELETE_MODEL_SHARDS', '0') == '1' and os.environ['CUDA_VISIBLE_DEVICES'][0] == '0': 55 | binary_path = os.path.realpath(shard_path) 56 | os.remove(binary_path) 57 | os.remove(shard_path) 58 | 59 | 60 | state_dict = {key: value.cuda() for key, value in state_dict.items()} 61 | 62 | model.load_state_dict(state_dict, strict=False, assign=True) 63 | 64 | torch.cuda.empty_cache() 65 | 66 | 67 | def patch_intervention_protocol() -> None: 68 | 69 | def wrap(intervene): 70 | 71 | @wraps(intervene) 72 | def intervene_wrapper(activations: Any, *args, **kwargs): 73 | 74 | placements = [] 75 | 76 | def check_for_dtensor(tensor: torch.Tensor): 77 | 78 | nonlocal placements 79 | 80 | if isinstance(tensor, DTensor): 81 | 82 | placements.append((tensor.placements, tensor.device_mesh)) 83 | 84 | return tensor.full_tensor() 85 | 86 | placements.append(None) 87 | 88 | return tensor 89 | 90 | activations = util.apply( 91 | activations, check_for_dtensor, torch.Tensor 92 | ) 93 | 94 | activations = intervene(activations, *args, **kwargs) 95 | 96 | def redistribute_tensors(tensor: torch.Tensor): 97 | 98 | nonlocal placements 99 | 100 | placement = placements.pop(0) 101 | 102 | if placement is None: 103 | 104 | return tensor 105 | 106 | placement, device_mesh = placement 107 | 108 | return DTensor.from_local( 109 | tensor, device_mesh=device_mesh, placements=[Replicate()] 110 | ).redistribute(device_mesh=device_mesh, placements=placement) 111 | 112 | if len(placements) > 0: 113 | 114 | activations = util.apply( 115 | activations, redistribute_tensors, torch.Tensor 116 | ) 117 | return activations 118 | 119 | return intervene_wrapper 120 | 121 | InterventionProtocol.intervene = wrap(InterventionProtocol.intervene) 122 | 123 | 124 | def to_full_tensor(data: Any) -> Any: 125 | 126 | return util.apply(data, lambda x: x.full_tensor(), DTensor) 127 | 128 | 129 | def set_module_tensor_to_device( 130 | module: torch.nn.Module, 131 | tensor_name: str, 132 | device, 133 | value=None, 134 | dtype=None, 135 | fp16_statistics=None, 136 | tied_params_map=None, 137 | ): 138 | """ 139 | A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing 140 | `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). 141 | 142 | Args: 143 | module (`torch.nn.Module`): 144 | The module in which the tensor we want to move lives. 145 | tensor_name (`str`): 146 | The full name of the parameter/buffer. 147 | device (`int`, `str` or `torch.device`): 148 | The device on which to set the tensor. 149 | value (`torch.Tensor`, *optional*): 150 | The value of the tensor (useful when going from the meta device to any other device). 151 | dtype (`torch.dtype`, *optional*): 152 | If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to 153 | the dtype of the existing parameter in the model. 154 | fp16_statistics (`torch.HalfTensor`, *optional*): 155 | The list of fp16 statistics to set on the module, used for 8 bit model serialization. 156 | tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`): 157 | A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given 158 | execution device, this parameter is useful to reuse the first available pointer of a shared weight on the 159 | device for all others, instead of duplicating memory. 160 | """ 161 | # Recurse if needed 162 | if "." in tensor_name: 163 | splits = tensor_name.split(".") 164 | for split in splits[:-1]: 165 | new_module = getattr(module, split) 166 | if new_module is None: 167 | raise ValueError(f"{module} has no attribute {split}.") 168 | module = new_module 169 | tensor_name = splits[-1] 170 | 171 | if ( 172 | tensor_name not in module._parameters 173 | and tensor_name not in module._buffers 174 | ): 175 | raise ValueError( 176 | f"{module} does not have a parameter or a buffer named {tensor_name}." 177 | ) 178 | is_buffer = tensor_name in module._buffers 179 | old_value = getattr(module, tensor_name) 180 | 181 | # Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight 182 | # in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer. 183 | if ( 184 | value is not None 185 | and tied_params_map is not None 186 | and value.data_ptr() in tied_params_map 187 | and device in tied_params_map[value.data_ptr()] 188 | ): 189 | module._parameters[tensor_name] = tied_params_map[value.data_ptr()][ 190 | device 191 | ] 192 | return 193 | elif ( 194 | tied_params_map is not None 195 | and old_value.data_ptr() in tied_params_map 196 | and device in tied_params_map[old_value.data_ptr()] 197 | ): 198 | module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][ 199 | device 200 | ] 201 | return 202 | 203 | if ( 204 | old_value.device == torch.device("meta") 205 | and device not in ["meta", torch.device("meta")] 206 | and value is None 207 | ): 208 | raise ValueError( 209 | f"{tensor_name} is on the meta device, we need a `value` to put in on {device}." 210 | ) 211 | 212 | param = ( 213 | module._parameters[tensor_name] 214 | if tensor_name in module._parameters 215 | else None 216 | ) 217 | param_cls = type(param) 218 | 219 | if value is not None: 220 | # We can expect mismatches when using bnb 4bit since Params4bit will reshape and pack the weights. 221 | # In other cases, we want to make sure we're not loading checkpoints that do not match the config. 222 | if ( 223 | old_value.shape != value.shape 224 | and param_cls.__name__ != "Params4bit" 225 | ): 226 | raise ValueError( 227 | f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this looks incorrect.' 228 | ) 229 | 230 | if dtype is None: 231 | # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model 232 | value = value.to(old_value.dtype) 233 | elif not str(value.dtype).startswith( 234 | ("torch.uint", "torch.int", "torch.bool") 235 | ): 236 | value = value.to(dtype) 237 | 238 | device_quantization = None 239 | with torch.no_grad(): 240 | # leave it on cpu first before moving them to cuda 241 | # # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0 242 | if ( 243 | param is not None 244 | and param.device.type != "cuda" 245 | and torch.device(device).type == "cuda" 246 | and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"] 247 | ): 248 | device_quantization = device 249 | device = "cpu" 250 | # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). 251 | if isinstance(device, int): 252 | if is_npu_available(): 253 | device = f"npu:{device}" 254 | elif is_mlu_available(): 255 | device = f"mlu:{device}" 256 | elif is_musa_available(): 257 | device = f"musa:{device}" 258 | elif is_xpu_available(): 259 | device = f"xpu:{device}" 260 | if "xpu" in str(device) and not is_xpu_available(): 261 | raise ValueError( 262 | f'{device} is not available, you should use device="cpu" instead' 263 | ) 264 | if value is None: 265 | new_value = old_value.to(device) 266 | if dtype is not None and device in ["meta", torch.device("meta")]: 267 | if not str(old_value.dtype).startswith( 268 | ("torch.uint", "torch.int", "torch.bool") 269 | ): 270 | new_value = new_value.to(dtype) 271 | 272 | if not is_buffer: 273 | module._parameters[tensor_name] = param_cls( 274 | new_value, requires_grad=old_value.requires_grad 275 | ) 276 | elif isinstance(value, torch.Tensor): 277 | new_value = value.to(device) 278 | else: 279 | new_value = torch.tensor(value, device=device) 280 | if device_quantization is not None: 281 | device = device_quantization 282 | if is_buffer: 283 | module._buffers[tensor_name] = new_value 284 | elif value is not None or not check_device_same( 285 | torch.device(device), module._parameters[tensor_name].device 286 | ): 287 | param_cls = type(module._parameters[tensor_name]) 288 | kwargs = module._parameters[tensor_name].__dict__ 289 | if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]: 290 | if ( 291 | param_cls.__name__ == "Int8Params" 292 | and new_value.dtype == torch.float32 293 | ): 294 | # downcast to fp16 if any - needed for 8bit serialization 295 | new_value = new_value.to(torch.float16) 296 | # quantize module that are going to stay on the cpu so that we offload quantized weights 297 | if device == "cpu" and param_cls.__name__ == "Int8Params": 298 | new_value = ( 299 | param_cls( 300 | new_value, 301 | requires_grad=old_value.requires_grad, 302 | **kwargs, 303 | ) 304 | .to(0) 305 | .to("cpu") 306 | ) 307 | new_value.CB = new_value.CB.to("cpu") 308 | new_value.SCB = new_value.SCB.to("cpu") 309 | else: 310 | new_value = param_cls( 311 | new_value, 312 | requires_grad=old_value.requires_grad, 313 | **kwargs, 314 | ).to(device) 315 | elif param_cls.__name__ in ["QTensor", "QBitsTensor"]: 316 | new_value = torch.nn.Parameter( 317 | new_value, requires_grad=old_value.requires_grad 318 | ).to(device) 319 | else: 320 | new_value = param_cls( 321 | new_value, requires_grad=old_value.requires_grad 322 | ).to(device) 323 | 324 | module._parameters[tensor_name] = new_value 325 | if fp16_statistics is not None: 326 | module._parameters[tensor_name].SCB = fp16_statistics.to(device) 327 | del fp16_statistics 328 | # as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight 329 | if ( 330 | module.__class__.__name__ == "Linear8bitLt" 331 | and getattr(module.weight, "SCB", None) is None 332 | and str(module.weight.device) != "meta" 333 | ): 334 | # quantize only if necessary 335 | device_index = ( 336 | torch.device(device).index 337 | if torch.device(device).type == "cuda" 338 | else None 339 | ) 340 | if ( 341 | not getattr(module.weight, "SCB", None) 342 | and device_index is not None 343 | ): 344 | if ( 345 | module.bias is not None 346 | and module.bias.device.type != "meta" 347 | ): 348 | # if a bias exists, we need to wait until the bias is set on the correct device 349 | module = module.cuda(device_index) 350 | elif module.bias is None: 351 | # if no bias exists, we can quantize right away 352 | module = module.cuda(device_index) 353 | elif ( 354 | module.__class__.__name__ == "Linear4bit" 355 | and getattr(module.weight, "quant_state", None) is None 356 | and str(module.weight.device) != "meta" 357 | ): 358 | # quantize only if necessary 359 | device_index = ( 360 | torch.device(device).index 361 | if torch.device(device).type == "cuda" 362 | else None 363 | ) 364 | if ( 365 | not getattr(module.weight, "quant_state", None) 366 | and device_index is not None 367 | ): 368 | module.weight = module.weight.cuda(device_index) 369 | # clean pre and post foward hook 370 | if device != "cpu": 371 | clear_device_cache() 372 | 373 | # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in 374 | # order to avoid duplicating memory, see above. 375 | if ( 376 | tied_params_map is not None 377 | and old_value.data_ptr() in tied_params_map 378 | and device not in tied_params_map[old_value.data_ptr()] 379 | ): 380 | tied_params_map[old_value.data_ptr()][device] = new_value 381 | elif ( 382 | value is not None 383 | and tied_params_map is not None 384 | and value.data_ptr() in tied_params_map 385 | and device not in tied_params_map[value.data_ptr()] 386 | ): 387 | tied_params_map[value.data_ptr()][device] = new_value 388 | 389 | 390 | #### PATCH ####################################### 391 | 392 | for hook in module._load_state_dict_post_hooks.values(): 393 | 394 | hook(module, None) 395 | 396 | 397 | def patch_accelerate(): 398 | 399 | modeling.set_module_tensor_to_device = set_module_tensor_to_device -------------------------------------------------------------------------------- /src/services/ray/src/ray/raystate.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | try: 4 | from slugify import slugify 5 | except: 6 | pass 7 | import ray 8 | import yaml 9 | from pydantic import BaseModel 10 | from ray.dashboard.modules.serve.sdk import ServeSubmissionClient 11 | from ray.serve.schema import ( 12 | DeploymentSchema, 13 | RayActorOptionsSchema, 14 | ServeApplicationSchema, 15 | ServeDeploySchema, 16 | ) 17 | 18 | from .deployments.base import BaseDeploymentArgs, BaseModelDeploymentArgs 19 | 20 | 21 | class ServiceConfigurationSchema(BaseModel): 22 | class ModelConfigurationSchema(BaseModel): 23 | 24 | model_import_path: str = None 25 | 26 | ray_actor_options: Dict[str, Any] = {} 27 | args: Dict[str, Any] = {} 28 | 29 | model_key: str 30 | num_replicas: int 31 | 32 | default_model_import_path: str 33 | request_import_path: str 34 | request_num_replicas: int 35 | 36 | models: List[ModelConfigurationSchema] 37 | 38 | 39 | class RayState: 40 | 41 | def __init__( 42 | self, 43 | ray_config_path: str, 44 | service_config_path: str, 45 | object_store_url: str, 46 | object_store_access_key: str, 47 | object_store_secret_key: str, 48 | api_url: str, 49 | ) -> None: 50 | 51 | self.ray_config_path = ray_config_path 52 | self.service_config_path = service_config_path 53 | self.object_store_url = object_store_url 54 | self.object_store_access_key = object_store_access_key 55 | self.object_store_secret_key = object_store_secret_key 56 | self.api_url = api_url 57 | 58 | self.runtime_context = ray.get_runtime_context() 59 | self.ray_dashboard_url = ( 60 | f"http://{self.runtime_context.worker.node.address_info['webui_url']}" 61 | ) 62 | 63 | self.name_to_application: Dict[str, ServeApplicationSchema] = {} 64 | 65 | def load_from_disk(self): 66 | 67 | with open(self.ray_config_path, "r") as file: 68 | self.ray_config = ServeDeploySchema(**yaml.safe_load(file)) 69 | 70 | with open(self.service_config_path, "r") as file: 71 | self.service_config = ServiceConfigurationSchema(**yaml.safe_load(file)) 72 | 73 | def redeploy(self): 74 | 75 | self.load_from_disk() 76 | 77 | self.add_request_app() 78 | 79 | for model_config in self.service_config.models: 80 | self.add_model_app(model_config) 81 | 82 | self.apply() 83 | 84 | def apply(self) -> None: 85 | 86 | ServeSubmissionClient(self.ray_dashboard_url).deploy_applications( 87 | self.ray_config.dict(exclude_unset=True), 88 | ) 89 | 90 | def add(self, application: ServeApplicationSchema): 91 | 92 | self.ray_config.applications.append(application) 93 | self.name_to_application[application.name] = application 94 | 95 | def add_request_app(self) -> None: 96 | application = ServeApplicationSchema( 97 | name="Request", 98 | import_path=self.service_config.request_import_path, 99 | route_prefix="/request", 100 | deployments=[ 101 | DeploymentSchema( 102 | name="RequestDeployment", 103 | num_replicas=self.service_config.request_num_replicas, 104 | ray_actor_options=RayActorOptionsSchema( 105 | num_cpus=1, resources={"head": 1} 106 | ), 107 | ) 108 | ], 109 | args=BaseDeploymentArgs( 110 | api_url=self.api_url, 111 | object_store_url=self.object_store_url, 112 | object_store_access_key=self.object_store_access_key, 113 | object_store_secret_key=self.object_store_secret_key, 114 | ).model_dump(), 115 | ) 116 | 117 | self.add(application) 118 | 119 | def add_model_app( 120 | self, model_config: ServiceConfigurationSchema.ModelConfigurationSchema 121 | ) -> None: 122 | 123 | model_key = slugify(model_config.model_key) 124 | 125 | model_config.args["model_key"] = model_config.model_key 126 | model_config.args["api_url"] = self.api_url 127 | model_config.args["object_store_url"] = self.object_store_url 128 | model_config.args["object_store_access_key"] = self.object_store_access_key 129 | model_config.args["object_store_secret_key"] = self.object_store_secret_key 130 | 131 | application = ServeApplicationSchema( 132 | name=f"Model:{model_key}", 133 | import_path=model_config.model_import_path 134 | or self.service_config.default_model_import_path, 135 | route_prefix=f"/Model:{model_key}", 136 | deployments=[ 137 | DeploymentSchema( 138 | name="ModelDeployment", 139 | num_replicas=model_config.num_replicas, 140 | ray_actor_options=model_config.ray_actor_options, 141 | ) 142 | ], 143 | args=model_config.args, 144 | runtime_env={ 145 | "env_vars": {"restart_hash": "", 146 | # For distributed model timeout handling 147 | "TORCH_NCCL_ASYNC_ERROR_HANDLING": "0"} 148 | }, 149 | ) 150 | 151 | self.add(application) 152 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/resources.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .util import get_total_cudamemory_MBs 4 | 5 | 6 | def main(head: bool, name: str = None): 7 | 8 | resources = {} 9 | 10 | if head: 11 | 12 | resources["head"] = 10 13 | 14 | resources["cuda_memory_MB"] = get_total_cudamemory_MBs() 15 | 16 | if name is not None: 17 | 18 | resources[name] = 10 19 | 20 | print(json.dumps(resources)) 21 | 22 | 23 | if __name__ == "__main__": 24 | import argparse 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--head", action="store_true") 28 | parser.add_argument("--name", default=None) 29 | main(**vars(parser.parse_args())) 30 | -------------------------------------------------------------------------------- /src/services/ray/src/ray/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from concurrent.futures import TimeoutError 4 | from contextlib import AbstractContextManager 5 | from functools import wraps 6 | from typing import Callable 7 | 8 | import torch 9 | from torch.overrides import TorchFunctionMode 10 | 11 | from nnsight.util import Patch, Patcher 12 | from nnsight.schema.format import functions 13 | from nnsight.tracing.graph import Graph, Node 14 | 15 | def get_total_cudamemory_MBs(return_ids=False) -> int: 16 | 17 | cudamemory = 0 18 | 19 | ids = [] 20 | 21 | for device in range(torch.cuda.device_count()): 22 | try: 23 | cudamemory += torch.cuda.mem_get_info(device)[1] * 1e-6 24 | if return_ids: 25 | ids.append(device) 26 | except: 27 | pass 28 | 29 | if return_ids: 30 | 31 | return int(cudamemory), ids 32 | 33 | return int(cudamemory) 34 | 35 | 36 | def set_cuda_env_var(ids=None): 37 | 38 | del os.environ["CUDA_VISIBLE_DEVICES"] 39 | 40 | if ids == None: 41 | 42 | _, ids = get_total_cudamemory_MBs(return_ids=True) 43 | 44 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(x) for x in ids]) 45 | 46 | 47 | def update_nnsight_print_function(new_function): 48 | 49 | new_function.__name__ = functions.get_function_name(print) 50 | 51 | functions.FUNCTIONS_WHITELIST[functions.get_function_name(print)] = new_function 52 | 53 | 54 | class NNsightTimer(AbstractContextManager): 55 | 56 | class FunctionMode(TorchFunctionMode): 57 | 58 | def __init__(self, timer: "NNsightTimer"): 59 | 60 | self.timer = timer 61 | 62 | super().__init__() 63 | 64 | def __torch_function__(self, func, types, args=(), kwargs=None): 65 | 66 | self.timer.check() 67 | 68 | if kwargs is None: 69 | kwargs = {} 70 | 71 | return func(*args, **kwargs) 72 | 73 | def __init__(self, timeout: float): 74 | 75 | self.timeout = timeout 76 | self.start: float = None 77 | 78 | self.patcher = Patcher( 79 | [ 80 | Patch(Node, self.wrap(Node.execute), "execute"), 81 | Patch(Graph, self.wrap(Graph.execute), "execute"), 82 | ] 83 | ) 84 | 85 | self.fn_mode = NNsightTimer.FunctionMode(self) 86 | 87 | def __enter__(self): 88 | 89 | if self.timeout is not None: 90 | 91 | self.reset() 92 | 93 | self.patcher.__enter__() 94 | self.fn_mode.__enter__() 95 | 96 | return self 97 | 98 | def __exit__(self, exc_type, exc_value, traceback): 99 | 100 | if self.timeout is not None: 101 | 102 | self.patcher.__exit__(None, None, None) 103 | self.fn_mode.__exit__(None, None, None) 104 | 105 | if isinstance(exc_value, Exception): 106 | raise exc_value 107 | 108 | def reset(self): 109 | 110 | self.start = time.time() 111 | 112 | def wrap(self, fn: Callable): 113 | 114 | @wraps(fn) 115 | def inner(*args, **kwargs): 116 | 117 | self.check() 118 | 119 | return fn(*args, **kwargs) 120 | 121 | return inner 122 | 123 | def check(self): 124 | 125 | if self.start and time.time() - self.start > self.timeout: 126 | 127 | self.start = 0 128 | 129 | raise Exception( 130 | f"Job took longer than timeout: {self.timeout} seconds" 131 | ) -------------------------------------------------------------------------------- /src/services/ray/src/schema: -------------------------------------------------------------------------------- 1 | ../../../common/schema/ -------------------------------------------------------------------------------- /src/services/ray/start-worker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$resource_name" ]; then 4 | resources=`python -m src.ray.resources` 5 | else 6 | resources=`python -m src.ray.resources --name $resource_name` 7 | fi 8 | 9 | ray start --resources "$resources" --address $RAY_ADDRESS 10 | 11 | tail -f /dev/null 12 | -------------------------------------------------------------------------------- /src/services/ray/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | resources=`python -m src.ray.resources --head` 4 | 5 | # Start Ray with environment variables from env_file 6 | ray start --head \ 7 | --resources="$resources" \ 8 | --port=$RAY_HEAD_INTERNAL_PORT \ 9 | --object-manager-port=$OBJECT_MANAGER_PORT \ 10 | --include-dashboard=true \ 11 | --dashboard-host=$RAY_DASHBOARD_HOST \ 12 | --dashboard-port=$RAY_DASHBOARD_INTERNAL_PORT \ 13 | --dashboard-grpc-port=$RAY_DASHBOARD_GRPC_PORT \ 14 | --metrics-export-port=$RAY_SERVE_INTERNAL_PORT 15 | 16 | serve deploy src/ray/config/ray_config.yml 17 | 18 | tail -f /dev/null 19 | -------------------------------------------------------------------------------- /telemetry/grafana/dashboards/telemetry.json: -------------------------------------------------------------------------------- 1 | { 2 | "annotations": { 3 | "list": [ 4 | { 5 | "builtIn": 1, 6 | "datasource": { 7 | "type": "grafana", 8 | "uid": "-- Grafana --" 9 | }, 10 | "enable": true, 11 | "hide": true, 12 | "iconColor": "rgba(0, 211, 255, 1)", 13 | "name": "Annotations & Alerts", 14 | "type": "dashboard" 15 | } 16 | ] 17 | }, 18 | "editable": true, 19 | "fiscalYearStartMonth": 0, 20 | "graphTooltip": 0, 21 | "id": 3, 22 | "links": [], 23 | "panels": [ 24 | { 25 | "datasource": { 26 | "type": "influxdb", 27 | "uid": "P951FEA4DE68E13C5" 28 | }, 29 | "fieldConfig": { 30 | "defaults": { 31 | "color": { 32 | "mode": "thresholds" 33 | }, 34 | "custom": { 35 | "align": "left", 36 | "cellOptions": { 37 | "type": "auto" 38 | }, 39 | "inspect": false 40 | }, 41 | "mappings": [], 42 | "thresholds": { 43 | "mode": "absolute", 44 | "steps": [ 45 | { 46 | "color": "green", 47 | "value": null 48 | }, 49 | { 50 | "color": "red", 51 | "value": 80 52 | } 53 | ] 54 | } 55 | }, 56 | "overrides": [] 57 | }, 58 | "gridPos": { 59 | "h": 8, 60 | "w": 12, 61 | "x": 0, 62 | "y": 0 63 | }, 64 | "id": 1, 65 | "options": { 66 | "cellHeight": "sm", 67 | "footer": { 68 | "countRows": false, 69 | "fields": "", 70 | "reducer": [ 71 | "sum" 72 | ], 73 | "show": false 74 | }, 75 | "frameIndex": 1, 76 | "showHeader": true 77 | }, 78 | "pluginVersion": "11.5.1", 79 | "targets": [ 80 | { 81 | "datasource": { 82 | "type": "influxdb", 83 | "uid": "P951FEA4DE68E13C5" 84 | }, 85 | "query": "from(bucket: \"data\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r[\"_measurement\"] == \"request_status\")", 86 | "refId": "A" 87 | } 88 | ], 89 | "title": "Request Status", 90 | "transformations": [ 91 | { 92 | "id": "labelsToFields", 93 | "options": {} 94 | }, 95 | { 96 | "id": "merge", 97 | "options": {} 98 | }, 99 | { 100 | "id": "sortBy", 101 | "options": { 102 | "fields": {}, 103 | "sort": [ 104 | { 105 | "desc": true, 106 | "field": "Time" 107 | } 108 | ] 109 | } 110 | } 111 | ], 112 | "type": "table" 113 | }, 114 | { 115 | "datasource": { 116 | "type": "influxdb", 117 | "uid": "P951FEA4DE68E13C5" 118 | }, 119 | "fieldConfig": { 120 | "defaults": { 121 | "color": { 122 | "mode": "thresholds" 123 | }, 124 | "custom": { 125 | "align": "left", 126 | "cellOptions": { 127 | "type": "auto" 128 | }, 129 | "inspect": false 130 | }, 131 | "mappings": [], 132 | "thresholds": { 133 | "mode": "absolute", 134 | "steps": [ 135 | { 136 | "color": "green", 137 | "value": null 138 | }, 139 | { 140 | "color": "red", 141 | "value": 80 142 | } 143 | ] 144 | } 145 | }, 146 | "overrides": [] 147 | }, 148 | "gridPos": { 149 | "h": 8, 150 | "w": 12, 151 | "x": 12, 152 | "y": 0 153 | }, 154 | "id": 2, 155 | "options": { 156 | "cellHeight": "sm", 157 | "footer": { 158 | "countRows": false, 159 | "fields": "", 160 | "reducer": [ 161 | "sum" 162 | ], 163 | "show": false 164 | }, 165 | "showHeader": true 166 | }, 167 | "pluginVersion": "11.5.1", 168 | "targets": [ 169 | { 170 | "query": "from(bucket: \"data\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r[\"_measurement\"] == \"gpu_mem\")", 171 | "refId": "A" 172 | } 173 | ], 174 | "title": "GPU memory", 175 | "transformations": [ 176 | { 177 | "id": "labelsToFields", 178 | "options": {} 179 | }, 180 | { 181 | "id": "merge", 182 | "options": {} 183 | } 184 | ], 185 | "type": "table" 186 | }, 187 | { 188 | "datasource": { 189 | "type": "influxdb", 190 | "uid": "P951FEA4DE68E13C5" 191 | }, 192 | "fieldConfig": { 193 | "defaults": { 194 | "color": { 195 | "mode": "thresholds" 196 | }, 197 | "custom": { 198 | "align": "left", 199 | "cellOptions": { 200 | "type": "auto" 201 | }, 202 | "inspect": false 203 | }, 204 | "mappings": [], 205 | "thresholds": { 206 | "mode": "absolute", 207 | "steps": [ 208 | { 209 | "color": "green", 210 | "value": null 211 | }, 212 | { 213 | "color": "red", 214 | "value": 80 215 | } 216 | ] 217 | } 218 | }, 219 | "overrides": [] 220 | }, 221 | "gridPos": { 222 | "h": 8, 223 | "w": 12, 224 | "x": 0, 225 | "y": 8 226 | }, 227 | "id": 4, 228 | "options": { 229 | "cellHeight": "sm", 230 | "footer": { 231 | "countRows": false, 232 | "fields": "", 233 | "reducer": [ 234 | "sum" 235 | ], 236 | "show": false 237 | }, 238 | "showHeader": true 239 | }, 240 | "pluginVersion": "11.5.1", 241 | "targets": [ 242 | { 243 | "query": "from(bucket: \"data\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r[\"_measurement\"] == \"network_data\")", 244 | "refId": "A" 245 | } 246 | ], 247 | "title": "Network Data", 248 | "transformations": [ 249 | { 250 | "id": "labelsToFields", 251 | "options": {} 252 | }, 253 | { 254 | "id": "merge", 255 | "options": {} 256 | } 257 | ], 258 | "type": "table" 259 | }, 260 | { 261 | "datasource": { 262 | "type": "influxdb", 263 | "uid": "P951FEA4DE68E13C5" 264 | }, 265 | "fieldConfig": { 266 | "defaults": { 267 | "color": { 268 | "mode": "thresholds" 269 | }, 270 | "custom": { 271 | "align": "left", 272 | "cellOptions": { 273 | "type": "auto" 274 | }, 275 | "inspect": false 276 | }, 277 | "mappings": [], 278 | "thresholds": { 279 | "mode": "absolute", 280 | "steps": [ 281 | { 282 | "color": "green", 283 | "value": null 284 | }, 285 | { 286 | "color": "red", 287 | "value": 80 288 | } 289 | ] 290 | } 291 | }, 292 | "overrides": [] 293 | }, 294 | "gridPos": { 295 | "h": 8, 296 | "w": 12, 297 | "x": 12, 298 | "y": 8 299 | }, 300 | "id": 3, 301 | "options": { 302 | "cellHeight": "sm", 303 | "footer": { 304 | "countRows": false, 305 | "fields": "", 306 | "reducer": [ 307 | "sum" 308 | ], 309 | "show": false 310 | }, 311 | "frameIndex": 0, 312 | "showHeader": true 313 | }, 314 | "pluginVersion": "11.5.1", 315 | "targets": [ 316 | { 317 | "query": "from(bucket: \"data\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r[\"_measurement\"] == \"stage_latency\")", 318 | "refId": "A" 319 | } 320 | ], 321 | "title": "Stage Latency", 322 | "transformations": [ 323 | { 324 | "id": "labelsToFields", 325 | "options": {} 326 | }, 327 | { 328 | "id": "merge", 329 | "options": {} 330 | } 331 | ], 332 | "type": "table" 333 | } 334 | ], 335 | "preload": false, 336 | "refresh": "", 337 | "schemaVersion": 40, 338 | "tags": [], 339 | "templating": { 340 | "list": [] 341 | }, 342 | "time": { 343 | "from": "now-6h", 344 | "to": "now" 345 | }, 346 | "timepicker": {}, 347 | "timezone": "browser", 348 | "title": "Telemetry", 349 | "uid": "dec6d7hcgf9xcc", 350 | "version": 9, 351 | "weekStart": "" 352 | } -------------------------------------------------------------------------------- /telemetry/grafana/provisioning/dashboards/telemetry.yml: -------------------------------------------------------------------------------- 1 | apiVersion: 1 2 | 3 | providers: 4 | - name: 'Telemetry' 5 | disableDeletion: false 6 | editable: true 7 | options: 8 | path: 9 | /var/lib/grafana/dashboards/ 10 | foldersFromFilesStructure: true -------------------------------------------------------------------------------- /telemetry/grafana/provisioning/datasources/influxdb.yml: -------------------------------------------------------------------------------- 1 | apiVersion: 1 2 | 3 | datasources: 4 | - name: InfluxDB 5 | type: influxdb 6 | access: proxy 7 | url: http://localhost:8086 8 | jsonData: 9 | version: Flux 10 | organization: NDIF 11 | defaultBucket: data 12 | tlsSkipVerify: true 13 | secureJsonData: 14 | token: ${INFLUXDB_ADMIN_TOKEN} 15 | -------------------------------------------------------------------------------- /telemetry/grafana/provisioning/datasources/prometheus.yml: -------------------------------------------------------------------------------- 1 | apiVersion: 1 2 | 3 | datasources: 4 | - name: Prometheus 5 | type: prometheus 6 | access: proxy 7 | url: http://localhost:9090 8 | -------------------------------------------------------------------------------- /telemetry/prometheus/prometheus.yml: -------------------------------------------------------------------------------- 1 | # Prometheus config file. 2 | 3 | # NOTE: Prometheus does not support environment variables, so changes made to the service ports for Ray and FastAPI metrics need to be reflected here! 4 | 5 | global: 6 | scrape_interval: 15s 7 | 8 | scrape_configs: 9 | - job_name: 'combined_metrics' # Merges metrics from FastAPI and Ray together (needed for Grafana to group metrics from a single request together). 10 | static_configs: 11 | - targets: 12 | - localhost:5000 # FastAPI - Prod 13 | - localhost:5001 # FastAPI - Dev 14 | file_sd_configs: 15 | - files: 16 | - '/tmp/ray/prom_metrics_service_discovery.json' # Contains dynamically updated list of ports & IP addresses of all the Ray nodes. 17 | metrics_path: /metrics 18 | metric_relabel_configs: 19 | - source_labels: [__name__] 20 | regex: 'ray_request_status' 21 | target_label: __name__ 22 | replacement: 'request_status' 23 | relabel_configs: 24 | - source_labels: [__address__] 25 | regex: '.*:(8267|5000|5001)' # Metric export ports for FastAPI and Ray services 26 | replacement: '${1}' 27 | target_label: job 28 | action: replace 29 | - source_labels: [job] 30 | regex: '8267' 31 | replacement: 'ray' 32 | target_label: job 33 | - source_labels: [job] 34 | regex: '5000' # Prod 35 | replacement: 'fast_api' 36 | target_label: job 37 | - source_labels: [job] 38 | regex: '5001' # Dev 39 | replacement: 'fast_api' 40 | target_label: job 41 | --------------------------------------------------------------------------------