├── test ├── __init__.py ├── utils.py ├── test_service.py ├── test_env.py ├── embedding │ └── test_data.py └── test_ctx.py ├── rust ├── rustfmt.toml ├── others │ ├── persia-common-benchmark │ │ ├── src │ │ │ └── lib.rs │ │ ├── Cargo.toml │ │ └── benches │ │ │ └── memcpy.rs │ ├── persia-nats-marcos │ │ └── Cargo.toml │ ├── persia-rpc │ │ ├── Cargo.toml │ │ └── src │ │ │ └── lib.rs │ ├── persia-nats-client │ │ ├── Cargo.toml │ │ └── src │ │ │ └── lib.rs │ └── persia-rpc-macro │ │ └── Cargo.toml ├── python │ └── persia_core │ │ ├── version.py │ │ └── __init__.py ├── persia-core │ ├── build.rs │ ├── src │ │ ├── cuda │ │ │ ├── utils.rs │ │ │ ├── cuda_stream_pool.rs │ │ │ ├── cuda_event_pool.rs │ │ │ ├── cuda_memory_pool.rs │ │ │ ├── pinned_memory_pool.rs │ │ │ ├── resource_pool.rs │ │ │ └── mod.rs │ │ ├── optim.rs │ │ ├── metrics.rs │ │ ├── dlpack.rs │ │ └── utils.rs │ └── Cargo.toml ├── persia-embedding-server │ ├── build.rs │ ├── src │ │ ├── lib.rs │ │ ├── monitor.rs │ │ └── bin │ │ │ ├── persia-embedding-parameter-server.rs │ │ │ └── persia-embedding-worker.rs │ └── Cargo.toml ├── persia-simd │ └── Cargo.toml ├── pyproject.toml ├── persia-storage │ └── Cargo.toml ├── persia-embedding-config │ ├── Cargo.toml │ └── examples │ │ └── global_config.yaml ├── persia-common │ ├── Cargo.toml │ └── src │ │ ├── grad.rs │ │ ├── utils.rs │ │ └── message_queue.rs ├── persia-metrics │ └── Cargo.toml ├── Cargo.toml ├── persia-model-manager │ └── Cargo.toml ├── persia-embedding-holder │ ├── Cargo.toml │ └── src │ │ ├── sharded.rs │ │ ├── lib.rs │ │ ├── eviction_map.rs │ │ └── emb_entry.rs ├── persia-libs │ ├── src │ │ └── lib.rs │ └── Cargo.toml ├── persia-incremental-update-manager │ └── Cargo.toml └── setup.py ├── docs ├── _autoapi_templates │ ├── index.rst │ └── python │ │ ├── attribute.rst │ │ ├── exception.rst │ │ ├── package.rst │ │ ├── function.rst │ │ ├── data.rst │ │ ├── method.rst │ │ ├── class.rst │ │ └── module.rst ├── doc-requirements.txt ├── redirect-index.html ├── build.sh ├── index.rst ├── _templates │ └── versions.html └── conf.py ├── persia ├── version.py ├── __init__.py ├── service.py ├── embedding │ ├── __init__.py │ └── optim.py ├── k8s_utils.py ├── prelude.py ├── utils.py ├── logger.py └── env.py ├── examples ├── src │ └── adult-income │ │ ├── .honcho.env │ │ ├── config │ │ ├── ts_config.properties │ │ ├── embedding_config.yml │ │ ├── global_config.yml │ │ └── global_config_infer.yml │ │ ├── .docker.env │ │ ├── data │ │ ├── prepare_data.sh │ │ └── data_preprocess.py │ │ ├── Procfile │ │ ├── Makefile │ │ ├── launch_ts.sh │ │ ├── data_loader.py │ │ ├── serve_handler.py │ │ ├── model.py │ │ ├── docker-compose.yml │ │ ├── serve_client.py │ │ ├── data_generator.py │ │ └── train.py └── README.md ├── .gitmodules ├── resources ├── grafana │ └── provisioning │ │ ├── dashboards │ │ └── default.yaml │ │ └── datasources │ │ └── default.yaml └── proto │ └── inference.proto ├── k8s ├── resources │ ├── nats.operator.temp.yaml │ ├── operator.persia.com.yaml │ └── server.persia.com.yaml ├── src │ ├── error.rs │ ├── bin │ │ ├── gencrd.rs │ │ └── operator.rs │ ├── finalizer.rs │ └── service.rs ├── Cargo.toml └── example │ └── adult-income-prediction.train.yml ├── setup.cfg ├── pyproject.toml ├── .github ├── dependabot.yml ├── stale.yml └── workflows │ └── rust_pipeline.yml ├── .buildkite ├── e2e │ ├── .ci.env │ ├── README.md │ ├── docker-compose.infer.yml │ └── docker-compose.train.yml ├── script │ ├── docker_image_process.sh │ └── k8s_system_test.sh └── pipeline.yml ├── Makefile ├── LICENSE ├── CITATION.cff ├── .gitignore ├── setup.py ├── Dockerfile └── README.md /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rust/rustfmt.toml: -------------------------------------------------------------------------------- 1 | edition = "2018" 2 | -------------------------------------------------------------------------------- /docs/_autoapi_templates/index.rst: -------------------------------------------------------------------------------- 1 | ../index.rst -------------------------------------------------------------------------------- /persia/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.dev13" 2 | -------------------------------------------------------------------------------- /rust/others/persia-common-benchmark/src/lib.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/_autoapi_templates/python/attribute.rst: -------------------------------------------------------------------------------- 1 | {% extends "python/data.rst" %} 2 | -------------------------------------------------------------------------------- /docs/_autoapi_templates/python/exception.rst: -------------------------------------------------------------------------------- 1 | {% extends "python/class.rst" %} 2 | -------------------------------------------------------------------------------- /docs/_autoapi_templates/python/package.rst: -------------------------------------------------------------------------------- 1 | {% extends "python/module.rst" %} 2 | -------------------------------------------------------------------------------- /docs/doc-requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-rtd-theme 2 | sphinx-multiversion 3 | sphinx-autoapi 4 | -------------------------------------------------------------------------------- /rust/python/persia_core/version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | __version__ = "dev" 4 | -------------------------------------------------------------------------------- /examples/src/adult-income/.honcho.env: -------------------------------------------------------------------------------- 1 | PERSIA_NATS_URL=nats://0.0.0.0:4222 2 | 3 | LOG_LEVEL=info 4 | -------------------------------------------------------------------------------- /examples/src/adult-income/config/ts_config.properties: -------------------------------------------------------------------------------- 1 | number_of_gpu=1 2 | default_workers_per_model=3 -------------------------------------------------------------------------------- /rust/persia-core/build.rs: -------------------------------------------------------------------------------- 1 | fn main() -> shadow_rs::SdResult<()> { 2 | shadow_rs::new() 3 | } 4 | -------------------------------------------------------------------------------- /rust/persia-embedding-server/build.rs: -------------------------------------------------------------------------------- 1 | fn main() -> shadow_rs::SdResult<()> { 2 | shadow_rs::new() 3 | } 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "rust/persia-speedy"] 2 | path = rust/persia-speedy 3 | url = git@github.com:PersiaML/persia-speedy.git 4 | -------------------------------------------------------------------------------- /rust/python/persia_core/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from .persia_core import * 4 | from .version import __version__ 5 | -------------------------------------------------------------------------------- /rust/persia-embedding-server/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod embedding_parameter_service; 2 | pub mod embedding_worker_service; 3 | pub mod monitor; 4 | -------------------------------------------------------------------------------- /test/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | def random_port(start: int = 10000, end: int = 65535) -> int: 5 | return random.randint(start, end) 6 | -------------------------------------------------------------------------------- /resources/grafana/provisioning/dashboards/default.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: 1 2 | 3 | providers: 4 | - name: Persia 5 | folder: Persia-Training 6 | type: file 7 | options: 8 | path: /workspace/grafana/dashboards/ -------------------------------------------------------------------------------- /rust/persia-simd/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "persia-simd" 3 | version = "0.1.0" 4 | authors = ["Kuaishou AI Platform PersiaML Team "] 5 | license = "MIT" 6 | edition = "2018" 7 | 8 | [dependencies] 9 | -------------------------------------------------------------------------------- /rust/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 43.0.0", "wheel", "setuptools-rust", "colorama", "tqdm", "setuptools_scm[toml]>=6.0"] 3 | build-backend = 'setuptools.build_meta' 4 | 5 | [tool.setuptools_scm] 6 | local_scheme = "no-local-version" 7 | root = ".." -------------------------------------------------------------------------------- /examples/src/adult-income/.docker.env: -------------------------------------------------------------------------------- 1 | PERSIA_NN_WORKER_ENTRY=/workspace/train.py 2 | PERSIA_DATALOADER_ENTRY=/workspace/data_loader.py 3 | PERSIA_EMBEDDING_CONFIG=/workspace/config/embedding_config.yml 4 | PERSIA_GLOBAL_CONFIG=/workspace/config/global_config.yml 5 | 6 | LOG_LEVEL=info -------------------------------------------------------------------------------- /k8s/resources/nats.operator.temp.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: "nats.io/v1alpha2" 2 | kind: "NatsCluster" 3 | metadata: 4 | name: "persia-nats-service" 5 | spec: 6 | size: 1 7 | natsConfig: 8 | maxPayload: 52428800 9 | resources: 10 | limits: 11 | memory: "8Gi" 12 | cpu: "2" -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pytype] 2 | inputs = persia 3 | 4 | [flake8] 5 | per-file-ignores = 6 | persia/__init__.py: F401,F403 7 | persia/prelude.py: F401,F402,E402 8 | 9 | ignore = E501,W503 10 | 11 | exclude = 12 | persia/sparse/__init__.py 13 | persia/sparse/emb.py 14 | persia/version.py -------------------------------------------------------------------------------- /docs/redirect-index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Redirecting to main branch 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /persia/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | 3 | from persia import logger as logger 4 | from persia import prelude as prelude 5 | from persia import ctx as ctx 6 | from persia import embedding as embedding 7 | from persia import data as data 8 | from persia import service as service 9 | from persia import utils as utils 10 | -------------------------------------------------------------------------------- /examples/src/adult-income/data/prepare_data.sh: -------------------------------------------------------------------------------- 1 | curl -o train.csv https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data 2 | curl -o test.csv https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test 3 | python3 data_preprocess.py --train-dataset train.csv --test-dataset test.csv --output_path . 4 | rm test.csv train.csv -------------------------------------------------------------------------------- /rust/others/persia-nats-marcos/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-nats-marcos" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | proc-macro2 = "1.0" 11 | quote = "1.0" 12 | syn = "1.0.81" 13 | 14 | [lib] 15 | proc-macro = true 16 | -------------------------------------------------------------------------------- /rust/others/persia-rpc/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-rpc" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | persia-libs = {path = "../../persia-libs"} 11 | persia-speedy = {path = "../../persia-speedy"} 12 | snafu = "0.6" 13 | -------------------------------------------------------------------------------- /persia/service.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import List 4 | 5 | 6 | def get_embedding_worker_services() -> List[str]: 7 | """Get a list of current embedding worker services.""" 8 | return ( 9 | [os.environ["EMBEDDING_WORKER_SERVICE"]] 10 | if os.environ.get("EMBEDDING_WORKER_SERVICE", None) 11 | else ["embedding_worker:8887"] 12 | ) 13 | -------------------------------------------------------------------------------- /rust/persia-storage/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-storage" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | persia-libs = {path = "../persia-libs"} 11 | persia-speedy = { path = "../persia-speedy" } 12 | enum_dispatch = "0.3.7" 13 | -------------------------------------------------------------------------------- /examples/src/adult-income/config/embedding_config.yml: -------------------------------------------------------------------------------- 1 | feature_index_prefix_bit: 12 2 | slots_config: 3 | workclass: 4 | dim: 8 5 | education: 6 | dim: 8 7 | marital_status: 8 | dim: 8 9 | occupation: 10 | dim: 8 11 | relationship: 12 | dim: 8 13 | race: 14 | dim: 8 15 | gender: 16 | dim: 8 17 | native_country: 18 | dim: 8 19 | feature_groups: {} 20 | -------------------------------------------------------------------------------- /rust/others/persia-nats-client/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-nats-client" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | async-nats = "0.10.1" 11 | persia-libs = {path = "../../persia-libs"} 12 | persia-speedy = { path = "../../persia-speedy" } 13 | -------------------------------------------------------------------------------- /docs/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | python3 -m pip install -r docs/doc-requirements.txt 6 | if [ "$BUILD_MULTIVERSION" == "1" ]; then 7 | git fetch --all --tags 8 | git checkout main 9 | git pull --unshallow 10 | python3 -m sphinx_multiversion docs build/html 11 | cp docs/redirect-index.html build/html/index.html 12 | else 13 | python3 -m sphinx docs build/html 14 | fi 15 | -------------------------------------------------------------------------------- /rust/persia-embedding-config/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-embedding-config" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | get_if_addrs = "0.5.3" 11 | local_ipaddress = "0.1.3" 12 | num-traits = "0.2.6" 13 | persia-libs = {path = "../persia-libs"} 14 | persia-speedy = {path = "../persia-speedy"} 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 43.0.0", "wheel", "setuptools-rust", "colorama", "tqdm", "setuptools_scm[toml]>=6.0"] 3 | build-backend = 'setuptools.build_meta' 4 | 5 | [tool.black] 6 | line-length = 88 7 | target-version = ['py37', 'py38'] 8 | include = '\.pyi?$' 9 | 10 | [tool.setuptools_scm] 11 | local_scheme = "no-local-version" 12 | write_to = "persia/version.py" 13 | write_to_template = "__version__ = \"{version}\"" 14 | -------------------------------------------------------------------------------- /rust/others/persia-rpc-macro/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-rpc-macro" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [lib] 10 | proc-macro = true 11 | 12 | [dependencies] 13 | proc-macro2 = "1" 14 | quote = "1.0" 15 | openssl = { version = "0.10", features = ["vendored"] } 16 | syn = {version = "1.0", features = ["full"]} 17 | -------------------------------------------------------------------------------- /rust/persia-common/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-common" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | persia-embedding-config = {path = "../persia-embedding-config"} 11 | persia-libs = {path = "../persia-libs"} 12 | persia-simd = {path = "../persia-simd"} 13 | persia-speedy = {path = "../persia-speedy"} 14 | -------------------------------------------------------------------------------- /k8s/src/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, thiserror::Error)] 2 | pub enum Error { 3 | #[error("Kubernetes reported error: {source}")] 4 | KubeError { 5 | #[from] 6 | source: kube::Error, 7 | }, 8 | #[error("Invalid PersiaJob CRD: {0}")] 9 | UserInputError(String), 10 | #[error("Failed to decode json format PersiaJobSpec: {0}")] 11 | JobSpecJsonDecodeError(String), 12 | #[error("Pod status is None")] 13 | NonePodStatusError, 14 | } 15 | -------------------------------------------------------------------------------- /rust/persia-metrics/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-metrics" 6 | version = "0.1.0" 7 | 8 | [dependencies] 9 | persia-embedding-config = {path = "../persia-embedding-config"} 10 | persia-libs = {path = "../persia-libs"} 11 | prometheus = {version = "0.13.0", features = ["push"]} 12 | prometheus_exporter = "0.8.3" 13 | scheduled-thread-pool = "0.2" 14 | -------------------------------------------------------------------------------- /examples/src/adult-income/config/global_config.yml: -------------------------------------------------------------------------------- 1 | common_config: 2 | metrics_config: 3 | enable_metrics: false 4 | push_interval_sec: 10 5 | job_type: Train 6 | checkpointing_config: 7 | num_workers: 8 8 | embedding_parameter_server_config: 9 | capacity: 1000000 10 | num_hashmap_internal_shards: 1 11 | enable_incremental_update: false 12 | incremental_buffer_size: 5000000 13 | incremental_channel_capacity: 1000 14 | embedding_worker_config: 15 | forward_buffer_size: 1000 -------------------------------------------------------------------------------- /rust/persia-embedding-config/examples/global_config.yaml: -------------------------------------------------------------------------------- 1 | common_config: 2 | metrics_config: 3 | enable_metrics: true 4 | push_interval_sec: 10 5 | job_type: Train 6 | checkpointing_config: 7 | num_workers: 8 8 | embedding_parameter_server_config: 9 | capacity: 1000000 10 | num_hashmap_internal_shards: 1 11 | enable_incremental_update: false 12 | incremental_buffer_size: 5000000 13 | incremental_channel_capacity: 1000 14 | embedding_worker_config: 15 | forward_buffer_size: 1000 -------------------------------------------------------------------------------- /rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | exclude = [] 3 | members = [ 4 | "persia-embedding-config", 5 | "persia-common", 6 | "persia-simd", 7 | "persia-embedding-server", 8 | "persia-core", 9 | "persia-libs", 10 | "persia-model-manager", 11 | "persia-storage", 12 | "persia-embedding-holder", 13 | "persia-metrics", 14 | "persia-speedy", 15 | "others/persia-common-benchmark", 16 | "others/persia-nats-client", 17 | "others/persia-nats-marcos", 18 | "others/persia-rpc", 19 | "others/persia-rpc-macro", 20 | ] 21 | -------------------------------------------------------------------------------- /rust/persia-model-manager/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-model-manager" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | persia-embedding-config = {path = "../persia-embedding-config"} 11 | persia-embedding-holder = {path = "../persia-embedding-holder"} 12 | persia-libs = {path = "../persia-libs"} 13 | persia-speedy = {path = "../persia-speedy"} 14 | persia-storage = {path = "../persia-storage"} 15 | -------------------------------------------------------------------------------- /resources/grafana/provisioning/datasources/default.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: 1 2 | 3 | datasources: 4 | # name of the datasource. Required 5 | - name: Prometheus 6 | # datasource type. Required 7 | type: prometheus 8 | # access mode. direct or proxy. Required 9 | access: proxy 10 | # org id. will default to orgId 1 if not specified 11 | orgId: 1 12 | # url 13 | url: http://prometheus:9090 14 | version: 1 15 | # allow users to edit datasources from the UI. 16 | editable: true -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "cargo" # See documentation for possible values 9 | directory: "rust/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | -------------------------------------------------------------------------------- /examples/src/adult-income/Procfile: -------------------------------------------------------------------------------- 1 | # Trick to block the data_loader to prevent process exit before training done.(works for linux and macos) 2 | data_loader: persia-launcher data-loader data_loader.py --replica-index 0 --replica-size 1 && cat 3 | nn_worker: persia-launcher nn-worker train.py 4 | embedding_worker: persia-launcher embedding-worker --embedding-config config/embedding_config.yml --global-config config/global_config.yml 5 | embedding_server: persia-launcher embedding-parameter-server --embedding-config config/embedding_config.yml --global-config config/global_config.yml 6 | nats_server: nats-server 7 | -------------------------------------------------------------------------------- /.buildkite/e2e/.ci.env: -------------------------------------------------------------------------------- 1 | LOG_LEVEL=info 2 | RUST_BACKTRACE=full 3 | 4 | PYTHONDONTWRITEBYTECODE=1 5 | 6 | PERSIA_NN_WORKER_ENTRY=/home/PERSIA/examples/src/adult-income/train.py 7 | PERSIA_DATALOADER_ENTRY=/home/PERSIA/examples/src/adult-income/data_loader.py 8 | PERSIA_EMBEDDING_CONFIG=/home/PERSIA/examples/src/adult-income/config/embedding_config.yml 9 | PERSIA_GLOBAL_CONFIG=/home/PERSIA/examples/src/adult-income/config/global_config.yml 10 | PERSIA_INFER_GLOBAL_CONFIG=/home/PERSIA/examples/src/adult-income/config/global_config_infer.yml 11 | 12 | PERSIA_CKPT_DIR=/cache/adult_income_ckpt 13 | REPRODUCIBLE=1 14 | EMBEDDING_STALENESS=1 -------------------------------------------------------------------------------- /rust/others/persia-common-benchmark/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | name = "persia-common-benchmark" 5 | version = "0.1.0" 6 | 7 | #[lib] 8 | #bench = false 9 | 10 | [dev-dependencies] 11 | bincode = "1" 12 | criterion = "0.3" 13 | criterion-macro = "0.3" 14 | persia-speedy = { path = "../../persia-speedy" } 15 | serde = {version = "1", features = ["derive"]} 16 | smallvec = "1" 17 | tinystr = "0.4" 18 | 19 | [[bench]] 20 | name = "memcpy" 21 | #harness = false 22 | 23 | [[bench]] 24 | name = "serialize_inf_request" 25 | #harness = false 26 | -------------------------------------------------------------------------------- /examples/src/adult-income/config/global_config_infer.yml: -------------------------------------------------------------------------------- 1 | common_config: 2 | metrics_config: 3 | enable_metrics: true 4 | push_interval_sec: 10 5 | job_type: Infer 6 | infer_config: 7 | servers: 8 | - embedding_parameter_server:8888 9 | embedding_checkpoint: /cache/adult_income_ckpt/ 10 | checkpointing_config: 11 | num_workers: 8 12 | embedding_parameter_server_config: 13 | capacity: 1000000 14 | num_hashmap_internal_shards: 1 15 | enable_incremental_update: false 16 | incremental_buffer_size: 5000000 17 | incremental_channel_capacity: 1000 18 | embedding_worker_config: 19 | forward_buffer_size: 1000 -------------------------------------------------------------------------------- /rust/persia-embedding-holder/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-embedding-holder" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | ahash = "0.7" 11 | array-linked-list = "0.1" 12 | persia-common = {path = "../persia-common"} 13 | persia-embedding-config = {path = "../persia-embedding-config"} 14 | persia-libs = {path = "../persia-libs"} 15 | persia-speedy = {path = "../persia-speedy"} 16 | 17 | [dev-dependencies] 18 | linked-hash-map = "0.5" 19 | rand = "0.8" 20 | rand_distr = "0.4" 21 | tracing = "0.1" 22 | -------------------------------------------------------------------------------- /.buildkite/script/docker_image_process.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | 5 | upload_image() { 6 | docker tag $1 $2 7 | docker push $2 8 | docker rmi $2 # remove the remote image ref 9 | } 10 | 11 | remove_image() { 12 | docker rmi -f $1 13 | } 14 | 15 | for image_name in "persia-cuda-runtime" "persia-cpu-runtime" 16 | do 17 | local_image_name=$image_name:$BUILDKITE_PIPELINE_ID 18 | remote_image_name=persiaml/$image_name:latest 19 | if [[ $1 == "upload" ]]; then 20 | upload_image $local_image_name $remote_image_name 21 | elif [[ $1 == "remove" ]]; then 22 | remove_image $local_image_name 23 | fi 24 | done 25 | -------------------------------------------------------------------------------- /.buildkite/e2e/README.md: -------------------------------------------------------------------------------- 1 | # Persia CI Process 2 | 3 | ## Build Multiple Persia Runtime Image 4 | 5 | The first step is to build the persia runtime image for e2e ci test.There will build two type of runtime to cover the cpu and cuda environment with `$BUILDKITE_PIPELINE_ID` tag. 6 | 7 | *ci runtime image below* 8 | - persia-cpu-runtime:$BUILDKITE_PIPELINE_ID 9 | - persia-cuda-runtime:$BUILDKITE_PIPELINE_ID 10 | 11 | 12 | ## Provide The Resource Folder 13 | To prevent the `buildkite-agent` can't remove the own files after docker-compose exit, there provide the `Persia/e2e/resource` folder to place the shared files when running multiple times docker-compose within one step. 14 | -------------------------------------------------------------------------------- /test/test_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from persia.service import get_embedding_worker_services 4 | 5 | 6 | def test_get_embedding_worker_services(): 7 | default_service = "embedding_worker:8887" 8 | for persia_service, service in zip( 9 | get_embedding_worker_services(), [default_service] 10 | ): 11 | assert persia_service == service 12 | 13 | embedding_worker_service = "localhost:8887" 14 | 15 | os.environ["EMBEDDING_WORKER_SERVICE"] = embedding_worker_service 16 | 17 | for persia_service, service in zip( 18 | get_embedding_worker_services(), [embedding_worker_service] 19 | ): 20 | assert persia_service == service 21 | -------------------------------------------------------------------------------- /rust/persia-libs/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub use anyhow; 2 | pub use async_lock; 3 | pub use backoff; 4 | pub use bytes; 5 | pub use chrono; 6 | pub use color_eyre; 7 | pub use flume; 8 | pub use futures; 9 | pub use half; 10 | pub use hashbrown; 11 | pub use hyper; 12 | pub use hyperloglogplus; 13 | pub use indexmap; 14 | pub use itertools; 15 | pub use lz4; 16 | pub use ndarray; 17 | pub use ndarray_rand; 18 | pub use once_cell; 19 | pub use parking_lot; 20 | pub use rand; 21 | pub use rayon; 22 | pub use serde; 23 | pub use serde_bytes; 24 | pub use serde_yaml; 25 | pub use smol; 26 | pub use thiserror; 27 | pub use tokio; 28 | pub use tracing; 29 | pub use tracing_subscriber; 30 | pub use url; 31 | -------------------------------------------------------------------------------- /rust/persia-incremental-update-manager/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-incremental-update-manager" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | griddle = "0.4" 11 | persia-embedding-config = {path = "../persia-embedding-config"} 12 | persia-common = {path = "../persia-common"} 13 | persia-embedding-holder = {path = "../persia-embedding-holder"} 14 | persia-libs = {path = "../persia-libs"} 15 | persia-metrics = {path = "../persia-metrics"} 16 | persia-storage = {path = "../persia-storage"} 17 | persia-speedy = { path = "../persia-speedy" } 18 | -------------------------------------------------------------------------------- /docs/_autoapi_templates/python/function.rst: -------------------------------------------------------------------------------- 1 | {% if obj.display %} 2 | .. function:: {{ obj.short_name }}({{ obj.args }}){% if obj.return_annotation is not none %} -> {{ obj.return_annotation }}{% endif %} 3 | 4 | {% for (args, return_annotation) in obj.overloads %} 5 | {{ obj.short_name }}({{ args }}){% if return_annotation is not none %} -> {{ return_annotation }}{% endif %} 6 | 7 | {% endfor %} 8 | {% if sphinx_version >= (2, 1) %} 9 | {% for property in obj.properties %} 10 | :{{ property }}: 11 | {% endfor %} 12 | {% endif %} 13 | 14 | {% if obj.docstring %} 15 | {{ obj.docstring|prepare_docstring|indent(3) }} 16 | {% else %} 17 | {% endif %} 18 | {% endif %} 19 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 7 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 1 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. PERSIA API Documentation documentation master file, created by 2 | sphinx-quickstart on Thu Jun 10 16:09:03 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | .. .. toctree:: 7 | .. :maxdepth: 4 8 | .. :caption: API Documents: 9 | 10 | PERSIA 11 | ====== 12 | 13 | This website contains PERSIA API documentation. See `tutorials `_ if you need step by step instructions on how to use PersiaML. 14 | 15 | 16 | .. toctree:: 17 | :titlesonly: 18 | :caption: API Documents 19 | 20 | {% for page in pages %} 21 | {% if page.top_level_object and page.display %} 22 | {{ page.include_path }} 23 | {% endif %} 24 | {% endfor %} 25 | 26 | -------------------------------------------------------------------------------- /k8s/src/bin/gencrd.rs: -------------------------------------------------------------------------------- 1 | use kube::CustomResourceExt; 2 | use persia_operator::crd::PersiaJob; 3 | use std::fs::File; 4 | use std::io::Write; 5 | use structopt::StructOpt; 6 | 7 | #[derive(Debug, StructOpt, Clone)] 8 | #[structopt()] 9 | struct Cli { 10 | #[structopt(long)] 11 | output: String, 12 | } 13 | 14 | fn main() { 15 | let args: Cli = Cli::from_args(); 16 | 17 | let crd = serde_yaml::to_string(&PersiaJob::crd()).unwrap(); 18 | 19 | match File::create(args.output) { 20 | Ok(mut output) => { 21 | if let Err(e) = write!(output, "{}", crd) { 22 | panic!("failed to write file due to {:?}", e); 23 | } 24 | } 25 | Err(e) => { 26 | panic!("failed to create file due to {:?}", e); 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /examples/src/adult-income/Makefile: -------------------------------------------------------------------------------- 1 | SERVICE := DEMO 2 | 3 | run: 4 | make stop && docker stack deploy -c docker-compose.yml $(SERVICE) && make nn_worker 5 | 6 | stop: 7 | docker stack rm $(SERVICE) 8 | 9 | data_loader: 10 | docker service logs -f `docker stack ps $(SERVICE) | grep data_loader|head -n 1|awk '{print $$1}'` 11 | 12 | nn_worker: 13 | docker service logs -f `docker stack ps $(SERVICE) | grep nn_worker|head -n 1|awk '{print $$1}'` 14 | 15 | server: 16 | docker service logs -f `docker stack ps $(SERVICE) | grep server|head -n 1|awk '{print $$1}'` 17 | 18 | embedding_worker: 19 | docker service logs -f `docker stack ps $(SERVICE) | grep embedding_worker|head -n 1|awk '{print $$1}'` 20 | 21 | nats: 22 | docker service logs -f `docker stack ps $(SERVICE) | grep persia_nats_service|head -n 1|awk '{print $$1}'` -------------------------------------------------------------------------------- /examples/src/adult-income/launch_ts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" 5 | 6 | torch-model-archiver \ 7 | --model-name adult_income \ 8 | --version 1.0 \ 9 | --serialized-file $PERSIA_CKPT_DIR/jit_dense.pt \ 10 | --handler $SCRIPTPATH/serve_handler.py \ 11 | --export-path $PERSIA_CKPT_DIR/ -f 12 | 13 | torchserve --start --ncs --model-store $PERSIA_CKPT_DIR \ 14 | --models adult_income.mar \ 15 | --ts-config $SCRIPTPATH/config/ts_config.properties & 16 | 17 | mkdir -p /cache/proto/ && python -m grpc_tools.protoc \ 18 | --proto_path=/proto/ \ 19 | --python_out=/cache/proto/ \ 20 | --grpc_python_out=/cache/proto/ \ 21 | /proto/inference.proto 22 | 23 | sleep 10s 24 | 25 | python $SCRIPTPATH/serve_client.py && torchserve --stop 26 | -------------------------------------------------------------------------------- /rust/persia-core/src/cuda/utils.rs: -------------------------------------------------------------------------------- 1 | use crate::cuda::cuda_event_pool::{CudaEventPtr, CUDA_EVENT_POOL}; 2 | use crate::cuda::cuda_stream_pool::CUDA_STREAM_POOL; 3 | 4 | use cuda_runtime_sys as cuda; 5 | use persia_libs::anyhow::Result; 6 | 7 | pub fn cuda_d2h( 8 | num_bytes: usize, 9 | data_ptr: *mut std::os::raw::c_void, 10 | host_ptr: *mut std::os::raw::c_void, 11 | ) -> Result { 12 | unsafe { 13 | let stream = CUDA_STREAM_POOL.allocate(0); 14 | let result = cuda::cudaMemcpyAsync( 15 | host_ptr, 16 | data_ptr, 17 | num_bytes, 18 | cuda::cudaMemcpyKind::cudaMemcpyDeviceToHost, 19 | stream.inner, 20 | ); 21 | assert_eq!(result, cuda::cudaError::cudaSuccess); 22 | let event = CUDA_EVENT_POOL.allocate(0); 23 | event.record(stream); 24 | Ok(event) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /docs/_autoapi_templates/python/data.rst: -------------------------------------------------------------------------------- 1 | {% if obj.display %} 2 | .. {{ obj.type }}:: {{ obj.name }} 3 | {%+ if obj.value is not none or obj.annotation is not none -%} 4 | :annotation: 5 | {%- if obj.annotation %} :{{ obj.annotation }} 6 | {%- endif %} 7 | {%- if obj.value is not none %} = {% 8 | if obj.value is string and obj.value.splitlines()|count > 1 -%} 9 | Multiline-String 10 | 11 | .. raw:: html 12 | 13 |
Show Value 14 | 15 | .. code-block:: text 16 | :linenos: 17 | 18 | {{ obj.value|indent(width=8) }} 19 | 20 | .. raw:: html 21 | 22 |
23 | 24 | {%- else -%} 25 | {{ obj.value|string|truncate(100) }} 26 | {%- endif %} 27 | {%- endif %} 28 | {% endif %} 29 | 30 | 31 | {{ obj.docstring|prepare_docstring|indent(3) }} 32 | {% endif %} 33 | -------------------------------------------------------------------------------- /docs/_templates/versions.html: -------------------------------------------------------------------------------- 1 | {%- if current_version %} 2 |
3 | 4 | Other Versions 5 | v: {{ current_version.name }} 6 | 7 | 8 |
9 | {%- if versions.tags %} 10 |
11 |
Tags
12 | {%- for item in versions.tags %} 13 |
{{ item.name }}
14 | {%- endfor %} 15 |
16 | {%- endif %} 17 | {%- if versions.branches %} 18 |
19 |
Branches
20 | {%- for item in versions.branches %} 21 |
{{ item.name }}
22 | {%- endfor %} 23 |
24 | {%- endif %} 25 |
26 |
27 | {%- endif %} 28 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # PERSIA Examples 2 | 3 | We provided several examples and multiple laucnher to help you quick start a *PERSIA* task. 4 | 5 | ## Honcho Launcher 6 | [Honcho](https://github.com/nickstenning/honcho) is a tool for managing multiple processes.Current honcho launcher only support launch the PERSIA Task in single node due to some distributed environments is hard to shared across multiple nodes. 7 | 8 | *launch example below* 9 | ```bash 10 | cd PERSIA/examples/src/adult-income 11 | honcho start -e .honcho.env 12 | ``` 13 | 14 | ## Docker Compose Launcher 15 | 16 | Docker [compose](https://docs.docker.com/compose/) can launch the multiple service under the swarm mode.Follow the [swarm mode](https://docs.docker.com/engine/swarm/) to adding multiple machines to swarm cluster to apply the distributed PERSIA training task. 17 | 18 | *launcher example below* 19 | ```bash 20 | cd PERSIA/examples/src/adult-income 21 | make run 22 | ``` -------------------------------------------------------------------------------- /examples/src/adult-income/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | 5 | from persia.embedding.data import PersiaBatch 6 | from persia.logger import get_logger 7 | from persia.ctx import DataCtx 8 | 9 | from data_generator import make_dataloader 10 | 11 | logger = get_logger("data_loader") 12 | 13 | train_filepath = os.path.join( 14 | os.path.dirname(os.path.realpath(__file__)), "data/train.npz" 15 | ) 16 | 17 | if __name__ == "__main__": 18 | with DataCtx() as ctx: 19 | loader = make_dataloader(train_filepath) 20 | for (non_id_type_feature, id_type_features, label) in tqdm( 21 | loader, desc="gen batch data..." 22 | ): 23 | persia_batch = PersiaBatch( 24 | id_type_features, 25 | non_id_type_features=[non_id_type_feature], 26 | labels=[label], 27 | requires_grad=True, 28 | ) 29 | ctx.send_data(persia_batch) 30 | -------------------------------------------------------------------------------- /rust/others/persia-common-benchmark/benches/memcpy.rs: -------------------------------------------------------------------------------- 1 | #![feature(custom_test_frameworks)] 2 | #![test_runner(criterion::runner)] 3 | 4 | use criterion::*; 5 | use criterion_macro::criterion; 6 | 7 | fn memcpy(bytes_src: &[u8], bytes_dst: &mut [u8]) { 8 | unsafe { 9 | std::ptr::copy(bytes_src.as_ptr(), bytes_dst.as_mut_ptr(), bytes_src.len()); 10 | } 11 | } 12 | 13 | #[criterion] 14 | fn bench_memcpy(c: &mut Criterion) { 15 | let bytes_src = vec![0; 1024 * 1024 * 5]; // 128 MB 16 | let mut bytes_dst = vec![0; 1024 * 1024 * 5]; // 128 MB 17 | let mut group = c.benchmark_group("memcpy"); 18 | group.throughput(Throughput::Bytes(bytes_src.len() as u64)); 19 | group.bench_function("memcpy", |b| { 20 | b.iter(|| { 21 | memcpy( 22 | black_box(bytes_src.as_slice()), 23 | black_box(bytes_dst.as_mut_slice()), 24 | ) 25 | }) 26 | }); 27 | group.finish(); 28 | } 29 | -------------------------------------------------------------------------------- /test/test_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from persia.env import ( 4 | get_local_rank, 5 | get_rank, 6 | get_world_size, 7 | get_replica_index, 8 | get_replica_size, 9 | reload_env, 10 | ) 11 | 12 | 13 | def test_nn_worker_env(): 14 | rank = 0 15 | local_rank = 0 16 | world_size = 1 17 | os.environ["RANK"] = str(rank) 18 | os.environ["WORLD_SIZE"] = str(world_size) 19 | os.environ["LOCAL_RANK"] = str(local_rank) 20 | 21 | reload_env() 22 | assert get_world_size() == world_size 23 | assert get_rank() == rank 24 | assert get_local_rank() == local_rank 25 | 26 | 27 | def test_data_loader_env(): 28 | replica_index = 0 29 | replica_size = 1 30 | 31 | del os.environ["RANK"] 32 | os.environ["REPLICA_SIZE"] = str(replica_size) 33 | os.environ["REPLICA_INDEX"] = str(replica_index) 34 | 35 | reload_env() 36 | assert get_replica_size() == replica_size 37 | assert get_replica_index() == replica_index 38 | -------------------------------------------------------------------------------- /docs/_autoapi_templates/python/method.rst: -------------------------------------------------------------------------------- 1 | {%- if obj.display %} 2 | {% if sphinx_version >= (2, 1) %} 3 | .. method:: {{ obj.short_name }}({{ obj.args }}){% if obj.return_annotation is not none %} -> {{ obj.return_annotation }}{% endif %} 4 | 5 | {% for (args, return_annotation) in obj.overloads %} 6 | {{ obj.short_name }}({{ args }}){% if return_annotation is not none %} -> {{ return_annotation }}{% endif %} 7 | 8 | {% endfor %} 9 | {% if obj.properties %} 10 | {% for property in obj.properties %} 11 | :{{ property }}: 12 | {% endfor %} 13 | 14 | {% else %} 15 | 16 | {% endif %} 17 | {% else %} 18 | .. {{ obj.method_type }}:: {{ obj.short_name }}({{ obj.args }}) 19 | {% for (args, return_annotation) in obj.overloads %} 20 | {{ " " * (obj.method_type | length) }} {{ obj.short_name }}({{ args }}) 21 | {% endfor %} 22 | 23 | {% endif %} 24 | {% if obj.docstring %} 25 | {{ obj.docstring|prepare_docstring|indent(3) }} 26 | {% endif %} 27 | {% endif %} 28 | -------------------------------------------------------------------------------- /resources/proto/inference.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.pytorch.serve.grpc.inference; 4 | 5 | import "google/protobuf/empty.proto"; 6 | 7 | option java_multiple_files = true; 8 | 9 | message PredictionsRequest { 10 | // Name of model. 11 | string model_name = 1; //required 12 | 13 | // Version of model to run prediction on. 14 | string model_version = 2; //optional 15 | 16 | // input data for model prediction 17 | map input = 3; //required 18 | } 19 | 20 | message PredictionResponse { 21 | // TorchServe health 22 | bytes prediction = 1; 23 | } 24 | 25 | message TorchServeHealthResponse { 26 | // TorchServe health 27 | string health = 1; 28 | } 29 | 30 | service InferenceAPIsService { 31 | rpc Ping(google.protobuf.Empty) returns (TorchServeHealthResponse) {} 32 | 33 | // Predictions entry point to get inference using default model version. 34 | rpc Predictions(PredictionsRequest) returns (PredictionResponse) {} 35 | } -------------------------------------------------------------------------------- /.github/workflows/rust_pipeline.yml: -------------------------------------------------------------------------------- 1 | name: rust fmt 2 | 3 | on: push 4 | 5 | jobs: 6 | format: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | with: 11 | submodules: recursive 12 | - uses: actions-rs/toolchain@v1 13 | with: 14 | toolchain: stable 15 | components: rustfmt 16 | override: true 17 | - name: rust-rustfmt-check 18 | uses: mbrobbel/rustfmt-check@master 19 | with: 20 | token: ${{ secrets.GITHUB_TOKEN }} 21 | args: --manifest-path rust/Cargo.toml --all 22 | test: 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v2 26 | with: 27 | submodules: recursive 28 | - name: Install latest nightly 29 | uses: actions-rs/toolchain@v1 30 | with: 31 | toolchain: nightly 32 | override: true 33 | - name: Run tests 34 | run: cd rust/ && cargo test --verbose 35 | -------------------------------------------------------------------------------- /rust/persia-embedding-holder/src/sharded.rs: -------------------------------------------------------------------------------- 1 | use persia_libs::parking_lot; 2 | use std::hash::{Hash, Hasher}; 3 | 4 | #[derive(Debug)] 5 | pub struct Sharded { 6 | pub inner: Vec>, 7 | pub phantom: std::marker::PhantomData, 8 | } 9 | 10 | #[inline] 11 | pub fn get_index(key: &K, count: usize) -> usize 12 | where 13 | K: Hash + Eq + Clone, 14 | { 15 | let mut s = ahash::AHasher::default(); 16 | key.hash(&mut s); 17 | (s.finish() as usize % count) as usize 18 | } 19 | 20 | impl Sharded { 21 | #[inline] 22 | pub fn shard(&self, key: &K) -> &parking_lot::RwLock 23 | where 24 | K: Hash + Eq + Clone, 25 | { 26 | unsafe { self.inner.get_unchecked(get_index(key, self.inner.len())) } 27 | } 28 | 29 | #[inline] 30 | pub fn get_shard_by_index(&self, idx: usize) -> &parking_lot::RwLock 31 | where 32 | K: Hash + Eq + Clone, 33 | { 34 | unsafe { self.inner.get_unchecked(idx) } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /.buildkite/script/k8s_system_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | cp ${BUILDKITE_BUILD_CHECKOUT_PATH}/k8s/resources/nats.operator.temp.yaml ${BUILDKITE_BUILD_CHECKOUT_PATH}/e2e/cache/nats.operator.${BUILDKITE_PIPELINE_ID}.yaml 5 | sed -i 's/persia-nats-service/persia-nats-service-'${BUILDKITE_PIPELINE_ID}'/g' ${BUILDKITE_BUILD_CHECKOUT_PATH}/e2e/cache/nats.operator.${BUILDKITE_PIPELINE_ID}.yaml 6 | kubectl apply -f ${BUILDKITE_BUILD_CHECKOUT_PATH}/e2e/cache/nats.operator.${BUILDKITE_PIPELINE_ID}.yaml 7 | 8 | docker run --rm -it -v $BUILDKITE_BUILD_CHECKOUT_PATH/e2e/cache:/cache persia-cpu-runtime:${BUILDKITE_PIPELINE_ID} bash -c "cp /opt/conda/lib/python3.8/site-packages/persia/e2e_test /cache" 9 | export KUBECONFIG=/etc/rancher/k3s/k3s.yaml 10 | $BUILDKITE_BUILD_CHECKOUT_PATH/e2e/cache/e2e_test; 11 | result=$?; 12 | 13 | kubectl delete -f ${BUILDKITE_BUILD_CHECKOUT_PATH}/e2e/cache/nats.operator.${BUILDKITE_PIPELINE_ID}.yaml 14 | rm ${BUILDKITE_BUILD_CHECKOUT_PATH}/e2e/cache/nats.operator.${BUILDKITE_PIPELINE_ID}.yaml 15 | 16 | exit $result -------------------------------------------------------------------------------- /k8s/src/finalizer.rs: -------------------------------------------------------------------------------- 1 | use crate::crd::PersiaJob; 2 | use kube::api::{Patch, PatchParams}; 3 | use kube::{Api, Client, Error}; 4 | use serde_json::{json, Value}; 5 | 6 | pub async fn add(client: Client, name: &str, namespace: &str) -> Result { 7 | let api: Api = Api::namespaced(client, namespace); 8 | let finalizer: Value = json!({ 9 | "metadata": { 10 | "finalizers": ["persiajobs.persia.com"] 11 | } 12 | }); 13 | 14 | let patch: Patch<&Value> = Patch::Merge(&finalizer); 15 | Ok(api.patch(name, &PatchParams::default(), &patch).await?) 16 | } 17 | 18 | pub async fn delete(client: Client, name: &str, namespace: &str) -> Result { 19 | let api: Api = Api::namespaced(client, namespace); 20 | let finalizer: Value = json!({ 21 | "metadata": { 22 | "finalizers": null 23 | } 24 | }); 25 | 26 | let patch: Patch<&Value> = Patch::Merge(&finalizer); 27 | Ok(api.patch(name, &PatchParams::default(), &patch).await?) 28 | } 29 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | IMAGE_TAG := test 2 | DEVICE := cuda 3 | 4 | lint: 5 | pytype 6 | 7 | flake8: 8 | python3 -m flake8 persia 9 | 10 | format: 11 | python3 -m black --config pyproject.toml 12 | 13 | pytest: 14 | pytest 15 | 16 | all: lint flake8 format 17 | 18 | build_dev_pip: 19 | USE_CUDA=1 pip3 install -e . --prefix=~/.local/ 20 | 21 | build_ci_image: 22 | DOCKER_BUILDKIT=1 docker build --build-arg DEVICE=cuda \ 23 | -t persia-ci:$(IMAGE_TAG) --target builder . 24 | 25 | build_dev_image: 26 | IMAGE_TAG=dev make build_cuda_runtime_image 27 | 28 | build_cuda_runtime_image: 29 | DOCKER_BUILDKIT=1 docker build --build-arg DEVICE=cuda \ 30 | -t persia-cuda-runtime:$(IMAGE_TAG) --target runtime . 31 | 32 | build_cpu_runtime_image: 33 | DOCKER_BUILDKIT=1 docker build --build-arg DEVICE=cpu --build-arg BASE_IMAGE="ubuntu:20.04" \ 34 | -t persia-cpu-runtime:$(IMAGE_TAG) --target runtime . 35 | 36 | build_runtime_image: build_cuda_runtime_image build_cpu_runtime_image 37 | 38 | build_all_image: build_ci_image build_cuda_runtime_image build_cpu_runtime_image 39 | -------------------------------------------------------------------------------- /persia/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | 4 | class EmbeddingConfig: 5 | r"""Embedding hyperparameters, argument of :class:`.EmbeddingCtx`.""" 6 | 7 | def __init__( 8 | self, 9 | emb_initialization: Tuple[float, float] = (-0.01, 0.01), 10 | admit_probability: float = 1.0, 11 | weight_bound: float = 10, 12 | ): 13 | """ 14 | Arguments: 15 | emb_initialization (Tuple[float, float], optional): lower and upper bound of embedding uniform initialization. 16 | admit_probability (float, optional): the probability (0<=, <=1) of admitting a new embedding. 17 | weight_bound (float, optional): restrict each element value of an embedding in [-weight_bound, weight_bound]. 18 | """ 19 | self.emb_initialization = emb_initialization 20 | self.admit_probability = admit_probability 21 | self.weight_bound = weight_bound 22 | 23 | 24 | def get_default_embedding_config(): 25 | """Get default embedding configuration.""" 26 | return EmbeddingConfig() 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kuaishou AI Platform 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /persia/k8s_utils.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from persia.utils import resolve_binary_execute_path, run_command 4 | 5 | 6 | @click.group() 7 | def cli(): 8 | ... 9 | 10 | 11 | @cli.command() 12 | @click.option( 13 | "--output", 14 | type=str, 15 | default="./jobs.persia.com.yaml", 16 | help="Persia k8s custom resource definition description file", 17 | ) 18 | def gencrd( 19 | output: str, 20 | ): 21 | executable_path = resolve_binary_execute_path("gencrd") 22 | cmd = [ 23 | executable_path, 24 | "--output", 25 | output, 26 | ] 27 | run_command(cmd) 28 | 29 | 30 | @cli.command() 31 | def operator(): 32 | executable_path = resolve_binary_execute_path("operator") 33 | run_command([executable_path]) 34 | 35 | 36 | @cli.command() 37 | @click.option( 38 | "--port", type=int, default="8080", help="Persia k8s schedule server port" 39 | ) 40 | def server( 41 | port: int, 42 | ): 43 | executable_path = resolve_binary_execute_path("server") 44 | cmd = [ 45 | executable_path, 46 | "--port", 47 | port, 48 | ] 49 | run_command(cmd) 50 | 51 | 52 | if __name__ == "__main__": 53 | cli() 54 | -------------------------------------------------------------------------------- /k8s/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-operator" 6 | publish = false 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | tokio = { version = "~1.0", features = ["full"] } 11 | kube = { version = "~0.60", default-features = true, features = ["derive", "rustls-tls"] } 12 | kube-derive = "~0.60" 13 | kube-runtime = "~0.60" 14 | k8s-openapi = { version = "~0.13", default-features = false, features = ["v1_17", "schemars"] } 15 | futures = "~0.3" 16 | serde = "~1.0" 17 | serde_json = "~1.0" 18 | schemars = "~0.8" 19 | thiserror = "~1.0" 20 | serde_yaml = "~0.8" 21 | parking_lot = {version = "~0.11", features = ["deadlock_detection"]} 22 | tracing = "~0.1" 23 | tracing-subscriber = "~0.2" 24 | actix-web = "~3.3" 25 | once_cell = "1.8.0" 26 | openssl = { version = "0.10", features = ["vendored"] } 27 | openssl-sys = "0.9.70" 28 | structopt = "0.3" 29 | collection_macros = "0.2.0" 30 | anyhow = "1.0" 31 | bytes = "1.1.0" 32 | 33 | [[bin]] 34 | name = "operator" 35 | path = "src/bin/operator.rs" 36 | 37 | [[bin]] 38 | name = "gencrd" 39 | path = "src/bin/gencrd.rs" 40 | 41 | [[bin]] 42 | name = "server" 43 | path = "src/bin/server.rs" 44 | 45 | [[bin]] 46 | name = "e2e" 47 | path = "src/bin/e2e.rs" 48 | -------------------------------------------------------------------------------- /rust/persia-core/src/cuda/cuda_stream_pool.rs: -------------------------------------------------------------------------------- 1 | use super::resource_pool::{Allocatable, Pool}; 2 | 3 | use cuda_runtime_sys as cuda; 4 | use persia_libs::once_cell; 5 | 6 | pub static CUDA_STREAM_POOL: once_cell::sync::Lazy> = 7 | once_cell::sync::Lazy::new(|| return Pool::new()); 8 | 9 | #[derive(Debug)] 10 | pub struct CudaStreamPtr { 11 | pub inner: cuda::cudaStream_t, 12 | } 13 | 14 | impl Default for CudaStreamPtr { 15 | fn default() -> Self { 16 | CudaStreamPtr { 17 | inner: std::ptr::null_mut(), 18 | } 19 | } 20 | } 21 | 22 | unsafe impl Send for CudaStreamPtr {} 23 | 24 | impl Allocatable for CudaStreamPtr { 25 | fn new(_size: usize) -> Self { 26 | let mut stream = std::ptr::null_mut(); 27 | let result = unsafe { 28 | cuda::cudaStreamCreateWithFlags( 29 | &mut stream as *mut cuda::cudaStream_t, 30 | cuda::cudaStreamNonBlocking, 31 | ) 32 | }; 33 | assert_eq!(result, cuda::cudaError::cudaSuccess); 34 | return CudaStreamPtr { inner: stream }; 35 | } 36 | 37 | fn size(&self) -> usize { 38 | 0 39 | } 40 | } 41 | 42 | impl Drop for CudaStreamPtr { 43 | fn drop(&mut self) { 44 | CUDA_STREAM_POOL.recycle(CudaStreamPtr { inner: self.inner }); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /rust/persia-libs/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | description = "commonly used future related things" 4 | edition = "2018" 5 | license = "MIT" 6 | name = "persia-libs" 7 | version = "0.1.0" 8 | 9 | [dependencies] 10 | anyhow = "1.0" 11 | async-lock = "2.1" 12 | backoff = {version = "0.3.0", features = ["tokio"]} 13 | bytes = {version = "1.0", features = ["serde"]} 14 | chrono = "0.4" 15 | color-eyre = "0.5" 16 | flume = "0.10" 17 | futures = "0.3" 18 | half = {version = "1.8", features = ["alloc", "std", "serde"]} 19 | hashbrown = {version = "0.11.2", features = ["serde"]} 20 | hyper = {version = "0.14.15", features = ["full"]} 21 | hyperloglogplus = "0.4" 22 | indexmap = {version = "1.5", features = ["serde-1"]} 23 | itertools = "0.10" 24 | lz4 = "1.23" 25 | ndarray = {version = "0.15.3", features = ["serde"]} 26 | ndarray-rand = "0.14" 27 | once_cell = "1.3" 28 | parking_lot = {version = "0.11", features = ["deadlock_detection"]} 29 | rand = {version = "0.8", features = ["small_rng"]} 30 | rayon = "1.5.1" 31 | serde = {version = "1.0", features = ["derive"]} 32 | serde_bytes = "0.11" 33 | serde_yaml = "0.8" 34 | smol = "1.0" 35 | thiserror = "1" 36 | tokio = {version = "1.13", features = ["full"]} 37 | tracing = "0.1" 38 | tracing-subscriber = {version = "0.3", features = ["env-filter"]} 39 | url = "2.1" 40 | -------------------------------------------------------------------------------- /examples/src/adult-income/serve_handler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | 5 | from persia.ctx import InferCtx 6 | from persia.service import get_embedding_worker_services 7 | from ts.torch_handler.base_handler import BaseHandler 8 | 9 | from ts.torch_handler.base_handler import BaseHandler 10 | 11 | device_id = 0 if torch.cuda.is_available() else None 12 | 13 | 14 | class PersiaHandler(BaseHandler, ABC): 15 | def initialize(self, context): 16 | super().initialize(context) 17 | embedding_worker_addrs = get_embedding_worker_services() 18 | self.persia_context = InferCtx(embedding_worker_addrs, device_id=device_id) 19 | self.persia_context.wait_for_serving() 20 | 21 | def preprocess(self, data): 22 | batch = data[0].get("batch") 23 | batch = bytes(batch) 24 | batch = self.persia_context.get_embedding_from_bytes(batch, device_id) 25 | 26 | model_input = self.persia_context.prepare_features(batch) 27 | return model_input 28 | 29 | def inference(self, data, *args, **kwargs): 30 | non_id_type_tensors, id_type_tensors, _ = data 31 | with torch.no_grad(): 32 | results = self.model(non_id_type_tensors, id_type_tensors) 33 | return results 34 | 35 | def postprocess(self, data): 36 | data = torch.reshape(data, (-1,)) 37 | data = data.tolist() 38 | return [data] 39 | -------------------------------------------------------------------------------- /rust/persia-common/src/grad.rs: -------------------------------------------------------------------------------- 1 | use persia_libs::{ 2 | half, 3 | ndarray::{arr2, Array2}, 4 | serde::{self, Deserialize, Serialize}, 5 | }; 6 | 7 | use persia_speedy::{Readable, Writable}; 8 | 9 | #[derive(Deserialize, Serialize, Readable, Writable, Debug)] 10 | #[serde(crate = "self::serde")] 11 | pub enum Gradients { 12 | F16(Array2), 13 | F32(Array2), 14 | } 15 | 16 | impl Default for Gradients { 17 | fn default() -> Self { 18 | Gradients::F32(arr2(&[[]])) 19 | } 20 | } 21 | 22 | #[derive(Deserialize, Serialize, Readable, Writable, Debug)] 23 | #[serde(crate = "self::serde")] 24 | pub struct FeatureEmbeddingGradientBatch { 25 | pub feature_name: String, 26 | pub gradients: Gradients, 27 | pub scale_factor: f32, 28 | } 29 | 30 | #[derive(Deserialize, Serialize, Readable, Writable, Debug)] 31 | #[serde(crate = "self::serde")] 32 | pub struct SkippedGradientBatch { 33 | pub feature_name: String, 34 | } 35 | 36 | #[derive(Deserialize, Serialize, Readable, Writable, Debug)] 37 | #[serde(crate = "self::serde")] 38 | pub enum SkippableFeatureEmbeddingGradientBatch { 39 | GradientBatch(FeatureEmbeddingGradientBatch), 40 | Skipped(SkippedGradientBatch), 41 | } 42 | 43 | #[derive(Deserialize, Serialize, Readable, Writable, Debug)] 44 | #[serde(crate = "self::serde")] 45 | pub struct EmbeddingGradientBatch { 46 | pub gradients: Vec, 47 | } 48 | -------------------------------------------------------------------------------- /.buildkite/e2e/docker-compose.infer.yml: -------------------------------------------------------------------------------- 1 | version: "3.3" 2 | services: 3 | torch_serve: 4 | env_file: 5 | - .ci.env 6 | image: persia-${IMAGE_TYPE}-runtime:${BUILDKITE_PIPELINE_ID} 7 | command: bash -c "/home/PERSIA/examples/src/adult-income/launch_ts.sh" 8 | volumes: 9 | - ../../resources/proto/:/proto 10 | - ./cache:/cache 11 | deploy: 12 | endpoint_mode: dnsrr 13 | depends_on: 14 | - embedding_worker 15 | - embedding_parameter_server 16 | 17 | embedding_worker: 18 | env_file: 19 | - .ci.env 20 | image: persia-${IMAGE_TYPE}-runtime:${BUILDKITE_PIPELINE_ID} 21 | command: bash -c "persia-launcher embedding-worker --replica-index 0 --replica-size 1 --global-config $$PERSIA_INFER_GLOBAL_CONFIG" 22 | deploy: 23 | endpoint_mode: dnsrr 24 | depends_on: 25 | - embedding_parameter_server 26 | volumes: 27 | - ./cache:/cache 28 | 29 | embedding_parameter_server: 30 | env_file: 31 | - .ci.env 32 | image: persia-${IMAGE_TYPE}-runtime:${BUILDKITE_PIPELINE_ID} 33 | command: bash -c "persia-launcher embedding-parameter-server --replica-index 0 --replica-size 1 --global-config $$PERSIA_INFER_GLOBAL_CONFIG" 34 | deploy: 35 | endpoint_mode: dnsrr 36 | volumes: 37 | - ./cache:/cache -------------------------------------------------------------------------------- /rust/persia-core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-core" 6 | publish = ["private"] 7 | version = "0.1.0" 8 | 9 | [lib] 10 | crate-type = ["cdylib"] 11 | name = "persia_core" 12 | path = "src/lib.rs" 13 | 14 | [dependencies] 15 | arr_macro = {version = "0.1", optional = true} 16 | cuda-runtime-sys = {version = "0.3.0-alpha.1", optional = true} 17 | numpy = "0.15" 18 | persia-common = {path = "../persia-common"} 19 | persia-embedding-config = {path = "../persia-embedding-config"} 20 | persia-embedding-holder = {path = "../persia-embedding-holder"} 21 | persia-embedding-server = {path = "../persia-embedding-server"} 22 | persia-libs = {path = "../persia-libs"} 23 | persia-metrics = {path = "../persia-metrics"} 24 | persia-model-manager = {path = "../persia-model-manager"} 25 | persia-nats-client = {path = "../others/persia-nats-client"} 26 | persia-nats-marcos = {path = "../others/persia-nats-marcos"} 27 | persia-rpc = {path = "../others/persia-rpc"} 28 | persia-rpc-macro = {path = "../others/persia-rpc-macro"} 29 | persia-speedy = {path = "../persia-speedy"} 30 | persia-storage = {path = "../persia-storage"} 31 | shadow-rs = "0.8.1" 32 | 33 | [features] 34 | default = [] 35 | 36 | cuda = ["cuda-runtime-sys", "arr_macro"] 37 | 38 | [dependencies.pyo3] 39 | default-features = false 40 | features = ["macros", "multiple-pymethods"] 41 | version = "0.15.1" 42 | 43 | [build-dependencies] 44 | shadow-rs = "0.8.1" 45 | -------------------------------------------------------------------------------- /rust/persia-embedding-server/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Kuaishou AI Platform PersiaML Team "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "persia-embedding-server" 6 | version = "0.1.0" 7 | 8 | [dependencies] 9 | crossbeam = "0.8" 10 | farmhash = "1" 11 | persia-common = {path = "../persia-common", default-features = false} 12 | persia-embedding-config = {path = "../persia-embedding-config"} 13 | persia-embedding-holder = {path = "../persia-embedding-holder"} 14 | persia-incremental-update-manager = {path = "../persia-incremental-update-manager"} 15 | persia-libs = {path = "../persia-libs"} 16 | persia-metrics = {path = "../persia-metrics"} 17 | persia-model-manager = {path = "../persia-model-manager"} 18 | persia-nats-client = {path = "../others/persia-nats-client"} 19 | persia-nats-marcos = {path = "../others/persia-nats-marcos"} 20 | persia-rpc = {path = "../others/persia-rpc"} 21 | persia-rpc-macro = {path = "../others/persia-rpc-macro"} 22 | persia-simd = {path = "../persia-simd"} 23 | persia-speedy = {path = "../persia-speedy"} 24 | shadow-rs = "0.8" 25 | snafu = "0.6" 26 | structopt = "0.3" 27 | tokio = {version = "1.13", features = ["full"]} 28 | tracing = "0.1" 29 | 30 | [dev-dependencies] 31 | rand = "0.8" 32 | 33 | [[bin]] 34 | name = "persia-embedding-parameter-server" 35 | path = "src/bin/persia-embedding-parameter-server.rs" 36 | 37 | [[bin]] 38 | name = "persia-embedding-worker" 39 | path = "src/bin/persia-embedding-worker.rs" 40 | 41 | [build-dependencies] 42 | shadow-rs = "0.8" 43 | -------------------------------------------------------------------------------- /.buildkite/e2e/docker-compose.train.yml: -------------------------------------------------------------------------------- 1 | version: "3.3" 2 | services: 3 | persia_nats_service: 4 | image: nats:latest 5 | deploy: 6 | replicas: 1 7 | 8 | data_loader: 9 | env_file: 10 | - .ci.env 11 | image: persia-${IMAGE_TYPE}-runtime:${BUILDKITE_PIPELINE_ID} 12 | command: persia-launcher data-loader --replica-index 0 --replica-size 1 13 | 14 | nn_worker: 15 | env_file: 16 | - .ci.env 17 | environment: 18 | CUBLAS_WORKSPACE_CONFIG: :4096:8 # Adapt to pytorch deterministic feature 19 | image: persia-${IMAGE_TYPE}-runtime:${BUILDKITE_PIPELINE_ID} 20 | command: persia-launcher nn-worker 21 | depends_on: 22 | - data_loader 23 | - embedding_worker 24 | - embedding_parameter_server 25 | - persia_nats_service 26 | volumes: 27 | - ./cache:/cache 28 | 29 | embedding_worker: 30 | env_file: 31 | - .ci.env 32 | image: persia-${IMAGE_TYPE}-runtime:${BUILDKITE_PIPELINE_ID} 33 | command: persia-launcher embedding-worker --replica-index 0 --replica-size 1 34 | volumes: 35 | - ./cache:/cache 36 | 37 | embedding_parameter_server: 38 | env_file: 39 | - .ci.env 40 | image: persia-${IMAGE_TYPE}-runtime:${BUILDKITE_PIPELINE_ID} 41 | command: persia-launcher embedding-parameter-server --replica-index 0 --replica-size 1 42 | volumes: 43 | - ./cache:/cache -------------------------------------------------------------------------------- /rust/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | from setuptools import setup, find_packages 5 | from setuptools_rust import Binding, RustExtension 6 | 7 | 8 | if __name__ == "__main__": 9 | import colorama 10 | 11 | colorama.init(autoreset=True) 12 | 13 | use_cuda = os.environ.get("USE_CUDA", False) 14 | features = None if not use_cuda else ["cuda"] 15 | 16 | setup( 17 | name="persia-core", 18 | use_scm_version={ 19 | "local_scheme": "no-local-version", 20 | "root": "..", 21 | "write_to_template": '__version__ = "{version}"', 22 | "write_to": os.path.join( 23 | os.path.dirname(os.path.abspath(__file__)), 24 | "python/persia_core/version.py", 25 | ), 26 | }, 27 | setup_requires=["setuptools_scm"], 28 | url="https://github.com/PersiaML/PersiaML/rust", 29 | python_requires=">=3.7", 30 | description="Core Python binding for PersiaML.", 31 | package_dir={"": "python/"}, 32 | packages=find_packages("python/"), 33 | rust_extensions=[ 34 | RustExtension( 35 | "persia_core.persia_core", 36 | path="persia-core/Cargo.toml", 37 | binding=Binding.PyO3, 38 | native=True, 39 | features=features, 40 | ) 41 | ], 42 | author="Kuaishou AI Platform", 43 | author_email="admin@mail.xrlian.com", 44 | install_requires=[], 45 | zip_safe=False, 46 | ) 47 | -------------------------------------------------------------------------------- /k8s/example/adult-income-prediction.train.yml: -------------------------------------------------------------------------------- 1 | apiVersion: persia.com/v1 2 | kind: PersiaJob 3 | metadata: 4 | name: adult-income 5 | namespace: default 6 | spec: 7 | persiaEnv: 8 | PERSIA_GLOBAL_CONFIG: /home/PERSIA/examples/src/adult-income/config/global_config.yml 9 | PERSIA_EMBEDDING_CONFIG: /home/PERSIA/examples/src/adult-income/config/embedding_config.yml 10 | PERSIA_NN_WORKER_ENTRY: /home/PERSIA/examples/src/adult-income/train.py 11 | PERSIA_DATALOADER_ENTRY: /home/PERSIA/examples/src/adult-income/data_loader.py 12 | env: 13 | - name: PERSIA_NATS_URL 14 | value: nats://persia-nats-service:4222 15 | 16 | embeddingParameterServer: 17 | replicas: 1 18 | resources: 19 | limits: 20 | memory: "24Gi" 21 | cpu: "4" 22 | 23 | embeddingWorker: 24 | replicas: 1 25 | resources: 26 | limits: 27 | memory: "24Gi" 28 | cpu: "4" 29 | 30 | nnWorker: 31 | replicas: 1 32 | nprocPerNode: 1 33 | resources: 34 | limits: 35 | memory: "24Gi" 36 | cpu: "12" 37 | nvidia.com/gpu: "1" 38 | env: 39 | - name: CUBLAS_WORKSPACE_CONFIG 40 | value: :4096:8 41 | - name: ENABLE_CUDA 42 | value: "1" 43 | 44 | dataloader: 45 | replicas: 1 46 | resources: 47 | limits: 48 | memory: "8Gi" 49 | cpu: "1" 50 | 51 | --- 52 | apiVersion: "nats.io/v1alpha2" 53 | kind: "NatsCluster" 54 | metadata: 55 | name: "persia-nats-service" 56 | spec: 57 | size: 1 58 | natsConfig: 59 | maxPayload: 52428800 60 | resources: 61 | limits: 62 | memory: "8Gi" 63 | cpu: "2" -------------------------------------------------------------------------------- /persia/prelude.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from types import ModuleType 4 | 5 | # pytype: disable=import-error 6 | import persia_core 7 | 8 | # pytype: enable=import-error 9 | 10 | 11 | def register_submodule(module: ModuleType, root_module_path: str): 12 | """Register the persia core module to sys module path. 13 | 14 | Arguments: 15 | module (ModuleType): root module. 16 | root_module_path (str): root module path. 17 | """ 18 | for attr in dir(module): 19 | if attr.startswith("__"): 20 | continue 21 | obj = getattr(module, attr) 22 | if isinstance(obj, ModuleType): 23 | submodule_name = attr 24 | full_path = f"{root_module_path}.{submodule_name}" 25 | sys.modules[full_path] = obj 26 | register_submodule(obj, full_path) 27 | 28 | 29 | register_submodule( 30 | persia_core, 31 | persia_core.__name__, 32 | ) 33 | 34 | # pytype: disable=import-error 35 | from persia_core import PersiaCommonContext, is_cuda_feature_available # noqa 36 | 37 | from persia_core.optim import OptimizerBase 38 | from persia_core.data import PersiaBatch as _PersiaBatch, check_pyarray_dtype_valid 39 | from persia_core.utils import ( 40 | PersiaMessageQueueServer, 41 | PersiaMessageQueueClient, 42 | PersiaBatchDataChannel, 43 | PersiaBatchDataSender, 44 | PersiaBatchDataReceiver, 45 | ) 46 | from persia_core.nats import initialize_dataflow # noqa 47 | 48 | from persia_core.backward import Backward # noqa 49 | from persia_core.forward import Forward, Tensor, PersiaTrainingBatch # noqa 50 | 51 | # pytype: enable=import-error 52 | -------------------------------------------------------------------------------- /examples/src/adult-income/model.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class DNN(nn.Module): 8 | def __init__( 9 | self, dense_mlp_output_size: int = 16, sparse_mlp_output_size: int = 128 10 | ): 11 | super(DNN, self).__init__() 12 | 13 | self.dense_mlp = torch.nn.Linear(5, dense_mlp_output_size) 14 | self.dense_bn = nn.BatchNorm1d(dense_mlp_output_size) 15 | 16 | self.sparse_mlp = torch.nn.Linear(64, sparse_mlp_output_size) 17 | self.sparse_bn = nn.BatchNorm1d(sparse_mlp_output_size) 18 | 19 | self.ln1 = nn.Linear(dense_mlp_output_size + sparse_mlp_output_size, 256) 20 | self.ln2 = nn.Linear(256, 128) 21 | self.ln3 = nn.Linear(128, 1) 22 | self.sigmoid = nn.Sigmoid() 23 | 24 | def forward( 25 | self, non_id_tensors: List[torch.Tensor], embedding_tensors: List[torch.Tensor] 26 | ): 27 | dense_x = non_id_tensors[0] 28 | 29 | sparse_concat = torch.cat(embedding_tensors, dim=1) 30 | sparse = self.sparse_mlp(sparse_concat.float()) 31 | sparse = self.sparse_bn(sparse) 32 | 33 | dense_x = self.dense_mlp(dense_x) 34 | dense_x = self.dense_bn(dense_x) 35 | x = torch.cat([sparse, dense_x], dim=1) 36 | x = self.ln1(x) 37 | x = self.ln2(x) 38 | x = self.ln3(x) 39 | 40 | return self.sigmoid(x) 41 | 42 | 43 | if __name__ == "__main__": 44 | model = DNN() 45 | batch_size = 64 46 | dense = torch.ones(batch_size, 5) 47 | sparses = [torch.ones(batch_size, 8) for _ in range(8)] 48 | output = model(dense, sparses) 49 | print(output) 50 | -------------------------------------------------------------------------------- /rust/persia-common/src/utils.rs: -------------------------------------------------------------------------------- 1 | use persia_libs::{flume, parking_lot, tracing}; 2 | 3 | #[derive(Clone)] 4 | pub struct ChannelPair { 5 | pub sender: flume::Sender, 6 | pub receiver: flume::Receiver, 7 | } 8 | 9 | impl ChannelPair { 10 | pub fn new(cap: usize) -> Self { 11 | let (sender, receiver) = flume::bounded(cap); 12 | Self { sender, receiver } 13 | } 14 | 15 | pub fn new_unbounded() -> Self { 16 | let (sender, receiver) = flume::unbounded(); 17 | Self { sender, receiver } 18 | } 19 | } 20 | 21 | pub fn start_deadlock_detection_thread() { 22 | if std::env::var("PERSIA_DEADLOCK_DETECTION") 23 | .unwrap_or(String::from("0")) 24 | .parse::() 25 | .expect("PERSIA_DEADLOCK_DETECTION should be 0 or 1") 26 | > 0 27 | { 28 | std::thread::spawn(move || { 29 | tracing::info!("deadlock detection thread started"); 30 | loop { 31 | std::thread::sleep(std::time::Duration::from_secs(60)); 32 | let deadlocks = parking_lot::deadlock::check_deadlock(); 33 | if deadlocks.is_empty() { 34 | continue; 35 | } 36 | 37 | tracing::error!("{} deadlocks detected", deadlocks.len()); 38 | for (i, threads) in deadlocks.iter().enumerate() { 39 | tracing::error!("Deadlock #{}", i); 40 | for t in threads { 41 | tracing::error!("Thread Id {:#?}", t.thread_id()); 42 | tracing::error!("{:#?}", t.backtrace()); 43 | } 44 | } 45 | } 46 | }); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /rust/persia-core/src/cuda/cuda_event_pool.rs: -------------------------------------------------------------------------------- 1 | use super::cuda_stream_pool::CudaStreamPtr; 2 | use super::resource_pool::{Allocatable, Pool}; 3 | 4 | use cuda_runtime_sys as cuda; 5 | use persia_libs::once_cell; 6 | 7 | pub static CUDA_EVENT_POOL: once_cell::sync::Lazy> = 8 | once_cell::sync::Lazy::new(|| return Pool::new()); 9 | 10 | #[derive(Debug)] 11 | #[must_use = "d2h or h2d memcpy should synchronize manually"] 12 | pub struct CudaEventPtr { 13 | pub inner: cuda::cudaEvent_t, 14 | } 15 | 16 | impl Default for CudaEventPtr { 17 | fn default() -> Self { 18 | CudaEventPtr { 19 | inner: std::ptr::null_mut(), 20 | } 21 | } 22 | } 23 | 24 | unsafe impl Send for CudaEventPtr {} 25 | 26 | impl CudaEventPtr { 27 | pub fn record(&self, stream: CudaStreamPtr) { 28 | let result = unsafe { cuda::cudaEventRecord(self.inner, stream.inner) }; 29 | assert_eq!(result, cuda::cudaError::cudaSuccess); 30 | } 31 | 32 | pub fn synchronize(&self) { 33 | let result = unsafe { cuda::cudaEventSynchronize(self.inner) }; 34 | assert_eq!(result, cuda::cudaError::cudaSuccess); 35 | } 36 | } 37 | 38 | impl Allocatable for CudaEventPtr { 39 | fn new(_size: usize) -> Self { 40 | let mut event = std::ptr::null_mut(); 41 | let result = unsafe { cuda::cudaEventCreate(&mut event as *mut cuda::cudaEvent_t) }; 42 | assert_eq!(result, cuda::cudaError::cudaSuccess); 43 | 44 | return CudaEventPtr { inner: event }; 45 | } 46 | 47 | fn size(&self) -> usize { 48 | 0 49 | } 50 | } 51 | 52 | impl Drop for CudaEventPtr { 53 | fn drop(&mut self) { 54 | CUDA_EVENT_POOL.recycle(CudaEventPtr { inner: self.inner }); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Lian" 5 | given-names: "Xiangru" 6 | orcid: "https://orcid.org/0000-0003-4456-8127" 7 | - family-names: "Yuan" 8 | given-names: "Binhang" 9 | - family-names: "Zhu" 10 | given-names: "Xuefeng" 11 | - family-names: "Wang" 12 | given-names: "Yulong" 13 | - family-names: "He" 14 | given-names: "Yongjun" 15 | - family-names: "Wu" 16 | given-names: "Honghuan" 17 | - family-names: "Sun" 18 | given-names: "Lei" 19 | - family-names: "Lyu" 20 | given-names: "Haodong" 21 | - family-names: "Liu" 22 | given-names: "Chengjun" 23 | - family-names: "Dong" 24 | given-names: "Xing" 25 | - family-names: "Liao" 26 | given-names: "Yiqiao" 27 | - family-names: "Luo" 28 | given-names: "Mingnan" 29 | - family-names: "Zhang" 30 | given-names: "Congfei" 31 | - family-names: "Xie" 32 | given-names: "Jingru" 33 | - family-names: "Li" 34 | given-names: "Haonan" 35 | - family-names: "Chen" 36 | given-names: "Lei" 37 | - family-names: "Huang" 38 | given-names: "Renjie" 39 | - family-names: "Lin" 40 | given-names: "Jianying" 41 | - family-names: "Shu" 42 | given-names: "Chengchun" 43 | - family-names: "Qiu" 44 | given-names: "Xuezhong" 45 | - family-names: "Liu" 46 | given-names: "Zhishan" 47 | - family-names: "Kong" 48 | given-names: "Dongying" 49 | - family-names: "Yuan" 50 | given-names: "Lei" 51 | - family-names: "Yu" 52 | given-names: "Hai" 53 | - family-names: "Yang" 54 | given-names: "Sen" 55 | - family-names: "Zhang" 56 | given-names: "Ce" 57 | - family-names: "Liu" 58 | given-names: "Ji" 59 | title: "Persia: A Hybrid System Scaling Deep Learning Based Recommenders up to 100 Trillion Parameters" 60 | date-released: 2021-11-12 61 | url: "https://github.com/PersiaML/Persia" 62 | 63 | -------------------------------------------------------------------------------- /rust/persia-core/src/cuda/cuda_memory_pool.rs: -------------------------------------------------------------------------------- 1 | use super::resource_pool::{Allocatable, Pool}; 2 | 3 | use cuda_runtime_sys as cuda; 4 | use persia_libs::{once_cell, tracing}; 5 | 6 | pub static CUDA_DEVICE_MEMORY_POOL: once_cell::sync::Lazy> = 7 | once_cell::sync::Lazy::new(|| return Pool::new()); 8 | 9 | /// We are not going to implement Drop trait for it, since we cannot recycle here, we need to recycle after the training process call free. 10 | #[derive(Debug)] 11 | pub struct CudaMallocPtr { 12 | pub inner: *mut std::os::raw::c_void, 13 | pub num_bytes: usize, 14 | } 15 | 16 | impl Default for CudaMallocPtr { 17 | fn default() -> Self { 18 | Self { 19 | inner: std::ptr::null_mut::(), 20 | num_bytes: 0, 21 | } 22 | } 23 | } 24 | 25 | impl Drop for CudaMallocPtr { 26 | fn drop(&mut self) { 27 | tracing::debug!("cuda pinned memory recycled, size {}", self.num_bytes); 28 | CUDA_DEVICE_MEMORY_POOL.recycle(CudaMallocPtr { 29 | inner: self.inner, 30 | num_bytes: self.num_bytes, 31 | }); 32 | } 33 | } 34 | 35 | unsafe impl Send for CudaMallocPtr {} 36 | 37 | impl Allocatable for CudaMallocPtr { 38 | fn new(size: usize) -> Self { 39 | let mut data_ptr: *mut std::os::raw::c_void = std::ptr::null_mut(); 40 | let result = 41 | unsafe { cuda::cudaMalloc(&mut data_ptr as *mut *mut std::os::raw::c_void, size) }; 42 | assert_eq!(result, cuda::cudaError::cudaSuccess); 43 | tracing::debug!("allocating cuda pinned memory, size {}", size); 44 | return CudaMallocPtr { 45 | inner: data_ptr, 46 | num_bytes: size, 47 | }; 48 | } 49 | 50 | fn size(&self) -> usize { 51 | self.num_bytes 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /rust/persia-core/src/cuda/pinned_memory_pool.rs: -------------------------------------------------------------------------------- 1 | use super::resource_pool::{Allocatable, Pool}; 2 | 3 | use cuda_runtime_sys as cuda; 4 | use persia_libs::once_cell; 5 | 6 | pub static PINNED_MEMORY_POOL: once_cell::sync::Lazy> = 7 | once_cell::sync::Lazy::new(|| return Pool::new()); 8 | 9 | #[derive(Debug)] 10 | pub struct PinnedMemoryPtr { 11 | pub inner: *mut std::os::raw::c_void, 12 | pub num_bytes: usize, 13 | } 14 | 15 | impl Default for PinnedMemoryPtr { 16 | fn default() -> Self { 17 | PinnedMemoryPtr { 18 | inner: std::ptr::null_mut(), 19 | num_bytes: 0, 20 | } 21 | } 22 | } 23 | 24 | impl Drop for PinnedMemoryPtr { 25 | fn drop(&mut self) { 26 | PINNED_MEMORY_POOL.recycle(PinnedMemoryPtr { 27 | inner: self.inner, 28 | num_bytes: self.num_bytes, 29 | }); 30 | } 31 | } 32 | 33 | unsafe impl Send for PinnedMemoryPtr {} 34 | 35 | impl Allocatable for PinnedMemoryPtr { 36 | fn new(size: usize) -> Self { 37 | let mut data_ptr: *mut std::os::raw::c_void = std::ptr::null_mut(); 38 | let result = 39 | unsafe { cuda::cudaMallocHost(&mut data_ptr as *mut *mut std::os::raw::c_void, size) }; 40 | assert_eq!(result, cuda::cudaError::cudaSuccess); 41 | return PinnedMemoryPtr { 42 | inner: data_ptr, 43 | num_bytes: size, 44 | }; 45 | } 46 | 47 | fn size(&self) -> usize { 48 | self.num_bytes 49 | } 50 | } 51 | 52 | impl PinnedMemoryPtr { 53 | pub fn as_slice(&self, num_elements: usize) -> &[T] { 54 | assert!( 55 | num_elements * std::mem::size_of::() <= self.num_bytes, 56 | "num_elements {}, num_bytes {}", 57 | num_elements, 58 | self.num_bytes 59 | ); 60 | unsafe { std::slice::from_raw_parts(self.inner as *const T, num_elements) } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /.buildkite/pipeline.yml: -------------------------------------------------------------------------------- 1 | steps: 2 | - label: "Build docker ci image for latest persiaml framework(CPU)" 3 | command: IMAGE_TAG=$BUILDKITE_PIPELINE_ID make build_cuda_runtime_image -e 4 | 5 | - label: "Build docker ci image for latest persiaml framework(GPU)" 6 | command: IMAGE_TAG=$BUILDKITE_PIPELINE_ID make build_cpu_runtime_image -e 7 | 8 | - wait 9 | 10 | - label: "launch e2e gpu test" 11 | key: "e2e-gpu-test" 12 | env: 13 | IMAGE_TYPE: "cuda" 14 | plugins: 15 | - docker-compose#v3.8.0: 16 | config: ".buildkite/e2e/docker-compose.train.yml" 17 | run: nn_worker 18 | graceful-shutdown: true 19 | - docker-compose#v3.8.0: 20 | config: ".buildkite/e2e/docker-compose.infer.yml" 21 | run: torch_serve 22 | graceful-shutdown: true 23 | 24 | - label: "launch e2e cpu test" 25 | key: "e2e-cpu-test" 26 | env: 27 | IMAGE_TYPE: "cpu" 28 | plugins: 29 | - docker-compose#v3.8.0: 30 | config: ".buildkite/e2e/docker-compose.train.yml" 31 | run: nn_worker 32 | graceful-shutdown: true 33 | - docker-compose#v3.8.0: 34 | config: ".buildkite/e2e/docker-compose.infer.yml" 35 | run: torch_serve 36 | graceful-shutdown: true 37 | 38 | - label: "pytest" 39 | command: 40 | - cd /test/ && pytest . -o log_cli=true 41 | plugins: 42 | - docker#v3.9.0: 43 | image: "persia-cuda-runtime:${BUILDKITE_PIPELINE_ID}" 44 | volumes: 45 | - ./test/:/test 46 | 47 | - label: "launch k8s e2e test" 48 | command: .buildkite/script/k8s_system_test.sh 49 | 50 | - wait 51 | 52 | - label: "upload docker image" 53 | branches: "main" 54 | command: 55 | - docker login -u $$DOCKER_USER -p $$DOCKER_TOKEN 56 | - .buildkite/script/docker_image_process.sh upload 57 | 58 | - wait 59 | 60 | - label: "remove docker image" 61 | command: .buildkite/script/docker_image_process.sh remove 62 | -------------------------------------------------------------------------------- /test/embedding/test_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import pytest 4 | 5 | from persia.embedding.data import ( 6 | _ND_ARRAY_SUPPORT_TYPE, 7 | NdarrayDataBase, 8 | IDTypeFeature, 9 | IDTypeFeatureWithSingleID, 10 | PersiaBatch, 11 | ) 12 | from persia.prelude import check_pyarray_dtype_valid 13 | 14 | 15 | def test_ndarray_base_dtype_convert_to_tensor(): 16 | for dtype in _ND_ARRAY_SUPPORT_TYPE: 17 | data = np.zeros((1), dtype=dtype) 18 | check_pyarray_dtype_valid(data, data.dtype) 19 | 20 | 21 | def test_ndarray_base_data(): 22 | # test dtype support 23 | for dtype in _ND_ARRAY_SUPPORT_TYPE: 24 | NdarrayDataBase(np.zeros(1, dtype=dtype)) 25 | 26 | # test batch_size and name 27 | ndarray_base_name = "test_name" 28 | data = NdarrayDataBase(np.array([1]), ndarray_base_name) 29 | assert data.batch_size == 1 30 | assert data.name == ndarray_base_name 31 | 32 | 33 | def test_id_type_feature(): 34 | id_type_feature_name = "test_name" 35 | 36 | assert ( 37 | IDTypeFeature(id_type_feature_name, [np.array([1], dtype=np.uint64)]).batch_size 38 | == 1 39 | ) 40 | 41 | 42 | def test_sparse_id_type_feature(): 43 | id_type_feature_name = "test_name" 44 | 45 | assert ( 46 | IDTypeFeatureWithSingleID( 47 | id_type_feature_name, np.array([1], dtype=np.uint64) 48 | ).batch_size 49 | == 1 50 | ) 51 | 52 | 53 | def test_persia_batch(): 54 | 55 | # test grad without label 56 | with pytest.raises(RuntimeError): 57 | PersiaBatch( 58 | id_type_features=[ 59 | IDTypeFeature("test_name", [np.array([1], dtype=np.uint64)]) 60 | ] 61 | ) 62 | 63 | # test serialize bytes 64 | persia_batch = PersiaBatch( 65 | id_type_features=[IDTypeFeature("test_name", [np.array([1], dtype=np.uint64)])], 66 | requires_grad=False, 67 | ) 68 | 69 | persia_batch_bytes = persia_batch.to_bytes() 70 | assert isinstance(persia_batch_bytes, bytes) 71 | -------------------------------------------------------------------------------- /rust/persia-core/src/optim.rs: -------------------------------------------------------------------------------- 1 | use crate::PersiaCommonContextImpl; 2 | 3 | use pyo3::exceptions::PyRuntimeError; 4 | use pyo3::prelude::*; 5 | 6 | use persia_common::optim::{AdagradConfig, AdamConfig, NaiveSGDConfig, OptimizerConfig}; 7 | 8 | #[pyclass] 9 | pub struct OptimizerBase { 10 | inner: Option, 11 | } 12 | 13 | impl OptimizerBase { 14 | pub fn get_inner(&self) -> Option { 15 | self.inner.clone() 16 | } 17 | } 18 | 19 | #[pymethods] 20 | impl OptimizerBase { 21 | #[new] 22 | pub fn new() -> Self { 23 | Self { inner: None } 24 | } 25 | 26 | pub fn init_adagrad( 27 | &mut self, 28 | lr: f32, 29 | wd: f32, 30 | g_square_momentum: f32, 31 | initialization: f32, 32 | eps: f32, 33 | vectorwise_shared: Option, 34 | ) -> () { 35 | let config = AdagradConfig { 36 | lr, 37 | wd, 38 | g_square_momentum, 39 | initialization, 40 | eps, 41 | vectorwise_shared: vectorwise_shared.unwrap_or(false), 42 | }; 43 | self.inner = Some(OptimizerConfig::Adagrad(config)); 44 | } 45 | 46 | pub fn init_sgd(&mut self, lr: f32, wd: f32) -> () { 47 | let config = NaiveSGDConfig { lr, wd }; 48 | self.inner = Some(OptimizerConfig::SGD(config)); 49 | } 50 | 51 | pub fn init_adam(&mut self, lr: f32, betas: (f32, f32), eps: f32) -> () { 52 | let config = AdamConfig { 53 | lr, 54 | beta1: betas.0, 55 | beta2: betas.1, 56 | eps, 57 | }; 58 | self.inner = Some(OptimizerConfig::Adam(config)); 59 | } 60 | 61 | pub fn apply(&self) -> PyResult<()> { 62 | let context = PersiaCommonContextImpl::get(); 63 | context 64 | .register_optimizer(self) 65 | .map_err(|e| PyRuntimeError::new_err(e.to_string())) 66 | } 67 | } 68 | 69 | pub fn init_module(super_module: &PyModule, py: Python) -> PyResult<()> { 70 | let module = PyModule::new(py, "optim")?; 71 | module.add_class::()?; 72 | super_module.add_submodule(module)?; 73 | Ok(()) 74 | } 75 | -------------------------------------------------------------------------------- /rust/persia-core/src/metrics.rs: -------------------------------------------------------------------------------- 1 | use persia_metrics::{Gauge, IntCounter, PersiaMetricsManager, PersiaMetricsManagerError}; 2 | 3 | use persia_libs::{anyhow::Result, once_cell::sync::OnceCell}; 4 | 5 | static METRICS_HOLDER: OnceCell = OnceCell::new(); 6 | 7 | pub struct MetricsHolder { 8 | pub forward_client_to_gpu_time_cost_sec: Gauge, 9 | pub forward_client_time_cost_sec: Gauge, 10 | pub forward_error: IntCounter, 11 | pub backward_client_time_cost_sec: Gauge, 12 | pub get_train_batch_time_cost_more_than_1ms_sec: Gauge, 13 | pub update_gradient_batched_time_cost_more_than_1ms_sec: Gauge, 14 | } 15 | 16 | impl MetricsHolder { 17 | pub fn get() -> Result<&'static Self, PersiaMetricsManagerError> { 18 | METRICS_HOLDER.get_or_try_init(|| { 19 | let m = PersiaMetricsManager::get()?; 20 | let holder = Self { 21 | forward_client_to_gpu_time_cost_sec: m.create_gauge( 22 | "forward_client_to_gpu_time_cost_sec", 23 | "get batched dense data and embeddings, then send to device time cost" 24 | )?, 25 | forward_client_time_cost_sec: m.create_gauge("forward_client_time_cost_sec", "get embeddings time cost")?, 26 | forward_error: m.create_counter("forward_error", "get embedding error counter")?, 27 | backward_client_time_cost_sec: m.create_gauge( 28 | "backward_client_time_cost_sec", 29 | "get graident backward packet and update it to server time cost" 30 | )?, 31 | get_train_batch_time_cost_more_than_1ms_sec: m.create_gauge( 32 | "get_train_batch_time_cost_more_than_1ms_sec", 33 | "get train batch time cost when it takes more than 1ms" 34 | )?, 35 | update_gradient_batched_time_cost_more_than_1ms_sec: m.create_gauge( 36 | "update_gradient_batched_time_cost_more_than_1ms_sec", 37 | "send gradient of embedding to gradient update buffer time cost when it takes more than 1ms" 38 | )?, 39 | }; 40 | Ok(holder) 41 | }) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /docs/_autoapi_templates/python/class.rst: -------------------------------------------------------------------------------- 1 | {% if obj.display %} 2 | .. {{ obj.type }}:: {{ obj.short_name }}{% if obj.args %}({{ obj.args }}){% endif %} 3 | {% if obj.constructor %} 4 | 5 | {% for (args, return_annotation) in obj.constructor.overloads %} 6 | {% if args and args.startswith("self, ") %}{% set args = args[6:] %}{% endif %} 7 | {{ " " * (obj.type | length) }} {{ obj.short_name }}{% if args %}({{ args }}){% endif %} 8 | {% endfor %} 9 | {% endif %} 10 | 11 | 12 | {% if obj.bases %} 13 | {% if "show-inheritance" in autoapi_options %} 14 | Bases: {% for base in obj.bases %}{{ base|link_objs }}{% if not loop.last %}, {% endif %}{% endfor %} 15 | {% endif %} 16 | 17 | 18 | {% if "show-inheritance-diagram" in autoapi_options and obj.bases != ["object"] %} 19 | .. autoapi-inheritance-diagram:: {{ obj.obj["full_name"] }} 20 | :parts: 1 21 | {% if "private-members" in autoapi_options %} 22 | :private-bases: 23 | {% endif %} 24 | 25 | {% endif %} 26 | {% endif %} 27 | {% if obj.docstring %} 28 | {{ obj.docstring|prepare_docstring|indent(3) }} 29 | {% endif %} 30 | {% if "inherited-members" in autoapi_options %} 31 | {% set visible_classes = obj.classes|selectattr("display")|list %} 32 | {% else %} 33 | {% set visible_classes = obj.classes|rejectattr("inherited")|selectattr("display")|list %} 34 | {% endif %} 35 | {% for klass in visible_classes %} 36 | {{ klass.render()|indent(3) }} 37 | {% endfor %} 38 | {% if "inherited-members" in autoapi_options %} 39 | {% set visible_attributes = obj.attributes|selectattr("display")|list %} 40 | {% else %} 41 | {% set visible_attributes = obj.attributes|rejectattr("inherited")|selectattr("display")|list %} 42 | {% endif %} 43 | {% for attribute in visible_attributes %} 44 | {{ attribute.render()|indent(3) }} 45 | {% endfor %} 46 | {% if "inherited-members" in autoapi_options %} 47 | {% set visible_methods = obj.methods|selectattr("display")|list %} 48 | {% else %} 49 | {% set visible_methods = obj.methods|rejectattr("inherited")|selectattr("display")|list %} 50 | {% endif %} 51 | {% for method in visible_methods %} 52 | {{ method.render()|indent(3) }} 53 | {% endfor %} 54 | {% endif %} 55 | -------------------------------------------------------------------------------- /examples/src/adult-income/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.2" 2 | services: 3 | persia_nats_service: 4 | image: nats:latest 5 | deploy: 6 | replicas: 1 7 | 8 | data_loader: 9 | env_file: 10 | - .docker.env 11 | depends_on: 12 | - nn_worker 13 | - embedding_worker 14 | - persia_nats_service 15 | image: persiaml/persia-cuda-runtime:latest 16 | command: persia-launcher data-loader --replica-index 0 --replica-size 1 17 | volumes: 18 | - type: bind 19 | source: . 20 | target: /workspace 21 | deploy: 22 | replicas: 1 23 | restart_policy: 24 | condition: on-failure 25 | 26 | nn_worker: 27 | env_file: 28 | - .docker.env 29 | environment: 30 | NCCL_SOCKET_IFNAME: eth0 31 | CUBLAS_WORKSPACE_CONFIG: :4096:8 32 | image: persiaml/persia-cuda-runtime:latest 33 | command: persia-launcher nn-worker --nproc-per-node 1 --nnodes 1 --node-rank 0 34 | volumes: 35 | - type: bind 36 | source: . 37 | target: /workspace 38 | deploy: 39 | replicas: 1 40 | restart_policy: 41 | condition: on-failure 42 | 43 | embedding_worker: 44 | env_file: 45 | - .docker.env 46 | depends_on: 47 | - embedding_parameter_server 48 | image: persiaml/persia-cuda-runtime:latest 49 | command: > 50 | bash -c "persia-launcher embedding-worker --embedding-config $$PERSIA_EMBEDDING_CONFIG 51 | --global-config $$PERSIA_GLOBAL_CONFIG --replica-index 0 --replica-size 1" 52 | deploy: 53 | replicas: 1 54 | restart_policy: 55 | condition: on-failure 56 | volumes: 57 | - type: bind 58 | source: . 59 | target: /workspace 60 | 61 | embedding_parameter_server: 62 | env_file: 63 | - .docker.env 64 | image: persiaml/persia-cuda-runtime:latest 65 | command: > 66 | bash -c "persia-launcher embedding-parameter-server --embedding-config $$PERSIA_EMBEDDING_CONFIG 67 | --global-config $$PERSIA_GLOBAL_CONFIG --replica-index 0 --replica-size 1" 68 | deploy: 69 | replicas: 1 70 | restart_policy: 71 | condition: on-failure 72 | volumes: 73 | - type: bind 74 | source: . 75 | target: /workspace -------------------------------------------------------------------------------- /k8s/src/service.rs: -------------------------------------------------------------------------------- 1 | use crate::crd::PersiaJobSpec; 2 | use crate::error::Error; 3 | use crate::PersiaJobResources; 4 | 5 | use kube::client::Client; 6 | 7 | pub struct PersiaJobSchedulingService { 8 | pub kubernetes_client: Client, 9 | } 10 | 11 | impl PersiaJobSchedulingService { 12 | pub fn new(kubernetes_client: Client) -> Self { 13 | Self { kubernetes_client } 14 | } 15 | 16 | pub async fn apply( 17 | &self, 18 | job_name: &str, 19 | namespace: &str, 20 | json_spec: &str, 21 | ) -> Result<(), Error> { 22 | let spec: PersiaJobSpec = serde_json::from_str(json_spec) 23 | .map_err(|e| Error::JobSpecJsonDecodeError(format!("{:?}", e)))?; 24 | let job_resources = 25 | PersiaJobResources::new(&spec, job_name, namespace, self.kubernetes_client.clone()); 26 | 27 | job_resources.apply().await?; 28 | 29 | Ok(()) 30 | } 31 | 32 | pub async fn delete(&self, job_name: &str, namespace: &str) -> Result<(), Error> { 33 | PersiaJobResources::delete_services(self.kubernetes_client.clone(), namespace, job_name) 34 | .await?; 35 | PersiaJobResources::delete_pods(self.kubernetes_client.clone(), namespace, job_name) 36 | .await?; 37 | 38 | Ok(()) 39 | } 40 | 41 | pub async fn list_pods(&self, job_name: &str, namespace: &str) -> Result, Error> { 42 | let pods = 43 | PersiaJobResources::get_pods_name(self.kubernetes_client.clone(), namespace, job_name) 44 | .await?; 45 | Ok(pods) 46 | } 47 | 48 | pub async fn pod_log(&self, namespace: &str, pod_name: &str) -> Result { 49 | let log = 50 | PersiaJobResources::get_pod_log(self.kubernetes_client.clone(), namespace, pod_name) 51 | .await?; 52 | Ok(log) 53 | } 54 | 55 | pub async fn job_status(&self, namespace: &str, pod_name: &str) -> Result { 56 | let status = 57 | PersiaJobResources::get_pod_status(self.kubernetes_client.clone(), namespace, pod_name) 58 | .await? 59 | .ok_or(Error::NonePodStatusError)?; 60 | 61 | let status = serde_json::to_string(&status).unwrap(); 62 | Ok(status) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /examples/src/adult-income/serve_client.py: -------------------------------------------------------------------------------- 1 | import grpc 2 | import os 3 | import sys 4 | import json 5 | 6 | sys.path.append("/cache/proto/") 7 | 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | from sklearn import metrics 12 | from persia.embedding.data import PersiaBatch 13 | 14 | import inference_pb2 15 | import inference_pb2_grpc 16 | 17 | from data_generator import make_dataloader 18 | 19 | 20 | def get_inference_stub(): 21 | channel = grpc.insecure_channel("localhost:7070") 22 | stub = inference_pb2_grpc.InferenceAPIsServiceStub(channel) 23 | return stub 24 | 25 | 26 | def infer(stub, model_name, model_input): 27 | input_data = {"batch": model_input} 28 | response = stub.Predictions( 29 | inference_pb2.PredictionsRequest(model_name=model_name, input=input_data) 30 | ) 31 | try: 32 | prediction = response.prediction.decode("utf-8") 33 | prediction = prediction.splitlines() 34 | prediction = [x.strip() for x in prediction] 35 | prediction = [x.replace(",", "") for x in prediction] 36 | prediction = prediction[1:-1] 37 | prediction = [float(x) for x in prediction] 38 | return prediction 39 | except: 40 | exit(1) 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | test_filepath = os.path.join( 46 | os.path.dirname(os.path.realpath(__file__)), "data/test.npz" 47 | ) 48 | loader = make_dataloader(test_filepath, batch_size=1024) 49 | all_pred = [] 50 | all_label = [] 51 | 52 | for (non_id_type_feature, id_type_features, label) in tqdm( 53 | loader, desc="gen batch data..." 54 | ): 55 | batch_data = PersiaBatch( 56 | id_type_features, 57 | non_id_type_features=[non_id_type_feature], 58 | requires_grad=False, 59 | ) 60 | model_input = batch_data.to_bytes() 61 | prediction = infer(get_inference_stub(), "adult_income", model_input) 62 | 63 | assert len(prediction) == len( 64 | label 65 | ), f"Missing results: prediction length({len(prediction)}) does not match label length({len(label)})" 66 | 67 | all_label.append(label.data) 68 | all_pred.append(prediction) 69 | 70 | all_pred, all_label = np.concatenate(all_pred), np.concatenate(all_label) 71 | 72 | fpr, tpr, th = metrics.roc_curve(all_label, all_pred) 73 | infer_auc = metrics.auc(fpr, tpr) 74 | 75 | print(f"infer_auc = {infer_auc}") 76 | 77 | assert ( 78 | infer_auc > 0.8927 79 | ), f"infer error, expect infer_auc > 0.8927 but got {infer_auc}" 80 | -------------------------------------------------------------------------------- /rust/persia-core/src/cuda/resource_pool.rs: -------------------------------------------------------------------------------- 1 | use std::sync::atomic::Ordering; 2 | 3 | use persia_libs::tracing; 4 | 5 | use persia_common::utils::ChannelPair; 6 | 7 | pub trait Allocatable { 8 | fn new(size: usize) -> Self; 9 | fn size(&self) -> usize; 10 | } 11 | 12 | pub struct Pool { 13 | /// each entry represents an allocation queue of 2**n bytes block 14 | sub_pools: [SubPool; 30], 15 | } 16 | 17 | impl Default for Pool { 18 | fn default() -> Self { 19 | Pool::new() 20 | } 21 | } 22 | 23 | impl Pool 24 | where 25 | T: Allocatable, 26 | { 27 | pub fn new() -> Self { 28 | Self { 29 | sub_pools: arr_macro::arr![SubPool::new(); 30], 30 | } 31 | } 32 | 33 | fn get_pool_location(size: usize) -> usize { 34 | if size == 0 { 35 | 0 36 | } else { 37 | (size.next_power_of_two().trailing_zeros() + 1) as usize 38 | } 39 | } 40 | 41 | pub fn allocate(&self, size: usize) -> T { 42 | let pool = &self.sub_pools[Self::get_pool_location(size)]; 43 | pool.allocate(size.next_power_of_two()) 44 | } 45 | 46 | pub fn recycle(&self, item: T) { 47 | let pool = &self.sub_pools[Self::get_pool_location(item.size())]; 48 | pool.recycle(item); 49 | } 50 | } 51 | 52 | pub struct SubPool { 53 | channel: ChannelPair, 54 | num_allocated: std::sync::atomic::AtomicU32, 55 | } 56 | 57 | impl Default for SubPool 58 | where 59 | T: Allocatable, 60 | { 61 | fn default() -> Self { 62 | let channel = ChannelPair::new_unbounded(); 63 | Self { 64 | channel, 65 | num_allocated: std::sync::atomic::AtomicU32::new(0), 66 | } 67 | } 68 | } 69 | 70 | impl SubPool 71 | where 72 | T: Allocatable, 73 | { 74 | pub fn new() -> Self { 75 | Self::default() 76 | } 77 | 78 | pub fn allocate(&self, size: usize) -> T { 79 | if let Ok(allocated) = self.channel.receiver.try_recv() { 80 | allocated 81 | } else { 82 | tracing::debug!( 83 | message = "no available resource in pool, creating a new one", 84 | type_info = tracing::field::debug(std::any::type_name::()), 85 | num_resources = self.num_allocated.load(Ordering::Acquire) 86 | ); 87 | let allocated = T::new(size); 88 | self.num_allocated.fetch_add(1, Ordering::AcqRel); 89 | allocated 90 | } 91 | } 92 | 93 | pub fn recycle(&self, item: T) { 94 | self.channel.sender.send(item).unwrap(); 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /k8s/resources/operator.persia.com.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: v1 3 | kind: ServiceAccount 4 | metadata: 5 | name: persia-operator 6 | namespace: default 7 | 8 | --- 9 | apiVersion: rbac.authorization.k8s.io/v1 10 | kind: ClusterRoleBinding 11 | metadata: 12 | name: persia-operator-binding 13 | roleRef: 14 | apiGroup: rbac.authorization.k8s.io 15 | kind: ClusterRole 16 | name: persia-operator 17 | subjects: 18 | - kind: ServiceAccount 19 | name: persia-operator 20 | namespace: default 21 | 22 | --- 23 | apiVersion: rbac.authorization.k8s.io/v1 24 | kind: ClusterRole 25 | metadata: 26 | name: persia-operator 27 | rules: 28 | # Allow creating CRDs 29 | - apiGroups: 30 | - apiextensions.k8s.io 31 | resources: 32 | - customresourcedefinitions 33 | verbs: ["get", "list", "create", "update", "watch"] 34 | 35 | # Allow all actions on NATS Operator manager CRDs 36 | - apiGroups: 37 | - persia.com 38 | resources: 39 | - persiajobs 40 | verbs: ["*"] 41 | 42 | # Allowed actions on Pods 43 | - apiGroups: [""] 44 | resources: 45 | - pods 46 | verbs: ["create", "watch", "get", "patch", "update", "delete", "list"] 47 | 48 | # Allowed actions on Services 49 | - apiGroups: [""] 50 | resources: 51 | - services 52 | verbs: ["create", "watch", "get", "patch", "update", "delete", "list"] 53 | 54 | # Allowed actions on Secrets 55 | - apiGroups: [""] 56 | resources: 57 | - secrets 58 | verbs: ["create", "watch", "get", "update", "delete", "list"] 59 | 60 | # Allow all actions on some special subresources 61 | - apiGroups: [""] 62 | resources: 63 | - pods/exec 64 | - pods/log 65 | - serviceaccounts/token 66 | - events 67 | verbs: ["*"] 68 | 69 | # Allow listing Namespaces and ServiceAccounts 70 | - apiGroups: [""] 71 | resources: 72 | - namespaces 73 | - serviceaccounts 74 | verbs: ["list", "get", "watch"] 75 | 76 | # Allow actions on Endpoints 77 | - apiGroups: [""] 78 | resources: 79 | - endpoints 80 | verbs: ["create", "watch", "get", "update", "delete", "list"] 81 | 82 | --- 83 | apiVersion: apps/v1 84 | kind: Deployment 85 | metadata: 86 | name: persia-operator 87 | namespace: default 88 | spec: 89 | replicas: 1 90 | selector: 91 | matchLabels: 92 | name: persia-operator 93 | template: 94 | metadata: 95 | labels: 96 | name: persia-operator 97 | spec: 98 | serviceAccountName: persia-operator 99 | containers: 100 | - name: persia-operator 101 | image: persiaml/persia-cpu-runtime:latest 102 | imagePullPolicy: Always 103 | args: 104 | - persia-k8s-utils 105 | - operator 106 | -------------------------------------------------------------------------------- /k8s/resources/server.persia.com.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: v1 3 | kind: ServiceAccount 4 | metadata: 5 | name: persia-server 6 | namespace: default 7 | 8 | --- 9 | apiVersion: rbac.authorization.k8s.io/v1 10 | kind: ClusterRoleBinding 11 | metadata: 12 | name: persia-server-binding 13 | roleRef: 14 | apiGroup: rbac.authorization.k8s.io 15 | kind: ClusterRole 16 | name: persia-server 17 | subjects: 18 | - kind: ServiceAccount 19 | name: persia-server 20 | namespace: default 21 | 22 | --- 23 | apiVersion: rbac.authorization.k8s.io/v1 24 | kind: ClusterRole 25 | metadata: 26 | name: persia-server 27 | rules: 28 | # Allow creating CRDs 29 | - apiGroups: 30 | - apiextensions.k8s.io 31 | resources: 32 | - customresourcedefinitions 33 | verbs: ["get", "list", "create", "update", "watch"] 34 | 35 | # Allow all actions on NATS Operator manager CRDs 36 | - apiGroups: 37 | - persia.com 38 | resources: 39 | - persiajobs 40 | verbs: ["*"] 41 | 42 | # Allowed actions on Pods 43 | - apiGroups: [""] 44 | resources: 45 | - pods 46 | verbs: ["create", "watch", "get", "patch", "update", "delete", "list"] 47 | 48 | # Allowed actions on Services 49 | - apiGroups: [""] 50 | resources: 51 | - services 52 | verbs: ["create", "watch", "get", "patch", "update", "delete", "list"] 53 | 54 | # Allowed actions on Secrets 55 | - apiGroups: [""] 56 | resources: 57 | - secrets 58 | verbs: ["create", "watch", "get", "update", "delete", "list"] 59 | 60 | # Allow all actions on some special subresources 61 | - apiGroups: [""] 62 | resources: 63 | - pods/exec 64 | - pods/log 65 | - serviceaccounts/token 66 | - events 67 | verbs: ["*"] 68 | 69 | # Allow listing Namespaces and ServiceAccounts 70 | - apiGroups: [""] 71 | resources: 72 | - namespaces 73 | - serviceaccounts 74 | verbs: ["list", "get", "watch"] 75 | 76 | # Allow actions on Endpoints 77 | - apiGroups: [""] 78 | resources: 79 | - endpoints 80 | verbs: ["create", "watch", "get", "update", "delete", "list"] 81 | 82 | --- 83 | apiVersion: apps/v1 84 | kind: Deployment 85 | metadata: 86 | name: persia-server 87 | namespace: default 88 | spec: 89 | replicas: 1 90 | selector: 91 | matchLabels: 92 | name: persia-server 93 | template: 94 | metadata: 95 | labels: 96 | name: persia-server 97 | spec: 98 | serviceAccountName: persia-server 99 | containers: 100 | - name: persia-server 101 | image: persiaml/persia-cpu-runtime:latest 102 | imagePullPolicy: Always 103 | args: 104 | - persia-k8s-utils 105 | - server 106 | ports: 107 | - name: http 108 | containerPort: 8080 109 | -------------------------------------------------------------------------------- /persia/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import subprocess 4 | 5 | from typing import List, Optional 6 | 7 | from persia.logger import get_default_logger 8 | from persia.env import PERSIA_LAUNCHER_VERBOSE 9 | 10 | _logger = get_default_logger() 11 | 12 | 13 | def setup_seed(seed: int): 14 | """Set the random seed for dependencies to ensure that experiments are reproducible. 15 | 16 | Arguments: 17 | seed (int): integer to use as seed for random number generator used by random, 18 | NumPy and pyTorch. 19 | """ 20 | import numpy as np 21 | import torch 22 | import random 23 | 24 | np.random.seed(seed) 25 | 26 | random.seed(seed) 27 | 28 | torch.random.manual_seed(seed) 29 | if getattr(torch, "use_deterministic_algorithms", None): 30 | torch.use_deterministic_algorithms(True) 31 | else: 32 | torch.backends.cudnn.deterministic = True 33 | 34 | 35 | def load_yaml(filepath: str) -> dict: 36 | """Load the yaml config by provided filepath. 37 | 38 | Arguments: 39 | filepath (str): yaml config path. 40 | """ 41 | if not os.path.exists(filepath): 42 | raise FileNotFoundError(f"filepath {filepath} not found!") 43 | 44 | with open(filepath, "r") as file: 45 | return yaml.load(file, Loader=yaml.FullLoader) 46 | 47 | 48 | def dump_yaml(content: dict, filepath: str): 49 | """Dump the content into filepath.""" 50 | 51 | with open(filepath, "w") as file: 52 | file.write(yaml.dump(content)) 53 | 54 | 55 | def run_command(cmd: List[str], env: Optional[dict] = None): 56 | cmd = list(map(str, cmd)) 57 | if PERSIA_LAUNCHER_VERBOSE: 58 | cmd_str = " ".join(cmd) 59 | _logger.info(f"execute command: {cmd_str}") 60 | 61 | subprocess.check_call(cmd, env=env) 62 | 63 | 64 | def resolve_binary_execute_path(binary_name: str) -> str: 65 | """Resolved executable file under persia package root.""" 66 | return os.path.realpath(os.path.join(__file__, "../", binary_name)) 67 | 68 | 69 | def _is_port_available(port: int): 70 | import socket 71 | 72 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 73 | try: 74 | s.bind(("", port)) 75 | return True 76 | except OSError: 77 | return False 78 | 79 | 80 | MAXIMUM_LOCAL_PORT_NUM: int = 65535 81 | 82 | 83 | def find_free_port(port: int = 10000, interval: int = 1) -> int: 84 | """Check current input port is available or not. It will add the interval to input port utils the 85 | the new port is available.""" 86 | 87 | while not _is_port_available(port): 88 | port += interval 89 | if port > MAXIMUM_LOCAL_PORT_NUM: 90 | raise ValueError("free port not found.") 91 | return port 92 | -------------------------------------------------------------------------------- /examples/src/adult-income/data/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from argparse import ArgumentParser 4 | 5 | import pandas as pd 6 | import numpy as np 7 | 8 | from sklearn.preprocessing import OrdinalEncoder 9 | 10 | 11 | def process(df_dataset: pd.core.frame.DataFrame, filepath: str): 12 | 13 | # categoricasl data encoded to int value 14 | for col in CATEGORICAL_COLUMNS: 15 | encoder = OrdinalEncoder() 16 | df_dataset[col] = encoder.fit_transform(df_dataset[[col]]).astype(np.uint64) 17 | 18 | categorical_data = np.vstack([df_dataset[k].values for k in CATEGORICAL_COLUMNS]).T 19 | continuous_data = np.vstack([df_dataset[k].values for k in CONTINUOUS_COLUMNS]).T 20 | 21 | df_dataset["target"] = df_dataset["income_bracket"].apply(lambda x: ">50K" in x) 22 | target = df_dataset["target"].to_numpy() 23 | 24 | np.savez_compressed( 25 | filepath, 26 | target=target.astype(np.float32), 27 | continuous_data=continuous_data.astype(np.float32).copy(), 28 | categorical_data=categorical_data, 29 | categorical_columns=CATEGORICAL_COLUMNS, 30 | ) 31 | 32 | 33 | if __name__ == "__main__": 34 | COLUMNS = [ 35 | "age", 36 | "workclass", 37 | "fnlwgt", 38 | "education", 39 | "education_num", 40 | "marital_status", 41 | "occupation", 42 | "relationship", 43 | "race", 44 | "gender", 45 | "capital_gain", 46 | "capital_loss", 47 | "hours_per_week", 48 | "native_country", 49 | "income_bracket", 50 | ] 51 | 52 | CATEGORICAL_COLUMNS = [ 53 | "workclass", 54 | "education", 55 | "marital_status", 56 | "occupation", 57 | "relationship", 58 | "race", 59 | "gender", 60 | "native_country", 61 | ] 62 | 63 | CONTINUOUS_COLUMNS = [ 64 | "age", 65 | "education_num", 66 | "capital_gain", 67 | "capital_loss", 68 | "hours_per_week", 69 | ] 70 | 71 | parser = ArgumentParser() 72 | parser.add_argument("--train-dataset", default="data_source/train.csv") 73 | parser.add_argument("--test-dataset", default="data_source/test.csv") 74 | parser.add_argument("--output_path", default="data_source") 75 | args = parser.parse_args() 76 | 77 | output_path = args.output_path 78 | 79 | if args.train_dataset: 80 | df_train = pd.read_csv(args.train_dataset, names=COLUMNS, skipinitialspace=True) 81 | process(df_train, os.path.join(output_path, "train.npz")) 82 | 83 | if args.test_dataset: 84 | df_test = pd.read_csv( 85 | args.test_dataset, names=COLUMNS, skipinitialspace=True, skiprows=1 86 | ) 87 | process(df_test, os.path.join(output_path, "test.npz")) 88 | -------------------------------------------------------------------------------- /rust/persia-core/src/dlpack.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | ///! For detailed documentation, see: 3 | ///! https://github.com/dmlc/dlpack 4 | use std::os::raw::c_void; 5 | 6 | use persia_libs::tracing; 7 | 8 | /// DLpack DeviceType representation. Most of the scene is DLCPU and DLCUDA. 9 | #[repr(C)] 10 | #[derive(Clone, Copy, Debug)] 11 | pub enum DLDeviceType { 12 | DLCPU = 1, 13 | DLCUDA = 2, 14 | DLCUDAHost = 3, 15 | DLOpenCL = 4, 16 | DLVulkan = 7, 17 | DLMetal = 8, 18 | DLVPI = 9, 19 | DLROCM = 10, 20 | DLROCMHost = 11, 21 | DLExtDev = 12, 22 | DLCUDAManaged = 13, 23 | } 24 | 25 | /// Enum type of Dlpack DataTypeCode. 26 | // 27 | /// Use enum to represent the generic datatype. This struct can't infer concrete datatype directly, 28 | /// the concrete datatype should compose with the bits field in [`DLDataType`]. 29 | #[derive(Clone, Copy)] 30 | pub enum DLDataTypeCode { 31 | DLInt = 0, 32 | DLUInt = 1, 33 | DLFloat = 2, 34 | DLOpaqueHandle = 3, 35 | DLBfloat = 4, 36 | DLComplex = 5, 37 | } 38 | 39 | /// Dlpack DataType representation.It can describe almost general datatype in DeepLearning framework. 40 | /// 41 | /// For example the [`i32`] should represent as [`DLDataType`].code=0 and [`DLDataType`].bits=4 42 | /// The [`i64`] should represent as [`DLDataType`].code=2 and [`DLDataType`].bits=8. 43 | #[repr(C)] 44 | #[derive(Clone, Copy, Debug)] 45 | pub struct DLDataType { 46 | pub code: u8, 47 | pub bits: u8, 48 | pub lanes: u16, 49 | } 50 | 51 | /// Dlpack Device representation. 52 | #[repr(C)] 53 | #[derive(Clone, Copy, Debug)] 54 | pub struct DLDevice { 55 | pub device_type: DLDeviceType, 56 | pub device_id: i32, 57 | } 58 | 59 | /// Dlpack tensor format. Almost all fields are required except strides. 60 | #[repr(C)] 61 | #[derive(Clone, Debug)] 62 | pub struct DLTensor { 63 | pub data: *mut c_void, 64 | pub device: DLDevice, 65 | pub ndim: i32, 66 | pub dtype: DLDataType, 67 | pub shape: *mut i64, 68 | pub strides: *mut i64, 69 | pub byte_offset: u64, 70 | } 71 | 72 | impl Drop for DLTensor { 73 | fn drop(&mut self) { 74 | tracing::debug!("drop dltensor..."); 75 | } 76 | } 77 | 78 | /// A wrapper of [`DLTensor`] that holds the dl_tensor data and corresponding deleter function. 79 | #[repr(C)] 80 | #[derive(Clone, Debug)] 81 | pub struct DLManagedTensor { 82 | pub dl_tensor: DLTensor, 83 | pub manager_ctx: *mut c_void, 84 | pub deleter: Option, 85 | } 86 | 87 | /// [`DLManagedTensor`] FFI C drop function 88 | /// 89 | /// Ensure drop the instance after ownership changes. 90 | pub extern "C" fn drop_dl_managed_tensor(drop_ptr: *mut DLManagedTensor) { 91 | if drop_ptr.is_null() { 92 | return; 93 | } 94 | 95 | unsafe { Box::from_raw(drop_ptr) }; 96 | } 97 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # npz files 141 | *.npz 142 | 143 | # macos file 144 | .DS_Store -------------------------------------------------------------------------------- /docs/_autoapi_templates/python/module.rst: -------------------------------------------------------------------------------- 1 | {% if not obj.display %} 2 | :orphan: 3 | 4 | {% endif %} 5 | .. :mod:`{{ obj.name }}` 6 | {{ obj.name }} 7 | ======={{ "=" * obj.name|length }} 8 | 9 | .. py:module:: {{ obj.name }} 10 | 11 | {% if obj.docstring %} 12 | .. autoapi-nested-parse:: 13 | 14 | {{ obj.docstring|prepare_docstring|indent(3) }} 15 | 16 | {% endif %} 17 | 18 | {% block subpackages %} 19 | {% set visible_subpackages = obj.subpackages|selectattr("display")|list %} 20 | {% if visible_subpackages %} 21 | Subpackages 22 | ----------- 23 | .. toctree:: 24 | :titlesonly: 25 | :maxdepth: 3 26 | 27 | {% for subpackage in visible_subpackages %} 28 | {{ subpackage.short_name }}/index.rst 29 | {% endfor %} 30 | 31 | 32 | {% endif %} 33 | {% endblock %} 34 | {% block submodules %} 35 | {% set visible_submodules = obj.submodules|selectattr("display")|list %} 36 | {% if visible_submodules %} 37 | Submodules 38 | ---------- 39 | .. toctree:: 40 | :titlesonly: 41 | :maxdepth: 1 42 | 43 | {% for submodule in visible_submodules %} 44 | {{ submodule.short_name }}/index.rst 45 | {% endfor %} 46 | 47 | 48 | {% endif %} 49 | {% endblock %} 50 | {% block content %} 51 | {% if obj.all is not none %} 52 | {% set visible_children = obj.children|selectattr("short_name", "in", obj.all)|list %} 53 | {% elif obj.type is equalto("package") %} 54 | {% set visible_children = obj.children|selectattr("display")|list %} 55 | {% else %} 56 | {% set visible_children = obj.children|selectattr("display")|rejectattr("imported")|list %} 57 | {% endif %} 58 | {% if visible_children %} 59 | {{ obj.type|title }} Contents 60 | {{ "-" * obj.type|length }}--------- 61 | 62 | {% set visible_classes = visible_children|selectattr("type", "equalto", "class")|list %} 63 | {% set visible_functions = visible_children|selectattr("type", "equalto", "function")|list %} 64 | {% set visible_attributes = visible_children|selectattr("type", "equalto", "data")|list %} 65 | {% if "show-module-summary" in autoapi_options and (visible_classes or visible_functions) %} 66 | {% block classes scoped %} 67 | {% if visible_classes %} 68 | Classes 69 | ~~~~~~~ 70 | 71 | .. autoapisummary:: 72 | 73 | {% for klass in visible_classes %} 74 | {{ klass.id }} 75 | {% endfor %} 76 | 77 | 78 | {% endif %} 79 | {% endblock %} 80 | 81 | {% block functions scoped %} 82 | {% if visible_functions %} 83 | Functions 84 | ~~~~~~~~~ 85 | 86 | .. autoapisummary:: 87 | 88 | {% for function in visible_functions %} 89 | {{ function.id }} 90 | {% endfor %} 91 | 92 | 93 | {% endif %} 94 | {% endblock %} 95 | 96 | {% block attributes scoped %} 97 | {% if visible_attributes %} 98 | Attributes 99 | ~~~~~~~~~~ 100 | 101 | .. autoapisummary:: 102 | 103 | {% for attribute in visible_attributes %} 104 | {{ attribute.id }} 105 | {% endfor %} 106 | 107 | 108 | {% endif %} 109 | {% endblock %} 110 | {% endif %} 111 | {% for obj_item in visible_children %} 112 | {{ obj_item.render()|indent(0) }} 113 | {% endfor %} 114 | {% endif %} 115 | {% endblock %} 116 | -------------------------------------------------------------------------------- /rust/persia-core/src/cuda/mod.rs: -------------------------------------------------------------------------------- 1 | use cuda_runtime_sys as cuda; 2 | 3 | pub mod cuda_event_pool; 4 | pub mod cuda_memory_pool; 5 | pub mod cuda_stream_pool; 6 | pub mod pinned_memory_pool; 7 | pub mod resource_pool; 8 | pub mod utils; 9 | 10 | use cuda_event_pool::{CudaEventPtr, CUDA_EVENT_POOL}; 11 | use cuda_memory_pool::{CudaMallocPtr, CUDA_DEVICE_MEMORY_POOL}; 12 | use cuda_stream_pool::CUDA_STREAM_POOL; 13 | use pinned_memory_pool::{PinnedMemoryPtr, PINNED_MEMORY_POOL}; 14 | 15 | use persia_libs::{anyhow::Result, tracing}; 16 | use persia_speedy::{Readable, Writable}; 17 | 18 | use crate::tensor::{CPUStorage, DTypeImpl}; 19 | 20 | pub fn set_device(card_index: i32) { 21 | let result = unsafe { cuda::cudaSetDevice(card_index) }; 22 | assert_eq!(result, cuda::cudaError::cudaSuccess); 23 | } 24 | 25 | #[derive(Debug, Readable, Writable)] 26 | pub struct GPUStorage { 27 | #[speedy(skip)] 28 | pub ptr: CudaMallocPtr, 29 | pub shape: Vec, 30 | #[speedy(skip)] 31 | pub event: CudaEventPtr, 32 | #[speedy(skip)] 33 | pub host_ptr: PinnedMemoryPtr, 34 | pub dtype: DTypeImpl, 35 | pub is_ready: bool, 36 | } 37 | 38 | impl GPUStorage { 39 | pub fn new(storage: CPUStorage, shape: Vec) -> Result { 40 | unsafe { 41 | let stream = CUDA_STREAM_POOL.allocate(0); 42 | let mut storage = storage; 43 | let dtype = storage.get_dtype(); 44 | let byte_count = shape.iter().product::() * dtype.get_type_size(); 45 | 46 | tracing::debug!("tensor shape is: {:?}, bytes: {:?}", &shape, &byte_count); 47 | 48 | let host_ptr = storage.get_raw_ptr(); 49 | let pinned_host_ptr = PINNED_MEMORY_POOL.allocate(byte_count); 50 | std::ptr::copy_nonoverlapping(host_ptr, pinned_host_ptr.inner, byte_count); 51 | 52 | let data_ptr = CUDA_DEVICE_MEMORY_POOL.allocate(byte_count); 53 | let result = cuda::cudaMemcpyAsync( 54 | data_ptr.inner, 55 | pinned_host_ptr.inner, 56 | byte_count, 57 | cuda::cudaMemcpyKind::cudaMemcpyHostToDevice, 58 | stream.inner, 59 | ); 60 | assert_eq!( 61 | result, 62 | cuda::cudaError::cudaSuccess, 63 | "data_ptr {:?}, pinned_host_ptr: {:?}, byte_count: {:?}, stream: {:?}", 64 | data_ptr, 65 | pinned_host_ptr, 66 | byte_count, 67 | stream 68 | ); 69 | let event = CUDA_EVENT_POOL.allocate(0); 70 | event.record(stream); 71 | 72 | Ok(GPUStorage { 73 | ptr: data_ptr, 74 | shape: shape.clone(), 75 | host_ptr: pinned_host_ptr, 76 | event, 77 | dtype, 78 | is_ready: false, 79 | }) 80 | } 81 | } 82 | 83 | pub fn sync_event(&mut self) { 84 | if !self.is_ready { 85 | self.event.synchronize(); 86 | self.is_ready = true; 87 | } 88 | } 89 | 90 | pub fn get_raw_ptr(&mut self) -> *mut std::os::raw::c_void { 91 | self.sync_event(); 92 | self.ptr.inner 93 | } 94 | 95 | pub fn get_dtype(&self) -> DTypeImpl { 96 | self.dtype.clone() 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import colorama 3 | 4 | from setuptools import setup, find_packages 5 | from setuptools_rust import Binding, RustExtension 6 | 7 | USE_CUDA = bool(int(os.environ.get("USE_CUDA", "0"))) 8 | USE_K8S = bool(int(os.environ.get("USE_K8S", "1"))) 9 | NATIVE = bool(int(os.environ.get("NATIVE", "0"))) 10 | 11 | if __name__ == "__main__": 12 | 13 | colorama.init(autoreset=True) 14 | 15 | features = None if not USE_CUDA else ["cuda"] 16 | rust_extensions = [] 17 | console_scripts = [] 18 | 19 | rust_extensions.append( 20 | RustExtension( 21 | # TODO: Due to this issue https://github.com/PyO3/setuptools-rust/issues/153 still not release 22 | # the new version of setuptool_rust, RustExtension can't enable the script feature 23 | # { 24 | # "persia-embedding-worker": "persia.persia_embedding_worker", 25 | # "persia-embedding-parameter-server": "persia.persia_embedding_parameter_server" 26 | # }, 27 | # script=True, 28 | { 29 | "persia-embedding-worker": "persia.persia-embedding-worker", 30 | "persia-embedding-parameter-server": "persia.persia-embedding-parameter-server", 31 | }, 32 | path="rust/persia-embedding-server/Cargo.toml", 33 | binding=Binding.Exec, 34 | native=NATIVE, 35 | ) 36 | ) 37 | console_scripts.append("persia-launcher=persia.launcher:cli") 38 | 39 | rust_extensions.append( 40 | RustExtension( 41 | "persia_core", 42 | path="rust/persia-core/Cargo.toml", 43 | binding=Binding.PyO3, 44 | native=NATIVE, 45 | features=features, 46 | ) 47 | ) 48 | 49 | if USE_K8S: 50 | rust_extensions.append( 51 | RustExtension( 52 | { 53 | "gencrd": "persia.gencrd", 54 | "operator": "persia.operator", 55 | "e2e": "persia.e2e_test", 56 | }, 57 | path="k8s/Cargo.toml", 58 | binding=Binding.Exec, 59 | native=NATIVE, 60 | ) 61 | ) 62 | console_scripts.append("persia-k8s-utils=persia.k8s_utils:cli") 63 | 64 | install_requires = ["colorlog", "pyyaml", "click", "honcho", "cloudpickle"] 65 | 66 | if USE_CUDA: 67 | name_suffix = os.getenv("PERSIA_CUDA_VERSION", "") 68 | if name_suffix != "": 69 | name_suffix = "-cuda" + name_suffix 70 | else: 71 | name_suffix = "" 72 | 73 | with open(os.path.realpath(os.path.join(__file__, "../README.md"))) as file: 74 | long_description = file.read() 75 | 76 | setup( 77 | name="persia" + name_suffix, 78 | use_scm_version={"local_scheme": "no-local-version"}, 79 | setup_requires=["setuptools_scm"], 80 | install_requires=install_requires, 81 | url="https://github.com/PersiaML/PersiaML", 82 | author="Kuaishou AI Platform Persia Team", 83 | author_email="admin@mail.xrlian.com", 84 | description="PersiaML Python Library", 85 | long_description=long_description, 86 | long_description_content_type="text/markdown", 87 | packages=find_packages(exclude=("tests",)), 88 | rust_extensions=rust_extensions, 89 | entry_points={"console_scripts": console_scripts}, 90 | python_requires=">=3.6", 91 | ) 92 | -------------------------------------------------------------------------------- /rust/persia-embedding-holder/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod array_linked_list; 2 | pub mod emb_entry; 3 | pub mod eviction_map; 4 | pub mod sharded; 5 | 6 | use std::sync::Arc; 7 | 8 | use persia_libs::{once_cell, parking_lot::RwLock, thiserror}; 9 | 10 | use emb_entry::HashMapEmbeddingEntry; 11 | use eviction_map::EvictionMap; 12 | use persia_embedding_config::{EmbeddingParameterServerConfig, PersiaGlobalConfigError}; 13 | use persia_speedy::{Readable, Writable}; 14 | use sharded::Sharded; 15 | 16 | #[derive(Clone, Readable, Writable, thiserror::Error, Debug)] 17 | pub enum PersiaEmbeddingHolderError { 18 | #[error("global config error: {0}")] 19 | PersiaGlobalConfigError(#[from] PersiaGlobalConfigError), 20 | #[error("id not fonud")] 21 | IdNotFound, 22 | } 23 | 24 | static PERSIA_EMBEDDING_HOLDER: once_cell::sync::OnceCell = 25 | once_cell::sync::OnceCell::new(); 26 | 27 | #[derive(Clone)] 28 | pub struct PersiaEmbeddingHolder { 29 | inner: Arc, u64>>, 30 | } 31 | 32 | impl PersiaEmbeddingHolder { 33 | pub fn get() -> Result { 34 | let singleton = PERSIA_EMBEDDING_HOLDER.get_or_try_init(|| { 35 | let config = EmbeddingParameterServerConfig::get()?; 36 | 37 | let bucket_size = config.num_hashmap_internal_shards; 38 | let cpapacity_per_bucket = config.capacity / bucket_size; 39 | 40 | let handles: Vec> = (0..bucket_size) 41 | .map(|_| { 42 | std::thread::spawn(move || { 43 | EvictionMap::with_capacity(cpapacity_per_bucket as usize) 44 | }) 45 | }) 46 | .collect(); 47 | 48 | let maps: Vec<_> = handles 49 | .into_iter() 50 | .map(|h| RwLock::new(h.join().expect("failed to create map"))) 51 | .collect(); 52 | 53 | let sharded = Sharded { 54 | inner: maps, 55 | phantom: std::marker::PhantomData::default(), 56 | }; 57 | Ok(PersiaEmbeddingHolder { 58 | inner: Arc::new(sharded), 59 | }) 60 | }); 61 | match singleton { 62 | Ok(s) => Ok(s.clone()), 63 | Err(e) => Err(e), 64 | } 65 | } 66 | 67 | pub fn num_total_signs(&self) -> usize { 68 | self.inner 69 | .inner 70 | .iter() 71 | .map(|x| x.read().len()) 72 | .sum::() 73 | } 74 | 75 | pub fn num_internal_shards(&self) -> usize { 76 | self.inner.inner.len() 77 | } 78 | 79 | pub fn capacity(&self) -> usize { 80 | self.inner 81 | .inner 82 | .iter() 83 | .map(|x| x.read().capacity()) 84 | .sum::() 85 | } 86 | 87 | pub fn clear(&self) { 88 | self.inner.inner.iter().for_each(|x| x.write().clear()); 89 | } 90 | 91 | pub fn shard(&self, key: &u64) -> &RwLock> { 92 | self.inner.shard(key) 93 | } 94 | 95 | pub fn get_shard_by_index( 96 | &self, 97 | index: usize, 98 | ) -> &RwLock> { 99 | self.inner.get_shard_by_index(index) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /persia/embedding/optim.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Tuple 3 | 4 | from persia.prelude import OptimizerBase 5 | 6 | 7 | class Optimizer(ABC): 8 | r"""Base optimizer to configurate the embedding update behavior.""" 9 | 10 | def __init__(self): 11 | self.optimizer_base = OptimizerBase() 12 | 13 | def apply(self): 14 | """Register sparse optimizer to embedding server.""" 15 | self.optimizer_base.apply() 16 | 17 | 18 | class SGD(Optimizer): 19 | r"""A wrapper to config the embedding-server SGD optimizer.""" 20 | 21 | def __init__(self, lr: float, momentum: float = 0.0, weight_decay: float = 0.0): 22 | """ 23 | Arguments: 24 | lr(float): learning rate. 25 | momentum(float, optional): momentum factor. 26 | weight_decay(float, optional): parameters L2 penalty factor. 27 | """ 28 | super(SGD, self).__init__() 29 | self.lr = lr 30 | self.momentum = momentum 31 | self.weight_decay = weight_decay 32 | self.optimizer_base.init_sgd(self.lr, self.weight_decay) 33 | 34 | 35 | class Adam(Optimizer): 36 | r"""A wrapper to config the embedding-server Adam optimizer.""" 37 | 38 | def __init__( 39 | self, 40 | lr: float = 1e-3, 41 | betas: Tuple[float, float] = (0.9, 0.999), 42 | weight_decay: float = 0, 43 | eps: float = 1e-8, 44 | ): 45 | """ 46 | Arguments: 47 | lr(float): learning rate. 48 | betas(tuple[float,float], optional): calculate the running averages of gradient and its square. 49 | weight_decay(float, optional): parameters L2 penalty factor. 50 | eps(float, optional): epsilon to avoid div zero. 51 | """ 52 | super(Adam, self).__init__() 53 | self.lr = lr 54 | self.betas = betas 55 | self.weight_decay = weight_decay 56 | self.eps = eps 57 | self.optimizer_base.init_adam(self.lr, self.betas, self.eps) 58 | 59 | 60 | class Adagrad(Optimizer): 61 | r"""A wrapper to config the embedding-server Adagrad optimizer.""" 62 | 63 | def __init__( 64 | self, 65 | lr: float = 1e-2, 66 | initial_accumulator_value: float = 1e-2, 67 | weight_decay: float = 0, 68 | g_square_momentum: float = 1, 69 | eps: float = 1e-10, 70 | vectorwise_shared: bool = False, 71 | ): 72 | """ 73 | Arguments: 74 | lr (float): learning rate. 75 | initial_accumulator_value (float, optional): initialization accumulator value for adagrad optimizer. 76 | weight_decay (float, optional): parameters L2 penalty factor. 77 | g_square_momentum (float, optional): factor of accumulator incremental. 78 | eps(float, optional): epsilon term to avoid divide zero. 79 | vectorwise_shared(bool, optional): whether to share optimizer status vectorwise of embedding. 80 | 81 | """ 82 | super(Adagrad, self).__init__() 83 | self.lr = lr 84 | self.weight_decay = weight_decay 85 | self.initial_accumulator_value = initial_accumulator_value 86 | self.g_square_momentum = g_square_momentum 87 | self.eps = eps 88 | self.vectorwise_shared = vectorwise_shared 89 | self.optimizer_base.init_adagrad( 90 | self.lr, 91 | self.weight_decay, 92 | self.g_square_momentum, 93 | self.initial_accumulator_value, 94 | self.eps, 95 | self.vectorwise_shared, 96 | ) 97 | -------------------------------------------------------------------------------- /examples/src/adult-income/data_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List, Tuple 2 | 3 | import numpy as np 4 | 5 | from tqdm import tqdm 6 | 7 | from persia.embedding.data import IDTypeFeature, NonIDTypeFeature, Label 8 | 9 | 10 | class DataLoader: 11 | def __init__( 12 | self, 13 | non_id_type_feature_data, 14 | id_type_feature_data, 15 | label_data, 16 | id_type_feature_names, 17 | batch_size: int = 128, 18 | skip_last_batch: bool = False, 19 | ): 20 | self.non_id_type_feature_data = non_id_type_feature_data 21 | self.id_type_feature_data = id_type_feature_data 22 | self.label_data = label_data 23 | self.id_type_feature_names = id_type_feature_names 24 | self.batch_size = batch_size 25 | self.skip_last_batch = skip_last_batch 26 | 27 | dataset_size = len(label_data) 28 | loader_size = (dataset_size - 1) // batch_size + 1 29 | if skip_last_batch: 30 | loader_size = loader_size - 1 31 | 32 | self.loader_size = loader_size 33 | 34 | def __iter__( 35 | self, 36 | ) -> Iterable[Tuple[NonIDTypeFeature, List[IDTypeFeature], Label]]: 37 | batch_size = self.batch_size 38 | skip_last_batch = self.skip_last_batch 39 | id_type_feature_names = self.id_type_feature_names 40 | id_type_feature_data = self.id_type_feature_data 41 | label_data = self.label_data 42 | 43 | dataset_size = len(self.label_data) 44 | for start in range(0, dataset_size, batch_size): 45 | end = min(start + batch_size, dataset_size) 46 | if end == dataset_size and skip_last_batch: 47 | print("skip last batch...") 48 | continue 49 | 50 | non_id_type_feature = NonIDTypeFeature( 51 | self.non_id_type_feature_data[start:end, :] 52 | ) 53 | id_type_features = [] 54 | 55 | for id_type_feature_idx, feature_name in enumerate(id_type_feature_names): 56 | id_type_feature = [] 57 | id_type_feature_batch = id_type_feature_data[start:end] 58 | for batch_idx in range(len(id_type_feature_batch)): 59 | id_type_feature.append( 60 | id_type_feature_batch[batch_idx][ 61 | id_type_feature_idx : id_type_feature_idx + 1 62 | ] 63 | ) 64 | id_type_features.append(IDTypeFeature(feature_name, id_type_feature)) 65 | 66 | label = label_data[start:end] 67 | label = Label(label.reshape(len(label), -1)) 68 | 69 | yield non_id_type_feature, id_type_features, label 70 | 71 | def __len__(self) -> int: 72 | return self.loader_size 73 | 74 | 75 | def make_dataloader( 76 | data_filepath: str, batch_size: int = 128, skip_last_batch: bool = False 77 | ) -> DataLoader: 78 | with np.load(data_filepath) as data: 79 | target = data["target"] 80 | continuous_data = data["continuous_data"] 81 | categorical_data = data["categorical_data"] 82 | categorical_columns = data["categorical_columns"] 83 | 84 | return DataLoader( 85 | continuous_data, 86 | categorical_data, 87 | target, 88 | categorical_columns, 89 | batch_size, 90 | skip_last_batch, 91 | ) 92 | 93 | 94 | if __name__ == "__main__": 95 | loader = make_dataloader("data/train.npz", 128, skip_last_batch=False) 96 | for (dense, sparse, target) in tqdm(loader, desc="generate_data"): 97 | ... 98 | -------------------------------------------------------------------------------- /rust/others/persia-nats-client/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use async_nats::{Connection, Subscription}; 4 | use persia_libs::{ 5 | backoff::{future::retry, ExponentialBackoff}, 6 | thiserror, 7 | tokio::sync::OnceCell, 8 | tracing, 9 | }; 10 | 11 | use persia_speedy::{Readable, Writable}; 12 | 13 | #[derive(Readable, Writable, thiserror::Error, Clone, Debug)] 14 | pub enum NatsError { 15 | #[error("nats io error {0:?}")] 16 | IoError(String), 17 | #[error("decode speedy error")] 18 | DecodeError, 19 | #[error("not found any node in subject error")] 20 | EmptyNodeError, 21 | } 22 | 23 | impl From for NatsError { 24 | fn from(error: std::io::Error) -> Self { 25 | let msg = format!("{:?}", error); 26 | NatsError::IoError(msg) 27 | } 28 | } 29 | 30 | static NATS_CLIENT: OnceCell = OnceCell::const_new(); 31 | 32 | #[derive(Debug, Clone)] 33 | pub struct NatsClient { 34 | nc: Connection, 35 | timeout: Duration, 36 | } 37 | 38 | impl NatsClient { 39 | pub async fn get() -> Self { 40 | let instance = NATS_CLIENT.get_or_init(|| NatsClient::new()).await; 41 | instance.clone() 42 | } 43 | 44 | async fn new() -> NatsClient { 45 | let nats_url = std::env::var("PERSIA_NATS_URL") 46 | .unwrap_or(String::from("nats://persia_nats_service:4222")); 47 | let mut backoff = ExponentialBackoff::default(); 48 | backoff.max_interval = std::time::Duration::from_millis(500); 49 | let nc = retry(backoff, || async { 50 | let res = async_nats::connect(nats_url.as_str()).await; 51 | if res.is_err() { 52 | tracing::warn!("failed to connect nats server, {:?}", res); 53 | } 54 | Ok(res?) 55 | }) 56 | .await 57 | .expect("failed to init nats connection"); 58 | 59 | NatsClient { 60 | nc, 61 | timeout: Duration::from_secs(10), 62 | } 63 | } 64 | 65 | pub async fn subscribe(&self, subject: &str) -> Result { 66 | tracing::debug!("subscribing nats subject {}", subject); 67 | match self.nc.subscribe(subject).await { 68 | Ok(subscription) => Ok(subscription), 69 | Err(err) => Err(NatsError::from(err)), 70 | } 71 | } 72 | 73 | pub async fn request(&self, subject: &str, msg: &[u8]) -> Result, NatsError> { 74 | tracing::debug!("requesting nats subject {}", subject); 75 | match self.nc.request_timeout(subject, msg, self.timeout).await { 76 | Ok(msg) => Ok(msg.data), 77 | Err(err) => Err(NatsError::from(err)), 78 | } 79 | } 80 | 81 | pub async fn request_multi( 82 | &self, 83 | subject: &str, 84 | msg: &[u8], 85 | ) -> Result>, NatsError> { 86 | match self.nc.request_multi(subject, msg).await { 87 | Ok(subscription) => { 88 | let mut messages = Vec::new(); 89 | while let Some(msg) = subscription.next().await { 90 | messages.push(msg.data); 91 | } 92 | Ok(messages) 93 | } 94 | Err(err) => Err(NatsError::from(err)), 95 | } 96 | } 97 | 98 | pub fn get_subject( 99 | &self, 100 | service_type: &str, 101 | fn_name: &str, 102 | replica_index: Option, 103 | ) -> String { 104 | match replica_index { 105 | Some(idx) => format!("{}.{}.{}", service_type, fn_name, idx), 106 | None => format!("{}.{}", service_type, fn_name), 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /k8s/src/bin/operator.rs: -------------------------------------------------------------------------------- 1 | use futures::stream::StreamExt; 2 | use kube::Resource; 3 | use kube::ResourceExt; 4 | use kube::{api::ListParams, client::Client, Api}; 5 | use kube_runtime::controller::{Context, ReconcilerAction}; 6 | use kube_runtime::Controller; 7 | use tokio::time::Duration; 8 | 9 | use persia_operator::crd::PersiaJob; 10 | use persia_operator::error::Error; 11 | use persia_operator::finalizer; 12 | use persia_operator::PersiaJobResources; 13 | 14 | #[tokio::main] 15 | async fn main() { 16 | openssl_sys::init(); 17 | 18 | let kubernetes_client: Client = Client::try_default() 19 | .await 20 | .expect("Expected a valid KUBECONFIG environment variable."); 21 | 22 | let crd_api: Api = Api::all(kubernetes_client.clone()); 23 | let context: Context = Context::new(ContextData::new(kubernetes_client.clone())); 24 | 25 | Controller::new(crd_api.clone(), ListParams::default()) 26 | .run(reconcile, on_error, context) 27 | .for_each(|reconciliation_result| async move { 28 | match reconciliation_result { 29 | Ok(persia_resource) => { 30 | println!("Reconciliation successful. Resource: {:?}", persia_resource); 31 | } 32 | Err(reconciliation_err) => { 33 | eprintln!("Reconciliation error: {:?}", reconciliation_err) 34 | } 35 | } 36 | }) 37 | .await; 38 | } 39 | 40 | struct ContextData { 41 | client: Client, 42 | } 43 | 44 | impl ContextData { 45 | pub fn new(client: Client) -> Self { 46 | ContextData { client } 47 | } 48 | } 49 | 50 | enum Action { 51 | Create, 52 | Delete, 53 | NoOp, 54 | } 55 | 56 | async fn reconcile( 57 | job: PersiaJob, 58 | context: Context, 59 | ) -> Result { 60 | let client: Client = context.get_ref().client.clone(); 61 | 62 | let namespace: String = match job.namespace() { 63 | None => { 64 | return Err(Error::UserInputError( 65 | "Expected PersiaJob resource to be namespaced. Can't deploy to an unknown namespace." 66 | .to_owned(), 67 | )); 68 | } 69 | Some(namespace) => namespace, 70 | }; 71 | 72 | let job_resources = PersiaJobResources::new(&job.spec, &job.name(), &namespace, client.clone()); 73 | 74 | return match determine_action(&job) { 75 | Action::Create => { 76 | let job_name = job.name(); 77 | eprintln!("Creating PersiaJob: {}", job_name); 78 | 79 | finalizer::add(client.clone(), &job_name, &namespace).await?; 80 | 81 | job_resources.apply().await?; 82 | 83 | Ok(ReconcilerAction { 84 | requeue_after: Some(Duration::from_secs(10)), 85 | }) 86 | } 87 | Action::Delete => { 88 | let job_name = job.name(); 89 | eprintln!("Deleting PersiaJob: {}", job_name); 90 | 91 | job_resources.delete().await?; 92 | finalizer::delete(client, &job_name, &namespace).await?; 93 | 94 | Ok(ReconcilerAction { 95 | requeue_after: None, 96 | }) 97 | } 98 | Action::NoOp => Ok(ReconcilerAction { 99 | requeue_after: Some(Duration::from_secs(10)), 100 | }), 101 | }; 102 | } 103 | 104 | fn determine_action(job: &PersiaJob) -> Action { 105 | return if job.meta().deletion_timestamp.is_some() { 106 | Action::Delete 107 | } else if job 108 | .meta() 109 | .finalizers 110 | .as_ref() 111 | .map_or(true, |finalizers| finalizers.is_empty()) 112 | { 113 | Action::Create 114 | } else { 115 | Action::NoOp 116 | }; 117 | } 118 | 119 | fn on_error(error: &Error, _context: Context) -> ReconcilerAction { 120 | eprintln!("Reconciliation error:\n{:?}", error); 121 | ReconcilerAction { 122 | requeue_after: Some(Duration::from_secs(5)), 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /persia/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from typing import Optional 4 | 5 | from colorlog import ColoredFormatter 6 | 7 | 8 | class levelFilter(logging.Filter): 9 | r"""Log level filter. 10 | 11 | Arguments: 12 | level (int): filter log level. Only logs with level higher than ``level`` will be kept. 13 | """ 14 | 15 | def __init__(self, level: int): 16 | self.level = level 17 | 18 | def filter(self, record: logging.LogRecord) -> bool: 19 | """Filter the log record whose level is greater than the preset log level. 20 | 21 | Arguments: 22 | record (logging.LogRecord): callback function input record items. 23 | """ 24 | return record.levelno > self.level 25 | 26 | 27 | STREAM_LOG_FORMAT = "%(log_color)s%(asctime)s %(levelname)-8s%(reset)s %(blue)s[%(filename)s:%(lineno)d]%(reset)s %(log_color)s%(message)s" 28 | FILE_LOG_FORMAT = "%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s" 29 | DEFAULT_LOGGER_NAME = "log" 30 | _default_logger = None 31 | 32 | LOG_COLOR = { 33 | "DEBUG": "cyan", 34 | "INFO": "green", 35 | "WARNING": "yellow", 36 | "ERROR": "red", 37 | "CRITICAL": "red,bg_white", 38 | } 39 | 40 | COLOR_FORMATTER = ColoredFormatter( 41 | STREAM_LOG_FORMAT, 42 | datefmt=None, 43 | reset=True, 44 | log_colors=LOG_COLOR, 45 | secondary_log_colors={}, 46 | style="%", 47 | ) 48 | FORMATTER = logging.Formatter( 49 | FILE_LOG_FORMAT, 50 | datefmt=None, 51 | style="%", 52 | ) 53 | 54 | 55 | def setLogger( 56 | name: str, 57 | log_level: int = logging.DEBUG, 58 | log_filename: str = "train.log", 59 | enable_file_logger: bool = False, 60 | err_redirect_filepath: str = "error.log", 61 | enable_err_redirect: bool = False, 62 | err_redirect_level: int = logging.INFO, 63 | ) -> logging.Logger: 64 | r"""Helper function to simplify the logger setup process with provided 65 | log_level and log_filename. Also makes it possible to redirect logs 66 | above a certain level to a different file. 67 | 68 | Arguments: 69 | name (str): logger name 70 | log_filename (str): log filename 71 | enable_file_logger (bool): whether enable save log into file 72 | err_redirect_filepath (str): err log redirect filepath 73 | enable_err_redirect (bool): whether enable err log redirect 74 | err_redirect_level (int): error redirect log level 75 | """ 76 | logger = logging.getLogger(name) 77 | handler = logging.StreamHandler() 78 | handler.setFormatter(COLOR_FORMATTER) 79 | 80 | logger.addHandler(handler) 81 | logger.setLevel(log_level) 82 | 83 | if enable_file_logger: 84 | file_normal_handler = logging.FileHandler(log_filename, mode="a") 85 | file_normal_handler.setFormatter(FORMATTER) 86 | logger.addHandler(file_normal_handler) 87 | 88 | if enable_err_redirect: 89 | file_error_handler = logging.FileHandler(err_redirect_filepath, mode="a") 90 | file_error_handler.setFormatter(FORMATTER) 91 | file_error_handler.addFilter(levelFilter(err_redirect_level)) 92 | logger.addHandler(file_error_handler) 93 | return logger 94 | 95 | 96 | def get_logger(name: str) -> logging.Logger: 97 | r"""Get logger by name. 98 | 99 | Arguments: 100 | name (str): logger name. 101 | """ 102 | return logging.getLogger(name) 103 | 104 | 105 | def _set_default_logger(name: str, **kwargs) -> logging.Logger: 106 | r"""Set the default logger. 107 | 108 | Arguments: 109 | name (str): default logger name. 110 | 111 | logging.Logger 112 | """ 113 | global _default_logger 114 | if not _default_logger: 115 | _default_logger = setLogger(name, **kwargs) 116 | return _default_logger 117 | 118 | 119 | def get_default_logger(name: Optional[str] = None, **kwargs) -> logging.Logger: 120 | r"""Get the default logger. If default logger is not set, init the default by 121 | the given name. 122 | 123 | Arguments: 124 | name (str, optional): logger name. 125 | """ 126 | if _default_logger is None: 127 | _set_default_logger(name or DEFAULT_LOGGER_NAME, **kwargs) 128 | return _default_logger 129 | -------------------------------------------------------------------------------- /rust/persia-core/src/utils.rs: -------------------------------------------------------------------------------- 1 | use pyo3::prelude::*; 2 | use pyo3::types::PyBytes; 3 | 4 | use crate::data::{PersiaBatch, PersiaBatchImpl}; 5 | 6 | use persia_common::message_queue::{PersiaMessageQueueClientImpl, PersiaMessageQueueServerImpl}; 7 | use persia_libs::{flume, tokio::runtime::Runtime}; 8 | 9 | #[pyclass] 10 | pub struct PersiaMessageQueueClient { 11 | pub inner: PersiaMessageQueueClientImpl, 12 | pub runtime: Runtime, 13 | } 14 | 15 | #[pymethods] 16 | impl PersiaMessageQueueClient { 17 | #[new] 18 | fn new(server_addr: &str) -> Self { 19 | let runtime = persia_libs::tokio::runtime::Builder::new_multi_thread() 20 | .enable_all() 21 | .worker_threads(5) 22 | .build() 23 | .unwrap(); 24 | 25 | let _guard = runtime.enter(); 26 | 27 | Self { 28 | inner: PersiaMessageQueueClientImpl::new(server_addr), 29 | runtime, 30 | } 31 | } 32 | 33 | fn put(&self, data: Vec) { 34 | let _gurad = self.runtime.enter(); 35 | self.runtime.block_on(self.inner.send(data)).unwrap(); 36 | } 37 | 38 | fn get<'a>(&self, _py: Python<'a>) -> &'a PyBytes { 39 | let _gurad = self.runtime.enter(); 40 | let bytes = self.runtime.block_on(self.inner.recv()); 41 | PyBytes::new(_py, bytes.unwrap().as_slice()) 42 | } 43 | } 44 | 45 | #[pyclass] 46 | pub struct PersiaMessageQueueServer { 47 | inner: PersiaMessageQueueServerImpl, 48 | runtime: Runtime, // thread unsafe 49 | } 50 | 51 | #[pymethods] 52 | impl PersiaMessageQueueServer { 53 | #[new] 54 | fn new(port: u16, cap: usize) -> Self { 55 | let runtime = persia_libs::tokio::runtime::Builder::new_multi_thread() 56 | .enable_all() 57 | .worker_threads(5) 58 | .build() 59 | .unwrap(); 60 | 61 | let _guard = runtime.enter(); 62 | 63 | Self { 64 | inner: PersiaMessageQueueServerImpl::new(port, cap), 65 | runtime, 66 | } 67 | } 68 | 69 | fn put(&self, data: Vec) { 70 | let _gurad = self.runtime.enter(); 71 | self.runtime.block_on(self.inner.send(data)) 72 | } 73 | 74 | fn get<'a>(&self, _py: Python<'a>) -> &'a PyBytes { 75 | let _gurad = self.runtime.enter(); 76 | let bytes = self.runtime.block_on(self.inner.recv()); 77 | PyBytes::new(_py, bytes.as_slice()) 78 | } 79 | } 80 | 81 | #[pyclass] 82 | pub struct PersiaBatchDataSender { 83 | pub inner: flume::Sender, 84 | } 85 | 86 | #[pymethods] 87 | impl PersiaBatchDataSender { 88 | pub fn send(&self, batch_data: &mut PersiaBatch, py: Python) -> PyResult<()> { 89 | let batch_data = std::mem::take(&mut batch_data.inner); 90 | py.allow_threads(move || { 91 | self.inner.send(batch_data).unwrap(); 92 | Ok(()) 93 | }) 94 | } 95 | } 96 | 97 | #[pyclass] 98 | pub struct PersiaBatchDataReceiver { 99 | pub inner: flume::Receiver, 100 | } 101 | #[pyclass] 102 | pub struct PersiaBatchDataChannel { 103 | pub sender: flume::Sender, 104 | pub receiver: flume::Receiver, 105 | } 106 | 107 | #[pymethods] 108 | impl PersiaBatchDataChannel { 109 | #[new] 110 | pub fn new(capacity: usize) -> Self { 111 | let (sender, receiver) = flume::bounded(capacity); 112 | Self { sender, receiver } 113 | } 114 | 115 | pub fn get_sender(&self) -> PersiaBatchDataSender { 116 | PersiaBatchDataSender { 117 | inner: self.sender.clone(), 118 | } 119 | } 120 | 121 | pub fn get_receiver(&self) -> PersiaBatchDataReceiver { 122 | PersiaBatchDataReceiver { 123 | inner: self.receiver.clone(), 124 | } 125 | } 126 | } 127 | 128 | pub fn init_module(super_module: &PyModule, py: Python) -> PyResult<()> { 129 | let module = PyModule::new(py, "utils")?; 130 | module.add_class::()?; 131 | module.add_class::()?; 132 | module.add_class::()?; 133 | module.add_class::()?; 134 | module.add_class::()?; 135 | super_module.add_submodule(module)?; 136 | Ok(()) 137 | } 138 | -------------------------------------------------------------------------------- /persia/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from persia.logger import get_default_logger 5 | 6 | _logger = get_default_logger() 7 | 8 | PERSIA_LAUNCHER_VERBOSE = bool(int(os.environ.get("PERSIA_LAUNCHER_VERBOSE", "0"))) 9 | 10 | # Skip all PERSIA data checks except batch size. 11 | # Raise RuntimeError when data does not meet requirement, such as 12 | # type, dtype or shape mismatch. 13 | PERSIA_SKIP_CHECK_DATA = bool(int(os.environ.get("PERSIA_SKIP_CHECK_DATA", "0"))) 14 | 15 | 16 | class _Env: 17 | def __init__(self): 18 | self.replica_size = None 19 | self.replica_index = None 20 | self.world_size = None 21 | self.rank = None 22 | self.local_rank = None 23 | self.is_init = False 24 | 25 | def init(self, force: bool = False): 26 | if self.is_init and not force: 27 | return 28 | 29 | if os.environ.get("RANK", None): 30 | self.rank = int(os.environ["RANK"]) 31 | self.local_rank = int(os.environ["LOCAL_RANK"]) 32 | self.world_size = int(os.environ["WORLD_SIZE"]) 33 | assert self.rank >= 0, "RANK cannot be negative." 34 | assert self.local_rank >= 0, "LOCAL_RANK cannot be negative." 35 | assert self.world_size >= 1, "WORLD_SIZE should be greater than one." 36 | else: 37 | if "REPLICA_INDEX" in os.environ: 38 | self.replica_index = int(os.environ["REPLICA_INDEX"]) 39 | assert self.replica_index >= 0, "REPLICA_IDNEX cannot be negative." 40 | 41 | replica_size = os.environ.get("REPLICA_SIZE", None) 42 | assert ( 43 | replica_size is not None 44 | ), "REPLICA_SIZE not found, setting environment variable REPLICA_SIZE before \ 45 | starting the PERSIA training task." 46 | self.replica_size = int(replica_size) 47 | assert ( 48 | self.replica_size >= 1 49 | ), "REPLICA_SIZE should be greater than one." 50 | else: 51 | _logger.warning( 52 | "REPLICA_INDEX not found, use default replica_index=0 and default replica_size=1" 53 | ) 54 | self.replica_size = 1 55 | self.replica_index = 0 56 | self.is_init = True 57 | 58 | def set( 59 | self, 60 | replica_size: Optional[int] = None, 61 | replica_index: Optional[int] = None, 62 | world_size: Optional[int] = None, 63 | rank: Optional[int] = None, 64 | local_rank: Optional[int] = None, 65 | ): 66 | self.replica_index = replica_index 67 | self.replica_size = replica_size 68 | self.world_size = world_size 69 | self.rank = rank 70 | self.local_rank = local_rank 71 | 72 | self.is_init = True 73 | 74 | 75 | _env = _Env() 76 | 77 | 78 | def reload_env(): 79 | """Reload the environment.""" 80 | _env.init(force=True) 81 | 82 | 83 | def set_env(**kwargs): 84 | """Set environment without any sanity check.""" 85 | _env.set(**kwargs) 86 | 87 | 88 | def _ensure_parse_env(get_func): 89 | def func(): 90 | if not _env.is_init: 91 | _env.init() 92 | return get_func() 93 | 94 | return func 95 | 96 | 97 | @_ensure_parse_env 98 | def get_world_size() -> int: 99 | """Get the total number of processes.""" 100 | return _env.world_size 101 | 102 | 103 | @_ensure_parse_env 104 | def get_rank() -> int: 105 | """Get the rank of current process.""" 106 | return _env.rank 107 | 108 | 109 | @_ensure_parse_env 110 | def get_local_rank() -> int: 111 | """Get the local rank of current process. 112 | 113 | Local rank is the rank of the process on the local machine.""" 114 | return _env.local_rank 115 | 116 | 117 | @_ensure_parse_env 118 | def get_replica_size() -> int: 119 | """Get the replica size of the current service. 120 | 121 | Replica size is the number of services launched by docker service or k8s""" 122 | return _env.replica_size 123 | 124 | 125 | @_ensure_parse_env 126 | def get_replica_index() -> int: 127 | """Get the replica index of current service. 128 | 129 | The replica index is a unique identifier assigned to each replica. They are assigned following 130 | the order of launching. 131 | """ 132 | return _env.replica_index 133 | -------------------------------------------------------------------------------- /rust/persia-embedding-server/src/monitor.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use persia_libs::hyperloglogplus::HyperLogLog; 4 | use persia_libs::{hashbrown, hyperloglogplus, once_cell, parking_lot, tracing}; 5 | 6 | use persia_common::{utils::ChannelPair, SingleSignInFeatureBatch}; 7 | use persia_metrics::{GaugeVec, PersiaMetricsManager, PersiaMetricsManagerError}; 8 | 9 | const INDICES_CHANNEL_CAP: usize = 1000; 10 | 11 | static METRICS_HOLDER: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); 12 | 13 | struct MetricsHolder { 14 | pub estimated_distinct_id: GaugeVec, 15 | } 16 | 17 | impl MetricsHolder { 18 | pub fn get() -> Result<&'static Self, PersiaMetricsManagerError> { 19 | METRICS_HOLDER.get_or_try_init(|| { 20 | let m = PersiaMetricsManager::get()?; 21 | let holder = Self { 22 | estimated_distinct_id: m.create_gauge_vec("estimated_distinct_id", "ATT")?, 23 | }; 24 | Ok(holder) 25 | }) 26 | } 27 | } 28 | 29 | pub struct EmbeddingMonitorInner { 30 | _feature_name: String, 31 | distinct_id_estimator: Arc< 32 | parking_lot::Mutex< 33 | hyperloglogplus::HyperLogLogPlus, 34 | >, 35 | >, 36 | indices_channel: ChannelPair>, 37 | _handlers: Vec>, 38 | } 39 | 40 | impl EmbeddingMonitorInner { 41 | pub fn new(feature_name: String) -> Self { 42 | let indices_channel = ChannelPair::new(INDICES_CHANNEL_CAP); 43 | let distinct_id_estimator = Arc::new(parking_lot::Mutex::new( 44 | hyperloglogplus::HyperLogLogPlus::new( 45 | 16, 46 | hashbrown::hash_map::DefaultHashBuilder::default(), 47 | ) 48 | .unwrap(), 49 | )); 50 | let mut handlers = Vec::new(); 51 | 52 | let recv_handler = { 53 | let reveiver = indices_channel.receiver.clone(); 54 | let feature_name = feature_name.clone(); 55 | let distinct_id_estimator = distinct_id_estimator.clone(); 56 | std::thread::spawn(move || { 57 | tracing::info!( 58 | "background thread for estimating {} distinct id start...", 59 | feature_name 60 | ); 61 | loop { 62 | let indices = reveiver.recv().unwrap_or(vec![]); 63 | let mut estimator = distinct_id_estimator.lock(); 64 | indices.iter().for_each(|id| { 65 | estimator.insert(id); 66 | }) 67 | } 68 | }) 69 | }; 70 | handlers.push(recv_handler); 71 | 72 | let commit_handler = { 73 | let distinct_id_estimator = distinct_id_estimator.clone(); 74 | let feature_name = feature_name.clone(); 75 | std::thread::spawn(move || loop { 76 | std::thread::sleep(std::time::Duration::from_secs(1)); 77 | if let Ok(m) = MetricsHolder::get() { 78 | let distinct_id = { distinct_id_estimator.lock().count().trunc() as u64 }; 79 | tracing::debug!("distinct_id for {} is {}", feature_name, distinct_id); 80 | m.estimated_distinct_id 81 | .with_label_values(&[feature_name.as_str()]) 82 | .set(distinct_id as f64); 83 | } 84 | }) 85 | }; 86 | handlers.push(commit_handler); 87 | 88 | EmbeddingMonitorInner { 89 | _feature_name: feature_name, 90 | distinct_id_estimator, 91 | indices_channel, 92 | _handlers: handlers, 93 | } 94 | } 95 | } 96 | 97 | impl EmbeddingMonitorInner { 98 | pub fn monitor_index_batch(&self, index_batch: &Vec) { 99 | let channel_size = self.indices_channel.sender.len(); 100 | if channel_size > INDICES_CHANNEL_CAP { 101 | tracing::warn!("too many batches when estimating distinct id, skiping..."); 102 | return; 103 | } 104 | let indices: Vec = index_batch.iter().map(|x| x.sign.clone()).collect(); 105 | let result = self.indices_channel.sender.try_send(indices); 106 | if result.is_err() { 107 | tracing::warn!("too many batches when estimating distinct id, skiping..."); 108 | } 109 | } 110 | 111 | pub fn estimate_distinct_id(&self) -> usize { 112 | self.distinct_id_estimator.lock().count().trunc() as usize 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /rust/persia-embedding-server/src/bin/persia-embedding-parameter-server.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::needless_return)] 2 | 3 | #[macro_use] 4 | extern crate shadow_rs; 5 | 6 | use std::{path::PathBuf, sync::Arc}; 7 | 8 | use persia_libs::{anyhow::Result, color_eyre, hyper, tracing, tracing_subscriber}; 9 | use structopt::StructOpt; 10 | 11 | use persia_common::utils::start_deadlock_detection_thread; 12 | use persia_embedding_config::{ 13 | EmbeddingConfig, EmbeddingParameterServerConfig, PerisaJobType, PersiaCommonConfig, 14 | PersiaGlobalConfig, 15 | }; 16 | use persia_embedding_holder::PersiaEmbeddingHolder; 17 | use persia_embedding_server::embedding_parameter_service::{ 18 | EmbeddingParameterNatsService, EmbeddingParameterNatsServiceResponder, 19 | EmbeddingParameterService, EmbeddingParameterServiceInner, 20 | }; 21 | use persia_incremental_update_manager::PerisaIncrementalUpdateManager; 22 | use persia_model_manager::EmbeddingModelManager; 23 | 24 | #[derive(Debug, StructOpt, Clone)] 25 | #[structopt()] 26 | struct Cli { 27 | #[structopt(long)] 28 | port: u16, 29 | #[structopt(long)] 30 | replica_index: usize, 31 | #[structopt(long)] 32 | replica_size: usize, 33 | #[structopt(long, env = "PERSIA_GLOBAL_CONFIG")] 34 | global_config: PathBuf, 35 | #[structopt(long, env = "PERSIA_EMBEDDING_CONFIG")] 36 | embedding_config: PathBuf, 37 | } 38 | 39 | #[tokio::main] 40 | async fn main() -> Result<()> { 41 | color_eyre::install().unwrap(); 42 | tracing_subscriber::fmt() 43 | .with_env_filter(tracing_subscriber::EnvFilter::from_env("LOG_LEVEL")) 44 | .init(); 45 | 46 | shadow!(build); 47 | eprintln!("project_name: {}", build::PROJECT_NAME); 48 | eprintln!("is_debug: {}", shadow_rs::is_debug()); 49 | eprintln!("version: {}", build::version()); 50 | eprintln!("tag: {}", build::TAG); 51 | eprintln!("commit_hash: {}", build::COMMIT_HASH); 52 | eprintln!("commit_date: {}", build::COMMIT_DATE); 53 | eprintln!("build_os: {}", build::BUILD_OS); 54 | eprintln!("rust_version: {}", build::RUST_VERSION); 55 | eprintln!("build_time: {}", build::BUILD_TIME); 56 | let args: Cli = Cli::from_args(); 57 | 58 | start_deadlock_detection_thread(); 59 | 60 | PersiaGlobalConfig::set_configures( 61 | &args.global_config, 62 | args.port, 63 | args.replica_index, 64 | args.replica_size, 65 | )?; 66 | 67 | EmbeddingConfig::set(&args.embedding_config)?; 68 | 69 | let embedding_config = EmbeddingConfig::get()?; 70 | let common_config = PersiaCommonConfig::get()?; 71 | let server_config = EmbeddingParameterServerConfig::get()?; 72 | let embedding_holder = PersiaEmbeddingHolder::get()?; 73 | let inc_update_manager = PerisaIncrementalUpdateManager::get()?; 74 | let embedding_model_manager = EmbeddingModelManager::get()?; 75 | let (tx, rx) = tokio::sync::oneshot::channel::<()>(); 76 | 77 | let inner = Arc::new(EmbeddingParameterServiceInner::new( 78 | embedding_holder, 79 | server_config, 80 | common_config, 81 | embedding_config, 82 | inc_update_manager, 83 | embedding_model_manager, 84 | args.replica_index, 85 | )); 86 | 87 | let service = EmbeddingParameterService { 88 | inner: inner.clone(), 89 | shutdown_channel: Arc::new(persia_libs::async_lock::RwLock::new(Some(tx))), 90 | }; 91 | 92 | let server = hyper::Server::bind(&([0, 0, 0, 0], args.port).into()) 93 | .tcp_nodelay(true) 94 | .serve(hyper::service::make_service_fn(|_| { 95 | let service = service.clone(); 96 | async move { Ok::<_, hyper::Error>(service) } 97 | })); 98 | 99 | let job_type = &inner.get_job_type()?; 100 | let _responder = match job_type { 101 | PerisaJobType::Infer => None, 102 | _ => { 103 | let nats_service = EmbeddingParameterNatsService { 104 | inner: inner.clone(), 105 | }; 106 | let responder = EmbeddingParameterNatsServiceResponder::new(nats_service).await; 107 | Some(responder) 108 | } 109 | }; 110 | 111 | match job_type { 112 | PerisaJobType::Infer => { 113 | let common_config = PersiaCommonConfig::get()?; 114 | let embedding_cpk = common_config.infer_config.embedding_checkpoint.clone(); 115 | inner.load(embedding_cpk).await?; 116 | } 117 | _ => {} 118 | } 119 | let graceful = server.with_graceful_shutdown(async { 120 | rx.await.ok(); 121 | }); 122 | 123 | if let Err(err) = graceful.await { 124 | tracing::error!("embedding server exited with error: {:?}!", err); 125 | } else { 126 | tracing::info!("embedding server exited successfully"); 127 | } 128 | 129 | Ok(()) 130 | } 131 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG DEVICE=cuda 2 | ARG BASE_IMAGE=nvidia/cuda:11.2.0-devel-ubuntu20.04 3 | 4 | FROM ${BASE_IMAGE} AS base 5 | ARG PYTHON_VERSION=3.8 6 | ARG PYTORCH_VERSION=1.8 7 | ARG MAGMA_CUDA_VERSION=magma-cuda110 8 | ARG DEVICE 9 | 10 | RUN apt-get -y update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends curl \ 11 | build-essential \ 12 | ca-certificates \ 13 | git \ 14 | libgfortran-8-dev \ 15 | vim \ 16 | zsh \ 17 | wget \ 18 | ssh \ 19 | iputils-ping \ 20 | procps \ 21 | net-tools \ 22 | apt-utils \ 23 | rlwrap \ 24 | ethtool \ 25 | telnet \ 26 | openjdk-11-jdk \ 27 | openssh-server 28 | 29 | RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 30 | chmod +x ~/miniconda.sh && \ 31 | ~/miniconda.sh -b -p /opt/conda && \ 32 | rm ~/miniconda.sh && \ 33 | /opt/conda/bin/conda install -y python=${PYTHON_VERSION} numpy scipy mkl mkl-include ninja cython typing && \ 34 | /opt/conda/bin/conda install -y -c conda-forge mpi4py && \ 35 | ln -s /usr/share/pyshared /opt/conda/lib/python${PYTHON_VERSION}/site-packages && \ 36 | if [ "${DEVICE}" = "cuda" ]; then \ 37 | /opt/conda/bin/conda install -y -c pytorch -c conda-forge ${MAGMA_CUDA_VERSION} pytorch=${PYTORCH_VERSION} torchvision; \ 38 | /opt/conda/bin/pip3 install bagua-cuda113 --no-cache-dir; \ 39 | else \ 40 | /opt/conda/bin/conda install -y -c pytorch -c conda-forge pytorch=${PYTORCH_VERSION} torchvision cpuonly; \ 41 | /opt/conda/bin/pip3 install scikit-learn --no-cache-dir; \ 42 | fi && \ 43 | /opt/conda/bin/conda install torchserve torch-model-archiver torch-workflow-archiver -c pytorch -y; \ 44 | /opt/conda/bin/conda clean -yapf; 45 | 46 | RUN mkdir -p /opt/hadoop/; \ 47 | cd /opt/hadoop/; \ 48 | wget https://dlcdn.apache.org/hadoop/common/hadoop-3.3.1/hadoop-3.3.1.tar.gz; \ 49 | tar -zxvf hadoop-3.3.1.tar.gz; \ 50 | rm hadoop-3.3.1.tar.gz; 51 | 52 | RUN /opt/conda/bin/pip install --no-cache-dir \ 53 | remote-pdb \ 54 | pytest \ 55 | tqdm \ 56 | pandas \ 57 | tensorboard \ 58 | ipython \ 59 | captum \ 60 | grpcio \ 61 | protobuf \ 62 | grpcio-tools && \ 63 | apt-get purge --auto-remove && \ 64 | apt-get clean 65 | 66 | ENV PATH=/opt/conda/bin:/opt/hadoop/hadoop-3.3.1/bin/:$PATH 67 | ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64/ 68 | ENV LIBRARY_PATH="/usr/local/lib64:/usr/local/lib:/usr/lib" 69 | ENV LD_LIBRARY_PATH="/opt/conda/lib/python${PYTHON_VERSION}/site-packages/torch/lib/:/opt/conda/lib/" 70 | 71 | # alias for cpu builder image 72 | FROM base AS cpu-builder-base 73 | # alias for gpu builder image 74 | FROM base AS cuda-builder-base 75 | ARG DEVICE 76 | 77 | ENV USE_CUDA=1 78 | ENV LIBRARY_PATH="${LIBRARY_PATH}:/usr/local/cuda/lib64/stubs/" 79 | 80 | FROM ${DEVICE}-builder-base AS builder 81 | 82 | ENV RUSTUP_HOME=/rust 83 | ENV CARGO_HOME=/cargo 84 | ENV PATH=/cargo/bin:/rust/bin:/opt/conda/bin:$PATH 85 | 86 | RUN curl -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y --profile default --no-modify-path 87 | 88 | FROM builder AS persia-builder 89 | 90 | WORKDIR /workspace 91 | COPY . /workspace 92 | RUN cd /workspace && pip3 install colorama setuptools setuptools-rust setuptools_scm \ 93 | && python setup.py bdist_wheel --dist-dir=/root/dist && rm -rf /workspace 94 | 95 | # Build bagua distributed training framework manully 96 | # RUN if [ "${DEVICE}" = "cuda" ]; then \ 97 | # rm -rf /etc/apt/sources.list.d; \ 98 | # apt-get update; \ 99 | # DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends zlib1g-dev libhwloc-dev; \ 100 | # git clone https://github.com/BaguaSys/bagua.git; \ 101 | # cd bagua; \ 102 | # pip3 install cmake setuptools-rust colorama tqdm wheel --no-cache-dir; \ 103 | # git submodule update --init --recursive; \ 104 | # python setup.py bdist_wheel --dist-dir=/root/dist; \ 105 | # cd ..; \ 106 | # rm -rf bagua; \ 107 | # /opt/conda/bin/conda clean -yapf; \ 108 | # fi 109 | 110 | ARG DEVICE 111 | FROM base AS runtime 112 | 113 | # Install the persia-runtime and bagua (Optional for cpu-runtime) 114 | COPY --from=persia-builder /root/dist . 115 | RUN pip3 install *.whl && rm -rf *.whl 116 | 117 | # Install nats server 118 | RUN wget https://github.com/nats-io/nats-server/releases/download/v2.6.6/nats-server-v2.6.6-linux-amd64.tar.gz && \ 119 | tar -zxvf nats-server-v2.6.6-linux-amd64.tar.gz && \ 120 | cp nats-server-v2.6.6-linux-amd64/nats-server /usr/bin/ &&\ 121 | rm -rf nats-server-v2.6.6-linux-amd64/ && \ 122 | rm nats-server-v2.6.6-linux-amd64.tar.gz 123 | 124 | # Prepare examples 125 | RUN mkdir -p /home/PERSIA/examples 126 | COPY examples /home/PERSIA/examples 127 | RUN cd /home/PERSIA/examples/src/adult-income/data/ && ./prepare_data.sh -------------------------------------------------------------------------------- /rust/persia-embedding-holder/src/eviction_map.rs: -------------------------------------------------------------------------------- 1 | use persia_libs::hashbrown::HashMap; 2 | use std::convert::TryFrom; 3 | use std::hash::Hash; 4 | 5 | use crate::array_linked_list::ArrayLinkedList; 6 | 7 | pub trait EvictionMapValue { 8 | fn hashmap_key(&self) -> K; 9 | } 10 | 11 | pub struct EvictionMap 12 | where 13 | K: Hash + Eq + Clone, 14 | V: EvictionMapValue, 15 | { 16 | pub hashmap: HashMap, 17 | pub linkedlist: ArrayLinkedList, 18 | pub capacity: usize, 19 | } 20 | 21 | impl EvictionMap 22 | where 23 | K: Hash + Eq + Clone, 24 | V: EvictionMapValue, 25 | { 26 | pub fn with_capacity(capacity: usize) -> Self { 27 | Self { 28 | hashmap: HashMap::with_capacity(capacity + 1), 29 | linkedlist: ArrayLinkedList::with_capacity(capacity as u32 + 1), 30 | capacity, 31 | } 32 | } 33 | 34 | pub fn get(&self, key: &K) -> Option<&V> { 35 | match self.hashmap.get(&key) { 36 | Some(idx) => self.linkedlist[*idx as usize].as_ref(), 37 | None => None, 38 | } 39 | } 40 | 41 | pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { 42 | match self.hashmap.get(&key) { 43 | Some(idx) => self.linkedlist[*idx as usize].as_mut(), 44 | None => None, 45 | } 46 | } 47 | 48 | pub fn get_refresh(&mut self, key: &K) -> Option<&V> { 49 | match self.hashmap.get(&key) { 50 | Some(idx) => { 51 | let idx = u32::try_from(*idx).expect("u32 array linked list overflow"); 52 | let v = self.linkedlist.remove(idx).unwrap(); 53 | let new_idx = self.linkedlist.push_back(v); 54 | let idx_ref = self.hashmap.get_mut(key).unwrap(); 55 | *idx_ref = new_idx; 56 | self.linkedlist[new_idx as usize].as_ref() 57 | } 58 | None => None, 59 | } 60 | } 61 | 62 | pub fn get_refresh_mut(&mut self, key: &K) -> Option<&mut V> { 63 | match self.hashmap.get(&key) { 64 | Some(idx) => { 65 | let idx = u32::try_from(*idx).expect("u32 array linked list overflow"); 66 | let v = self.linkedlist.remove(idx).unwrap(); 67 | let new_idx = self.linkedlist.push_back(v); 68 | let idx_ref = self.hashmap.get_mut(key).unwrap(); 69 | *idx_ref = new_idx; 70 | self.linkedlist[new_idx as usize].as_mut() 71 | } 72 | None => None, 73 | } 74 | } 75 | 76 | pub fn insert(&mut self, key: K, value: V) -> (Option, Option) { 77 | let old = match self.hashmap.get(&key) { 78 | Some(idx) => self.linkedlist.remove(*idx), 79 | None => None, 80 | }; 81 | 82 | let new_idx = self.linkedlist.push_back(value); 83 | self.hashmap.insert(key, new_idx); 84 | 85 | let evicted = if self.linkedlist.len() as usize > self.capacity { 86 | let evicted = self.linkedlist.pop_front(); 87 | if let Some(evicted_v) = &evicted { 88 | let evicted_k = evicted_v.hashmap_key(); 89 | self.hashmap.remove(&evicted_k); 90 | } 91 | evicted 92 | } else { 93 | None 94 | }; 95 | 96 | (old, evicted) 97 | } 98 | 99 | pub fn clear(&mut self) { 100 | self.hashmap.clear(); 101 | self.linkedlist.clear(); 102 | } 103 | 104 | pub fn capacity(&self) -> usize { 105 | self.capacity 106 | } 107 | 108 | pub fn len(&self) -> usize { 109 | self.linkedlist.len() as usize 110 | } 111 | } 112 | 113 | #[cfg(test)] 114 | mod eviction_map_tests { 115 | // Note this useful idiom: importing names from outer (for mod tests) scope. 116 | use super::*; 117 | use crate::emb_entry::HashMapEmbeddingEntry; 118 | use persia_embedding_config::InitializationMethod; 119 | 120 | #[test] 121 | fn test_evict() { 122 | let mut map: EvictionMap = EvictionMap::with_capacity(5); 123 | 124 | let initialization = InitializationMethod::default(); 125 | 126 | for i in 0..5 { 127 | let entry = HashMapEmbeddingEntry::new(&initialization, 8, 16, i, i); 128 | map.insert(i, entry); 129 | } 130 | 131 | assert_eq!(map.len(), 5); 132 | 133 | for i in 5..10 { 134 | let entry = HashMapEmbeddingEntry::new(&initialization, 8, 16, i, i); 135 | map.insert(i, entry); 136 | } 137 | 138 | assert_eq!(map.len(), 5); 139 | assert_eq!(map.get_refresh(&4).is_none(), true); 140 | assert_eq!(map.get_refresh(&5).is_some(), true); 141 | 142 | let entry = HashMapEmbeddingEntry::new(&initialization, 8, 16, 10, 10); 143 | map.insert(10, entry); 144 | 145 | assert_eq!(map.len(), 5); 146 | assert_eq!(map.get_refresh(&6).is_none(), true); 147 | assert_eq!(map.get_refresh(&5).is_some(), true); 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "PERSIA API Documentation" 21 | copyright = "2021, Kuaishou AI Platform" 22 | author = "Kuaishou AI Platform" 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | "autoapi.extension", 32 | "sphinx.ext.autodoc", 33 | "sphinx.ext.intersphinx", 34 | "sphinx.ext.todo", 35 | "sphinx.ext.viewcode", 36 | "sphinx.ext.napoleon", 37 | "sphinx_multiversion", 38 | ] 39 | 40 | 41 | napoleon_numpy_docstring = True 42 | autoapi_python_class_content = "both" 43 | autodoc_typehints = "description" 44 | autoapi_type = "python" 45 | autoapi_dirs = ["../persia"] 46 | autoapi_root = "autoapi" 47 | autoapi_template_dir = "_autoapi_templates" 48 | autoapi_ignore = [] 49 | autoapi_options = [ 50 | "members", 51 | "undoc-members", 52 | "show-inheritance", 53 | "imported-members", 54 | ] 55 | autoapi_member_order = "groupwise" 56 | 57 | master_doc = "autoapi/index" 58 | 59 | # Add any paths that contain templates here, relative to this directory. 60 | templates_path = ["_templates"] 61 | 62 | # List of patterns, relative to source directory, that match files and 63 | # directories to ignore when looking for source files. 64 | # This pattern also affects html_static_path and html_extra_path. 65 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 66 | 67 | 68 | # -- Options for HTML output ------------------------------------------------- 69 | 70 | # Add any paths that contain custom static files (such as style sheets) here, 71 | # relative to this directory. They are copied after the builtin static files, 72 | # so a file named "default.css" will overwrite the builtin "default.css". 73 | html_static_path = ["_static"] 74 | 75 | # If true, '()' will be appended to :func: etc. cross-reference text. 76 | add_function_parentheses = False 77 | 78 | # If true, the current module name will be prepended to all description 79 | # unit titles (such as .. function::). 80 | add_module_names = True 81 | 82 | # If true, `todo` and `todoList` produce output, else they produce nothing. 83 | todo_include_todos = True 84 | 85 | # The theme to use for HTML and HTML Help pages. See the documentation for 86 | # a list of builtin themes. 87 | html_theme = "sphinx_rtd_theme" 88 | 89 | # Theme options are theme-specific and customize the look and feel of a theme 90 | # further. For a list of options available for each theme, see the 91 | # documentation. 92 | html_theme_options = { 93 | "show_powered_by": False, 94 | "github_user": "PersiaML", 95 | "github_repo": "PERSIA", 96 | "github_banner": True, 97 | "show_related": False, 98 | "note_bg": "#FFF59C", 99 | } 100 | 101 | # If true, SmartyPants will be used to convert quotes and dashes to 102 | # typographically correct entities. 103 | html_use_smartypants = False 104 | 105 | # If true, links to the reST sources are added to the pages. 106 | html_show_sourcelink = False 107 | 108 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 109 | html_show_sphinx = False 110 | 111 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 112 | html_show_copyright = True 113 | 114 | 115 | _ignore_methods = [] 116 | 117 | _ignore_functions = [ 118 | "persia.launcher.cli", 119 | "persia.launcher.nn_worker", 120 | "persia.launcher.data_loader", 121 | "persia.launcher.embedding_worker", 122 | "persia.launcher.embedding_parameter_server", 123 | ] 124 | 125 | _ignore_classes = [] 126 | 127 | _ignore_module = [ 128 | "persia.version", 129 | "persia.logger", 130 | "persia.prelude", 131 | "persia.k8s_utils", 132 | ] 133 | 134 | 135 | def skip_methods(app, what, name, obj, skip, options): 136 | if what == "method" and name in _ignore_methods: 137 | skip = True 138 | return skip 139 | 140 | if what == "function" and name in _ignore_functions: 141 | skip = True 142 | return skip 143 | 144 | if what == "class" and name in _ignore_classes: 145 | skip = True 146 | return skip 147 | 148 | if what == "module" and name in _ignore_module: 149 | skip = True 150 | return skip 151 | 152 | return skip 153 | 154 | 155 | def setup(sphinx): 156 | sphinx.connect("autoapi-skip-member", skip_methods) 157 | -------------------------------------------------------------------------------- /rust/persia-embedding-server/src/bin/persia-embedding-worker.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate shadow_rs; 3 | 4 | use std::path::PathBuf; 5 | use std::sync::Arc; 6 | 7 | use persia_libs::{ 8 | anyhow::Result, color_eyre, hashbrown::HashMap, hyper, rand, tracing, tracing_subscriber, 9 | }; 10 | 11 | use structopt::StructOpt; 12 | 13 | use persia_common::utils::start_deadlock_detection_thread; 14 | use persia_embedding_config::{ 15 | EmbeddingConfig, EmbeddingWorkerConfig, PerisaJobType, PersiaCommonConfig, PersiaGlobalConfig, 16 | }; 17 | use persia_embedding_server::embedding_parameter_service::EmbeddingParameterNatsServicePublisher; 18 | use persia_embedding_server::embedding_worker_service::{ 19 | AllEmbeddingServerClient, EmbeddingWorker, EmbeddingWorkerInner, EmbeddingWorkerNatsService, 20 | EmbeddingWorkerNatsServiceResponder, 21 | }; 22 | use persia_model_manager::EmbeddingModelManager; 23 | 24 | #[derive(Debug, StructOpt, Clone)] 25 | #[structopt()] 26 | struct Cli { 27 | #[structopt(long)] 28 | port: u16, 29 | #[structopt(long)] 30 | replica_index: usize, 31 | #[structopt(long)] 32 | replica_size: usize, 33 | #[structopt(long, env = "PERSIA_GLOBAL_CONFIG")] 34 | global_config: PathBuf, 35 | #[structopt(long, env = "PERSIA_EMBEDDING_CONFIG")] 36 | embedding_config: PathBuf, 37 | } 38 | 39 | #[tokio::main] 40 | async fn main() -> Result<()> { 41 | color_eyre::install().unwrap(); 42 | tracing_subscriber::fmt() 43 | .with_env_filter(tracing_subscriber::EnvFilter::from_env("LOG_LEVEL")) 44 | .init(); 45 | shadow!(build); 46 | eprintln!("project_name: {}", build::PROJECT_NAME); 47 | eprintln!("is_debug: {}", shadow_rs::is_debug()); 48 | eprintln!("version: {}", build::version()); 49 | eprintln!("tag: {}", build::TAG); 50 | eprintln!("commit_hash: {}", build::COMMIT_HASH); 51 | eprintln!("commit_date: {}", build::COMMIT_DATE); 52 | eprintln!("build_os: {}", build::BUILD_OS); 53 | eprintln!("rust_version: {}", build::RUST_VERSION); 54 | eprintln!("build_time: {}", build::BUILD_TIME); 55 | let args: Cli = Cli::from_args(); 56 | 57 | start_deadlock_detection_thread(); 58 | 59 | PersiaGlobalConfig::set_configures( 60 | &args.global_config, 61 | args.port, 62 | args.replica_index, 63 | args.replica_size, 64 | )?; 65 | 66 | EmbeddingConfig::set(&args.embedding_config)?; 67 | 68 | let common_config = PersiaCommonConfig::get()?; 69 | let all_embedding_server_client = match &common_config.job_type { 70 | PerisaJobType::Infer => { 71 | let servers = common_config.infer_config.servers.clone(); 72 | AllEmbeddingServerClient::with_addrs(servers).await 73 | } 74 | _ => { 75 | let nats_publisher = EmbeddingParameterNatsServicePublisher::new().await; 76 | AllEmbeddingServerClient::with_nats(nats_publisher).await 77 | } 78 | }; 79 | 80 | let replica_size = all_embedding_server_client.replica_size() as u64; 81 | let embedding_worker_config = EmbeddingWorkerConfig::get()?; 82 | let embedding_config = EmbeddingConfig::get()?; 83 | let embedding_model_manager = EmbeddingModelManager::get()?; 84 | 85 | let inner = Arc::new(EmbeddingWorkerInner { 86 | all_embedding_server_client, 87 | replica_size, 88 | forward_id: std::sync::atomic::AtomicU64::new(rand::random()), 89 | forward_id_buffer: persia_libs::async_lock::RwLock::new(HashMap::with_capacity(10000)), 90 | post_forward_buffer: persia_libs::async_lock::RwLock::new(HashMap::with_capacity(10000)), 91 | cannot_forward_batched_time: crossbeam::atomic::AtomicCell::new( 92 | std::time::SystemTime::now(), 93 | ), 94 | embedding_config, 95 | staleness: Default::default(), 96 | embedding_worker_config, 97 | embedding_model_manager, 98 | }); 99 | 100 | let _responder = match &common_config.job_type { 101 | PerisaJobType::Infer => None, 102 | _ => { 103 | let nats_service = EmbeddingWorkerNatsService { 104 | inner: inner.clone(), 105 | }; 106 | let responder = EmbeddingWorkerNatsServiceResponder::new(nats_service).await; 107 | Some(responder) 108 | } 109 | }; 110 | 111 | let (tx, rx) = tokio::sync::oneshot::channel::<()>(); 112 | let service = EmbeddingWorker { 113 | inner: inner, 114 | shutdown_channel: Arc::new(persia_libs::async_lock::RwLock::new(Some(tx))), 115 | }; 116 | 117 | let server = hyper::server::Server::bind(&([0, 0, 0, 0], args.port).into()) 118 | .tcp_nodelay(true) 119 | .serve(hyper::service::make_service_fn(|_| { 120 | let service = service.clone(); 121 | async { Ok::<_, hyper::Error>(service) } 122 | })); 123 | 124 | tracing::info!("embedding worker rpc server started"); 125 | 126 | let graceful = server.with_graceful_shutdown(async { 127 | rx.await.ok(); 128 | }); 129 | 130 | if let Err(err) = graceful.await { 131 | tracing::error!("embedding worker exited with error: {:?}!", err); 132 | } else { 133 | tracing::info!("embedding worker exited successfully"); 134 | } 135 | 136 | Ok(()) 137 | } 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 |
6 | 7 |

8 | tutorials 9 | Documentation Status 10 | PyPI version 11 | PyPI downloads 12 | Docker Pulls 13 | license 14 |

15 | 16 |
17 | 18 |
19 | 20 | 21 | *WARNING: THIS PROJECT IS CURRENTLY NOT MAINTAINED, DUE TO COMPANY REORGANIZATION.* 22 | 23 | **PERSIA** (**P**arallel r**E**commendation t**R**aining **S**ystem with hybr**I**d **A**cceleration) is developed by [AI platform@Kuaishou Technology](https://www.kuaishou.com/en), collaborating with ETH. It is a PyTorch-based (the first public one to our best knowledge) system for training large scale deep learning recommendation models on commodity hardwares. It is capable of training recommendation models with up to 100 trillion parameters. To the best of our knowledge, this is the largest model size in recommendation systems so far. Empirical study on public datasets indicate PERSIA's significant advantage over several other existing training systems in recommendation [1]. Its efficiency and robustness have also been validated by multiple applications with 100 million level DAU at Kuaishou. 24 | 25 | *Disclaimer: The program is usable and has served several important businesses. However, the official English documentation and tutorials are still under heavy construction and they are a bit raw now. We encourage adventurers to try out PERSIA and contribute!* 26 | 27 | ## News 28 | 29 | * [Training Deep Learning-based recommender models of 100 trillion parameters over Google Cloud](https://archive.ph/8ay0C) 30 | * [突破百万亿参数规模,追求极致的效率和性价比:华人团队开源首个异构并行推荐系统训练框架 PERSIA](https://archive.ph/Mixk0) (In Chinese. Title: Breaking through the trillion parameter scale in pursuit of ultimate efficiency and cost effectiveness: Chinese team open source PERSIA, the first heterogeneous parallel recommendation system) 31 | * [参数量卷到一百万亿!华人团队开源史上最大的推荐训练系统 PERSIA](https://archive.md/FbocB) (In Chinese. Title: PERSIA, the Largest Recommended Training System in the History of Open Source by Far) 32 | * AI Engines in the "Short-video" Era: Eating 100 Trillion Parameters, Invited talk, Facebook, 2021. 33 | * 单机训练速度提升 640 倍!独家解读快手商业广告模型 GPU 训练平台 PERSIA (In Chinese. Title: 640x Faster GPU Based Learning System for Ad Recommendation) 34 | * [[AI Front]](https://archive.is/2ii2L) [[中国日报]](https://archive.is/N8fK2) [[InfoQ]](https://archive.is/JESDU) [[CSDN]](https://archive.is/tpvkN) [[Tencent Cloud News]](https://archive.is/kLuaT) [[AcFun]](https://archive.md/vuPmb) 35 | * 创新、平衡与大格局:快手商业化的慢与快 (In Chinese. Title: Innovation, Balance, and Big Picture: The Speed of Kwai Commercialization) 36 | * [[TechSir]](https://archive.is/EOQ18) [[China Daily]](https://archive.is/L2VJE) [[Sohu]](https://archive.is/aY66U) 37 | 38 | ## Links 39 | 40 | * [GitHub Repository](https://github.com/PersiaML/PERSIA) 41 | * [Tutorials](https://persiaml-tutorials.pages.dev/) 42 | * [API documentation](https://persiaml.pages.dev/) (Under Construction) 43 | 44 | ## References 45 | 46 | 1. Xiangru Lian, Binhang Yuan, Xuefeng Zhu, Yulong Wang, Yongjun He, Honghuan Wu, Lei Sun, Haodong Lyu, Chengjun Liu, Xing Dong, Yiqiao Liao, Mingnan Luo, Congfei Zhang, Jingru Xie, Haonan Li, Lei Chen, Renjie Huang, Jianying Lin, Chengchun Shu, Xuezhong Qiu, Zhishan Liu, Dongying Kong, Lei Yuan, Hai Yu, Sen Yang, Ce Zhang, & Ji Liu. (2021). [Persia: A Hybrid System Scaling Deep Learning Based Recommenders up to 100 Trillion Parameters.](https://arxiv.org/abs/2111.05897) 47 | 48 | 2. Ji Liu & Ce Zhang. (2021). [Distributed Learning Systems with First-order Methods](https://arxiv.org/pdf/2104.05245). 49 | 50 | ## License 51 | 52 | This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. 53 | -------------------------------------------------------------------------------- /rust/others/persia-rpc/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::ops::Add; 2 | 3 | use persia_libs::{bytes::Buf, hyper, lz4, tokio, url}; 4 | use persia_speedy::{Readable, Writable}; 5 | use snafu::{ensure, Backtrace, ResultExt, Snafu}; 6 | 7 | #[derive(Snafu, Debug)] 8 | #[snafu(visibility = "pub")] 9 | pub enum PersiaRpcError { 10 | #[snafu(display("serialization error"))] 11 | SerializationFailure { 12 | source: persia_speedy::Error, 13 | backtrace: Option, 14 | }, 15 | #[snafu(display("io error"))] 16 | IOFailure { 17 | source: std::io::Error, 18 | backtrace: Option, 19 | }, 20 | #[snafu(display("server addr parse error from {}: {}", server_addr, source))] 21 | ServerAddrParseFailure { 22 | server_addr: String, 23 | source: url::ParseError, 24 | backtrace: Option, 25 | }, 26 | #[snafu(display("transport error {}: {}", msg, source))] 27 | TransportError { 28 | msg: String, 29 | source: hyper::Error, 30 | backtrace: Option, 31 | }, 32 | #[snafu(display("transport server side error {}", msg))] 33 | TransportServerSideError { 34 | msg: String, 35 | backtrace: Option, 36 | }, 37 | } 38 | 39 | pub struct RpcClient { 40 | client: hyper::Client, 41 | server_addr: url::Url, 42 | } 43 | 44 | fn expect_uri(url: url::Url) -> hyper::Uri { 45 | url.as_str() 46 | .parse() 47 | .expect("a parsed Url should always be a valid Uri") 48 | } 49 | 50 | impl RpcClient { 51 | /// server_addr format should be host:port 52 | pub fn new(server_addr: &str) -> Result { 53 | let server_addr = url::Url::parse("http://".to_string().add(server_addr).as_str()) 54 | .context(ServerAddrParseFailure { 55 | server_addr: server_addr.to_string(), 56 | })?; 57 | Ok(Self { 58 | client: hyper::Client::builder() 59 | // .http2_only(true) 60 | // .retry_canceled_requests(true) 61 | // .set_host(false) 62 | // .http2_adaptive_window(true) 63 | .build_http(), 64 | server_addr, 65 | }) 66 | } 67 | 68 | pub async fn call_async<'a, T, R>( 69 | &self, 70 | endpoint_name: &str, 71 | input: &T, 72 | compress: bool, 73 | ) -> Result 74 | where 75 | R: Readable<'a, persia_speedy::LittleEndian> + Send + 'static, 76 | T: Writable + Send + 'static, 77 | { 78 | let server_addr = self 79 | .server_addr 80 | .join(endpoint_name) 81 | .context(ServerAddrParseFailure { 82 | server_addr: endpoint_name.to_string(), 83 | })?; 84 | 85 | let data = tokio::task::block_in_place(|| input.write_to_vec()) 86 | .context(SerializationFailure {})?; 87 | 88 | let data = if compress && (data.len() > 0) { 89 | tokio::task::block_in_place(|| { 90 | lz4::block::compress( 91 | data.as_slice(), 92 | Some(lz4::block::CompressionMode::FAST(3)), 93 | true, 94 | ) 95 | }) 96 | .context(IOFailure {})? 97 | } else { 98 | data 99 | }; 100 | 101 | let req = hyper::Request::builder() 102 | .method("POST") 103 | .uri(expect_uri(server_addr)) 104 | .body(hyper::Body::from(data)) 105 | .expect("request builder"); 106 | 107 | let response = self.client.request(req).await.context(TransportError { 108 | msg: format!("call {} error", endpoint_name), 109 | })?; 110 | ensure!( 111 | response.status() == hyper::http::StatusCode::OK, 112 | TransportServerSideError { 113 | msg: format!( 114 | "call {} server side error: {:?}", 115 | endpoint_name, 116 | response.into_body() 117 | ), 118 | } 119 | ); 120 | 121 | let mut resp_bytes = 122 | hyper::body::aggregate(response.into_body()) 123 | .await 124 | .context(TransportError { 125 | msg: format!("call {} recv bytes error", endpoint_name), 126 | })?; 127 | 128 | if compress && resp_bytes.remaining() >= 4 { 129 | let resp_bytes = tokio::task::block_in_place(|| { 130 | let mut buffer = vec![0; resp_bytes.remaining()]; 131 | resp_bytes.copy_to_slice(buffer.as_mut()); 132 | lz4::block::decompress(buffer.as_slice(), None) 133 | }) 134 | .context(IOFailure {})?; 135 | let resp: R = 136 | tokio::task::block_in_place(|| R::read_from_buffer_owned(resp_bytes.as_slice())) // TODO: this can be zero copy if we use read_from_buffer and correctly deal with lifetime 137 | .context(SerializationFailure {})?; 138 | return Ok(resp); 139 | } else { 140 | let resp: R = 141 | tokio::task::block_in_place(|| R::read_from_stream_unbuffered(resp_bytes.reader())) // TODO: this can be zero copy if we use read_from_buffer and correctly deal with lifetime 142 | .context(SerializationFailure {})?; 143 | return Ok(resp); 144 | } 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /test/test_ctx.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import List 4 | 5 | import numpy as np 6 | import torch 7 | import pytest 8 | 9 | from persia.helper import ensure_persia_service 10 | from persia.ctx import BaseCtx, DataCtx 11 | 12 | from .utils import random_port 13 | 14 | EMBEDDING_CONFIG = {"slots_config": {"age": {"dim": 8}}} 15 | RAW_EMBEDDING_CONFIG = { 16 | "slots_config": { 17 | "user_id": {"dim": 8}, 18 | "user_id_follower_list": {"dim": 8, "embedding_summation": False}, 19 | } 20 | } 21 | 22 | GLOBAL_CONFIG = { 23 | "embedding_worker_config": {"forward_buffer_size": 1000}, 24 | "common_config": {"metrics_config": {"enable_metrics": False}}, 25 | } 26 | 27 | 28 | def assert_ndarray_base_data( 29 | ndarray_base_data_list: List[np.ndarray], 30 | tensors: List[torch.Tensor], 31 | use_cuda: bool, 32 | ): 33 | assert len(ndarray_base_data_list) == len(tensors) 34 | for ndarray_base_data, tensor in zip(ndarray_base_data_list, tensors): 35 | if use_cuda: 36 | tensor = tensor.cpu() 37 | 38 | np.testing.assert_equal(ndarray_base_data, tensor.numpy()) 39 | 40 | 41 | def assert_id_type_feature_data(tensors: List[torch.Tensor], config: dict): 42 | embedding_configs = config["slots_config"] 43 | for tensor, embedding_config in zip(tensors, embedding_configs): 44 | expected_dim = ( 45 | embedding_config["dim"] 46 | if embedding_config.get("embedding_summation", True) 47 | else embedding_config["dim"] + 1 48 | ) 49 | expected_ndim = 2 if embedding_config["embedding_summation"] else 3 50 | assert len(tensor.shape) == expected_ndim 51 | assert tensor.shape[-1] == expected_dim 52 | 53 | 54 | # FIXME: Try no-singleton PersiaCommonContext. 55 | # Every time init the PersiaCommonContext, it will reuse the instance created 56 | # before. Any environment update makes no effects on singleton instance PersiaCommonContext, 57 | # such as PERSIA_NATS_URL. 58 | 59 | if torch.cuda.is_available(): 60 | parameter_list = [True] 61 | ids = ["cuda"] 62 | else: 63 | parameter_list = [False] 64 | ids = ["cpu"] 65 | 66 | 67 | @pytest.mark.parametrize("use_cuda", parameter_list, ids=ids) 68 | def test_data_ctx(use_cuda: bool): 69 | non_id_type_features = [np.array([1], dtype=np.float32)] 70 | labels = [ 71 | np.array( 72 | [ 73 | 1, 74 | ], 75 | dtype=np.float32, 76 | ) 77 | ] 78 | 79 | def data_loader(): 80 | from persia.embedding.data import ( 81 | PersiaBatch, 82 | IDTypeFeature, 83 | NonIDTypeFeature, 84 | Label, 85 | ) 86 | 87 | persia_batch = PersiaBatch( 88 | [ 89 | IDTypeFeature( 90 | "age", 91 | [ 92 | np.array( 93 | [ 94 | 1, 95 | 2, 96 | 3, 97 | ], 98 | dtype=np.uint64, 99 | ) 100 | ], 101 | ) 102 | ], 103 | non_id_type_features=[ 104 | NonIDTypeFeature(non_id_type_feature) 105 | for non_id_type_feature in non_id_type_features 106 | ], 107 | labels=[Label(label) for label in labels], 108 | requires_grad=False, 109 | ) 110 | 111 | with DataCtx() as data_ctx: 112 | data_ctx.send_data(persia_batch) 113 | 114 | os.environ["WORLD_SIZE"] = str(1) 115 | os.environ["RANK"] = str(0) 116 | os.environ["LOCAL_RANK"] = str(0) 117 | 118 | from persia.ctx import PreprocessMode, _prepare_feature 119 | from persia.data import DataLoader, StreamingDataset 120 | from persia.embedding import get_default_embedding_config 121 | from persia.env import get_world_size 122 | 123 | device_id = 0 if use_cuda else None 124 | 125 | with ensure_persia_service( 126 | data_loader_func=data_loader, 127 | embedding_config=EMBEDDING_CONFIG, 128 | global_config=GLOBAL_CONFIG, 129 | embedding_worker_port=random_port(), 130 | embedding_parameter_server_port=random_port(), 131 | nats_server_port=random_port(), 132 | ): 133 | embedding_config = get_default_embedding_config() 134 | 135 | with BaseCtx(device_id=device_id) as ctx: 136 | ctx.common_context.init_nats_publisher(get_world_size()) 137 | ctx.common_context.configure_embedding_parameter_servers( 138 | embedding_config.emb_initialization[0], 139 | embedding_config.emb_initialization[1], 140 | embedding_config.admit_probability, 141 | embedding_config.weight_bound > 0, 142 | embedding_config.weight_bound, 143 | ) 144 | ctx.common_context.wait_servers_ready() 145 | 146 | data_loader = DataLoader( 147 | StreamingDataset(buffer_size=10), timeout_ms=1000 * 30 148 | ) 149 | data_generator = iter(data_loader) 150 | persia_training_batch = next(data_generator) 151 | ( 152 | non_id_type_tensors, 153 | id_type_embedding_tensors, 154 | label_tensors, 155 | ) = _prepare_feature(persia_training_batch, PreprocessMode.EVAL) 156 | 157 | assert_ndarray_base_data( 158 | non_id_type_features, non_id_type_tensors, use_cuda 159 | ) 160 | assert_ndarray_base_data(labels, label_tensors, use_cuda) 161 | # assert_id_type_feature_data(id_type_embedding_tensors, EMBEDDING_CONFIG) 162 | -------------------------------------------------------------------------------- /examples/src/adult-income/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from tqdm import tqdm 7 | from sklearn import metrics 8 | 9 | from persia.ctx import TrainCtx, eval_ctx 10 | from persia.embedding.optim import Adagrad 11 | from persia.embedding.data import PersiaBatch 12 | from persia.env import get_rank, get_local_rank, get_world_size 13 | from persia.logger import get_default_logger 14 | from persia.data import DataLoader, IterableDataset, StreamingDataset 15 | from persia.utils import setup_seed 16 | 17 | from model import DNN 18 | from data_generator import make_dataloader 19 | 20 | 21 | logger = get_default_logger("nn_worker") 22 | 23 | CPU_TEST_AUC = 0.8928645493226243 24 | GPU_TEST_AUC = 0.8927145127622554 25 | 26 | 27 | class TestDataset(IterableDataset): 28 | def __init__(self, test_dir: str, batch_size: int = 128): 29 | super(TestDataset, self).__init__(buffer_size=10) 30 | self.loader = make_dataloader(test_dir, batch_size) 31 | 32 | def __iter__(self): 33 | logger.info("test loader start to generating data...") 34 | for non_id_type_feature, id_type_features, label in self.loader: 35 | yield PersiaBatch( 36 | id_type_features, 37 | non_id_type_features=[non_id_type_feature], 38 | labels=[label], 39 | requires_grad=False, 40 | ) 41 | 42 | 43 | def test(model: torch.nn.Module, data_loader: DataLoader, cuda: bool): 44 | logger.info("start to test...") 45 | model.eval() 46 | 47 | with eval_ctx(model=model) as ctx: 48 | accuracies, losses = [], [] 49 | all_pred, all_labels = [], [] 50 | for (batch_idx, batch_data) in enumerate(tqdm(data_loader, desc="test...")): 51 | (pred, labels) = ctx.forward(batch_data) 52 | label = labels[0] 53 | loss = loss_fn(pred, label) 54 | if cuda: 55 | pred = pred.cpu() 56 | label = label.cpu() 57 | else: 58 | label = label.clone() # cpu mode need copy the target data 59 | all_pred.append(pred.detach().numpy()) 60 | all_labels.append(label.detach().numpy()) 61 | accuracy = (torch.round(pred) == label).sum() / label.shape[0] 62 | accuracies.append(accuracy) 63 | losses.append(loss) 64 | 65 | all_pred, all_labels = np.concatenate(all_pred), np.concatenate(all_labels) 66 | 67 | fpr, tpr, _th = metrics.roc_curve(all_labels, all_pred) 68 | test_auc = metrics.auc(fpr, tpr) 69 | 70 | test_accuracies = torch.mean(torch.tensor(accuracies)) 71 | test_loss = torch.mean(torch.tensor(losses)) 72 | logger.info( 73 | f"test auc is {test_auc} accuracy is {test_accuracies}, loss is {test_loss}" 74 | ) 75 | 76 | model.train() 77 | 78 | return test_auc 79 | 80 | 81 | if __name__ == "__main__": 82 | 83 | reproducible = bool(int(os.environ.get("REPRODUCIBLE", 0))) 84 | embedding_staleness = int(os.environ.get("EMBEDDING_STALENESS", 10)) 85 | 86 | if reproducible: 87 | setup_seed(3) 88 | 89 | model = DNN() 90 | logger.info("init Simple DNN model...") 91 | rank, device_id, world_size = get_rank(), get_local_rank(), get_world_size() 92 | 93 | mixed_precision = True 94 | use_cuda = torch.cuda.is_available() 95 | 96 | if use_cuda: 97 | torch.cuda.set_device(device_id) 98 | model.cuda(device_id) 99 | else: 100 | mixed_precision = False 101 | device_id = None 102 | 103 | logger.info(f"device_id is {device_id}") 104 | 105 | dense_optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 106 | embedding_optimizer = Adagrad(lr=1e-2) 107 | loss_fn = torch.nn.BCELoss(reduction="mean") 108 | 109 | buffer_size = 10 110 | test_dir = os.path.join( 111 | os.path.dirname(os.path.realpath(__file__)), "data/test.npz" 112 | ) 113 | test_dataset = TestDataset(test_dir, batch_size=128) 114 | test_interval = 254 // world_size - 1 115 | 116 | with TrainCtx( 117 | model=model, 118 | embedding_optimizer=embedding_optimizer, 119 | dense_optimizer=dense_optimizer, 120 | mixed_precision=mixed_precision, 121 | device_id=device_id, 122 | ) as ctx: 123 | train_dataloader = DataLoader( 124 | StreamingDataset(buffer_size), 125 | reproducible=reproducible, 126 | embedding_staleness=embedding_staleness, 127 | ) 128 | test_loader = DataLoader(test_dataset) 129 | 130 | logger.info("start to training...") 131 | for (batch_idx, data) in enumerate(train_dataloader): 132 | (output, labels) = ctx.forward(data) 133 | label = labels[0] 134 | loss = loss_fn(output, label) 135 | scaled_loss = ctx.backward(loss) 136 | accuracy = (torch.round(output) == label).sum() / label.shape[0] 137 | 138 | logger.info( 139 | f"current batch idx: {batch_idx} loss: {float(loss)} scaled_loss: {float(scaled_loss)} accuracy: {float(accuracy)}" 140 | ) 141 | if batch_idx % test_interval == 0 and batch_idx != 0: 142 | test_auc = test(model, test_loader, use_cuda) 143 | 144 | checkpoint_dir = os.environ.get("PERSIA_CKPT_DIR", None) 145 | if checkpoint_dir is not None and rank == 0: 146 | logger.info(f"dump checkpoint to {checkpoint_dir}") 147 | ctx.dump_checkpoint(checkpoint_dir, with_jit_model=True) 148 | 149 | if world_size == 1 and reproducible and embedding_staleness == 1: 150 | np.testing.assert_equal( 151 | np.array([test_auc]), 152 | np.array([GPU_TEST_AUC if use_cuda else CPU_TEST_AUC]), 153 | ) 154 | 155 | break 156 | -------------------------------------------------------------------------------- /rust/persia-embedding-holder/src/emb_entry.rs: -------------------------------------------------------------------------------- 1 | use persia_libs::{ 2 | ndarray::Array1, 3 | ndarray_rand::rand_distr::{Gamma, Normal, Poisson, Uniform}, 4 | ndarray_rand::RandomExt, 5 | rand::prelude::SmallRng, 6 | rand::SeedableRng, 7 | serde::{self, Deserialize, Serialize}, 8 | }; 9 | 10 | use persia_embedding_config::InitializationMethod; 11 | use persia_speedy::{Readable, Writable}; 12 | 13 | use crate::eviction_map::EvictionMapValue; 14 | 15 | #[derive(Serialize, Deserialize, Readable, Writable, Clone, Debug)] 16 | #[serde(crate = "self::serde")] 17 | pub struct HashMapEmbeddingEntry { 18 | inner: Vec, // TODO option1: consider using smallvec and slab allocator, and reference that smallvec with &[f32] here to avoid const generics 19 | // TODO option2: consider wrap BufferPool (see crates.io) or modify sharded slab to allocate &[f32] here 20 | // TODO option3: consider using a object pool of &[f32] with predefined length and all these &[f32] comes from a large continuous Vec. When the object pool is exhausted, create a new large continuous Vec and split it to &[f32]s and add them to the object pool 21 | // TODO option4: allocate slices and put them in the slice_arena (see crates.io), then put the slice in the arena into a reusable object pool for consumption 22 | // TODO option5: allocate slices in bumpalo_herd allocator with alloc_slice_fill_default, and unsafely converts it to Vec, then put the Vec in a reusable object pool for consumption. In this case we can actually put the whole entry in the pool 23 | embedding_dim: usize, 24 | sign: u64, 25 | } 26 | 27 | impl HashMapEmbeddingEntry { 28 | pub fn new( 29 | initialization_method: &InitializationMethod, 30 | dim: usize, 31 | require_space: usize, 32 | seed: u64, 33 | sign: u64, 34 | ) -> Self { 35 | let emb = { 36 | let mut rng = SmallRng::seed_from_u64(seed); 37 | match initialization_method { 38 | InitializationMethod::BoundedUniform(x) => { 39 | Array1::random_using((dim,), Uniform::new(x.lower, x.upper), &mut rng) 40 | } 41 | InitializationMethod::BoundedGamma(x) => { 42 | Array1::random_using((dim,), Gamma::new(x.shape, x.scale).unwrap(), &mut rng) 43 | } 44 | InitializationMethod::BoundedPoisson(x) => { 45 | Array1::random_using((dim,), Poisson::new(x.lambda).unwrap(), &mut rng) 46 | } 47 | InitializationMethod::BoundedNormal(x) => Array1::random_using( 48 | (dim,), 49 | Normal::new(x.mean, x.standard_deviation).unwrap(), 50 | &mut rng, 51 | ), 52 | _ => panic!( 53 | "unsupported initialization method for hashmap impl: {:?}", 54 | initialization_method 55 | ), 56 | } 57 | }; 58 | 59 | let mut inner = emb.into_raw_vec(); 60 | if require_space > 0 { 61 | inner.resize(inner.len() + require_space, 0.0_f32); 62 | } 63 | Self { 64 | inner, 65 | embedding_dim: dim, 66 | sign, 67 | } 68 | } 69 | 70 | pub fn new_empty(dim: usize, require_space: usize, sign: u64) -> Self { 71 | Self { 72 | inner: vec![0f32; dim + require_space], 73 | embedding_dim: dim, 74 | sign, 75 | } 76 | } 77 | 78 | pub fn from_emb(emb: Vec, sign: u64) -> Self { 79 | let embedding_dim = emb.len(); 80 | Self { 81 | inner: emb, 82 | embedding_dim, 83 | sign, 84 | } 85 | } 86 | 87 | pub fn from_emb_and_opt(emb: Vec, opt: &[f32], sign: u64) -> Self { 88 | let embedding_dim = emb.len(); 89 | let mut inner = emb; 90 | inner.extend_from_slice(opt); 91 | Self { 92 | inner, 93 | embedding_dim, 94 | sign, 95 | } 96 | } 97 | 98 | pub fn copy_from_other(&mut self, other: &Self) -> bool { 99 | if self.embedding_dim() != other.embedding_dim() { 100 | return false; 101 | } 102 | for (dst, src) in self.inner.iter_mut().zip(other.inner.iter()) { 103 | *dst = *src; 104 | } 105 | return true; 106 | } 107 | 108 | pub fn as_mut_emb_entry_slice(&mut self) -> &mut [f32] { 109 | self.inner.as_mut_slice() 110 | } 111 | 112 | pub fn as_emb_entry_slice(&self) -> &[f32] { 113 | self.inner.as_slice() 114 | } 115 | 116 | pub fn inner_size(&self) -> usize { 117 | self.inner.len() 118 | } 119 | 120 | pub fn dim(&self) -> usize { 121 | self.embedding_dim 122 | } 123 | 124 | pub fn embedding_dim(&self) -> usize { 125 | self.embedding_dim 126 | } 127 | 128 | pub fn emb(&self) -> &[f32] { 129 | &self.inner[..self.embedding_dim()] 130 | } 131 | 132 | pub fn emb_mut(&mut self) -> &mut [f32] { 133 | let dim = self.embedding_dim(); 134 | &mut self.inner[..dim] 135 | } 136 | 137 | pub fn boxed(self) -> Box { 138 | Box::new(self) 139 | } 140 | 141 | pub fn opt(&self) -> &[f32] { 142 | &self.inner[self.embedding_dim()..] 143 | } 144 | 145 | pub fn opt_mut(&mut self) -> &mut [f32] { 146 | let dim = self.embedding_dim(); 147 | &mut self.inner[dim..] 148 | } 149 | 150 | pub fn emb_and_opt_mut(&mut self) -> (&mut [f32], &mut [f32]) { 151 | let dim = self.embedding_dim(); 152 | self.inner.split_at_mut(dim) 153 | } 154 | 155 | pub fn sign(&self) -> u64 { 156 | self.sign 157 | } 158 | } 159 | 160 | impl EvictionMapValue for HashMapEmbeddingEntry { 161 | fn hashmap_key(&self) -> u64 { 162 | self.sign 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /rust/persia-common/src/message_queue.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::ChannelPair; 2 | 3 | use std::ops::Add; 4 | use std::sync::Arc; 5 | 6 | use persia_libs::{ 7 | hyper::{self, Body, Request, Response}, 8 | thiserror, tokio, tracing, url, 9 | }; 10 | 11 | #[derive(thiserror::Error, Debug)] 12 | pub enum PersiaMessageQueueError { 13 | #[error("send error")] 14 | SendError, 15 | #[error("recv error")] 16 | RecvError, 17 | #[error("hyper error")] 18 | HyperError(#[from] hyper::Error), 19 | } 20 | 21 | #[derive(Clone)] 22 | pub struct PersiaMessageQueueClientImpl { 23 | client: hyper::Client, 24 | server_addr: url::Url, 25 | } 26 | 27 | fn expect_uri(url: url::Url) -> hyper::Uri { 28 | url.as_str() 29 | .parse() 30 | .expect("a parsed Url should always be a valid Uri") 31 | } 32 | 33 | impl PersiaMessageQueueClientImpl { 34 | pub fn new(server_addr: &str) -> Self { 35 | let server_addr = url::Url::parse("http://".to_string().add(server_addr).as_str()).unwrap(); 36 | Self { 37 | client: hyper::Client::builder() 38 | .http2_only(true) 39 | .retry_canceled_requests(true) 40 | .set_host(false) 41 | .http2_adaptive_window(true) 42 | .build_http(), 43 | server_addr, 44 | } 45 | } 46 | 47 | pub async fn send(&self, content: Vec) -> Result<(), PersiaMessageQueueError> { 48 | let req = hyper::Request::builder() 49 | .method("POST") 50 | .uri(expect_uri(self.server_addr.join("send").unwrap())) 51 | .body(hyper::Body::from(content)) 52 | .expect("request builder"); 53 | let resp = self.client.request(req).await?; 54 | if resp.status().is_success() { 55 | Ok(()) 56 | } else { 57 | Err(PersiaMessageQueueError::SendError) 58 | } 59 | } 60 | 61 | pub async fn recv(&self) -> Result, PersiaMessageQueueError> { 62 | let req = hyper::Request::builder() 63 | .method("POST") 64 | .uri(expect_uri(self.server_addr.join("recv").unwrap())) 65 | .body(hyper::Body::empty()) 66 | .expect("request builder"); 67 | let resp = self.client.request(req).await?; 68 | if resp.status().is_success() { 69 | Ok(hyper::body::to_bytes(resp.into_body()).await?.to_vec()) 70 | } else { 71 | Err(PersiaMessageQueueError::SendError) 72 | } 73 | } 74 | } 75 | 76 | #[derive(Clone)] 77 | pub struct PersiaMessageQueueService { 78 | message_queue: ChannelPair, 79 | } 80 | 81 | #[derive(Clone)] 82 | pub struct PersiaMessageQueueServerImpl { 83 | message_queue: ChannelPair, 84 | server_handler: Arc>>, 85 | } 86 | 87 | impl PersiaMessageQueueServerImpl { 88 | pub fn new(port: u16, cap: usize) -> PersiaMessageQueueServerImpl { 89 | let message_queue = ChannelPair::new(cap); 90 | let service = PersiaMessageQueueService { 91 | message_queue: message_queue.clone(), 92 | }; 93 | 94 | let server = hyper::Server::bind(&([0, 0, 0, 0], port).into()) 95 | .http2_only(true) 96 | .http2_adaptive_window(true) 97 | .tcp_nodelay(true) 98 | .serve(hyper::service::make_service_fn(move |_| { 99 | let service = service.clone(); 100 | async move { 101 | Ok::<_, hyper::Error>(hyper::service::service_fn(move |req: Request| { 102 | let service = service.clone(); 103 | async move { 104 | match req.uri().path() { 105 | "/send" => { 106 | let body: hyper::body::Bytes = 107 | hyper::body::to_bytes(req.into_body()).await?; 108 | service.message_queue.sender.send_async(body).await.unwrap(); 109 | Ok::<_, hyper::Error>(Response::new(hyper::body::Body::empty())) 110 | } 111 | "/recv" => { 112 | let body = 113 | service.message_queue.receiver.recv_async().await.unwrap(); 114 | Ok::<_, hyper::Error>(Response::new(Body::from(body))) 115 | } 116 | _ => { 117 | tracing::error!("unsupported uri for persia message queue"); 118 | let mut resp = Response::default(); 119 | *resp.status_mut() = hyper::http::StatusCode::BAD_REQUEST; 120 | Ok(resp) 121 | } 122 | } 123 | } 124 | })) 125 | } 126 | })); 127 | 128 | let server_handler = Arc::new(tokio::task::spawn(async move { server.await })); 129 | 130 | Self { 131 | server_handler, 132 | message_queue, 133 | } 134 | } 135 | 136 | pub async fn send(&self, content: Vec) { 137 | self.message_queue 138 | .sender 139 | .send_async(hyper::body::Bytes::from(content)) 140 | .await 141 | .unwrap() 142 | } 143 | 144 | pub async fn recv(&self) -> Vec { 145 | self.message_queue 146 | .receiver 147 | .recv_async() 148 | .await 149 | .unwrap() 150 | .to_vec() 151 | } 152 | 153 | pub async fn handler(&self) -> Arc>> { 154 | self.server_handler.clone() 155 | } 156 | } 157 | --------------------------------------------------------------------------------